import os from parser.CellData import CellData from parser.DataParserFactory import DatParser import numpy as np from fitting.ModelFit import ModelFit, get_best_fit # from plottools.axes import labelaxes_params import matplotlib.pyplot as plt from run_Fitter import iget_start_parameters colors = ["black", "red", "blue", "orange", "green"] def main(): fit = get_best_fit("results/kraken_fit/2011-10-25-ad-invivo-1/") print(fit.get_fit_routine_error()) quit() sam_tests() # cells = 40 # number = len([i for i in iget_start_parameters()]) # single_core = number * 1400 / 60 / 60 # print("start parameters:", number) # print("single core time:", single_core, "h") # print("single core time:", single_core/24, "days") # # cores = 16 # cells = 40 # # print(cores, "core time:", single_core/cores, "h") # print(cores, "core time:", single_core / 24 / cores, "days") # print(cores, "core time all", cells, "cells:", single_core / 24 / cores * cells, "days") # # print("left over:", number%cores) # fit = get_best_fit("results/final_sam2/2012-12-20-ae-invivo-1/") # fit.generate_master_plot() def sam_tests(): data_folder = "./data/final_sam/" for cell in sorted(os.listdir(data_folder)): print(cell) cell_folder = os.path.join(data_folder, cell) if not os.path.exists(os.path.join(cell_folder, "samspikes1.dat")): continue if "2018-05-08-aa-invivo-1" not in cell: continue cell_data = CellData(cell_folder) sampling_rate = int(round(1 / cell_data.get_sampling_interval())) sam_spikes = cell_data.get_sam_spiketimes() delta_freqs = cell_data.get_sam_delta_frequencies() [time_traces, v1_traces, eod_traces, local_eod_traces, stimulus_traces] = cell_data.get_sam_traces() print(len(time_traces)) for i in range(len(delta_freqs)): if abs(delta_freqs[i]) > 50: continue fig, axes = plt.subplots(2, 1, sharex="all") axes[0].plot(time_traces[i], local_eod_traces[i]) axes[0].set_title("Local EOD - dF {}".format(delta_freqs[i])) axes[1].plot(time_traces[i], v1_traces[i]) axes[1].set_title("v1 trace") ah_spike = average_spike_height(sam_spikes, v1_traces[i], sampling_rate) for j, idx in enumerate(get_x_best(ah_spike)): axes[1].eventplot(sam_spikes[idx], lineoffsets=max(v1_traces[i] + 1.5 * (j + 1)), colors=colors[j % len(colors)]) plt.show() plt.close() def average_spike_height(spike_trains, local_eod, sampling_rate): average_height = [] for spikes_train in spike_trains: indices = np.array([s * sampling_rate for s in spikes_train[0]], dtype=np.int) local_eod = np.array(local_eod) spike_values = [local_eod[i] for i in indices if i < len(local_eod)] average_height.append(np.mean(spike_values)) return average_height def get_x_best(average_heights, x=5): biggest_idx = [] biggest_heights = [] for i, height in enumerate(average_heights): if len(biggest_idx) < x: biggest_idx.append(i) biggest_heights.append(height) elif height > min(biggest_heights): mini = np.argmin(biggest_heights) biggest_heights[mini] = height biggest_idx[mini] = i biggest_heights, biggest_idx = (list(t) for t in zip(*sorted(zip(biggest_heights, biggest_idx), reverse=True))) print(biggest_heights) return biggest_idx if __name__ == '__main__': main()