Compare commits
5 Commits
d8fe654ac8
...
15cee494f6
Author | SHA1 | Date | |
---|---|---|---|
15cee494f6 | |||
796e03ae7e | |||
801acc9547 | |||
66aa79e47a | |||
509405033a |
134
fixtracks/widgets/classifier.py
Normal file
134
fixtracks/widgets/classifier.py
Normal file
@ -0,0 +1,134 @@
|
|||||||
|
import logging
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from PySide6.QtWidgets import QWidget, QVBoxLayout, QTabWidget,QPushButton
|
||||||
|
from PySide6.QtCore import Signal
|
||||||
|
from PySide6.QtGui import QBrush, QColor
|
||||||
|
|
||||||
|
import pyqtgraph as pg
|
||||||
|
|
||||||
|
|
||||||
|
class SizeClassifier(QWidget):
|
||||||
|
apply = Signal()
|
||||||
|
|
||||||
|
def __init__(self, parent=None):
|
||||||
|
super().__init__(parent)
|
||||||
|
self._t1_selection = None
|
||||||
|
self._t2_selection = None
|
||||||
|
self._coordinates = None
|
||||||
|
self._sizes = None
|
||||||
|
|
||||||
|
self._plot_widget = self.setupGraph()
|
||||||
|
self._apply_btn = QPushButton("apply")
|
||||||
|
self._apply_btn.clicked.connect(lambda: self.apply.emit())
|
||||||
|
|
||||||
|
layout = QVBoxLayout()
|
||||||
|
layout.addWidget(self._plot_widget)
|
||||||
|
layout.addWidget(self._apply_btn)
|
||||||
|
self.setLayout(layout)
|
||||||
|
|
||||||
|
def setupGraph(self):
|
||||||
|
track1_brush = QBrush(QColor.fromString("orange"))
|
||||||
|
track1_brush.color().setAlphaF(0.5)
|
||||||
|
track2_brush = QBrush(QColor.fromString("green"))
|
||||||
|
|
||||||
|
pg.setConfigOptions(antialias=True)
|
||||||
|
plot_widget = pg.GraphicsLayoutWidget(show=False)
|
||||||
|
|
||||||
|
self._t1_selection = pg.LinearRegionItem([100, 200])
|
||||||
|
self._t1_selection.setZValue(-10) # what is that?
|
||||||
|
self._t1_selection.setBrush(track1_brush)
|
||||||
|
self._t2_selection = pg.LinearRegionItem([300,400])
|
||||||
|
self._t2_selection.setZValue(-10) # what is that?
|
||||||
|
self._t2_selection.setBrush(track2_brush)
|
||||||
|
return plot_widget
|
||||||
|
|
||||||
|
def estimate_length(self, coords, bodyaxis =None):
|
||||||
|
if bodyaxis is None:
|
||||||
|
bodyaxis = [0, 1, 2, 5]
|
||||||
|
bodycoords = coords[:, bodyaxis, :]
|
||||||
|
dists = np.sum(np.sqrt(np.sum(np.diff(bodycoords, axis=1)**2, axis=2)), axis=1)
|
||||||
|
return dists
|
||||||
|
|
||||||
|
def estimate_histogram(self, dists, min_threshold=1., max_threshold=99.):
|
||||||
|
min_length = np.percentile(dists, min_threshold)
|
||||||
|
max_length = np.percentile(dists, max_threshold)
|
||||||
|
bins = np.linspace(0.5 * min_length, 1.5 * max_length, 100)
|
||||||
|
hist, edges = np.histogram(dists, bins=bins, density=True)
|
||||||
|
return hist, edges
|
||||||
|
|
||||||
|
def setCoordinates(self, coordinates):
|
||||||
|
self._coordinates = coordinates
|
||||||
|
self._sizes = self.estimate_length(coordinates)
|
||||||
|
n, e = self.estimate_histogram(self._sizes)
|
||||||
|
plot = self._plot_widget.addPlot()
|
||||||
|
bgi = pg.BarGraphItem(x0=e[:-1], x1=e[1:], height=n, pen='w', brush=(0,0,255,150))
|
||||||
|
plot.addItem(bgi)
|
||||||
|
plot.setLabel('left', "prob. density")
|
||||||
|
plot.setLabel('bottom', "bodylength", units="px")
|
||||||
|
plot.addItem(self._t1_selection)
|
||||||
|
plot.addItem(self._t2_selection)
|
||||||
|
|
||||||
|
def selections(self, track1=True):
|
||||||
|
if track1:
|
||||||
|
return self._t1_selection.getRegion()
|
||||||
|
else:
|
||||||
|
return self._t2_selection.getRegion()
|
||||||
|
|
||||||
|
def assignedTracks(self):
|
||||||
|
tracks = np.ones_like(self._sizes, dtype=int) * -1
|
||||||
|
t1lower, t1upper = self.selections()
|
||||||
|
t2lower, t2upper = self.selections(False)
|
||||||
|
tracks[(self._sizes >= t1lower) & (self._sizes < t1upper)] = 1
|
||||||
|
tracks[(self._sizes >= t2lower) & (self._sizes < t2upper)] = 2
|
||||||
|
return tracks
|
||||||
|
|
||||||
|
|
||||||
|
class ClassifierWidget(QTabWidget):
|
||||||
|
apply_sizeclassifier = Signal(np.ndarray)
|
||||||
|
|
||||||
|
def __init__(self, parent=None):
|
||||||
|
super().__init__(parent)
|
||||||
|
self._size_classifier = SizeClassifier()
|
||||||
|
self.addTab(self._size_classifier, "Size classifier")
|
||||||
|
self._size_classifier.apply.connect(self._on_applySizeClassifier)
|
||||||
|
|
||||||
|
def _on_applySizeClassifier(self):
|
||||||
|
tracks = self.size_classifier.assignedTracks()
|
||||||
|
self.apply_sizeclassifier.emit(tracks)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def size_classifier(self):
|
||||||
|
return self._size_classifier
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
import pickle
|
||||||
|
from fixtracks.info import PACKAGE_ROOT
|
||||||
|
from PySide6.QtWidgets import QApplication
|
||||||
|
|
||||||
|
datafile = PACKAGE_ROOT / "data/merged_small.pkl"
|
||||||
|
print(datafile)
|
||||||
|
with open(datafile, "rb") as f:
|
||||||
|
df = pickle.load(f)
|
||||||
|
|
||||||
|
coords = np.stack(df.keypoints.values,).astype(np.float32)[:,:,:]
|
||||||
|
|
||||||
|
app = QApplication([])
|
||||||
|
window = QWidget()
|
||||||
|
window.setMinimumSize(200, 200)
|
||||||
|
layout = QVBoxLayout()
|
||||||
|
win = SizeClassifier()
|
||||||
|
win.setCoordinates(coords)
|
||||||
|
|
||||||
|
btn = QPushButton("get bounds")
|
||||||
|
btn.clicked.connect(lambda: win.selections())
|
||||||
|
|
||||||
|
layout.addWidget(win)
|
||||||
|
layout.addWidget(btn)
|
||||||
|
window.setLayout(layout)
|
||||||
|
window.show()
|
||||||
|
app.exec()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -10,9 +10,11 @@ from fixtracks.widgets.detectionview import DetectionData
|
|||||||
|
|
||||||
class Skeleton(QGraphicsRectItem):
|
class Skeleton(QGraphicsRectItem):
|
||||||
skeleton_grid = [(0, 1), (1, 2), (1, 3), (1, 4), (2, 5)]
|
skeleton_grid = [(0, 1), (1, 2), (1, 3), (1, 4), (2, 5)]
|
||||||
|
bodyaxis = [0, 1, 2, 5]
|
||||||
|
|
||||||
def __init__(self, x, y, width, height, keypoint_coordinates, brush):
|
def __init__(self, x, y, width, height, keypoint_coordinates, brush):
|
||||||
super().__init__(x, y, width, height)
|
super().__init__(x, y, width, height)
|
||||||
|
self._keypoints = keypoint_coordinates
|
||||||
skeleton_pen = QPen(brush.color())
|
skeleton_pen = QPen(brush.color())
|
||||||
skeleton_pen.setWidthF(1.0)
|
skeleton_pen.setWidthF(1.0)
|
||||||
skeleton_marker = 5
|
skeleton_marker = 5
|
||||||
@ -37,6 +39,12 @@ class Skeleton(QGraphicsRectItem):
|
|||||||
# self.setAcceptHoverEvents(True) # Enable hover events if needed
|
# self.setAcceptHoverEvents(True) # Enable hover events if needed
|
||||||
self.setFlags(QGraphicsRectItem.ItemIsSelectable)
|
self.setFlags(QGraphicsRectItem.ItemIsSelectable)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def length(self):
|
||||||
|
bodykps = self._keypoints[self.bodyaxis, :]
|
||||||
|
dist = np.sum(np.sqrt(np.sum(np.diff(bodykps, axis=0)**2, axis=1)), axis=0)
|
||||||
|
return dist
|
||||||
|
|
||||||
# def mousePressEvent(self, event):
|
# def mousePressEvent(self, event):
|
||||||
# self.signals.clicked.emit(self.data(0), QPointF(event.scenePos().x(), event.scenePos().y()))
|
# self.signals.clicked.emit(self.data(0), QPointF(event.scenePos().x(), event.scenePos().y()))
|
||||||
|
|
||||||
@ -68,7 +76,7 @@ class SkeletonWidget(QWidget):
|
|||||||
font.setPointSize(9)
|
font.setPointSize(9)
|
||||||
self._info_label = QLabel("")
|
self._info_label = QLabel("")
|
||||||
self._info_label.setFont(font)
|
self._info_label.setFont(font)
|
||||||
|
|
||||||
lyt = QVBoxLayout()
|
lyt = QVBoxLayout()
|
||||||
lyt.addWidget(self._view)
|
lyt.addWidget(self._view)
|
||||||
lyt.addWidget(self._info_label)
|
lyt.addWidget(self._info_label)
|
||||||
@ -82,7 +90,11 @@ class SkeletonWidget(QWidget):
|
|||||||
def updateInfo(self, index):
|
def updateInfo(self, index):
|
||||||
if index > -1:
|
if index > -1:
|
||||||
s = self._skeletons[index]
|
s = self._skeletons[index]
|
||||||
self._info_label.setText(f"Detection id {s.data(DetectionData.ID.value)}, track {s.data(DetectionData.TRACK_ID.value)} on frame {s.data(DetectionData.FRAME.value)}")
|
l = s.length
|
||||||
|
i = s.data(DetectionData.ID.value)
|
||||||
|
t = s.data(DetectionData.TRACK_ID.value)
|
||||||
|
f = s.data(DetectionData.FRAME.value)
|
||||||
|
self._info_label.setText(f"Id {i}, track {t} on frame {f}, length {l:.1f} px")
|
||||||
else:
|
else:
|
||||||
self._info_label.setText("")
|
self._info_label.setText("")
|
||||||
|
|
||||||
@ -119,6 +131,8 @@ class SkeletonWidget(QWidget):
|
|||||||
|
|
||||||
def addSkeleton(self, coords, detection_id, frame, track, brush, update=True):
|
def addSkeleton(self, coords, detection_id, frame, track, brush, update=True):
|
||||||
def check_extent(x, y, w, h):
|
def check_extent(x, y, w, h):
|
||||||
|
if x == 0 and y == 0:
|
||||||
|
return
|
||||||
if len(self._skeletons) == 0:
|
if len(self._skeletons) == 0:
|
||||||
self._minx = x
|
self._minx = x
|
||||||
self._maxx = x + w
|
self._maxx = x + w
|
||||||
@ -186,14 +200,9 @@ def main():
|
|||||||
df = pickle.load(f)
|
df = pickle.load(f)
|
||||||
|
|
||||||
focus_brush = QBrush(QColor.fromString("red"))
|
focus_brush = QBrush(QColor.fromString("red"))
|
||||||
second_brush = QBrush(QColor.fromString("blue"))
|
|
||||||
|
|
||||||
scnd_coords = np.stack(df.keypoints[(df.track == 2)].values,).astype(np.float32)[:,:,:]
|
|
||||||
scnd_tracks = df.track[df.track == 2].values
|
|
||||||
scnd_ids = df.track[(df.track == 2)].index.values
|
|
||||||
|
|
||||||
focus_coords = np.stack(df.keypoints[df.track == 1].values,).astype(np.float32)[:,:,:]
|
focus_coords = np.stack(df.keypoints[df.track == 1].values,).astype(np.float32)[:,:,:]
|
||||||
focus_tracks = df.track[df.track == 1].values
|
focus_tracks = df.track[df.track == 1].values
|
||||||
|
focus_frames = df.track[df.track == 1].values
|
||||||
focus_ids = df.track[(df.track == 2)].index.values
|
focus_ids = df.track[(df.track == 2)].index.values
|
||||||
|
|
||||||
app = QApplication([])
|
app = QApplication([])
|
||||||
@ -209,7 +218,8 @@ def main():
|
|||||||
layout.addWidget(btn)
|
layout.addWidget(btn)
|
||||||
# view.addSkeleton(focus_coords[10,:,:], focus_ids[10], focus_brush)
|
# view.addSkeleton(focus_coords[10,:,:], focus_ids[10], focus_brush)
|
||||||
count = 100
|
count = 100
|
||||||
view.addSkeletons(focus_coords[:count,:,:], focus_ids[:count], focus_brush)
|
view.addSkeletons(focus_coords[:count,:,:], focus_ids[:count],
|
||||||
|
focus_frames[:count], focus_tracks[:count], focus_brush)
|
||||||
# view.addSkeletons(scnd_coords[:count,:,:], scnd_ids[:count], second_brush)
|
# view.addSkeletons(scnd_coords[:count,:,:], scnd_ids[:count], second_brush)
|
||||||
|
|
||||||
# view.addSkeletons(focus_coords[:10,:,:], focus_ids[:10], focus_brush)
|
# view.addSkeletons(focus_coords[:10,:,:], focus_ids[:10], focus_brush)
|
||||||
|
@ -14,6 +14,8 @@ from fixtracks.utils.writer import PickleWriter
|
|||||||
from fixtracks.widgets.detectionview import DetectionView, DetectionData
|
from fixtracks.widgets.detectionview import DetectionView, DetectionData
|
||||||
from fixtracks.widgets.detectiontimeline import DetectionTimeline
|
from fixtracks.widgets.detectiontimeline import DetectionTimeline
|
||||||
from fixtracks.widgets.skeleton import SkeletonWidget
|
from fixtracks.widgets.skeleton import SkeletonWidget
|
||||||
|
from fixtracks.widgets.classifier import ClassifierWidget
|
||||||
|
|
||||||
|
|
||||||
class PoseTableModel(QAbstractTableModel):
|
class PoseTableModel(QAbstractTableModel):
|
||||||
column_header = ["frame", "track"]
|
column_header = ["frame", "track"]
|
||||||
@ -259,6 +261,10 @@ class DataController(QObject):
|
|||||||
logging.error("Column %s not in dictionary", col)
|
logging.error("Column %s not in dictionary", col)
|
||||||
return np.nan
|
return np.nan
|
||||||
|
|
||||||
|
@property
|
||||||
|
def numDetections(self):
|
||||||
|
return self._data["track"].shape[0]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def selectionRange(self):
|
def selectionRange(self):
|
||||||
return self._start, self._stop
|
return self._start, self._stop
|
||||||
@ -286,6 +292,12 @@ class DataController(QObject):
|
|||||||
def assignUserSelection(self, track_id):
|
def assignUserSelection(self, track_id):
|
||||||
self._data["track"][self._user_selections] = track_id
|
self._data["track"][self._user_selections] = track_id
|
||||||
|
|
||||||
|
def assignTracks(self, tracks):
|
||||||
|
if len(tracks) != self.numDetections:
|
||||||
|
logging.error("DataController: Size of passed tracks does not match data!")
|
||||||
|
return
|
||||||
|
self._data["track"] = tracks
|
||||||
|
|
||||||
def save(self, filename):
|
def save(self, filename):
|
||||||
export_columns = self._columns.copy()
|
export_columns = self._columns.copy()
|
||||||
export_columns.remove("index")
|
export_columns.remove("index")
|
||||||
@ -299,6 +311,8 @@ class DataController(QObject):
|
|||||||
return 0
|
return 0
|
||||||
return self._data["keypoints"][0].shape[0]
|
return self._data["keypoints"][0].shape[0]
|
||||||
|
|
||||||
|
def coordinates(self):
|
||||||
|
return np.stack(self._data["keypoints"]).astype(np.float32)
|
||||||
|
|
||||||
class FixTracks(QWidget):
|
class FixTracks(QWidget):
|
||||||
back = Signal()
|
back = Signal()
|
||||||
@ -391,9 +405,13 @@ class FixTracks(QWidget):
|
|||||||
btnBox.addWidget(self._progress_bar)
|
btnBox.addWidget(self._progress_bar)
|
||||||
btnBox.addWidget(self._saveBtn)
|
btnBox.addWidget(self._saveBtn)
|
||||||
|
|
||||||
|
self._classifier = ClassifierWidget()
|
||||||
|
self._classifier.apply_sizeclassifier.connect(self.on_classifyBySize)
|
||||||
|
self._classifier.setMaximumWidth(500)
|
||||||
cntrlBox = QHBoxLayout()
|
cntrlBox = QHBoxLayout()
|
||||||
cntrlBox.addItem(QSpacerItem(200, 100, QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Expanding))
|
cntrlBox.addWidget(self._classifier)
|
||||||
cntrlBox.addWidget(self._controls_widget, alignment=Qt.AlignmentFlag.AlignCenter)
|
cntrlBox.addWidget(self._controls_widget, alignment=Qt.AlignmentFlag.AlignCenter)
|
||||||
|
cntrlBox.addItem(QSpacerItem(300, 100, QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Expanding))
|
||||||
|
|
||||||
vbox = QVBoxLayout()
|
vbox = QVBoxLayout()
|
||||||
vbox.addLayout(timelinebox)
|
vbox.addLayout(timelinebox)
|
||||||
@ -412,6 +430,12 @@ class FixTracks(QWidget):
|
|||||||
layout.addWidget(splitter)
|
layout.addWidget(splitter)
|
||||||
self.setLayout(layout)
|
self.setLayout(layout)
|
||||||
|
|
||||||
|
def on_classifyBySize(self, tracks):
|
||||||
|
self._data.setSelectionRange("index", 0, self._data.numDetections)
|
||||||
|
self._data.assignTracks(tracks)
|
||||||
|
self._timeline.setDetectionData(self._data.data)
|
||||||
|
self.update()
|
||||||
|
|
||||||
def on_dataSelection(self):
|
def on_dataSelection(self):
|
||||||
filename = self._data_combo.currentText()
|
filename = self._data_combo.currentText()
|
||||||
if "please select" in filename.lower() or len(filename.strip()) == 0:
|
if "please select" in filename.lower() or len(filename.strip()) == 0:
|
||||||
@ -509,6 +533,8 @@ class FixTracks(QWidget):
|
|||||||
maxframes = self._data.max("frame")
|
maxframes = self._data.max("frame")
|
||||||
rel_width = self._windowspinner.value() / maxframes
|
rel_width = self._windowspinner.value() / maxframes
|
||||||
self._timeline.setWindowWidth(rel_width)
|
self._timeline.setWindowWidth(rel_width)
|
||||||
|
coordinates = self._data.coordinates()
|
||||||
|
self._classifier.size_classifier.setCoordinates(coordinates)
|
||||||
self.update()
|
self.update()
|
||||||
self._saveBtn.setEnabled(True)
|
self._saveBtn.setEnabled(True)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user