@@ -69,10 +69,11 @@ def batched_nms(
69
69
_log_api_usage_once (batched_nms )
70
70
# Benchmarks that drove the following thresholds are at
71
71
# https://github.com/pytorch/vision/issues/1311#issuecomment-781329339
72
- if boxes .numel () > (4000 if boxes .device .type == "cpu" else 20000 ) and not torchvision ._is_tracing ():
73
- return _batched_nms_vanilla (boxes , scores , idxs , iou_threshold )
74
- else :
75
- return _batched_nms_coordinate_trick (boxes , scores , idxs , iou_threshold )
72
+ return _batched_nms_vanilla (boxes , scores , idxs , iou_threshold )
73
+ #if boxes.numel() > (4000 if boxes.device.type == "cpu" else 20000) and not torchvision._is_tracing():
74
+ # return _batched_nms_vanilla(boxes, scores, idxs, iou_threshold)
75
+ #else:
76
+ # return _batched_nms_coordinate_trick(boxes, scores, idxs, iou_threshold)
76
77
77
78
78
79
@torch .jit ._script_if_tracing
@@ -104,7 +105,8 @@ def _batched_nms_vanilla(
104
105
) -> Tensor :
105
106
# Based on Detectron2 implementation, just manually call nms() on each class independently
106
107
keep_mask = torch .zeros_like (scores , dtype = torch .bool )
107
- for class_id in torch .unique (idxs ):
108
+ #for class_id in torch.unique(idxs):
109
+ for class_id in idxs :
108
110
curr_indices = torch .where (idxs == class_id )[0 ]
109
111
curr_keep_indices = nms (boxes [curr_indices ], scores [curr_indices ], iou_threshold )
110
112
keep_mask [curr_indices [curr_keep_indices ]] = True
0 commit comments