from parser.CellData import CellData import pyrelacs.DataLoader as Dl from thunderfish.eventdetection import detect_peaks import os import numpy as np import matplotlib.pyplot as plt TEST_SIMILARITY = True REDETECT_SPIKES = True TOP_PERCENTILE = 95 BOTTOM_PERCENTILE = 5 FACTOR = 0.5 # strange_cells: # 2012-07-12-ap-invivo-1 # cell with a few traces with max similarity < 0.1 # 2012-12-13-af-invivo-1 # cell with MANY traces with max similarity < 0.1 # 2012-12-21-ak-invivo-1 # a few # 2012-12-21-an-invivo-1 # a few # 2013-02-21-ae-invivo-1 # " # 2013-02-21-ag-invivo-1 # " # 2014-06-06-ac-invivo-1 # alot below 0.4 but a good bit above the 2nd max def main(): test_fi_trace() quit() # find_and_save_best_threshold() # quit() directory = "data/final/" skip_to = False skip_to_cell = "2012-12-13-af-invivo-1" threshold_file_path = "data/fi_thresholds.tsv" thresholds_dict = load_fi_thresholds(threshold_file_path) for cell in sorted(os.listdir(directory)): # if cell != "2014-01-10-ab-invivo-1": # continue if skip_to: if cell == skip_to_cell: skip_to = False else: continue cell_dir = directory + cell # "data/final/2012-04-20-af-invivo-1/" print(cell_dir) cell_data = CellData(cell_dir) before = cell_data.get_delay() after = cell_data.get_after_stimulus_duration() # parser = DatParser(cell_dir) if os.path.exists(cell_dir + "/redetected_spikes.npy") and not REDETECT_SPIKES: spikes = np.load(cell_dir + "/redetected_spikes.npy", allow_pickle=True) traces = np.load(cell_dir + "/fi_time_v1_traces.npy", allow_pickle=True) else: step = cell_data.get_sampling_interval() threshold_pair = thresholds_dict[cell] spikes, traces = get_redetected_spikes(cell_dir, before, after, step, threshold_pair) np.save(cell_dir + "/redetected_spikes.npy", spikes, allow_pickle=True) np.save(cell_dir + "/fi_time_v1_traces.npy", traces, allow_pickle=True) print("redetection finished") if os.path.exists(cell_dir + "/fi_traces_contrasts.npy") and not TEST_SIMILARITY: trace_contrasts = np.load(cell_dir + "/fi_traces_contrasts.npy", allow_pickle=True) trace_max_similarity = np.load(cell_dir + "/fi_traces_contrasts_similarity.npy", allow_pickle=True) else: cell_spiketrains = cell_data.get_fi_spiketimes() # plt.plot(traces[0][0], traces[0][1]) # plt.eventplot(cell_spiketrains[0][0], colors="black", lineoffsets=max(traces[0][1]) + 1) # plt.eventplot(spikes[0], colors="black", lineoffsets=max(traces[0][1]) + 2) # plt.show() # plt.close() # unsorted_cell_spiketimes = get_unsorted_spiketimes(cell_dir + "/fispikes1.dat") trace_contrasts = np.zeros(len(traces), dtype=np.int) - 1 trace_max_similarity = np.zeros((len(traces), 2)) - 1 for i, spiketrain in enumerate(spikes): similarity, max_idx, maxima = find_matching_spiketrain(spiketrain, cell_spiketrains, cell_data.get_sampling_interval()) trace_contrasts[i] = max_idx[0] trace_max_similarity[i] = maxima # if trace_max_similarity[i] <= 0.05: # step = cell_data.get_sampling_interval() # test_detected_spiketimes(traces[i], spiketrain, cell_spiketrains[max_idx[0]], step) np.save(cell_dir + "/fi_traces_contrasts.npy", trace_contrasts, allow_pickle=True) np.save(cell_dir + "/fi_traces_contrasts_similarity.npy", trace_max_similarity, allow_pickle=True) print("similarity test finished") # step_size = cell_data.get_sampling_interval() # steps = np.arange(0, 100.1, 0.5) # percentiles_arr = np.zeros((len(traces), len(steps))) # for i, trace_pair in enumerate(traces): # v1_part = trace_pair[1][-int(np.rint(0.6/step_size)):] # percentiles = np.percentile(np.array(v1_part) - np.median(v1_part), steps) # percentiles_arr[i, :] = percentiles # plt.plot(steps, percentiles) # mean_perc = np.mean(percentiles_arr, axis=0) # plt.plot(steps, mean_perc) # plt.show() # plt.close() # bins = np.arange(0, 1.001, 0.05) # plt.hist(trace_max_similarity, bins=bins) # plt.show() # plt.close() # # # step_size = cell_data.get_sampling_interval() # cell_spiketrains = cell_data.get_fi_spiketimes() # contrasts = cell_data.get_fi_contrasts() # tested_contrasts = [] # for i, redetected in enumerate(spikes): # idx = trace_contrasts[i] # if idx not in tested_contrasts: # print("Contrast: {:.3f}".format(contrasts[idx])) # test_detected_spiketimes(traces[i], redetected, cell_spiketrains[idx], step_size) # tested_contrasts.append(idx) def test_fi_trace(): # cell = "2012-12-13-af-invivo-1" # cell = "2012-07-12-ap-invivo-1" data_dir = "data/final/" full_count = 0 contrast_trials_below_three = 0 differences_max_second_max = [] for cell in sorted(os.listdir(data_dir)): cell_dir = data_dir + cell # print(cell) cell_data = CellData(cell_dir) step_size = cell_data.get_sampling_interval() spiketimes = cell_data.get_fi_spiketimes() # trials = [len(x) for x in spiketimes] # total = sum(trials) spikes = np.load(cell_dir + "/redetected_spikes.npy", allow_pickle=True) # print("Cell data total: {} vs {} # traces".format(total, len(spikes))) traces = np.load(cell_dir + "/fi_time_v1_traces.npy", allow_pickle=True) trace_contrasts = np.load(cell_dir + "/fi_traces_contrasts.npy", allow_pickle=True) trace_max_similarity = np.load(cell_dir + "/fi_traces_contrasts_similarity.npy", allow_pickle=True) count_good = 0 count_bad = 0 threshold_file_path = "data/fi_thresholds.tsv" # thresholds_dict = load_fi_thresholds(threshold_file_path) # spikes, traces = get_redetected_spikes(cell_dir, 0.2, 0.8, cell_data.get_sampling_interval(), thresholds_dict[cell]) # print("No preduration:", len(traces)) contrast_trials = {} for i in range(len(traces)): differences_max_second_max.append((trace_max_similarity[i][0] - trace_max_similarity[i][1])/ trace_max_similarity[i][0]) if trace_max_similarity[i][0] > trace_max_similarity[i][1] + 0.15 and trace_max_similarity[i][0] < trace_max_similarity[i][1] + 0.2: print("max sim: {:.2f}, {:.2f}".format(trace_max_similarity[i][0], trace_max_similarity[i][1])) if trace_max_similarity[i][0] > trace_max_similarity[i][1] + 0.15: count_good += 1 if trace_contrasts[i] not in contrast_trials: contrast_trials[trace_contrasts[i]] = 0 contrast_trials[trace_contrasts[i]] += 1 continue count_bad += 1 # count_bad += 1 # event_offset = max(traces[i][1]) + 0.5 # fig, axes = plt.subplots(2, 1, sharex="all") # axes[0].plot(traces[i][0], traces[i][1]) # axes[0].eventplot(spikes[i], lineoffsets=event_offset, colors="black") # # similarity, max_idx, maxima = find_matching_spiketrain(spikes[i], spiketimes, step_size) # axes[0].eventplot(spiketimes[max_idx[0]][max_idx[1]], lineoffsets=event_offset + 1, colors="orange") # # # for o, st in enumerate(spiketimes[trace_contrasts[i]]): # # axes[0].eventplot(st, lineoffsets=event_offset + 1 + o*1, colors="orange") # # time, v1, eod, local_eod, stimulus = get_ith_trace(cell_dir, i) # axes[1].plot(time, local_eod) # # plt.show() # plt.close() # t, f = hF.calculate_time_and_frequency_trace(spikes[-1], cell_data.get_sampling_interval()) # plt.plot(t, f) # plt.eventplot(spikes[-1], lineoffsets=max(traces[-1][1]) + 0.5) # plt.show() # plt.close() if count_bad > 0: over_seven = 0 below_three = 0 for key in contrast_trials.keys(): if contrast_trials[key] >= 7: over_seven += 1 if contrast_trials[key] < 3: below_three += 1 if over_seven < 7: full_count += 1 print(cell) print(contrast_trials) print("good:", count_good, "bad:", count_bad) if below_three > 1: contrast_trials_below_three += 1 # print("good:", count_good, "bad:", count_bad) print("Cells less than 7 trials in seven contrasts:", full_count) print("Cells less than 3 trials in a contrast:", contrast_trials_below_three) def get_ith_trace(cell_dir, i): count = 0 for info, key, time, x in Dl.iload_traces(cell_dir, repro="FICurve", before=0.2, after=0.8): if '----- Control --------------------------------------------------------' in info[0].keys(): pre_duration = float( info[0]["----- Pre-Intensities ------------------------------------------------"]["preduration"][:-2]) if pre_duration != 0: continue elif "preduration" in info[0].keys(): pre_duration = float(info[0]["preduration"][:-2]) if pre_duration != 0: continue elif len(info) == 2 and "preduration" in info[1].keys(): pre_duration = float(info[1]["preduration"][:-2]) if pre_duration != 0: continue if count < i: count += 1 continue # print(count) # time, v1, eod, local_eod, stimulus # print(info) # print(key) v1 = x[0] eod = x[1] local_eod = x[2] stimulus = x[3] return time, v1, eod, local_eod, stimulus def load_fi_thresholds(threshold_file_path): thresholds_dict = {} if os.path.exists(threshold_file_path): with open(threshold_file_path, "r") as threshold_file: for line in threshold_file: line = line.strip() line = line.split('\t') name = line[0] bottom_percentile = float(line[1]) top_percentile = float(line[2]) thresholds_dict[name] = [bottom_percentile, top_percentile] # print("Already done:", name) return thresholds_dict def find_and_save_best_threshold(): base_path = "data/final/" threshold_file_path = "data/fi_thresholds.tsv" re_choose_thresholds = False thresholds_dict = load_fi_thresholds(threshold_file_path) count = 0 for item in sorted(os.listdir(base_path)): if item in thresholds_dict.keys() and not re_choose_thresholds: continue count += 1 print("cells to do:", count) for item in sorted(os.listdir(base_path)): if item in thresholds_dict.keys() and not re_choose_thresholds and not thresholds_dict[item][0] < 10: print("Already done:", item) continue cell_dir = base_path + item # starting assumptions: standard_top_percentile = 95 threshold_pairs = [(40, 95), (50, 95), (60, 95)] colors = ["blue", "orange", "red"] if "thresholds" in item: continue print(item) item_path = base_path + item cell_data = CellData(item_path) step_size = cell_data.get_sampling_interval() trace_pairs = np.load(cell_dir + "/fi_time_v1_traces.npy", allow_pickle=True) trace_contrasts = np.load(cell_dir + "/fi_traces_contrasts.npy", allow_pickle=True) trace_max_similarity = np.load(cell_dir + "/fi_traces_contrasts_similarity.npy", allow_pickle=True) example_trace_pairs = [] example_contrasts = [] for i, trace_pair in enumerate(trace_pairs): if trace_contrasts[i] not in example_contrasts: example_contrasts.append(trace_contrasts[i]) example_trace_pairs.append(trace_pair) example_contrasts, example_trace_pairs = zip(*sorted(zip(example_contrasts, example_trace_pairs))) stop = False print("Thresholds are:\n ") for i in range(len(threshold_pairs)): print("{}: {} - {}".format(i, colors[i], threshold_pairs[i])) plot_test_thresholds(example_trace_pairs, threshold_pairs, colors, step_size) response = input("Choose: 'ok', 'stop', or a number (bottom threshold 0-100)") while True: if response == "stop": stop = True break elif response.lower().startswith("ok"): parts = response.split(" ") if len(parts) == 1: print("please specify an index:") response = input("Choose: 'ok', 'stop', or a number (bottom threshold 0-100)") continue try: threshold_idx = int(parts[1]) break except: print("{} could not be parsed as number or ok please try again.".format(response)) print("Thresholds are:\n ") for i in range(len(threshold_pairs)): print("{}: {} - {}".format(i, colors[i], threshold_pairs[i])) response = input("Choose: 'ok', 'stop', or a number (bottom threshold 0-100)") try: parts = response.strip().split(",") if len(parts) == 1: extra_pair = (float(parts[0]), standard_top_percentile) elif len(parts) == 2: extra_pair = (float(parts[0]), float(parts[1])) else: raise ValueError() except ValueError as e: print("{} could not be parsed as number or ok please try again.".format(response)) print("Thresholds are:\n ") for i in range(len(threshold_pairs)): print("{}: {} - {}".format(i, colors[i], threshold_pairs[i])) response = input("Choose: 'ok', 'stop', or a number (bottom threshold 0-100) or two numbers: bot, top") continue plot_test_thresholds(example_trace_pairs, threshold_pairs, colors, step_size, extra_pair=extra_pair) print("Thresholds are:\n ") for i in range(len(threshold_pairs)): print("{}: {} - {}".format(i, colors[i], threshold_pairs[i])) response = input("Choose: 'ok', 'stop', or a number (bottom threshold 0-100)") if stop: break if threshold_idx < len(threshold_pairs): thresholds_dict[item] = [threshold_pairs[threshold_idx][0], threshold_pairs[threshold_idx][1]] else: thresholds_dict[item] = [extra_pair[0], extra_pair[1]] with open(threshold_file_path, "w") as threshold_file: for name in sorted(thresholds_dict.keys()): line = name + "\t" line += str(thresholds_dict[name][0]) + "\t" line += str(thresholds_dict[name][1]) + "\t" threshold_file.write(line + "\n") def plot_test_thresholds(trace_pairs, threshold_pairs, colors, step_size, extra_pair=None): ncols = int(np.ceil(len(trace_pairs) / 4)) nrows = int(np.ceil(len(trace_pairs) / ncols)) fig, axes = plt.subplots(nrows, ncols, sharex="all", figsize=(12, 12)) for i, (time, v1) in enumerate(trace_pairs): line_offset = 0 c = i % ncols r = int(np.floor(i / ncols)) v1_max = np.max(v1) v1_median = np.median(v1) axes[r, c].plot(time, v1) axes[r, c].plot((time[0], time[-1]), (v1_median, v1_median), color="black") v1_part = v1[-int(0.6/step_size):] if extra_pair is not None: threshold = np.percentile(v1_part, extra_pair[1]) - np.percentile(v1_part, extra_pair[0]) axes[r, c].plot((time[0], time[-1]), (v1_median+threshold, v1_median+threshold), color="black") peaks, _ = detect_peaks(v1, threshold=threshold) spikes = [time[idx] for idx in peaks] axes[r, c].eventplot(spikes, colors="black", lineoffsets=v1_max + line_offset) line_offset += 1 for j, (bot_perc, top_perc) in enumerate(threshold_pairs): threshold = np.percentile(v1_part, top_perc) - np.percentile(v1_part, bot_perc) axes[r, c].plot((time[0], time[-1]), (v1_median + threshold, v1_median + threshold), color=colors[j]) peaks, _ = detect_peaks(v1, threshold=threshold) spikes = [time[idx] for idx in peaks] axes[r, c].eventplot(spikes, colors=colors[j], lineoffsets=v1_max + line_offset) line_offset += 1 plt.show() plt.close() def test_detected_spiketimes(traces, redetected, spiketimes, step): time = traces[0] v1 = traces[1] plt.plot(traces[0], traces[1]) plt.eventplot(redetected, colors="red", lineoffsets=max(traces[1]) + 1) median = np.median(traces[1]) last_600_ms = int(np.rint(0.6 / step)) threshold_last_600 = np.percentile(v1[-last_600_ms:], TOP_PERCENTILE) - np.percentile(v1[-last_600_ms:], BOTTOM_PERCENTILE) * FACTOR threshold_normal = np.percentile(v1, 94.5) - np.percentile(v1, 50) print("threshold full time : {:.2f}".format(threshold_normal)) print("threshold last 600 ms: {:.2f}".format(threshold_last_600)) peaks, _ = detect_peaks(v1, threshold=threshold_last_600) redetected_current_values = [time[idx] for idx in peaks] plt.eventplot(redetected_current_values, colors="green", lineoffsets=max(traces[1]) + 2) plt.plot((traces[0][0], traces[0][-1]), (median, median), color="black") plt.plot((traces[0][0], traces[0][-1]), (median+threshold_normal, median+threshold_normal), color="black") plt.plot((traces[0][0], traces[0][-1]), (median+threshold_last_600, median+threshold_last_600), color="grey") for i, spiketrain in enumerate(spiketimes): plt.eventplot(spiketrain, colors="black", lineoffsets=max(traces[1]) + 3 + i) plt.show() plt.close() def plot_percentiles(trace): steps = np.arange(0, 100.1, 0.5) percentiles = np.percentile(trace, steps) plt.plot(steps, percentiles) plt.show() plt.close() def get_unsorted_spiketimes(fi_file): spiketimes = [] for metadata, key, data in Dl.iload(fi_file): spike_time_data = data[:, 0] / 1000 spiketimes.append(spike_time_data) return spiketimes def find_matching_spiketrain(redetected, cell_spiketrains, step_size): # redetected_idices = [int(np.rint(s / step_size)) for s in redetected] spikes_dict = {} for s in redetected: idx = int(np.rint(s / step_size)) spikes_dict[idx] = True spikes_dict[idx+1] = True spikes_dict[idx-1] = True similarity = np.zeros((len(cell_spiketrains), max([len(contrast_list) for contrast_list in cell_spiketrains]))) maximum = -1 max_idx = (-1, -1) for i, contrast_list in enumerate(cell_spiketrains): for j, cell_spiketrain in enumerate(contrast_list): count = 0 cell_spike_indices = [int(np.rint(s / step_size)) for s in cell_spiketrain] # plt.plot(cell_spiketrain, cell_spike_indices, '.') # plt.plot(redetected, redetected_idices, '.') # plt.show() # plt.close() for spike in cell_spiketrain: idx = int(np.rint(spike / step_size)) if idx in spikes_dict: count += 1 similarity[i, j] = count / len(cell_spiketrain) if similarity[i, j] > maximum: maximum = similarity[i, j] max_idx = (i, j) # plt.imshow(similarity) # plt.show() # plt.close() flattened = similarity.flatten() sorted_flattened = sorted(flattened) second_max = sorted_flattened[-2] if maximum < 0.5: print("Identification: max_sim: {:.2f} vs {:.2f} second max; Diff: {} worked".format(maximum, second_max, maximum - second_max)) return similarity, max_idx, (maximum, second_max) def get_redetected_spikes(cell_dir, before, after, step, threshold_pair): spikes_list = [] traces = [] count = 1 for info, key, time, x in Dl.iload_traces(cell_dir, repro="FICurve", before=before, after=after): # print(count) if '----- Control --------------------------------------------------------' in info[0].keys(): pre_duration = float( info[0]["----- Pre-Intensities ------------------------------------------------"]["preduration"][:-2]) if pre_duration != 0: continue elif "preduration" in info[0].keys(): pre_duration = float(info[0]["preduration"][:-2]) if pre_duration != 0: continue elif len(info) == 2 and "preduration" in info[1].keys(): pre_duration = float(info[1]["preduration"][:-2]) if pre_duration != 0: continue count += 1 # time, v1, eod, local_eod, stimulus # print(key) # print(info) v1 = x[0] # percentiles = np.arange(0.0, 101, 1) # plt.plot(percentiles, np.percentile(v1, percentiles)) # plt.show() # plt.close() if len(v1) > 15/step: print("Skipping Fi-Curve trace longer than 15 seconds!") continue if len(v1) > 3/step: print("Warning: A FI-Curve trace is longer than 3 seconds.") if after < 0.8: print("Why the f is the after stimulus time shorter than 0.8s ???") raise ValueError("Safety error: check where the after stimulus time comes from.") last_about_600_ms = int(np.rint((after-0.2)/step)) top = np.percentile(v1[-last_about_600_ms:], threshold_pair[1]) bottom = np.percentile(v1[-last_about_600_ms:], threshold_pair[0]) threshold = (top - bottom) peaks, _ = detect_peaks(v1, threshold=threshold) spikes = [time[idx] for idx in peaks] spikes_list.append(np.array(spikes)) # eod = x[1] # local_eod = x[2] stimulus = x[3] # if count % 5 == 0: # plt.eventplot(spikes, colors="black", lineoffsets=max(v1) + 1) # plt.plot(time, v1) # median = np.median(v1) # plt.plot((time[0], time[-1]), (median, median), color="grey") # plt.plot((time[0], time[-1]), (median+threshold, median+threshold), color="grey") # plt.show() # plt.close() # print(key[5]) # if "rectangle" not in key[5] and "FICurve" not in key[5][35]: # raise ValueError("No value in key 5 is rectangle:") traces.append([np.array(time), np.array(v1)]) return np.array(spikes_list), traces if __name__ == '__main__': main()