P-unit_model/redetect_fi_curve.py
2020-09-04 17:54:29 +02:00

610 lines
23 KiB
Python

from CellData import CellData
from DataParserFactory import DatParser
import pyrelacs.DataLoader as Dl
import helperFunctions as hF
from thunderfish.eventdetection import detect_peaks
import os
import numpy as np
import matplotlib.pyplot as plt
TEST_SIMILARITY = True
REDETECT_SPIKES = True
TOP_PERCENTILE = 95
BOTTOM_PERCENTILE = 5
FACTOR = 0.5
# strange_cells:
# 2012-07-12-ap-invivo-1 # cell with a few traces with max similarity < 0.1
# 2012-12-13-af-invivo-1 # cell with MANY traces with max similarity < 0.1
# 2012-12-21-ak-invivo-1 # a few
# 2012-12-21-an-invivo-1 # a few
# 2013-02-21-ae-invivo-1 # "
# 2013-02-21-ag-invivo-1 # "
# 2014-06-06-ac-invivo-1 # alot below 0.4 but a good bit above the 2nd max
def main():
test_fi_trace()
quit()
# find_and_save_best_threshold()
# quit()
directory = "data/final/"
skip_to = False
skip_to_cell = "2012-12-13-af-invivo-1"
threshold_file_path = "data/fi_thresholds.tsv"
thresholds_dict = load_fi_thresholds(threshold_file_path)
for cell in sorted(os.listdir(directory)):
# if cell != "2014-01-10-ab-invivo-1":
# continue
if skip_to:
if cell == skip_to_cell:
skip_to = False
else:
continue
cell_dir = directory + cell # "data/final/2012-04-20-af-invivo-1/"
print(cell_dir)
cell_data = CellData(cell_dir)
before = cell_data.get_delay()
after = cell_data.get_after_stimulus_duration()
# parser = DatParser(cell_dir)
if os.path.exists(cell_dir + "/redetected_spikes.npy") and not REDETECT_SPIKES:
spikes = np.load(cell_dir + "/redetected_spikes.npy", allow_pickle=True)
traces = np.load(cell_dir + "/fi_time_v1_traces.npy", allow_pickle=True)
else:
step = cell_data.get_sampling_interval()
threshold_pair = thresholds_dict[cell]
spikes, traces = get_redetected_spikes(cell_dir, before, after, step, threshold_pair)
np.save(cell_dir + "/redetected_spikes.npy", spikes, allow_pickle=True)
np.save(cell_dir + "/fi_time_v1_traces.npy", traces, allow_pickle=True)
print("redetection finished")
if os.path.exists(cell_dir + "/fi_traces_contrasts.npy") and not TEST_SIMILARITY:
trace_contrasts = np.load(cell_dir + "/fi_traces_contrasts.npy", allow_pickle=True)
trace_max_similarity = np.load(cell_dir + "/fi_traces_contrasts_similarity.npy", allow_pickle=True)
else:
cell_spiketrains = cell_data.get_fi_spiketimes()
# plt.plot(traces[0][0], traces[0][1])
# plt.eventplot(cell_spiketrains[0][0], colors="black", lineoffsets=max(traces[0][1]) + 1)
# plt.eventplot(spikes[0], colors="black", lineoffsets=max(traces[0][1]) + 2)
# plt.show()
# plt.close()
# unsorted_cell_spiketimes = get_unsorted_spiketimes(cell_dir + "/fispikes1.dat")
trace_contrasts = np.zeros(len(traces), dtype=np.int) - 1
trace_max_similarity = np.zeros((len(traces), 2)) - 1
for i, spiketrain in enumerate(spikes):
similarity, max_idx, maxima = find_matching_spiketrain(spiketrain, cell_spiketrains, cell_data.get_sampling_interval())
trace_contrasts[i] = max_idx[0]
trace_max_similarity[i] = maxima
# if trace_max_similarity[i] <= 0.05:
# step = cell_data.get_sampling_interval()
# test_detected_spiketimes(traces[i], spiketrain, cell_spiketrains[max_idx[0]], step)
np.save(cell_dir + "/fi_traces_contrasts.npy", trace_contrasts, allow_pickle=True)
np.save(cell_dir + "/fi_traces_contrasts_similarity.npy", trace_max_similarity, allow_pickle=True)
print("similarity test finished")
# step_size = cell_data.get_sampling_interval()
# steps = np.arange(0, 100.1, 0.5)
# percentiles_arr = np.zeros((len(traces), len(steps)))
# for i, trace_pair in enumerate(traces):
# v1_part = trace_pair[1][-int(np.rint(0.6/step_size)):]
# percentiles = np.percentile(np.array(v1_part) - np.median(v1_part), steps)
# percentiles_arr[i, :] = percentiles
# plt.plot(steps, percentiles)
# mean_perc = np.mean(percentiles_arr, axis=0)
# plt.plot(steps, mean_perc)
# plt.show()
# plt.close()
# bins = np.arange(0, 1.001, 0.05)
# plt.hist(trace_max_similarity, bins=bins)
# plt.show()
# plt.close()
#
#
# step_size = cell_data.get_sampling_interval()
# cell_spiketrains = cell_data.get_fi_spiketimes()
# contrasts = cell_data.get_fi_contrasts()
# tested_contrasts = []
# for i, redetected in enumerate(spikes):
# idx = trace_contrasts[i]
# if idx not in tested_contrasts:
# print("Contrast: {:.3f}".format(contrasts[idx]))
# test_detected_spiketimes(traces[i], redetected, cell_spiketrains[idx], step_size)
# tested_contrasts.append(idx)
def test_fi_trace():
# cell = "2012-12-13-af-invivo-1"
# cell = "2012-07-12-ap-invivo-1"
data_dir = "data/final/"
full_count = 0
contrast_trials_below_three = 0
differences_max_second_max = []
for cell in sorted(os.listdir(data_dir)):
cell_dir = data_dir + cell
# print(cell)
cell_data = CellData(cell_dir)
step_size = cell_data.get_sampling_interval()
spiketimes = cell_data.get_fi_spiketimes()
# trials = [len(x) for x in spiketimes]
# total = sum(trials)
spikes = np.load(cell_dir + "/redetected_spikes.npy", allow_pickle=True)
# print("Cell data total: {} vs {} # traces".format(total, len(spikes)))
traces = np.load(cell_dir + "/fi_time_v1_traces.npy", allow_pickle=True)
trace_contrasts = np.load(cell_dir + "/fi_traces_contrasts.npy", allow_pickle=True)
trace_max_similarity = np.load(cell_dir + "/fi_traces_contrasts_similarity.npy", allow_pickle=True)
count_good = 0
count_bad = 0
threshold_file_path = "data/fi_thresholds.tsv"
# thresholds_dict = load_fi_thresholds(threshold_file_path)
# spikes, traces = get_redetected_spikes(cell_dir, 0.2, 0.8, cell_data.get_sampling_interval(), thresholds_dict[cell])
# print("No preduration:", len(traces))
contrast_trials = {}
for i in range(len(traces)):
differences_max_second_max.append((trace_max_similarity[i][0] - trace_max_similarity[i][1])/ trace_max_similarity[i][0])
if trace_max_similarity[i][0] > trace_max_similarity[i][1] + 0.15 and trace_max_similarity[i][0] < trace_max_similarity[i][1] + 0.2:
print("max sim: {:.2f}, {:.2f}".format(trace_max_similarity[i][0], trace_max_similarity[i][1]))
if trace_max_similarity[i][0] > trace_max_similarity[i][1] + 0.15:
count_good += 1
if trace_contrasts[i] not in contrast_trials:
contrast_trials[trace_contrasts[i]] = 0
contrast_trials[trace_contrasts[i]] += 1
continue
count_bad += 1
# count_bad += 1
# event_offset = max(traces[i][1]) + 0.5
# fig, axes = plt.subplots(2, 1, sharex="all")
# axes[0].plot(traces[i][0], traces[i][1])
# axes[0].eventplot(spikes[i], lineoffsets=event_offset, colors="black")
#
# similarity, max_idx, maxima = find_matching_spiketrain(spikes[i], spiketimes, step_size)
# axes[0].eventplot(spiketimes[max_idx[0]][max_idx[1]], lineoffsets=event_offset + 1, colors="orange")
#
# # for o, st in enumerate(spiketimes[trace_contrasts[i]]):
# # axes[0].eventplot(st, lineoffsets=event_offset + 1 + o*1, colors="orange")
#
# time, v1, eod, local_eod, stimulus = get_ith_trace(cell_dir, i)
# axes[1].plot(time, local_eod)
#
# plt.show()
# plt.close()
# t, f = hF.calculate_time_and_frequency_trace(spikes[-1], cell_data.get_sampling_interval())
# plt.plot(t, f)
# plt.eventplot(spikes[-1], lineoffsets=max(traces[-1][1]) + 0.5)
# plt.show()
# plt.close()
if count_bad > 0:
over_seven = 0
below_three = 0
for key in contrast_trials.keys():
if contrast_trials[key] >= 7:
over_seven += 1
if contrast_trials[key] < 3:
below_three += 1
if over_seven < 7:
full_count += 1
print(cell)
print(contrast_trials)
print("good:", count_good, "bad:", count_bad)
if below_three > 1:
contrast_trials_below_three += 1
# print("good:", count_good, "bad:", count_bad)
print("Cells less than 7 trials in seven contrasts:", full_count)
print("Cells less than 3 trials in a contrast:", contrast_trials_below_three)
def get_ith_trace(cell_dir, i):
count = 0
for info, key, time, x in Dl.iload_traces(cell_dir, repro="FICurve", before=0.2, after=0.8):
if '----- Control --------------------------------------------------------' in info[0].keys():
pre_duration = float(
info[0]["----- Pre-Intensities ------------------------------------------------"]["preduration"][:-2])
if pre_duration != 0:
continue
elif "preduration" in info[0].keys():
pre_duration = float(info[0]["preduration"][:-2])
if pre_duration != 0:
continue
elif len(info) == 2 and "preduration" in info[1].keys():
pre_duration = float(info[1]["preduration"][:-2])
if pre_duration != 0:
continue
if count < i:
count += 1
continue
# print(count)
# time, v1, eod, local_eod, stimulus
# print(info)
# print(key)
v1 = x[0]
eod = x[1]
local_eod = x[2]
stimulus = x[3]
return time, v1, eod, local_eod, stimulus
def load_fi_thresholds(threshold_file_path):
thresholds_dict = {}
if os.path.exists(threshold_file_path):
with open(threshold_file_path, "r") as threshold_file:
for line in threshold_file:
line = line.strip()
line = line.split('\t')
name = line[0]
bottom_percentile = float(line[1])
top_percentile = float(line[2])
thresholds_dict[name] = [bottom_percentile, top_percentile]
# print("Already done:", name)
return thresholds_dict
def find_and_save_best_threshold():
base_path = "data/final/"
threshold_file_path = "data/fi_thresholds.tsv"
re_choose_thresholds = False
thresholds_dict = load_fi_thresholds(threshold_file_path)
count = 0
for item in sorted(os.listdir(base_path)):
if item in thresholds_dict.keys() and not re_choose_thresholds:
continue
count += 1
print("cells to do:", count)
for item in sorted(os.listdir(base_path)):
if item in thresholds_dict.keys() and not re_choose_thresholds and not thresholds_dict[item][0] < 10:
print("Already done:", item)
continue
cell_dir = base_path + item
# starting assumptions:
standard_top_percentile = 95
threshold_pairs = [(40, 95), (50, 95), (60, 95)]
colors = ["blue", "orange", "red"]
if "thresholds" in item:
continue
print(item)
item_path = base_path + item
cell_data = CellData(item_path)
step_size = cell_data.get_sampling_interval()
trace_pairs = np.load(cell_dir + "/fi_time_v1_traces.npy", allow_pickle=True)
trace_contrasts = np.load(cell_dir + "/fi_traces_contrasts.npy", allow_pickle=True)
trace_max_similarity = np.load(cell_dir + "/fi_traces_contrasts_similarity.npy", allow_pickle=True)
example_trace_pairs = []
example_contrasts = []
for i, trace_pair in enumerate(trace_pairs):
if trace_contrasts[i] not in example_contrasts:
example_contrasts.append(trace_contrasts[i])
example_trace_pairs.append(trace_pair)
example_contrasts, example_trace_pairs = zip(*sorted(zip(example_contrasts, example_trace_pairs)))
stop = False
print("Thresholds are:\n ")
for i in range(len(threshold_pairs)):
print("{}: {} - {}".format(i, colors[i], threshold_pairs[i]))
plot_test_thresholds(example_trace_pairs, threshold_pairs, colors, step_size)
response = input("Choose: 'ok', 'stop', or a number (bottom threshold 0-100)")
while True:
if response == "stop":
stop = True
break
elif response.lower().startswith("ok"):
parts = response.split(" ")
if len(parts) == 1:
print("please specify an index:")
response = input("Choose: 'ok', 'stop', or a number (bottom threshold 0-100)")
continue
try:
threshold_idx = int(parts[1])
break
except:
print("{} could not be parsed as number or ok please try again.".format(response))
print("Thresholds are:\n ")
for i in range(len(threshold_pairs)):
print("{}: {} - {}".format(i, colors[i], threshold_pairs[i]))
response = input("Choose: 'ok', 'stop', or a number (bottom threshold 0-100)")
try:
parts = response.strip().split(",")
if len(parts) == 1:
extra_pair = (float(parts[0]), standard_top_percentile)
elif len(parts) == 2:
extra_pair = (float(parts[0]), float(parts[1]))
else:
raise ValueError()
except ValueError as e:
print("{} could not be parsed as number or ok please try again.".format(response))
print("Thresholds are:\n ")
for i in range(len(threshold_pairs)):
print("{}: {} - {}".format(i, colors[i], threshold_pairs[i]))
response = input("Choose: 'ok', 'stop', or a number (bottom threshold 0-100) or two numbers: bot, top")
continue
plot_test_thresholds(example_trace_pairs, threshold_pairs, colors, step_size, extra_pair=extra_pair)
print("Thresholds are:\n ")
for i in range(len(threshold_pairs)):
print("{}: {} - {}".format(i, colors[i], threshold_pairs[i]))
response = input("Choose: 'ok', 'stop', or a number (bottom threshold 0-100)")
if stop:
break
if threshold_idx < len(threshold_pairs):
thresholds_dict[item] = [threshold_pairs[threshold_idx][0], threshold_pairs[threshold_idx][1]]
else:
thresholds_dict[item] = [extra_pair[0], extra_pair[1]]
with open(threshold_file_path, "w") as threshold_file:
for name in sorted(thresholds_dict.keys()):
line = name + "\t"
line += str(thresholds_dict[name][0]) + "\t"
line += str(thresholds_dict[name][1]) + "\t"
threshold_file.write(line + "\n")
def plot_test_thresholds(trace_pairs, threshold_pairs, colors, step_size, extra_pair=None):
ncols = int(np.ceil(len(trace_pairs) / 4))
nrows = int(np.ceil(len(trace_pairs) / ncols))
fig, axes = plt.subplots(nrows, ncols, sharex="all", figsize=(12, 12))
for i, (time, v1) in enumerate(trace_pairs):
line_offset = 0
c = i % ncols
r = int(np.floor(i / ncols))
v1_max = np.max(v1)
v1_median = np.median(v1)
axes[r, c].plot(time, v1)
axes[r, c].plot((time[0], time[-1]), (v1_median, v1_median), color="black")
v1_part = v1[-int(0.6/step_size):]
if extra_pair is not None:
threshold = np.percentile(v1_part, extra_pair[1]) - np.percentile(v1_part, extra_pair[0])
axes[r, c].plot((time[0], time[-1]), (v1_median+threshold, v1_median+threshold), color="black")
peaks, _ = detect_peaks(v1, threshold=threshold)
spikes = [time[idx] for idx in peaks]
axes[r, c].eventplot(spikes, colors="black", lineoffsets=v1_max + line_offset)
line_offset += 1
for j, (bot_perc, top_perc) in enumerate(threshold_pairs):
threshold = np.percentile(v1_part, top_perc) - np.percentile(v1_part, bot_perc)
axes[r, c].plot((time[0], time[-1]), (v1_median + threshold, v1_median + threshold), color=colors[j])
peaks, _ = detect_peaks(v1, threshold=threshold)
spikes = [time[idx] for idx in peaks]
axes[r, c].eventplot(spikes, colors=colors[j], lineoffsets=v1_max + line_offset)
line_offset += 1
plt.show()
plt.close()
def test_detected_spiketimes(traces, redetected, spiketimes, step):
time = traces[0]
v1 = traces[1]
plt.plot(traces[0], traces[1])
plt.eventplot(redetected, colors="red", lineoffsets=max(traces[1]) + 1)
median = np.median(traces[1])
last_600_ms = int(np.rint(0.6 / step))
threshold_last_600 = np.percentile(v1[-last_600_ms:], TOP_PERCENTILE) - np.percentile(v1[-last_600_ms:], BOTTOM_PERCENTILE) * FACTOR
threshold_normal = np.percentile(v1, 94.5) - np.percentile(v1, 50)
print("threshold full time : {:.2f}".format(threshold_normal))
print("threshold last 600 ms: {:.2f}".format(threshold_last_600))
peaks, _ = detect_peaks(v1, threshold=threshold_last_600)
redetected_current_values = [time[idx] for idx in peaks]
plt.eventplot(redetected_current_values, colors="green", lineoffsets=max(traces[1]) + 2)
plt.plot((traces[0][0], traces[0][-1]), (median, median), color="black")
plt.plot((traces[0][0], traces[0][-1]), (median+threshold_normal, median+threshold_normal), color="black")
plt.plot((traces[0][0], traces[0][-1]), (median+threshold_last_600, median+threshold_last_600), color="grey")
for i, spiketrain in enumerate(spiketimes):
plt.eventplot(spiketrain, colors="black", lineoffsets=max(traces[1]) + 3 + i)
plt.show()
plt.close()
def plot_percentiles(trace):
steps = np.arange(0, 100.1, 0.5)
percentiles = np.percentile(trace, steps)
plt.plot(steps, percentiles)
plt.show()
plt.close()
def get_unsorted_spiketimes(fi_file):
spiketimes = []
for metadata, key, data in Dl.iload(fi_file):
spike_time_data = data[:, 0] / 1000
spiketimes.append(spike_time_data)
return spiketimes
def find_matching_spiketrain(redetected, cell_spiketrains, step_size):
# redetected_idices = [int(np.rint(s / step_size)) for s in redetected]
spikes_dict = {}
for s in redetected:
idx = int(np.rint(s / step_size))
spikes_dict[idx] = True
spikes_dict[idx+1] = True
spikes_dict[idx-1] = True
similarity = np.zeros((len(cell_spiketrains), max([len(contrast_list) for contrast_list in cell_spiketrains])))
maximum = -1
max_idx = (-1, -1)
for i, contrast_list in enumerate(cell_spiketrains):
for j, cell_spiketrain in enumerate(contrast_list):
count = 0
cell_spike_indices = [int(np.rint(s / step_size)) for s in cell_spiketrain]
# plt.plot(cell_spiketrain, cell_spike_indices, '.')
# plt.plot(redetected, redetected_idices, '.')
# plt.show()
# plt.close()
for spike in cell_spiketrain:
idx = int(np.rint(spike / step_size))
if idx in spikes_dict:
count += 1
similarity[i, j] = count / len(cell_spiketrain)
if similarity[i, j] > maximum:
maximum = similarity[i, j]
max_idx = (i, j)
# plt.imshow(similarity)
# plt.show()
# plt.close()
flattened = similarity.flatten()
sorted_flattened = sorted(flattened)
second_max = sorted_flattened[-2]
if maximum < 0.5:
print("Identification: max_sim: {:.2f} vs {:.2f} second max; Diff: {} worked".format(maximum, second_max, maximum - second_max))
return similarity, max_idx, (maximum, second_max)
def get_redetected_spikes(cell_dir, before, after, step, threshold_pair):
spikes_list = []
traces = []
count = 1
for info, key, time, x in Dl.iload_traces(cell_dir, repro="FICurve", before=before, after=after):
# print(count)
if '----- Control --------------------------------------------------------' in info[0].keys():
pre_duration = float(
info[0]["----- Pre-Intensities ------------------------------------------------"]["preduration"][:-2])
if pre_duration != 0:
continue
elif "preduration" in info[0].keys():
pre_duration = float(info[0]["preduration"][:-2])
if pre_duration != 0:
continue
elif len(info) == 2 and "preduration" in info[1].keys():
pre_duration = float(info[1]["preduration"][:-2])
if pre_duration != 0:
continue
count += 1
# time, v1, eod, local_eod, stimulus
# print(key)
# print(info)
v1 = x[0]
# percentiles = np.arange(0.0, 101, 1)
# plt.plot(percentiles, np.percentile(v1, percentiles))
# plt.show()
# plt.close()
if len(v1) > 15/step:
print("Skipping Fi-Curve trace longer than 15 seconds!")
continue
if len(v1) > 3/step:
print("Warning: A FI-Curve trace is longer than 3 seconds.")
if after < 0.8:
print("Why the f is the after stimulus time shorter than 0.8s ???")
raise ValueError("Safety error: check where the after stimulus time comes from.")
last_about_600_ms = int(np.rint((after-0.2)/step))
top = np.percentile(v1[-last_about_600_ms:], threshold_pair[1])
bottom = np.percentile(v1[-last_about_600_ms:], threshold_pair[0])
threshold = (top - bottom)
peaks, _ = detect_peaks(v1, threshold=threshold)
spikes = [time[idx] for idx in peaks]
spikes_list.append(np.array(spikes))
# eod = x[1]
# local_eod = x[2]
stimulus = x[3]
# if count % 5 == 0:
# plt.eventplot(spikes, colors="black", lineoffsets=max(v1) + 1)
# plt.plot(time, v1)
# median = np.median(v1)
# plt.plot((time[0], time[-1]), (median, median), color="grey")
# plt.plot((time[0], time[-1]), (median+threshold, median+threshold), color="grey")
# plt.show()
# plt.close()
# print(key[5])
# if "rectangle" not in key[5] and "FICurve" not in key[5][35]:
# raise ValueError("No value in key 5 is rectangle:")
traces.append([np.array(time), np.array(v1)])
return np.array(spikes_list), traces
if __name__ == '__main__':
main()