@@ -184,11 +184,6 @@ def dispatch_b(
184
184
(num_experts + 1 ,), device = hidden_states .device , dtype = torch .int64
185
185
)
186
186
187
- # TODO
188
- # masked_m = torch.empty(
189
- # (self.num_local_experts,), device=hidden_states.device, dtype=torch.int64
190
- # )
191
- # expected_m = 0
192
187
masked_m = expected_m = None
193
188
194
189
return (
@@ -327,6 +322,8 @@ def combine_a(
327
322
def combine_b (self , output , previous_event ):
328
323
hidden_states , event = self ._combine_core (output , previous_event )
329
324
event .current_stream_wait () if self .async_finish else ()
325
+ self .handle = None
326
+ self .src2dst = None
330
327
return hidden_states
331
328
332
329
def _combine_core (self , x : torch .Tensor , previous_event ):
@@ -402,13 +399,6 @@ def dispatch_b(
402
399
):
403
400
hook () if self .return_recv_hook else event .current_stream_wait ()
404
401
405
- # TODO
406
- # reorder_topk_ids = torch.empty(
407
- # (0,), device=hidden_states.device, dtype=torch.int64
408
- # )
409
- # seg_indptr = torch.zeros(
410
- # (num_experts + 1,), device=hidden_states.device, dtype=torch.int64
411
- # )
412
402
reorder_topk_ids = seg_indptr = None
413
403
414
404
return (
@@ -508,6 +498,7 @@ def _combine_core(
508
498
return_recv_hook = self .return_recv_hook ,
509
499
)
510
500
)
501
+ self .handle = None
511
502
return combined_hidden_states , event , hook
512
503
513
504
0 commit comments