Skip to content

Commit 675ad46

Browse files
committed
feat: ✨ Enhance RoboflowDetectionModel to support additional RFDETR variants and remove check_dependencies method
Signed-off-by: Onuralp SEZER <onuralp@ultralytics.com>
1 parent 0aca3fc commit 675ad46

File tree

1 file changed

+8
-13
lines changed

1 file changed

+8
-13
lines changed

sahi/models/roboflow.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from sahi.models.base import DetectionModel
66
from sahi.prediction import ObjectPrediction
77
from sahi.utils.compatibility import fix_full_shape_list, fix_shift_amount_list
8-
from sahi.utils.import_utils import check_requirements
98

109

1110
class RoboflowDetectionModel(DetectionModel):
@@ -50,6 +49,11 @@ def __init__(
5049
self._device = device
5150
self._api_key = api_key
5251

52+
if self._use_universe:
53+
self.required_packages = list(getattr(self, "required_packages", [])) + ["inference"]
54+
else:
55+
self.required_packages = list(getattr(self, "required_packages", [])) + ["rfdetr"]
56+
5357
super().__init__(
5458
model=model,
5559
model_path=model_path,
@@ -66,15 +70,6 @@ def __init__(
6670
if load_at_init:
6771
self.load_model()
6872

69-
def check_dependencies(self) -> None:
70-
"""
71-
This function can be implemented to ensure model dependencies are installed.
72-
"""
73-
if self._use_universe:
74-
check_requirements(["inference"])
75-
else:
76-
check_requirements(["rfdetr"])
77-
7873
def set_model(self, model: Any, **kwargs):
7974
"""
8075
This function should be implemented to instantiate a DetectionModel out of an already loaded model
@@ -107,10 +102,10 @@ def load_model(self):
107102
assert model.task_type == "object-detection", "Roboflow model must be an object detection model."
108103

109104
else:
110-
from rfdetr.detr import RFDETRBase, RFDETRLarge
105+
from rfdetr.detr import RFDETRBase, RFDETRLarge, RFDETRMedium, RFDETRNano, RFDETRSmall
111106

112107
model, model_path = self._model, self.model_path
113-
model_names = ("RFDETRBase", "RFDETRLarge")
108+
model_names = ("RFDETRBase", "RFDETRNano", "RFDETRSmall", "RFDETRMedium", "RFDETRLarge")
114109
if hasattr(model, "__name__") and model.__name__ in model_names:
115110
model_params = dict(
116111
resolution=int(self.image_size) if self.image_size else 560,
@@ -121,7 +116,7 @@ def load_model(self):
121116
model_params["pretrain_weights"] = model_path
122117

123118
model = model(**model_params)
124-
elif isinstance(model, (RFDETRBase, RFDETRLarge)):
119+
elif isinstance(model, (RFDETRBase, RFDETRNano, RFDETRSmall, RFDETRMedium, RFDETRLarge)):
125120
model = model
126121
else:
127122
raise ValueError(

0 commit comments

Comments
 (0)