import numpy as np
import scipy.signal
from dataclasses import dataclass, field
import tarfile
import re

def avg(*data, n=100):
    return [d[:int(len(d)/n)*n].reshape(-1, n).mean(axis=1) for d in data]
def irng(array, start, stop):
    """Return a slice indexing the array in the given value range"""
    indices = [np.argmin(np.abs(array - start)), np.argmin(np.abs(array - stop))]
    return slice(min(indices), max(indices))

class Trace:

# timing



# lassie spill data

from .bdiolib.bdio import bdio

class LassieSpillData(Trace):
    """Container for time series data from lassie spill or DCCT application
    t - time array in seconds
    v - value array in unit
    unit - unit of value array
    events - dict mapping events to time
    t: np.ndarray
    v: np.ndarray
    unit: str = 'particles'
    events: dict = field(default_factory=dict)
    def from_file(cls, fname, from_event=None, time_offset=0, verbose=0):
        """Read in time series data from a *.tdf file saved with lassiespill or DCCT application
        :param fname: path to filename
        :param from_event: return data from this event, time will also be relative to this event                    
        b = bdio.BDIOReader(fname)

        appname = ''
        device = ''
        fs = None
        t0 = 0
        t1 = None
        unit = 'particles'
        events = {}
        blocks = b.get_directory()
        for bi in blocks:
            block = b.next_block()
            if verbose: print(bi.get_title(), block)
            #if hasattr(block, 'get_rows'):
            #    for row in block.get_rows():
            #        print(row)
            if bi.is_header_block():
                appname = block.get_app_name()
                device = block.get_device()
                if verbose: print(device, 'data saved from', appname)
            elif bi.get_title() == 'TraceInfo':
                for row in block.get_rows():
                    if verbose: print(row)
                    if row.get_tag() == ('intensity frequency' if appname=='DCCT' else 'sampling frequency'):
                        assert row.get_unit().lower() == 'hz', 'sampling frequency has unexpected unit {}'.format(row.get_unit())
                        fs = row.get_value()
            elif bi.get_title() == 'Integrals':
                for row in block.get_rows():
                    if verbose: print(row)
                    if row.get_tag() == 'Calibrated: beam in integral':
                        unit = row.get_unit()
            elif bi.is_double_array_block() and bi.get_title() == ('Intensity' if appname=='DCCT' else 'Calibrated'):
                values = np.array(block.get_array())
            elif bi.is_event_table_block():
                t0 = block.get_rows()[0].get_time() # asume the first event is the start of data # TODO: is there a better way to do it?
                for row in block.get_rows():
                    events[row.get_event()] = (row.get_time() - t0)*1e-9

        times = np.linspace(0, len(values)/fs, len(values))
        if from_event:
            if from_event not in events:
                raise ValueError(f'Requested data from event {from_event}, but event not found in data file')
            from_time = events[from_event]
            idx = np.argmin(np.abs(times - from_time))
            values = values[idx:]
            times = times[idx:] - times[idx]
            events = {e: t - from_time for e,t in events.items()}
        return LassieSpillData(times + time_offset, values, unit=unit, events=events)

# IPM data

class IPMData(Trace):
    """Container for profile data from IPM
    t - time array in seconds
    w - profile width array
    w_unit - unit of displacement array values
    x - horizontal profile amplitudes
    y - vertical profile amplitudes
    unit - unit of profile amplitudes
    t: np.ndarray
    w: np.ndarray
    x: np.ndarray
    y: np.ndarray
    unit: str = 'a.u.'
    w_unit: str = 'mm'
    x_rms: np.ndarray = None
    y_rms: np.ndarray = None
    beam_current: np.ndarray = None
    dipole: np.ndarray = None
    hf: np.ndarray = None
    def apply_calibration(self, time_range):
        """Calibrates the profile data by subtracting the average profile measured in the given timespan
        Note that this does not affect rms beam sizes
        :param time_range: tuple of (start, end) of calibration window in seconds
        window = irng(self.t, *time_range)
        self.x -= np.mean(self.x[window], axis=0)
        self.y -= np.mean(self.y[window], axis=0)
    def from_file(cls, fname, clean=False, from_event=None, time_offset=0, verbose=0):
        """Read in profile data from a *.tgz file saved with IPM application
        :param fname: path to filename
        :param clean: use adc_clean.dat instead of adc.dat
        :param from_event: return data from this event, time will also be relative to this event
            Giacomini, Tino <>
                One beam profile consist of 64 values.
                The distance between two wires is 0.6 mm. The wires diameter is 1.5 mm. So the center to center distance of the wires is 2.1 mm.
            Handbook RGM_SIS
                At SIS cyclus start event 32 appears. This event is read by the RGM and the measure cyclus starts, see Measure cyclus.
                Mit dem Start eines SIS Zyklus wird auch der RGM gestartet und nimmt Daten auf. Bei Auftreten eines vorher festgelegten Events, speichert die Software die gerade aktuelle Profilnummer und die seit dem Spillstart vergangene Zeit in ms, sog. Time Stamp Function. 
                Es wird mit einer Periode von 10ms ein Strahlprofil aufgenommen.
        tar =, "r:gz")
        # profile data
        f = tar.extractfile('adc_clean.dat' if clean else 'adc.dat')
        dat = np.loadtxt(f, skiprows=1)
        p = 2.1 * (np.arange(64)-31.5)
        n = dat.shape[0]//64*64
        if dat.shape[0]%64 != 0: print(f'WARNING: loaded incomplete data from {fname}')
        x = dat[:n,1].reshape((-1, 64))
        y = dat[:n,2].reshape((-1, 64))
        t = np.arange(x.shape[0]) * 10e-3 # Es wird mit einer Periode von 10ms ein Strahlprofil aufgenommen.
        # time data              
        dat = np.loadtxt(tar.extractfile('bwrms.dat'), skiprows=1)
        _, x_rms, y_rms = dat.T

        # scalar data
        #   File scl.dat contains the scalerdata. 
        #   Col 0 = internal timing,
        #   Col 1 = DC (Current),
        #   Col 2 = Dipole signal,
        #   Col 3 = HF signal,
        #   Col 4 = Ev40 (MB Trigger)  / Injection
        #   Col 5 = Ev45 (Flattop) / Accel end
        #   Col 6 = Ev51 (Extraction End) 
        #   Col 7 = EvXX see GUI->Parameter->Event 
        scl = np.loadtxt(tar.extractfile('scl.dat'), skiprows=1)
        index, beam_current, dipole, hf, event_40, event_45, event_51, event_user = scl.T
        offset = 0
        if from_event is None:
            offset = 0        
        elif from_event == 40:
            assert np.max(event_40) == 1, f'Requested data from event {from_event}, but event 40 not found in data'
            offset = np.argmax(event_40)
        elif from_event == 45:
            assert np.max(event_45) == 1, f'Requested data from event {from_event}, but event 45 not found in data'
            offset = np.argmax(event_45)
        elif from_event == 51:
            assert np.max(event_51) == 1, f'Requested data from event {from_event}, but event 51 not found in data'
            offset = np.argmax(event_51)
            # not a default event, maybe user event?
            f = tar.extractfile('param.dat')
            ev = None
            for line in f:
                if verbose: print(line.decode('ascii').strip())
                m = re.match(r'ev\s*=\s*(\d+)', line)
                if m:
                    ev = int(
            if from_event == ev:
                assert np.max(event_user) == 1, f'Requested data from event {from_event}, but event {ev} not found in data'
                offset = np.argmax(event_user)
                raise ValueError(f'Requested data from event {from_event} which is neither a standard event (40,45,51) nor the user event ({ev})')
        if verbose: print(f'Returning data with offset {offset} as requested by event {from_event}')
        x, y, t = x[offset:, :], y[offset:, :], t[offset:] - t[offset]
        x_rms, y_rms, beam_current, dipole, hf = [_[offset:] for _ in (x_rms, y_rms, beam_current, dipole, hf)]
        return IPMData(t + time_offset, p, x, y,
                        x_rms = x_rms,
                        y_rms = y_rms,
                        beam_current = np.concatenate((beam_current, [np.nan])),
                        dipole = np.concatenate((dipole, [np.nan])),
                        hf = np.concatenate((hf,  [np.nan])),

# Libera data

class LiberaData(Trace):
    """Container for data from libera-ireg dump
    t - time array in seconds
    t: np.ndarray
    x: np.ndarray
    y: np.ndarray
    s: np.ndarray
    x_unit: str = 'mm'
    y_unit: str = 'mm'
    s_unit: str = 'a.u.'
class LiberaBBBData(LiberaData):
    """Container for bunch-by-bunch data from libera-ireg dump
    def from_file(cls, fname, time_offset=0, verbose=0):
        """Read in bunch-by-bunch data from a *.bin file saved with libera-ireg
        :param fname: path to filename
        References: Libera_Hadron_User_Manual_1.04.pdf
        bunch = np.memmap(fname, dtype=np.dtype([
                ('S', '<i4'),
                ('r1', '<i4'),
                ('X', '<i4'),  # in nm
                ('Y', '<i4'),  # in nm
                ('TS', '<u8'), # in 4ns (ADC clock at 250MS/s)
                ('t2', '<u2'),
                ('t1', '<u2'),
                ('status', '<u4')]))
        time = (bunch['TS'] + bunch['t1'])*4e-9 # in s
        return LiberaBBBData(time + time_offset, bunch['X']*1e-6, bunch['Y']*1e-6, bunch['S'])

    def to_tbt_data(self, h, b=None):
        """Returns the turn-by-turn data
        :param h: the harmonic number of the RF system (number of bunches per turn)
        :param b: the bunch index or None for an average
        wrap = lambda a: a[:(len(a)//h)*h].reshape(-1, h).mean(axis=1) if b is None else a[b::h]
        return LiberaTBTData(t=wrap(self.t), x=wrap(self.x), y=wrap(self.y), s=wrap(self.s),
                             x_unit=self.x_unit, y_unit=self.y_unit, s_unit=self.s_unit)
class LiberaTBTData(LiberaData):
    """Container for turn-by-turn data from libera-ireg dump

class NWAData(Trace):
    """Container for data from Network Analyser save
    f - frequency in Hz
    m - magnitude
    p - phase
    f: np.ndarray
    m: np.ndarray
    p: np.ndarray
    f_unit: str = 'Hz'
    m_unit: str = 'a.u.'
    p_unit: str = 'a.u.'
    def from_file(cls, magfile, phasefile=None, *, isdeg=True, unwrap=False, verbose=0):
        """Read in data from CSV file for magnitude and phase
        :param magfile: path to filename with magnitude trace data
        :param phasefile: path to filename with phase trace data. If None, phase data is assumed to be the second column in magfile.
        :param unwrap: if true, relative phase is unwraped to absolute phase centered around zero
        :param isdeg: if phase data is in degree (True) or radians (False)
        f, m, *other = np.loadtxt(magfile, skiprows=3, delimiter=',').T
        if phasefile is None:
            p = other[0]
            fp, p, *_ = np.loadtxt(phasefile, skiprows=3, delimiter=',').T
            if np.any(f != fp):
                raise ValueError('Files provided do not have equal frequency components.')
        if unwrap:
            if isdeg: p = np.deg2rad(p)
            p = np.unwrap(p)
            p -= np.mean(p)
            if isdeg: p = np.rad2deg(p)
        return NWAData(f, m, p, p_unit='deg' if isdeg else 'rad')