diff --git a/fixtracks/widgets/classifier.py b/fixtracks/widgets/classifier.py index 506cd42..2bcf5b2 100644 --- a/fixtracks/widgets/classifier.py +++ b/fixtracks/widgets/classifier.py @@ -102,18 +102,8 @@ class ClassifierWidget(QTabWidget): return self._size_classifier -def main(): - import pickle - from fixtracks.info import PACKAGE_ROOT - from PySide6.QtWidgets import QApplication - - datafile = PACKAGE_ROOT / "data/merged_small.pkl" - print(datafile) - with open(datafile, "rb") as f: - df = pickle.load(f) - - coords = np.stack(df.keypoints.values,).astype(np.float32)[:,:,:] +def test_sizeClassifier(coords): app = QApplication([]) window = QWidget() window.setMinimumSize(200, 200) @@ -121,14 +111,38 @@ def main(): win = SizeClassifier() win.setCoordinates(coords) - btn = QPushButton("get bounds") - btn.clicked.connect(lambda: win.selections()) + layout.addWidget(win) + window.setLayout(layout) + window.show() + app.exec() + +def test_neighborhoodClassifier(coords): + app = QApplication([]) + window = QWidget() + window.setMinimumSize(200, 200) + layout = QVBoxLayout() + win = SizeClassifier() + win.setCoordinates(coords) layout.addWidget(win) - layout.addWidget(btn) window.setLayout(layout) window.show() app.exec() + +def main(): + import pickle + from fixtracks.info import PACKAGE_ROOT + from PySide6.QtWidgets import QApplication + datafile = PACKAGE_ROOT / "data/merged_small.pkl" + print(datafile) + with open(datafile, "rb") as f: + df = pickle.load(f) + + coords = np.stack(df.keypoints.values,).astype(np.float32) + frames = df.frame.values + test_sizeClassifier(coords) + + if __name__ == "__main__": main()