Compare commits

..

No commits in common. "15cee494f69ff0f63b5fba13a9523a6e4bfd89c8" and "d8fe654ac87cd1ed7a8612954ea52e12fdaf222f" have entirely different histories.

3 changed files with 10 additions and 180 deletions

View File

@ -1,134 +0,0 @@
import logging
import numpy as np
from PySide6.QtWidgets import QWidget, QVBoxLayout, QTabWidget,QPushButton
from PySide6.QtCore import Signal
from PySide6.QtGui import QBrush, QColor
import pyqtgraph as pg
class SizeClassifier(QWidget):
apply = Signal()
def __init__(self, parent=None):
super().__init__(parent)
self._t1_selection = None
self._t2_selection = None
self._coordinates = None
self._sizes = None
self._plot_widget = self.setupGraph()
self._apply_btn = QPushButton("apply")
self._apply_btn.clicked.connect(lambda: self.apply.emit())
layout = QVBoxLayout()
layout.addWidget(self._plot_widget)
layout.addWidget(self._apply_btn)
self.setLayout(layout)
def setupGraph(self):
track1_brush = QBrush(QColor.fromString("orange"))
track1_brush.color().setAlphaF(0.5)
track2_brush = QBrush(QColor.fromString("green"))
pg.setConfigOptions(antialias=True)
plot_widget = pg.GraphicsLayoutWidget(show=False)
self._t1_selection = pg.LinearRegionItem([100, 200])
self._t1_selection.setZValue(-10) # what is that?
self._t1_selection.setBrush(track1_brush)
self._t2_selection = pg.LinearRegionItem([300,400])
self._t2_selection.setZValue(-10) # what is that?
self._t2_selection.setBrush(track2_brush)
return plot_widget
def estimate_length(self, coords, bodyaxis =None):
if bodyaxis is None:
bodyaxis = [0, 1, 2, 5]
bodycoords = coords[:, bodyaxis, :]
dists = np.sum(np.sqrt(np.sum(np.diff(bodycoords, axis=1)**2, axis=2)), axis=1)
return dists
def estimate_histogram(self, dists, min_threshold=1., max_threshold=99.):
min_length = np.percentile(dists, min_threshold)
max_length = np.percentile(dists, max_threshold)
bins = np.linspace(0.5 * min_length, 1.5 * max_length, 100)
hist, edges = np.histogram(dists, bins=bins, density=True)
return hist, edges
def setCoordinates(self, coordinates):
self._coordinates = coordinates
self._sizes = self.estimate_length(coordinates)
n, e = self.estimate_histogram(self._sizes)
plot = self._plot_widget.addPlot()
bgi = pg.BarGraphItem(x0=e[:-1], x1=e[1:], height=n, pen='w', brush=(0,0,255,150))
plot.addItem(bgi)
plot.setLabel('left', "prob. density")
plot.setLabel('bottom', "bodylength", units="px")
plot.addItem(self._t1_selection)
plot.addItem(self._t2_selection)
def selections(self, track1=True):
if track1:
return self._t1_selection.getRegion()
else:
return self._t2_selection.getRegion()
def assignedTracks(self):
tracks = np.ones_like(self._sizes, dtype=int) * -1
t1lower, t1upper = self.selections()
t2lower, t2upper = self.selections(False)
tracks[(self._sizes >= t1lower) & (self._sizes < t1upper)] = 1
tracks[(self._sizes >= t2lower) & (self._sizes < t2upper)] = 2
return tracks
class ClassifierWidget(QTabWidget):
apply_sizeclassifier = Signal(np.ndarray)
def __init__(self, parent=None):
super().__init__(parent)
self._size_classifier = SizeClassifier()
self.addTab(self._size_classifier, "Size classifier")
self._size_classifier.apply.connect(self._on_applySizeClassifier)
def _on_applySizeClassifier(self):
tracks = self.size_classifier.assignedTracks()
self.apply_sizeclassifier.emit(tracks)
@property
def size_classifier(self):
return self._size_classifier
def main():
import pickle
from fixtracks.info import PACKAGE_ROOT
from PySide6.QtWidgets import QApplication
datafile = PACKAGE_ROOT / "data/merged_small.pkl"
print(datafile)
with open(datafile, "rb") as f:
df = pickle.load(f)
coords = np.stack(df.keypoints.values,).astype(np.float32)[:,:,:]
app = QApplication([])
window = QWidget()
window.setMinimumSize(200, 200)
layout = QVBoxLayout()
win = SizeClassifier()
win.setCoordinates(coords)
btn = QPushButton("get bounds")
btn.clicked.connect(lambda: win.selections())
layout.addWidget(win)
layout.addWidget(btn)
window.setLayout(layout)
window.show()
app.exec()
if __name__ == "__main__":
main()

View File

@ -10,11 +10,9 @@ from fixtracks.widgets.detectionview import DetectionData
class Skeleton(QGraphicsRectItem): class Skeleton(QGraphicsRectItem):
skeleton_grid = [(0, 1), (1, 2), (1, 3), (1, 4), (2, 5)] skeleton_grid = [(0, 1), (1, 2), (1, 3), (1, 4), (2, 5)]
bodyaxis = [0, 1, 2, 5]
def __init__(self, x, y, width, height, keypoint_coordinates, brush): def __init__(self, x, y, width, height, keypoint_coordinates, brush):
super().__init__(x, y, width, height) super().__init__(x, y, width, height)
self._keypoints = keypoint_coordinates
skeleton_pen = QPen(brush.color()) skeleton_pen = QPen(brush.color())
skeleton_pen.setWidthF(1.0) skeleton_pen.setWidthF(1.0)
skeleton_marker = 5 skeleton_marker = 5
@ -39,12 +37,6 @@ class Skeleton(QGraphicsRectItem):
# self.setAcceptHoverEvents(True) # Enable hover events if needed # self.setAcceptHoverEvents(True) # Enable hover events if needed
self.setFlags(QGraphicsRectItem.ItemIsSelectable) self.setFlags(QGraphicsRectItem.ItemIsSelectable)
@property
def length(self):
bodykps = self._keypoints[self.bodyaxis, :]
dist = np.sum(np.sqrt(np.sum(np.diff(bodykps, axis=0)**2, axis=1)), axis=0)
return dist
# def mousePressEvent(self, event): # def mousePressEvent(self, event):
# self.signals.clicked.emit(self.data(0), QPointF(event.scenePos().x(), event.scenePos().y())) # self.signals.clicked.emit(self.data(0), QPointF(event.scenePos().x(), event.scenePos().y()))
@ -76,7 +68,7 @@ class SkeletonWidget(QWidget):
font.setPointSize(9) font.setPointSize(9)
self._info_label = QLabel("") self._info_label = QLabel("")
self._info_label.setFont(font) self._info_label.setFont(font)
lyt = QVBoxLayout() lyt = QVBoxLayout()
lyt.addWidget(self._view) lyt.addWidget(self._view)
lyt.addWidget(self._info_label) lyt.addWidget(self._info_label)
@ -90,11 +82,7 @@ class SkeletonWidget(QWidget):
def updateInfo(self, index): def updateInfo(self, index):
if index > -1: if index > -1:
s = self._skeletons[index] s = self._skeletons[index]
l = s.length self._info_label.setText(f"Detection id {s.data(DetectionData.ID.value)}, track {s.data(DetectionData.TRACK_ID.value)} on frame {s.data(DetectionData.FRAME.value)}")
i = s.data(DetectionData.ID.value)
t = s.data(DetectionData.TRACK_ID.value)
f = s.data(DetectionData.FRAME.value)
self._info_label.setText(f"Id {i}, track {t} on frame {f}, length {l:.1f} px")
else: else:
self._info_label.setText("") self._info_label.setText("")
@ -131,8 +119,6 @@ class SkeletonWidget(QWidget):
def addSkeleton(self, coords, detection_id, frame, track, brush, update=True): def addSkeleton(self, coords, detection_id, frame, track, brush, update=True):
def check_extent(x, y, w, h): def check_extent(x, y, w, h):
if x == 0 and y == 0:
return
if len(self._skeletons) == 0: if len(self._skeletons) == 0:
self._minx = x self._minx = x
self._maxx = x + w self._maxx = x + w
@ -200,9 +186,14 @@ def main():
df = pickle.load(f) df = pickle.load(f)
focus_brush = QBrush(QColor.fromString("red")) focus_brush = QBrush(QColor.fromString("red"))
second_brush = QBrush(QColor.fromString("blue"))
scnd_coords = np.stack(df.keypoints[(df.track == 2)].values,).astype(np.float32)[:,:,:]
scnd_tracks = df.track[df.track == 2].values
scnd_ids = df.track[(df.track == 2)].index.values
focus_coords = np.stack(df.keypoints[df.track == 1].values,).astype(np.float32)[:,:,:] focus_coords = np.stack(df.keypoints[df.track == 1].values,).astype(np.float32)[:,:,:]
focus_tracks = df.track[df.track == 1].values focus_tracks = df.track[df.track == 1].values
focus_frames = df.track[df.track == 1].values
focus_ids = df.track[(df.track == 2)].index.values focus_ids = df.track[(df.track == 2)].index.values
app = QApplication([]) app = QApplication([])
@ -218,8 +209,7 @@ def main():
layout.addWidget(btn) layout.addWidget(btn)
# view.addSkeleton(focus_coords[10,:,:], focus_ids[10], focus_brush) # view.addSkeleton(focus_coords[10,:,:], focus_ids[10], focus_brush)
count = 100 count = 100
view.addSkeletons(focus_coords[:count,:,:], focus_ids[:count], view.addSkeletons(focus_coords[:count,:,:], focus_ids[:count], focus_brush)
focus_frames[:count], focus_tracks[:count], focus_brush)
# view.addSkeletons(scnd_coords[:count,:,:], scnd_ids[:count], second_brush) # view.addSkeletons(scnd_coords[:count,:,:], scnd_ids[:count], second_brush)
# view.addSkeletons(focus_coords[:10,:,:], focus_ids[:10], focus_brush) # view.addSkeletons(focus_coords[:10,:,:], focus_ids[:10], focus_brush)

View File

@ -14,8 +14,6 @@ from fixtracks.utils.writer import PickleWriter
from fixtracks.widgets.detectionview import DetectionView, DetectionData from fixtracks.widgets.detectionview import DetectionView, DetectionData
from fixtracks.widgets.detectiontimeline import DetectionTimeline from fixtracks.widgets.detectiontimeline import DetectionTimeline
from fixtracks.widgets.skeleton import SkeletonWidget from fixtracks.widgets.skeleton import SkeletonWidget
from fixtracks.widgets.classifier import ClassifierWidget
class PoseTableModel(QAbstractTableModel): class PoseTableModel(QAbstractTableModel):
column_header = ["frame", "track"] column_header = ["frame", "track"]
@ -261,10 +259,6 @@ class DataController(QObject):
logging.error("Column %s not in dictionary", col) logging.error("Column %s not in dictionary", col)
return np.nan return np.nan
@property
def numDetections(self):
return self._data["track"].shape[0]
@property @property
def selectionRange(self): def selectionRange(self):
return self._start, self._stop return self._start, self._stop
@ -292,12 +286,6 @@ class DataController(QObject):
def assignUserSelection(self, track_id): def assignUserSelection(self, track_id):
self._data["track"][self._user_selections] = track_id self._data["track"][self._user_selections] = track_id
def assignTracks(self, tracks):
if len(tracks) != self.numDetections:
logging.error("DataController: Size of passed tracks does not match data!")
return
self._data["track"] = tracks
def save(self, filename): def save(self, filename):
export_columns = self._columns.copy() export_columns = self._columns.copy()
export_columns.remove("index") export_columns.remove("index")
@ -311,8 +299,6 @@ class DataController(QObject):
return 0 return 0
return self._data["keypoints"][0].shape[0] return self._data["keypoints"][0].shape[0]
def coordinates(self):
return np.stack(self._data["keypoints"]).astype(np.float32)
class FixTracks(QWidget): class FixTracks(QWidget):
back = Signal() back = Signal()
@ -405,13 +391,9 @@ class FixTracks(QWidget):
btnBox.addWidget(self._progress_bar) btnBox.addWidget(self._progress_bar)
btnBox.addWidget(self._saveBtn) btnBox.addWidget(self._saveBtn)
self._classifier = ClassifierWidget()
self._classifier.apply_sizeclassifier.connect(self.on_classifyBySize)
self._classifier.setMaximumWidth(500)
cntrlBox = QHBoxLayout() cntrlBox = QHBoxLayout()
cntrlBox.addWidget(self._classifier) cntrlBox.addItem(QSpacerItem(200, 100, QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Expanding))
cntrlBox.addWidget(self._controls_widget, alignment=Qt.AlignmentFlag.AlignCenter) cntrlBox.addWidget(self._controls_widget, alignment=Qt.AlignmentFlag.AlignCenter)
cntrlBox.addItem(QSpacerItem(300, 100, QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Expanding))
vbox = QVBoxLayout() vbox = QVBoxLayout()
vbox.addLayout(timelinebox) vbox.addLayout(timelinebox)
@ -430,12 +412,6 @@ class FixTracks(QWidget):
layout.addWidget(splitter) layout.addWidget(splitter)
self.setLayout(layout) self.setLayout(layout)
def on_classifyBySize(self, tracks):
self._data.setSelectionRange("index", 0, self._data.numDetections)
self._data.assignTracks(tracks)
self._timeline.setDetectionData(self._data.data)
self.update()
def on_dataSelection(self): def on_dataSelection(self):
filename = self._data_combo.currentText() filename = self._data_combo.currentText()
if "please select" in filename.lower() or len(filename.strip()) == 0: if "please select" in filename.lower() or len(filename.strip()) == 0:
@ -533,8 +509,6 @@ class FixTracks(QWidget):
maxframes = self._data.max("frame") maxframes = self._data.max("frame")
rel_width = self._windowspinner.value() / maxframes rel_width = self._windowspinner.value() / maxframes
self._timeline.setWindowWidth(rel_width) self._timeline.setWindowWidth(rel_width)
coordinates = self._data.coordinates()
self._classifier.size_classifier.setCoordinates(coordinates)
self.update() self.update()
self._saveBtn.setEnabled(True) self._saveBtn.setEnabled(True)