This commit is contained in:
weygoldt
2023-04-11 15:33:07 +02:00
parent 12b1fdccae
commit 282c846b05
26 changed files with 1165 additions and 798 deletions

View File

@@ -30,12 +30,12 @@ class Behavior:
"""
def __init__(self, folder_path: str) -> None:
LED_on_time_BORIS = np.load(os.path.join(
folder_path, 'LED_on_time.npy'), allow_pickle=True)
LED_on_time_BORIS = np.load(
os.path.join(folder_path, "LED_on_time.npy"), allow_pickle=True
)
csv_filename = os.path.split(folder_path[:-1])[-1]
csv_filename = '-'.join(csv_filename.split('-')[:-1]) + '.csv'
csv_filename = "-".join(csv_filename.split("-")[:-1]) + ".csv"
# embed()
# csv_filename = [f for f in os.listdir(
@@ -43,31 +43,39 @@ class Behavior:
# logger.info(f'CSV file: {csv_filename}')
self.dataframe = read_csv(os.path.join(folder_path, csv_filename))
self.chirps = np.load(os.path.join(
folder_path, 'chirps.npy'), allow_pickle=True)
self.chirps_ids = np.load(os.path.join(
folder_path, 'chirp_ids.npy'), allow_pickle=True)
self.chirps = np.load(
os.path.join(folder_path, "chirps.npy"), allow_pickle=True
)
self.chirps_ids = np.load(
os.path.join(folder_path, "chirp_ids.npy"), allow_pickle=True
)
self.ident = np.load(os.path.join(
folder_path, 'ident_v.npy'), allow_pickle=True)
self.idx = np.load(os.path.join(
folder_path, 'idx_v.npy'), allow_pickle=True)
self.freq = np.load(os.path.join(
folder_path, 'fund_v.npy'), allow_pickle=True)
self.time = np.load(os.path.join(
folder_path, "times.npy"), allow_pickle=True)
self.spec = np.load(os.path.join(
folder_path, "spec.npy"), allow_pickle=True)
self.ident = np.load(
os.path.join(folder_path, "ident_v.npy"), allow_pickle=True
)
self.idx = np.load(
os.path.join(folder_path, "idx_v.npy"), allow_pickle=True
)
self.freq = np.load(
os.path.join(folder_path, "fund_v.npy"), allow_pickle=True
)
self.time = np.load(
os.path.join(folder_path, "times.npy"), allow_pickle=True
)
self.spec = np.load(
os.path.join(folder_path, "spec.npy"), allow_pickle=True
)
for k, key in enumerate(self.dataframe.keys()):
key = key.lower()
if ' ' in key:
key = key.replace(' ', '_')
if '(' in key:
key = key.replace('(', '')
key = key.replace(')', '')
setattr(self, key, np.array(
self.dataframe[self.dataframe.keys()[k]]))
if " " in key:
key = key.replace(" ", "_")
if "(" in key:
key = key.replace("(", "")
key = key.replace(")", "")
setattr(
self, key, np.array(self.dataframe[self.dataframe.keys()[k]])
)
last_LED_t_BORIS = LED_on_time_BORIS[-1]
real_time_range = self.time[-1] - self.time[0]
@@ -78,22 +86,19 @@ class Behavior:
def correct_chasing_events(
category: np.ndarray,
timestamps: np.ndarray
category: np.ndarray, timestamps: np.ndarray
) -> tuple[np.ndarray, np.ndarray]:
onset_ids = np.arange(len(category))[category == 0]
offset_ids = np.arange(len(category))[category == 1]
onset_ids = np.arange(
len(category))[category == 0]
offset_ids = np.arange(
len(category))[category == 1]
wrong_bh = np.arange(len(category))[
category != 2][:-1][np.diff(category[category != 2]) == 0]
wrong_bh = np.arange(len(category))[category != 2][:-1][
np.diff(category[category != 2]) == 0
]
if category[category != 2][-1] == 0:
wrong_bh = np.append(
wrong_bh,
np.arange(len(category))[category != 2][-1])
wrong_bh, np.arange(len(category))[category != 2][-1]
)
if onset_ids[0] > offset_ids[0]:
offset_ids = np.delete(offset_ids, 0)
@@ -103,18 +108,16 @@ def correct_chasing_events(
category = np.delete(category, wrong_bh)
timestamps = np.delete(timestamps, wrong_bh)
new_onset_ids = np.arange(
len(category))[category == 0]
new_offset_ids = np.arange(
len(category))[category == 1]
new_onset_ids = np.arange(len(category))[category == 0]
new_offset_ids = np.arange(len(category))[category == 1]
# Check whether on- or offset is longer and calculate length difference
if len(new_onset_ids) > len(new_offset_ids):
embed()
logger.warning('Onsets are greater than offsets')
logger.warning("Onsets are greater than offsets")
elif len(new_onset_ids) < len(new_offset_ids):
logger.warning('Offsets are greater than onsets')
logger.warning("Offsets are greater than onsets")
elif len(new_onset_ids) == len(new_offset_ids):
# logger.info('Chasing events are equal')
pass
@@ -130,13 +133,11 @@ def center_chirps(
# dt: float,
# width: float,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
event_chirps = [] # chirps that are in specified window around event
event_chirps = [] # chirps that are in specified window around event
# timestamps of chirps around event centered on the event timepoint
centered_chirps = []
for event_timestamp in events:
start = event_timestamp - time_before_event
stop = event_timestamp + time_after_event
chirps_around_event = [c for c in chirps if (c >= start) & (c <= stop)]
@@ -152,7 +153,8 @@ def center_chirps(
if len(centered_chirps) != len(event_chirps):
raise ValueError(
'Non centered chirps and centered chirps are not equal')
"Non centered chirps and centered chirps are not equal"
)
# time = np.arange(-time_before_event, time_after_event, dt)

View File

@@ -23,7 +23,9 @@ def minmaxnorm(data):
return (data - np.min(data)) / (np.max(data) - np.min(data))
def instantaneous_frequency2(signal: np.ndarray, fs: float, interpolation: str = 'linear') -> np.ndarray:
def instantaneous_frequency2(
signal: np.ndarray, fs: float, interpolation: str = "linear"
) -> np.ndarray:
"""
Compute the instantaneous frequency of a periodic signal using zero crossings and resample the frequency using linear
or cubic interpolation to match the dimensions of the input array.
@@ -55,10 +57,10 @@ def instantaneous_frequency2(signal: np.ndarray, fs: float, interpolation: str =
orig_len = len(signal)
freq = resample(freq, orig_len)
if interpolation == 'linear':
if interpolation == "linear":
freq = np.interp(np.arange(0, orig_len), np.arange(0, orig_len), freq)
elif interpolation == 'cubic':
freq = resample(freq, orig_len, window='cubic')
elif interpolation == "cubic":
freq = resample(freq, orig_len, window="cubic")
return freq
@@ -67,7 +69,7 @@ def instantaneous_frequency(
signal: np.ndarray,
samplerate: int,
smoothing_window: int,
interpolation: str = 'linear',
interpolation: str = "linear",
) -> np.ndarray:
"""
Compute the instantaneous frequency of a signal that is approximately
@@ -120,11 +122,10 @@ def instantaneous_frequency(
orig_len = len(signal)
freq = resample(instantaneous_frequency, orig_len)
if interpolation == 'linear':
if interpolation == "linear":
freq = np.interp(np.arange(0, orig_len), np.arange(0, orig_len), freq)
elif interpolation == 'cubic':
freq = resample(freq, orig_len, window='cubic')
elif interpolation == "cubic":
freq = resample(freq, orig_len, window="cubic")
return freq
@@ -160,7 +161,6 @@ def purge_duplicates(
group = [timestamps[0]]
for i in range(1, len(timestamps)):
# check the difference between current timestamp and previous
# timestamp is less than the threshold
if timestamps[i] - timestamps[i - 1] < threshold:
@@ -379,7 +379,6 @@ def acausal_kde1d(spikes, time, width):
if __name__ == "__main__":
timestamps = [
[1.2, 1.5, 1.3],
[],

View File

@@ -35,7 +35,6 @@ class LoadData:
"""
def __init__(self, datapath: str) -> None:
# load raw data
self.datapath = datapath
self.file = os.path.join(datapath, "traces-grid1.raw")

View File

@@ -3,10 +3,10 @@ import numpy as np
def bandpass_filter(
signal: np.ndarray,
samplerate: float,
lowf: float,
highf: float,
signal: np.ndarray,
samplerate: float,
lowf: float,
highf: float,
) -> np.ndarray:
"""Bandpass filter a signal.
@@ -60,9 +60,7 @@ def highpass_filter(
def lowpass_filter(
signal: np.ndarray,
samplerate: float,
cutoff: float
signal: np.ndarray, samplerate: float, cutoff: float
) -> np.ndarray:
"""Lowpass filter a signal.
@@ -86,10 +84,9 @@ def lowpass_filter(
return filtered_signal
def envelope(signal: np.ndarray,
samplerate: float,
cutoff_frequency: float
) -> np.ndarray:
def envelope(
signal: np.ndarray, samplerate: float, cutoff_frequency: float
) -> np.ndarray:
"""Calculate the envelope of a signal using a lowpass filter.
Parameters

View File

@@ -2,12 +2,13 @@ import logging
def makeLogger(name: str):
# create logger formats for file and terminal
file_formatter = logging.Formatter(
"[ %(levelname)s ] ~ %(asctime)s ~ %(module)s.%(funcName)s: %(message)s")
"[ %(levelname)s ] ~ %(asctime)s ~ %(module)s.%(funcName)s: %(message)s"
)
console_formatter = logging.Formatter(
"[ %(levelname)s ] in %(module)s.%(funcName)s: %(message)s")
"[ %(levelname)s ] in %(module)s.%(funcName)s: %(message)s"
)
# create logging file if loglevel is debug
file_handler = logging.FileHandler(f"gridtools_log.log", mode="w")
@@ -29,7 +30,6 @@ def makeLogger(name: str):
if __name__ == "__main__":
# initiate logger
mylogger = makeLogger(__name__)

View File

@@ -7,7 +7,6 @@ from matplotlib.colors import ListedColormap
def PlotStyle() -> None:
class style:
# lightcmap = cmocean.tools.lighten(cmocean.cm.haline, 0.8)
# units
@@ -76,13 +75,15 @@ def PlotStyle() -> None:
va="center",
zorder=1000,
bbox=dict(
boxstyle=f"circle, pad={padding}", fc="white", ec="black", lw=1
boxstyle=f"circle, pad={padding}",
fc="white",
ec="black",
lw=1,
),
)
@classmethod
def fade_cmap(cls, cmap):
my_cmap = cmap(np.arange(cmap.N))
my_cmap[:, -1] = np.linspace(0, 1, cmap.N)
my_cmap = ListedColormap(my_cmap)
@@ -295,7 +296,6 @@ def PlotStyle() -> None:
if __name__ == "__main__":
s = PlotStyle()
import matplotlib.cbook as cbook
@@ -347,7 +347,8 @@ if __name__ == "__main__":
for ax in axs:
ax.yaxis.grid(True)
ax.set_xticks(
[y + 1 for y in range(len(all_data))], labels=["x1", "x2", "x3", "x4"]
[y + 1 for y in range(len(all_data))],
labels=["x1", "x2", "x3", "x4"],
)
ax.set_xlabel("Four separate samples")
ax.set_ylabel("Observed values")
@@ -396,7 +397,10 @@ if __name__ == "__main__":
grid = np.random.rand(4, 4)
fig, axs = plt.subplots(
nrows=3, ncols=6, figsize=(9, 6), subplot_kw={"xticks": [], "yticks": []}
nrows=3,
ncols=6,
figsize=(9, 6),
subplot_kw={"xticks": [], "yticks": []},
)
for ax, interp_method in zip(axs.flat, methods):

View File

@@ -7,7 +7,6 @@ from matplotlib.colors import ListedColormap
def PlotStyle() -> None:
class style:
# lightcmap = cmocean.tools.lighten(cmocean.cm.haline, 0.8)
# units
@@ -76,13 +75,15 @@ def PlotStyle() -> None:
va="center",
zorder=1000,
bbox=dict(
boxstyle=f"circle, pad={padding}", fc="white", ec="black", lw=1
boxstyle=f"circle, pad={padding}",
fc="white",
ec="black",
lw=1,
),
)
@classmethod
def fade_cmap(cls, cmap):
my_cmap = cmap(np.arange(cmap.N))
my_cmap[:, -1] = np.linspace(0, 1, cmap.N)
my_cmap = ListedColormap(my_cmap)
@@ -295,7 +296,6 @@ def PlotStyle() -> None:
if __name__ == "__main__":
s = PlotStyle()
import matplotlib.cbook as cbook
@@ -347,7 +347,8 @@ if __name__ == "__main__":
for ax in axs:
ax.yaxis.grid(True)
ax.set_xticks(
[y + 1 for y in range(len(all_data))], labels=["x1", "x2", "x3", "x4"]
[y + 1 for y in range(len(all_data))],
labels=["x1", "x2", "x3", "x4"],
)
ax.set_xlabel("Four separate samples")
ax.set_ylabel("Observed values")
@@ -396,7 +397,10 @@ if __name__ == "__main__":
grid = np.random.rand(4, 4)
fig, axs = plt.subplots(
nrows=3, ncols=6, figsize=(9, 6), subplot_kw={"xticks": [], "yticks": []}
nrows=3,
ncols=6,
figsize=(9, 6),
subplot_kw={"xticks": [], "yticks": []},
)
for ax, interp_method in zip(axs.flat, methods):

View File

@@ -7,7 +7,6 @@ from matplotlib.colors import ListedColormap
def PlotStyle() -> None:
class style:
# lightcmap = cmocean.tools.lighten(cmocean.cm.haline, 0.8)
# units
@@ -76,13 +75,15 @@ def PlotStyle() -> None:
va="center",
zorder=1000,
bbox=dict(
boxstyle=f"circle, pad={padding}", fc="white", ec="black", lw=1
boxstyle=f"circle, pad={padding}",
fc="white",
ec="black",
lw=1,
),
)
@classmethod
def fade_cmap(cls, cmap):
my_cmap = cmap(np.arange(cmap.N))
my_cmap[:, -1] = np.linspace(0, 1, cmap.N)
my_cmap = ListedColormap(my_cmap)
@@ -295,7 +296,6 @@ def PlotStyle() -> None:
if __name__ == "__main__":
s = PlotStyle()
import matplotlib.cbook as cbook
@@ -347,7 +347,8 @@ if __name__ == "__main__":
for ax in axs:
ax.yaxis.grid(True)
ax.set_xticks(
[y + 1 for y in range(len(all_data))], labels=["x1", "x2", "x3", "x4"]
[y + 1 for y in range(len(all_data))],
labels=["x1", "x2", "x3", "x4"],
)
ax.set_xlabel("Four separate samples")
ax.set_ylabel("Observed values")
@@ -396,7 +397,10 @@ if __name__ == "__main__":
grid = np.random.rand(4, 4)
fig, axs = plt.subplots(
nrows=3, ncols=6, figsize=(9, 6), subplot_kw={"xticks": [], "yticks": []}
nrows=3,
ncols=6,
figsize=(9, 6),
subplot_kw={"xticks": [], "yticks": []},
)
for ax, interp_method in zip(axs.flat, methods):

View File

@@ -37,7 +37,7 @@ def create_chirp(
ck = 0
csig = 0.5 * chirpduration / np.power(2.0 * np.log(10.0), 0.5 / kurtosis)
#csig = csig*-1
# csig = csig*-1
for k, t in enumerate(time):
a = 1.0
f = eodf