From ccca6e030b08946fdc053d9d5c0f1119ab6ffaa7 Mon Sep 17 00:00:00 2001
From: alexanderott <a.ott@student.uni-tuebingen.de>
Date: Sat, 22 May 2021 13:10:15 +0200
Subject: [PATCH] tau calculation and more

---
 AnalysisMasterScript.py     |   2 +-
 Figure_constants.py         |   6 +-
 Figures_results.py          |   2 +-
 download_fits.py            |  46 ++++++++++++++
 experiments/FiCurve.py      | 120 +++++++++++++++++++++++++++++++++++-
 find_ram_stimulus_files.py  |  27 ++++++++
 parser/DataParserFactory.py |  11 ++--
 sam_experiments.py          |   2 +-
 save_model_fits_as_csv.py   |   6 +-
 test.py                     |  94 ++++++++++++++++++++++------
 10 files changed, 280 insertions(+), 36 deletions(-)
 create mode 100644 download_fits.py
 create mode 100644 find_ram_stimulus_files.py

diff --git a/AnalysisMasterScript.py b/AnalysisMasterScript.py
index 2605d75..d8dae80 100644
--- a/AnalysisMasterScript.py
+++ b/AnalysisMasterScript.py
@@ -6,7 +6,7 @@ import Figure_constants as Fc
 import my_util.save_load as sl
 
 
-fit_folder = "results/sam_cells"  # kraken fit
+fit_folder = "results/sam_cells_only_best"  # kraken fit
 # fit_folder = "results/final_sam_dend_noise_test/"  # noise in input fit
 # fit_folder = "results/final_sam2/"  # noise in input fit
 
diff --git a/Figure_constants.py b/Figure_constants.py
index 295b6eb..dd7261b 100644
--- a/Figure_constants.py
+++ b/Figure_constants.py
@@ -1,7 +1,7 @@
 
 import plottools.colors as ptc
-from plottools.axes import labelaxes_params
-
+#from plottools.axes import labelaxes_params
+from plottools.axes import axes_params
 
 SAVE_FOLDER = "./thesis/figures/"
 
@@ -35,5 +35,5 @@ finf_marker = (4, 0, 0)  # "s"
 
 
 def set_figure_labels(xoffset="auto", yoffset='auto'):
-    labelaxes_params(xoffs=xoffset, yoffs=yoffset, labels='A',
+    axes_params(xoffs=xoffset, yoffs=yoffset, label='A',
                      font=dict(size=16, family='serif'))
diff --git a/Figures_results.py b/Figures_results.py
index f01472d..b9c97ba 100644
--- a/Figures_results.py
+++ b/Figures_results.py
@@ -318,7 +318,7 @@ def create_parameter_distributions(par_values, prefix=""):
     plt.tight_layout()
 
     consts.set_figure_labels(xoffset=-2.5, yoffset=1.5)
-    fig.label_axes()
+    # fig.label_axes()
 
     plt.savefig(consts.SAVE_FOLDER + prefix + "parameter_distributions.pdf")
     plt.close()
diff --git a/download_fits.py b/download_fits.py
new file mode 100644
index 0000000..dc1b95a
--- /dev/null
+++ b/download_fits.py
@@ -0,0 +1,46 @@
+
+import os
+
+def main():
+    file_path = "./results/folders.txt"
+
+    cells = {}
+    with open(file_path) as f:
+        for l in f:
+            parts = l.strip().split("/")
+            cell_folder = parts[0]
+            start_para = parts[1]
+            score = float(start_para.split('_')[-1])
+            if cell_folder not in cells:
+                cells[cell_folder] = [score, start_para]
+            else:
+                if cells[cell_folder][0] > score:
+                    cells[cell_folder] = [score, start_para]
+
+    for k in sorted(cells.keys()):
+        print(k, cells[k][1])
+    remotehost = "alex@kraken.am28.uni-tuebingen.de"
+    remote_base = "P-unit_model/results/sam_cells/"
+    folders_to_copy = [remote_base + k + "/" + cells[k][1] + "/ " for k in sorted(cells.keys())]
+    remote_files = ""
+    for i in range(len(folders_to_copy)):
+        remote_files += folders_to_copy[i]
+
+    local_base = "./results/sam_cells_only_best/"
+    # os.system('scp -r  "%s:%s" "%s"' % (remotehost, remote_files, local_base))
+
+    # create folders
+
+    for k in sorted(cells.keys()):
+        cell_folder = "./results/sam_cells_only_best/" + k + "/"
+        os.makedirs(cell_folder)
+
+        os.rename("./results/sam_cells_only_best/best/" + cells[k][1], cell_folder + cells[k][1])
+
+
+def read_file(path):
+    pass
+
+
+if __name__ == '__main__':
+    main()
\ No newline at end of file
diff --git a/experiments/FiCurve.py b/experiments/FiCurve.py
index 2b42ae1..5c9424f 100644
--- a/experiments/FiCurve.py
+++ b/experiments/FiCurve.py
@@ -6,6 +6,7 @@ import numpy as np
 import matplotlib.pyplot as plt
 from warnings import warn
 from my_util import helperFunctions as hF, functions as fu
+from scipy.optimize import curve_fit
 from os.path import join, exists
 import pickle
 from sys import stderr
@@ -23,7 +24,7 @@ class FICurve:
         self.f_inf_frequencies = []
         self.indices_f_zero = []
         self.f_zero_frequencies = []
-
+        self.taus = []
         # increase, offset
         self.f_inf_fit = []
         # f_max, f_min, k, x_zero
@@ -45,9 +46,59 @@ class FICurve:
         self.f_inf_fit = hF.fit_clipped_line(self.stimulus_values, self.f_inf_frequencies)
         self.f_zero_fit = hF.fit_boltzmann(self.stimulus_values, self.f_zero_frequencies)
 
+    def __calculate_time_constant_internal__(self, contrast, mean_frequency, baseline_freq, sampling_interval, pre_duration, plot=False):
+        time_constant_fit_length = 0.05
+
+        if contrast > 0:
+            maximum_idx = np.argmax(mean_frequency)
+            maximum = mean_frequency[maximum_idx]
+            start_fit_idx = maximum_idx
+            while (mean_frequency[start_fit_idx]) > 0.80 * (maximum - baseline_freq) + baseline_freq:
+                start_fit_idx += 1
+
+        else:
+            minimum_idx = np.argmin(mean_frequency)
+            minimum = mean_frequency[minimum_idx]
+            start_fit_idx = minimum_idx
+            # print("Border: ", baseline_freq - (0.80 * (baseline_freq - minimum)))
+            while (mean_frequency[start_fit_idx]) < baseline_freq - (0.80 * (baseline_freq - minimum)):
+                start_fit_idx += 1
+
+        # print("start:", start_fit_idx * sampling_interval - pre_duration)
+        end_fit_idx = start_fit_idx + int(time_constant_fit_length / sampling_interval)
+
+        x_values = np.arange(end_fit_idx - start_fit_idx) * sampling_interval
+        y_values = mean_frequency[start_fit_idx:end_fit_idx]
+
+        try:
+            popt, pcov = curve_fit(fu.exponential_function, x_values, y_values,
+                                   p0=(1 / (np.power(1, 10)), 5, 50), maxfev=100000)
+
+            # print(popt)
+            if plot:
+                if contrast > 0:
+                    plt.title("c: {:.2f} Base_f: {:.2f}, f_zero: {:.2f}".format(contrast, baseline_freq, maximum))
+                else:
+                    plt.title("c: {:.2f} Base_f: {:.2f}, f_zero: {:.2f}".format(contrast, baseline_freq, minimum))
+                plt.plot(np.arange(len(mean_frequency)) * sampling_interval - pre_duration, mean_frequency,
+                         '.')
+                plt.plot(np.arange(start_fit_idx, end_fit_idx, 1) * sampling_interval - pre_duration, y_values,
+                         color="darkgreen")
+                plt.plot(np.arange(start_fit_idx, end_fit_idx, 1) * sampling_interval - pre_duration,
+                         fu.exponential_function(x_values, popt[0], popt[1], popt[2]), color="orange")
+                plt.show()
+                plt.close()
+            return popt, pcov
+        except RuntimeError:
+            print("RuntimeError happened in fit_exponential.")
+            return [], []
+
     def calculate_all_frequency_points(self):
         raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
 
+    def calculate_time_constant(self, contrast_idx, plot=False):
+        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
+
     def get_f_baseline_frequencies(self):
         return self.f_baseline_frequencies
 
@@ -143,7 +194,7 @@ class FICurve:
                 plt.savefig(save_path + "mean_frequency_contrast_{:.2f}.png".format(sv))
             plt.close()
 
-    def plot_fi_curve(self, save_path=None):
+    def plot_fi_curve(self, save_path=None, title=""):
         min_x = min(self.stimulus_values)
         max_x = max(self.stimulus_values)
         step = (max_x - min_x) / 5000
@@ -161,6 +212,7 @@ class FICurve:
                  color='red', label='f_0_fit')
 
         plt.legend()
+        plt.title(title)
         plt.ylabel("Frequency [Hz]")
         plt.xlabel("Stimulus value")
 
@@ -231,6 +283,33 @@ class FICurve:
             plt.savefig(save_path + "fi_curve_comparision.png")
         plt.close()
 
+    def write_detection_data_to_csv(self, save_path, name=""):
+        steady_state = self.get_f_inf_frequencies()
+        onset = self.get_f_zero_frequencies()
+        baseline = self.get_f_baseline_frequencies()
+        contrasts = self.stimulus_values
+
+        headers = ["contrasts", "f_baseline", "f_steady_state", "f_onset"]
+
+        if len(name) is not 0:
+            file_name = name
+        else:
+            file_name = "fi_data.csv"
+
+        with open(save_path + file_name, 'w') as f:
+            for i in range(len(headers)):
+                if i == 0:
+                    f.write(headers[i])
+                else:
+                    f.write("," + headers[i])
+            f.write("\n")
+
+            for i in range(len(contrasts)):
+                f.write(str(contrasts[i]) + ",")
+                f.write(str(baseline[i]) + ",")
+                f.write(str(steady_state[i]) + ",")
+                f.write(str(onset[i]) + "\n")
+
     def plot_f_point_detections(self, save_path=None):
         raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
 
@@ -243,6 +322,11 @@ class FICurve:
         values["f_zero_frequencies"] = self.f_zero_frequencies
         values["f_inf_fit"] = self.f_inf_fit
         values["f_zero_fit"] = self.f_zero_fit
+        taus = []
+
+        for i in range(len(self.stimulus_values)):
+            taus.append(self.calculate_time_constant(i))
+        values["time_constants"] = taus
 
         with open(join(save_directory, self.save_file_name), "wb") as file:
             pickle.dump(values, file)
@@ -267,6 +351,7 @@ class FICurve:
         self.f_zero_frequencies = values["f_zero_frequencies"]
         self.f_inf_fit = values["f_inf_fit"]
         self.f_zero_fit = values["f_zero_fit"]
+        self.taus = values["time_constants"]
         print("Fi-Curve: Values loaded!")
         return True
 
@@ -277,6 +362,21 @@ class FICurveCellData(FICurve):
         self.cell_data = cell_data
         super().__init__(stimulus_values, save_dir, recalculate)
 
+    def calculate_time_constant(self, contrast_idx, plot=False):
+        if len(self.taus) > 0:
+            return self.taus[contrast_idx]
+
+        mean_frequency = self.cell_data.get_mean_fi_curve_isi_frequencies()[contrast_idx]
+        baseline_freq = self.get_f_baseline_frequencies()[contrast_idx]
+        pre_duration = -1*self.cell_data.get_recording_times()[0]
+        sampling_interval = self.cell_data.get_sampling_interval()
+
+        # __calculate_time_constant_internal__(self, contrast, mean_frequency, baseline_freq, sampling_interval, pre_duration, plot=False):
+        popt, pcov = super().__calculate_time_constant_internal__(self.stimulus_values[contrast_idx], mean_frequency,
+                                                                  baseline_freq, sampling_interval, pre_duration, plot=plot)
+
+        return popt[1]
+
     def calculate_all_frequency_points(self):
         mean_frequencies = self.cell_data.get_mean_fi_curve_isi_frequencies()
         time_axes = self.cell_data.get_time_axes_fi_curve_mean_frequencies()
@@ -458,6 +558,22 @@ class FICurveModel(FICurve):
             self.f_baseline_frequencies.append(f_baseline)
             self.indices_f_baseline.append(f_base_idx)
 
+    def calculate_time_constant(self, contrast_idx, plot=False):
+        if len(self.taus) > 0:
+            return self.taus[contrast_idx]
+
+        mean_frequency = self.mean_frequency_traces[contrast_idx]
+        baseline_freq = self.get_f_baseline_frequencies()[contrast_idx]
+        pre_duration = 0
+        sampling_interval = self.model.get_sampling_interval()
+
+        popt, pcov = super().__calculate_time_constant_internal__(self.stimulus_values[contrast_idx], mean_frequency,
+                                                                  baseline_freq, sampling_interval, pre_duration, plot=plot)
+        if len(popt) > 0:
+            return popt[1]
+        else:
+            return -1
+
     def get_mean_time_and_freq_traces(self):
         return self.mean_time_traces, self.mean_frequency_traces
 
diff --git a/find_ram_stimulus_files.py b/find_ram_stimulus_files.py
new file mode 100644
index 0000000..8534154
--- /dev/null
+++ b/find_ram_stimulus_files.py
@@ -0,0 +1,27 @@
+
+import os
+import numpy as np
+import pyrelacs.DataLoader as Dl
+
+
+def main():
+    folder = "data/final/"
+    stim_files = []
+
+    for cell in sorted(os.listdir(folder)):
+
+        base_path = folder + cell
+
+        for info, key, time, x in Dl.iload_traces(base_path, repro="FileStimulus", before=0, after=0):
+            print(cell)
+            if len(info) == 2 and "file" in info[1].keys():
+                stim_files.append(info[1]["file"])
+            break
+
+
+    for file in np.unique(stim_files):
+        print(file)
+
+
+if __name__ == '__main__':
+    main()
diff --git a/parser/DataParserFactory.py b/parser/DataParserFactory.py
index 0e670dd..9acad8f 100644
--- a/parser/DataParserFactory.py
+++ b/parser/DataParserFactory.py
@@ -187,7 +187,7 @@ class DatParser(AbstractParser):
         return self.fi_recording_times
 
     def get_baseline_traces(self):
-        return self.__get_traces__("BaselineActivity")
+        return self.get_traces("BaselineActivity")
 
     def get_baseline_spiketimes(self):
         # TODO change: reading from file -> detect from v1 trace
@@ -201,7 +201,7 @@ class DatParser(AbstractParser):
         return spiketimes
 
     def get_fi_curve_traces(self):
-        return self.__get_traces__("FICurve")
+        return self.get_traces("FICurve")
 
     def get_fi_frequency_traces(self):
         raise NotImplementedError("Not possible in .dat data type.\n"
@@ -290,7 +290,7 @@ class DatParser(AbstractParser):
         return trans_amplitudes, intensities, spiketimes
 
     def get_sam_traces(self):
-        return self.__get_traces__("SAM")
+        return self.get_traces("SAM")
 
     def get_sam_info(self):
         contrasts = []
@@ -351,7 +351,8 @@ class DatParser(AbstractParser):
 
         return spiketimes, contrasts, delta_fs, eod_freqs, durations, trans_amplitudes
 
-    def __get_traces__(self, repro):
+    def get_traces(self, repro, before=0, after=0):
+
         time_traces = []
         v1_traces = []
         eod_traces = []
@@ -360,7 +361,7 @@ class DatParser(AbstractParser):
 
         nothing = True
 
-        for info, key, time, x in Dl.iload_traces(self.base_path, repro=repro):
+        for info, key, time, x in Dl.iload_traces(self.base_path, repro=repro, before=before, after=after):
             nothing = False
             time_traces.append(time)
             v1_traces.append(x[0])
diff --git a/sam_experiments.py b/sam_experiments.py
index e671903..9671499 100644
--- a/sam_experiments.py
+++ b/sam_experiments.py
@@ -75,7 +75,7 @@ def plot_traces_with_spiketimes():
     modelfit = get_best_fit("results/final_2/2011-10-25-ad-invivo-1/")
     cell_data = modelfit.get_cell_data()
 
-    traces = cell_data.parser.__get_traces__("SAM")
+    traces = cell_data.parser.get_traces("SAM")
     # [time_traces, v1_traces, eod_traces, local_eod_traces, stimulus_traces]
     sam_spiketimes = cell_data.get_sam_spiketimes()
     for i in range(len(traces[0])):
diff --git a/save_model_fits_as_csv.py b/save_model_fits_as_csv.py
index 3d404af..bea3d6e 100644
--- a/save_model_fits_as_csv.py
+++ b/save_model_fits_as_csv.py
@@ -5,14 +5,14 @@ import pandas as pd
 import numpy as np
 import matplotlib.pyplot as plt
 
-SAVE_DIR = "results/lab_rotation/"
+SAVE_DIR = "results/sam_cells_only_best/"
 
 
 def main():
 
-    res_folder = "results/final_2/"
+    res_folder = "results/sam_cells_only_best/"
 
-    # save_model_parameters(res_folder)
+    save_model_parameters(res_folder)
     # save_cell_info(res_folder)
 
     # test_save_cell_info()
diff --git a/test.py b/test.py
index d93b68d..eecf842 100644
--- a/test.py
+++ b/test.py
@@ -6,34 +6,88 @@ from fitting.ModelFit import ModelFit, get_best_fit
 # from plottools.axes import labelaxes_params
 import matplotlib.pyplot as plt
 from run_Fitter import iget_start_parameters
+from experiments.FiCurve import FICurve, FICurveCellData, FICurveModel
+
 colors = ["black", "red", "blue", "orange", "green"]
 
 
 def main():
-
-    fit = get_best_fit("results/kraken_fit/2011-10-25-ad-invivo-1/")
-    print(fit.get_fit_routine_error())
-
-    quit()
-    sam_tests()
-    # cells = 40
-    # number = len([i for i in iget_start_parameters()])
-    # single_core = number * 1400 / 60 / 60
-    # print("start parameters:", number)
-    # print("single core time:", single_core, "h")
-    # print("single core time:", single_core/24, "days")
+    # results_dir = "data/final/"
+    # for folder in sorted(os.listdir(results_dir)):
+    #     folder_path = os.path.join(results_dir, folder)
+    #
+    #     if not os.path.isdir(folder_path):
+    #         continue
+    #
+    #     cell_data = CellData(folder_path)
+    #     cell_name = cell_data.get_cell_name()
     #
-    # cores = 16
-    # cells = 40
+    #     fi_cell = FICurveCellData(cell_data, cell_data.get_fi_contrasts(), cell_data.data_path)
     #
-    # print(cores, "core time:", single_core/cores, "h")
-    # print(cores, "core time:", single_core / 24 / cores, "days")
-    # print(cores, "core time all", cells, "cells:", single_core / 24 / cores * cells, "days")
+    #     fi_cell.plot_fi_curve(title=cell_name, save_path="temp/cell_fi_curves_images/" + cell_name + "_")
     #
-    # print("left over:", number%cores)
+    #     steady_state = fi_cell.get_f_inf_frequencies()
+    #     onset = fi_cell.get_f_zero_frequencies()
+    #     baseline = fi_cell.get_f_baseline_frequencies()
+    #     contrasts = fi_cell.stimulus_values
+    #
+    #     headers = ["contrasts", "f_baseline", "f_steady_state", "f_onset"]
+    #     with open("temp/cell_fi_curves_csvs/" + cell_name + ".csv", 'w') as f:
+    #         for i in range(len(headers)):
+    #             if i == 0:
+    #                 f.write(headers[i])
+    #             else:
+    #                 f.write("," + headers[i])
+    #         f.write("\n")
+    #
+    #         for i in range(len(contrasts)):
+    #             f.write(str(contrasts[i]) + ",")
+    #             f.write(str(baseline[i]) + ",")
+    #             f.write(str(steady_state[i]) + ",")
+    #             f.write(str(onset[i]) + "\n")
+    # quit()
+    cell_taus = []
+    model_taus = []
+
+    results_dir = "results/sam_cells_only_best/"
+    for folder in sorted(os.listdir(results_dir)):
+        folder_path = os.path.join(results_dir, folder)
+        if not os.path.isdir(folder_path):
+            continue
+
+        fit = get_best_fit(folder_path)
+        print(fit.get_fit_routine_error())
+        model = fit.get_model()
+        cell_data = fit.get_cell_data()
+
+        fi_model = FICurveModel(model, cell_data.get_fi_contrasts(), cell_data.get_eod_frequency())
+        tau_model = fi_model.calculate_time_constant(-2)
+        model_taus.append(tau_model)
+        fi_cell = FICurveCellData(cell_data, cell_data.get_fi_contrasts(), cell_data.data_path)
+
+        tau_cell = fi_cell.calculate_time_constant(-2)
+        cell_taus.append(tau_cell)
+    # model_taus = [0.008227050473746214, 339.82706244279075, 0.010807838358313856, 0.01115826226335211, 0.007413613528371537, 0.013213123673467943, 0.010808781901437248, 0.0014254019917934319, 0.015448860984264491, 0.014413888046967265, 0.029301687421672096, 255.82969629640462, 0.00457130444591641, 0.009463250852321902, 0.007755615618900141, 0.009110183466482135, 0.007225102891006319, 0.0024319255218167336, 0.017420779742227246, 0.027195130905873905, 0.00934661249103802, 0.07158177921097474, 0.004866423936911278, 0.0008792730042370866, 0.00820470663372859, 0.05135988132772797, -945.8805502129879, -625.3981095962032, 0.00045249542468299257, 0.10198296886109447, 0.02992101543230009, 715.8802825637086, 0.0074281010613263775, 0.002038042609377947, 0.0055331475878047445, 0.010965819934792512, 0.00916015878530846, -123.0502556160885, 0.013734214511572751, 0.004193114169578979, 0.011103783836162914, 0.018070119202374276]
+    # cell_taus = [0.0035588022114672975, 0.005541599918212267, 0.007848670525682807, 0.008147461940299978, 0.005948699597158819, 0.0024739217090879104, 0.0038303906688137847, 0.00300889313116284, 0.014167509501882801, 0.009459132581703281, 0.005226151863380407, 772.607757547133, 0.0016936075127979523, 0.008768601246126134, 0.0036987681597240958, 0.009306705661392982, 0.004808427175831087, 0.005419130192821167, 0.0028735071877832733, 0.005983916198767454, 0.004369124640159074, 0.020115307489662095, 468.1810372271939, 0.0012946259647070454, 0.0021810924044437753, 259.6701021041893, 2891.7659169677813, -2155.469810882238, 0.0027895996432137117, 0.01503608591999554, 1138.5941497875147, -0.009831620851536924, 0.004657794528111363, -0.007131468820451661, -0.0221455330638256, -589.1530734507537, -506.6077728634018, -0.0028166760486066605, 359.3395355603788, -0.003053762369811596, 0.00465946355831796, 0.01675427242298042]
+
+    model_taus_c = [v for v in model_taus if np.abs(v) < 0.15]
+    cell_taus_c = [v for v in cell_taus if np.abs(v) < 0.15]
+    print("model removed:", len(model_taus) - len(model_taus_c))
+    print("cell removed:", len(cell_taus) - len(cell_taus_c))
+
+    fig, axes = plt.subplots(1, 2, sharey="all", sharex="all")
+
+    axes[0].hist(model_taus_c)
+    axes[0].set_title("Model taus")
+
+    axes[1].hist(cell_taus_c)
+    axes[1].set_title("Cell taus")
 
-    # fit = get_best_fit("results/final_sam2/2012-12-20-ae-invivo-1/")
-    # fit.generate_master_plot()
+    plt.show()
+    plt.close()
+    print(model_taus)
+    print(cell_taus)
+    # sam_tests()
 
 
 def sam_tests():