[classifier] kind of handling mulitple detections in one frame
This commit is contained in:
parent
430ee4fac7
commit
d6b91c25d2
@ -1,7 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from PySide6.QtWidgets import QWidget, QVBoxLayout, QTabWidget, QPushButton, QGraphicsView
|
from PySide6.QtWidgets import QWidget, QVBoxLayout, QTabWidget, QPushButton, QGraphicsView, QTextEdit
|
||||||
from PySide6.QtWidgets import QSpinBox, QProgressBar, QGridLayout, QLabel, QCheckBox, QProgressDialog
|
from PySide6.QtWidgets import QSpinBox, QProgressBar, QGridLayout, QLabel, QCheckBox, QProgressDialog
|
||||||
from PySide6.QtCore import Qt, Signal, Slot, QRunnable, QObject, QThreadPool
|
from PySide6.QtCore import Qt, Signal, Slot, QRunnable, QObject, QThreadPool
|
||||||
from PySide6.QtGui import QBrush, QColor
|
from PySide6.QtGui import QBrush, QColor
|
||||||
@ -24,7 +24,7 @@ class Detection():
|
|||||||
self.userlabeled = userlabeled
|
self.userlabeled = userlabeled
|
||||||
|
|
||||||
class WorkerSignals(QObject):
|
class WorkerSignals(QObject):
|
||||||
error = Signal(str)
|
message = Signal(str)
|
||||||
running = Signal(bool)
|
running = Signal(bool)
|
||||||
progress = Signal(int, int, int)
|
progress = Signal(int, int, int)
|
||||||
currentframe = Signal(int)
|
currentframe = Signal(int)
|
||||||
@ -52,7 +52,7 @@ class ConsitencyDataLoader(QRunnable):
|
|||||||
self.positions = self.data.centerOfGravity()
|
self.positions = self.data.centerOfGravity()
|
||||||
self.orientations = self.data.orientation()
|
self.orientations = self.data.orientation()
|
||||||
self.lengths = self.data.animalLength()
|
self.lengths = self.data.animalLength()
|
||||||
self.bendedness = self.data.bendedness()
|
# self.bendedness = self.data.bendedness()
|
||||||
self.userlabeled = self.data["userlabeled"]
|
self.userlabeled = self.data["userlabeled"]
|
||||||
self.scores = self.data["confidence"] # ignore for now, let's see how far this carries.
|
self.scores = self.data["confidence"] # ignore for now, let's see how far this carries.
|
||||||
self.frames = self.data["frame"]
|
self.frames = self.data["frame"]
|
||||||
@ -94,18 +94,6 @@ class ConsistencyWorker(QRunnable):
|
|||||||
detections.append(d)
|
detections.append(d)
|
||||||
return detections
|
return detections
|
||||||
|
|
||||||
def needs_checking(original, new):
|
|
||||||
res = False
|
|
||||||
for n, o in zip(new, original):
|
|
||||||
res = (o == 1 or o == 2) and n != o
|
|
||||||
if res:
|
|
||||||
print("inverted assignment, needs cross-checking?")
|
|
||||||
if not res:
|
|
||||||
res = len(new) > 1 and (np.all(new == 1) or np.all(new == 2))
|
|
||||||
if res:
|
|
||||||
print("all detections would be assigned to one track!")
|
|
||||||
return res
|
|
||||||
|
|
||||||
def assign_by_distance(d):
|
def assign_by_distance(d):
|
||||||
t1_step = d.frame - last_detections[1].frame
|
t1_step = d.frame - last_detections[1].frame
|
||||||
t2_step = d.frame - last_detections[2].frame
|
t2_step = d.frame - last_detections[2].frame
|
||||||
@ -138,6 +126,15 @@ class ConsistencyWorker(QRunnable):
|
|||||||
most_likely_track = np.argmin(length_differences) + 1
|
most_likely_track = np.argmin(length_differences) + 1
|
||||||
return most_likely_track, length_differences
|
return most_likely_track, length_differences
|
||||||
|
|
||||||
|
def check_multiple_detections(detections):
|
||||||
|
distances = np.zeros((len(detections), len(detections)))
|
||||||
|
for i, d1 in enumerate(detections):
|
||||||
|
for j, d2 in enumerate(detections):
|
||||||
|
distances[i, j] = np.abs(np.linalg.norm(d2.position - d1.position))
|
||||||
|
lowest_dist = np.argmin(np.sum(distances, axis=1))
|
||||||
|
del detections[lowest_dist]
|
||||||
|
return detections
|
||||||
|
|
||||||
unique_frames = np.unique(self.frames)
|
unique_frames = np.unique(self.frames)
|
||||||
steps = int((len(unique_frames) - self._startframe) // 100)
|
steps = int((len(unique_frames) - self._startframe) // 100)
|
||||||
errors = 0
|
errors = 0
|
||||||
@ -150,19 +147,28 @@ class ConsistencyWorker(QRunnable):
|
|||||||
if self._stoprequest:
|
if self._stoprequest:
|
||||||
break
|
break
|
||||||
error = False
|
error = False
|
||||||
|
message = ""
|
||||||
self.signals.currentframe.emit(f)
|
self.signals.currentframe.emit(f)
|
||||||
indices = np.where(self.frames == f)[0]
|
indices = np.where(self.frames == f)[0]
|
||||||
detections = get_detections(f, indices)
|
detections = get_detections(f, indices)
|
||||||
done = [False, False]
|
done = [False, False]
|
||||||
if len(detections) == 0:
|
if len(detections) == 0:
|
||||||
continue
|
continue
|
||||||
|
if len(detections) > 2:
|
||||||
|
message = f"Frame {f}: More than 2 detections ({len(detections)}) in the same frame!"
|
||||||
|
logging.info("ConsistencyTracker: %s", message)
|
||||||
|
self.signals.message.emit(message)
|
||||||
|
while len(detections) > 2:
|
||||||
|
detections = check_multiple_detections(detections)
|
||||||
|
|
||||||
if len(detections) > 1 and np.any([detections[0].userlabeled, detections[1].userlabeled]):
|
if len(detections) > 1 and np.any([detections[0].userlabeled, detections[1].userlabeled]):
|
||||||
# more than one detection
|
# more than one detection
|
||||||
if detections[0].userlabeled and detections[1].userlabeled:
|
if detections[0].userlabeled and detections[1].userlabeled:
|
||||||
if detections[0].track == detections[1].track:
|
if detections[0].track == detections[1].track:
|
||||||
error = True
|
error = True
|
||||||
logging.info("Classification error both detections in the same frame are assigned to the same track!")
|
message = f"Frame {f}: Classification error both detections in the same frame are assigned to the same track!"
|
||||||
|
logging.info("ConsistencyTracker: %s", message)
|
||||||
|
self.signals.message.emit(message)
|
||||||
elif detections[0].userlabeled and not detections[1].userlabeled:
|
elif detections[0].userlabeled and not detections[1].userlabeled:
|
||||||
detections[1].track = 1 if detections[0].track == 2 else 2
|
detections[1].track = 1 if detections[0].track == 2 else 2
|
||||||
elif not detections[0].userlabeled and detections[1].userlabeled:
|
elif not detections[0].userlabeled and detections[1].userlabeled:
|
||||||
@ -178,50 +184,52 @@ class ConsistencyWorker(QRunnable):
|
|||||||
elif len(detections) == 1 and detections[0].userlabeled: # ony one detection and labeled
|
elif len(detections) == 1 and detections[0].userlabeled: # ony one detection and labeled
|
||||||
last_detections[detections[0].track] = detections[0]
|
last_detections[detections[0].track] = detections[0]
|
||||||
done[0] = True
|
done[0] = True
|
||||||
|
|
||||||
if np.sum(done) == len(detections):
|
if np.sum(done) == len(detections):
|
||||||
continue
|
continue
|
||||||
# if f == 2088:
|
|
||||||
# embed()
|
|
||||||
# return
|
|
||||||
if error and self._stoponerror:
|
if error and self._stoponerror:
|
||||||
self.signals.error.emit("Classification error both detections in the same frame are assigned to the same track!")
|
self.signals.message.emit("Tracking stopped at frame %i.", f)
|
||||||
break
|
break
|
||||||
|
elif error:
|
||||||
|
continue
|
||||||
dist_assignments = np.zeros(2, dtype=int)
|
dist_assignments = np.zeros(2, dtype=int)
|
||||||
orientation_assignments = np.zeros_like(dist_assignments)
|
orientation_assignments = np.zeros_like(dist_assignments)
|
||||||
length_assignments = np.zeros_like(dist_assignments)
|
length_assignments = np.zeros_like(dist_assignments)
|
||||||
distances = np.zeros((2, 2))
|
distances = np.zeros((2, 2))
|
||||||
orientations = np.zeros_like(distances)
|
orientations = np.zeros_like(distances)
|
||||||
lengths = np.zeros_like(distances)
|
lengths = np.zeros_like(distances)
|
||||||
assignments = np.zeros((2, 2))
|
assignments = np.zeros(2)
|
||||||
for i, d in enumerate(detections):
|
for i, d in enumerate(detections):
|
||||||
dist_assignments[i], distances[i, :] = assign_by_distance(d)
|
dist_assignments[i], distances[i, :] = assign_by_distance(d)
|
||||||
orientation_assignments[i], orientations[i,:] = assign_by_orientation(d)
|
orientation_assignments[i], orientations[i,:] = assign_by_orientation(d)
|
||||||
length_assignments[i], lengths[i, :] = assign_by_length(d)
|
length_assignments[i], lengths[i, :] = assign_by_length(d)
|
||||||
assignments[i, :] = dist_assignments # (dist_assignments * 10 + orientation_assignments + length_assignments) / 3
|
assignments = dist_assignments # (dist_assignments * 10 + orientation_assignments + length_assignments) / 3
|
||||||
|
|
||||||
diffs = np.diff(assignments, axis=1)
|
|
||||||
error = False
|
error = False
|
||||||
temp = {}
|
temp = {}
|
||||||
message = ""
|
message = ""
|
||||||
for i, d in enumerate(detections):
|
if len(detections) > 1:
|
||||||
temp = {}
|
if assignments[0] == assignments[1]:
|
||||||
if diffs[i] == 0: # both are equally likely
|
|
||||||
d.track = -1
|
d.track = -1
|
||||||
error = True
|
error = True
|
||||||
message = "Classification error both detections in the same frame are assigned to the same track!"
|
message = f"Frame {f}: Classification error: both detections in the same frame are assigned to the same track!"
|
||||||
break
|
break
|
||||||
if diffs[i] < 0:
|
elif assignments[0] != assignments[1]:
|
||||||
d.track = 1
|
detections[0].track = assignments[0]
|
||||||
|
detections[1].track = assignments[1]
|
||||||
|
temp[detections[0].track] = detections[0]
|
||||||
|
temp[detections[1].track] = detections[1]
|
||||||
|
self.tracks[detections[0].id] = detections[0].track
|
||||||
|
self.tracks[detections[1].id] = detections[1].track
|
||||||
else:
|
else:
|
||||||
d.track = 2
|
if np.abs(np.diff(distances[0,:])) > 50: # maybe include the time difference into this?
|
||||||
self.tracks[d.id] = d.track
|
detections[0].track = assignments[0]
|
||||||
if d.track not in temp:
|
temp[detections[0].track] = detections[0]
|
||||||
temp[d.track] = d
|
self.tracks[detections[0].id] = detections[0].track
|
||||||
else:
|
else:
|
||||||
|
self.tracks[detections[0].id] = -1
|
||||||
|
message = f"Frame: {f}: Decision based on distance not safe. Track set to unassigned."
|
||||||
error = True
|
error = True
|
||||||
message = "Double assignment to the same track!"
|
|
||||||
break
|
|
||||||
|
|
||||||
if not error:
|
if not error:
|
||||||
for k in temp:
|
for k in temp:
|
||||||
@ -232,14 +240,14 @@ class ConsistencyWorker(QRunnable):
|
|||||||
self.tracks[idx] = -1
|
self.tracks[idx] = -1
|
||||||
errors += 1
|
errors += 1
|
||||||
if self._stoponerror:
|
if self._stoponerror:
|
||||||
self.signals.error.emit(message)
|
self.signals.message.emit(message)
|
||||||
break
|
break
|
||||||
processed += 1
|
processed += 1
|
||||||
|
|
||||||
if steps > 0 and f % steps == 0:
|
if steps > 0 and f % steps == 0:
|
||||||
progress += 1
|
progress += 1
|
||||||
self.signals.progress.emit(progress, processed, errors)
|
self.signals.progress.emit(progress, processed, errors)
|
||||||
|
self.signals.message.emit("Tracking stopped at frame %i.", f)
|
||||||
self.signals.stopped.emit(f)
|
self.signals.stopped.emit(f)
|
||||||
|
|
||||||
|
|
||||||
@ -487,6 +495,10 @@ class ConsistencyClassifier(QWidget):
|
|||||||
self._stoponerror.setChecked(True)
|
self._stoponerror.setChecked(True)
|
||||||
self.threadpool = QThreadPool()
|
self.threadpool = QThreadPool()
|
||||||
|
|
||||||
|
self._messagebox = QTextEdit()
|
||||||
|
self._messagebox.setFocusPolicy(Qt.NoFocus)
|
||||||
|
self._messagebox.setReadOnly(True)
|
||||||
|
|
||||||
lyt = QGridLayout()
|
lyt = QGridLayout()
|
||||||
lyt.addWidget(QLabel("Start frame:"), 0, 0 )
|
lyt.addWidget(QLabel("Start frame:"), 0, 0 )
|
||||||
lyt.addWidget(self._startframe_spinner, 0, 1, 1, 2)
|
lyt.addWidget(self._startframe_spinner, 0, 1, 1, 2)
|
||||||
@ -499,13 +511,14 @@ class ConsistencyClassifier(QWidget):
|
|||||||
lyt.addWidget(self._assignedlabel, 4, 1)
|
lyt.addWidget(self._assignedlabel, 4, 1)
|
||||||
lyt.addWidget(QLabel("errors/issues"), 5, 0)
|
lyt.addWidget(QLabel("errors/issues"), 5, 0)
|
||||||
lyt.addWidget(self._errorlabel, 5, 1)
|
lyt.addWidget(self._errorlabel, 5, 1)
|
||||||
|
lyt.addWidget(self._messagebox, 6, 0, 2, 3)
|
||||||
lyt.addWidget(self._startbtn, 6, 0)
|
|
||||||
lyt.addWidget(self._stopbtn, 6, 1)
|
lyt.addWidget(self._startbtn, 8, 0)
|
||||||
lyt.addWidget(self._proceedbtn, 6, 2)
|
lyt.addWidget(self._stopbtn, 8, 1)
|
||||||
lyt.addWidget(self._apply_btn, 7, 0, 1, 2)
|
lyt.addWidget(self._proceedbtn, 8, 2)
|
||||||
lyt.addWidget(self._refreshbtn, 7, 2, 1, 1)
|
lyt.addWidget(self._apply_btn, 9, 0, 1, 2)
|
||||||
lyt.addWidget(self._progressbar, 8, 0, 1, 3)
|
lyt.addWidget(self._refreshbtn, 9, 2, 1, 1)
|
||||||
|
lyt.addWidget(self._progressbar, 10, 0, 1, 3)
|
||||||
self.setLayout(lyt)
|
self.setLayout(lyt)
|
||||||
|
|
||||||
def setData(self, data:TrackingData):
|
def setData(self, data:TrackingData):
|
||||||
@ -575,12 +588,16 @@ class ConsistencyClassifier(QWidget):
|
|||||||
self._startframe_spinner.value(), self._stoponerror.isChecked())
|
self._startframe_spinner.value(), self._stoponerror.isChecked())
|
||||||
self._worker.signals.stopped.connect(self.worker_stopped)
|
self._worker.signals.stopped.connect(self.worker_stopped)
|
||||||
self._worker.signals.progress.connect(self.worker_progress)
|
self._worker.signals.progress.connect(self.worker_progress)
|
||||||
|
self._worker.signals.message.connect(self.worker_error)
|
||||||
self._worker.signals.currentframe.connect(self.worker_frame)
|
self._worker.signals.currentframe.connect(self.worker_frame)
|
||||||
self.threadpool.start(self._worker)
|
self.threadpool.start(self._worker)
|
||||||
|
|
||||||
def worker_frame(self, frame):
|
def worker_frame(self, frame):
|
||||||
self._framelabel.setText(str(frame))
|
self._framelabel.setText(str(frame))
|
||||||
|
|
||||||
|
def worker_error(self, msg):
|
||||||
|
self._messagebox.append(msg)
|
||||||
|
|
||||||
def proceed(self):
|
def proceed(self):
|
||||||
self.start()
|
self.start()
|
||||||
|
|
||||||
@ -666,7 +683,7 @@ def main():
|
|||||||
import pickle
|
import pickle
|
||||||
from fixtracks.info import PACKAGE_ROOT
|
from fixtracks.info import PACKAGE_ROOT
|
||||||
|
|
||||||
datafile = PACKAGE_ROOT / "data/merged_small_beginning.pkl"
|
datafile = PACKAGE_ROOT / "data/merged_small_starter.pkl"
|
||||||
|
|
||||||
with open(datafile, "rb") as f:
|
with open(datafile, "rb") as f:
|
||||||
df = pickle.load(f)
|
df = pickle.load(f)
|
||||||
|
Loading…
Reference in New Issue
Block a user