diff --git a/code/plot_functions.py b/code/plot_functions.py index 63be604..f8ea237 100644 --- a/code/plot_functions.py +++ b/code/plot_functions.py @@ -72,7 +72,7 @@ functions_path = r"C:\Users\diana\OneDrive - UT Cloud\Master\GPs\GP1_Grewe\Proje sys.path.append(functions_path) import useful_functions as u -def plot_highlighted_integrals(frequency, power, points, color_mapping, points_categories, delta = 2.5): +def plot_highlighted_integrals(frequency, power, points, color_mapping, points_categories, delta=2.5): """ Plot the power spectrum and highlight integrals that exceed the threshold. @@ -82,12 +82,10 @@ def plot_highlighted_integrals(frequency, power, points, color_mapping, points_c An array of frequencies corresponding to the power values. power : np.array An array of power spectral density values. - exceeding_points : list - A list of harmonic frequencies that exceed the threshold. + points : list + A list of harmonic frequencies to check and highlight. delta : float Half-width of the range for integration around each point. - threshold : float - Threshold value to compare integrals with local mean. color_mapping : dict A dictionary mapping each category to its color. points_categories : dict @@ -111,17 +109,23 @@ def plot_highlighted_integrals(frequency, power, points, color_mapping, points_c if valid: # Define color based on the category of the point color = next((c for cat, c in color_mapping.items() if point in points_categories[cat]), 'gray') + # Shade the region around the point where the integral was calculated ax.axvspan(point - delta, point + delta, color=color, alpha=0.3, label=f'{point:.2f} Hz') - print(f"Integral around {point:.2f} Hz: {integral:.5e}") - + + # Print out point and color + print(f"Integral around {point:.2f} Hz: {integral:.5e}, Color: {color}") + + # Annotate the plot with the point and its color + ax.text(point, max(power) * 0.9, f'{point:.2f}', color=color, fontsize=10, ha='center') + # Define left and right boundaries of adjacent regions left_boundary = frequency[np.where((frequency >= point - 5 * delta) & (frequency < point - delta))[0][0]] right_boundary = frequency[np.where((frequency > point + delta) & (frequency <= point + 5 * delta))[0][-1]] # Add vertical dashed lines at the boundaries of the adjacent regions - ax.axvline(x=left_boundary, color="k", linestyle="--") - ax.axvline(x=right_boundary, color="k", linestyle="--") + #ax.axvline(x=left_boundary, color="k", linestyle="--") + #ax.axvline(x=right_boundary, color="k", linestyle="--") ax.set_xlim([0, 1200]) ax.set_xlabel('Frequency (Hz)') @@ -132,3 +136,4 @@ def plot_highlighted_integrals(frequency, power, points, color_mapping, points_c return fig + diff --git a/code/tuning_curve_max.py b/code/tuning_curve_max.py index bea6610..64e669e 100644 --- a/code/tuning_curve_max.py +++ b/code/tuning_curve_max.py @@ -10,11 +10,9 @@ import useful_functions as f -# variables -delta = 2.5 # radius for peak detection # all files we want to use -files = glob.glob("../data/2024-10-16-af*.nix") +files = glob.glob("../data/2024-10-*.nix") # get only the good and fair filepaths new_files = f.remove_poor(files) @@ -22,6 +20,9 @@ new_files = f.remove_poor(files) # loop over all the good files for file in new_files: + + contrast_frequencies = [] + contrast_powers = [] # load a file dataset = rlx.Dataset(file) # extract sams @@ -30,14 +31,40 @@ for file in new_files: stim_frequencies = np.zeros(len(sams)) peak_powers = np.zeros_like(stim_frequencies) # loop over all sams - for i, sam in enumerate(sams): - # get sam frequency and stimuli - avg_dur, _, _, _, _, _, stim_frequency = f.sam_data(sam) - print(avg_dur) - if np.isnan(avg_dur): + # dictionary for the contrasts + contrast_sams = {20 : [], + 10 : [], + 5 : []} + # loop over all sams + for sam in sams: + # get the contrast + avg_dur, contrast, _, _, _, _, _ = f.sam_data(sam) + # check for valid trails + if np.isnan(contrast): + continue + elif sam.stimulus_count < 3: #aborted trials + continue + elif avg_dur < 1.7: continue - # use this to change lists basically and add the contrast somewhere else: + contrast = int(contrast) # get integer of contrast + # sort them accordingly + if contrast == 20: + contrast_sams[20].append(sam) + if contrast == 10: + contrast_sams[10].append(sam) + if contrast == 5: + contrast_sams[5].append(sam) + else: + continue + # loop over the contrasts + for key in contrast_sams: + stim_frequencies = np.zeros(len(contrast_sams[key])) + peak_powers = np.zeros_like(stim_frequencies) + + for i, sam in enumerate(contrast_sams[key]): + # get stimulus frequency and stimuli + _, _, _, _, _, _, stim_frequency = f.sam_data(sam) stimuli = sam.stimuli # lists for the power spectra frequencies = [] @@ -52,20 +79,27 @@ for file in new_files: #average over the stimuli sam_frequency = np.mean(frequencies, axis = 0) sam_power = np.mean(powers, axis = 0) - # detect and validate peaks + # detect peaks integral, surroundings, peak_power = f.calculate_integral(sam_frequency, sam_power, stim_frequency) - valid = f.valid_integrals(integral, surroundings, stim_frequency) - #if there is a peak get the power in the peak powers - if valid == True: - peak_powers[i] = peak_power + + peak_powers[i] = peak_power # add the current stimulus frequency stim_frequencies[i] = stim_frequency + + # replae zeros with NaN + peak_powers = np.where(peak_powers == 0, np.nan, peak_powers) + + contrast_frequencies.append(stim_frequencies) + contrast_powers.append(peak_powers) - # replae zeros with NaN - peak_powers = np.where(peak_powers == 0, np.nan, peak_powers) - -plt.plot(stim_frequencies, peak_powers) + fig, ax = plt.subplots(layout = 'constrained') + ax.plot(contrast_frequencies[0], contrast_powers[0]) + ax.plot(contrast_frequencies[1], contrast_powers[1]) + ax.plot(contrast_frequencies[2], contrast_powers[2]) + ax.set_xlabel('stimulus frequency [Hz]') + ax.set_ylabel(r' power [$\frac{\mathrm{mV^2}}{\mathrm{Hz}}$]') + ax.set_title(f"{file}") diff --git a/code/useful_functions.py b/code/useful_functions.py index de46615..a4a2d27 100644 --- a/code/useful_functions.py +++ b/code/useful_functions.py @@ -32,7 +32,7 @@ def all_coming_together(freq_array, power_array, points_list, categories, num_ha Returns ------- valid_points : list - A list of valid points with their harmonics. + A continuous list of harmonics for all valid points. color_mapping : dict A dictionary mapping categories to corresponding colors. category_harmonics : dict @@ -40,7 +40,7 @@ def all_coming_together(freq_array, power_array, points_list, categories, num_ha messages : list A list of messages for each point, stating whether it was valid or not. """ - valid_points = [] + valid_points = [] # A continuous list of harmonics for valid points color_mapping = {} category_harmonics = {} messages = [] @@ -58,7 +58,7 @@ def all_coming_together(freq_array, power_array, points_list, categories, num_ha if valid: # Step 3: Prepare harmonics if the point is valid harmonics, color_map, category_harm = prepare_harmonic(point, category, num_harmonics, color) - valid_points.append((point, harmonics)) + valid_points.extend(harmonics) # Use extend() to append harmonics in a continuous manner color_mapping.update(color_map) category_harmonics.update(category_harm) messages.append(f"The point {point} is valid.") @@ -67,6 +67,8 @@ def all_coming_together(freq_array, power_array, points_list, categories, num_ha return valid_points, color_mapping, category_harmonics, messages + + def AM(EODf, stimulus): """ Calculates the Amplitude Modulation and Nyquist frequency @@ -273,7 +275,7 @@ def power_spectrum(stimulus): # computes firing rates rate = firing_rate(binary, dt = dt) # creates power spectrum - freq, power = welch(rate, fs = 1/dt, nperseg = 2**16, noverlap = 2**15) + freq, power = welch(binary, fs = 1/dt, nperseg = 2**16, noverlap = 2**15) return freq, power def prepare_harmonic(frequency, category, num_harmonics, color): @@ -397,6 +399,39 @@ def sam_data(sam): avg_dur = np.mean(durations) return avg_dur, sam_amp, sam_am, sam_df, sam_eodf, sam_nyquist, sam_stim +def sam_spectrum(sam): + """ + Creates a power spectrum for a ReproRun of a SAM. + + Parameters + ---------- + sam : ReproRun Object + The Reprorun the powerspectrum should be generated from. + + Returns + ------- + sam_frequency : np.array + The frequencies of the powerspectrum. + sam_power : np.array + The powers of the frequencies. + + """ + stimuli = sam.stimuli + # lists for the power spectra + frequencies = [] + powers = [] + # loop over the stimuli + for stimulus in stimuli: + # get the powerspectrum for each stimuli + frequency, power = power_spectrum(stimulus) + # append the power spectrum data + frequencies.append(frequency) + powers.append(power) + #average over the stimuli + sam_frequency = np.mean(frequencies, axis = 0) + sam_power = np.mean(powers, axis = 0) + return sam_frequency, sam_power + def spike_times(stim): """ Reads out the spike times and other necessary parameters @@ -425,8 +460,7 @@ def spike_times(stim): dt = ti.sampling_interval return spikes, stim_dur, dt # se changed spike_times to spikes so its not the same as name of function - -def valid_integrals(integral, local_mean, point, threshold = 0.3): +def valid_integrals(integral, local_mean, point, threshold = 0.1): """ Check if the integral exceeds the threshold compared to the local mean and provide feedback on whether the given point is valid or not.