diff --git a/plotting.py b/plotting.py index 2b6f66db42ec011097867ee220d52da3b60eb12e..e4dd01227823f6c2e02208eab6f7b5634941b420 100644 --- a/plotting.py +++ b/plotting.py @@ -63,21 +63,29 @@ def add_resonance_vlines(axes, max_order, color='r'): add_vline(axes, to_half_intervall(r), color, text=f'{n}/{m} resonance' if m > 1 else None, lw=3/max(1, m-2)**.5, alpha=1/max(1, m-2), text_vertical=True, text_top=True) -def subplot_shared_labels(axes, xlabel=None, ylabel=None, clear=True): +def subplot_shared_labels(axes, xlabel=None, ylabel=None, clear='auto'): """Adds labels to shared axes as needed :param axes: 2D array of axes from subplots (pass squeeze=False to plt.subplots if required) :param xlabel: the shared xlabel :param ylabel: the shared ylabel - :param clear: if true (default) any existing labels on the axes will get cleared + :param clear: if 'auto' clears duplicate labels; if true clears any existing labels; if false do not clear any labels """ - for r in range(axes.shape[0]): + axes = np.array(axes) + for r in reversed(range(axes.shape[0])): for c in range(axes.shape[1]): - if clear: axes[r,c].set(xlabel=None, ylabel=None) - if c == 0 or not axes[r,c].get_shared_y_axes().joined(axes[r,c], axes[r,0]): - axes[r,c].set(ylabel=ylabel) - if r == axes.shape[0]-1 or not axes[r,c].get_shared_x_axes().joined(axes[r,c], axes[-1,c]): + if clear is True: + axes[r,c].set(xlabel=None, ylabel=None) + elif clear == 'auto': + if r < axes.shape[0]-1 and axes[r,c].get_shared_x_axes().joined(axes[r,c], axes[-1,c]): + axes[r,c].set(xlabel=None) + if c > 0 and axes[r,c].get_shared_y_axes().joined(axes[r,c], axes[r,0]): + axes[r,c].set(ylabel=None) + + if xlabel is not None and (r == axes.shape[0]-1 or not axes[r,c].get_shared_x_axes().joined(axes[r,c], axes[-1,c])): axes[r,c].set(xlabel=xlabel) + if ylabel is not None and (c == 0 or not axes[r,c].get_shared_y_axes().joined(axes[r,c], axes[r,0])): + axes[r,c].set(ylabel=ylabel) def grid_diagonal(ax, **kwargs): for k, v in dict(color='lightgray',lw=1,zorder=-100).items(): @@ -172,7 +180,7 @@ def plot_tbt(ax, libera_data, what='fsxy', *, over_time=True, turn_range=None, t else: a = ax.twinx() if i > 0 else ax if i > 0: a.spines.right.set_position(("axes", 0.9+0.1*len(axes))) - a.set(ylabel=axlabels[w]) + a.set(ylabel=axlabels[w], xlabel='Time / s' if over_time else 'Turn') axes.update({axlabels[w]: a}) if w == 'f': @@ -230,7 +238,7 @@ def plot_tune_spectrum(ax, libera_data, xy, turn_range=None, time_range=None, tu ax.plot(freq, mag, **kwargs) ax.set(xlim=tune_range, xlabel=f'Tune $q_{xy}$', - ylabel='a.u.') + ylabel='Magnitude / a.u.') if fit: q = None