diff --git a/fixtracks/utils/trackingdata.py b/fixtracks/utils/trackingdata.py index e3d8625..cf13de0 100644 --- a/fixtracks/utils/trackingdata.py +++ b/fixtracks/utils/trackingdata.py @@ -199,8 +199,11 @@ class TrackingData(QObject): scores[:, np.setdiff1d(np.arange(scores.shape[1]), nodes)] = 0.0 weighted_coords = self.coordinates(selection=selection) * scores[:, :, np.newaxis] sum_scores = np.sum(scores, axis=1, keepdims=True) - center_of_gravity = np.sum(weighted_coords, axis=1) / sum_scores - return center_of_gravity + + cogs = np.zeros((weighted_coords.shape[0], 2)) + val_ids = np.where(sum_scores > 0.0)[0] + cogs[val_ids] = np.sum(weighted_coords[val_ids], axis=1) / sum_scores[val_ids] + return cogs def animalLength(self, bodyaxis=None): if bodyaxis is None: