import matplotlib.pyplot as plt import numpy as np from IPython import embed import matplotlib as matplotlib import math import scipy.integrate as integrate from scipy import signal from scipy.interpolate import interp1d from scipy.interpolate import CubicSpline import scipy as sp import pickle from scipy.spatial import distance from myfunctions import * import time from matplotlib import gridspec #from matplotlib_scalebar.scalebar import ScaleBar import matplotlib.mlab as ml import scipy.integrate as si import pandas as pd def remove_all_spines(ax, nr): ax[nr].spines['right'].set_visible(False) ax[nr].spines['top'].set_visible(False) ax[nr].spines['left'].set_visible(False) ax[nr].spines['bottom'].set_visible(False) def find_beats(start,end,step,eod_fr): eod_fe = np.arange(start, end, step) beats = eod_fe - eod_fr beat_corr = eod_fe % eod_fr beat_corr[beat_corr > eod_fr / 2] = eod_fr - beat_corr[beat_corr > eod_fr / 2] return eod_fe, beat_corr, beats def snip(left_c,right_c,e,g,sampling, deviation_s,d,eod_fr, a_fr, eod_fe,phase_zero,p, size,s, sigma,a_fe,deviation,beat_corr, chirp = True): time, time_cut, cut = find_times(left_c[g], right_c[g], sampling, deviation_s[d]) eod_fish_e, eod_fish_r, period_fish_r, period_fish_e = find_periods(a_fe, time, eod_fr, a_fr, eod_fe, e) #embed() if chirp == True: eod_fe_chirp = integrate_chirp(a_fe, time, eod_fe[e], phase_zero[p], size[s], sigma) else: eod_fe_chirp = eod_fish_e eod_rec_down, eod_rec_up = rectify(eod_fish_r, eod_fe_chirp) # rectify eod_overlayed_chirp = (eod_fish_r + eod_fe_chirp)[cut:-cut] threshold_cube = (eod_rec_up) ** 3 maxima_values, maxima_index, maxima_interp = global_maxima(period_fish_e, period_fish_r, eod_rec_up[cut:-cut]) # global maxima index_peaks, value_peaks, peaks_interp = find_lm(eod_rec_up[cut:-cut]) # local maxima middle_conv, eod_conv_down, eod_conv_up, eod_conv_downsampled = conv(eod_fr,sampling, cut, deviation[d], eod_rec_up, eod_rec_down) # convolve eod_fish_both = integrate_chirp(a_fe, time, eod_fe[e] - eod_fr, phase_zero[p], size[s], sigma) am_corr_full = integrate_chirp(a_fe, time_cut, beat_corr[e], phase_zero[p], size[s], sigma) # indirect am calculation _, time_fish, cut_f = find_times(left_c[g], right_c[g], eod_fr, deviation_s[d]) # downsampled through fish EOD am_corr_ds = integrate_chirp(a_fe, time_fish, beat_corr[e], phase_zero[p], size[s], sigma) am_df_ds = integrate_chirp(a_fe, time_fish, eod_fe[e] - eod_fr, phase_zero[p], size[s], sigma) # indirect am calculation return cut, threshold_cube , time_cut, eod_conv_up, am_corr_full, peaks_interp, maxima_interp, am_corr_ds,am_df_ds,eod_fish_both,eod_overlayed_chirp def single_stim(ax,colors, row, col, eod_fr, eod_fe, e,lower, s = 0, p = 0, d = 0, labels = True,col_basic = 'silver',add = 'simple',df_col = 'blue', factor = 200, beat_corr_col = 'gold',col_hline = 'no', nfft = 4096, minus_bef = -30, delta_t = 0.014, sampling = 100000, deviation = [150],plus_bef = -10, a_fr = 1, phase_zero = [0], shift_phase = 0, size = [120],a_fe = 0.8,ax_nr = 'no',lw_whole = 0.5,y = 'yes'): beat_corr = eod_fe % eod_fr beat_corr[beat_corr > eod_fr / 2] = eod_fr - beat_corr[beat_corr > eod_fr / 2] sigma = delta_t / math.sqrt((2 * math.log(10))) # time, time_cut = find_times(time_range, sampling, deviation[d], 1) left_c = minus_bef * delta_t * sampling right_c = plus_bef * delta_t * sampling time, time_cut, cut = find_times(left_c, right_c, sampling, deviation[d] / (1000 * sampling)) #embed() time_fish_both = time * 2 * np.pi * (eod_fr - eod_fe[e]) eod_fish_both = 0.05 * np.sin(time_fish_both) eod_fish_e, eod_fish_r, period_fish_r, period_fish_e = find_periods(a_fe, time, eod_fr, a_fr, eod_fe, e) eod_fish_both = integrate_chirp(a_fe, time, eod_fe[e] - eod_fr, phase_zero[p] + shift_phase, size[s], sigma) eod_fe_chirp = integrate_chirp(a_fe, time, eod_fe[e], phase_zero[p], size[s], sigma) eod_overlayed_chirp = eod_fish_r + eod_fe_chirp eod_rectified_down, eod_recitified_up = rectify(eod_fish_r, eod_fe_chirp) # rectify maxima_values, maxima_index, maxima_interp = global_maxima(period_fish_e, period_fish_r, eod_recitified_up) # global maxima index_peaks, value_peaks, peaks_interp = find_lm(eod_recitified_up) # local maxima try: middle_conv, eod_convolved_down, eod_convolved_up, eod_conv_downsampled = conv(eod_fr, sampling, cut, deviation[d], eod_recitified_up, eod_rectified_down) # convolve except: middle_conv = [] eod_convolved_down = [] eod_convolved_up = [] eod_conv_downsampled = [] left_c = -200 * delta_t * sampling right_c = 200 * delta_t * sampling _, time_fish, _ = find_times(left_c, right_c, eod_fr, deviation[d]) # downsampled through fish EOD am_fish = integrate_chirp(a_fe, time_fish, beat_corr[e], phase_zero[p], size[s], sigma) middle_am = int(len(am_fish) / 2) print(beat_corr[e]) am_corr = integrate_chirp(a_fe, time_cut, beat_corr[e], phase_zero[p] + shift_phase, size[s], sigma) # indirect am calculation power, freq = ml.psd(maxima_interp - np.mean(maxima_interp), Fs=sampling, NFFT=nfft, noverlap=nfft / 2) f_max = freq[np.argmax(power[freq < 0.5 * eod_fr])] #ax['upper'].scatter(eod_fe[e] - eod_fr, f_max, color='red', s=19) if plus_bef < 0: green_true = False ending = time[0] * 1000, else: ending = 0 green_true = True plt.axvline(x=-7.5, color='black', linestyle='dotted', linewidth=1) plt.axvline(x=7.5, color='black', linestyle='dotted', linewidth=1) print(colors[e]) #embed() ax[e] = pl_eods(eod_fish_both, cut, maxima_interp, maxima_index, maxima_values, lower, e, e, time_cut, am_corr, eod_fe, eod_overlayed_chirp, deviation, d, eod_fr, sampling, value_peaks, time_fish, am_fish, factor, eod_convolved_down, index_peaks, eod_convolved_up, eod_recitified_up, add=add, green_true=green_true, beat_corr_col=beat_corr_col, ending = ending, col_basic = col_basic, color_am=colors[e], df_col=df_col,ax_nr = ax_nr,lw_whole = lw_whole) # for i in range(3): ax[e].spines['right'].set_visible(False) ax[e].spines['top'].set_visible(False) ax[e].spines['left'].set_visible(True) ax[e].spines['bottom'].set_visible(True) if col_hline != 'no': plt.axhline(y=0, color=col_hline, linewidth=0.5) # embed() xticks = 'off' yticks = 'off' plot_pos = col * row - col + 1 # if e+1 == plot_pos: # ax[e].set_xlabel('Time [ms]', labelpad=5) xaxis = np.arange(row * col - col + 1, row * col + 1, 1) if e + 1 == xaxis[int(len(xaxis) / 2)] and (labels == True): ax[e].set_xlabel('Time [ms]', labelpad=5) if (e + 1 in np.arange(1, row * col + 1, col)) and (y == 'yes')and (labels == True): ax[e].set_ylabel('[mv]', labelpad=5) if (beat_corr_col != 'no') and (df_col != 'steelblue'): ax[e].set_yticks([]) plt.subplots_adjust(wspace = 0.2) # else: # if xticks == 'off': # ax[e].set_xticks([]) # if yticks == 'off': # ax[e].set_yticks([]) # lower_left_label(e+1, col, row, 'Time [ms]', '[mv]',xticks = 'off',yticks = 'off',) return f_max,eod_overlayed_chirp,ax def title_variation(add, ax, eod_fe, eod_fr, e): if add == True: ax.title.set_text('DF:' + str(eod_fe[e] - eod_fr) + 'Hz ' + 'rf:' + str(eod_fr) + ' ef:' + str(eod_fe[e])) elif add == 'simple': ax.title.set_text('Beat:' + str(eod_fe[e] - eod_fr) + 'Hz') elif add == 'no': a = 2 else: ax.title.set_text( 'Beat:' + str(eod_fe[e] - eod_fr) + 'Hz, Mult:' + str(int(((eod_fe[e] - eod_fr) / eod_fr + 1) * 100) / 100)) def pl_eods(eod_fish_both, cut, maxima_interp, maxima_index, maxima, gs0, i, e, time, am_corr, eod_fe, eod_overlayed_chirp, deviation, d, eod_fr, sampling, value_peaks, time_fish, am_fish, factor, eod_convolved_down, index_peaks, eod_convolved_up, eod_rectified_up, add = False,lw_red = 1.2, lw = 1, add1 = False,share = False,green_true = True,beat_corr_col = 'orange',color_am = 'red',df_col = 'pink',ax_nr = 'no',col_basic = 'silver',ending = 0, lw_whole = 0.5): #if share == True: # ax = fig.add_subplot(row, col, i + 1, sharex=ax, # sharey=ax) #else: # ax = fig.add_subplot(row, col, i + 1) #embed() if type(ax_nr) != str: ax = plt.subplot(gs0[ax_nr]) else: ax = plt.subplot(gs0[int(e)]) # title variation title_variation(add, ax, eod_fe, eod_fr, e) # main version variations if col_basic != 'no': ax.plot(time * 1000-ending, eod_overlayed_chirp[cut:-cut], label='EOD both fish', color=col_basic, linewidth=lw_whole) if beat_corr_col != 'no': ax.plot(time * 1000-ending, am_corr +2.4, color=beat_corr_col, label='EOD adjusted beat', linewidth = lw) if color_am != 'no': ax.plot(time*1000-ending, maxima_interp[cut:-cut], color=color_am, label= 'AM',linewidth = lw)#[int(3 * deviation[d]):int(-3 * deviation[d])] if df_col != 'no': ax.plot(time*1000-ending,eod_fish_both[cut:-cut]+ 3.60,color=df_col,label= 'Difference frequency', linewidth = 0.6) # additional version variations if add1 == True: ax.scatter((maxima_index - 0.5 * len(eod_rectified_up)) / (sampling / 1000), maxima, color='red', s=10) ax.plot(time_fish[int(3 * deviation[d] / factor):int(-3 * deviation[d] / factor)] * 1000, am_fish[int(3 * deviation[d] / factor):int(-3 * deviation[d] / factor)] + 0.4, color='purple', label='indirect am - downgesampled', linewidth=lw) ax.plot((index_peaks - 0.5 * len(eod_rectified_up)) / (sampling / 1000), value_peaks, color='green', label='all maxima') if add == True: ax.plot(time * 1000-ending, eod_convolved_up, color='red',linewidth = lw) ax.plot(time * 1000-ending, eod_convolved_down, color='red', label='convolved',linewidth = lw) # embed() return ax def find_times(left_c,right_c, sampling,deviation_s): for_conv = 5 * deviation_s time = np.arange(int(np.round(left_c))-1000, int(np.round(right_c))+1000, 1) time = time[(time >left_c) &(time < right_c)] time = time/sampling #time = np.arange(-for_conv+left_c,for_conv+right_c, 1 / sampling) cut = int(np.ceil(for_conv*sampling)) if cut == 0: #time_cut = time*1 cut = 1 time_cut = time[cut:-cut] else: time_cut = time[cut:-cut] #embed() return time, time_cut, cut def conv(eod_fr, sampling, cut,deviation, eod_rectified_up, eod_rectified_down): if deviation* 5 % 2: points = deviation * 5 else: points = deviation * 5 - 1 #embed() gaussian = signal.gaussian(points, std=deviation, sym=True) gaussian_normalised = (gaussian * 2) / np.sum(gaussian) length_convolved = int(len(gaussian_normalised) / 2) eod_convolved_up = np.convolve(gaussian_normalised, eod_rectified_up) eod_convolved_up = eod_convolved_up[length_convolved + cut:-length_convolved - cut] eod_convolved_down = np.convolve(gaussian_normalised, eod_rectified_down) eod_convolved_down = eod_convolved_down[length_convolved + cut:-length_convolved - cut] middle_conv = int(len(eod_convolved_up) / 2) eod_conv_downsampled = eod_convolved_up[0:-1:int(np.round(sampling / eod_fr))] return middle_conv, eod_convolved_down, eod_convolved_up,eod_conv_downsampled def find_dev(x, sampling): deviation_ms = np.array(x) deviation_s = deviation_ms/1000 deviation_dp = sampling*deviation_s deviation_dp = list(map(int, deviation_dp)) return deviation_ms, deviation_s, deviation_dp def find_periods(a_fe, time, eod_fr,a_fr,eod_fe,e): time_fish_r = time * 2 * np.pi * eod_fr eod_fish_r = a_fr * np.sin(time_fish_r) period_fish_r = time_fish_r[(time_fish_r <= np.mean(time_fish_r)+2 * np.pi) & (time_fish_r > np.mean(time_fish_r))] time_fish_e = time * 2 * np.pi * eod_fe[e] eod_fish_e = a_fe * np.sin(time_fish_r) period_fish_e = time_fish_e[(time_fish_e <= np.mean(time_fish_e)+ 2 * np.pi) & (time_fish_e > np.mean(time_fish_e))] return eod_fish_e, eod_fish_r,period_fish_r,period_fish_e def integrate_chirp(a_fe,time,beat,phase_zero,size, sigma): I = ((np.pi ** 0.5) / 2) * sp.special.erf(time / sigma) - ((np.pi ** 0.5) / 2) * sp.special.erf(-np.inf) phase = time * 2 * np.pi * beat+ 2 * np.pi * size * sigma * I + phase_zero eod_fe_chirp = a_fe * np.sin(phase) return eod_fe_chirp def rectify(eod_fish_r,eod_fe_chirp): eod_rec_up = eod_fish_r + eod_fe_chirp eod_rectified_down = eod_fish_r + eod_fe_chirp eod_rec_up[eod_rec_up < 0] = 0 # rectify eod_rectified_down[eod_rectified_down > 0] = 0 # rectify return eod_rectified_down, eod_rec_up def find_lm(eod_rec_up): x = signal.find_peaks(eod_rec_up) index_peaks = x[0] value_peaks = eod_rec_up[index_peaks] peaks_interp = np.interp(np.arange(0, len(eod_rec_up), 1), index_peaks, value_peaks) return index_peaks, value_peaks, peaks_interp def global_maxima(period_fish_e,period_fish_r,eod_rectified_up): #period_length = max(len(period_fish_e), len(period_fish_r)) period_length = len(period_fish_r) if period_length >len(eod_rectified_up): maxima_values = np.max(eod_rectified_up) maxima_index = np.argmax(eod_rectified_up) maxima_interp = [maxima_values]*len(eod_rectified_up) else: split_windows = np.arange(period_length, len(eod_rectified_up), period_length) splits = np.split(eod_rectified_up, split_windows) steps = np.arange(0, len(eod_rectified_up), len(splits[0])) maxima_values = np.max(splits[0:-1], 1) maxima_index = np.argmax(splits[0:-1], 1) maxima_index = maxima_index + steps[0:-1] maxima_interp = np.interp(np.arange(0, len(eod_rectified_up), 1), maxima_index, maxima_values) return maxima_values,maxima_index, maxima_interp