import rlxnix as rlx
import numpy as np
import matplotlib.pyplot as plt
import os
from scipy.signal import welch

# close all currently open figures
plt.close('all')

'''FUNCTIONS'''
def vt_spikes(dataset):
    # get sams
    sams = dataset.repro_runs('SAM')
    sam = sams[2]

    # get potetial over time (vt curve)
    potential, time = sam.trace_data('V-1')

    # get spike times
    spike_times, _ = sam.trace_data('Spikes-1')

    # plot
    fig = plt.figure(figsize=(5, 2.5))
    # alternative to ax = axs[0]
    ax = fig.add_subplot()
    # plot vt diagram
    ax.plot(time[time<0.1], potential[time<0.1])
    # plot spikes into vt diagram, at max V
    ax.scatter(spike_times[spike_times<0.1], np.ones_like(spike_times[spike_times<0.1]) * np.max(potential))
    plt.show()
    
    return sam

def scatter_plot(sam1):

    ### plot scatter plot for one sam with all 3 stims
    # get stim count
    stim_count = sam1.stimulus_count
    
    # create colormap
    colors = plt.cm.prism(np.linspace(0, 1, stim_count))
    
    # plot
    fig = plt.figure()
    ax = fig.add_subplot()
    
    stimuli = []
    for i in range(stim_count):
        # get stim i from sam
        stim = sam.stimuli[i]
        potential_stim, time_stim = stim.trace_data('V-1')
        # get spike_times
        spike_times_stim, _ = stim.trace_data('Spikes-1')
        stimuli.append(spike_times_stim)
        
    ax.eventplot(stimuli, colors=colors)
    ax.set_xlabel('Spike Times [ms]')
    ax.set_ylabel('Loop #')
    ax.set_yticks(range(stim_count))
    ax.set_title('Spikes of SAM 3')
    plt.show()
    return stim, stim_count, time_stim

# create binary array with ones for spike times
def binary_spikes(spike_times, duration , dt):
    '''Converts spike times to binary representation
    Params
    ------
    spike_times: np.array
        spike times
    duration: float
        trial duration
    dt: float
        temporal resolution
    
    Returns
    --------
    binary: np.array
        The binary representation of the spike times
    '''
    binary = np.zeros(int(duration//dt)) # // is truncated division, returns number w/o decimals, same as np.round
    spike_indices = np.asarray(np.round(spike_times//dt), dtype=int)
    binary[spike_indices] = 1 
    return binary

# function to plot psth
def firing_rates(binary_spikes, box_width=0.01, dt=0.000025):
    box = np.ones(int(box_width // dt))
    box /= np.sum(box * dt) # normalize box kernel w interal of 1
    rate = np.convolve(binary_spikes, box, mode='same')
    return rate

def power_spectrum(rate, dt):
    f, p = welch(rate, fs = 1./dt, nperseg=2**16, noverlap=2**15) 
    # algorithm makes rounding mistakes, we want to calc many spectra and take mean of those 
    # nperseg: length of segments in # datapoints
    # noverlap: # datapoints that overlap in segments
    return f, p
    


'''IMPORT DATA'''
datafolder = '../data' #./ wo ich gerade bin; ../ eine ebene höher; ../../ zwei ebenen höher

example_file = os.path.join('..', 'data', '2024-10-16-ac-invivo-1.nix')

# extract metadata
dataset = rlx.Dataset(example_file)

### plot
# timeline of whole rec
dataset.plot_timeline()

# voltage and spikes of current sam
sam = vt_spikes(dataset)

# spike times of all loops
stim, stim_count, time_stim = scatter_plot(sam)


'''POWER SPECTRUM'''
# define variables for binary spikes function
spikes, _ = stim.trace_data('Spikes-1')
ti = stim.trace_info('V-1')
dt = ti.sampling_interval
duration = stim.duration

### spectrum
# vector with binary values for wholes length of stim
binary = binary_spikes(spikes, duration, dt)

# calculate firing rate 
rate = firing_rates(binary, 0.01, dt) # box width of 10 ms

# plot psth or whatever
# plt.plot(time_stim, rate)
# plt.show()

freq, power = power_spectrum(binary, dt)

# plot power spectrum
fig = plt.figure()
ax = fig.add_subplot()
ax.plot(freq, power)
ax.set_xlabel('Frequency [Hz]')
ax.set_ylabel('Power [1/Hz]')
ax.set_xlim(0, 1000)
plt.show()

eodf = stim.metadata[stim.name]['EODF'][0][0]
df = stim.metadata['RePro-Info']['settings']['deltaf'][0][0]
stimulus_freq = df + eodf


### TODO:
    # then loop over sams/dfs, all stims, intensities
    # when does stim start in eodf/ at which phase and how does that influence our signal --> alignment problem: egal wenn wir spectren haben
    # we want to see peaks at phase locking to own and stim frequency, and at amp modulation frequency
    # clean up current code (define variables outside of functions, plot spectrum in function)
    # git
    # tuning curve over stim intensities or over delta f?