Merge branch 'master' of github.com:weygoldt/grid-chirpdetection

This commit is contained in:
weygoldt
2023-04-14 21:02:11 +02:00
28 changed files with 2361 additions and 747 deletions

View File

@@ -10,73 +10,84 @@ from modules.filters import bandpass_filter
def main(folder):
file = os.path.join(folder, 'traces-grid.raw')
file = os.path.join(folder, "traces-grid.raw")
data = open_data(folder, 60.0, 0, channel=-1)
time = np.load(folder + 'times.npy', allow_pickle=True)
freq = np.load(folder + 'fund_v.npy', allow_pickle=True)
ident = np.load(folder + 'ident_v.npy', allow_pickle=True)
idx = np.load(folder + 'idx_v.npy', allow_pickle=True)
time = np.load(folder + "times.npy", allow_pickle=True)
freq = np.load(folder + "fund_v.npy", allow_pickle=True)
ident = np.load(folder + "ident_v.npy", allow_pickle=True)
idx = np.load(folder + "idx_v.npy", allow_pickle=True)
t0 = 3*60*60 + 6*60 + 43.5
t0 = 3 * 60 * 60 + 6 * 60 + 43.5
dt = 60
data_oi = data[t0 * data.samplerate: (t0+dt)*data.samplerate, :]
data_oi = data[t0 * data.samplerate : (t0 + dt) * data.samplerate, :]
for i in [10]:
# getting the spectogramm
spec_power, spec_freqs, spec_times = spectrogram(
data_oi[:, i], ratetime=data.samplerate, freq_resolution=50, overlap_frac=0.0)
fig, ax = plt.subplots(figsize=(20/2.54, 12/2.54))
ax.pcolormesh(spec_times, spec_freqs, decibel(
spec_power), vmin=-100, vmax=-50)
data_oi[:, i],
ratetime=data.samplerate,
freq_resolution=50,
overlap_frac=0.0,
)
fig, ax = plt.subplots(figsize=(20 / 2.54, 12 / 2.54))
ax.pcolormesh(
spec_times, spec_freqs, decibel(spec_power), vmin=-100, vmax=-50
)
for track_id in np.unique(ident):
# window_index for time array in time window
window_index = np.arange(len(idx))[(ident == track_id) &
(time[idx] >= t0) &
(time[idx] <= (t0+dt))]
window_index = np.arange(len(idx))[
(ident == track_id)
& (time[idx] >= t0)
& (time[idx] <= (t0 + dt))
]
freq_temp = freq[window_index]
time_temp = time[idx[window_index]]
#mean_freq = np.mean(freq_temp)
#fdata = bandpass_filter(data_oi[:, track_id], data.samplerate, mean_freq-5, mean_freq+200)
# mean_freq = np.mean(freq_temp)
# fdata = bandpass_filter(data_oi[:, track_id], data.samplerate, mean_freq-5, mean_freq+200)
ax.plot(time_temp - t0, freq_temp)
ax.set_ylim(500, 1000)
plt.show()
# filter plot
id = 10.
id = 10.0
i = 10
window_index = np.arange(len(idx))[(ident == id) &
(time[idx] >= t0) &
(time[idx] <= (t0+dt))]
window_index = np.arange(len(idx))[
(ident == id) & (time[idx] >= t0) & (time[idx] <= (t0 + dt))
]
freq_temp = freq[window_index]
time_temp = time[idx[window_index]]
mean_freq = np.mean(freq_temp)
fdata = bandpass_filter(
data_oi[:, i], rate=data.samplerate, lowf=mean_freq-5, highf=mean_freq+200)
data_oi[:, i],
rate=data.samplerate,
lowf=mean_freq - 5,
highf=mean_freq + 200,
)
fig, ax = plt.subplots()
ax.plot(np.arange(len(fdata))/data.samplerate, fdata, marker='*')
ax.plot(np.arange(len(fdata)) / data.samplerate, fdata, marker="*")
# plt.show()
# freqency analyis of filtered data
time_fdata = np.arange(len(fdata))/data.samplerate
time_fdata = np.arange(len(fdata)) / data.samplerate
roll_fdata = np.roll(fdata, shift=1)
period_index = np.arange(len(fdata))[(roll_fdata < 0) & (fdata >= 0)]
plt.plot(time_fdata, fdata)
plt.scatter(time_fdata[period_index], fdata[period_index], c='r')
plt.scatter(time_fdata[period_index-1], fdata[period_index-1], c='r')
plt.scatter(time_fdata[period_index], fdata[period_index], c="r")
plt.scatter(time_fdata[period_index - 1], fdata[period_index - 1], c="r")
upper_bound = np.abs(fdata[period_index])
lower_bound = np.abs(fdata[period_index-1])
lower_bound = np.abs(fdata[period_index - 1])
upper_times = np.abs(time_fdata[period_index])
lower_times = np.abs(time_fdata[period_index-1])
lower_times = np.abs(time_fdata[period_index - 1])
lower_ratio = lower_bound/(lower_bound+upper_bound)
upper_ratio = upper_bound/(lower_bound+upper_bound)
lower_ratio = lower_bound / (lower_bound + upper_bound)
upper_ratio = upper_bound / (lower_bound + upper_bound)
time_delta = upper_times-lower_times
true_zero = lower_times + time_delta*lower_ratio
time_delta = upper_times - lower_times
true_zero = lower_times + time_delta * lower_ratio
plt.scatter(true_zero, np.zeros(len(true_zero)))
@@ -84,7 +95,7 @@ def main(folder):
inst_freq = 1 / np.diff(true_zero)
filtered_inst_freq = gaussian_filter1d(inst_freq, 0.005)
fig, ax = plt.subplots()
ax.plot(filtered_inst_freq, marker='.')
ax.plot(filtered_inst_freq, marker=".")
# in 5 sekunden welcher fisch auf einer elektrode am
embed()
@@ -99,5 +110,7 @@ def main(folder):
pass
if __name__ == '__main__':
main('/Users/acfw/Documents/uni_tuebingen/chirpdetection/gp_benda/data/2022-06-02-10_00/')
if __name__ == "__main__":
main(
"/Users/acfw/Documents/uni_tuebingen/chirpdetection/gp_benda/data/2022-06-02-10_00/"
)

View File

@@ -12,25 +12,27 @@ from modules.filehandling import LoadData
def main(folder):
data = LoadData(folder)
t0 = 3*60*60 + 6*60 + 43.5
t0 = 3 * 60 * 60 + 6 * 60 + 43.5
dt = 60
data_oi = data.raw[t0 * data.raw_rate: (t0+dt)*data.raw_rate, :]
# good electrode
electrode = 10
data_oi = data.raw[t0 * data.raw_rate : (t0 + dt) * data.raw_rate, :]
# good electrode
electrode = 10
data_oi = data_oi[:, electrode]
fig, axs = plt.subplots(2,1)
axs[0].plot( np.arange(data_oi.shape[0]) / data.raw_rate, data_oi)
fig, axs = plt.subplots(2, 1)
axs[0].plot(np.arange(data_oi.shape[0]) / data.raw_rate, data_oi)
for tr, track_id in enumerate(np.unique(data.ident[~np.isnan(data.ident)])):
rack_window_index = np.arange(len(data.idx))[
(data.ident == track_id) &
(data.time[data.idx] >= t0) &
(data.time[data.idx] <= (t0+dt))]
(data.ident == track_id)
& (data.time[data.idx] >= t0)
& (data.time[data.idx] <= (t0 + dt))
]
freq_fish = data.freq[rack_window_index]
axs[1].plot(np.arange(freq_fish.shape[0]) / data.raw_rate, freq_fish)
plt.show()
if __name__ == '__main__':
main('/Users/acfw/Documents/uni_tuebingen/chirpdetection/GP2023_chirp_detection/data/2022-06-02-10_00/')
if __name__ == "__main__":
main(
"/Users/acfw/Documents/uni_tuebingen/chirpdetection/GP2023_chirp_detection/data/2022-06-02-10_00/"
)

View File

@@ -1,8 +1,8 @@
import os
import os
import os
import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.pyplot as plt
from IPython import embed
from pandas import read_csv
@@ -11,51 +11,65 @@ from scipy.ndimage import gaussian_filter1d
logger = makeLogger(__name__)
class Behavior:
"""Load behavior data from csv file as class attributes
Attributes
----------
behavior: 0: chasing onset, 1: chasing offset, 2: physical contact
behavior_type:
behavioral_category:
comment_start:
comment_stop:
dataframe: pandas dataframe with all the data
duration_s:
media_file:
observation_date:
observation_id:
start_s: start time of the event in seconds
stop_s: stop time of the event in seconds
total_length:
behavior_type:
behavioral_category:
comment_start:
comment_stop:
dataframe: pandas dataframe with all the data
duration_s:
media_file:
observation_date:
observation_id:
start_s: start time of the event in seconds
stop_s: stop time of the event in seconds
total_length:
"""
def __init__(self, folder_path: str) -> None:
LED_on_time_BORIS = np.load(os.path.join(folder_path, 'LED_on_time.npy'), allow_pickle=True)
self.time = np.load(os.path.join(folder_path, "times.npy"), allow_pickle=True)
csv_filename = [f for f in os.listdir(folder_path) if f.endswith('.csv')][0] # check if there are more than one csv file
LED_on_time_BORIS = np.load(
os.path.join(folder_path, "LED_on_time.npy"), allow_pickle=True
)
self.time = np.load(
os.path.join(folder_path, "times.npy"), allow_pickle=True
)
csv_filename = [
f for f in os.listdir(folder_path) if f.endswith(".csv")
][
0
] # check if there are more than one csv file
self.dataframe = read_csv(os.path.join(folder_path, csv_filename))
self.chirps = np.load(os.path.join(folder_path, 'chirps.npy'), allow_pickle=True)
self.chirps_ids = np.load(os.path.join(folder_path, 'chirps_ids.npy'), allow_pickle=True)
self.chirps = np.load(
os.path.join(folder_path, "chirps.npy"), allow_pickle=True
)
self.chirps_ids = np.load(
os.path.join(folder_path, "chirps_ids.npy"), allow_pickle=True
)
for k, key in enumerate(self.dataframe.keys()):
key = key.lower()
if ' ' in key:
key = key.replace(' ', '_')
if '(' in key:
key = key.replace('(', '')
key = key.replace(')', '')
setattr(self, key, np.array(self.dataframe[self.dataframe.keys()[k]]))
key = key.lower()
if " " in key:
key = key.replace(" ", "_")
if "(" in key:
key = key.replace("(", "")
key = key.replace(")", "")
setattr(
self, key, np.array(self.dataframe[self.dataframe.keys()[k]])
)
last_LED_t_BORIS = LED_on_time_BORIS[-1]
real_time_range = self.time[-1] - self.time[0]
factor = 1.034141
shift = last_LED_t_BORIS - real_time_range * factor
self.start_s = (self.start_s - shift) / factor
self.stop_s = (self.stop_s - shift) / factor
"""
1 - chasing onset
2 - chasing offset
@@ -83,77 +97,77 @@ temporal encpding needs to be corrected ... not exactly 25FPS.
behavior = data['Behavior']
"""
def correct_chasing_events(
category: np.ndarray,
timestamps: np.ndarray
) -> tuple[np.ndarray, np.ndarray]:
onset_ids = np.arange(
len(category))[category == 0]
offset_ids = np.arange(
len(category))[category == 1]
def correct_chasing_events(
category: np.ndarray, timestamps: np.ndarray
) -> tuple[np.ndarray, np.ndarray]:
onset_ids = np.arange(len(category))[category == 0]
offset_ids = np.arange(len(category))[category == 1]
# Check whether on- or offset is longer and calculate length difference
if len(onset_ids) > len(offset_ids):
len_diff = len(onset_ids) - len(offset_ids)
longer_array = onset_ids
shorter_array = offset_ids
logger.info(f'Onsets are greater than offsets by {len_diff}')
logger.info(f"Onsets are greater than offsets by {len_diff}")
elif len(onset_ids) < len(offset_ids):
len_diff = len(offset_ids) - len(onset_ids)
longer_array = offset_ids
shorter_array = onset_ids
logger.info(f'Offsets are greater than offsets by {len_diff}')
logger.info(f"Offsets are greater than offsets by {len_diff}")
elif len(onset_ids) == len(offset_ids):
logger.info('Chasing events are equal')
logger.info("Chasing events are equal")
return category, timestamps
# Correct the wrong chasing events; delete double events
wrong_ids = []
for i in range(len(longer_array)-(len_diff+1)):
if (shorter_array[i] > longer_array[i]) & (shorter_array[i] < longer_array[i+1]):
for i in range(len(longer_array) - (len_diff + 1)):
if (shorter_array[i] > longer_array[i]) & (
shorter_array[i] < longer_array[i + 1]
):
pass
else:
wrong_ids.append(longer_array[i])
longer_array = np.delete(longer_array, i)
category = np.delete(
category, wrong_ids)
timestamps = np.delete(
timestamps, wrong_ids)
category = np.delete(category, wrong_ids)
timestamps = np.delete(timestamps, wrong_ids)
return category, timestamps
def event_triggered_chirps(
event: np.ndarray,
chirps:np.ndarray,
event: np.ndarray,
chirps: np.ndarray,
time_before_event: int,
time_after_event: int
)-> tuple[np.ndarray, np.ndarray]:
event_chirps = [] # chirps that are in specified window around event
centered_chirps = [] # timestamps of chirps around event centered on the event timepoint
time_after_event: int,
) -> tuple[np.ndarray, np.ndarray]:
event_chirps = [] # chirps that are in specified window around event
centered_chirps = (
[]
) # timestamps of chirps around event centered on the event timepoint
for event_timestamp in event:
start = event_timestamp - time_before_event # timepoint of window start
stop = event_timestamp + time_after_event # timepoint of window ending
chirps_around_event = [c for c in chirps if (c >= start) & (c <= stop)] # get chirps that are in a -5 to +5 sec window around event
start = event_timestamp - time_before_event # timepoint of window start
stop = event_timestamp + time_after_event # timepoint of window ending
chirps_around_event = [
c for c in chirps if (c >= start) & (c <= stop)
] # get chirps that are in a -5 to +5 sec window around event
event_chirps.append(chirps_around_event)
if len(chirps_around_event) == 0:
continue
else:
else:
centered_chirps.append(chirps_around_event - event_timestamp)
centered_chirps = np.concatenate(centered_chirps, axis=0) # convert list of arrays to one array for plotting
centered_chirps = np.concatenate(
centered_chirps, axis=0
) # convert list of arrays to one array for plotting
return event_chirps, centered_chirps
def main(datapath: str):
# behavior is pandas dataframe with all the data
bh = Behavior(datapath)
# chirps are not sorted in time (presumably due to prior groupings)
# get and sort chirps and corresponding fish_ids of the chirps
chirps = bh.chirps[np.argsort(bh.chirps)]
@@ -172,10 +186,34 @@ def main(datapath: str):
# First overview plot
fig1, ax1 = plt.subplots()
ax1.scatter(chirps, np.ones_like(chirps), marker='*', color='royalblue', label='Chirps')
ax1.scatter(chasing_onset, np.ones_like(chasing_onset)*2, marker='.', color='forestgreen', label='Chasing onset')
ax1.scatter(chasing_offset, np.ones_like(chasing_offset)*2.5, marker='.', color='firebrick', label='Chasing offset')
ax1.scatter(physical_contact, np.ones_like(physical_contact)*3, marker='x', color='black', label='Physical contact')
ax1.scatter(
chirps,
np.ones_like(chirps),
marker="*",
color="royalblue",
label="Chirps",
)
ax1.scatter(
chasing_onset,
np.ones_like(chasing_onset) * 2,
marker=".",
color="forestgreen",
label="Chasing onset",
)
ax1.scatter(
chasing_offset,
np.ones_like(chasing_offset) * 2.5,
marker=".",
color="firebrick",
label="Chasing offset",
)
ax1.scatter(
physical_contact,
np.ones_like(physical_contact) * 3,
marker="x",
color="black",
label="Physical contact",
)
plt.legend()
# plt.show()
plt.close()
@@ -187,29 +225,40 @@ def main(datapath: str):
# Evaluate how many chirps were emitted in specific time window around the chasing onset events
# Iterate over chasing onsets (later over fish)
time_around_event = 5 # time window around the event in which chirps are counted, 5 = -5 to +5 sec around event
time_around_event = 5 # time window around the event in which chirps are counted, 5 = -5 to +5 sec around event
#### Loop crashes at concatenate in function ####
# for i in range(len(fish_ids)):
# fish = fish_ids[i]
# chirps = chirps[chirps_fish_ids == fish]
# print(fish)
chasing_chirps, centered_chasing_chirps = event_triggered_chirps(chasing_onset, chirps, time_around_event, time_around_event)
physical_chirps, centered_physical_chirps = event_triggered_chirps(physical_contact, chirps, time_around_event, time_around_event)
chasing_chirps, centered_chasing_chirps = event_triggered_chirps(
chasing_onset, chirps, time_around_event, time_around_event
)
physical_chirps, centered_physical_chirps = event_triggered_chirps(
physical_contact, chirps, time_around_event, time_around_event
)
# Kernel density estimation ???
# centered_chasing_chirps_convolved = gaussian_filter1d(centered_chasing_chirps, 5)
# centered_chasing = chasing_onset[0] - chasing_onset[0] ## get the 0 timepoint for plotting; set one chasing event to 0
offsets = [0.5, 1]
fig4, ax4 = plt.subplots(figsize=(20 / 2.54, 12 / 2.54), constrained_layout=True)
ax4.eventplot(np.array([centered_chasing_chirps, centered_physical_chirps]), lineoffsets=offsets, linelengths=0.25, colors=['g', 'r'])
ax4.vlines(0, 0, 1.5, 'tab:grey', 'dashed', 'Timepoint of event')
fig4, ax4 = plt.subplots(
figsize=(20 / 2.54, 12 / 2.54), constrained_layout=True
)
ax4.eventplot(
np.array([centered_chasing_chirps, centered_physical_chirps]),
lineoffsets=offsets,
linelengths=0.25,
colors=["g", "r"],
)
ax4.vlines(0, 0, 1.5, "tab:grey", "dashed", "Timepoint of event")
# ax4.plot(centered_chasing_chirps_convolved)
ax4.set_yticks(offsets)
ax4.set_yticklabels(['Chasings', 'Physical \n contacts'])
ax4.set_xlabel('Time[s]')
ax4.set_ylabel('Type of event')
ax4.set_yticklabels(["Chasings", "Physical \n contacts"])
ax4.set_xlabel("Time[s]")
ax4.set_ylabel("Type of event")
plt.show()
# Associate chirps to inidividual fish
@@ -219,22 +268,21 @@ def main(datapath: str):
### Plots:
# 1. All recordings, all fish, all chirps
# One CTC, one PTC
# One CTC, one PTC
# 2. All recordings, only winners
# One CTC, one PTC
# One CTC, one PTC
# 3. All recordings, all losers
# One CTC, one PTC
# One CTC, one PTC
#### Chirp counts per fish general #####
fig2, ax2 = plt.subplots()
x = ['Fish1', 'Fish2']
x = ["Fish1", "Fish2"]
width = 0.35
ax2.bar(x, fish, width=width)
ax2.set_ylabel('Chirp count')
ax2.set_ylabel("Chirp count")
# plt.show()
plt.close()
##### Count chirps emitted during chasing events and chirps emitted out of chasing events #####
chirps_in_chasings = []
for onset, offset in zip(chasing_onset, chasing_offset):
@@ -251,23 +299,24 @@ def main(datapath: str):
counts_chirps_chasings += 1
# chirps in chasing events
fig3 , ax3 = plt.subplots()
ax3.bar(['Chirps in chasing events', 'Chasing events without Chirps'], [counts_chirps_chasings, chasings_without_chirps], width=width)
plt.ylabel('Count')
fig3, ax3 = plt.subplots()
ax3.bar(
["Chirps in chasing events", "Chasing events without Chirps"],
[counts_chirps_chasings, chasings_without_chirps],
width=width,
)
plt.ylabel("Count")
# plt.show()
plt.close()
plt.close()
# comparison between chasing events with and without chirps
embed()
exit()
if __name__ == '__main__':
if __name__ == "__main__":
# Path to the data
datapath = '../data/mount_data/2020-05-13-10_00/'
datapath = '../data/mount_data/2020-05-13-10_00/'
datapath = "../data/mount_data/2020-05-13-10_00/"
datapath = "../data/mount_data/2020-05-13-10_00/"
main(datapath)

View File

@@ -8,30 +8,27 @@ from modules.datahandling import instantaneous_frequency
from modules.simulations import create_chirp
# trying thunderfish fakefish chirp simulation ---------------------------------
samplerate = 44100
freq, ampl = fakefish.chirps(eodf=500, chirp_contrast=0.2)
data = fakefish.wavefish_eods(fish='Alepto', frequency=freq, phase0=3, samplerate=samplerate)
data = fakefish.wavefish_eods(
fish="Alepto", frequency=freq, phase0=3, samplerate=samplerate
)
# filter signal with bandpass_filter
data_filterd = bandpass_filter(data*ampl+1, samplerate, 0.01, 1.99)
data_filterd = bandpass_filter(data * ampl + 1, samplerate, 0.01, 1.99)
embed()
data_freq_time, data_freq = instantaneous_frequency(data, samplerate, 5)
fig, ax = plt.subplots(4, 1, figsize=(20 / 2.54, 12 / 2.54), sharex=True)
ax[0].plot(np.arange(len(data))/samplerate, data*ampl)
#ax[0].scatter(true_zero, np.zeros_like(true_zero), color='red')
ax[1].plot(np.arange(len(data_filterd))/samplerate, data_filterd)
ax[2].plot(np.arange(len(freq))/samplerate, freq)
ax[0].plot(np.arange(len(data)) / samplerate, data * ampl)
# ax[0].scatter(true_zero, np.zeros_like(true_zero), color='red')
ax[1].plot(np.arange(len(data_filterd)) / samplerate, data_filterd)
ax[2].plot(np.arange(len(freq)) / samplerate, freq)
ax[3].plot(data_freq_time, data_freq)
plt.show()
embed()

View File

@@ -1,6 +1,8 @@
from itertools import compress
from dataclasses import dataclass
from itertools import compress
import matplotlib.gridspec as gr
import matplotlib.pyplot as plt
import numpy as np
from IPython import embed
import matplotlib.pyplot as plt
@@ -15,11 +17,17 @@ from modules.plotstyle import PlotStyle
from modules.logger import makeLogger
from modules.datahandling import (
flatten,
purge_duplicates,
group_timestamps,
instantaneous_frequency,
minmaxnorm,
purge_duplicates,
)
from modules.filehandling import ConfLoader, LoadData, make_outputdir
from modules.filters import bandpass_filter, envelope, highpass_filter
from modules.logger import makeLogger
from modules.plotstyle import PlotStyle
from scipy.signal import find_peaks
from thunderfish.powerspectrum import decibel, spectrogram
logger = makeLogger(__name__)
@@ -58,7 +66,6 @@ class ChirpPlotBuffer:
frequency_peaks: np.ndarray
def plot_buffer(self, chirps: np.ndarray, plot: str) -> None:
logger.debug("Starting plotting")
# make data for plotting
@@ -134,7 +141,6 @@ class ChirpPlotBuffer:
ax0.set_ylim(np.min(self.frequency) - 100, np.max(self.frequency) + 200)
for track_id in self.data.ids:
t0_track = self.t0_old - 5
dt_track = self.dt + 10
window_idx = np.arange(len(self.data.idx))[
@@ -175,10 +181,16 @@ class ChirpPlotBuffer:
# )
ax0.axhline(
q50 - self.config.minimal_bandwidth / 2, color=ps.gblue1, lw=1, ls="dashed"
q50 - self.config.minimal_bandwidth / 2,
color=ps.gblue1,
lw=1,
ls="dashed",
)
ax0.axhline(
q50 + self.config.minimal_bandwidth / 2, color=ps.gblue1, lw=1, ls="dashed"
q50 + self.config.minimal_bandwidth / 2,
color=ps.gblue1,
lw=1,
ls="dashed",
)
ax0.axhline(search_lower, color=ps.gblue2, lw=1, ls="dashed")
ax0.axhline(search_upper, color=ps.gblue2, lw=1, ls="dashed")
@@ -204,7 +216,11 @@ class ChirpPlotBuffer:
# plot waveform of filtered signal
ax1.plot(
self.time, self.baseline * waveform_scaler, c=ps.gray, lw=lw, alpha=0.5
self.time,
self.baseline * waveform_scaler,
c=ps.gray,
lw=lw,
alpha=0.5,
)
ax1.plot(
self.time,
@@ -215,7 +231,13 @@ class ChirpPlotBuffer:
)
# plot waveform of filtered search signal
ax2.plot(self.time, self.search * waveform_scaler, c=ps.gray, lw=lw, alpha=0.5)
ax2.plot(
self.time,
self.search * waveform_scaler,
c=ps.gray,
lw=lw,
alpha=0.5,
)
ax2.plot(
self.time,
self.search_envelope_unfiltered * waveform_scaler,
@@ -237,9 +259,7 @@ class ChirpPlotBuffer:
# ax4.plot(
# self.time, self.baseline_envelope * waveform_scaler, c=ps.gblue1, lw=lw
# )
ax4.plot(
self.time, self.baseline_envelope, c=ps.gblue1, lw=lw
)
ax4.plot(self.time, self.baseline_envelope, c=ps.gblue1, lw=lw)
ax4.scatter(
(self.time)[self.baseline_peaks],
# (self.baseline_envelope * waveform_scaler)[self.baseline_peaks],
@@ -268,7 +288,9 @@ class ChirpPlotBuffer:
)
# plot filtered instantaneous frequency
ax6.plot(self.frequency_time, self.frequency_filtered, c=ps.gblue3, lw=lw)
ax6.plot(
self.frequency_time, self.frequency_filtered, c=ps.gblue3, lw=lw
)
ax6.scatter(
self.frequency_time[self.frequency_peaks],
self.frequency_filtered[self.frequency_peaks],
@@ -302,7 +324,9 @@ class ChirpPlotBuffer:
# ax7.spines.bottom.set_bounds((0, 5))
ax0.set_xlim(0, self.config.window)
plt.subplots_adjust(left=0.165, right=0.975, top=0.98, bottom=0.074, hspace=0.2)
plt.subplots_adjust(
left=0.165, right=0.975, top=0.98, bottom=0.074, hspace=0.2
)
fig.align_labels()
if plot == "show":
@@ -407,7 +431,9 @@ def extract_frequency_bands(
q25, q75 = q50 - minimal_bandwidth / 2, q50 + minimal_bandwidth / 2
# filter baseline
filtered_baseline = bandpass_filter(raw_data, samplerate, lowf=q25, highf=q75)
filtered_baseline = bandpass_filter(
raw_data, samplerate, lowf=q25, highf=q75
)
# filter search area
filtered_search_freq = bandpass_filter(
@@ -452,12 +478,14 @@ def window_median_all_track_ids(
track_ids = []
for _, track_id in enumerate(np.unique(data.ident[~np.isnan(data.ident)])):
# the window index combines the track id and the time window
window_idx = np.arange(len(data.idx))[
(data.ident == track_id)
& (data.time[data.idx] >= window_start_seconds)
& (data.time[data.idx] <= (window_start_seconds + window_duration_seconds))
& (
data.time[data.idx]
<= (window_start_seconds + window_duration_seconds)
)
]
if len(data.freq[window_idx]) > 0:
@@ -594,15 +622,15 @@ def find_searchband(
# iterate through theses tracks
if check_track_ids.size != 0:
for j, check_track_id in enumerate(check_track_ids):
q25_temp = q25[percentiles_ids == check_track_id]
q75_temp = q75[percentiles_ids == check_track_id]
bool_lower[search_window > q25_temp - config.search_res] = False
bool_upper[search_window < q75_temp + config.search_res] = False
search_window_bool[(bool_lower == False) & (bool_upper == False)] = False
search_window_bool[
(bool_lower == False) & (bool_upper == False)
] = False
# find gaps in search window
search_window_indices = np.arange(len(search_window))
@@ -621,7 +649,9 @@ def find_searchband(
# if the first value is -1, the array starst with true, so a gap
if nonzeros[0] == -1:
stops = search_window_indices[search_window_gaps == -1]
starts = np.append(0, search_window_indices[search_window_gaps == 1])
starts = np.append(
0, search_window_indices[search_window_gaps == 1]
)
# if the last value is -1, the array ends with true, so a gap
if nonzeros[-1] == 1:
@@ -658,7 +688,6 @@ def find_searchband(
def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None:
assert plot in [
"save",
"show",
@@ -728,7 +757,6 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None:
multiwindow_ids = []
for st, window_start_index in enumerate(window_start_indices):
logger.info(f"Processing window {st+1} of {len(window_start_indices)}")
window_start_seconds = window_start_index / data.raw_rate
@@ -743,8 +771,9 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None:
)
# iterate through all fish
for tr, track_id in enumerate(np.unique(data.ident[~np.isnan(data.ident)])):
for tr, track_id in enumerate(
np.unique(data.ident[~np.isnan(data.ident)])
):
logger.debug(f"Processing track {tr} of {len(data.ids)}")
# get index of track data in this time window
@@ -772,16 +801,17 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None:
nanchecker = np.unique(np.isnan(current_powers))
if (len(nanchecker) == 1) and nanchecker[0] is True:
logger.warning(
f"No powers available for track {track_id} window {st}," "skipping."
f"No powers available for track {track_id} window {st},"
"skipping."
)
continue
# find the strongest electrodes for the current fish in the current
# window
best_electrode_index = np.argsort(np.nanmean(current_powers, axis=0))[
-config.number_electrodes :
]
best_electrode_index = np.argsort(
np.nanmean(current_powers, axis=0)
)[-config.number_electrodes :]
# find a frequency above the baseline of the current fish in which
# no other fish is active to search for chirps there
@@ -801,9 +831,9 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None:
# iterate through electrodes
for el, electrode_index in enumerate(best_electrode_index):
logger.debug(
f"Processing electrode {el+1} of " f"{len(best_electrode_index)}"
f"Processing electrode {el+1} of "
f"{len(best_electrode_index)}"
)
# LOAD DATA FOR CURRENT ELECTRODE AND CURRENT FISH ------------
@@ -812,7 +842,9 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None:
current_raw_data = data.raw[
window_start_index:window_stop_index, electrode_index
]
current_raw_time = raw_time[window_start_index:window_stop_index]
current_raw_time = raw_time[
window_start_index:window_stop_index
]
# EXTRACT FEATURES --------------------------------------------
@@ -838,8 +870,7 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None:
# because the instantaneous frequency is not reliable there
amplitude_mask = mask_low_amplitudes(
baseline_envelope_unfiltered,
config.baseline_min_amplitude
baseline_envelope_unfiltered, config.baseline_min_amplitude
)
# highpass filter baseline envelope to remove slower
@@ -876,27 +907,30 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None:
# filtered baseline such as the one we are working with.
baseline_frequency = instantaneous_frequency(
baselineband,
data.raw_rate,
config.baseline_frequency_smoothing
baselineband,
data.raw_rate,
config.baseline_frequency_smoothing,
)
# Take the absolute of the instantaneous frequency to invert
# troughs into peaks. This is nessecary since the narrow
# troughs into peaks. This is nessecary since the narrow
# pass band introduces these anomalies. Also substract by the
# median to set it to 0.
baseline_frequency_filtered = np.abs(
baseline_frequency - np.median(baseline_frequency)
)
# check if there is at least one superthreshold peak on the
# instantaneous and exit the loop if not. This is used to
# prevent windows that do definetely not include a chirp
# to enter normalization, where small changes due to noise
# would be amplified
# check if there is at least one superthreshold peak on the
# instantaneous and exit the loop if not. This is used to
# prevent windows that do definetely not include a chirp
# to enter normalization, where small changes due to noise
# would be amplified
if not has_chirp(baseline_frequency_filtered[amplitude_mask], config.baseline_frequency_peakheight):
if not has_chirp(
baseline_frequency_filtered[amplitude_mask],
config.baseline_frequency_peakheight,
):
continue
# CUT OFF OVERLAP ---------------------------------------------
@@ -911,14 +945,20 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None:
current_raw_time = current_raw_time[no_edges]
baselineband = baselineband[no_edges]
baseline_envelope_unfiltered = baseline_envelope_unfiltered[no_edges]
baseline_envelope_unfiltered = baseline_envelope_unfiltered[
no_edges
]
searchband = searchband[no_edges]
baseline_envelope = baseline_envelope[no_edges]
search_envelope_unfiltered = search_envelope_unfiltered[no_edges]
search_envelope_unfiltered = search_envelope_unfiltered[
no_edges
]
search_envelope = search_envelope[no_edges]
baseline_frequency = baseline_frequency[no_edges]
baseline_frequency_filtered = baseline_frequency_filtered[no_edges]
baseline_frequency_filtered = baseline_frequency_filtered[
no_edges
]
baseline_frequency_time = current_raw_time
# # get instantaneous frequency withoup edges
@@ -959,13 +999,16 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None:
)
# detect peaks inst_freq_filtered
frequency_peak_indices, _ = find_peaks(
baseline_frequency_filtered, prominence=config.frequency_prominence
baseline_frequency_filtered,
prominence=config.frequency_prominence,
)
# DETECT CHIRPS IN SEARCH WINDOW ------------------------------
# get the peak timestamps from the peak indices
baseline_peak_timestamps = current_raw_time[baseline_peak_indices]
baseline_peak_timestamps = current_raw_time[
baseline_peak_indices
]
search_peak_timestamps = current_raw_time[search_peak_indices]
frequency_peak_timestamps = baseline_frequency_time[
@@ -1014,7 +1057,6 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None:
)
if chirp_detected or (debug != "elecrode"):
logger.debug("Detected chirp, ititialize buffer ...")
# save data to Buffer
@@ -1106,7 +1148,6 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None:
multiwindow_chirps_flat = []
multiwindow_ids_flat = []
for track_id in np.unique(multiwindow_ids):
# get chirps for this fish and flatten the list
current_track_bool = np.asarray(multiwindow_ids) == track_id
current_track_chirps = flatten(
@@ -1115,7 +1156,9 @@ def chirpdetection(datapath: str, plot: str, debug: str = "false") -> None:
# add flattened chirps to the list
multiwindow_chirps_flat.extend(current_track_chirps)
multiwindow_ids_flat.extend(list(np.ones_like(current_track_chirps) * track_id))
multiwindow_ids_flat.extend(
list(np.ones_like(current_track_chirps) * track_id)
)
# purge duplicates, i.e. chirps that are very close to each other
# duplites arise due to overlapping windows

View File

@@ -1,37 +1,37 @@
# Path setup ------------------------------------------------------------------
dataroot: "../data/" # path to data
outputdir: "../output/" # path to save plots to
dataroot: "../data/" # path to data
outputdir: "../output/" # path to save plots to
# Rolling window parameters ---------------------------------------------------
window: 5 # rolling window length in seconds
window: 5 # rolling window length in seconds
overlap: 1 # window overlap in seconds
edge: 0.25 # window edge cufoffs to mitigate filter edge effects
# Electrode iteration parameters ----------------------------------------------
number_electrodes: 2 # number of electrodes to go over
minimum_electrodes: 1 # mimumun number of electrodes a chirp must be on
number_electrodes: 2 # number of electrodes to go over
minimum_electrodes: 1 # mimumun number of electrodes a chirp must be on
# Feature extraction parameters -----------------------------------------------
search_df_lower: 20 # start searching this far above the baseline
search_df_upper: 100 # stop searching this far above the baseline
search_res: 1 # search window resolution
default_search_freq: 60 # search here if no need for a search frequency
minimal_bandwidth: 10 # minimal bandpass filter width for baseline
search_bandwidth: 10 # minimal bandpass filter width for search frequency
baseline_frequency_smoothing: 3 # instantaneous frequency smoothing
search_df_lower: 20 # start searching this far above the baseline
search_df_upper: 100 # stop searching this far above the baseline
search_res: 1 # search window resolution
default_search_freq: 60 # search here if no need for a search frequency
minimal_bandwidth: 10 # minimal bandpass filter width for baseline
search_bandwidth: 10 # minimal bandpass filter width for search frequency
baseline_frequency_smoothing: 3 # instantaneous frequency smoothing
# Feature processing parameters -----------------------------------------------
baseline_frequency_peakheight: 5 # the min peak height of the baseline instfreq
baseline_min_amplitude: 0.0001 # the minimal value of the baseline envelope
baseline_envelope_cutoff: 25 # envelope estimation cutoff
baseline_envelope_bandpass_lowf: 2 # envelope badpass lower cutoff
baseline_envelope_bandpass_highf: 100 # envelope bandbass higher cutoff
search_envelope_cutoff: 10 # search envelope estimation cufoff
baseline_frequency_peakheight: 5 # the min peak height of the baseline instfreq
baseline_min_amplitude: 0.0001 # the minimal value of the baseline envelope
baseline_envelope_cutoff: 25 # envelope estimation cutoff
baseline_envelope_bandpass_lowf: 2 # envelope badpass lower cutoff
baseline_envelope_bandpass_highf: 100 # envelope bandbass higher cutoff
search_envelope_cutoff: 10 # search envelope estimation cufoff
# Peak detecion parameters ----------------------------------------------------
# baseline_prominence: 0.00005 # peak prominence threshold for baseline envelope
@@ -39,9 +39,8 @@ search_envelope_cutoff: 10 # search envelope estimation cufoff
# frequency_prominence: 2 # peak prominence threshold for baseline freq
baseline_prominence: 0.3 # peak prominence threshold for baseline envelope
search_prominence: 0.3 # peak prominence threshold for search envelope
frequency_prominence: 0.3 # peak prominence threshold for baseline freq
search_prominence: 0.3 # peak prominence threshold for search envelope
frequency_prominence: 0.3 # peak prominence threshold for baseline freq
# Classify events as chirps if they are less than this time apart
chirp_window_threshold: 0.02

View File

@@ -35,28 +35,36 @@ class Behavior:
"""
def __init__(self, folder_path: str) -> None:
print(f'{folder_path}')
LED_on_time_BORIS = np.load(os.path.join(
folder_path, 'LED_on_time.npy'), allow_pickle=True)
self.time = np.load(os.path.join(
folder_path, "times.npy"), allow_pickle=True)
csv_filename = [f for f in os.listdir(folder_path) if f.endswith(
'.csv')][0] # check if there are more than one csv file
print(f"{folder_path}")
LED_on_time_BORIS = np.load(
os.path.join(folder_path, "LED_on_time.npy"), allow_pickle=True
)
self.time = np.load(
os.path.join(folder_path, "times.npy"), allow_pickle=True
)
csv_filename = [
f for f in os.listdir(folder_path) if f.endswith(".csv")
][
0
] # check if there are more than one csv file
self.dataframe = read_csv(os.path.join(folder_path, csv_filename))
self.chirps = np.load(os.path.join(
folder_path, 'chirps.npy'), allow_pickle=True)
self.chirps_ids = np.load(os.path.join(
folder_path, 'chirp_ids.npy'), allow_pickle=True)
self.chirps = np.load(
os.path.join(folder_path, "chirps.npy"), allow_pickle=True
)
self.chirps_ids = np.load(
os.path.join(folder_path, "chirp_ids.npy"), allow_pickle=True
)
for k, key in enumerate(self.dataframe.keys()):
key = key.lower()
if ' ' in key:
key = key.replace(' ', '_')
if '(' in key:
key = key.replace('(', '')
key = key.replace(')', '')
setattr(self, key, np.array(
self.dataframe[self.dataframe.keys()[k]]))
if " " in key:
key = key.replace(" ", "_")
if "(" in key:
key = key.replace("(", "")
key = key.replace(")", "")
setattr(
self, key, np.array(self.dataframe[self.dataframe.keys()[k]])
)
last_LED_t_BORIS = LED_on_time_BORIS[-1]
real_time_range = self.time[-1] - self.time[0]
@@ -95,17 +103,14 @@ temporal encpding needs to be corrected ... not exactly 25FPS.
def correct_chasing_events(
category: np.ndarray,
timestamps: np.ndarray
category: np.ndarray, timestamps: np.ndarray
) -> tuple[np.ndarray, np.ndarray]:
onset_ids = np.arange(len(category))[category == 0]
offset_ids = np.arange(len(category))[category == 1]
onset_ids = np.arange(
len(category))[category == 0]
offset_ids = np.arange(
len(category))[category == 1]
wrong_bh = np.arange(len(category))[
category != 2][:-1][np.diff(category[category != 2]) == 0]
wrong_bh = np.arange(len(category))[category != 2][:-1][
np.diff(category[category != 2]) == 0
]
if onset_ids[0] > offset_ids[0]:
offset_ids = np.delete(offset_ids, 0)
help_index = offset_ids[0]
@@ -117,12 +122,12 @@ def correct_chasing_events(
# Check whether on- or offset is longer and calculate length difference
if len(onset_ids) > len(offset_ids):
len_diff = len(onset_ids) - len(offset_ids)
logger.info(f'Onsets are greater than offsets by {len_diff}')
logger.info(f"Onsets are greater than offsets by {len_diff}")
elif len(onset_ids) < len(offset_ids):
len_diff = len(offset_ids) - len(onset_ids)
logger.info(f'Offsets are greater than onsets by {len_diff}')
logger.info(f"Offsets are greater than onsets by {len_diff}")
elif len(onset_ids) == len(offset_ids):
logger.info('Chasing events are equal')
logger.info("Chasing events are equal")
return category, timestamps
@@ -135,8 +140,7 @@ def event_triggered_chirps(
dt: float,
width: float,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
event_chirps = [] # chirps that are in specified window around event
event_chirps = [] # chirps that are in specified window around event
# timestamps of chirps around event centered on the event timepoint
centered_chirps = []
@@ -159,16 +163,19 @@ def event_triggered_chirps(
else:
# convert list of arrays to one array for plotting
centered_chirps = np.concatenate(centered_chirps, axis=0)
centered_chirps_convolved = (acausal_kde1d(
centered_chirps, time, width)) / len(event)
centered_chirps_convolved = (
acausal_kde1d(centered_chirps, time, width)
) / len(event)
return event_chirps, centered_chirps, centered_chirps_convolved
def main(datapath: str):
foldernames = [
datapath + x + '/' for x in os.listdir(datapath) if os.path.isdir(datapath + x)]
datapath + x + "/"
for x in os.listdir(datapath)
if os.path.isdir(datapath + x)
]
nrecording_chirps = []
nrecording_chirps_fish_ids = []
@@ -179,7 +186,7 @@ def main(datapath: str):
# Iterate over all recordings and save chirp- and event-timestamps
for folder in foldernames:
# exclude folder with empty LED_on_time.npy
if folder == '../data/mount_data/2020-05-12-10_00/':
if folder == "../data/mount_data/2020-05-12-10_00/":
continue
bh = Behavior(folder)
@@ -209,7 +216,7 @@ def main(datapath: str):
time_before_event = 30
time_after_event = 60
dt = 0.01
width = 1.5 # width of kernel for all recordings, currently gaussian kernel
width = 1.5 # width of kernel for all recordings, currently gaussian kernel
recording_width = 2 # width of kernel for each recording
time = np.arange(-time_before_event, time_after_event, dt)
@@ -232,18 +239,47 @@ def main(datapath: str):
physical_contacts = nrecording_physicals[i]
# Chirps around chasing onsets
_, centered_chasing_onset_chirps, cc_chasing_onset_chirps = event_triggered_chirps(
chasing_onsets, chirps, time_before_event, time_after_event, dt, recording_width)
(
_,
centered_chasing_onset_chirps,
cc_chasing_onset_chirps,
) = event_triggered_chirps(
chasing_onsets,
chirps,
time_before_event,
time_after_event,
dt,
recording_width,
)
# Chirps around chasing offsets
_, centered_chasing_offset_chirps, cc_chasing_offset_chirps = event_triggered_chirps(
chasing_offsets, chirps, time_before_event, time_after_event, dt, recording_width)
(
_,
centered_chasing_offset_chirps,
cc_chasing_offset_chirps,
) = event_triggered_chirps(
chasing_offsets,
chirps,
time_before_event,
time_after_event,
dt,
recording_width,
)
# Chirps around physical contacts
_, centered_physical_chirps, cc_physical_chirps = event_triggered_chirps(
physical_contacts, chirps, time_before_event, time_after_event, dt, recording_width)
(
_,
centered_physical_chirps,
cc_physical_chirps,
) = event_triggered_chirps(
physical_contacts,
chirps,
time_before_event,
time_after_event,
dt,
recording_width,
)
nrecording_centered_onset_chirps.append(centered_chasing_onset_chirps)
nrecording_centered_offset_chirps.append(
centered_chasing_offset_chirps)
nrecording_centered_offset_chirps.append(centered_chasing_offset_chirps)
nrecording_centered_physical_chirps.append(centered_physical_chirps)
## Shuffled chirps ##
@@ -331,12 +367,13 @@ def main(datapath: str):
# New bootstrapping approach
for n in range(nbootstrapping):
diff_onset = np.diff(
np.sort(flatten(nrecording_centered_onset_chirps)))
diff_onset = np.diff(np.sort(flatten(nrecording_centered_onset_chirps)))
diff_offset = np.diff(
np.sort(flatten(nrecording_centered_offset_chirps)))
np.sort(flatten(nrecording_centered_offset_chirps))
)
diff_physical = np.diff(
np.sort(flatten(nrecording_centered_physical_chirps)))
np.sort(flatten(nrecording_centered_physical_chirps))
)
np.random.shuffle(diff_onset)
shuffled_onset = np.cumsum(diff_onset)
@@ -345,9 +382,11 @@ def main(datapath: str):
np.random.shuffle(diff_physical)
shuffled_physical = np.cumsum(diff_physical)
kde_onset (acausal_kde1d(shuffled_onset, time, width))/(27*100)
kde_offset = (acausal_kde1d(shuffled_offset, time, width))/(27*100)
kde_physical = (acausal_kde1d(shuffled_physical, time, width))/(27*100)
kde_onset(acausal_kde1d(shuffled_onset, time, width)) / (27 * 100)
kde_offset = (acausal_kde1d(shuffled_offset, time, width)) / (27 * 100)
kde_physical = (acausal_kde1d(shuffled_physical, time, width)) / (
27 * 100
)
bootstrap_onset.append(kde_onset)
bootstrap_offset.append(kde_offset)
@@ -355,11 +394,14 @@ def main(datapath: str):
# New shuffle approach q5, q50, q95
onset_q5, onset_median, onset_q95 = np.percentile(
bootstrap_onset, [5, 50, 95], axis=0)
bootstrap_onset, [5, 50, 95], axis=0
)
offset_q5, offset_median, offset_q95 = np.percentile(
bootstrap_offset, [5, 50, 95], axis=0)
bootstrap_offset, [5, 50, 95], axis=0
)
physical_q5, physical_median, physical_q95 = np.percentile(
bootstrap_physical, [5, 50, 95], axis=0)
bootstrap_physical, [5, 50, 95], axis=0
)
# vstack um 1. Dim zu cutten
# nrecording_shuffled_convolved_onset_chirps = np.vstack(nrecording_shuffled_convolved_onset_chirps)
@@ -378,45 +420,66 @@ def main(datapath: str):
# Flatten event timestamps
all_onsets = np.concatenate(
nrecording_chasing_onsets).ravel() # not centered
nrecording_chasing_onsets
).ravel() # not centered
all_offsets = np.concatenate(
nrecording_chasing_offsets).ravel() # not centered
all_physicals = np.concatenate(
nrecording_physicals).ravel() # not centered
nrecording_chasing_offsets
).ravel() # not centered
all_physicals = np.concatenate(nrecording_physicals).ravel() # not centered
# Flatten all chirps around events
all_onset_chirps = np.concatenate(
nrecording_centered_onset_chirps).ravel() # centered
nrecording_centered_onset_chirps
).ravel() # centered
all_offset_chirps = np.concatenate(
nrecording_centered_offset_chirps).ravel() # centered
nrecording_centered_offset_chirps
).ravel() # centered
all_physical_chirps = np.concatenate(
nrecording_centered_physical_chirps).ravel() # centered
nrecording_centered_physical_chirps
).ravel() # centered
# Convolute all chirps
# Divide by total number of each event over all recordings
all_onset_chirps_convolved = (acausal_kde1d(
all_onset_chirps, time, width)) / len(all_onsets)
all_offset_chirps_convolved = (acausal_kde1d(
all_offset_chirps, time, width)) / len(all_offsets)
all_physical_chirps_convolved = (acausal_kde1d(
all_physical_chirps, time, width)) / len(all_physicals)
all_onset_chirps_convolved = (
acausal_kde1d(all_onset_chirps, time, width)
) / len(all_onsets)
all_offset_chirps_convolved = (
acausal_kde1d(all_offset_chirps, time, width)
) / len(all_offsets)
all_physical_chirps_convolved = (
acausal_kde1d(all_physical_chirps, time, width)
) / len(all_physicals)
# Plot all events with all shuffled
fig, ax = plt.subplots(1, 3, figsize=(
28*ps.cm, 16*ps.cm, ), constrained_layout=True, sharey='all')
fig, ax = plt.subplots(
1,
3,
figsize=(
28 * ps.cm,
16 * ps.cm,
),
constrained_layout=True,
sharey="all",
)
# offsets = np.arange(1,28,1)
ax[0].set_xlabel('Time[s]')
ax[0].set_xlabel("Time[s]")
# Plot chasing onsets
ax[0].set_ylabel('Chirp rate [Hz]')
ax[0].set_ylabel("Chirp rate [Hz]")
ax[0].plot(time, all_onset_chirps_convolved, color=ps.yellow, zorder=2)
ax0 = ax[0].twinx()
nrecording_centered_onset_chirps = np.asarray(
nrecording_centered_onset_chirps, dtype=object)
ax0.eventplot(np.array(nrecording_centered_onset_chirps),
linelengths=0.5, colors=ps.gray, alpha=0.25, zorder=1)
ax0.vlines(0, 0, 1.5, ps.white, 'dashed')
ax[0].set_zorder(ax0.get_zorder()+1)
nrecording_centered_onset_chirps, dtype=object
)
ax0.eventplot(
np.array(nrecording_centered_onset_chirps),
linelengths=0.5,
colors=ps.gray,
alpha=0.25,
zorder=1,
)
ax0.vlines(0, 0, 1.5, ps.white, "dashed")
ax[0].set_zorder(ax0.get_zorder() + 1)
ax[0].patch.set_visible(False)
ax0.set_yticklabels([])
ax0.set_yticks([])
@@ -426,15 +489,21 @@ def main(datapath: str):
ax[0].plot(time, onset_median, color=ps.black)
# Plot chasing offets
ax[1].set_xlabel('Time[s]')
ax[1].set_xlabel("Time[s]")
ax[1].plot(time, all_offset_chirps_convolved, color=ps.orange, zorder=2)
ax1 = ax[1].twinx()
nrecording_centered_offset_chirps = np.asarray(
nrecording_centered_offset_chirps, dtype=object)
ax1.eventplot(np.array(nrecording_centered_offset_chirps),
linelengths=0.5, colors=ps.gray, alpha=0.25, zorder=1)
ax1.vlines(0, 0, 1.5, ps.white, 'dashed')
ax[1].set_zorder(ax1.get_zorder()+1)
nrecording_centered_offset_chirps, dtype=object
)
ax1.eventplot(
np.array(nrecording_centered_offset_chirps),
linelengths=0.5,
colors=ps.gray,
alpha=0.25,
zorder=1,
)
ax1.vlines(0, 0, 1.5, ps.white, "dashed")
ax[1].set_zorder(ax1.get_zorder() + 1)
ax[1].patch.set_visible(False)
ax1.set_yticklabels([])
ax1.set_yticks([])
@@ -444,24 +513,31 @@ def main(datapath: str):
ax[1].plot(time, offset_median, color=ps.black)
# Plot physical contacts
ax[2].set_xlabel('Time[s]')
ax[2].set_xlabel("Time[s]")
ax[2].plot(time, all_physical_chirps_convolved, color=ps.maroon, zorder=2)
ax2 = ax[2].twinx()
nrecording_centered_physical_chirps = np.asarray(
nrecording_centered_physical_chirps, dtype=object)
ax2.eventplot(np.array(nrecording_centered_physical_chirps),
linelengths=0.5, colors=ps.gray, alpha=0.25, zorder=1)
ax2.vlines(0, 0, 1.5, ps.white, 'dashed')
ax[2].set_zorder(ax2.get_zorder()+1)
nrecording_centered_physical_chirps, dtype=object
)
ax2.eventplot(
np.array(nrecording_centered_physical_chirps),
linelengths=0.5,
colors=ps.gray,
alpha=0.25,
zorder=1,
)
ax2.vlines(0, 0, 1.5, ps.white, "dashed")
ax[2].set_zorder(ax2.get_zorder() + 1)
ax[2].patch.set_visible(False)
ax2.set_yticklabels([])
ax2.set_yticks([])
# ax[2].fill_between(time, shuffled_q5_physical, shuffled_q95_physical, color=ps.gray, alpha=0.5)
# ax[2].plot(time, shuffled_median_physical, ps.black)
ax[2].fill_between(time, physical_q5, physical_q95,
color=ps.gray, alpha=0.5)
ax[2].fill_between(
time, physical_q5, physical_q95, color=ps.gray, alpha=0.5
)
ax[2].plot(time, physical_median, ps.black)
fig.suptitle('All recordings')
fig.suptitle("All recordings")
plt.show()
plt.close()
@@ -587,7 +663,7 @@ def main(datapath: str):
#### Chirps around events, only losers, one recording ####
if __name__ == '__main__':
if __name__ == "__main__":
# Path to the data
datapath = '../data/mount_data/'
datapath = "../data/mount_data/"
main(datapath)

View File

@@ -8,50 +8,51 @@ from IPython import embed
def get_valid_datasets(dataroot):
datasets = sorted([name for name in os.listdir(dataroot) if os.path.isdir(
os.path.join(dataroot, name))])
datasets = sorted(
[
name
for name in os.listdir(dataroot)
if os.path.isdir(os.path.join(dataroot, name))
]
)
valid_datasets = []
for dataset in datasets:
path = os.path.join(dataroot, dataset)
csv_name = '-'.join(dataset.split('-')[:3]) + '.csv'
csv_name = "-".join(dataset.split("-")[:3]) + ".csv"
if os.path.exists(os.path.join(path, csv_name)) is False:
continue
if os.path.exists(os.path.join(path, 'ident_v.npy')) is False:
if os.path.exists(os.path.join(path, "ident_v.npy")) is False:
continue
ident = np.load(os.path.join(path, 'ident_v.npy'))
ident = np.load(os.path.join(path, "ident_v.npy"))
number_of_fish = len(np.unique(ident[~np.isnan(ident)]))
if number_of_fish != 2:
continue
valid_datasets.append(dataset)
datapaths = [os.path.join(dataroot, dataset) +
'/' for dataset in valid_datasets]
datapaths = [
os.path.join(dataroot, dataset) + "/" for dataset in valid_datasets
]
return datapaths, valid_datasets
def main(datapaths):
for path in datapaths:
chirpdetection(path, plot='show')
chirpdetection(path, plot="show")
if __name__ == '__main__':
if __name__ == "__main__":
dataroot = "../data/mount_data/"
dataroot = '../data/mount_data/'
datapaths, valid_datasets = get_valid_datasets(dataroot)
datapaths, valid_datasets= get_valid_datasets(dataroot)
recs = pd.DataFrame(columns=['recording'], data=valid_datasets)
recs.to_csv('../recs.csv', index=False)
recs = pd.DataFrame(columns=["recording"], data=valid_datasets)
recs.to_csv("../recs.csv", index=False)
# datapaths = ['../data/mount_data/2020-03-25-10_00/']
main(datapaths)

View File

@@ -1,4 +1,4 @@
import os
import os
from paramiko import SSHClient
from scp import SCPClient
from IPython import embed
@@ -7,29 +7,41 @@ from pandas import read_csv
ssh = SSHClient()
ssh.load_system_host_keys()
ssh.connect(hostname='kraken',
username='efish',
password='fwNix4U',
)
ssh.connect(
hostname="kraken",
username="efish",
password="fwNix4U",
)
# SCPCLient takes a paramiko transport as its only argument
scp = SCPClient(ssh.get_transport())
data = read_csv('../recs.csv')
foldernames = data['recording'].values
data = read_csv("../recs.csv")
foldernames = data["recording"].values
directory = f'/Users/acfw/Documents/uni_tuebingen/chirpdetection/GP2023_chirp_detection/data/mount_data/'
directory = f"/Users/acfw/Documents/uni_tuebingen/chirpdetection/GP2023_chirp_detection/data/mount_data/"
for foldername in foldernames:
if not os.path.exists(directory + foldername):
os.makedirs(directory + foldername)
if not os.path.exists(directory+foldername):
os.makedirs(directory+foldername)
files = [('-').join(foldername.split('-')[:3])+'.csv','chirp_ids.npy', 'chirps.npy', 'fund_v.npy', 'ident_v.npy', 'idx_v.npy', 'times.npy', 'spec.npy', 'LED_on_time.npy', 'sign_v.npy']
files = [
("-").join(foldername.split("-")[:3]) + ".csv",
"chirp_ids.npy",
"chirps.npy",
"fund_v.npy",
"ident_v.npy",
"idx_v.npy",
"times.npy",
"spec.npy",
"LED_on_time.npy",
"sign_v.npy",
]
for f in files:
scp.get(f'/home/efish/behavior/2019_tube_competition/{foldername}/{f}',
directory+foldername)
scp.get(
f"/home/efish/behavior/2019_tube_competition/{foldername}/{f}",
directory + foldername,
)
scp.close()

View File

@@ -30,12 +30,12 @@ class Behavior:
"""
def __init__(self, folder_path: str) -> None:
LED_on_time_BORIS = np.load(os.path.join(
folder_path, 'LED_on_time.npy'), allow_pickle=True)
LED_on_time_BORIS = np.load(
os.path.join(folder_path, "LED_on_time.npy"), allow_pickle=True
)
csv_filename = os.path.split(folder_path[:-1])[-1]
csv_filename = '-'.join(csv_filename.split('-')[:-1]) + '.csv'
csv_filename = "-".join(csv_filename.split("-")[:-1]) + ".csv"
# embed()
# csv_filename = [f for f in os.listdir(
@@ -43,31 +43,39 @@ class Behavior:
# logger.info(f'CSV file: {csv_filename}')
self.dataframe = read_csv(os.path.join(folder_path, csv_filename))
self.chirps = np.load(os.path.join(
folder_path, 'chirps.npy'), allow_pickle=True)
self.chirps_ids = np.load(os.path.join(
folder_path, 'chirp_ids.npy'), allow_pickle=True)
self.chirps = np.load(
os.path.join(folder_path, "chirps.npy"), allow_pickle=True
)
self.chirps_ids = np.load(
os.path.join(folder_path, "chirp_ids.npy"), allow_pickle=True
)
self.ident = np.load(os.path.join(
folder_path, 'ident_v.npy'), allow_pickle=True)
self.idx = np.load(os.path.join(
folder_path, 'idx_v.npy'), allow_pickle=True)
self.freq = np.load(os.path.join(
folder_path, 'fund_v.npy'), allow_pickle=True)
self.time = np.load(os.path.join(
folder_path, "times.npy"), allow_pickle=True)
self.spec = np.load(os.path.join(
folder_path, "spec.npy"), allow_pickle=True)
self.ident = np.load(
os.path.join(folder_path, "ident_v.npy"), allow_pickle=True
)
self.idx = np.load(
os.path.join(folder_path, "idx_v.npy"), allow_pickle=True
)
self.freq = np.load(
os.path.join(folder_path, "fund_v.npy"), allow_pickle=True
)
self.time = np.load(
os.path.join(folder_path, "times.npy"), allow_pickle=True
)
self.spec = np.load(
os.path.join(folder_path, "spec.npy"), allow_pickle=True
)
for k, key in enumerate(self.dataframe.keys()):
key = key.lower()
if ' ' in key:
key = key.replace(' ', '_')
if '(' in key:
key = key.replace('(', '')
key = key.replace(')', '')
setattr(self, key, np.array(
self.dataframe[self.dataframe.keys()[k]]))
if " " in key:
key = key.replace(" ", "_")
if "(" in key:
key = key.replace("(", "")
key = key.replace(")", "")
setattr(
self, key, np.array(self.dataframe[self.dataframe.keys()[k]])
)
last_LED_t_BORIS = LED_on_time_BORIS[-1]
real_time_range = self.time[-1] - self.time[0]
@@ -78,22 +86,19 @@ class Behavior:
def correct_chasing_events(
category: np.ndarray,
timestamps: np.ndarray
category: np.ndarray, timestamps: np.ndarray
) -> tuple[np.ndarray, np.ndarray]:
onset_ids = np.arange(len(category))[category == 0]
offset_ids = np.arange(len(category))[category == 1]
onset_ids = np.arange(
len(category))[category == 0]
offset_ids = np.arange(
len(category))[category == 1]
wrong_bh = np.arange(len(category))[
category != 2][:-1][np.diff(category[category != 2]) == 0]
wrong_bh = np.arange(len(category))[category != 2][:-1][
np.diff(category[category != 2]) == 0
]
if category[category != 2][-1] == 0:
wrong_bh = np.append(
wrong_bh,
np.arange(len(category))[category != 2][-1])
wrong_bh, np.arange(len(category))[category != 2][-1]
)
if onset_ids[0] > offset_ids[0]:
offset_ids = np.delete(offset_ids, 0)
@@ -103,18 +108,16 @@ def correct_chasing_events(
category = np.delete(category, wrong_bh)
timestamps = np.delete(timestamps, wrong_bh)
new_onset_ids = np.arange(
len(category))[category == 0]
new_offset_ids = np.arange(
len(category))[category == 1]
new_onset_ids = np.arange(len(category))[category == 0]
new_offset_ids = np.arange(len(category))[category == 1]
# Check whether on- or offset is longer and calculate length difference
if len(new_onset_ids) > len(new_offset_ids):
embed()
logger.warning('Onsets are greater than offsets')
logger.warning("Onsets are greater than offsets")
elif len(new_onset_ids) < len(new_offset_ids):
logger.warning('Offsets are greater than onsets')
logger.warning("Offsets are greater than onsets")
elif len(new_onset_ids) == len(new_offset_ids):
# logger.info('Chasing events are equal')
pass
@@ -130,13 +133,11 @@ def center_chirps(
# dt: float,
# width: float,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
event_chirps = [] # chirps that are in specified window around event
event_chirps = [] # chirps that are in specified window around event
# timestamps of chirps around event centered on the event timepoint
centered_chirps = []
for event_timestamp in events:
start = event_timestamp - time_before_event
stop = event_timestamp + time_after_event
chirps_around_event = [c for c in chirps if (c >= start) & (c <= stop)]
@@ -152,7 +153,8 @@ def center_chirps(
if len(centered_chirps) != len(event_chirps):
raise ValueError(
'Non centered chirps and centered chirps are not equal')
"Non centered chirps and centered chirps are not equal"
)
# time = np.arange(-time_before_event, time_after_event, dt)

View File

@@ -23,7 +23,9 @@ def minmaxnorm(data):
return (data - np.min(data)) / (np.max(data) - np.min(data))
def instantaneous_frequency2(signal: np.ndarray, fs: float, interpolation: str = 'linear') -> np.ndarray:
def instantaneous_frequency2(
signal: np.ndarray, fs: float, interpolation: str = "linear"
) -> np.ndarray:
"""
Compute the instantaneous frequency of a periodic signal using zero crossings and resample the frequency using linear
or cubic interpolation to match the dimensions of the input array.
@@ -55,10 +57,10 @@ def instantaneous_frequency2(signal: np.ndarray, fs: float, interpolation: str =
orig_len = len(signal)
freq = resample(freq, orig_len)
if interpolation == 'linear':
if interpolation == "linear":
freq = np.interp(np.arange(0, orig_len), np.arange(0, orig_len), freq)
elif interpolation == 'cubic':
freq = resample(freq, orig_len, window='cubic')
elif interpolation == "cubic":
freq = resample(freq, orig_len, window="cubic")
return freq
@@ -67,7 +69,7 @@ def instantaneous_frequency(
signal: np.ndarray,
samplerate: int,
smoothing_window: int,
interpolation: str = 'linear',
interpolation: str = "linear",
) -> np.ndarray:
"""
Compute the instantaneous frequency of a signal that is approximately
@@ -120,11 +122,10 @@ def instantaneous_frequency(
orig_len = len(signal)
freq = resample(instantaneous_frequency, orig_len)
if interpolation == 'linear':
if interpolation == "linear":
freq = np.interp(np.arange(0, orig_len), np.arange(0, orig_len), freq)
elif interpolation == 'cubic':
freq = resample(freq, orig_len, window='cubic')
elif interpolation == "cubic":
freq = resample(freq, orig_len, window="cubic")
return freq
@@ -160,7 +161,6 @@ def purge_duplicates(
group = [timestamps[0]]
for i in range(1, len(timestamps)):
# check the difference between current timestamp and previous
# timestamp is less than the threshold
if timestamps[i] - timestamps[i - 1] < threshold:
@@ -379,7 +379,6 @@ def acausal_kde1d(spikes, time, width):
if __name__ == "__main__":
timestamps = [
[1.2, 1.5, 1.3],
[],

View File

@@ -35,7 +35,6 @@ class LoadData:
"""
def __init__(self, datapath: str) -> None:
# load raw data
self.datapath = datapath
self.file = os.path.join(datapath, "traces-grid1.raw")

View File

@@ -3,10 +3,10 @@ import numpy as np
def bandpass_filter(
signal: np.ndarray,
samplerate: float,
lowf: float,
highf: float,
signal: np.ndarray,
samplerate: float,
lowf: float,
highf: float,
) -> np.ndarray:
"""Bandpass filter a signal.
@@ -60,9 +60,7 @@ def highpass_filter(
def lowpass_filter(
signal: np.ndarray,
samplerate: float,
cutoff: float
signal: np.ndarray, samplerate: float, cutoff: float
) -> np.ndarray:
"""Lowpass filter a signal.
@@ -86,10 +84,9 @@ def lowpass_filter(
return filtered_signal
def envelope(signal: np.ndarray,
samplerate: float,
cutoff_frequency: float
) -> np.ndarray:
def envelope(
signal: np.ndarray, samplerate: float, cutoff_frequency: float
) -> np.ndarray:
"""Calculate the envelope of a signal using a lowpass filter.
Parameters

View File

@@ -2,12 +2,13 @@ import logging
def makeLogger(name: str):
# create logger formats for file and terminal
file_formatter = logging.Formatter(
"[ %(levelname)s ] ~ %(asctime)s ~ %(module)s.%(funcName)s: %(message)s")
"[ %(levelname)s ] ~ %(asctime)s ~ %(module)s.%(funcName)s: %(message)s"
)
console_formatter = logging.Formatter(
"[ %(levelname)s ] in %(module)s.%(funcName)s: %(message)s")
"[ %(levelname)s ] in %(module)s.%(funcName)s: %(message)s"
)
# create logging file if loglevel is debug
file_handler = logging.FileHandler(f"gridtools_log.log", mode="w")
@@ -29,7 +30,6 @@ def makeLogger(name: str):
if __name__ == "__main__":
# initiate logger
mylogger = makeLogger(__name__)

View File

@@ -7,7 +7,6 @@ from matplotlib.colors import ListedColormap
def PlotStyle() -> None:
class style:
# lightcmap = cmocean.tools.lighten(cmocean.cm.haline, 0.8)
# units
@@ -76,13 +75,15 @@ def PlotStyle() -> None:
va="center",
zorder=1000,
bbox=dict(
boxstyle=f"circle, pad={padding}", fc="white", ec="black", lw=1
boxstyle=f"circle, pad={padding}",
fc="white",
ec="black",
lw=1,
),
)
@classmethod
def fade_cmap(cls, cmap):
my_cmap = cmap(np.arange(cmap.N))
my_cmap[:, -1] = np.linspace(0, 1, cmap.N)
my_cmap = ListedColormap(my_cmap)
@@ -295,7 +296,6 @@ def PlotStyle() -> None:
if __name__ == "__main__":
s = PlotStyle()
import matplotlib.cbook as cbook
@@ -347,7 +347,8 @@ if __name__ == "__main__":
for ax in axs:
ax.yaxis.grid(True)
ax.set_xticks(
[y + 1 for y in range(len(all_data))], labels=["x1", "x2", "x3", "x4"]
[y + 1 for y in range(len(all_data))],
labels=["x1", "x2", "x3", "x4"],
)
ax.set_xlabel("Four separate samples")
ax.set_ylabel("Observed values")
@@ -396,7 +397,10 @@ if __name__ == "__main__":
grid = np.random.rand(4, 4)
fig, axs = plt.subplots(
nrows=3, ncols=6, figsize=(9, 6), subplot_kw={"xticks": [], "yticks": []}
nrows=3,
ncols=6,
figsize=(9, 6),
subplot_kw={"xticks": [], "yticks": []},
)
for ax, interp_method in zip(axs.flat, methods):

View File

@@ -7,7 +7,6 @@ from matplotlib.colors import ListedColormap
def PlotStyle() -> None:
class style:
# lightcmap = cmocean.tools.lighten(cmocean.cm.haline, 0.8)
# units
@@ -76,13 +75,15 @@ def PlotStyle() -> None:
va="center",
zorder=1000,
bbox=dict(
boxstyle=f"circle, pad={padding}", fc="white", ec="black", lw=1
boxstyle=f"circle, pad={padding}",
fc="white",
ec="black",
lw=1,
),
)
@classmethod
def fade_cmap(cls, cmap):
my_cmap = cmap(np.arange(cmap.N))
my_cmap[:, -1] = np.linspace(0, 1, cmap.N)
my_cmap = ListedColormap(my_cmap)
@@ -295,7 +296,6 @@ def PlotStyle() -> None:
if __name__ == "__main__":
s = PlotStyle()
import matplotlib.cbook as cbook
@@ -347,7 +347,8 @@ if __name__ == "__main__":
for ax in axs:
ax.yaxis.grid(True)
ax.set_xticks(
[y + 1 for y in range(len(all_data))], labels=["x1", "x2", "x3", "x4"]
[y + 1 for y in range(len(all_data))],
labels=["x1", "x2", "x3", "x4"],
)
ax.set_xlabel("Four separate samples")
ax.set_ylabel("Observed values")
@@ -396,7 +397,10 @@ if __name__ == "__main__":
grid = np.random.rand(4, 4)
fig, axs = plt.subplots(
nrows=3, ncols=6, figsize=(9, 6), subplot_kw={"xticks": [], "yticks": []}
nrows=3,
ncols=6,
figsize=(9, 6),
subplot_kw={"xticks": [], "yticks": []},
)
for ax, interp_method in zip(axs.flat, methods):

View File

@@ -7,7 +7,6 @@ from matplotlib.colors import ListedColormap
def PlotStyle() -> None:
class style:
# lightcmap = cmocean.tools.lighten(cmocean.cm.haline, 0.8)
# units
@@ -76,13 +75,15 @@ def PlotStyle() -> None:
va="center",
zorder=1000,
bbox=dict(
boxstyle=f"circle, pad={padding}", fc="white", ec="black", lw=1
boxstyle=f"circle, pad={padding}",
fc="white",
ec="black",
lw=1,
),
)
@classmethod
def fade_cmap(cls, cmap):
my_cmap = cmap(np.arange(cmap.N))
my_cmap[:, -1] = np.linspace(0, 1, cmap.N)
my_cmap = ListedColormap(my_cmap)
@@ -295,7 +296,6 @@ def PlotStyle() -> None:
if __name__ == "__main__":
s = PlotStyle()
import matplotlib.cbook as cbook
@@ -347,7 +347,8 @@ if __name__ == "__main__":
for ax in axs:
ax.yaxis.grid(True)
ax.set_xticks(
[y + 1 for y in range(len(all_data))], labels=["x1", "x2", "x3", "x4"]
[y + 1 for y in range(len(all_data))],
labels=["x1", "x2", "x3", "x4"],
)
ax.set_xlabel("Four separate samples")
ax.set_ylabel("Observed values")
@@ -396,7 +397,10 @@ if __name__ == "__main__":
grid = np.random.rand(4, 4)
fig, axs = plt.subplots(
nrows=3, ncols=6, figsize=(9, 6), subplot_kw={"xticks": [], "yticks": []}
nrows=3,
ncols=6,
figsize=(9, 6),
subplot_kw={"xticks": [], "yticks": []},
)
for ax, interp_method in zip(axs.flat, methods):

View File

@@ -37,7 +37,7 @@ def create_chirp(
ck = 0
csig = 0.5 * chirpduration / np.power(2.0 * np.log(10.0), 0.5 / kurtosis)
#csig = csig*-1
# csig = csig*-1
for k, t in enumerate(time):
a = 1.0
f = eodf

View File

@@ -16,26 +16,25 @@ logger = makeLogger(__name__)
def get_chirp_winner_loser(folder_name, Behavior, order_meta_df):
foldername = folder_name.split('/')[-2]
winner_row = order_meta_df[order_meta_df['recording'] == foldername]
winner = winner_row['winner'].values[0].astype(int)
winner_fish1 = winner_row['fish1'].values[0].astype(int)
winner_fish2 = winner_row['fish2'].values[0].astype(int)
foldername = folder_name.split("/")[-2]
winner_row = order_meta_df[order_meta_df["recording"] == foldername]
winner = winner_row["winner"].values[0].astype(int)
winner_fish1 = winner_row["fish1"].values[0].astype(int)
winner_fish2 = winner_row["fish2"].values[0].astype(int)
if winner > 0:
if winner == winner_fish1:
winner_fish_id = winner_row['rec_id1'].values[0]
loser_fish_id = winner_row['rec_id2'].values[0]
winner_fish_id = winner_row["rec_id1"].values[0]
loser_fish_id = winner_row["rec_id2"].values[0]
elif winner == winner_fish2:
winner_fish_id = winner_row['rec_id2'].values[0]
loser_fish_id = winner_row['rec_id1'].values[0]
winner_fish_id = winner_row["rec_id2"].values[0]
loser_fish_id = winner_row["rec_id1"].values[0]
chirp_winner = len(
Behavior.chirps[Behavior.chirps_ids == winner_fish_id])
chirp_loser = len(
Behavior.chirps[Behavior.chirps_ids == loser_fish_id])
Behavior.chirps[Behavior.chirps_ids == winner_fish_id]
)
chirp_loser = len(Behavior.chirps[Behavior.chirps_ids == loser_fish_id])
return chirp_winner, chirp_loser
else:
@@ -43,24 +42,24 @@ def get_chirp_winner_loser(folder_name, Behavior, order_meta_df):
def get_chirp_size(folder_name, Behavior, order_meta_df, id_meta_df):
foldername = folder_name.split("/")[-2]
folder_row = order_meta_df[order_meta_df["recording"] == foldername]
fish1 = folder_row["fish1"].values[0].astype(int)
fish2 = folder_row["fish2"].values[0].astype(int)
winner = folder_row["winner"].values[0].astype(int)
foldername = folder_name.split('/')[-2]
folder_row = order_meta_df[order_meta_df['recording'] == foldername]
fish1 = folder_row['fish1'].values[0].astype(int)
fish2 = folder_row['fish2'].values[0].astype(int)
winner = folder_row['winner'].values[0].astype(int)
groub = folder_row["group"].values[0].astype(int)
size_fish1_row = id_meta_df[
(id_meta_df["group"] == groub) & (id_meta_df["fish"] == fish1)
]
size_fish2_row = id_meta_df[
(id_meta_df["group"] == groub) & (id_meta_df["fish"] == fish2)
]
groub = folder_row['group'].values[0].astype(int)
size_fish1_row = id_meta_df[(id_meta_df['group'] == groub) & (
id_meta_df['fish'] == fish1)]
size_fish2_row = id_meta_df[(id_meta_df['group'] == groub) & (
id_meta_df['fish'] == fish2)]
size_winners = [size_fish1_row[col].values[0]
for col in ['l1', 'l2', 'l3']]
size_winners = [size_fish1_row[col].values[0] for col in ["l1", "l2", "l3"]]
size_fish1 = np.nanmean(size_winners)
size_losers = [size_fish2_row[col].values[0] for col in ['l1', 'l2', 'l3']]
size_losers = [size_fish2_row[col].values[0] for col in ["l1", "l2", "l3"]]
size_fish2 = np.nanmean(size_losers)
if winner == fish1:
@@ -75,8 +74,8 @@ def get_chirp_size(folder_name, Behavior, order_meta_df, id_meta_df):
size_diff_bigger = 0
size_diff_smaller = 0
winner_fish_id = folder_row['rec_id1'].values[0]
loser_fish_id = folder_row['rec_id2'].values[0]
winner_fish_id = folder_row["rec_id1"].values[0]
loser_fish_id = folder_row["rec_id2"].values[0]
elif winner == fish2:
if size_fish2 > size_fish1:
@@ -90,39 +89,39 @@ def get_chirp_size(folder_name, Behavior, order_meta_df, id_meta_df):
size_diff_bigger = 0
size_diff_smaller = 0
winner_fish_id = folder_row['rec_id2'].values[0]
loser_fish_id = folder_row['rec_id1'].values[0]
winner_fish_id = folder_row["rec_id2"].values[0]
loser_fish_id = folder_row["rec_id1"].values[0]
else:
size_diff_bigger = np.nan
size_diff_smaller = np.nan
winner_fish_id = np.nan
loser_fish_id = np.nan
return size_diff_bigger, size_diff_smaller, winner_fish_id, loser_fish_id
return (
size_diff_bigger,
size_diff_smaller,
winner_fish_id,
loser_fish_id,
)
chirp_winner = len(
Behavior.chirps[Behavior.chirps_ids == winner_fish_id])
chirp_loser = len(
Behavior.chirps[Behavior.chirps_ids == loser_fish_id])
chirp_winner = len(Behavior.chirps[Behavior.chirps_ids == winner_fish_id])
chirp_loser = len(Behavior.chirps[Behavior.chirps_ids == loser_fish_id])
return size_diff_bigger, chirp_winner, size_diff_smaller, chirp_loser
return size_diff_bigger, chirp_winner, size_diff_smaller, chirp_loser
def get_chirp_freq(folder_name, Behavior, order_meta_df):
foldername = folder_name.split("/")[-2]
folder_row = order_meta_df[order_meta_df["recording"] == foldername]
fish1 = folder_row["fish1"].values[0].astype(int)
fish2 = folder_row["fish2"].values[0].astype(int)
foldername = folder_name.split('/')[-2]
folder_row = order_meta_df[order_meta_df['recording'] == foldername]
fish1 = folder_row['fish1'].values[0].astype(int)
fish2 = folder_row['fish2'].values[0].astype(int)
fish1_freq = folder_row["rec_id1"].values[0].astype(int)
fish2_freq = folder_row["rec_id2"].values[0].astype(int)
fish1_freq = folder_row['rec_id1'].values[0].astype(int)
fish2_freq = folder_row['rec_id2'].values[0].astype(int)
chirp_freq_fish1 = np.nanmedian(
Behavior.freq[Behavior.ident == fish1_freq])
chirp_freq_fish2 = np.nanmedian(
Behavior.freq[Behavior.ident == fish2_freq])
winner = folder_row['winner'].values[0].astype(int)
chirp_freq_fish1 = np.nanmedian(Behavior.freq[Behavior.ident == fish1_freq])
chirp_freq_fish2 = np.nanmedian(Behavior.freq[Behavior.ident == fish2_freq])
winner = folder_row["winner"].values[0].astype(int)
if winner == fish1:
# if chirp_freq_fish1 > chirp_freq_fish2:
@@ -138,9 +137,9 @@ def get_chirp_freq(folder_name, Behavior, order_meta_df):
# winner_fish_id = np.nan
# loser_fish_id = np.nan
winner_fish_id = folder_row['rec_id1'].values[0]
winner_fish_id = folder_row["rec_id1"].values[0]
winner_fish_freq = chirp_freq_fish1
loser_fish_id = folder_row['rec_id2'].values[0]
loser_fish_id = folder_row["rec_id2"].values[0]
loser_fish_freq = chirp_freq_fish2
elif winner == fish2:
@@ -157,9 +156,9 @@ def get_chirp_freq(folder_name, Behavior, order_meta_df):
# winner_fish_id = np.nan
# loser_fish_id = np.nan
winner_fish_id = folder_row['rec_id2'].values[0]
winner_fish_id = folder_row["rec_id2"].values[0]
winner_fish_freq = chirp_freq_fish2
loser_fish_id = folder_row['rec_id1'].values[0]
loser_fish_id = folder_row["rec_id1"].values[0]
loser_fish_freq = chirp_freq_fish1
else:
winner_fish_freq = np.nan
@@ -168,25 +167,25 @@ def get_chirp_freq(folder_name, Behavior, order_meta_df):
loser_fish_id = np.nan
return winner_fish_freq, winner_fish_id, loser_fish_freq, loser_fish_id
chirp_winner = len(
Behavior.chirps[Behavior.chirps_ids == winner_fish_id])
chirp_loser = len(
Behavior.chirps[Behavior.chirps_ids == loser_fish_id])
chirp_winner = len(Behavior.chirps[Behavior.chirps_ids == winner_fish_id])
chirp_loser = len(Behavior.chirps[Behavior.chirps_ids == loser_fish_id])
return winner_fish_freq, chirp_winner, loser_fish_freq, chirp_loser
def main(datapath: str):
foldernames = [
datapath + x + '/' for x in os.listdir(datapath) if os.path.isdir(datapath+x)]
datapath + x + "/"
for x in os.listdir(datapath)
if os.path.isdir(datapath + x)
]
foldernames, _ = get_valid_datasets(datapath)
path_order_meta = (
'/').join(foldernames[0].split('/')[:-2]) + '/order_meta.csv'
path_order_meta = ("/").join(
foldernames[0].split("/")[:-2]
) + "/order_meta.csv"
order_meta_df = read_csv(path_order_meta)
order_meta_df['recording'] = order_meta_df['recording'].str[1:-1]
path_id_meta = (
'/').join(foldernames[0].split('/')[:-2]) + '/id_meta.csv'
order_meta_df["recording"] = order_meta_df["recording"].str[1:-1]
path_id_meta = ("/").join(foldernames[0].split("/")[:-2]) + "/id_meta.csv"
id_meta_df = read_csv(path_id_meta)
chirps_winner = []
@@ -202,10 +201,9 @@ def main(datapath: str):
freq_chirps_winner = []
freq_chirps_loser = []
for foldername in foldernames:
# behabvior is pandas dataframe with all the data
if foldername == '../data/mount_data/2020-05-12-10_00/':
if foldername == "../data/mount_data/2020-05-12-10_00/":
continue
bh = Behavior(foldername)
# chirps are not sorted in time (presumably due to prior groupings)
@@ -217,15 +215,24 @@ def main(datapath: str):
category, timestamps = correct_chasing_events(category, timestamps)
winner_chirp, loser_chirp = get_chirp_winner_loser(
foldername, bh, order_meta_df)
foldername, bh, order_meta_df
)
chirps_winner.append(winner_chirp)
chirps_loser.append(loser_chirp)
size_diff_bigger, chirp_winner, size_diff_smaller, chirp_loser = get_chirp_size(
foldername, bh, order_meta_df, id_meta_df)
(
size_diff_bigger,
chirp_winner,
size_diff_smaller,
chirp_loser,
) = get_chirp_size(foldername, bh, order_meta_df, id_meta_df)
freq_winner, chirp_freq_winner, freq_loser, chirp_freq_loser = get_chirp_freq(
foldername, bh, order_meta_df)
(
freq_winner,
chirp_freq_winner,
freq_loser,
chirp_freq_loser,
) = get_chirp_freq(foldername, bh, order_meta_df)
freq_diffs_higher.append(freq_winner)
freq_diffs_lower.append(freq_loser)
@@ -242,82 +249,124 @@ def main(datapath: str):
pearsonr(size_diffs_winner, size_chirps_winner)
pearsonr(size_diffs_loser, size_chirps_loser)
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(
21*ps.cm, 7*ps.cm), width_ratios=[1, 0.8, 0.8], sharey=True)
plt.subplots_adjust(left=0.11, right=0.948, top=0.86,
wspace=0.343, bottom=0.198)
fig, (ax1, ax2, ax3) = plt.subplots(
1,
3,
figsize=(21 * ps.cm, 7 * ps.cm),
width_ratios=[1, 0.8, 0.8],
sharey=True,
)
plt.subplots_adjust(
left=0.11, right=0.948, top=0.86, wspace=0.343, bottom=0.198
)
scatterwinner = 1.15
scatterloser = 1.85
chirps_winner = np.asarray(chirps_winner)[~np.isnan(chirps_winner)]
chirps_loser = np.asarray(chirps_loser)[~np.isnan(chirps_loser)]
embed()
exit()
freq_diffs_higher = np.asarray(
freq_diffs_higher)[~np.isnan(freq_diffs_higher)]
freq_diffs_lower = np.asarray(freq_diffs_lower)[
~np.isnan(freq_diffs_lower)]
freq_chirps_winner = np.asarray(
freq_chirps_winner)[~np.isnan(freq_chirps_winner)]
freq_chirps_loser = np.asarray(
freq_chirps_loser)[~np.isnan(freq_chirps_loser)]
freq_diffs_higher = np.asarray(freq_diffs_higher)[
~np.isnan(freq_diffs_higher)
]
freq_diffs_lower = np.asarray(freq_diffs_lower)[~np.isnan(freq_diffs_lower)]
freq_chirps_winner = np.asarray(freq_chirps_winner)[
~np.isnan(freq_chirps_winner)
]
freq_chirps_loser = np.asarray(freq_chirps_loser)[
~np.isnan(freq_chirps_loser)
]
stat = wilcoxon(chirps_winner, chirps_loser)
print(stat)
winner_color = ps.gblue2
loser_color = ps.gblue1
bplot1 = ax1.boxplot(chirps_winner, positions=[
0.9], showfliers=False, patch_artist=True)
bplot1 = ax1.boxplot(
chirps_winner, positions=[0.9], showfliers=False, patch_artist=True
)
bplot2 = ax1.boxplot(chirps_loser, positions=[
2.1], showfliers=False, patch_artist=True)
bplot2 = ax1.boxplot(
chirps_loser, positions=[2.1], showfliers=False, patch_artist=True
)
ax1.scatter(np.ones(len(chirps_winner)) *
scatterwinner, chirps_winner, color=winner_color)
ax1.scatter(np.ones(len(chirps_loser)) *
scatterloser, chirps_loser, color=loser_color)
ax1.set_xticklabels(['Winner', 'Loser'])
ax1.scatter(
np.ones(len(chirps_winner)) * scatterwinner,
chirps_winner,
color=winner_color,
)
ax1.scatter(
np.ones(len(chirps_loser)) * scatterloser,
chirps_loser,
color=loser_color,
)
ax1.set_xticklabels(["Winner", "Loser"])
ax1.text(0.1, 0.85, f'n={len(chirps_loser)}',
transform=ax1.transAxes, color=ps.white)
ax1.text(
0.1,
0.85,
f"n={len(chirps_loser)}",
transform=ax1.transAxes,
color=ps.white,
)
for w, l in zip(chirps_winner, chirps_loser):
ax1.plot([scatterwinner, scatterloser], [w, l],
color=ps.white, alpha=0.6, linewidth=1, zorder=-1)
ax1.set_ylabel('Chirp counts', color=ps.white)
ax1.set_xlabel('Competition outcome', color=ps.white)
ax1.plot(
[scatterwinner, scatterloser],
[w, l],
color=ps.white,
alpha=0.6,
linewidth=1,
zorder=-1,
)
ax1.set_ylabel("Chirp counts", color=ps.white)
ax1.set_xlabel("Competition outcome", color=ps.white)
ps.set_boxplot_color(bplot1, winner_color)
ps.set_boxplot_color(bplot2, loser_color)
ax2.scatter(size_diffs_winner, size_chirps_winner,
color=winner_color, label='Winner')
ax2.scatter(size_diffs_loser, size_chirps_loser,
color=loser_color, label='Loser')
ax2.scatter(
size_diffs_winner,
size_chirps_winner,
color=winner_color,
label="Winner",
)
ax2.scatter(
size_diffs_loser, size_chirps_loser, color=loser_color, label="Loser"
)
ax2.text(0.05, 0.85, f'n={len(size_chirps_loser)}',
transform=ax2.transAxes, color=ps.white)
ax2.text(
0.05,
0.85,
f"n={len(size_chirps_loser)}",
transform=ax2.transAxes,
color=ps.white,
)
ax2.set_xlabel('Size difference [cm]')
ax2.set_xlabel("Size difference [cm]")
# ax2.set_xticks(np.arange(-10, 10.1, 2))
ax3.scatter(freq_diffs_higher, freq_chirps_winner, color=winner_color)
ax3.scatter(freq_diffs_lower, freq_chirps_loser, color=loser_color)
ax3.text(0.1, 0.85, f'n={len(np.asarray(freq_chirps_winner)[~np.isnan(freq_chirps_loser)])}',
transform=ax3.transAxes, color=ps.white)
ax3.text(
0.1,
0.85,
f"n={len(np.asarray(freq_chirps_winner)[~np.isnan(freq_chirps_loser)])}",
transform=ax3.transAxes,
color=ps.white,
)
ax3.set_xlabel('EODf [Hz]')
ax3.set_xlabel("EODf [Hz]")
handles, labels = ax2.get_legend_handles_labels()
fig.legend(handles, labels, loc='upper center',
ncol=2, bbox_to_anchor=(0.5, 1.04))
fig.legend(
handles, labels, loc="upper center", ncol=2, bbox_to_anchor=(0.5, 1.04)
)
# pearson r
plt.savefig('../poster/figs/chirps_winner_loser.pdf')
plt.savefig("../poster/figs/chirps_winner_loser.pdf")
plt.show()
if __name__ == '__main__':
if __name__ == "__main__":
# Path to the data
datapath = '../data/mount_data/'
datapath = "../data/mount_data/"
main(datapath)

View File

@@ -21,14 +21,16 @@ logger = makeLogger(__name__)
def main(datapath: str):
foldernames = [
datapath + x + '/' for x in os.listdir(datapath) if os.path.isdir(datapath+x)]
datapath + x + "/"
for x in os.listdir(datapath)
if os.path.isdir(datapath + x)
]
time_precents = []
chirps_percents = []
for foldername in foldernames:
# behabvior is pandas dataframe with all the data
if foldername == '../data/mount_data/2020-05-12-10_00/':
if foldername == "../data/mount_data/2020-05-12-10_00/":
continue
bh = Behavior(foldername)
@@ -46,50 +48,70 @@ def main(datapath: str):
chirps_in_chasings = []
for onset, offset in zip(chasing_onset, chasing_offset):
chirps_in_chasing = [
c for c in bh.chirps if (c > onset) & (c < offset)]
c for c in bh.chirps if (c > onset) & (c < offset)
]
chirps_in_chasings.append(chirps_in_chasing)
try:
time_chasing = np.sum(
chasing_offset[chasing_offset < 3*60*60] - chasing_onset[chasing_onset < 3*60*60])
chasing_offset[chasing_offset < 3 * 60 * 60]
- chasing_onset[chasing_onset < 3 * 60 * 60]
)
except:
time_chasing = np.sum(
chasing_offset[chasing_offset < 3*60*60] - chasing_onset[chasing_onset < 3*60*60][:-1])
chasing_offset[chasing_offset < 3 * 60 * 60]
- chasing_onset[chasing_onset < 3 * 60 * 60][:-1]
)
time_chasing_percent = (time_chasing/(3*60*60))*100
time_chasing_percent = (time_chasing / (3 * 60 * 60)) * 100
chirps_chasing = np.asarray(flatten(chirps_in_chasings))
chirps_chasing_new = chirps_chasing[chirps_chasing < 3*60*60]
chirps_percent = (len(chirps_chasing_new) /
len(bh.chirps[bh.chirps < 3*60*60]))*100
chirps_chasing_new = chirps_chasing[chirps_chasing < 3 * 60 * 60]
chirps_percent = (
len(chirps_chasing_new) / len(bh.chirps[bh.chirps < 3 * 60 * 60])
) * 100
time_precents.append(time_chasing_percent)
chirps_percents.append(chirps_percent)
fig, ax = plt.subplots(1, 1, figsize=(7*ps.cm, 7*ps.cm))
fig, ax = plt.subplots(1, 1, figsize=(7 * ps.cm, 7 * ps.cm))
scatter_time = 1.20
scatter_chirps = 1.80
size = 10
bplot1 = ax.boxplot([time_precents, chirps_percents],
showfliers=False, patch_artist=True)
bplot1 = ax.boxplot(
[time_precents, chirps_percents], showfliers=False, patch_artist=True
)
ps.set_boxplot_color(bplot1, ps.gray)
ax.set_xticklabels(['Time \nchasing', 'Chirps \nin chasing'])
ax.set_ylabel('Percent')
ax.scatter(np.ones(len(time_precents))*scatter_time, time_precents,
facecolor=ps.white, s=size)
ax.scatter(np.ones(len(chirps_percents))*scatter_chirps, chirps_percents,
facecolor=ps.white, s=size)
ax.set_xticklabels(["Time \nchasing", "Chirps \nin chasing"])
ax.set_ylabel("Percent")
ax.scatter(
np.ones(len(time_precents)) * scatter_time,
time_precents,
facecolor=ps.white,
s=size,
)
ax.scatter(
np.ones(len(chirps_percents)) * scatter_chirps,
chirps_percents,
facecolor=ps.white,
s=size,
)
for i in range(len(time_precents)):
ax.plot([scatter_time, scatter_chirps], [time_precents[i],
chirps_percents[i]], alpha=0.6, linewidth=1, color=ps.white)
ax.plot(
[scatter_time, scatter_chirps],
[time_precents[i], chirps_percents[i]],
alpha=0.6,
linewidth=1,
color=ps.white,
)
ax.text(0.1, 0.9, f'n={len(time_precents)}', transform=ax.transAxes)
ax.text(0.1, 0.9, f"n={len(time_precents)}", transform=ax.transAxes)
plt.subplots_adjust(left=0.221, bottom=0.186, right=0.97, top=0.967)
plt.savefig('../poster/figs/chirps_in_chasing.pdf')
plt.savefig("../poster/figs/chirps_in_chasing.pdf")
plt.show()
if __name__ == '__main__':
if __name__ == "__main__":
# Path to the data
datapath = '../data/mount_data/'
datapath = "../data/mount_data/"
main(datapath)

View File

@@ -13,6 +13,7 @@ from modules.plotstyle import PlotStyle
from modules.behaviour_handling import Behavior, correct_chasing_events
from extract_chirps import get_valid_datasets
ps = PlotStyle()
logger = makeLogger(__name__)
@@ -20,13 +21,16 @@ logger = makeLogger(__name__)
def main(datapath: str):
foldernames = [
datapath + x + '/' for x in os.listdir(datapath) if os.path.isdir(datapath+x)]
datapath + x + "/"
for x in os.listdir(datapath)
if os.path.isdir(datapath + x)
]
foldernames, _ = get_valid_datasets(datapath)
for foldername in foldernames[3:4]:
print(foldername)
# foldername = foldernames[0]
if foldername == '../data/mount_data/2020-05-12-10_00/':
if foldername == "../data/mount_data/2020-05-12-10_00/":
continue
# behabvior is pandas dataframe with all the data
bh = Behavior(foldername)
@@ -52,18 +56,43 @@ def main(datapath: str):
exit()
fish1_color = ps.gblue2
fish2_color = ps.gblue1
fig, ax = plt.subplots(5, 1, figsize=(
21*ps.cm, 10*ps.cm), height_ratios=[0.5, 0.5, 0.5, 0.2, 6], sharex=True)
fig, ax = plt.subplots(
5,
1,
figsize=(21 * ps.cm, 10 * ps.cm),
height_ratios=[0.5, 0.5, 0.5, 0.2, 6],
sharex=True,
)
# marker size
s = 80
ax[0].scatter(physical_contact, np.ones(
len(physical_contact)), color=ps.gray, marker='|', s=s)
ax[1].scatter(chasing_onset, np.ones(len(chasing_onset)),
color=ps.gray, marker='|', s=s)
ax[2].scatter(fish1, np.ones(len(fish1))-0.25,
color=fish1_color, marker='|', s=s)
ax[2].scatter(fish2, np.zeros(len(fish2))+0.25,
color=fish2_color, marker='|', s=s)
ax[0].scatter(
physical_contact,
np.ones(len(physical_contact)),
color=ps.gray,
marker="|",
s=s,
)
ax[1].scatter(
chasing_onset,
np.ones(len(chasing_onset)),
color=ps.gray,
marker="|",
s=s,
)
ax[2].scatter(
fish1,
np.ones(len(fish1)) - 0.25,
color=fish1_color,
marker="|",
s=s,
)
ax[2].scatter(
fish2,
np.zeros(len(fish2)) + 0.25,
color=fish2_color,
marker="|",
s=s,
)
freq_temp = bh.freq[bh.ident == fish1_id]
time_temp = bh.time[bh.idx[bh.ident == fish1_id]]
@@ -94,35 +123,38 @@ def main(datapath: str):
ax[2].set_xticks([])
ps.hide_ax(ax[2])
ax[4].axvspan(0, 3, 0, 5, facecolor='grey', alpha=0.5)
ax[4].axvspan(0, 3, 0, 5, facecolor="grey", alpha=0.5)
ax[4].set_xticks(np.arange(0, 6.1, 0.5))
ps.hide_ax(ax[3])
labelpad = 30
fsize = 12
ax[0].set_ylabel('Contact', rotation=0,
labelpad=labelpad, fontsize=fsize)
ax[0].set_ylabel(
"Contact", rotation=0, labelpad=labelpad, fontsize=fsize
)
ax[0].yaxis.set_label_coords(-0.062, -0.08)
ax[1].set_ylabel('Chasing', rotation=0,
labelpad=labelpad, fontsize=fsize)
ax[1].set_ylabel(
"Chasing", rotation=0, labelpad=labelpad, fontsize=fsize
)
ax[1].yaxis.set_label_coords(-0.06, -0.08)
ax[2].set_ylabel('Chirps', rotation=0,
labelpad=labelpad, fontsize=fsize)
ax[2].set_ylabel(
"Chirps", rotation=0, labelpad=labelpad, fontsize=fsize
)
ax[2].yaxis.set_label_coords(-0.07, -0.08)
ax[4].set_ylabel('EODf')
ax[4].set_ylabel("EODf")
ax[4].set_xlabel('Time [h]')
ax[4].set_xlabel("Time [h]")
# ax[0].set_title(foldername.split('/')[-2])
# 2020-03-31-9_59
plt.subplots_adjust(left=0.158, right=0.987, top=0.918, bottom=0.136)
plt.savefig('../poster/figs/timeline.svg')
plt.savefig("../poster/figs/timeline.svg")
plt.show()
# plot chirps
if __name__ == '__main__':
if __name__ == "__main__":
# Path to the data
datapath = '../data/mount_data/'
datapath = "../data/mount_data/"
main(datapath)

View File

@@ -11,7 +11,6 @@ ps = PlotStyle()
def main():
# Load data
datapath = "../data/2022-06-02-10_00/"
data = LoadData(datapath)
@@ -24,26 +23,31 @@ def main():
timescaler = 1000
raw = data.raw[window_start_index:window_start_index +
window_duration_index, 10]
raw = data.raw[
window_start_index : window_start_index + window_duration_index, 10
]
fig, (ax1, ax2) = plt.subplots(
1, 2, figsize=(21 * ps.cm, 8*ps.cm), sharex=True, sharey=True)
1, 2, figsize=(21 * ps.cm, 8 * ps.cm), sharex=True, sharey=True
)
# plot instantaneous frequency
filtered1 = bandpass_filter(
signal=raw, lowf=750, highf=1200, samplerate=data.raw_rate)
signal=raw, lowf=750, highf=1200, samplerate=data.raw_rate
)
filtered2 = bandpass_filter(
signal=raw, lowf=550, highf=700, samplerate=data.raw_rate)
signal=raw, lowf=550, highf=700, samplerate=data.raw_rate
)
freqtime1, freq1 = instantaneous_frequency(
filtered1, data.raw_rate, smoothing_window=3)
filtered1, data.raw_rate, smoothing_window=3
)
freqtime2, freq2 = instantaneous_frequency(
filtered2, data.raw_rate, smoothing_window=3)
filtered2, data.raw_rate, smoothing_window=3
)
ax1.plot(freqtime1*timescaler, freq1, color=ps.g, lw=2, label="Fish 1")
ax1.plot(freqtime2*timescaler, freq2, color=ps.gray,
lw=2, label="Fish 2")
ax1.plot(freqtime1 * timescaler, freq1, color=ps.g, lw=2, label="Fish 1")
ax1.plot(freqtime2 * timescaler, freq2, color=ps.gray, lw=2, label="Fish 2")
# ax.legend(bbox_to_anchor=(1.04, 1), borderaxespad=0)
# # ps.hide_xax(ax1)
@@ -62,8 +66,8 @@ def main():
ax1.imshow(
decibel(spec_power[fmask, :]),
extent=[
spec_times[0]*timescaler,
spec_times[-1]*timescaler,
spec_times[0] * timescaler,
spec_times[-1] * timescaler,
spec_freqs[fmask][0],
spec_freqs[fmask][-1],
],
@@ -87,8 +91,8 @@ def main():
ax2.imshow(
decibel(spec_power[fmask, :]),
extent=[
spec_times[0]*timescaler,
spec_times[-1]*timescaler,
spec_times[0] * timescaler,
spec_times[-1] * timescaler,
spec_freqs[fmask][0],
spec_freqs[fmask][-1],
],
@@ -98,9 +102,8 @@ def main():
alpha=1,
)
# ps.hide_xax(ax3)
ax2.plot(freqtime1*timescaler, freq1, color=ps.g, lw=2, label="_")
ax2.plot(freqtime2*timescaler, freq2, color=ps.gray,
lw=2, label="_")
ax2.plot(freqtime1 * timescaler, freq1, color=ps.g, lw=2, label="_")
ax2.plot(freqtime2 * timescaler, freq2, color=ps.gray, lw=2, label="_")
ax2.set_xlim(75, 200)
ax1.set_ylim(400, 1200)
@@ -109,15 +112,22 @@ def main():
fig.supylabel("Frequency [Hz]", fontsize=14)
handles, labels = ax1.get_legend_handles_labels()
ax2.legend(handles, labels, bbox_to_anchor=(1.04, 1), loc="upper left", ncol=1,)
ax2.legend(
handles,
labels,
bbox_to_anchor=(1.04, 1),
loc="upper left",
ncol=1,
)
ps.letter_subplots(xoffset=[-0.27, -0.1], yoffset=1.05)
plt.subplots_adjust(left=0.12, right=0.85, top=0.89,
bottom=0.18, hspace=0.35)
plt.subplots_adjust(
left=0.12, right=0.85, top=0.89, bottom=0.18, hspace=0.35
)
plt.savefig('../poster/figs/introplot.pdf')
plt.savefig("../poster/figs/introplot.pdf")
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@@ -1,7 +1,9 @@
from modules.plotstyle import PlotStyle
from modules.behaviour_handling import (
Behavior, correct_chasing_events, center_chirps)
Behavior,
correct_chasing_events,
center_chirps,
)
from modules.datahandling import flatten, causal_kde1d, acausal_kde1d
from modules.logger import makeLogger
from pandas import read_csv
@@ -18,80 +20,93 @@ logger = makeLogger(__name__)
ps = PlotStyle()
def bootstrap(data, nresamples, kde_time, kernel_width, event_times, time_before, time_after):
def bootstrap(
data,
nresamples,
kde_time,
kernel_width,
event_times,
time_before,
time_after,
):
bootstrapped_kdes = []
data = data[data <= 3*60*60] # only night time
data = data[data <= 3 * 60 * 60] # only night time
diff_data = np.diff(np.sort(data), prepend=0)
# if len(data) != 0:
# mean_chirprate = (len(data) - 1) / (data[-1] - data[0])
for i in tqdm(range(nresamples)):
np.random.shuffle(diff_data)
bootstrapped_data = np.cumsum(diff_data)
# bootstrapped_data = data + np.random.randn(len(data)) * 10
bootstrap_data_centered = center_chirps(
bootstrapped_data, event_times, time_before, time_after)
bootstrapped_data, event_times, time_before, time_after
)
bootstrapped_kde = acausal_kde1d(
bootstrap_data_centered, time=kde_time, width=kernel_width)
bootstrap_data_centered, time=kde_time, width=kernel_width
)
bootstrapped_kde = list(np.asarray(
bootstrapped_kde) / len(event_times))
bootstrapped_kde = list(np.asarray(bootstrapped_kde) / len(event_times))
bootstrapped_kdes.append(bootstrapped_kde)
return bootstrapped_kdes
def jackknife(data, nresamples, subsetsize, kde_time, kernel_width, event_times, time_before, time_after):
def jackknife(
data,
nresamples,
subsetsize,
kde_time,
kernel_width,
event_times,
time_before,
time_after,
):
jackknife_kdes = []
data = data[data <= 3*60*60] # only night time
data = data[data <= 3 * 60 * 60] # only night time
subsetsize = int(len(data) * subsetsize)
diff_data = np.diff(np.sort(data), prepend=0)
for i in tqdm(range(nresamples)):
jackknifed_data = np.random.choice(
diff_data, subsetsize, replace=False)
jackknifed_data = np.random.choice(diff_data, subsetsize, replace=False)
jackknifed_data = np.cumsum(jackknifed_data)
jackknifed_data_centered = center_chirps(
jackknifed_data, event_times, time_before, time_after)
jackknifed_data, event_times, time_before, time_after
)
jackknifed_kde = acausal_kde1d(
jackknifed_data_centered, time=kde_time, width=kernel_width)
jackknifed_data_centered, time=kde_time, width=kernel_width
)
jackknifed_kde = list(np.asarray(
jackknifed_kde) / len(event_times))
jackknifed_kde = list(np.asarray(jackknifed_kde) / len(event_times))
jackknife_kdes.append(jackknifed_kde)
return jackknife_kdes
def get_chirp_winner_loser(folder_name, Behavior, order_meta_df):
foldername = folder_name.split('/')[-2]
winner_row = order_meta_df[order_meta_df['recording'] == foldername]
winner = winner_row['winner'].values[0].astype(int)
winner_fish1 = winner_row['fish1'].values[0].astype(int)
winner_fish2 = winner_row['fish2'].values[0].astype(int)
foldername = folder_name.split("/")[-2]
winner_row = order_meta_df[order_meta_df["recording"] == foldername]
winner = winner_row["winner"].values[0].astype(int)
winner_fish1 = winner_row["fish1"].values[0].astype(int)
winner_fish2 = winner_row["fish2"].values[0].astype(int)
if winner > 0:
if winner == winner_fish1:
winner_fish_id = winner_row['rec_id1'].values[0]
loser_fish_id = winner_row['rec_id2'].values[0]
winner_fish_id = winner_row["rec_id1"].values[0]
loser_fish_id = winner_row["rec_id2"].values[0]
elif winner == winner_fish2:
winner_fish_id = winner_row['rec_id2'].values[0]
loser_fish_id = winner_row['rec_id1'].values[0]
winner_fish_id = winner_row["rec_id2"].values[0]
loser_fish_id = winner_row["rec_id1"].values[0]
chirp_winner = Behavior.chirps[Behavior.chirps_ids == winner_fish_id]
chirp_loser = Behavior.chirps[Behavior.chirps_ids == loser_fish_id]
@@ -101,7 +116,6 @@ def get_chirp_winner_loser(folder_name, Behavior, order_meta_df):
def main(dataroot):
foldernames, _ = np.asarray(get_valid_datasets(dataroot))
plot_all = True
time_before = 90
@@ -111,10 +125,9 @@ def main(dataroot):
kde_time = np.arange(-time_before, time_after, dt)
nbootstraps = 50
meta_path = (
'/').join(foldernames[0].split('/')[:-2]) + '/order_meta.csv'
meta_path = ("/").join(foldernames[0].split("/")[:-2]) + "/order_meta.csv"
meta = pd.read_csv(meta_path)
meta['recording'] = meta['recording'].str[1:-1]
meta["recording"] = meta["recording"].str[1:-1]
winner_onsets = []
winner_offsets = []
@@ -143,24 +156,24 @@ def main(dataroot):
# loser_onset_chirpcount = 0
# loser_offset_chirpcount = 0
# loser_physical_chirpcount = 0
fig, ax = plt.subplots(1, 2, figsize=(
14 * ps.cm, 7*ps.cm), sharey=True, sharex=True)
fig, ax = plt.subplots(
1, 2, figsize=(14 * ps.cm, 7 * ps.cm), sharey=True, sharex=True
)
# Iterate over all recordings and save chirp- and event-timestamps
good_recs = np.asarray([0, 15])
for i, folder in tqdm(enumerate(foldernames[good_recs])):
foldername = folder.split('/')[-2]
foldername = folder.split("/")[-2]
# logger.info('Loading data from folder: {}'.format(foldername))
broken_folders = ['../data/mount_data/2020-05-12-10_00/']
broken_folders = ["../data/mount_data/2020-05-12-10_00/"]
if folder in broken_folders:
continue
bh = Behavior(folder)
category, timestamps = correct_chasing_events(bh.behavior, bh.start_s)
category = category[timestamps < 3*60*60] # only night time
timestamps = timestamps[timestamps < 3*60*60] # only night time
category = category[timestamps < 3 * 60 * 60] # only night time
timestamps = timestamps[timestamps < 3 * 60 * 60] # only night time
winner, loser = get_chirp_winner_loser(folder, bh, meta)
if winner is None:
@@ -168,27 +181,33 @@ def main(dataroot):
# winner_count += len(winner)
# loser_count += len(loser)
onsets = (timestamps[category == 0])
offsets = (timestamps[category == 1])
physicals = (timestamps[category == 2])
onsets = timestamps[category == 0]
offsets = timestamps[category == 1]
physicals = timestamps[category == 2]
onset_count += len(onsets)
offset_count += len(offsets)
physical_count += len(physicals)
winner_onsets.append(center_chirps(
winner, onsets, time_before, time_after))
winner_offsets.append(center_chirps(
winner, offsets, time_before, time_after))
winner_physicals.append(center_chirps(
winner, physicals, time_before, time_after))
winner_onsets.append(
center_chirps(winner, onsets, time_before, time_after)
)
winner_offsets.append(
center_chirps(winner, offsets, time_before, time_after)
)
winner_physicals.append(
center_chirps(winner, physicals, time_before, time_after)
)
loser_onsets.append(center_chirps(
loser, onsets, time_before, time_after))
loser_offsets.append(center_chirps(
loser, offsets, time_before, time_after))
loser_physicals.append(center_chirps(
loser, physicals, time_before, time_after))
loser_onsets.append(
center_chirps(loser, onsets, time_before, time_after)
)
loser_offsets.append(
center_chirps(loser, offsets, time_before, time_after)
)
loser_physicals.append(
center_chirps(loser, physicals, time_before, time_after)
)
# winner_onset_chirpcount += len(winner_onsets[-1])
# winner_offset_chirpcount += len(winner_offsets[-1])
@@ -232,14 +251,17 @@ def main(dataroot):
# event_times=onsets,
# time_before=time_before,
# time_after=time_after))
loser_offsets_boot.append(bootstrap(
loser,
nresamples=nbootstraps,
kde_time=kde_time,
kernel_width=kernel_width,
event_times=offsets,
time_before=time_before,
time_after=time_after))
loser_offsets_boot.append(
bootstrap(
loser,
nresamples=nbootstraps,
kde_time=kde_time,
kernel_width=kernel_width,
event_times=offsets,
time_before=time_before,
time_after=time_after,
)
)
# loser_physicals_boot.append(bootstrap(
# loser,
# nresamples=nbootstraps,
@@ -249,18 +271,17 @@ def main(dataroot):
# time_before=time_before,
# time_after=time_after))
# loser_offsets_jackknife = jackknife(
# loser,
# nresamples=nbootstraps,
# subsetsize=0.9,
# kde_time=kde_time,
# kernel_width=kernel_width,
# event_times=offsets,
# time_before=time_before,
# time_after=time_after)
# loser_offsets_jackknife = jackknife(
# loser,
# nresamples=nbootstraps,
# subsetsize=0.9,
# kde_time=kde_time,
# kernel_width=kernel_width,
# event_times=offsets,
# time_before=time_before,
# time_after=time_after)
if plot_all:
# winner_onsets_conv = acausal_kde1d(
# winner_onsets[-1], kde_time, kernel_width)
# winner_offsets_conv = acausal_kde1d(
@@ -271,24 +292,35 @@ def main(dataroot):
# loser_onsets_conv = acausal_kde1d(
# loser_onsets[-1], kde_time, kernel_width)
loser_offsets_conv = acausal_kde1d(
loser_offsets[-1], kde_time, kernel_width)
loser_offsets[-1], kde_time, kernel_width
)
# loser_physicals_conv = acausal_kde1d(
# loser_physicals[-1], kde_time, kernel_width)
ax[i].plot(kde_time, loser_offsets_conv /
len(offsets), lw=2, zorder=100, c=ps.gblue1)
ax[i].plot(
kde_time,
loser_offsets_conv / len(offsets),
lw=2,
zorder=100,
c=ps.gblue1,
)
ax[i].fill_between(
kde_time,
np.percentile(loser_offsets_boot[-1], 1, axis=0),
np.percentile(loser_offsets_boot[-1], 99, axis=0),
color='gray',
alpha=0.8)
color="gray",
alpha=0.8,
)
ax[i].plot(kde_time, np.median(loser_offsets_boot[-1], axis=0),
color=ps.black, linewidth=2)
ax[i].plot(
kde_time,
np.median(loser_offsets_boot[-1], axis=0),
color=ps.black,
linewidth=2,
)
ax[i].axvline(0, color=ps.gray, linestyle='--')
ax[i].axvline(0, color=ps.gray, linestyle="--")
# ax[i].fill_between(
# kde_time,
@@ -300,8 +332,8 @@ def main(dataroot):
# color=ps.white, linewidth=2)
ax[i].set_xlim(-60, 60)
fig.supylabel('Chirp rate (a.u.)', fontsize=14)
fig.supxlabel('Time (s)', fontsize=14)
fig.supylabel("Chirp rate (a.u.)", fontsize=14)
fig.supxlabel("Time (s)", fontsize=14)
# fig, ax = plt.subplots(2, 3, figsize=(
# 21*ps.cm, 10*ps.cm), sharey=True, sharex=True)
@@ -521,9 +553,9 @@ def main(dataroot):
# color=ps.gray,
# alpha=0.5)
plt.subplots_adjust(bottom=0.21, top=0.93)
plt.savefig('../poster/figs/kde.pdf')
plt.savefig("../poster/figs/kde.pdf")
plt.show()
if __name__ == '__main__':
main('../data/mount_data/')
if __name__ == "__main__":
main("../data/mount_data/")