highbeats_pdf/functionssimulation.py
2020-12-01 12:00:50 +01:00

296 lines
14 KiB
Python

import matplotlib.pyplot as plt
import numpy as np
from IPython import embed
import matplotlib as matplotlib
import math
import scipy.integrate as integrate
from scipy import signal
from scipy.interpolate import interp1d
from scipy.interpolate import CubicSpline
import scipy as sp
import pickle
from scipy.spatial import distance
from myfunctions import *
import time
from matplotlib import gridspec
#from matplotlib_scalebar.scalebar import ScaleBar
import matplotlib.mlab as ml
import scipy.integrate as si
import pandas as pd
def remove_all_spines(ax, nr):
ax[nr].spines['right'].set_visible(False)
ax[nr].spines['top'].set_visible(False)
ax[nr].spines['left'].set_visible(False)
ax[nr].spines['bottom'].set_visible(False)
def find_beats(start,end,step,eod_fr):
eod_fe = np.arange(start, end, step)
beats = eod_fe - eod_fr
beat_corr = eod_fe % eod_fr
beat_corr[beat_corr > eod_fr / 2] = eod_fr - beat_corr[beat_corr > eod_fr / 2]
return eod_fe, beat_corr, beats
def snip(left_c,right_c,e,g,sampling, deviation_s,d,eod_fr, a_fr, eod_fe,phase_zero,p, size,s, sigma,a_fe,deviation,beat_corr, chirp = True):
time, time_cut, cut = find_times(left_c[g], right_c[g], sampling, deviation_s[d])
eod_fish_e, eod_fish_r, period_fish_r, period_fish_e = find_periods(a_fe, time, eod_fr, a_fr, eod_fe, e)
#embed()
if chirp == True:
eod_fe_chirp = integrate_chirp(a_fe, time, eod_fe[e], phase_zero[p], size[s], sigma)
else:
eod_fe_chirp = eod_fish_e
eod_rec_down, eod_rec_up = rectify(eod_fish_r, eod_fe_chirp) # rectify
eod_overlayed_chirp = (eod_fish_r + eod_fe_chirp)[cut:-cut]
threshold_cube = (eod_rec_up) ** 3
maxima_values, maxima_index, maxima_interp = global_maxima(period_fish_e, period_fish_r,
eod_rec_up[cut:-cut]) # global maxima
index_peaks, value_peaks, peaks_interp = find_lm(eod_rec_up[cut:-cut]) # local maxima
middle_conv, eod_conv_down, eod_conv_up, eod_conv_downsampled = conv(eod_fr,sampling, cut, deviation[d], eod_rec_up,
eod_rec_down) # convolve
eod_fish_both = integrate_chirp(a_fe, time, eod_fe[e] - eod_fr, phase_zero[p], size[s], sigma)
am_corr_full = integrate_chirp(a_fe, time_cut, beat_corr[e], phase_zero[p], size[s],
sigma) # indirect am calculation
_, time_fish, cut_f = find_times(left_c[g], right_c[g], eod_fr, deviation_s[d]) # downsampled through fish EOD
am_corr_ds = integrate_chirp(a_fe, time_fish, beat_corr[e], phase_zero[p], size[s], sigma)
am_df_ds = integrate_chirp(a_fe, time_fish, eod_fe[e] - eod_fr, phase_zero[p], size[s],
sigma) # indirect am calculation
return cut, threshold_cube , time_cut, eod_conv_up, am_corr_full, peaks_interp, maxima_interp, am_corr_ds,am_df_ds,eod_fish_both,eod_overlayed_chirp
def single_stim(ax,colors, row, col, eod_fr, eod_fe, e,lower, s = 0, p = 0, d = 0, labels = True,col_basic = 'silver',add = 'simple',df_col = 'blue', factor = 200, beat_corr_col = 'gold',col_hline = 'no', nfft = 4096, minus_bef = -30, delta_t = 0.014, sampling = 100000, deviation = [150],plus_bef = -10, a_fr = 1, phase_zero = [0], shift_phase = 0, size = [120],a_fe = 0.8,ax_nr = 'no',lw_whole = 0.5,y = 'yes'):
beat_corr = eod_fe % eod_fr
beat_corr[beat_corr > eod_fr / 2] = eod_fr - beat_corr[beat_corr > eod_fr / 2]
sigma = delta_t / math.sqrt((2 * math.log(10)))
# time, time_cut = find_times(time_range, sampling, deviation[d], 1)
left_c = minus_bef * delta_t * sampling
right_c = plus_bef * delta_t * sampling
time, time_cut, cut = find_times(left_c, right_c, sampling, deviation[d] / (1000 * sampling))
#embed()
time_fish_both = time * 2 * np.pi * (eod_fr - eod_fe[e])
eod_fish_both = 0.05 * np.sin(time_fish_both)
eod_fish_e, eod_fish_r, period_fish_r, period_fish_e = find_periods(a_fe, time, eod_fr, a_fr, eod_fe, e)
eod_fish_both = integrate_chirp(a_fe, time, eod_fe[e] - eod_fr, phase_zero[p] + shift_phase, size[s], sigma)
eod_fe_chirp = integrate_chirp(a_fe, time, eod_fe[e], phase_zero[p], size[s], sigma)
eod_overlayed_chirp = eod_fish_r + eod_fe_chirp
eod_rectified_down, eod_recitified_up = rectify(eod_fish_r, eod_fe_chirp) # rectify
maxima_values, maxima_index, maxima_interp = global_maxima(period_fish_e, period_fish_r,
eod_recitified_up) # global maxima
index_peaks, value_peaks, peaks_interp = find_lm(eod_recitified_up) # local maxima
try:
middle_conv, eod_convolved_down, eod_convolved_up, eod_conv_downsampled = conv(eod_fr, sampling, cut, deviation[d],
eod_recitified_up,
eod_rectified_down) # convolve
except:
middle_conv = []
eod_convolved_down = []
eod_convolved_up = []
eod_conv_downsampled = []
left_c = -200 * delta_t * sampling
right_c = 200 * delta_t * sampling
_, time_fish, _ = find_times(left_c, right_c, eod_fr, deviation[d]) # downsampled through fish EOD
am_fish = integrate_chirp(a_fe, time_fish, beat_corr[e], phase_zero[p], size[s], sigma)
middle_am = int(len(am_fish) / 2)
print(beat_corr[e])
am_corr = integrate_chirp(a_fe, time_cut, beat_corr[e], phase_zero[p] + shift_phase, size[s],
sigma) # indirect am calculation
power, freq = ml.psd(maxima_interp - np.mean(maxima_interp), Fs=sampling, NFFT=nfft, noverlap=nfft / 2)
f_max = freq[np.argmax(power[freq < 0.5 * eod_fr])]
#ax['upper'].scatter(eod_fe[e] - eod_fr, f_max, color='red', s=19)
if plus_bef < 0:
green_true = False
ending = time[0] * 1000,
else:
ending = 0
green_true = True
plt.axvline(x=-7.5, color='black', linestyle='dotted', linewidth=1)
plt.axvline(x=7.5, color='black', linestyle='dotted', linewidth=1)
print(colors[e])
#embed()
ax[e] = pl_eods(eod_fish_both, cut, maxima_interp, maxima_index,
maxima_values, lower, e, e, time_cut, am_corr, eod_fe, eod_overlayed_chirp, deviation, d,
eod_fr, sampling, value_peaks, time_fish, am_fish, factor, eod_convolved_down, index_peaks,
eod_convolved_up, eod_recitified_up, add=add, green_true=green_true,
beat_corr_col=beat_corr_col, ending = ending, col_basic = col_basic, color_am=colors[e], df_col=df_col,ax_nr = ax_nr,lw_whole = lw_whole) #
for i in range(3):
ax[e].spines['right'].set_visible(False)
ax[e].spines['top'].set_visible(False)
ax[e].spines['left'].set_visible(True)
ax[e].spines['bottom'].set_visible(True)
if col_hline != 'no':
plt.axhline(y=0, color=col_hline, linewidth=0.5)
# embed()
xticks = 'off'
yticks = 'off'
plot_pos = col * row - col + 1
# if e+1 == plot_pos:
# ax[e].set_xlabel('Time [ms]', labelpad=5)
xaxis = np.arange(row * col - col + 1, row * col + 1, 1)
if e + 1 == xaxis[int(len(xaxis) / 2)] and (labels == True):
ax[e].set_xlabel('Time [ms]', labelpad=5)
if (e + 1 in np.arange(1, row * col + 1, col)) and (y == 'yes')and (labels == True):
ax[e].set_ylabel('[mv]', labelpad=5)
if (beat_corr_col != 'no') and (df_col != 'steelblue'):
ax[e].set_yticks([])
plt.subplots_adjust(wspace = 0.2)
# else:
# if xticks == 'off':
# ax[e].set_xticks([])
# if yticks == 'off':
# ax[e].set_yticks([])
# lower_left_label(e+1, col, row, 'Time [ms]', '[mv]',xticks = 'off',yticks = 'off',)
return f_max,eod_overlayed_chirp,ax
def title_variation(add, ax, eod_fe, eod_fr, e):
if add == True:
ax.title.set_text('DF:' + str(eod_fe[e] - eod_fr) + 'Hz ' + 'rf:' + str(eod_fr) + ' ef:' + str(eod_fe[e]))
elif add == 'simple':
ax.title.set_text('Beat:' + str(eod_fe[e] - eod_fr) + 'Hz')
elif add == 'no':
a = 2
else:
ax.title.set_text(
'Beat:' + str(eod_fe[e] - eod_fr) + 'Hz, Mult:' + str(int(((eod_fe[e] - eod_fr) / eod_fr + 1) * 100) / 100))
def pl_eods(eod_fish_both, cut, maxima_interp, maxima_index, maxima, gs0, i, e, time, am_corr, eod_fe, eod_overlayed_chirp, deviation, d, eod_fr, sampling, value_peaks, time_fish, am_fish, factor, eod_convolved_down, index_peaks, eod_convolved_up, eod_rectified_up, add = False,lw_red = 1.2, lw = 1, add1 = False,share = False,green_true = True,beat_corr_col = 'orange',color_am = 'red',df_col = 'pink',ax_nr = 'no',col_basic = 'silver',ending = 0, lw_whole = 0.5):
#if share == True:
# ax = fig.add_subplot(row, col, i + 1, sharex=ax,
# sharey=ax)
#else:
# ax = fig.add_subplot(row, col, i + 1)
#embed()
if type(ax_nr) != str:
ax = plt.subplot(gs0[ax_nr])
else:
ax = plt.subplot(gs0[int(e)])
# title variation
title_variation(add, ax, eod_fe, eod_fr, e)
# main version variations
if col_basic != 'no':
ax.plot(time * 1000-ending, eod_overlayed_chirp[cut:-cut],
label='EOD both fish',
color=col_basic, linewidth=lw_whole)
if beat_corr_col != 'no':
ax.plot(time * 1000-ending, am_corr +2.4, color=beat_corr_col, label='EOD adjusted beat', linewidth = lw)
if color_am != 'no':
ax.plot(time*1000-ending, maxima_interp[cut:-cut], color=color_am, label= 'AM',linewidth = lw)#[int(3 * deviation[d]):int(-3 * deviation[d])]
if df_col != 'no':
ax.plot(time*1000-ending,eod_fish_both[cut:-cut]+ 3.60,color=df_col,label= 'Difference frequency', linewidth = 0.6)
# additional version variations
if add1 == True:
ax.scatter((maxima_index - 0.5 * len(eod_rectified_up)) / (sampling / 1000), maxima,
color='red', s=10)
ax.plot(time_fish[int(3 * deviation[d] / factor):int(-3 * deviation[d] / factor)] * 1000,
am_fish[int(3 * deviation[d] / factor):int(-3 * deviation[d] / factor)] + 0.4, color='purple',
label='indirect am - downgesampled', linewidth=lw)
ax.plot((index_peaks - 0.5 * len(eod_rectified_up)) / (sampling / 1000), value_peaks,
color='green', label='all maxima')
if add == True:
ax.plot(time * 1000-ending, eod_convolved_up, color='red',linewidth = lw)
ax.plot(time * 1000-ending, eod_convolved_down, color='red', label='convolved',linewidth = lw)
# embed()
return ax
def find_times(left_c,right_c, sampling,deviation_s):
for_conv = 5 * deviation_s
time = np.arange(int(np.round(left_c))-1000, int(np.round(right_c))+1000, 1)
time = time[(time >left_c) &(time < right_c)]
time = time/sampling
#time = np.arange(-for_conv+left_c,for_conv+right_c, 1 / sampling)
cut = int(np.ceil(for_conv*sampling))
if cut == 0:
#time_cut = time*1
cut = 1
time_cut = time[cut:-cut]
else:
time_cut = time[cut:-cut]
#embed()
return time, time_cut, cut
def conv(eod_fr, sampling, cut,deviation, eod_rectified_up, eod_rectified_down):
if deviation* 5 % 2:
points = deviation * 5
else:
points = deviation * 5 - 1
#embed()
gaussian = signal.gaussian(points, std=deviation, sym=True)
gaussian_normalised = (gaussian * 2) / np.sum(gaussian)
length_convolved = int(len(gaussian_normalised) / 2)
eod_convolved_up = np.convolve(gaussian_normalised, eod_rectified_up)
eod_convolved_up = eod_convolved_up[length_convolved + cut:-length_convolved - cut]
eod_convolved_down = np.convolve(gaussian_normalised, eod_rectified_down)
eod_convolved_down = eod_convolved_down[length_convolved + cut:-length_convolved - cut]
middle_conv = int(len(eod_convolved_up) / 2)
eod_conv_downsampled = eod_convolved_up[0:-1:int(np.round(sampling / eod_fr))]
return middle_conv, eod_convolved_down, eod_convolved_up,eod_conv_downsampled
def find_dev(x, sampling):
deviation_ms = np.array(x)
deviation_s = deviation_ms/1000
deviation_dp = sampling*deviation_s
deviation_dp = list(map(int, deviation_dp))
return deviation_ms, deviation_s, deviation_dp
def find_periods(a_fe, time, eod_fr,a_fr,eod_fe,e):
time_fish_r = time * 2 * np.pi * eod_fr
eod_fish_r = a_fr * np.sin(time_fish_r)
period_fish_r = time_fish_r[(time_fish_r <= np.mean(time_fish_r)+2 * np.pi) & (time_fish_r > np.mean(time_fish_r))]
time_fish_e = time * 2 * np.pi * eod_fe[e]
eod_fish_e = a_fe * np.sin(time_fish_r)
period_fish_e = time_fish_e[(time_fish_e <= np.mean(time_fish_e)+ 2 * np.pi) & (time_fish_e > np.mean(time_fish_e))]
return eod_fish_e, eod_fish_r,period_fish_r,period_fish_e
def integrate_chirp(a_fe,time,beat,phase_zero,size, sigma):
I = ((np.pi ** 0.5) / 2) * sp.special.erf(time / sigma) - ((np.pi ** 0.5) / 2) * sp.special.erf(-np.inf)
phase = time * 2 * np.pi * beat+ 2 * np.pi * size * sigma * I + phase_zero
eod_fe_chirp = a_fe * np.sin(phase)
return eod_fe_chirp
def rectify(eod_fish_r,eod_fe_chirp):
eod_rec_up = eod_fish_r + eod_fe_chirp
eod_rectified_down = eod_fish_r + eod_fe_chirp
eod_rec_up[eod_rec_up < 0] = 0 # rectify
eod_rectified_down[eod_rectified_down > 0] = 0 # rectify
return eod_rectified_down, eod_rec_up
def find_lm(eod_rec_up):
x = signal.find_peaks(eod_rec_up)
index_peaks = x[0]
value_peaks = eod_rec_up[index_peaks]
peaks_interp = np.interp(np.arange(0, len(eod_rec_up), 1), index_peaks, value_peaks)
return index_peaks, value_peaks, peaks_interp
def global_maxima(period_fish_e,period_fish_r,eod_rectified_up):
#period_length = max(len(period_fish_e), len(period_fish_r))
period_length = len(period_fish_r)
if period_length >len(eod_rectified_up):
maxima_values = np.max(eod_rectified_up)
maxima_index = np.argmax(eod_rectified_up)
maxima_interp = [maxima_values]*len(eod_rectified_up)
else:
split_windows = np.arange(period_length, len(eod_rectified_up), period_length)
splits = np.split(eod_rectified_up, split_windows)
steps = np.arange(0, len(eod_rectified_up), len(splits[0]))
maxima_values = np.max(splits[0:-1], 1)
maxima_index = np.argmax(splits[0:-1], 1)
maxima_index = maxima_index + steps[0:-1]
maxima_interp = np.interp(np.arange(0, len(eod_rectified_up), 1), maxima_index, maxima_values)
return maxima_values,maxima_index, maxima_interp