Skip to content
Snippets Groups Projects
plotting.py 6.86 KiB
Newer Older
import scipy
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

from .data import *


# https://arxiv.org/abs/2107.02270
petroff_colors = ["#3f90da", "#ffa90e", "#bd1f01", "#94a4a2", "#832db6", "#a96b59", "#e76300", "#b9ac70", "#717581", "#92dadd"]
cmap_petroff_10 = mpl.colors.ListedColormap(petroff_colors, 'Petroff 10')
mpl.rcParams['axes.prop_cycle'] = mpl.cycler(color=petroff_colors)

cmap_petroff_gradient = mpl.colors.LinearSegmentedColormap.from_list('Petroff gradient', [petroff_colors[i] for i in (9,0,4,2,6,1)])
cmap_petroff_gradient.set_under(petroff_colors[3])
cmap_petroff_gradient.set_over(petroff_colors[7])
mpl.rcParams['image.cmap'] = cmap_petroff_gradient

cmap_petroff_bipolar = mpl.colors.LinearSegmentedColormap.from_list('Petroff bipolar', [petroff_colors[i] for i in (2,6,1,3,9,0,4)])
cmap_petroff_bipolar.set_under(petroff_colors[5])
cmap_petroff_bipolar.set_over(petroff_colors[8])



def add_vline(axes, r, color='k', text=None, label=None, order=None, lw=1, text_top=True, text_vertical=True, zorder=-100, **kwargs):
    for i, a in enumerate(axes):
        if a is None: continue
        if order:
            kwargs['ls'] = (0, [11,2]+[1,2]*order)
            kwargs['alpha'] = 1/order
        a.axvline(r, c=color, label=label if a.get_xlim()[0] <= r <= a.get_xlim()[1] else None, zorder=5, lw=lw, **kwargs)
    if text:
        a.text(r, .99 if text_top else 0.01, f' {text} ', fontsize='small', c=color, zorder=zorder,
               alpha=1/order if order else 1, transform=a.get_xaxis_text1_transform(0)[0],
               rotation=90 if text_vertical else 0, clip_on=True, ha='right' if text_vertical else 'left',
               va='top' if text_top else 'bottom')

def to_half_intervall(q):
    """Returns the corresponding fractional tune value in the half interval [0;0.5)
    """
    q = q-np.floor(q)
    q = np.where(q < 0, q+1, q)
    return np.min([q, 1-q], axis=0)

def add_resonance_vlines(axes, max_order, color='r'):
    res = set()
    for m in range(1, max_order+1):
        for n in range(m//2+1):
            r = n/m
            if r in res: continue
            res.add(r)
            add_vline(axes, to_half_intervall(r), color, text=f'{n}/{m} resonance' if m > 1 else None, lw=3/max(1, m-2)**.5,
                      alpha=1/max(1, m-2), text_vertical=True, text_top=True)





# Libera tbt data
##################

def turn_or_time_range(time, turn_range=None, time_range=None):
    if turn_range is not None and time_range is not None: 
        raise ValueError('Parameters turn_range or time_range are mutually exclusive')
    if time_range is not None:
        return irng(time, *time_range)
    if turn_range is None:
        return slice(None, None)
    return turn_range


def plot_tbt(ax, libera_data, over_time=True, turn_range=None, time_range=None):
    """Plot turn-by-turn data    
    """
    assert isinstance(libera_data, LiberaTBTData), f'Expected LiberaTBTData but got {type(libera_data)}'
    
    turn_range = turn_or_time_range(libera_data.t, turn_range, time_range)
    t, s = libera_data.t[turn_range], libera_data.s[turn_range]
    
    ax2 = ax.twinx()
    lf, = ax.plot(*avg(t[:-1] if over_time else np.arange(0, len(s)-1), 1e-6/np.diff(t), n=500), c=cmap_petroff_10(3))
    ax.set(ylabel='$f_\\mathrm{rev}$ / MHz', ylim=(0, 1))
    ax.grid(color='lightgray')
    ls, = ax2.plot(*avg(t if over_time else np.arange(0, len(s)), s, n=500), c=cmap_petroff_10(1))
    ax2.set(ylabel='Pickup sum / a.u.', ylim=(0,1e8))
    ax2.legend([lf,ls], ['Revolution frequency', 'Pickup sum signal'], loc='center right', fontsize='small')


def plot_tune_spectrum(ax, libera_data, xy, turn_range=None, time_range=None, tune_range=None, **kwargs):
    """Plot a tune spectrum based on turn-by-turn data
    
    :param ax: Axis to plot onto
    :param libera_data: Instance of LiberaTBTData class
    :param xy: either 'x' or 'y'
    :param turn_range: tuple of (start_turn, stop_turn) for range to plot
    :param turn_range: tuple of (start_time, stop_time) in seconds for range to plot
    """
    assert isinstance(libera_data, LiberaTBTData), f'Expected LiberaTBTData but got {type(libera_data)}'
    
    turn_range = turn_or_time_range(libera_data.t, turn_range, time_range)
    tbt_data = getattr(libera_data, xy)[turn_range]
    
    fft = np.fft.rfft(tbt_data)
    freq, mag, phase = np.fft.rfftfreq(len(tbt_data), d=1), np.abs(fft), np.angle(fft)
    
    if tune_range is not None:
        mask = irng(freq, *tune_range)
        freq, mag, phase = freq[mask], mag[mask], phase[mask]
    else:
        tune_range = (0, 0.5)
    
    ax.plot(freq, mag, **kwargs)
    
    ax.set(xlim=tune_range, xlabel=f'Tune $q_{xy}$',
           ylabel='Magnitude / a.u.')


def plot_tune_spectrogram(ax, libera_data, xy, nperseg=2**12, noverlap=None, ninterpol=4, smoothing=0, over_time=True, colorbar=False,
                          turn_range=None, time_range=None, tune_range=None):
    """Plot a tune spectrogram based on turn-by-turn data
    
    :param ax: Axis to plot onto
    :param libera_data: Instance of LiberaTBTData class
    :param xy: either 'x' or 'y'
    
    :param turn_range: tuple of (start_turn, stop_turn) for range to plot
    :param turn_range: tuple of (start_time, stop_time) in seconds for range to plot
    
    """
    noverlap = (nperseg - nperseg//ninterpol) if noverlap is None else noverlap
    
    assert isinstance(libera_data, LiberaTBTData), f'Expected LiberaTBTData but got {type(libera_data)}'
    
    turn_range = turn_or_time_range(libera_data.t, turn_range, time_range)
    tbt_data = getattr(libera_data, xy)[turn_range]
    
    tune, turn, value = scipy.signal.stft(tbt_data, fs=1, nperseg=nperseg, noverlap=noverlap, window='boxcar', boundary=None, padded=False)
    time = libera_data.t[turn.astype(int)]
    mag = np.abs(value)
    #mag[0,:] = 0 # supress DC
    
    if tune_range is not None:
        mask = irng(tune, *tune_range)
        tune, mag = tune[mask], mag[mask, :]
    else:
        tune_range = (0, 0.5)
    
    if over_time:
        cm = ax.pcolormesh(time, tune, mag, shading='nearest',
                           cmap='gist_heat_r',
                           vmin=np.percentile(mag, 0.5), vmax=np.percentile(mag, 99.5),
                           #cmap='plasma_r',
                           #norm=mpl.colors.LogNorm(),
                           #norm=mpl.colors.LogNorm(vmin=np.nanmean(mag)/5, vmax=np.nanmean(mag)*1000),
                           #vmin=np.nanmean(mag), vmax=np.nanmean(mag)+np.nanstd(mag),
        )
    else:
        cm = ax.imshow(mag, extent=(turn[0], turn[-1], tune[-1], tune[0]), aspect='auto', rasterized=True, 
                       cmap='plasma_r',
                       vmin=np.percentile(mag, 1), vmax=np.percentile(mag, 99),
        )
    if colorbar: fig.colorbar(cm, label='FFT magnitude', ax=ax)
    ax.set(ylim=tune_range, ylabel=f'Tune $q_{xy}$') #  or $1-q_{xy}$')
    return time if over_time else turn, tune, mag