Compare commits

..

2 Commits

2 changed files with 45 additions and 29 deletions

View File

@ -166,10 +166,11 @@ class TrackingData(QObject):
lengths = np.sum(np.sqrt(np.sum(np.diff(bodycoords, axis=1)**2, axis=2)), axis=1) lengths = np.sum(np.sqrt(np.sum(np.diff(bodycoords, axis=1)**2, axis=2)), axis=1)
return lengths return lengths
def orientation(self, head_node=1, tail_node=5): def orientation(self, head_node=0, tail_node=5):
bodycoords = self.coordinates()[:, [head_node, tail_node], :] bodycoords = self.coordinates()[:, [head_node, tail_node], :]
vectors = bodycoords[:, 1, :] - bodycoords[:, 0, :] vectors = bodycoords[:, 1, :] - bodycoords[:, 0, :]
orientations = np.arctan2(vectors[:, 1], vectors[:, 0]) orientations = np.arctan2(vectors[:, 0], vectors[:, 1]) * 180 / np.pi
orientations[orientations < 0] += 360
return orientations return orientations
def bendedness(self, bodyaxis=None): def bendedness(self, bodyaxis=None):
@ -241,10 +242,20 @@ def main():
count += 1 count += 1
return dists return dists
datafile = PACKAGE_ROOT / "data/merged_small.pkl" def plot_skeleton(positions):
skeleton_grid = [(0, 1), (1, 2), (1, 3), (1, 4), (2, 5)]
colors = ["tab:red"]
colors.extend(["tab:blue"]*5)
plt.scatter(positions[:, 0], positions[:, 1], c=colors)
for si, ei in skeleton_grid:
plt.plot([positions[si, 0], positions[ei, 0]],
[positions[si, 1], positions[ei, 1]], color="tab:green")
datafile = PACKAGE_ROOT / "data/merged2.pkl"
with open(datafile, "rb") as f: with open(datafile, "rb") as f:
df = pickle.load(f) df = pickle.load(f)
data = TrackingData() data = TrackingData()
data.setData(as_dict(df)) data.setData(as_dict(df))
all_cogs = data.centerOfGravity() all_cogs = data.centerOfGravity()
@ -253,7 +264,7 @@ def main():
frames = data["frame"] frames = data["frame"]
tracks = data["track"] tracks = data["track"]
bendedness = data.bendedness() bendedness = data.bendedness()
positions = data.coordinates()[[160388, 160389]]
embed() embed()
tracks = data["track"] tracks = data["track"]

View File

@ -16,6 +16,7 @@ class WorkerSignals(QObject):
error = Signal(str) error = Signal(str)
running = Signal(bool) running = Signal(bool)
progress = Signal(int, int, int) progress = Signal(int, int, int)
currentframe = Signal(int)
stopped = Signal(int) stopped = Signal(int)
class ConsitencyDataLoader(QRunnable): class ConsitencyDataLoader(QRunnable):
@ -95,12 +96,15 @@ class ConsistencyWorker(QRunnable):
def assign_by_orientation(f, o): def assign_by_orientation(f, o):
t1_step = f - last_frame[0] t1_step = f - last_frame[0]
t2_step = f - last_frame[1] t2_step = f - last_frame[1]
orientationchange = np.unwrap((last_angle - o)/np.array([t1_step, t2_step])) orientationchange = (last_angle - o)
most_likely_track = np.argmin(orientationchange) + 1 orientationchange[orientationchange > 180] = 360 - orientationchange[orientationchange > 180]
orientationchange /= np.array([t1_step, t2_step])
# orientationchange = np.abs(np.unwrap((last_angle - o)/np.array([t1_step, t2_step])))
most_likely_track = np.argmin(np.abs(orientationchange)) + 1
return most_likely_track, orientationchange return most_likely_track, orientationchange
def assign_by_length(f, o): def assign_by_length(o):
length_difference = (last_length - o) length_difference = np.abs((last_length - o))
most_likely_track = np.argmin(length_difference) + 1 most_likely_track = np.argmin(length_difference) + 1
return most_likely_track, length_difference return most_likely_track, length_difference
@ -130,8 +134,8 @@ class ConsistencyWorker(QRunnable):
steps = int((maxframes - startframe) // 200) steps = int((maxframes - startframe) // 200)
for f in np.unique(self.frames[self.frames > startframe]): for f in np.unique(self.frames[self.frames > startframe]):
print(f)
processed += 1 processed += 1
self.signals.currentframe.emit(f)
if self._stoprequest: if self._stoprequest:
break break
indices = np.where(self.frames == f)[0] indices = np.where(self.frames == f)[0]
@ -146,7 +150,6 @@ class ConsistencyWorker(QRunnable):
lengths = np.zeros_like(distances) lengths = np.zeros_like(distances)
for i, (idx, p) in enumerate(zip(indices, pp)): for i, (idx, p) in enumerate(zip(indices, pp)):
print(i)
if self.userlabeled[idx]: if self.userlabeled[idx]:
print("user") print("user")
userlabeled[i] = True userlabeled[i] = True
@ -157,17 +160,13 @@ class ConsistencyWorker(QRunnable):
continue continue
dist_assignments[i], distances[i, :] = assign_by_distance(f, p) dist_assignments[i], distances[i, :] = assign_by_distance(f, p)
angle_assignments[i], orientations[i,:] = assign_by_orientation(f, self.orientations[idx]) angle_assignments[i], orientations[i,:] = assign_by_orientation(f, self.orientations[idx])
length_assignments[i], lengths[i, :] = assign_by_length(f, self.lengths[idx]) length_assignments[i], lengths[i, :] = assign_by_length(self.lengths[idx])
if np.any(userlabeled): if np.any(userlabeled):
continue continue
# check (re) assignment, update, and proceed # check (re) assignment, update, and proceed
if not needs_checking(originaltracks, dist_assignments): if not needs_checking(originaltracks, dist_assignments):
logging.info("frame %i: Decision based on distance")
do_assignment(f, indices, dist_assignments) do_assignment(f, indices, dist_assignments)
else: else:
print(distances)
print(orientations)
print(lengths)
if not (np.all(length_assignments == 1) or np.all(length_assignments == 2)): # if I find a solution by body length if not (np.all(length_assignments == 1) or np.all(length_assignments == 2)): # if I find a solution by body length
logging.debug("frame %i: Decision based on body length", f) logging.debug("frame %i: Decision based on body length", f)
do_assignment(f, indices, length_assignments) do_assignment(f, indices, length_assignments)
@ -180,7 +179,6 @@ class ConsistencyWorker(QRunnable):
self.tracks[idx] = -1 self.tracks[idx] = -1
errors += 1 errors += 1
if self._stoponerror: if self._stoponerror:
embed()
break break
if steps > 0 and f % steps == 0: if steps > 0 and f % steps == 0:
@ -398,6 +396,7 @@ class ConsistencyClassifier(QWidget):
self._processed_frames = 0 self._processed_frames = 0
self._errorlabel = QLabel() self._errorlabel = QLabel()
self._framelabel = QLabel()
self._errorlabel.setStyleSheet("QLabel { color : red; }") self._errorlabel.setStyleSheet("QLabel { color : red; }")
self._assignedlabel = QLabel() self._assignedlabel = QLabel()
self._maxframeslabel = QLabel() self._maxframeslabel = QLabel()
@ -439,17 +438,19 @@ class ConsistencyClassifier(QWidget):
lyt.addWidget(QLabel("of"), 1, 1, 1, 1) lyt.addWidget(QLabel("of"), 1, 1, 1, 1)
lyt.addWidget(self._maxframeslabel, 1, 2, 1, 1) lyt.addWidget(self._maxframeslabel, 1, 2, 1, 1)
lyt.addWidget(self._stoponerror, 2, 0, 1, 3) lyt.addWidget(self._stoponerror, 2, 0, 1, 3)
lyt.addWidget(QLabel("assigned"), 3, 0) lyt.addWidget(QLabel("Current frame"), 3,0)
lyt.addWidget(self._assignedlabel, 3, 1) lyt.addWidget(self._framelabel, 3,1)
lyt.addWidget(QLabel("errors/issues"), 4, 0) lyt.addWidget(QLabel("assigned"), 4, 0)
lyt.addWidget(self._errorlabel, 4, 1) lyt.addWidget(self._assignedlabel, 4, 1)
lyt.addWidget(QLabel("errors/issues"), 5, 0)
lyt.addWidget(self._startbtn, 5, 0) lyt.addWidget(self._errorlabel, 5, 1)
lyt.addWidget(self._stopbtn, 5, 1)
lyt.addWidget(self._proceedbtn, 5, 2) lyt.addWidget(self._startbtn, 6, 0)
lyt.addWidget(self._apply_btn, 6, 0, 1, 2) lyt.addWidget(self._stopbtn, 6, 1)
lyt.addWidget(self._refreshbtn, 6, 2, 1, 1) lyt.addWidget(self._proceedbtn, 6, 2)
lyt.addWidget(self._progressbar, 7, 0, 1, 3) lyt.addWidget(self._apply_btn, 7, 0, 1, 2)
lyt.addWidget(self._refreshbtn, 7, 2, 1, 1)
lyt.addWidget(self._progressbar, 8, 0, 1, 3)
self.setLayout(lyt) self.setLayout(lyt)
def setData(self, data:TrackingData): def setData(self, data:TrackingData):
@ -514,8 +515,12 @@ class ConsistencyClassifier(QWidget):
self._startframe_spinner.value(), self._stoponerror.isChecked()) self._startframe_spinner.value(), self._stoponerror.isChecked())
self._worker.signals.stopped.connect(self.worker_stopped) self._worker.signals.stopped.connect(self.worker_stopped)
self._worker.signals.progress.connect(self.worker_progress) self._worker.signals.progress.connect(self.worker_progress)
self._worker.signals.currentframe.connect(self.worker_frame)
self.threadpool.start(self._worker) self.threadpool.start(self._worker)
def worker_frame(self, frame):
self._framelabel.setText(str(frame))
def proceed(self): def proceed(self):
self.start() self.start()