import os import pyrelacs.DataLoader as dl import numpy as np import matplotlib.pyplot as plt from warnings import warn def get_subfolder_paths(basepath): subfolders = [] for content in os.listdir(basepath): content_path = basepath + content if os.path.isdir(content_path): subfolders.append(content_path) return sorted(subfolders) def get_traces(directory, trace_type, repro): # trace_type = 1: Voltage p-unit # trace_type = 2: EOD # trace_type = 3: local EOD ~(EOD + stimulus) # trace_type = 4: Stimulus load_iter = dl.iload_traces(directory, repro=repro) time_traces = [] value_traces = [] nothing = True for info, key, time, x in load_iter: nothing = False time_traces.append(time) value_traces.append(x[trace_type-1]) if nothing: print("iload_traces found nothing for the BaselineActivity repro!") return time_traces, value_traces def get_all_traces(directory, repro): load_iter = dl.iload_traces(directory, repro=repro) time_traces = [] v1_traces = [] eod_traces = [] local_eod_traces = [] stimulus_traces = [] nothing = True for info, key, time, x in load_iter: nothing = False time_traces.append(time) v1_traces.append(x[0]) eod_traces.append(x[1]) local_eod_traces.append(x[2]) stimulus_traces.append(x[3]) print(info) traces = [v1_traces, eod_traces, local_eod_traces, stimulus_traces] if nothing: print("iload_traces found nothing for the BaselineActivity repro!") return time_traces, traces def merge_similar_intensities(intensities, spiketimes, trans_amplitudes): i = 0 diffs = np.diff(sorted(intensities)) margin = np.mean(diffs) * 0.6666 while True: if i >= len(intensities): break intensities, spiketimes, trans_amplitudes = merge_intensities_similar_to_index(intensities, spiketimes, trans_amplitudes, i, margin) i += 1 # Sort the lists so that intensities are increasing x = [list(x) for x in zip(*sorted(zip(intensities, spiketimes), key=lambda pair: pair[0]))] intensities = x[0] spiketimes = x[1] return intensities, spiketimes, trans_amplitudes def merge_intensities_similar_to_index(intensities, spiketimes, trans_amplitudes, index, margin): intensity = intensities[index] indices_to_merge = [] for i in range(index+1, len(intensities)): if np.abs(intensities[i]-intensity) < margin: indices_to_merge.append(i) if len(indices_to_merge) != 0: indices_to_merge.reverse() trans_amplitude_values = [trans_amplitudes[k] for k in indices_to_merge] all_the_same = True for j in range(1, len(trans_amplitude_values)): if not trans_amplitude_values[0] == trans_amplitude_values[j]: all_the_same = False break if all_the_same: for idx in indices_to_merge: del trans_amplitudes[idx] else: raise RuntimeError("Trans_amplitudes not the same....") for idx in indices_to_merge: spiketimes[index].extend(spiketimes[idx]) del spiketimes[idx] del intensities[idx] return intensities, spiketimes, trans_amplitudes def all_calculate_mean_isi_frequencies(spiketimes, time_start, sampling_interval): times = [] mean_frequencies = [] for i in range(len(spiketimes)): trial_times = [] trial_means = [] for j in range(len(spiketimes[i])): time, isi_freq = calculate_isi_frequency(spiketimes[i][j], time_start, sampling_interval) trial_means.append(isi_freq) trial_times.append(time) time, mean_freq = calculate_mean_frequency(trial_times, trial_means) times.append(time) mean_frequencies.append(mean_freq) return times, mean_frequencies # TODO remove additional time vector calculation! def calculate_isi_frequency(spiketimes, time_start, sampling_interval): first_isi = spiketimes[0] - time_start isis = [first_isi] isis.extend(np.diff(spiketimes)) time = np.arange(time_start, spiketimes[-1], sampling_interval) full_frequency = [] i = 0 for isi in isis: if isi == 0: warn("An ISI was zero in FiCurve:__calculate_mean_isi_frequency__()") continue freq = 1 / isi frequency_step = int(round(isi * (1 / sampling_interval))) * [freq] full_frequency.extend(frequency_step) i += 1 if len(full_frequency) != len(time): if abs(len(full_frequency) - len(time)) == 1: warn("FiCurve:__calculate_mean_isi_frequency__():\nFrequency and time were one of in length!") if len(full_frequency) < len(time): time = time[:len(full_frequency)] else: full_frequency = full_frequency[:len(time)] else: print("ERROR PRINT:") print("freq:", len(full_frequency), "time:", len(time), "diff:", len(full_frequency) - len(time)) raise RuntimeError("FiCurve:__calculate_mean_isi_frequency__():\n" "Frequency and time are not the same length!") return time, full_frequency def calculate_mean_frequency(trial_times, trial_freqs): lengths = [len(t) for t in trial_times] shortest = min(lengths) time = trial_times[0][0:shortest] shortend_freqs = [freq[0:shortest] for freq in trial_freqs] mean_freq = [sum(e) / len(e) for e in zip(*shortend_freqs)] return time, mean_freq def crappy_smoothing(signal:list, window_size:int = 5) -> list: smoothed = [] for i in range(len(signal)): k = window_size if i < window_size: k = i j = window_size if i + j > len(signal): j = len(signal) - i smoothed.append(np.mean(signal[i-k:i+j])) return smoothed def plot_frequency_curve(cell_data, save_path: str = None, indices: list = None): contrast = cell_data.get_fi_contrasts() time_axes = cell_data.get_time_axes_mean_frequencies() mean_freqs = cell_data.get_mean_isi_frequencies() if indices is None: indices = np.arange(len(contrast)) for i in indices: plt.plot(time_axes[i], mean_freqs[i], label=str(round(contrast[i], 2))) if save_path is None: plt.show() else: plt.savefig(save_path + "mean_frequency_curves.png") plt.close()