intermediate state

This commit is contained in:
Jan Grewe 2025-02-12 17:14:21 +01:00
parent c7e482ffd1
commit bf7c37eb46
3 changed files with 137 additions and 33 deletions

View File

@ -157,6 +157,32 @@ class TrackingData(QObject):
center_of_gravity = np.sum(weighted_coords, axis=1) / sum_scores
return center_of_gravity
def animalLength(self, bodyaxis=None):
if bodyaxis is None:
bodyaxis = [0, 1, 2, 5]
bodycoords = self.coordinates()[:, bodyaxis, :]
lengths = np.sum(np.sqrt(np.sum(np.diff(bodycoords, axis=1)**2, axis=2)), axis=1)
return lengths
def orientation(self, head_node=1, tail_node=5):
bodycoords = self.coordinates()[:, [head_node, tail_node], :]
vectors = bodycoords[:, 1, :] - bodycoords[:, 0, :]
orientations = np.arctan2(vectors[:, 1], vectors[:, 0])
return orientations
def bendedness(self, bodyaxis=None):
if bodyaxis is None:
bodyaxis = [0, 1, 2, 5]
bodycoords = self.coordinates()[:, bodyaxis, :]
head_tail_vector = bodycoords[:, -1, :] - bodycoords[:, 0, :]
head_tail_length = np.linalg.norm(head_tail_vector, axis=1, keepdims=True)
normalized_head_tail_vector = head_tail_vector / head_tail_length
projections = np.einsum('ijk,ij->ik', bodycoords - bodycoords[:, 0, np.newaxis, :], normalized_head_tail_vector)
distances = np.linalg.norm(bodycoords - (bodycoords[:, 0, np.newaxis, :] + projections[:, :, np.newaxis] * normalized_head_tail_vector[:, np.newaxis, :]), axis=2)
deviation = np.mean(distances, axis=1)
return deviation
def __getitem__(self, key):
return self._data[key]
@ -174,7 +200,6 @@ def main():
from IPython import embed
import matplotlib.pyplot as plt
from fixtracks.info import PACKAGE_ROOT
from scipy.spatial.distance import cdist
def as_dict(df:pd.DataFrame):
d = {c: df[c].values for c in df.columns}
@ -208,6 +233,13 @@ def main():
data = TrackingData()
data.setData(as_dict(df))
all_cogs = data.centerOfGravity()
orientations = data.orientation()
lengths = data.animalLength()
frames = data["frame"]
tracks = data["track"]
bendedness = data.bendedness()
embed()
tracks = data["track"]
cogs = all_cogs[tracks==1]
all_dists = neighborDistances(cogs, 2, False)

View File

@ -1,10 +1,10 @@
import logging
import numpy as np
import pyqtgraph as pg
from PySide6.QtWidgets import QWidget, QVBoxLayout, QTabWidget,QPushButton, QGraphicsView
from PySide6.QtCore import Signal
from PySide6.QtGui import QBrush, QColor
import pyqtgraph as pg # needs to be imported after pyside to not import pyqt
from fixtracks.utils.trackingdata import TrackingData
@ -38,10 +38,10 @@ class SizeClassifier(QWidget):
self._t1_selection = pg.LinearRegionItem([100, 200])
self._t1_selection.setZValue(-10)
self._t1_selection.setBrush(track1_brush)
self._t1_selection.setBrush("orange")
self._t2_selection = pg.LinearRegionItem([300,400])
self._t2_selection.setZValue(-10)
self._t2_selection.setBrush(track2_brush)
self._t2_selection.setBrush("green")
return plot_widget
def estimate_length(self, coords, bodyaxis =None):
@ -94,6 +94,7 @@ class NeighborhoodValidator(QWidget):
self._positions = None
self._distances = None
self._tracks = None
self._frames = None
self._plot = None
self._plot_widget = self.setupGraph()
@ -116,9 +117,11 @@ class NeighborhoodValidator(QWidget):
def estimate_histogram(self, dists, min_threshold=1., max_threshold=99., bin_count=100, log=False):
min_dist = np.percentile(dists, min_threshold)
max_dist = np.percentile(dists, max_threshold)
print(min_dist, max_dist)
if log:
bins = np.logspace(min_dist, max_dist, bin_count, base=10)
bins = np.linspace(min_dist, max_dist, bin_count)
else:
bins = np.linspace(min_dist, max_dist, bin_count)
hist, edges = np.histogram(dists, bins=bins, density=True)
return hist, edges
@ -157,9 +160,12 @@ class NeighborhoodValidator(QWidget):
frames : np.ndarray
respective frame.
"""
def mouseClicked(self, event):
print("mouse clicked at", event.pos())
def mouseClicked(event):
pos = event.pos()
if self._plot.sceneBoundingRect().contains(pos):
mousePoint = vb.mapSceneToView(pos)
print("mouse clicked at", mousePoint)
vLine.setPos(mousePoint.x())
track2_brush = QBrush(QColor.fromString("green"))
track1_brush = QBrush(QColor.fromString("orange"))
self._positions = positions
@ -172,11 +178,12 @@ class NeighborhoodValidator(QWidget):
t2_frames = self._frames[self._tracks == 2]
t2_distances = self.neighborDistances(t2_positions, t2_frames, 1, False)
n, e = self.estimate_histogram(t1_distances[1:], bin_count=100, log=False)
self._plot = self._plot_widget.addPlot()
vb = self._plot.vb
n, e = self.estimate_histogram(t1_distances[1:], 1, 95, bin_count=100, log=False)
bgi1 = pg.BarGraphItem(x0=e[:-1], x1=e[1:], height=n, pen='w', brush=track1_brush)
self._plot.addItem(bgi1)
n, e = self.estimate_histogram(t2_distances[1:], bin_count=100, log=False)
n, e = self.estimate_histogram(t2_distances[1:], 1, 95, bin_count=100, log=False)
bgi2 = pg.BarGraphItem(x0=e[:-1], x1=e[1:], height=n, pen='w', brush=track2_brush)
self._plot.addItem(bgi2)
self._plot.scene().sigMouseClicked.connect(mouseClicked)
@ -187,9 +194,47 @@ class NeighborhoodValidator(QWidget):
# plot.addItem(self._threshold)
vLine = pg.InfiniteLine(pos=10, angle=90, movable=False)
self._plot.addItem(vLine, ignoreBounds=True)
vb = self._plot.vb
class ConsistencyClassifier(QWidget):
apply = Signal()
name = "Consistency classifier"
def __init__(self, parent=None):
super().__init__(parent)
def setData(self, keypoints, tracks, frames):
"""Set the data, the classifier/should be working on.
Parameters
----------
positions : np.ndarray
The position estimates, e.g. the center of gravity for each detection
tracks : np.ndarray
The current track assignment.
frames : np.ndarray
respective frame.
"""
def mouseClicked(event):
pos = event.pos()
if self._plot.sceneBoundingRect().contains(pos):
mousePoint = vb.mapSceneToView(pos)
print("mouse clicked at", mousePoint)
vLine.setPos(mousePoint.x())
track2_brush = QBrush(QColor.fromString("green"))
track1_brush = QBrush(QColor.fromString("orange"))
self._positions = positions
self._tracks = tracks
self._frames = frames
t1_positions = self._positions[self._tracks == 1]
t1_frames = self._frames[self._tracks == 1]
t1_distances = self.neighborDistances(t1_positions, t1_frames, 1, False)
t2_positions = self._positions[self._tracks == 2]
t2_frames = self._frames[self._tracks == 2]
t2_distances = self.neighborDistances(t2_positions, t2_frames, 1, False)
class ClassifierWidget(QTabWidget):
apply_sizeclassifier = Signal(np.ndarray)
@ -226,7 +271,7 @@ def main():
import pickle
from fixtracks.info import PACKAGE_ROOT
datafile = PACKAGE_ROOT / "data/merged_small.pkl"
datafile = PACKAGE_ROOT / "data/merged_small_tracked.pkl"
with open(datafile, "rb") as f:
df = pickle.load(f)

View File

@ -32,15 +32,29 @@ class Window(QGraphicsRectItem):
self.signals.windowMoved.emit()
def setWindowWidth(self, newwidth):
logging.debug("timeline.window: update window width to %.3f", newwidth)
logging.debug("timeline.window: update window width to %f", newwidth)
self._width = newwidth
r = self.rect()
r.setWidth(newwidth)
self.setRect(r)
self.signals.windowMoved.emit()
def setWindow(self, newx, newwidth):
logging.debug("timeline.window: update window to range %.3f to %.3f", newx, newwidth)
def setWindow(self, newx:float, newwidth:float):
def setWindow(self, newx: float, newwidth: float):
"""
Update the window to the specified range.
Parameters
----------
newx : float
The new x-coordinate of the window.
newwidth : float
The new width of the window.
Returns
-------
None
"""
logging.debug("timeline.window: update window to range %.5f to %.5f", newx, newwidth)
self._width = newwidth
r = self.rect()
self.setRect(newx, r.y(), self._width, r.height())
@ -62,7 +76,6 @@ class Window(QGraphicsRectItem):
self.setX(self.scene().width() - self._width)
if r.y() != self._y:
self.setY(self._y)
print(self.sceneBoundingRect())
super().mouseReleaseEvent(event)
self.signals.windowMoved.emit()
@ -80,7 +93,7 @@ class DetectionTimeline(QWidget):
self._data = detectiondata
self._rangeStart = 0.0
self._rangeStop = 0.005
self.total_width = 2000
self._total_width = 2000
self._stepCount = 200
self._bg_brush = QBrush(QColor(20, 20, 20, 255))
transparent_brush = QBrush(QColor(200, 200, 200, 64))
@ -101,7 +114,7 @@ class DetectionTimeline(QWidget):
self._window = Window(0, 0, 100, 60, axis_pen, transparent_brush)
self._window.signals.windowMoved.connect(self.on_windowMoved)
self._scene = QGraphicsScene(QRectF(0, 0, self.total_width, 55.))
self._scene = QGraphicsScene(QRectF(0, 0, self._total_width, 55.))
self._scene.setBackgroundBrush(self._bg_brush)
self._scene.addItem(self._window)
@ -147,6 +160,7 @@ class DetectionTimeline(QWidget):
self.draw_coverage()
def draw_coverage(self):
# FIXME this must be disentangled. timeline should not have to deal with two different ways of data storage
if isinstance(self._data, pd.DataFrame):
maxframe = np.max(self._data.frame.values)
@ -178,9 +192,9 @@ class DetectionTimeline(QWidget):
if t2_coverage[i]: self._scene.addLine(pos[i], 17, pos[i], 32., pen=self._t2_pen)
if other_coverage[i]: self._scene.addLine(pos[i], 34, pos[i], 49., pen=self._other_pen)
def updatePosition(self):
start = np.round(self._rangeStart * 100, 1)
stop = np.round(self._rangeStop * 100, 1)
def updatePositionLabel(self):
start = np.round(self._rangeStart * 100, 4)
stop = np.round(self._rangeStop * 100, 4)
self._position_label.setText(f"Current position: {start}% to {stop}% of data.")
@property
@ -202,10 +216,10 @@ class DetectionTimeline(QWidget):
def on_windowMoved(self):
scene_width = self._scene.width()
self._rangeStart = np.round(self._window.sceneBoundingRect().left() / scene_width, 3)
self._rangeStop = np.round(self._window.sceneBoundingRect().right() / scene_width, 3)
self._rangeStart = self._window.sceneBoundingRect().left() / scene_width
self._rangeStop = self._window.sceneBoundingRect().right() / scene_width
logging.debug("Timeline: WindowUpdated positions start: %.3f end: %.3f", self.rangeStart, self.rangeStop)
self.updatePosition()
self.updatePositionLabel()
self.signals.windowMoved.emit()
def setWindowPos(self, newx: float):
@ -220,8 +234,8 @@ class DetectionTimeline(QWidget):
newx = 0.0
elif newx > 1.0:
newx = 1.0
logging.debug("Set window x tp new position %.3f", newx)
x_rel = np.round(newx * self.total_width)
logging.debug("Timeline:setWindow to new position %.4f", newx)
x_rel = np.round(newx * self._total_width)
self._window.setWindowX(x_rel)
def setWindowWidth(self, width: float):
@ -232,17 +246,30 @@ class DetectionTimeline(QWidget):
width : float
The width in a range 0.0 to 1.0 (aka 0% to 100% of the span.)
"""
logging.debug("Set window width to new value %.3f of %i total width", width, self.total_width)
span = np.round(width * self.total_width)
self._window.setWindowWidth(span)
logging.debug("Set window width to new value %.5f of %i total width", width, self._total_width)
span = np.round(width * self._total_width)
self._window.setWindowWidth(np.round(span))
def setWindow(self, xpos, width):
span = np.round(width * self.total_width)
def setWindow(self, xpos:float, width:float):
"""
Set the window position and width.
Parameters
----------
xpos : float
The x position of the window as a fraction of the total data.
Must be between 0.0 and 1.0. Values outside this range will be clamped.
width : float
The width of the window as a fraction of the total data.
Returns
-------
None
"""
if xpos < 0.0:
xpos = 0.0
elif xpos > 1.0:
xpos = 1.0
xstart = np.round(xpos * self.total_width)
xstart = xpos * self._total_width
span = width * self._total_width
self._window.setWindow(xstart, span)
def windowBounds(self):