[classifier] kind of handling mulitple detections in one frame
This commit is contained in:
parent
430ee4fac7
commit
d6b91c25d2
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import numpy as np
|
||||
|
||||
from PySide6.QtWidgets import QWidget, QVBoxLayout, QTabWidget, QPushButton, QGraphicsView
|
||||
from PySide6.QtWidgets import QWidget, QVBoxLayout, QTabWidget, QPushButton, QGraphicsView, QTextEdit
|
||||
from PySide6.QtWidgets import QSpinBox, QProgressBar, QGridLayout, QLabel, QCheckBox, QProgressDialog
|
||||
from PySide6.QtCore import Qt, Signal, Slot, QRunnable, QObject, QThreadPool
|
||||
from PySide6.QtGui import QBrush, QColor
|
||||
@ -24,7 +24,7 @@ class Detection():
|
||||
self.userlabeled = userlabeled
|
||||
|
||||
class WorkerSignals(QObject):
|
||||
error = Signal(str)
|
||||
message = Signal(str)
|
||||
running = Signal(bool)
|
||||
progress = Signal(int, int, int)
|
||||
currentframe = Signal(int)
|
||||
@ -52,7 +52,7 @@ class ConsitencyDataLoader(QRunnable):
|
||||
self.positions = self.data.centerOfGravity()
|
||||
self.orientations = self.data.orientation()
|
||||
self.lengths = self.data.animalLength()
|
||||
self.bendedness = self.data.bendedness()
|
||||
# self.bendedness = self.data.bendedness()
|
||||
self.userlabeled = self.data["userlabeled"]
|
||||
self.scores = self.data["confidence"] # ignore for now, let's see how far this carries.
|
||||
self.frames = self.data["frame"]
|
||||
@ -94,25 +94,13 @@ class ConsistencyWorker(QRunnable):
|
||||
detections.append(d)
|
||||
return detections
|
||||
|
||||
def needs_checking(original, new):
|
||||
res = False
|
||||
for n, o in zip(new, original):
|
||||
res = (o == 1 or o == 2) and n != o
|
||||
if res:
|
||||
print("inverted assignment, needs cross-checking?")
|
||||
if not res:
|
||||
res = len(new) > 1 and (np.all(new == 1) or np.all(new == 2))
|
||||
if res:
|
||||
print("all detections would be assigned to one track!")
|
||||
return res
|
||||
|
||||
def assign_by_distance(d):
|
||||
t1_step = d.frame - last_detections[1].frame
|
||||
t2_step = d.frame - last_detections[2].frame
|
||||
if t1_step == 0 or t2_step == 0:
|
||||
print(f"framecount is zero! current frame {f}, last frame {last_detections[1].frame} and {last_detections[2].frame}")
|
||||
distance_to_trackone = np.linalg.norm(d.position - last_detections[1].position)/t1_step
|
||||
distance_to_tracktwo = np.linalg.norm(d.position - last_detections[2].position)/t2_step
|
||||
distance_to_trackone = np.linalg.norm(d.position - last_detections[1].position) /t1_step
|
||||
distance_to_tracktwo = np.linalg.norm(d.position - last_detections[2].position) /t2_step
|
||||
most_likely_track = np.argmin([distance_to_trackone, distance_to_tracktwo]) + 1
|
||||
distances = np.zeros(2)
|
||||
distances[0] = distance_to_trackone
|
||||
@ -138,6 +126,15 @@ class ConsistencyWorker(QRunnable):
|
||||
most_likely_track = np.argmin(length_differences) + 1
|
||||
return most_likely_track, length_differences
|
||||
|
||||
def check_multiple_detections(detections):
|
||||
distances = np.zeros((len(detections), len(detections)))
|
||||
for i, d1 in enumerate(detections):
|
||||
for j, d2 in enumerate(detections):
|
||||
distances[i, j] = np.abs(np.linalg.norm(d2.position - d1.position))
|
||||
lowest_dist = np.argmin(np.sum(distances, axis=1))
|
||||
del detections[lowest_dist]
|
||||
return detections
|
||||
|
||||
unique_frames = np.unique(self.frames)
|
||||
steps = int((len(unique_frames) - self._startframe) // 100)
|
||||
errors = 0
|
||||
@ -150,19 +147,28 @@ class ConsistencyWorker(QRunnable):
|
||||
if self._stoprequest:
|
||||
break
|
||||
error = False
|
||||
message = ""
|
||||
self.signals.currentframe.emit(f)
|
||||
indices = np.where(self.frames == f)[0]
|
||||
detections = get_detections(f, indices)
|
||||
done = [False, False]
|
||||
if len(detections) == 0:
|
||||
continue
|
||||
if len(detections) > 2:
|
||||
message = f"Frame {f}: More than 2 detections ({len(detections)}) in the same frame!"
|
||||
logging.info("ConsistencyTracker: %s", message)
|
||||
self.signals.message.emit(message)
|
||||
while len(detections) > 2:
|
||||
detections = check_multiple_detections(detections)
|
||||
|
||||
if len(detections) > 1 and np.any([detections[0].userlabeled, detections[1].userlabeled]):
|
||||
# more than one detection
|
||||
if detections[0].userlabeled and detections[1].userlabeled:
|
||||
if detections[0].track == detections[1].track:
|
||||
error = True
|
||||
logging.info("Classification error both detections in the same frame are assigned to the same track!")
|
||||
message = f"Frame {f}: Classification error both detections in the same frame are assigned to the same track!"
|
||||
logging.info("ConsistencyTracker: %s", message)
|
||||
self.signals.message.emit(message)
|
||||
elif detections[0].userlabeled and not detections[1].userlabeled:
|
||||
detections[1].track = 1 if detections[0].track == 2 else 2
|
||||
elif not detections[0].userlabeled and detections[1].userlabeled:
|
||||
@ -178,50 +184,52 @@ class ConsistencyWorker(QRunnable):
|
||||
elif len(detections) == 1 and detections[0].userlabeled: # ony one detection and labeled
|
||||
last_detections[detections[0].track] = detections[0]
|
||||
done[0] = True
|
||||
|
||||
if np.sum(done) == len(detections):
|
||||
continue
|
||||
# if f == 2088:
|
||||
# embed()
|
||||
# return
|
||||
|
||||
if error and self._stoponerror:
|
||||
self.signals.error.emit("Classification error both detections in the same frame are assigned to the same track!")
|
||||
self.signals.message.emit("Tracking stopped at frame %i.", f)
|
||||
break
|
||||
elif error:
|
||||
continue
|
||||
dist_assignments = np.zeros(2, dtype=int)
|
||||
orientation_assignments = np.zeros_like(dist_assignments)
|
||||
length_assignments = np.zeros_like(dist_assignments)
|
||||
distances = np.zeros((2, 2))
|
||||
orientations = np.zeros_like(distances)
|
||||
lengths = np.zeros_like(distances)
|
||||
assignments = np.zeros((2, 2))
|
||||
assignments = np.zeros(2)
|
||||
for i, d in enumerate(detections):
|
||||
dist_assignments[i], distances[i, :] = assign_by_distance(d)
|
||||
orientation_assignments[i], orientations[i,:] = assign_by_orientation(d)
|
||||
length_assignments[i], lengths[i, :] = assign_by_length(d)
|
||||
assignments[i, :] = dist_assignments # (dist_assignments * 10 + orientation_assignments + length_assignments) / 3
|
||||
assignments = dist_assignments # (dist_assignments * 10 + orientation_assignments + length_assignments) / 3
|
||||
|
||||
diffs = np.diff(assignments, axis=1)
|
||||
error = False
|
||||
temp = {}
|
||||
message = ""
|
||||
for i, d in enumerate(detections):
|
||||
temp = {}
|
||||
if diffs[i] == 0: # both are equally likely
|
||||
if len(detections) > 1:
|
||||
if assignments[0] == assignments[1]:
|
||||
d.track = -1
|
||||
error = True
|
||||
message = "Classification error both detections in the same frame are assigned to the same track!"
|
||||
message = f"Frame {f}: Classification error: both detections in the same frame are assigned to the same track!"
|
||||
break
|
||||
if diffs[i] < 0:
|
||||
d.track = 1
|
||||
else:
|
||||
d.track = 2
|
||||
self.tracks[d.id] = d.track
|
||||
if d.track not in temp:
|
||||
temp[d.track] = d
|
||||
elif assignments[0] != assignments[1]:
|
||||
detections[0].track = assignments[0]
|
||||
detections[1].track = assignments[1]
|
||||
temp[detections[0].track] = detections[0]
|
||||
temp[detections[1].track] = detections[1]
|
||||
self.tracks[detections[0].id] = detections[0].track
|
||||
self.tracks[detections[1].id] = detections[1].track
|
||||
else:
|
||||
if np.abs(np.diff(distances[0,:])) > 50: # maybe include the time difference into this?
|
||||
detections[0].track = assignments[0]
|
||||
temp[detections[0].track] = detections[0]
|
||||
self.tracks[detections[0].id] = detections[0].track
|
||||
else:
|
||||
self.tracks[detections[0].id] = -1
|
||||
message = f"Frame: {f}: Decision based on distance not safe. Track set to unassigned."
|
||||
error = True
|
||||
message = "Double assignment to the same track!"
|
||||
break
|
||||
|
||||
if not error:
|
||||
for k in temp:
|
||||
@ -232,14 +240,14 @@ class ConsistencyWorker(QRunnable):
|
||||
self.tracks[idx] = -1
|
||||
errors += 1
|
||||
if self._stoponerror:
|
||||
self.signals.error.emit(message)
|
||||
self.signals.message.emit(message)
|
||||
break
|
||||
processed += 1
|
||||
|
||||
if steps > 0 and f % steps == 0:
|
||||
progress += 1
|
||||
self.signals.progress.emit(progress, processed, errors)
|
||||
|
||||
self.signals.message.emit("Tracking stopped at frame %i.", f)
|
||||
self.signals.stopped.emit(f)
|
||||
|
||||
|
||||
@ -487,6 +495,10 @@ class ConsistencyClassifier(QWidget):
|
||||
self._stoponerror.setChecked(True)
|
||||
self.threadpool = QThreadPool()
|
||||
|
||||
self._messagebox = QTextEdit()
|
||||
self._messagebox.setFocusPolicy(Qt.NoFocus)
|
||||
self._messagebox.setReadOnly(True)
|
||||
|
||||
lyt = QGridLayout()
|
||||
lyt.addWidget(QLabel("Start frame:"), 0, 0 )
|
||||
lyt.addWidget(self._startframe_spinner, 0, 1, 1, 2)
|
||||
@ -499,13 +511,14 @@ class ConsistencyClassifier(QWidget):
|
||||
lyt.addWidget(self._assignedlabel, 4, 1)
|
||||
lyt.addWidget(QLabel("errors/issues"), 5, 0)
|
||||
lyt.addWidget(self._errorlabel, 5, 1)
|
||||
|
||||
lyt.addWidget(self._startbtn, 6, 0)
|
||||
lyt.addWidget(self._stopbtn, 6, 1)
|
||||
lyt.addWidget(self._proceedbtn, 6, 2)
|
||||
lyt.addWidget(self._apply_btn, 7, 0, 1, 2)
|
||||
lyt.addWidget(self._refreshbtn, 7, 2, 1, 1)
|
||||
lyt.addWidget(self._progressbar, 8, 0, 1, 3)
|
||||
lyt.addWidget(self._messagebox, 6, 0, 2, 3)
|
||||
|
||||
lyt.addWidget(self._startbtn, 8, 0)
|
||||
lyt.addWidget(self._stopbtn, 8, 1)
|
||||
lyt.addWidget(self._proceedbtn, 8, 2)
|
||||
lyt.addWidget(self._apply_btn, 9, 0, 1, 2)
|
||||
lyt.addWidget(self._refreshbtn, 9, 2, 1, 1)
|
||||
lyt.addWidget(self._progressbar, 10, 0, 1, 3)
|
||||
self.setLayout(lyt)
|
||||
|
||||
def setData(self, data:TrackingData):
|
||||
@ -575,12 +588,16 @@ class ConsistencyClassifier(QWidget):
|
||||
self._startframe_spinner.value(), self._stoponerror.isChecked())
|
||||
self._worker.signals.stopped.connect(self.worker_stopped)
|
||||
self._worker.signals.progress.connect(self.worker_progress)
|
||||
self._worker.signals.message.connect(self.worker_error)
|
||||
self._worker.signals.currentframe.connect(self.worker_frame)
|
||||
self.threadpool.start(self._worker)
|
||||
|
||||
def worker_frame(self, frame):
|
||||
self._framelabel.setText(str(frame))
|
||||
|
||||
def worker_error(self, msg):
|
||||
self._messagebox.append(msg)
|
||||
|
||||
def proceed(self):
|
||||
self.start()
|
||||
|
||||
@ -666,7 +683,7 @@ def main():
|
||||
import pickle
|
||||
from fixtracks.info import PACKAGE_ROOT
|
||||
|
||||
datafile = PACKAGE_ROOT / "data/merged_small_beginning.pkl"
|
||||
datafile = PACKAGE_ROOT / "data/merged_small_starter.pkl"
|
||||
|
||||
with open(datafile, "rb") as f:
|
||||
df = pickle.load(f)
|
||||
|
Loading…
Reference in New Issue
Block a user