-
Notifications
You must be signed in to change notification settings - Fork 123
Expand file tree
/
Copy pathtest_model_processor.py
More file actions
74 lines (54 loc) · 2.66 KB
/
test_model_processor.py
File metadata and controls
74 lines (54 loc) · 2.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import unittest
from roboflow.config import TASK_CLS, TASK_DET, TASK_OBB, TASK_POSE, TASK_SEG
from roboflow.util.model_processor import (
_detect_rfdetr_task,
_detect_yolo_task,
task_of_model_type,
)
class _FakeModel:
"""Stand-in for an Ultralytics model_instance; only __class__.__name__ matters."""
def _make_fake(name: str):
return type(name, (_FakeModel,), {})()
class TaskOfModelTypeTest(unittest.TestCase):
def test_detect_defaults(self):
self.assertEqual(task_of_model_type("yolov11"), TASK_DET)
self.assertEqual(task_of_model_type("rfdetr-base"), TASK_DET)
self.assertEqual(task_of_model_type("rfdetr-medium"), TASK_DET)
self.assertEqual(task_of_model_type("yolov8"), TASK_DET)
def test_segment(self):
self.assertEqual(task_of_model_type("yolov11-seg"), TASK_SEG)
self.assertEqual(task_of_model_type("rfdetr-seg-medium"), TASK_SEG)
self.assertEqual(task_of_model_type("yolov7-seg"), TASK_SEG)
def test_pose(self):
self.assertEqual(task_of_model_type("yolov11-pose"), TASK_POSE)
def test_classify(self):
self.assertEqual(task_of_model_type("yolov11-cls"), TASK_CLS)
def test_obb(self):
self.assertEqual(task_of_model_type("yolov11-obb"), TASK_OBB)
class DetectYoloTaskTest(unittest.TestCase):
def test_ultralytics_class_names(self):
cases = {
"SegmentationModel": TASK_SEG,
"PoseModel": TASK_POSE,
"ClassificationModel": TASK_CLS,
"OBBModel": TASK_OBB,
"DetectionModel": TASK_DET,
}
for cls_name, expected in cases.items():
self.assertEqual(_detect_yolo_task(_make_fake(cls_name)), expected, cls_name)
def test_unrecognized_returns_none(self):
self.assertIsNone(_detect_yolo_task(_make_fake("SomeOtherModel")))
self.assertIsNone(_detect_yolo_task(None))
class DetectRfdetrTaskTest(unittest.TestCase):
def test_segmentation_model_names(self):
for name in ("RFDETRSegNano", "RFDETRSegSmall", "RFDETRSegMedium", "RFDETRSegLarge"):
self.assertEqual(_detect_rfdetr_task({"model_name": name}), TASK_SEG, name)
def test_detection_model_names(self):
for name in ("RFDETRNano", "RFDETRSmall", "RFDETRMedium", "RFDETRLarge", "RFDETRXLarge"):
self.assertEqual(_detect_rfdetr_task({"model_name": name}), TASK_DET, name)
def test_unrecognized_returns_none(self):
self.assertIsNone(_detect_rfdetr_task({}))
self.assertIsNone(_detect_rfdetr_task({"model_name": None}))
self.assertIsNone(_detect_rfdetr_task({"args": {"segmentation_head": True}}))
if __name__ == "__main__":
unittest.main()