Skip to content
Snippets Groups Projects
plotting.py 18 KiB
Newer Older
import scipy
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import pint

from .data import *
from .fitting import *
# Setup pint
SI = pint.UnitRegistry()
SI.setup_matplotlib()
SI.default_format = '~P'


# Colors
# https://arxiv.org/abs/2107.02270
petroff_colors = ["#3f90da", "#ffa90e", "#bd1f01", "#94a4a2", "#832db6", "#a96b59", "#e76300", "#b9ac70", "#717581", "#92dadd"]
cmap_petroff_10 = mpl.colors.ListedColormap(petroff_colors, 'Petroff 10')
mpl.rcParams['axes.prop_cycle'] = mpl.cycler(color=petroff_colors)

cmap_petroff_gradient = mpl.colors.LinearSegmentedColormap.from_list('Petroff gradient', [petroff_colors[i] for i in (9,0,4,2,6,1)])
cmap_petroff_gradient.set_under(petroff_colors[3])
cmap_petroff_gradient.set_over(petroff_colors[7])
mpl.rcParams['image.cmap'] = cmap_petroff_gradient

cmap_petroff_bipolar = mpl.colors.LinearSegmentedColormap.from_list('Petroff bipolar', [petroff_colors[i] for i in (2,6,1,3,9,0,4)])
cmap_petroff_bipolar.set_under(petroff_colors[5])
cmap_petroff_bipolar.set_over(petroff_colors[8])



def add_vline(axes, r, color='k', text=None, label=None, order=None, lw=1, text_top=True, text_vertical=True, zorder=-100, **kwargs):
    for i, a in enumerate(axes):
        if a is None: continue
        if order:
            kwargs['ls'] = (0, [11,2]+[1,2]*order)
            kwargs['alpha'] = 1/order
        a.axvline(r, c=color, label=label if a.get_xlim()[0] <= r <= a.get_xlim()[1] else None, zorder=5, lw=lw, **kwargs)
    if text:
        a.text(r, .99 if text_top else 0.01, f' {text} ', fontsize='small', c=color, zorder=zorder,
               alpha=1/order if order else 1, transform=a.get_xaxis_text1_transform(0)[0],
               rotation=90 if text_vertical else 0, clip_on=True, ha='right' if text_vertical else 'left',
               va='top' if text_top else 'bottom')

def to_half_intervall(q):
    """Returns the corresponding fractional tune value in the half interval [0;0.5)
    """
    q = q-np.floor(q)
    q = np.where(q < 0, q+1, q)
    return np.min([q, 1-q], axis=0)

def add_resonance_vlines(axes, max_order, color='r'):
    res = set()
    for m in range(1, max_order+1):
        for n in range(m//2+1):
            r = n/m
            if r in res: continue
            res.add(r)
            add_vline(axes, to_half_intervall(r), color, text=f'{n}/{m} resonance' if m > 1 else None, lw=3/max(1, m-2)**.5,
                      alpha=1/max(1, m-2), text_vertical=True, text_top=True)

def subplot_shared_labels(axes, xlabel=None, ylabel=None, clear=True):
    """Adds labels to shared axes as needed
    
    :param axes: 2D array of axes from subplots (pass squeeze=False to plt.subplots if required)
    :param xlabel: the shared xlabel
    :param ylabel: the shared ylabel
    :param clear: if true (default) any existing labels on the axes will get cleared
    """
    for r in range(axes.shape[0]):
        for c in range(axes.shape[1]):
Philipp Niedermayer's avatar
Philipp Niedermayer committed
            if ylabel:
                if clear: axes[r,c].set(ylabel=None)
                if c == 0 or not axes[r,c].get_shared_y_axes().joined(axes[r,c], axes[r,0]):
                    axes[r,c].set(ylabel=ylabel)
            if xlabel:
                if clear: axes[r,c].set(xlabel=None)
                if r == axes.shape[0]-1 or not axes[r,c].get_shared_x_axes().joined(axes[r,c], axes[-1,c]):
                    axes[r,c].set(xlabel=xlabel)
Philipp Niedermayer's avatar
Philipp Niedermayer committed
def grid_diagonal(ax, **kwargs):
    for k, v in dict(color='lightgray',lw=1,zorder=-100).items():
        kwargs.setdefault(k, v)
    xlim, ylim = ax.get_xlim(), ax.get_ylim()
    xtick, ytick = ax.get_xticks(), ax.get_yticks()
    for x in xtick:
        if xlim[0] <= x <= xlim[1]:
            ax.axline((x, ylim[0]), slope=1, **kwargs)
    for y in ytick:
        if ylim[0] <= y <= ylim[1]:
            ax.axline((xlim[0], y), slope=1, **kwargs)
    ax.set(xlim=xlim, ylim=ylim)
    
    
def fiberplot(ax, datasets, *, labels=[None, None], vertical=True):
    """Create a 
    :param datasets: 2D array of datasets (position, x, y) or dict with two levels
    """
    if type(datasets) is dict:
        labels = list(datasets.keys())
        datasets = [[(p, *v) for p, v in dataset.items()] for dataset in datasets.values()]
    yticks = list({p for dataset in datasets for p, *_ in dataset})
    dy = np.min(np.diff(sorted(yticks)))/2
    
    for c, dataset in enumerate(datasets):
        color = next(ax._get_lines.prop_cycler)['color']
        for i, (p, x, y, *_) in enumerate(dataset):
            v = dy*(y-np.min(y))/np.max(y)
            plot_function = ax.fill_betweenx if vertical else ax.fill_between
            plot_function(x, p - v*(1-c%2), p + v*(c%2 if len(datasets)>1 else 1),
                          color=color, lw=0, label=labels[c] if i == 0 else None)
    if vertical:
        ax.set(xticks=yticks)
    else:
        ax.set(yticks=yticks)

def format_axis_radians(yaxis):
    yaxis.set_major_locator(mpl.ticker.MultipleLocator(np.deg2rad(360)))
    yaxis.set_minor_locator(mpl.ticker.MultipleLocator(np.deg2rad(180)))
    yaxis.set_major_formatter(mpl.ticker.FuncFormatter(lambda x, p: '0' if x==0 else '$\\pi$' if x==np.pi else '$-\\pi$' if x==-np.pi else f'${x/np.pi:g}\\pi$'))
    yaxis.set_minor_formatter(mpl.ticker.FuncFormatter(lambda x, p: '0' if x==0 else '$\\pi$' if x==np.pi else '$-\\pi$' if x==-np.pi else None))

def format_axis_degrees(yaxis):
    yaxis.set_major_locator(mpl.ticker.MaxNLocator('auto', steps=[1, 1.5, 3, 6, 9, 10]))
    yaxis.set_minor_locator(mpl.ticker.MultipleLocator(30))


def add_scale(ax, scale, text=None, *, size=0.01, padding=0.1, loc='lower right', color='k', fontsize='x-small'):
    """Make a scale or yardstick patch"""
    aux = mpl.offsetbox.AuxTransformBox(mpl.transforms.blended_transform_factory(ax.transData, ax.transAxes))
    aux.add_artist(plt.Rectangle((0,0), scale, size, fc=color))
    if text: aux.add_artist(plt.Text(0, size, text, color=color, va='bottom', fontsize=fontsize))
    ab = mpl.offsetbox.AnchoredOffsetbox(loc, borderpad=padding, zorder=100, frameon=False)
    ab.set_child(aux)
    ax.add_artist(ab)
    


# Libera tbt data
##################

def turn_or_time_range(time, turn_range=None, time_range=None):
    if turn_range is not None and time_range is not None: 
        raise ValueError('Parameters turn_range or time_range are mutually exclusive')
    if time_range is not None:
        return irng(time, *time_range)
    if turn_range is None:
        return slice(None, None)
def plot_tbt(ax, libera_data, what='fsxy', *, over_time=True, turn_range=None, time_range=None):
    """Plot turn-by-turn data    
    
    :param what: list of signals to plot: f, s, x, y
    """
    assert isinstance(libera_data, LiberaTBTData), f'Expected LiberaTBTData but got {type(libera_data)}'
    
    axlabels = dict(f='$f_\\mathrm{rev}$ / MHz', s='Pickup sum / a.u.', x='Position / mm', y='Position / mm')
    
    turn_range = turn_or_time_range(libera_data.t, turn_range, time_range)
    t = libera_data.t[turn_range]
    x = t if over_time else np.arange(0, len(t))
    axes = {}
    ls, labels = [], []
    for i, w in enumerate(what):
        if axlabels[w] in axes:
            a = axes[axlabels[w]]
        else:
            a = ax.twinx() if i > 0 else ax
            if i > 0: a.spines.right.set_position(("axes", 0.9+0.1*len(axes)))
            a.set(ylabel=axlabels[w])
            axes.update({axlabels[w]: a})
        
        if w == 'f':
            v = 1e-6/np.diff(t)
        else:
            v = getattr(libera_data, w)[turn_range]
            
        l, = a.plot(*avg(x[:len(v)], v, n=500), c=dict(f=cmap_petroff_10(1), s=cmap_petroff_10(3), x=cmap_petroff_10(0), y=cmap_petroff_10(2))[w])
        ls.append(l)
        labels.append(dict(f='Revolution frequency', s='Pickup sum signal', x='X position', y='Y position')[w])
    
    if len(what) > 1: a.legend(ls, labels, fontsize='small')
def plot_btf(axf, axp, data, *, frev=None, **kwargs):
    
    f = data.f*(1/frev if frev else 1)
    axf.plot(f, data.m, **kwargs)
    axp.plot(f, data.p, **kwargs)

    axf.set(ylabel=f'Magnitude / {data.m_unit}', xlim=(np.min(f), np.max(f)))
    axp.set(ylabel=f'Phase / {data.p_unit}', xlabel='Stimulus tune' if frev else f'Stimulus frequency / {data.f_unit}')
Philipp Niedermayer's avatar
Philipp Niedermayer committed
def plot_tune_spectrum(ax, libera_data, xy, turn_range=None, time_range=None, tune_range=None, fit=False, smoothing=None, return_spectrum=False, **kwargs):
    """Plot a tune spectrum based on turn-by-turn data
    
    :param ax: Axis to plot onto
    :param libera_data: Instance of LiberaTBTData class
    :param xy: either 'x' or 'y'
    :param turn_range: tuple of (start_turn, stop_turn) for range to plot
    :param turn_range: tuple of (start_time, stop_time) in seconds for range to plot
    """
    assert isinstance(libera_data, LiberaTBTData), f'Expected LiberaTBTData but got {type(libera_data)}'
    
    turn_range = turn_or_time_range(libera_data.t, turn_range, time_range)
    tbt_data = getattr(libera_data, xy)[turn_range]
    
    fft = np.fft.rfft(tbt_data)
    freq, mag, phase = np.fft.rfftfreq(len(tbt_data), d=1), np.abs(fft), np.angle(fft)
    
    if tune_range is not None:
        mask = irng(freq, *tune_range)
        freq, mag, phase = freq[mask], mag[mask], phase[mask]
    else:
        tune_range = (0, 0.5)
Philipp Niedermayer's avatar
Philipp Niedermayer committed
        
    if smoothing is not None:
        ls, = ax.plot(freq, mag, **dict(kwargs, zorder=-100, alpha=0.1))
        if 'c' not in kwargs and 'color' not in kwargs: kwargs.update(c=ls.get_color())
        add_scale(ax, smoothing*np.mean(np.diff(freq)))
        #freq, mag, phase = avg(freq, mag, phase, n=smoothing)
        freq, mag, phase = [scipy.signal.savgol_filter(_, smoothing, 0) for _ in (freq, mag, phase)] # savgol filter with order 0 is moving average

    ax.plot(freq, mag, **kwargs)
    
    ax.set(xlim=tune_range, xlabel=f'Tune $q_{xy}$',
           ylabel='a.u.')
    
    if fit:
            if fit is True: fit = fit_lorenzian
            fitr = fit(freq, mag)
            if fit in (fit_lorenzian, fit_gaussian):
                q, w = fitr[0][2], fitr[0][3]
                if q<np.min(freq) or q>np.max(freq) or w > 0.1:
                    raise RuntimeError('Fit failed')
                q = SI.Measurement(fitr[0][2], (fitr[1][2]**2+fitr[0][3]**2+fitr[1][3]**2)**0.5, '') # conservative estimate of error including width of peak
                ax.plot(*fitr[-1], '--', label=f'Fit $q_{xy}={q:~L}$', zorder=50)
            
            elif fit in (fit_double_lorenzian, ):
                q1, w1, q2, w2 = fitr[0][2], fitr[0][3], fitr[0][5], fitr[0][6]
                if q1<np.min(freq) or q1>np.max(freq) or w1 > 0.1 or q2<np.min(freq) or q2>np.max(freq) or w2 > 0.1:
                    raise RuntimeError('Fit failed')
                q1 = SI.Measurement(fitr[0][2], (fitr[1][2]**2+fitr[0][3]**2+fitr[1][3]**2)**0.5, '') # conservative estimate of error including width of peak
                q2 = SI.Measurement(fitr[0][5], (fitr[1][5]**2+fitr[0][6]**2+fitr[1][6]**2)**0.5, '') # conservative estimate of error including width of peak
                ax.plot(*fitr[-1], '--', label=f'Fit $q_{{{xy},1}}={q1:~L}$\nFit $q_{{{xy},2}}={q2:~L}$', zorder=50)
                q = (q1, q2)
                
        except RuntimeError:
            print('Warning: fit failed')
Philipp Niedermayer's avatar
Philipp Niedermayer committed
        
        if return_spectrum:
            return freq, mag, phase, q
        return q
Philipp Niedermayer's avatar
Philipp Niedermayer committed
    if return_spectrum:
        return freq, mag, phase
def plot_tune_spectrogram(ax, libera_data, xy, *, nperseg=2**12, noverlap=None, ninterpol=None, over_time=True, colorbar=False,
                          turn_range=None, time_range=None, tune_range=None, excitation=None, vmin=None, vmax=None,
                          cmap='gist_heat_r', bunches=None, show_nperseg=True):
    """Plot a tune spectrogram based on turn-by-turn data
    
    :param ax: Axis to plot onto
    :param libera_data: Instance of LiberaTBTData or LiberaBBBData
                        When passing LiberaTBTData, the spectrogram for the single bunch is returned
                        When passing LiberaBBBData, the resulting spectrogram will be an interleave of the single bunch spectra
                        Hint: Use LiberaBBBData.to_tbt_data(h, b) to extract single bunch data
    :param bunches: bunch number in case libera_data is an instance of LiberaBBBData
    :param xy: either 'x' or 'y'
    
    :param over_time: plot data as function of time (if true) or turn number (otherwise).
                         Note that plotting over time is not suited for interactive plots.
    :param turn_range: tuple of (start_turn, stop_turn) for range to plot
    :param time_range: tuple of (start_time, stop_time) in seconds for range to plot
    :param tune_range: tuple of (start_tune, stop_tune) ror range to plot    
    :param excitation: tuple of (type, tune, ...) to indicate excitation range. Supported types are:
                         Excitation frequency band: ('band', tune, bandwidth_in_Hz)
                         Sinusoidal excitation: ('sine', tune)
                         Chirp excitation: ('chirp', tune, bandwidth_in_Hz, chirp_frequency_in_Hz, chirp_phase_in_rad)
    :param vmin: minimum value for colorscale
    :param vmax: maximum value for colorscale
    :param colorbar: add colorbar to plot
    if bunches is None:
        # STFT on turn by turn data
        assert isinstance(libera_data, LiberaTBTData), f'Either pass number of bunches or single bunch LiberaTBTData (but got {type(libera_data)}).'
        if ninterpol is None: ninterpol = 4
    else:
        # interleaved STFT on bunch-by-bunch data
        assert isinstance(libera_data, LiberaBBBData), f'Passing number of bunches requires LiberaBBBData (but got {type(libera_data)}).'
        if ninterpol is None: ninterpol = 1
        
    noverlap = (nperseg - nperseg//ninterpol) if noverlap is None else noverlap
    # iterate bunches
    turns, values, times, frevs = [], [], [], []
    iterate_bunches = [0] if bunches is None else range(bunches)
    for b in iterate_bunches:
        # extract turn-by-turn data for single bunch
        _libera_data = libera_data if bunches is None else libera_data.to_tbt_data(bunches, b=b)
        _turn_range = turn_or_time_range(_libera_data.t, turn_range, time_range) # crop by turn or time
        _b_range = slice(None, None) if bunches is None else slice(nperseg*b//bunches, None) # to shift fft window
        _tbt_data = getattr(_libera_data, xy)[_turn_range][_b_range]

        # compute STFT
        tune, idx, _value = scipy.signal.stft(_tbt_data, fs=1, nperseg=nperseg, noverlap=noverlap, window='boxcar', boundary=None, padded=False)
        idx = idx.astype(int)
        _turn = np.arange(len(_libera_data.t))[_turn_range][_b_range][0] + idx
        _time = _libera_data.t[_turn_range][_b_range][idx]
        _frev = scipy.ndimage.uniform_filter1d(1/np.diff(_libera_data.t)[_turn_range][_b_range], size=nperseg)[idx]
        turns.append(_turn); values.append(_value); times.append(_time); frevs.append(_frev)

    # interleave STFTs from single bunches
    # (we can't just pass fs=bunches to STFT because their phase is not related)
    if bunches is not None and bunches > 1 and noverlap > 0:
        raise NotImplementedError('Multi-bunch spectra are currently not supported with non-zero overlap in STFT. Either pass TBTData or noverlap=0')
    turn = np.empty((sum([t.shape[0] for t in turns]),), dtype=turns[0].dtype)
    time = np.empty((sum([t.shape[0] for t in times]),), dtype=times[0].dtype)
    frev = np.empty((sum([t.shape[0] for t in frevs]),), dtype=frevs[0].dtype)
    value = np.empty((values[0].shape[0], sum([v.shape[1] for v in values])), dtype=values[0].dtype)
    for b in iterate_bunches:
        turn[b::bunches] = turns[b]
        time[b::bunches] = times[b]
        frev[b::bunches] = frevs[b]
        value[:,b::bunches] = values[b]

    # magnitude
    mag = np.abs(value)
    #mag[0,:] = 0 # supress DC
    
    xdata = time if over_time else turn
    
    if tune_range is not None:
        mask = irng(tune, *tune_range)
        tune, mag = tune[mask], mag[mask, :]
    else:
        tune_range = (0, 0.5)
    
    if over_time:
        # pcolormesh can handle non-equidistant data, but is not well suited for interactive plots (slow!)
        cm = ax.pcolormesh(xdata, tune, mag, shading='nearest', cmap=cmap,
                           #cmap='gist_heat_r',
                           vmin=vmin or np.percentile(mag, 0.5), vmax=vmax or np.percentile(mag, 99.5),
                           #cmap='plasma_r',
                           #norm=mpl.colors.LogNorm(),
                           #norm=mpl.colors.LogNorm(vmin=np.nanmean(mag)/5, vmax=np.nanmean(mag)*1000),
                           #vmin=np.nanmean(mag), vmax=np.nanmean(mag)+np.nanstd(mag),
        )
    else:
        # imshow is much faster, but can only handle equidistant data
        cm = ax.imshow(mag, extent=(xdata[0], xdata[-1], tune[-1], tune[0]), aspect='auto', rasterized=True, 
                       cmap=cmap, #cmap='plasma_r',
                       vmin=vmin or np.percentile(mag, 1), vmax=vmax or np.percentile(mag, 99),
        
    if colorbar:
        ax.get_figure().colorbar(cm, label='FFT magnitude', ax=ax)
    
    if show_nperseg and not over_time:
        add_scale(ax, nperseg)
    if excitation is not None:
        ex_type, ex_q, *ex_args = excitation
        ex_style = dict(lw=1, ls=':', c='k', alpha=0.5, zorder=50)
        if ex_type == 'band':
            ex_dq = ex_args[0]/frev
            ax.plot(xdata, ex_q+ex_dq, xdata, ex_q-ex_dq, **ex_style)
        elif ex_type == 'sine':
            ax.plot(xdata, ex_q*np.ones_like(xdata), **ex_style)
        elif ex_type == 'chirp':
            ex_dq = ex_args[0]/frev
            ex_fc, ex_pc = ex_args[1:3]
            ax.plot(xdata, ex_q + ex_dq/2*np.sin(2*np.pi*ex_fc*time + ex_pc), **ex_style)
        else:
            raise NotImpementedError(f'Excitation type {ex_type} not implemented')
            
            
            
    
    ax.set(ylim=tune_range, ylabel=f'Tune $q_{xy}$', xlabel='Time / s' if over_time else 'Turn') #  or $1-q_{xy}$')
    return xdata, tune, mag, frev