Skip to content
Snippets Groups Projects
other.py 30.6 KiB
Newer Older
Philipp Niedermayer's avatar
Philipp Niedermayer committed
from .common import *
def add_vline(axes, r, color='k', text=None, label=None, order=None, lw=1, text_top=True, text_vertical=True, zorder=-100, **kwargs):
    for i, a in enumerate(axes):
        if a is None: continue
        if order:
            kwargs['ls'] = (0, [11,2]+[1,2]*order)
            kwargs['alpha'] = 1/order
        a.axvline(r, c=color, label=label if a.get_xlim()[0] <= r <= a.get_xlim()[1] else None, zorder=zorder, lw=lw, **kwargs)
    if text:
        a.text(r, .99 if text_top else 0.01, f' {text} ', fontsize='small', c=color, zorder=zorder,
               alpha=1/order if order else 1, transform=a.get_xaxis_text1_transform(0)[0],
               rotation=90 if text_vertical else 0, clip_on=True, ha='right' if text_vertical else 'left',
               va='top' if text_top else 'bottom')

def to_half_intervall(q):
    """Returns the corresponding fractional tune value in the half interval [0;0.5)
    """
    q = q-np.floor(q)
    q = np.where(q < 0, q+1, q)
    return np.min([q, 1-q], axis=0)

def add_resonance_vlines(axes, max_order, color='r', zorder=100):
    low, high = int(min([min(a.get_xlim()) for a in axes])), int(max([max(a.get_xlim()) for a in axes])+1)
    res = set()
    for m in range(1, max_order+1):
        for n in range(m//2+1):
            r = n/m
            if r in res: continue
            res.add(r)
            for h in range(low, high):
                for s in (+1, -1):
                    add_vline(axes, h+s*to_half_intervall(r), color, text=f'{n}/{m} resonance' if m > 1 else None, lw=3/max(1, m-2)**.5,
                              alpha=1/max(1, m-2), text_vertical=True, text_top=True, zorder=zorder)
Philipp Niedermayer's avatar
Philipp Niedermayer committed
def subplot_shared_labels(axes, xlabel=None, ylabel=None, clear='auto'):
Philipp Niedermayer's avatar
Philipp Niedermayer committed
    """Adds and removes labels to shared axes as needed
    
    :param axes: 2D array of axes from subplots (pass squeeze=False to plt.subplots if required)
    :param xlabel: the shared xlabel
    :param ylabel: the shared ylabel
Philipp Niedermayer's avatar
Philipp Niedermayer committed
    :param clear: if 'auto' clears duplicate labels; if true clears any existing labels; if false do not clear any labels
Philipp Niedermayer's avatar
Philipp Niedermayer committed
    axes = np.array(axes)
    for r in reversed(range(axes.shape[0])):
        for c in range(axes.shape[1]):
Philipp Niedermayer's avatar
Philipp Niedermayer committed
            if clear is True:
                axes[r,c].set(xlabel=None, ylabel=None)
            elif clear == 'auto':
                if r < axes.shape[0]-1 and axes[r,c].get_shared_x_axes().joined(axes[r,c], axes[-1,c]):
                    axes[r,c].set(xlabel=None)
                if c > 0 and axes[r,c].get_shared_y_axes().joined(axes[r,c], axes[r,0]):
                    axes[r,c].set(ylabel=None)
                
            if xlabel is not None and (r == axes.shape[0]-1 or not axes[r,c].get_shared_x_axes().joined(axes[r,c], axes[-1,c])):
                axes[r,c].set(xlabel=xlabel)
Philipp Niedermayer's avatar
Philipp Niedermayer committed
            if ylabel is not None and (c == 0 or not axes[r,c].get_shared_y_axes().joined(axes[r,c], axes[r,0])):
                axes[r,c].set(ylabel=ylabel)
Philipp Niedermayer's avatar
Philipp Niedermayer committed
def grid_diagonal(ax, **kwargs):
Philipp Niedermayer's avatar
Philipp Niedermayer committed
    """Adds a diagonal grid to the given axes
    
    :param ax: the axes
    :param kwargs: optional arguments passed to ax.axline
    """
Philipp Niedermayer's avatar
Philipp Niedermayer committed
    for k, v in dict(color='lightgray',lw=1,zorder=-100).items():
        kwargs.setdefault(k, v)
    xlim, ylim = ax.get_xlim(), ax.get_ylim()
    xtick, ytick = ax.get_xticks(), ax.get_yticks()
    for x in xtick:
        if xlim[0] <= x <= xlim[1]:
            ax.axline((x, ylim[0]), slope=1, **kwargs)
    for y in ytick:
        if ylim[0] <= y <= ylim[1]:
            ax.axline((xlim[0], y), slope=1, **kwargs)
    ax.set(xlim=xlim, ylim=ylim)
    
    
def fiberplot(ax, datasets, *, labels=[None, None], vertical=True):
Philipp Niedermayer's avatar
Philipp Niedermayer committed
    """Create a fiberplot with multiple datasets of x and y data
    x is plotted on the horizontal (vertical) axis
    y determines the height (width) of the fiberplot
    position of each dataset determines the fiberplot position on the vertical (horizontal) axis
    label can be used to compare 2 categories where the fiberplot is split into two
    
    :param datasets: 2D array of datasets [(position, x, y), ...] or dict with two levels {label: {position: (x,y), ...}, ...}
    
Philipp Niedermayer's avatar
Philipp Niedermayer committed
    """
    if type(datasets) is dict:
        labels = list(datasets.keys())
        datasets = [[(p, *v) for p, v in dataset.items()] for dataset in datasets.values()]
    yticks = list({p for dataset in datasets for p, *_ in dataset})
    dy = np.min(np.diff(sorted(yticks)))/2
    
    for c, dataset in enumerate(datasets):
        color = next(ax._get_lines.prop_cycler)['color']
        for i, (p, x, y, *_) in enumerate(dataset):
            v = dy*(y-np.min(y))/np.max(y)
            plot_function = ax.fill_betweenx if vertical else ax.fill_between
            plot_function(x, p - v*(1-c%2), p + v*(c%2 if len(datasets)>1 else 1),
                          color=color, lw=0, label=labels[c] if i == 0 else None)
    if vertical:
        ax.set(xticks=yticks)
    else:
        ax.set(yticks=yticks)

class SpecialMultipleLocator(mpl.ticker.MaxNLocator):
    def __init__(self, fixed_multiples, n=5, minor_n=None):
        """Create a locator that locks to fixed_multiples with about n ticks
        
        For ranges smaller than the smallest fixed_multiple, the default MaxNLocator is used
        For ranges larger than the largest fixed_multiple, a multiple of the later is used
        If minor_n is given with same length as fixed_multiples, the ticks are subdivided by the corresponding number
        """
        super().__init__(n)
        self.fixed_multiples = fixed_multiples
        self.n = n
        self.minor_n = minor_n
        
    def _raw_ticks(self, vmin, vmax):
        if vmax - vmin < self.n*self.fixed_multiples[0]:
            return super()._raw_ticks(vmin, vmax)
        for step in self.fixed_multiples:
            if (vmax - vmin)/step <= self.n:
                break
        while (vmax - vmin)/step > self.n:
            step += self.fixed_multiples[-1]
        if self.minor_n is not None:
            if step in self.fixed_multiples:
                step /= self.minor_n[self.fixed_multiples.index(step)]
            else:
                step /= self.minor_n[-1]
        return np.arange(int(vmin/step)*step, vmax+step, step)

class DegreeLocator(SpecialMultipleLocator):
    """A plot tick locator for angles in degree
    """
    def __init__(self, kind='major'):
        super().__init__((5, 15, 30, 45, 60, 90, 120, 180, 360), 5, None if kind=='major' else (
                         (5,  3,  3,  3,  4,  3,   4,   4,   4)))

class RadiansLocator(SpecialMultipleLocator):
    """A plot tick locator for angles in radiant
    """
    def __init__(self, kind='major'):
        super().__init__(list(np.deg2rad((5, 15, 30, 45, 60, 90, 120, 180, 360))), 5, None if kind=='major' else (
                         (5,  3,  3,  3,  4,  3,   4,   4,   4)))

class RadiansFormatter(mpl.ticker.Formatter):
    """A plot tick formatter for angles in radiant
    """
    def __call__(self, x, pos=None):
        if x == 0:
            return '0'
        s = '-' if x < 0 else ''
        x = abs(x)
        if x == np.pi:
            return f'${s}\\pi$'
        for n in (2,3,4,6,8,12):
            m = round(x/(np.pi/n))
            if abs(x - m*np.pi/n) < 1e-10 and m/n != m//n:
                if m == 1: m = ''
                return f'${s}{m}\\pi/{n}$'
        return f'${x/np.pi:g}\\pi$'
    
def format_axis_degrees(yaxis):
    yaxis.set_major_locator(DegreeLocator('major'))
    yaxis.set_minor_locator(DegreeLocator('minor'))

def format_axis_radians(yaxis):
    yaxis.set_major_locator(RadiansLocator('major'))
    yaxis.set_minor_locator(RadiansLocator('minor'))
    yaxis.set_major_formatter(RadiansFormatter())
def add_scale(ax, scale, text=None, *, vertical=False, size=0.01, padding=0.1, loc='auto', color='k', fontsize='x-small'):
    """Make a scale or yardstick patch"""
    if loc == 'auto': loc = 'upper left' if vertical else 'lower right'
    w, h = scale, size
    w_trans, h_trans = ax.transData, ax.transAxes
    if vertical: # swap dimensions
        w, h = h, w
        w_trans, h_trans = h_trans, w_trans
    aux = mpl.offsetbox.AuxTransformBox(mpl.transforms.blended_transform_factory(w_trans, h_trans))
    aux.add_artist(plt.Rectangle((0,0), w, h, fc=color))
    if text:
        if vertical:
            aux.add_artist(plt.Text(2*w, h/2, text, color=color, ha='left', va='center', rotation='vertical', fontsize=fontsize))
        else:
            aux.add_artist(plt.Text(w/2, 1.5*h, text, color=color, va='bottom', ha='center', fontsize=fontsize))
    ab = mpl.offsetbox.AnchoredOffsetbox(loc, borderpad=padding, zorder=100, frameon=False)
    ab.set_child(aux)
    ax.add_artist(ab)
def v(swap_xy, x, y):
    return (y, x) if swap_xy else (x, y)
    
def smooth_plot(ax, x, y, smoothing=None, swap_xy=False, **kwargs):
    r = ax.plot(*v(swap_xy, x, *smooth(y, n=smoothing)), **kwargs)
        add_scale(ax, np.mean(np.diff(x))*smoothing, vertical=swap_xy)
        kwargs.update(lw=1, alpha=0.1, label=None, zorder=-1, color=r[0].get_color())
        ax.plot(*v(swap_xy, x, y), **kwargs)


# Libera tbt data
##################

def turn_or_time_range(time, turn_range=None, time_range=None):
    if turn_range is not None and time_range is not None: 
        raise ValueError('Parameters turn_range or time_range are mutually exclusive')
    if time_range is not None:
        return irng(time, *time_range)
    if turn_range is None:
        return slice(None, None)
def plot_tbt(ax, libera_data, what='fsxy', *, over_time=True, turn_range=None, time_range=None, averaging=500, show_avg_std=True, show_avg_extrema=True, **kwargs):
    """Plot turn-by-turn data    
Philipp Niedermayer's avatar
Philipp Niedermayer committed
    :param libera_data: instance of LiberaTBTData
    :param what: signals to plot, any combination of 'f' (revolution frequency), 's' (sum signal), 'x' and/or 'y' (position)
    :param over_time: if True, plot data as function of time rather than turn
    :param turn_range: (start, stop) tuple of turns to plot
    :param time_range: (start, stop) tuple of time in s to plot
    :param averaging: number of consecutive turns to average over
    :param show_avg_std: if True, plot the band of standard deviation around the averaged data
    :param show_avg_std: if True, plot the band of min-to-max around the averaged data
    """
    assert isinstance(libera_data, LiberaTBTData), f'Expected LiberaTBTData but got {type(libera_data)}'
    
    axlabels = dict(f='$f_\\mathrm{rev}$ / kHz', s='Pickup sum / a.u.', x='Position / mm', y='Position / mm')
    turn_range = turn_or_time_range(libera_data.t, turn_range, time_range)
    t = libera_data.t[turn_range]
    x = t if over_time else np.arange(0, len(t))
    axes = {} # label: ax
    limits = {} # ax: Bbox
    ls, labels = [], []
    for i, w in enumerate(what):
        if axlabels[w] in axes:
            a = axes[axlabels[w]]
        else:
            a = ax.twinx() if i > 0 else ax
            if i > 0: a.spines.right.set_position(("axes", 0.9+0.1*len(axes)))
Philipp Niedermayer's avatar
Philipp Niedermayer committed
            a.set(ylabel=axlabels[w], xlabel='Time / s' if over_time else 'Turn')
            axes.update({axlabels[w]: a})
            limits[a] = mpl.transforms.Bbox([[0,0],[0,0]])
            v = 1e-3/np.diff(t) # to kHz
        else:
            v = getattr(libera_data, w)[turn_range]
        
        args = dict(**kwargs)
        if 'c' not in args and 'color' not in args: 
            args['c'] = dict(f=cmap_petroff_10(1), s=cmap_petroff_10(3), x=cmap_petroff_10(0), y=cmap_petroff_10(2))[w]
        xx, vv = avg(x[:len(v)], v, n=averaging)
        l, = a.plot(xx, vv, **args)
        ls.append(l)
        labels.append(dict(f='Revolution frequency', s='Pickup sum signal', x='X position', y='Y position')[w])
        limits[a].update_from_data_xy(np.vstack(l.get_data()).T, ignore=limits[a].width==limits[a].height==0)
        if averaging > 1 and show_avg_std:
            _, ve = avg(x[:len(v)], v, n=averaging, function=np.std)
            a.fill_between(xx, vv-ve, vv+ve, color=l.get_color(), alpha=0.4, zorder=-1, lw=0)
        if averaging > 1 and show_avg_extrema:
            _, vmi = avg(x[:len(v)], v, n=averaging, function=np.min)
            _, vma = avg(x[:len(v)], v, n=averaging, function=np.max)
            a.fill_between(xx, vmi, vma, color=l.get_color(), alpha=0.2, zorder=-2, lw=0)
        
    # autosacale
    for a, lim in limits.items():
        a.dataLim = lim
        a.autoscale_view()
    
    if len(what) > 1: a.legend(ls, labels, fontsize='small')
def plot_btf(axf, axp, data, *, frev=None, filled=False, smoothing=None, **kwargs):
Philipp Niedermayer's avatar
Philipp Niedermayer committed
    """Plot beam transfer function
    
    :param axf: axis for magnitude response
    :param axp: axis for phase response (or None)
Philipp Niedermayer's avatar
Philipp Niedermayer committed
    :param data: instance of NWAData
    :param frev: if not None, plot fraction of revolution frequency (tune) on x axis
    :param filled: show trace as filled plot instead of line plot
    :param smoothing: apply running average to data
Philipp Niedermayer's avatar
Philipp Niedermayer committed
    :param kwargs: arguments passed to plot function
    """
    
    f = data.f*(1/frev if frev else 1)
    
    if filled:
        c = axf.fill_between(f, np.zeros_like(data.m), *smooth(data.m, n=smoothing), **kwargs).get_facecolor()
    else:
        c = smooth_plot(axf, f, data.m, smoothing, **kwargs)[0].get_color()
    
    if 'c' not in kwargs and 'color' not in kwargs: kwargs.update(color=c)
    if axp is not None: smooth_plot(axp, f, data.p, smoothing, **kwargs)
                          
    if isinstance(data, NWADataAverage):
        kwargs.update(lw=0, alpha=0.5, label=None)
        axf.fill_between(f, *smooth(data.m - data.m_std, data.m + data.m_std, n=smoothing), **kwargs)
        if axp is not None: axp.fill_between(f, *smooth(data.p - data.p_std, data.p + data.p_std, n=smoothing), **kwargs)

    axf.set(ylabel=f'Magnitude response / {data.m_unit}', xlim=(np.min(f), np.max(f)))
    if axp is not None: axp.set(ylabel=f'Phase response / {data.p_unit}')
    (axf if axp is None else axp).set(xlabel='Stimulus tune' if frev else f'Stimulus frequency / {data.f_unit}')


def plot_btf_scan(axf, axp, dataset, *, frev=None, smoothing=None, cmap='gist_heat_r', colorbar=False, **kwargs):
    """Plot beam transfer function
    
    :param axf: axis for magnitude response
    :param axp: axis for phase response (or None)
    :param dataset: dictionary {scan value: instance of NWAData}
    :param frev: if not None, plot fraction of revolution frequency (tune) on x axis
    :param smoothing: apply running average to data
    :param colorbar: add colorbar to plot
    :param kwargs: arguments passed to plot function
    """
    keys = list(dataset.keys())
    primary = dataset[keys[0]]
    
    # check dataset for consistency
    assert np.all([primary.f_unit == dataset[k].f_unit for k in keys]), 'Data does not have equal units of frequency!'
    assert np.all([primary.m_unit == dataset[k].m_unit for k in keys]), 'Data does not have equal units of magnitude!'
    assert np.all([primary.p_unit == dataset[k].p_unit for k in keys]), 'Data does not have equal units of phase!'
    
    # collect magnitudes and phases into 2D array
    if np.all([np.all(primary.f == dataset[k].f) for k in keys]):
        f = primary.f*(1/frev if frev else 1)
    else:
        # different frequency rasters, works but might not be what we want
        f = [dataset[k].f*(1/frev if frev else 1) for k in keys]
    magnitudes = [smooth(dataset[k].m, n=smoothing)[0] for k in keys]
    phases = [smooth(dataset[k].p, n=smoothing)[0] for k in keys]
    
    # plot 2D array
    cmf = axf.pcolormesh(f, keys, magnitudes, cmap=cmap, rasterized=True, **kwargs)
    if smoothing: add_scale(axf, np.mean(np.diff(f))*smoothing)
    if axp is not None: 
        cmp = axf.pcolormesh(f, keys, phases, cmap=cmap, rasterized=True, **kwargs)
        if smoothing: add_scale(axp, np.mean(np.diff(f))*smoothing)
    
    # plot layout
    if colorbar:
        axf.get_figure().colorbar(cmf, label=f'Magnitude response / {primary.m_unit}', ax=axf)
        if axp is not None: axp.get_figure().colorbar(cmp, label=f'Phase response / {primary.p_unit}', ax=axp)
    else:
        axf.set(title=f'Magnitude response')
        if axp is not None: axp.set(title=f'Phase response')
    (axf if axp is None else axp).set(xlim=(np.min(f), np.max(f)), xlabel='Stimulus tune' if frev else f'Stimulus frequency / {primary.f_unit}')
    if len(keys) < 12:
        axf.set(yticks=keys)
        if axp is not None: axp.set(yticks=keys)



def plot_tune_spectrum(ax, libera_data, xy, turn_range=None, time_range=None, tune_range=None, fit=False, fitargs=None, smoothing=None, return_spectrum=False, swap_xy=False, scaling='amplitude', **kwargs):
    """Plot a tune spectrum based on turn-by-turn data
    
    :param ax: Axis to plot onto
    :param libera_data: Instance of LiberaTBTData class
    :param xy: either 'x' or 'y'
    :param turn_range: tuple of (start_turn, stop_turn) for range to plot
Philipp Niedermayer's avatar
Philipp Niedermayer committed
    :param time_range: tuple of (start_time, stop_time) in seconds for range to plot
    :param tune_range: tuple of (start_tune, stop_tune) for range to plot
    :param fit: if True or any of (fit_lorenzian, fit_gaussian, fit_*), determine the tune from a fit on the spectrum
    :param fitargs: dict of keyword-arguments to pass to fit function
Philipp Niedermayer's avatar
Philipp Niedermayer committed
    :param smoothing: if specified, apply a moving average smoothing filter of this width to the data
    :param swap_xy: if True, swap plot axis
    :param scaling: scaling of fft amplitude, one of 'a' (oscillation amplitude in mm), 'pds' (power density spectrum), 'e' (oscillation energy)
    if fitargs is None: fitargs = {}
    assert isinstance(libera_data, LiberaTBTData), f'Expected LiberaTBTData but got {type(libera_data)}'
    
    turn_range = turn_or_time_range(libera_data.t, turn_range, time_range)
    tbt_data = getattr(libera_data, xy)[turn_range]
    
    fft = np.fft.rfft(tbt_data)
    freq, mag, phase = np.fft.rfftfreq(len(tbt_data), d=1), np.abs(fft), np.angle(fft)
    if scaling.lower()[0] == 'a':
        # oscillation amplitude
        label = 'Oscillation amplitude $\\hat{'+xy+'}$ / ' + getattr(libera_data, xy+'_unit')
        mag *= 2/len(tbt_data)
    elif scaling.lower() == 'pds':
        # power density spectrum
        label = '$|\\mathrm{FFT}|^2$ / a.u.'
        mag = mag**2 # magnitude squared
    elif scaling.lower()[0] == 'e':
        # oscillation energy
        mag *= 2/len(tbt_data) # amplitude in mm
        mag = 2*np.pi**2 * freq**2 * mag**2 # Energy E/m/f0² in mm² with E=kx²/2=m*(2*pi*f)²*x² and here freq=f/f0
        f0 = np.mean(1/np.diff(libera_data.t)) # revolution frequency f0 in Hz
        mag *= f0**2 # E/m in mm²Hz²
        mag *= SI(getattr(libera_data, xy+'_unit')+'^2 Hz^2').to('eV/u').magnitude # E/m in eV/u
        label = 'Oscillation energy $E/m$ / eV/u'
    
    if tune_range is not None:
        mask = irng(freq, *tune_range)
        freq, mag, phase = freq[mask], mag[mask], phase[mask]
    else:
        tune_range = (0, 0.5)
    smooth_plot(ax, freq, mag, smoothing, swap_xy=swap_xy, **kwargs)
    mag, phase = smooth(mag, phase, n=smoothing)
    ax_set(ax, swap_xy=swap_xy, xlim=tune_range, xlabel=f'Tune $q_{xy}$',
            if fit is True: fit = fit_lorenzian
            fitr = fit(freq, mag, **fitargs)
            if fit in (fit_lorenzian, fit_gaussian):
                q, w = fitr[0][2], fitr[0][3]
                if q<np.min(freq) or q>np.max(freq) or w > 0.1:
                    raise RuntimeError('Fit failed')
                q = SI.Measurement(q, (fitr[1][2]**2+w**2+fitr[1][3]**2)**0.5, '') # conservative estimate of error including width of peak
                ax.plot(*v(swap_xy, *fitr[-1]), '--', lw=1, label=f'Fit $q_{xy}={q:~L}$', zorder=50)
            elif fit in (fit_multi_lorenzian, ):
                qs, labels = [], []
                for i in range(int(len(fitr[0])/3)):
                    q, w = fitr[0][3*i+2], fitr[0][3*i+3]
                    if q<np.min(freq) or q>np.max(freq) or w > 0.1:
                        raise RuntimeError('Fit failed')
                    q = SI.Measurement(q, (fitr[1][3*i+2]**2+w**2+fitr[1][3*i+3]**2)**0.5, '') # conservative estimate of error including width of peak
                    qs.append(q)
                    labels.append(f'Fit $q_{{{xy},{i+1}}}={q:~L}$')
                ax.plot(*v(swap_xy, *fitr[-1]), '--', lw=1, label='\n'.join(labels), zorder=50)
        except RuntimeError:
            print('Warning: fit failed')
Philipp Niedermayer's avatar
Philipp Niedermayer committed
        
        if return_spectrum:
            return freq, mag, phase, q
        return q
Philipp Niedermayer's avatar
Philipp Niedermayer committed
    if return_spectrum:
        return freq, mag, phase
def plot_tune_spectrum_scan(ax, dataset, xy, turn_range=None, time_range=None, tune_range=None, smoothing=None, cmap='gist_heat_r', colorbar=False, **kwargs):
    """Plot a tune spectrum based on turn-by-turn data for a parameter scan
    
    :param ax: Axis to plot onto
    :param dataset: dictionary {scan value: instance of LiberaTBTData class}
    :param xy: either 'x' or 'y'
    :param turn_range: tuple of (start_turn, stop_turn) for range to plot
    :param time_range: tuple of (start_time, stop_time) in seconds for range to plot
    :param tune_range: tuple of (start_tune, stop_tune) for range to plot
    :param smoothing: if specified, apply a moving average smoothing filter of this width to the data
    :param cmap: colormap
    :param colorbar: add colorbar to plot
    """
    keys = list(dataset.keys())
    primary = dataset[keys[0]]
    
    assert np.all([isinstance(dataset[k], LiberaTBTData) for k in keys]), f'Expected LiberaTBTData but got {[type(dataset[k]) for k in keys]}'
    
    # collect spectra into 2D array
    f, magnitudes = [], []
    for k in keys:
        tbt_data = getattr(dataset[k], xy)[turn_or_time_range(dataset[k].t, turn_range, time_range)]
    
        fft = np.fft.rfft(tbt_data)
        freq, mag, phase = np.fft.rfftfreq(len(tbt_data), d=1), np.abs(fft), np.angle(fft)
        mag = mag**2 # magnitude squared
        
    
        if tune_range is not None:
            mask = irng(freq, *tune_range)
            freq, mag, phase = freq[mask], mag[mask], phase[mask]
        else:
            tune_range = (0, 0.5)
        
        f.append(freq)
        magnitudes.append(*smooth(mag, n=smoothing))
    
    if np.all(np.roll(f, 1, axis=0) == f):
        f = f[0]
    
    # plot 2D array
    cmf = ax.pcolormesh(f, keys, magnitudes, cmap=cmap, rasterized=True, **kwargs)
    if smoothing: add_scale(ax, np.mean(np.diff(f))*smoothing)
    
    # plot layout
    if colorbar:
        ax.get_figure().colorbar(cmf, label='$|\\mathrm{FFT}|^2$ / a.u.', ax=ax, )
    ax.set(xlim=(np.min(f), np.max(f)), xlabel='Tune $q_{xy}$')
    if len(keys) < 12:
        ax.set(yticks=keys)

def plot_tune_spectrogram(ax, libera_data, xy, *, nperseg=2**12, noverlap=None, ninterpol=None, over_time=True, colorbar=False,
                          turn_range=None, time_range=None, tune_range=None, excitation=None, vmin=None, vmax=None, smoothing=None,
                          cmap='gist_heat_r', bunches=None, show_nperseg=True, fit=False, fitkwarg=dict()):
    """Plot a tune spectrogram based on turn-by-turn data
    
    :param ax: Axis to plot onto
    :param libera_data: Instance of LiberaTBTData or LiberaBBBData
                        When passing LiberaTBTData, the spectrogram for the single bunch is returned
                        When passing LiberaBBBData, the resulting spectrogram will be an interleave of the single bunch spectra
                        Hint: Use LiberaBBBData.to_tbt_data(h, b) to extract single bunch data
    :param bunches: bunch number in case libera_data is an instance of LiberaBBBData
    :param xy: either 'x' or 'y'
    
    :param over_time: plot data as function of time (if true) or turn number (otherwise).
                         Note that plotting over time is not suited for interactive plots.
    :param turn_range: tuple of (start_turn, stop_turn) for range to plot
    :param time_range: tuple of (start_time, stop_time) in seconds for range to plot
    :param tune_range: tuple of (start_tune, stop_tune) ror range to plot
    :param excitation: tuple of (type, tune, ...) to indicate excitation range. Supported types are:
                         Excitation frequency band: ('band', tune, bandwidth_in_Hz)
                         Sinusoidal excitation: ('sine', tune)
                         Chirp excitation: ('chirp', tune, bandwidth_in_Hz, chirp_frequency_in_Hz, chirp_phase_in_rad)
    :param vmin: minimum value for colorscale
    :param vmax: maximum value for colorscale
    :param smoothing: if specified, apply a moving average smoothing filter of this width to the data along the tune axis
    :param colorbar: add colorbar to plot
    if bunches is None:
        # STFT on turn by turn data
        assert isinstance(libera_data, LiberaTBTData), f'Either pass number of bunches or single bunch LiberaTBTData (but got {type(libera_data)}).'
        if ninterpol is None: ninterpol = 4
    else:
        # interleaved STFT on bunch-by-bunch data
        assert isinstance(libera_data, LiberaBBBData), f'Passing number of bunches requires LiberaBBBData (but got {type(libera_data)}).'
        if ninterpol is None: ninterpol = 1
        
    noverlap = (nperseg - nperseg//ninterpol) if noverlap is None else noverlap
    # iterate bunches
    turns, values, times, frevs = [], [], [], []
    iterate_bunches = [0] if bunches is None else range(bunches)
    for b in iterate_bunches:
        # extract turn-by-turn data for single bunch
        _libera_data = libera_data if bunches is None else libera_data.to_tbt_data(bunches, b=b)
        _turn_range = turn_or_time_range(_libera_data.t, turn_range, time_range) # crop by turn or time
        _b_range = slice(None, None) if bunches is None else slice(nperseg*b//bunches, None) # to shift fft window
        _tbt_data = getattr(_libera_data, xy)[_turn_range][_b_range]

        # compute STFT
        tune, idx, _value = scipy.signal.stft(_tbt_data, fs=1, nperseg=nperseg, noverlap=noverlap, window='boxcar', padded=False, scaling='psd')        
        idx = idx.astype(int)
        _turn = np.arange(len(_libera_data.t))[_turn_range][_b_range][0] + idx
        _time = _libera_data.t[_turn_range][_b_range][idx]
        _frev = scipy.ndimage.uniform_filter1d(1/np.diff(_libera_data.t)[_turn_range][_b_range], size=nperseg)[idx]
        turns.append(_turn); values.append(_value); times.append(_time); frevs.append(_frev)

    # interleave STFTs from single bunches
    # (we can't just pass fs=bunches to STFT because their phase is not related)
    if bunches is not None and bunches > 1 and noverlap > 0:
        raise NotImplementedError('Multi-bunch spectra are currently not supported with non-zero overlap in STFT. Either pass TBTData or noverlap=0')
    turn = np.empty((sum([t.shape[0] for t in turns]),), dtype=turns[0].dtype)
    time = np.empty((sum([t.shape[0] for t in times]),), dtype=times[0].dtype)
    frev = np.empty((sum([t.shape[0] for t in frevs]),), dtype=frevs[0].dtype)
    value = np.empty((values[0].shape[0], sum([v.shape[1] for v in values])), dtype=values[0].dtype)
    for b in iterate_bunches:
        turn[b::bunches] = turns[b]
        time[b::bunches] = times[b]
        frev[b::bunches] = frevs[b]
        value[:,b::bunches] = values[b]

    # magnitude squared
    mag = np.abs(value)**2
    #mag[0,:] = 0 # supress DC
    
    xdata = time if over_time else turn
    
    if tune_range is not None:
        mask = irng(tune, *tune_range)
        tune, mag = tune[mask], mag[mask, :]
    else:
        tune_range = (0, 0.5)
    
    # smooth data
    if smoothing is not None:
        add_scale(ax, smoothing*np.mean(np.diff(tune)), vertical=True)
        tune = scipy.signal.savgol_filter(tune, smoothing, 0)
        mag = scipy.signal.savgol_filter(mag, smoothing, 0, axis=0)        
    
    if over_time:
        # pcolormesh can handle non-equidistant data, but is not well suited for interactive plots (slow!)
Philipp Niedermayer's avatar
Philipp Niedermayer committed
        cm = ax.pcolormesh(xdata, tune, mag, shading='nearest', cmap=cmap, rasterized=True,
                           #cmap='gist_heat_r',
                           vmin=vmin or np.percentile(mag, 0.5), vmax=vmax or np.percentile(mag, 99.5),
                           #cmap='plasma_r',
                           #norm=mpl.colors.LogNorm(),
                           #norm=mpl.colors.LogNorm(vmin=np.nanmean(mag)/5, vmax=np.nanmean(mag)*1000),
                           #vmin=np.nanmean(mag), vmax=np.nanmean(mag)+np.nanstd(mag),
        )
    else:
        # imshow is much faster, but can only handle equidistant data
        cm = ax.imshow(mag, extent=(xdata[0], xdata[-1], tune[-1], tune[0]), aspect='auto', rasterized=True, 
                       cmap=cmap, #cmap='plasma_r',
                       vmin=vmin or np.percentile(mag, 1), vmax=vmax or np.percentile(mag, 99),
        ax.get_figure().colorbar(cm, label='$|\\mathrm{FFT}|^2$ / a.u.', ax=ax)
    if show_nperseg and not over_time:
        add_scale(ax, nperseg)
    if excitation is not None:
        ex_type, ex_q, *ex_args = excitation
        ex_style = dict(lw=1, ls=(0, (4, 4)), color='k', alpha=0.5, zorder=50)
        if ex_type == 'band':
            ex_dq = ex_args[0]/frev
            ax.plot(xdata, ex_q+ex_dq, xdata, ex_q-ex_dq, **ex_style)
            #ax.fill_between(xdata, ex_q+ex_dq, 0.5*np.ones_like(xdata), hatch='//', alpha=0.5, edgecolor='k', facecolor='none')
            #ax.fill_between(xdata, 0*np.ones_like(xdata), ex_q-ex_dq, hatch='//', alpha=0.5, edgecolor='k', facecolor='none')
        elif ex_type == 'sine':
            ax.plot(xdata, ex_q*np.ones_like(xdata), **ex_style)
        elif ex_type == 'chirp':
            ex_dq = ex_args[0]/frev
            ex_fc, ex_pc = ex_args[1:3]
            ax.plot(xdata, ex_q + ex_dq/2*np.sin(2*np.pi*ex_fc*time + ex_pc), **ex_style)
        else:
            raise NotImpementedError(f'Excitation type {ex_type} not implemented')
    ax.set(ylim=tune_range, ylabel=f'Tune $q_{xy}$', xlabel='Time / s' if over_time else 'Turn') #  or $1-q_{xy}$')
    
    if fit:
        dq_dx = None
        try:
            fitr = fit(tune, xdata, mag)
            if fit in (fit_moving_gaussian, ):
                v0, vp, q0, dq, sigma = fitr[0]
                dx = xdata[-1]-xdata[0]
                dq = SI.Measurement(fitr[0][3], fitr[1][3], '1/s' if over_time else '')
                dq_dx = dq/dx
Philipp Niedermayer's avatar
Philipp Niedermayer committed
                ax.plot(*fitr[-1], **dict(ls=(0, (5, 10)), label=f'Fit $\\partial q_{xy}/\\partial '+('t' if over_time else 'n')+f'={dq_dx:~L}$', zorder=50, color=cmap_petroff_10(9), **fitkwarg))
                
        except RuntimeError:
            print('Warning: fit failed')
        
        return xdata, tune, mag, frev, dq_dx
            
    
    return xdata, tune, mag, frev