Skip to content
Snippets Groups Projects
plotting.py 10.1 KiB
Newer Older
import scipy
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import pint

from .data import *
from .fitting import *
# Setup pint
SI = pint.UnitRegistry()
SI.setup_matplotlib()
SI.default_format = '~P'


# Colors
# https://arxiv.org/abs/2107.02270
petroff_colors = ["#3f90da", "#ffa90e", "#bd1f01", "#94a4a2", "#832db6", "#a96b59", "#e76300", "#b9ac70", "#717581", "#92dadd"]
cmap_petroff_10 = mpl.colors.ListedColormap(petroff_colors, 'Petroff 10')
mpl.rcParams['axes.prop_cycle'] = mpl.cycler(color=petroff_colors)

cmap_petroff_gradient = mpl.colors.LinearSegmentedColormap.from_list('Petroff gradient', [petroff_colors[i] for i in (9,0,4,2,6,1)])
cmap_petroff_gradient.set_under(petroff_colors[3])
cmap_petroff_gradient.set_over(petroff_colors[7])
mpl.rcParams['image.cmap'] = cmap_petroff_gradient

cmap_petroff_bipolar = mpl.colors.LinearSegmentedColormap.from_list('Petroff bipolar', [petroff_colors[i] for i in (2,6,1,3,9,0,4)])
cmap_petroff_bipolar.set_under(petroff_colors[5])
cmap_petroff_bipolar.set_over(petroff_colors[8])



def add_vline(axes, r, color='k', text=None, label=None, order=None, lw=1, text_top=True, text_vertical=True, zorder=-100, **kwargs):
    for i, a in enumerate(axes):
        if a is None: continue
        if order:
            kwargs['ls'] = (0, [11,2]+[1,2]*order)
            kwargs['alpha'] = 1/order
        a.axvline(r, c=color, label=label if a.get_xlim()[0] <= r <= a.get_xlim()[1] else None, zorder=5, lw=lw, **kwargs)
    if text:
        a.text(r, .99 if text_top else 0.01, f' {text} ', fontsize='small', c=color, zorder=zorder,
               alpha=1/order if order else 1, transform=a.get_xaxis_text1_transform(0)[0],
               rotation=90 if text_vertical else 0, clip_on=True, ha='right' if text_vertical else 'left',
               va='top' if text_top else 'bottom')

def to_half_intervall(q):
    """Returns the corresponding fractional tune value in the half interval [0;0.5)
    """
    q = q-np.floor(q)
    q = np.where(q < 0, q+1, q)
    return np.min([q, 1-q], axis=0)

def add_resonance_vlines(axes, max_order, color='r'):
    res = set()
    for m in range(1, max_order+1):
        for n in range(m//2+1):
            r = n/m
            if r in res: continue
            res.add(r)
            add_vline(axes, to_half_intervall(r), color, text=f'{n}/{m} resonance' if m > 1 else None, lw=3/max(1, m-2)**.5,
                      alpha=1/max(1, m-2), text_vertical=True, text_top=True)

def subplot_shared_labels(axes, xlabel=None, ylabel=None, clear=True):
    """Adds labels to shared axes as needed
    
    :param axes: 2D array of axes from subplots (pass squeeze=False to plt.subplots if required)
    :param xlabel: the shared xlabel
    :param ylabel: the shared ylabel
    :param clear: if true (default) any existing labels on the axes will get cleared
    """
    for r in range(axes.shape[0]):
        for c in range(axes.shape[1]):
            if clear: axes[r,c].set(xlabel=None, ylabel=None)
            if c == 0 or not axes[r,c].get_shared_y_axes().joined(axes[r,c], axes[r,0]):
                axes[r,c].set(ylabel=ylabel)
            if r == axes.shape[0]-1 or not axes[r,c].get_shared_x_axes().joined(axes[r,c], axes[-1,c]):
                axes[r,c].set(xlabel=xlabel)




# Libera tbt data
##################

def turn_or_time_range(time, turn_range=None, time_range=None):
    if turn_range is not None and time_range is not None: 
        raise ValueError('Parameters turn_range or time_range are mutually exclusive')
    if time_range is not None:
        return irng(time, *time_range)
    if turn_range is None:
        return slice(None, None)
    return turn_range


def plot_tbt(ax, libera_data, over_time=True, turn_range=None, time_range=None):
    """Plot turn-by-turn data    
    """
    assert isinstance(libera_data, LiberaTBTData), f'Expected LiberaTBTData but got {type(libera_data)}'
    
    turn_range = turn_or_time_range(libera_data.t, turn_range, time_range)
    t, s = libera_data.t[turn_range], libera_data.s[turn_range]
    
    ax2 = ax.twinx()
    lf, = ax.plot(*avg(t[:-1] if over_time else np.arange(0, len(s)-1), 1e-6/np.diff(t), n=500), c=cmap_petroff_10(3))
    ax.set(ylabel='$f_\\mathrm{rev}$ / MHz', ylim=(0, 1))
    ax.grid(color='lightgray')
    ls, = ax2.plot(*avg(t if over_time else np.arange(0, len(s)), s, n=500), c=cmap_petroff_10(1))
    ax2.set(ylabel='Pickup sum / a.u.', ylim=(0,1e8))
    ax2.legend([lf,ls], ['Revolution frequency', 'Pickup sum signal'], loc='center right', fontsize='small')



def plot_tune_spectrum(ax, libera_data, xy, turn_range=None, time_range=None, tune_range=None, fit=False, **kwargs):
    """Plot a tune spectrum based on turn-by-turn data
    
    :param ax: Axis to plot onto
    :param libera_data: Instance of LiberaTBTData class
    :param xy: either 'x' or 'y'
    :param turn_range: tuple of (start_turn, stop_turn) for range to plot
    :param turn_range: tuple of (start_time, stop_time) in seconds for range to plot
    """
    assert isinstance(libera_data, LiberaTBTData), f'Expected LiberaTBTData but got {type(libera_data)}'
    
    turn_range = turn_or_time_range(libera_data.t, turn_range, time_range)
    tbt_data = getattr(libera_data, xy)[turn_range]
    
    fft = np.fft.rfft(tbt_data)
    freq, mag, phase = np.fft.rfftfreq(len(tbt_data), d=1), np.abs(fft), np.angle(fft)
    
    if tune_range is not None:
        mask = irng(freq, *tune_range)
        freq, mag, phase = freq[mask], mag[mask], phase[mask]
    else:
        tune_range = (0, 0.5)
    
    ax.plot(freq, mag, **kwargs)
    
    ax.set(xlim=tune_range, xlabel=f'Tune $q_{xy}$',
           ylabel='a.u.')
    
    if fit:
        try:
            fitr = (fit if callable(fit) else fit_lorenzian)(freq, mag)
            q, w = fitr[0][2], fitr[0][3]
            if q<0 or q>0.5 or w > 0.01:
                raise RuntimeError('Fit failed')
        except RuntimeError:
            print('Warning: fit failed')
        else:
            q = SI.Measurement(fitr[0][2], (fitr[1][2]**2+fitr[0][3]**2+fitr[1][3]**2)**0.5, '') # conservative estimate of error including width of peak
            ax.plot(*fitr[-1], '--', label=f'Fit $q_{xy}={q:~L}$', zorder=50)

def plot_tune_spectrogram(ax, libera_data, xy, nperseg=2**12, noverlap=None, ninterpol=4, over_time=True, colorbar=False,
                          turn_range=None, time_range=None, tune_range=None, excitation=None,
                          cmap='gist_heat_r'):
    """Plot a tune spectrogram based on turn-by-turn data
    
    :param ax: Axis to plot onto
    :param libera_data: Instance of LiberaTBTData class
    :param xy: either 'x' or 'y'
    
    :param over_time: plot data as function of time (if true) or turn number (otherwise).
                         Note that plotting over time is not suited for interactive plots.
    :param colorbar: include colorbar
    :param turn_range: tuple of (start_turn, stop_turn) for range to plot
    :param time_range: tuple of (start_time, stop_time) in seconds for range to plot
    :param tune_range: tuple of (start_tune, stop_tune) ror range to plot    
    :param excitation: tuple of (type, tune, ...) to indicate excitation range. Supported types are:
                         Excitation frequency band: ('band', tune, bandwidth_in_Hz)
                         Sinusoidal excitation: ('sine', tune)
                         Chirp excitation: ('chirp', tune, bandwidth_in_Hz, chirp_frequency_in_Hz, chirp_phase_in_rad)
    
    """
    noverlap = (nperseg - nperseg//ninterpol) if noverlap is None else noverlap
    
    assert isinstance(libera_data, LiberaTBTData), f'Expected LiberaTBTData but got {type(libera_data)}'
    
    turn_range = turn_or_time_range(libera_data.t, turn_range, time_range)
    tbt_data = getattr(libera_data, xy)[turn_range]
    
    tune, turn, value = scipy.signal.stft(tbt_data, fs=1, nperseg=nperseg, noverlap=noverlap, window='boxcar', boundary=None, padded=False)
    time = libera_data.t[turn_range][turn.astype(int)]
    frev = scipy.ndimage.uniform_filter1d(1/np.diff(libera_data.t)[turn_range], size=nperseg)[turn.astype(int)]
    mag = np.abs(value)
    #mag[0,:] = 0 # supress DC
    
    xdata = time if over_time else turn
    
    if tune_range is not None:
        mask = irng(tune, *tune_range)
        tune, mag = tune[mask], mag[mask, :]
    else:
        tune_range = (0, 0.5)
    
    if over_time:
        # pcolormesh can handle non-equidistant data, but is not well suited for interactive plots (slow!)
        cm = ax.pcolormesh(xdata, tune, mag, shading='nearest', cmap=cmap,
                           #cmap='gist_heat_r',
                           vmin=np.percentile(mag, 0.5), vmax=np.percentile(mag, 99.5),
                           #cmap='plasma_r',
                           #norm=mpl.colors.LogNorm(),
                           #norm=mpl.colors.LogNorm(vmin=np.nanmean(mag)/5, vmax=np.nanmean(mag)*1000),
                           #vmin=np.nanmean(mag), vmax=np.nanmean(mag)+np.nanstd(mag),
        )
    else:
        # imshow is much faster, but can only handle equidistant data
        cm = ax.imshow(mag, extent=(xdata[0], xdata[-1], tune[-1], tune[0]), aspect='auto', rasterized=True, 
                       cmap=cmap, #cmap='plasma_r',
                       vmin=np.percentile(mag, 1), vmax=np.percentile(mag, 99),
        )
    if colorbar: fig.colorbar(cm, label='FFT magnitude', ax=ax)
    
    if excitation is not None:
        ex_type, ex_q, *ex_args = excitation
        ex_style = dict(lw=1, ls=':', c='k', alpha=0.5, zorder=50)
        if ex_type == 'band':
            ex_dq = ex_args[0]/frev
            ax.plot(xdata, ex_q+ex_dq, xdata, ex_q-ex_dq, **ex_style)
        elif ex_type == 'sine':
            ax.plot(xdata, ex_q*np.ones_like(xdata), **ex_style)
        elif ex_type == 'chirp':
            ex_dq = ex_args[0]/frev
            ex_fc, ex_pc = ex_args[1:3]
            ax.plot(xdata, ex_q + ex_dq/2*np.sin(2*np.pi*ex_fc*time + ex_pc), **ex_style)
        else:
            raise NotImpementedError(f'Excitation type {ex_type} not implemented')
            
            
            
    
    ax.set(ylim=tune_range, ylabel=f'Tune $q_{xy}$') #  or $1-q_{xy}$')
    return xdata, tune, mag, frev