P-unit_model/introduction/introductionBaseline.py
2019-12-20 13:33:34 +01:00

284 lines
8.0 KiB
Python

import pyrelacs.DataLoader as dl
import numpy as np
import matplotlib.pyplot as plt
from IPython import embed
import os
import helperFunctions as hf
from thunderfish.eventdetection import detect_peaks
SAVEPATH = ""
def get_savepath():
global SAVEPATH
return SAVEPATH
def set_savepath(new_path):
global SAVEPATH
SAVEPATH = new_path
def main():
for folder in hf.get_subfolder_paths("data/"):
filepath = folder + "/basespikes1.dat"
set_savepath("figures/" + folder.split('/')[1] + "/")
print("Folder:", folder)
if not os.path.exists(get_savepath()):
os.makedirs(get_savepath())
spiketimes = []
ran = False
for metadata, key, data in dl.iload(filepath):
ran = True
spikes = data[:, 0]
spiketimes.append(spikes) # save for calculation of vector strength
metadata = metadata[0]
#print(metadata)
# print('firing frequency1:', metadata['firing frequency1'])
# print(mean_firing_rate(spikes))
# print('Coefficient of Variation (CV):', metadata['CV1'])
# print(calculate_coefficient_of_variation(spikes))
if not ran:
print("------------ DIDN'T RUN")
isi_histogram(spiketimes)
times, eods = hf.get_traces(folder, 2, 'BaselineActivity')
times, v1s = hf.get_traces(folder, 1, 'BaselineActivity')
vs = calculate_vector_strength(times, eods, spiketimes, v1s)
# print("Calculated vector strength:", vs)
def mean_firing_rate(spiketimes):
# mean firing rate (number of spikes per time)
return len(spiketimes)/spiketimes[-1]*1000
def calculate_coefficient_of_variation(spiketimes):
# CV (stddev of ISI divided by mean ISI (np.diff(spiketimes))
isi = np.diff(spiketimes)
std = np.std(isi)
mean = np.mean(isi)
return std/mean
def isi_histogram(spiketimes):
# ISI histogram (play around with binsize! < 1ms)
isi = []
for spike_list in spiketimes:
isi.extend(np.diff(spike_list))
maximum = max(isi)
bins = np.arange(0, maximum*1.01, 0.1)
plt.title('Phase locking of ISI without stimulus')
plt.xlabel('ISI in ms')
plt.ylabel('Count')
plt.hist(isi, bins=bins)
plt.savefig(get_savepath() + 'phase_locking_without_stimulus.png')
plt.close()
def calculate_vector_strength(times, eods, spiketimes, v1s):
# Vectorstaerke (use EOD frequency from header (metadata)) VS > 0.8
# dl.iload_traces(repro='BaselineActivity')
relative_spike_times = []
eod_durations = []
if len(times) == 0:
print("-----LENGTH OF TIMES = 0")
for recording in range(len(times)):
rel_spikes, eod_durs = eods_around_spikes(times[recording], eods[recording], spiketimes[recording])
relative_spike_times.extend(rel_spikes)
eod_durations.extend(eod_durs)
vs = __vector_strength__(rel_spikes, eod_durs)
phases = calculate_phases(rel_spikes, eod_durs)
plot_polar(phases, "test_phase_locking_" + str(recording) + "_with_vs:" + str(round(vs, 3)) + ".png")
print("VS of recording", recording, ":", vs)
plot_phaselocking_testfigures(times[recording], eods[recording], spiketimes[recording], v1s[recording])
return __vector_strength__(relative_spike_times, eod_durations)
def eods_around_spikes(time, eod, spiketimes):
eod_durations = []
relative_spike_times = []
for spike in spiketimes:
index = spike * 20 # time in s given timestamp of spike in ms - recorded at 20kHz -> timestamp/1000*20000 = idx
if index != np.round(index):
print("INDEX NOT AN INTEGER in eods_around_spikes! index:", index)
continue
index = int(index)
start_time, end_time = search_eod_start_and_end_times(time, eod, index)
eod_durations.append(end_time-start_time)
relative_spike_times.append(spike/1000 - start_time)
return relative_spike_times, eod_durations
def search_eod_start_and_end_times(time, eod, index):
# TODO might break if a spike is in the cut off first or last eod!
# search start_time:
previous = index
working_idx = index-1
while True:
if eod[working_idx] < 0 < eod[previous]:
first_value = eod[working_idx]
second_value = eod[previous]
dif = second_value - first_value
part = np.abs(first_value/dif)
time_dif = np.abs(time[previous] - time[working_idx])
start_time = time[working_idx] + time_dif*part
break
previous = working_idx
working_idx -= 1
# search end_time
previous = index
working_idx = index + 1
while True:
if eod[previous] < 0 < eod[working_idx]:
first_value = eod[previous]
second_value = eod[working_idx]
dif = second_value - first_value
part = np.abs(first_value / dif)
time_dif = np.abs(time[previous] - time[working_idx])
end_time = time[working_idx] + time_dif * part
break
previous = working_idx
working_idx += 1
return start_time, end_time
def search_closest_index(array, value, start=0, end=-1):
# searches the array to find the closest value in the array to the given value and returns its index.
# expects sorted array!
# start hast to be smaller than end
if end == -1:
end = len(array)-1
while True:
if end-start <= 1:
return end if np.abs(array[end]-value) < np.abs(array[start]-value) else start
middle = int(np.floor((end-start)/2)+start)
if array[middle] == value:
return middle
elif array[middle] > value:
end = middle
continue
else:
start = middle
continue
def __vector_strength__(relative_spike_times, eod_durations):
# adapted from Ramona
n = len(relative_spike_times)
if n == 0:
return 0
phase_times = np.zeros(n)
for i in range(n):
phase_times[i] = (relative_spike_times[i] / eod_durations[i]) * 2 * np.pi
vs = np.sqrt((1 / n * sum(np.cos(phase_times))) ** 2 + (1 / n * sum(np.sin(phase_times))) ** 2)
return vs
def calculate_phases(relative_spike_times, eod_durations):
phase_times = np.zeros(len(relative_spike_times))
for i in range(len(relative_spike_times)):
phase_times[i] = (relative_spike_times[i] / eod_durations[i]) * 2 * np.pi
return phase_times
def plot_polar(phases, name=""):
fig = plt.figure()
ax = fig.add_subplot(111, polar=True)
# r = np.arange(0, 1, 0.001)
# theta = 2 * 2 * np.pi * r
# line, = ax.plot(theta, r, color='#ee8d18', lw=3)
bins = np.arange(0, np.pi*2, 0.05)
ax.hist(phases, bins=bins)
if name == "":
plt.show()
else:
plt.savefig(get_savepath() + name)
plt.close()
def plot_phaselocking_testfigures(time, eod, spiketimes, v1):
eod_start_times = []
eod_end_times = []
for spike in spiketimes:
index = spike * 20 # time in s given timestamp of spike in ms - recorded at 20kHz -> timestamp/1000*20000 = idx
if index != np.round(index):
print("INDEX NOT AN INTEGER in eods_around_spikes! index:", index)
continue
index = int(index)
start_time, end_time = search_eod_start_and_end_times(time, eod, index)
eod_start_times.append(start_time)
eod_end_times.append(end_time)
cutoff_in_sec = 2
sampling = 20000
max_idx = cutoff_in_sec*sampling
spikes_part = [x/1000 for x in spiketimes if x/1000 < cutoff_in_sec]
count_spikes = len(spikes_part)
print(spiketimes)
print(len(spikes_part))
x_axis = time[0:max_idx]
plt.plot(spikes_part, np.ones(len(spikes_part))*-20, 'o')
plt.plot(x_axis, v1[0:max_idx])
plt.plot(eod_start_times[: count_spikes], np.zeros(count_spikes), 'o')
plt.plot(eod_end_times[: count_spikes], np.zeros(count_spikes), 'o')
plt.show()
plt.close()
if __name__ == '__main__':
main()