Skip to content

Commit aa11358

Browse files
committed
Fix S2S pipeline edge case error (#243)
1 parent a7749e5 commit aa11358

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

src/seamless_communication/streaming/agents/online_text_decoder.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -325,10 +325,7 @@ def policy(self, states: DecoderAgentStates) -> Action:
325325
blocked_ngrams = self.get_blocked_ngrams(states.target_indices)
326326
decoder_features_out = None
327327

328-
while (
329-
len(states.target_indices + pred_indices) < self.max_len(states)
330-
and len(pred_indices) < self.max_consecutive_writes
331-
):
328+
while True:
332329
index, prob, decoder_features = self.run_decoder(states, pred_indices)
333330

334331
if decoder_features_out is None:
@@ -361,6 +358,12 @@ def policy(self, states: DecoderAgentStates) -> Action:
361358
if prob < self.decision_threshold and not states.source_finished:
362359
break
363360

361+
if (
362+
len(states.target_indices + pred_indices) >= self.max_len(states)
363+
or len(pred_indices) >= self.max_consecutive_writes
364+
):
365+
break
366+
364367
pred_indices.append(index)
365368
if self.state_bag.step_nr == 0:
366369
self.state_bag.increment_step_nr(

0 commit comments

Comments
 (0)