jar_project/jar_functions.py
2020-08-26 18:13:49 +02:00

212 lines
7.4 KiB
Python

import os #compability with windows
from IPython import embed
import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
def step_response(t, a1, a2, tau1, tau2):
r_step = (a1*(1 - np.exp(-t/tau1))) + (a2*(1 - np.exp(-t/tau2)))
r_step[t<0] = 0
return r_step
def sin_response(t, f, p, A):
r_sin = A*np.sin(2*np.pi*t*f + p)
return r_sin
def parse_dataset(dataset_name):
assert(os.path.exists(dataset_name)) #see if data exists
f = open(dataset_name, 'r') #open data we gave in
lines = f.readlines() #read data
f.close() #?
# metadata lists for every loop
eodfs = []
deltafs = []
stimulusfs = []
duration = []
pause = []
# data itself
times = []
frequencies = []
amplitudes = []
# temporary lists with data we put in the lists above
time = []
ampl = []
freq = []
for i in range(len(lines)):
l = lines[i].strip() #all lines of textdata, exclude all empty lines (empty () default for spacebar)
if "#" in l and "EODf" in l: #if line starts with # EODf:
eodfs.append(float(l.split(':')[-1].strip()[:-2])) #append: line splitted by ':' the 2nd part ([-1],
if "#" in l and "Delta f" in l: #which got striped so we sure there is no space at the end,
deltafs.append(float(l.split(':')[-1].strip()[:-2])) #from that all expect the last two signs (Hz unit)
if "#" in l and "StimulusFrequency" in l: #this for different metadata in different lists
stimulusfs.append(float(l.split(':')[-1].strip()[:-2]))
if "#" in l and "Duration" in l:
duration.append(float(l.split(':')[-1].strip()[:-3]))
if "#" in l and "Pause" in l:
pause.append(float(l.split(':')[-1].strip()[:-3]))
if '#Key' in l:
if len(time) != 0: #therefore empty in the first round
times.append(np.array(time)) #2nd loop means time != 0, so we put the times/amplitudes/frequencies to
amplitudes.append(np.array(ampl)) #the data of the first loop
frequencies.append(np.array(freq))
time = [] #temporary lists to overwrite the lists with the same name we made before
ampl = [] #so they are empty again
freq = []
if len(l) > 0 and l[0] is not '#': #line not empty and doesnt start with #
temporary = list(map(float, l.split())) #temporary list where we got 3 index splitted by spacebar, map to find them
time.append(temporary[0]) #temporary lists with the data at that place, respectively
freq.append(temporary[1])
ampl.append(temporary[2])
times.append(np.array(time)) #append data from one list to another
amplitudes.append(np.array(ampl)) #these append the data from the first loop to the final lists, because we overwrite them (?)
frequencies.append(np.array(freq))
return frequencies, times, amplitudes, eodfs, deltafs, stimulusfs, duration, pause #output of the function
def parse_infodataset(dataset_name):
assert(os.path.exists(dataset_name)) #see if data exists
f = open(dataset_name, 'r') #open data we gave in
lines = f.readlines() #read data
f.close() #?
identifier = []
for i in range(len(lines)):
l = lines[i].strip() #all lines of textdata, exclude all empty lines (empty () default for spacebar)
if "#" in l and "Identifier" in l:
identifier.append((l.split(':')[-1].strip()))
return identifier
def mean_traces(start, stop, timespan, frequencies, time):
minimumt = min([len(time[k]) for k in range(len(time))])
tnew = np.arange(start, stop, timespan / minimumt)
frequency = np.zeros((len(frequencies), len(tnew)))
for k in range(len(frequencies)):
ft = time[k][frequencies[k] > -5]
fn = frequencies[k][frequencies[k] > -5]
frequency[k,:] = np.interp(tnew, ft, fn)
mf = np.mean(frequency, axis=0)
return mf, tnew
def mean_noise_cut(frequencies, time, n):
cutf = np.zeros(len(frequencies))
for k in range(0, len(frequencies) - n):
kk = int(k)
f = np.mean(frequencies[kk:kk+n])
kkk = int(kk+n/2)
if k == 0:
cutf[:kkk] = f
cutf[kkk] = f
cutf[kkk:] = f
return cutf
def norm_function(f, t, onset_point, offset_point):
onset_end = onset_point - 10
offset_start = offset_point - 10
norm = []
for j in range(len(f)):
base = np.median(f[j][(t[j] >= onset_end) & (t[j] < onset_point)])
ground = f[j] - base
jar = np.median(ground[(t[j] >= offset_start) & (t[j] < offset_point)])
normed = ground / jar
norm.append(normed)
return norm
def base_eod(frequencies, time, onset_point):
base_eod = []
onset_end = onset_point - 10
base = np.median(frequencies[(time >= onset_end) & (time < onset_point)])
base_eod.append(base)
return base_eod
def JAR_eod(frequencies, time, offset_point):
jar_eod = []
offset_start = offset_point - 10
jar = np.median(frequencies[(time >= offset_start) & (time < offset_point)])
jar_eod.append(jar)
return jar_eod
def sort_values(values):
a = values[:2]
tau = np.array(sorted(values[2:], reverse=False))
values = np.array([a, tau])
values_flat = values.flatten()
return values_flat
def average(freq_all, time_all, start, stop, timespan, dm):
mf_all, tnew_all = mean_traces(start, stop, timespan, freq_all, time_all)
plt.plot(tnew_all, mf_all, color='b', label='average', ls='dashed')
# fit for average
sv_all, sc_all = curve_fit(step_response, tnew_all[tnew_all < dm], mf_all[tnew_all < dm],
bounds=(0.0, np.inf)) # step_values and step_cov
values_all = sort_values(sv_all)
plt.plot(tnew_all[tnew_all < dm], step_response(tnew_all, *sv_all)[tnew_all < dm], color = 'g', lw = 2,
label='average_fit: a1=%.2f, a2=%.2f, tau1=%.2f, tau2=%.2f' % tuple(values_all))
print('average: a1, a2, tau1, tau2', values_all)
return mf_all, tnew_all, values_all
def import_data(dataset):
import nixio as nix
nf = nix.File.open(dataset, nix.FileMode.ReadOnly)
b = nf.blocks[0]
eod = b.data_arrays['EOD-1']
dt = eod.dimensions[0].sampling_interval
di = int(50.0/dt)
t = b.tags['Beats_1']
amfreq = t.metadata['RePro-Info']['settings']['amfreq']
dat = []
pre_dat = []
for mt in b.multi_tags:
data = mt.retrieve_data(0, 'EOD-1')[:] # data[0]
dat.append(data)
i0 = int(mt.positions[0][0]/dt)
pre_data = eod[i0-di:i0]
pre_dat.append(pre_data)
return dat, pre_dat, dt
#nf.close()
def import_amfreq(dataset):
import nixio as nix
nf = nix.File.open(dataset, nix.FileMode.ReadOnly)
b = nf.blocks[0]
eod = b.data_arrays['EOD-1']
dt = eod.dimensions[0].sampling_interval
di = int(50.0/dt)
t = b.tags['Beats_1']
amfreq = t.metadata['RePro-Info']['settings']['amfreq']
return amfreq
if __name__ == '__main__':
import_data(os.path.join('JAR', '2020-07-21-ak', '2020-07-21-ak.nix'))