intermediate state

This commit is contained in:
Jan Grewe 2025-02-12 17:14:21 +01:00
parent c7e482ffd1
commit bf7c37eb46
3 changed files with 137 additions and 33 deletions

View File

@ -157,6 +157,32 @@ class TrackingData(QObject):
center_of_gravity = np.sum(weighted_coords, axis=1) / sum_scores center_of_gravity = np.sum(weighted_coords, axis=1) / sum_scores
return center_of_gravity return center_of_gravity
def animalLength(self, bodyaxis=None):
if bodyaxis is None:
bodyaxis = [0, 1, 2, 5]
bodycoords = self.coordinates()[:, bodyaxis, :]
lengths = np.sum(np.sqrt(np.sum(np.diff(bodycoords, axis=1)**2, axis=2)), axis=1)
return lengths
def orientation(self, head_node=1, tail_node=5):
bodycoords = self.coordinates()[:, [head_node, tail_node], :]
vectors = bodycoords[:, 1, :] - bodycoords[:, 0, :]
orientations = np.arctan2(vectors[:, 1], vectors[:, 0])
return orientations
def bendedness(self, bodyaxis=None):
if bodyaxis is None:
bodyaxis = [0, 1, 2, 5]
bodycoords = self.coordinates()[:, bodyaxis, :]
head_tail_vector = bodycoords[:, -1, :] - bodycoords[:, 0, :]
head_tail_length = np.linalg.norm(head_tail_vector, axis=1, keepdims=True)
normalized_head_tail_vector = head_tail_vector / head_tail_length
projections = np.einsum('ijk,ij->ik', bodycoords - bodycoords[:, 0, np.newaxis, :], normalized_head_tail_vector)
distances = np.linalg.norm(bodycoords - (bodycoords[:, 0, np.newaxis, :] + projections[:, :, np.newaxis] * normalized_head_tail_vector[:, np.newaxis, :]), axis=2)
deviation = np.mean(distances, axis=1)
return deviation
def __getitem__(self, key): def __getitem__(self, key):
return self._data[key] return self._data[key]
@ -174,7 +200,6 @@ def main():
from IPython import embed from IPython import embed
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from fixtracks.info import PACKAGE_ROOT from fixtracks.info import PACKAGE_ROOT
from scipy.spatial.distance import cdist
def as_dict(df:pd.DataFrame): def as_dict(df:pd.DataFrame):
d = {c: df[c].values for c in df.columns} d = {c: df[c].values for c in df.columns}
@ -208,6 +233,13 @@ def main():
data = TrackingData() data = TrackingData()
data.setData(as_dict(df)) data.setData(as_dict(df))
all_cogs = data.centerOfGravity() all_cogs = data.centerOfGravity()
orientations = data.orientation()
lengths = data.animalLength()
frames = data["frame"]
tracks = data["track"]
bendedness = data.bendedness()
embed()
tracks = data["track"] tracks = data["track"]
cogs = all_cogs[tracks==1] cogs = all_cogs[tracks==1]
all_dists = neighborDistances(cogs, 2, False) all_dists = neighborDistances(cogs, 2, False)

View File

@ -1,10 +1,10 @@
import logging import logging
import numpy as np import numpy as np
import pyqtgraph as pg
from PySide6.QtWidgets import QWidget, QVBoxLayout, QTabWidget,QPushButton, QGraphicsView from PySide6.QtWidgets import QWidget, QVBoxLayout, QTabWidget,QPushButton, QGraphicsView
from PySide6.QtCore import Signal from PySide6.QtCore import Signal
from PySide6.QtGui import QBrush, QColor from PySide6.QtGui import QBrush, QColor
import pyqtgraph as pg # needs to be imported after pyside to not import pyqt
from fixtracks.utils.trackingdata import TrackingData from fixtracks.utils.trackingdata import TrackingData
@ -38,10 +38,10 @@ class SizeClassifier(QWidget):
self._t1_selection = pg.LinearRegionItem([100, 200]) self._t1_selection = pg.LinearRegionItem([100, 200])
self._t1_selection.setZValue(-10) self._t1_selection.setZValue(-10)
self._t1_selection.setBrush(track1_brush) self._t1_selection.setBrush("orange")
self._t2_selection = pg.LinearRegionItem([300,400]) self._t2_selection = pg.LinearRegionItem([300,400])
self._t2_selection.setZValue(-10) self._t2_selection.setZValue(-10)
self._t2_selection.setBrush(track2_brush) self._t2_selection.setBrush("green")
return plot_widget return plot_widget
def estimate_length(self, coords, bodyaxis =None): def estimate_length(self, coords, bodyaxis =None):
@ -94,6 +94,7 @@ class NeighborhoodValidator(QWidget):
self._positions = None self._positions = None
self._distances = None self._distances = None
self._tracks = None self._tracks = None
self._frames = None
self._plot = None self._plot = None
self._plot_widget = self.setupGraph() self._plot_widget = self.setupGraph()
@ -116,8 +117,10 @@ class NeighborhoodValidator(QWidget):
def estimate_histogram(self, dists, min_threshold=1., max_threshold=99., bin_count=100, log=False): def estimate_histogram(self, dists, min_threshold=1., max_threshold=99., bin_count=100, log=False):
min_dist = np.percentile(dists, min_threshold) min_dist = np.percentile(dists, min_threshold)
max_dist = np.percentile(dists, max_threshold) max_dist = np.percentile(dists, max_threshold)
print(min_dist, max_dist)
if log: if log:
bins = np.logspace(min_dist, max_dist, bin_count, base=10) bins = np.logspace(min_dist, max_dist, bin_count, base=10)
else:
bins = np.linspace(min_dist, max_dist, bin_count) bins = np.linspace(min_dist, max_dist, bin_count)
hist, edges = np.histogram(dists, bins=bins, density=True) hist, edges = np.histogram(dists, bins=bins, density=True)
return hist, edges return hist, edges
@ -157,9 +160,12 @@ class NeighborhoodValidator(QWidget):
frames : np.ndarray frames : np.ndarray
respective frame. respective frame.
""" """
def mouseClicked(self, event): def mouseClicked(event):
print("mouse clicked at", event.pos()) pos = event.pos()
if self._plot.sceneBoundingRect().contains(pos):
mousePoint = vb.mapSceneToView(pos)
print("mouse clicked at", mousePoint)
vLine.setPos(mousePoint.x())
track2_brush = QBrush(QColor.fromString("green")) track2_brush = QBrush(QColor.fromString("green"))
track1_brush = QBrush(QColor.fromString("orange")) track1_brush = QBrush(QColor.fromString("orange"))
self._positions = positions self._positions = positions
@ -172,11 +178,12 @@ class NeighborhoodValidator(QWidget):
t2_frames = self._frames[self._tracks == 2] t2_frames = self._frames[self._tracks == 2]
t2_distances = self.neighborDistances(t2_positions, t2_frames, 1, False) t2_distances = self.neighborDistances(t2_positions, t2_frames, 1, False)
n, e = self.estimate_histogram(t1_distances[1:], bin_count=100, log=False)
self._plot = self._plot_widget.addPlot() self._plot = self._plot_widget.addPlot()
vb = self._plot.vb
n, e = self.estimate_histogram(t1_distances[1:], 1, 95, bin_count=100, log=False)
bgi1 = pg.BarGraphItem(x0=e[:-1], x1=e[1:], height=n, pen='w', brush=track1_brush) bgi1 = pg.BarGraphItem(x0=e[:-1], x1=e[1:], height=n, pen='w', brush=track1_brush)
self._plot.addItem(bgi1) self._plot.addItem(bgi1)
n, e = self.estimate_histogram(t2_distances[1:], bin_count=100, log=False) n, e = self.estimate_histogram(t2_distances[1:], 1, 95, bin_count=100, log=False)
bgi2 = pg.BarGraphItem(x0=e[:-1], x1=e[1:], height=n, pen='w', brush=track2_brush) bgi2 = pg.BarGraphItem(x0=e[:-1], x1=e[1:], height=n, pen='w', brush=track2_brush)
self._plot.addItem(bgi2) self._plot.addItem(bgi2)
self._plot.scene().sigMouseClicked.connect(mouseClicked) self._plot.scene().sigMouseClicked.connect(mouseClicked)
@ -187,7 +194,45 @@ class NeighborhoodValidator(QWidget):
# plot.addItem(self._threshold) # plot.addItem(self._threshold)
vLine = pg.InfiniteLine(pos=10, angle=90, movable=False) vLine = pg.InfiniteLine(pos=10, angle=90, movable=False)
self._plot.addItem(vLine, ignoreBounds=True) self._plot.addItem(vLine, ignoreBounds=True)
vb = self._plot.vb
class ConsistencyClassifier(QWidget):
apply = Signal()
name = "Consistency classifier"
def __init__(self, parent=None):
super().__init__(parent)
def setData(self, keypoints, tracks, frames):
"""Set the data, the classifier/should be working on.
Parameters
----------
positions : np.ndarray
The position estimates, e.g. the center of gravity for each detection
tracks : np.ndarray
The current track assignment.
frames : np.ndarray
respective frame.
"""
def mouseClicked(event):
pos = event.pos()
if self._plot.sceneBoundingRect().contains(pos):
mousePoint = vb.mapSceneToView(pos)
print("mouse clicked at", mousePoint)
vLine.setPos(mousePoint.x())
track2_brush = QBrush(QColor.fromString("green"))
track1_brush = QBrush(QColor.fromString("orange"))
self._positions = positions
self._tracks = tracks
self._frames = frames
t1_positions = self._positions[self._tracks == 1]
t1_frames = self._frames[self._tracks == 1]
t1_distances = self.neighborDistances(t1_positions, t1_frames, 1, False)
t2_positions = self._positions[self._tracks == 2]
t2_frames = self._frames[self._tracks == 2]
t2_distances = self.neighborDistances(t2_positions, t2_frames, 1, False)
@ -226,7 +271,7 @@ def main():
import pickle import pickle
from fixtracks.info import PACKAGE_ROOT from fixtracks.info import PACKAGE_ROOT
datafile = PACKAGE_ROOT / "data/merged_small.pkl" datafile = PACKAGE_ROOT / "data/merged_small_tracked.pkl"
with open(datafile, "rb") as f: with open(datafile, "rb") as f:
df = pickle.load(f) df = pickle.load(f)

View File

@ -32,15 +32,29 @@ class Window(QGraphicsRectItem):
self.signals.windowMoved.emit() self.signals.windowMoved.emit()
def setWindowWidth(self, newwidth): def setWindowWidth(self, newwidth):
logging.debug("timeline.window: update window width to %.3f", newwidth) logging.debug("timeline.window: update window width to %f", newwidth)
self._width = newwidth self._width = newwidth
r = self.rect() r = self.rect()
r.setWidth(newwidth) r.setWidth(newwidth)
self.setRect(r) self.setRect(r)
self.signals.windowMoved.emit() self.signals.windowMoved.emit()
def setWindow(self, newx, newwidth): def setWindow(self, newx:float, newwidth:float):
logging.debug("timeline.window: update window to range %.3f to %.3f", newx, newwidth) def setWindow(self, newx: float, newwidth: float):
"""
Update the window to the specified range.
Parameters
----------
newx : float
The new x-coordinate of the window.
newwidth : float
The new width of the window.
Returns
-------
None
"""
logging.debug("timeline.window: update window to range %.5f to %.5f", newx, newwidth)
self._width = newwidth self._width = newwidth
r = self.rect() r = self.rect()
self.setRect(newx, r.y(), self._width, r.height()) self.setRect(newx, r.y(), self._width, r.height())
@ -62,7 +76,6 @@ class Window(QGraphicsRectItem):
self.setX(self.scene().width() - self._width) self.setX(self.scene().width() - self._width)
if r.y() != self._y: if r.y() != self._y:
self.setY(self._y) self.setY(self._y)
print(self.sceneBoundingRect())
super().mouseReleaseEvent(event) super().mouseReleaseEvent(event)
self.signals.windowMoved.emit() self.signals.windowMoved.emit()
@ -80,7 +93,7 @@ class DetectionTimeline(QWidget):
self._data = detectiondata self._data = detectiondata
self._rangeStart = 0.0 self._rangeStart = 0.0
self._rangeStop = 0.005 self._rangeStop = 0.005
self.total_width = 2000 self._total_width = 2000
self._stepCount = 200 self._stepCount = 200
self._bg_brush = QBrush(QColor(20, 20, 20, 255)) self._bg_brush = QBrush(QColor(20, 20, 20, 255))
transparent_brush = QBrush(QColor(200, 200, 200, 64)) transparent_brush = QBrush(QColor(200, 200, 200, 64))
@ -101,7 +114,7 @@ class DetectionTimeline(QWidget):
self._window = Window(0, 0, 100, 60, axis_pen, transparent_brush) self._window = Window(0, 0, 100, 60, axis_pen, transparent_brush)
self._window.signals.windowMoved.connect(self.on_windowMoved) self._window.signals.windowMoved.connect(self.on_windowMoved)
self._scene = QGraphicsScene(QRectF(0, 0, self.total_width, 55.)) self._scene = QGraphicsScene(QRectF(0, 0, self._total_width, 55.))
self._scene.setBackgroundBrush(self._bg_brush) self._scene.setBackgroundBrush(self._bg_brush)
self._scene.addItem(self._window) self._scene.addItem(self._window)
@ -147,6 +160,7 @@ class DetectionTimeline(QWidget):
self.draw_coverage() self.draw_coverage()
def draw_coverage(self): def draw_coverage(self):
# FIXME this must be disentangled. timeline should not have to deal with two different ways of data storage
if isinstance(self._data, pd.DataFrame): if isinstance(self._data, pd.DataFrame):
maxframe = np.max(self._data.frame.values) maxframe = np.max(self._data.frame.values)
@ -178,9 +192,9 @@ class DetectionTimeline(QWidget):
if t2_coverage[i]: self._scene.addLine(pos[i], 17, pos[i], 32., pen=self._t2_pen) if t2_coverage[i]: self._scene.addLine(pos[i], 17, pos[i], 32., pen=self._t2_pen)
if other_coverage[i]: self._scene.addLine(pos[i], 34, pos[i], 49., pen=self._other_pen) if other_coverage[i]: self._scene.addLine(pos[i], 34, pos[i], 49., pen=self._other_pen)
def updatePosition(self): def updatePositionLabel(self):
start = np.round(self._rangeStart * 100, 1) start = np.round(self._rangeStart * 100, 4)
stop = np.round(self._rangeStop * 100, 1) stop = np.round(self._rangeStop * 100, 4)
self._position_label.setText(f"Current position: {start}% to {stop}% of data.") self._position_label.setText(f"Current position: {start}% to {stop}% of data.")
@property @property
@ -202,10 +216,10 @@ class DetectionTimeline(QWidget):
def on_windowMoved(self): def on_windowMoved(self):
scene_width = self._scene.width() scene_width = self._scene.width()
self._rangeStart = np.round(self._window.sceneBoundingRect().left() / scene_width, 3) self._rangeStart = self._window.sceneBoundingRect().left() / scene_width
self._rangeStop = np.round(self._window.sceneBoundingRect().right() / scene_width, 3) self._rangeStop = self._window.sceneBoundingRect().right() / scene_width
logging.debug("Timeline: WindowUpdated positions start: %.3f end: %.3f", self.rangeStart, self.rangeStop) logging.debug("Timeline: WindowUpdated positions start: %.3f end: %.3f", self.rangeStart, self.rangeStop)
self.updatePosition() self.updatePositionLabel()
self.signals.windowMoved.emit() self.signals.windowMoved.emit()
def setWindowPos(self, newx: float): def setWindowPos(self, newx: float):
@ -220,8 +234,8 @@ class DetectionTimeline(QWidget):
newx = 0.0 newx = 0.0
elif newx > 1.0: elif newx > 1.0:
newx = 1.0 newx = 1.0
logging.debug("Set window x tp new position %.3f", newx) logging.debug("Timeline:setWindow to new position %.4f", newx)
x_rel = np.round(newx * self.total_width) x_rel = np.round(newx * self._total_width)
self._window.setWindowX(x_rel) self._window.setWindowX(x_rel)
def setWindowWidth(self, width: float): def setWindowWidth(self, width: float):
@ -232,17 +246,30 @@ class DetectionTimeline(QWidget):
width : float width : float
The width in a range 0.0 to 1.0 (aka 0% to 100% of the span.) The width in a range 0.0 to 1.0 (aka 0% to 100% of the span.)
""" """
logging.debug("Set window width to new value %.3f of %i total width", width, self.total_width) logging.debug("Set window width to new value %.5f of %i total width", width, self._total_width)
span = np.round(width * self.total_width) span = np.round(width * self._total_width)
self._window.setWindowWidth(span) self._window.setWindowWidth(np.round(span))
def setWindow(self, xpos, width): def setWindow(self, xpos:float, width:float):
span = np.round(width * self.total_width) """
Set the window position and width.
Parameters
----------
xpos : float
The x position of the window as a fraction of the total data.
Must be between 0.0 and 1.0. Values outside this range will be clamped.
width : float
The width of the window as a fraction of the total data.
Returns
-------
None
"""
if xpos < 0.0: if xpos < 0.0:
xpos = 0.0 xpos = 0.0
elif xpos > 1.0: elif xpos > 1.0:
xpos = 1.0 xpos = 1.0
xstart = np.round(xpos * self.total_width) xstart = xpos * self._total_width
span = width * self._total_width
self._window.setWindow(xstart, span) self._window.setWindow(xstart, span)
def windowBounds(self): def windowBounds(self):