Skip to content

Commit 916ba03

Browse files
committed
[DNL] executorch export faster-rcnn
1 parent 6279faa commit 916ba03

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

torchvision/ops/boxes.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,11 @@ def batched_nms(
6969
_log_api_usage_once(batched_nms)
7070
# Benchmarks that drove the following thresholds are at
7171
# https://github.com/pytorch/vision/issues/1311#issuecomment-781329339
72-
if boxes.numel() > (4000 if boxes.device.type == "cpu" else 20000) and not torchvision._is_tracing():
73-
return _batched_nms_vanilla(boxes, scores, idxs, iou_threshold)
74-
else:
75-
return _batched_nms_coordinate_trick(boxes, scores, idxs, iou_threshold)
72+
return _batched_nms_vanilla(boxes, scores, idxs, iou_threshold)
73+
#if boxes.numel() > (4000 if boxes.device.type == "cpu" else 20000) and not torchvision._is_tracing():
74+
# return _batched_nms_vanilla(boxes, scores, idxs, iou_threshold)
75+
#else:
76+
# return _batched_nms_coordinate_trick(boxes, scores, idxs, iou_threshold)
7677

7778

7879
@torch.jit._script_if_tracing
@@ -104,7 +105,8 @@ def _batched_nms_vanilla(
104105
) -> Tensor:
105106
# Based on Detectron2 implementation, just manually call nms() on each class independently
106107
keep_mask = torch.zeros_like(scores, dtype=torch.bool)
107-
for class_id in torch.unique(idxs):
108+
#for class_id in torch.unique(idxs):
109+
for class_id in idxs:
108110
curr_indices = torch.where(idxs == class_id)[0]
109111
curr_keep_indices = nms(boxes[curr_indices], scores[curr_indices], iou_threshold)
110112
keep_mask[curr_indices[curr_keep_indices]] = True

0 commit comments

Comments
 (0)