55from sahi .models .base import DetectionModel
66from sahi .prediction import ObjectPrediction
77from sahi .utils .compatibility import fix_full_shape_list , fix_shift_amount_list
8- from sahi .utils .import_utils import check_requirements
98
109
1110class RoboflowDetectionModel (DetectionModel ):
@@ -50,6 +49,11 @@ def __init__(
5049 self ._device = device
5150 self ._api_key = api_key
5251
52+ if self ._use_universe :
53+ self .required_packages = list (getattr (self , "required_packages" , [])) + ["inference" ]
54+ else :
55+ self .required_packages = list (getattr (self , "required_packages" , [])) + ["rfdetr" ]
56+
5357 super ().__init__ (
5458 model = model ,
5559 model_path = model_path ,
@@ -66,15 +70,6 @@ def __init__(
6670 if load_at_init :
6771 self .load_model ()
6872
69- def check_dependencies (self ) -> None :
70- """
71- This function can be implemented to ensure model dependencies are installed.
72- """
73- if self ._use_universe :
74- check_requirements (["inference" ])
75- else :
76- check_requirements (["rfdetr" ])
77-
7873 def set_model (self , model : Any , ** kwargs ):
7974 """
8075 This function should be implemented to instantiate a DetectionModel out of an already loaded model
@@ -107,10 +102,10 @@ def load_model(self):
107102 assert model .task_type == "object-detection" , "Roboflow model must be an object detection model."
108103
109104 else :
110- from rfdetr .detr import RFDETRBase , RFDETRLarge
105+ from rfdetr .detr import RFDETRBase , RFDETRLarge , RFDETRMedium , RFDETRNano , RFDETRSmall
111106
112107 model , model_path = self ._model , self .model_path
113- model_names = ("RFDETRBase" , "RFDETRLarge" )
108+ model_names = ("RFDETRBase" , "RFDETRNano" , "RFDETRSmall" , "RFDETRMedium" , " RFDETRLarge" )
114109 if hasattr (model , "__name__" ) and model .__name__ in model_names :
115110 model_params = dict (
116111 resolution = int (self .image_size ) if self .image_size else 560 ,
@@ -121,7 +116,7 @@ def load_model(self):
121116 model_params ["pretrain_weights" ] = model_path
122117
123118 model = model (** model_params )
124- elif isinstance (model , (RFDETRBase , RFDETRLarge )):
119+ elif isinstance (model , (RFDETRBase , RFDETRNano , RFDETRSmall , RFDETRMedium , RFDETRLarge )):
125120 model = model
126121 else :
127122 raise ValueError (
0 commit comments