diff --git a/litgpt/generate/base.py b/litgpt/generate/base.py index 83d13fddbd..552916a667 100644 --- a/litgpt/generate/base.py +++ b/litgpt/generate/base.py @@ -77,7 +77,7 @@ def next_token( model: GPT, input_pos: torch.Tensor, x: torch.Tensor, - input_pos_maxp1: Optional[torch.Tensor] = None, + input_pos_maxp1: Optional[int] = None, **sample_kwargs: Dict[str, Any], ) -> torch.Tensor: logits = model(x, input_pos, input_pos_maxp1=input_pos_maxp1) @@ -180,10 +180,7 @@ def generate_fn( input_pos = torch.arange(0, prompt_size, device=device, dtype=torch.int64) # input_pos_maxp1 introduces data-dependent shapes and control flow. # We want to skip if ThunderModules are involved, either directly or wrapped in LightningModule etc. - if not any(m.__class__.__name__ == "ThunderModule" for m in model.modules()): - input_pos_maxp1 = torch.tensor(prompt_size, device=device) - else: - input_pos_maxp1 = None + input_pos_maxp1 = prompt_size if all(m.__class__.__name__ != "ThunderModule" for m in model.modules()) else None for current_idx in range(max_returned_tokens - prompt_size): # Generate the token token = next_token( @@ -231,7 +228,7 @@ def generate_fn( else: input_pos.add_(1) if input_pos_maxp1 is not None: - input_pos_maxp1.add_(1) + input_pos_maxp1 += 1 # Yield any remaining tokens if yielded_idx < len(tokens): diff --git a/litgpt/generate/sequentially.py b/litgpt/generate/sequentially.py index 52ab467912..2ed1987ace 100644 --- a/litgpt/generate/sequentially.py +++ b/litgpt/generate/sequentially.py @@ -108,7 +108,7 @@ def layer_to_device( def move_block_input(device: torch.device, module: torch.nn.Module, ins): """``forward_pre_hook`` to move a Block's input before forward.""" # during inference, none of the inputs are None: x, cos, sin, mask, input_pos - return tuple(t.to(device) for t in ins) + return tuple(t.to(device) if torch.is_tensor(t) else t for t in ins) def move_block_output(device: torch.device, module: torch.nn.Module, ins, outs) -> torch.Tensor: diff --git a/litgpt/generate/speculative_decoding.py b/litgpt/generate/speculative_decoding.py index 99ba92345c..814ee958ac 100644 --- a/litgpt/generate/speculative_decoding.py +++ b/litgpt/generate/speculative_decoding.py @@ -62,7 +62,7 @@ def speculative_decoding( target_model: GPT, token: torch.Tensor, input_pos: torch.Tensor, - input_pos_maxp1: torch.Tensor, + input_pos_maxp1: int, speculative_k: int, **sample_kwargs: Dict[str, Any], ) -> torch.Tensor: @@ -100,7 +100,7 @@ def speculative_decoding( # Step 1: Generate candidate tokens using draft model # The draft model autoregressively generates k tokens, keeping track of probabilities draft_input_pos = input_pos.clone() - draft_input_pos_maxp1 = input_pos_maxp1.clone() + draft_input_pos_maxp1 = input_pos_maxp1 draft_tokens, draft_probs = [], [] draft_token = token for idx in range(speculative_k): @@ -109,7 +109,7 @@ def speculative_decoding( ) draft_token, draft_prob = sample(logits, **sample_kwargs) draft_input_pos.add_(1) - draft_input_pos_maxp1.add_(1) + draft_input_pos_maxp1 += 1 draft_tokens.append(draft_token) draft_probs.append(draft_prob) draft_tokens = torch.cat(draft_tokens) @@ -118,7 +118,7 @@ def speculative_decoding( # Feed both original token and draft tokens to get target probabilities candidate_tokens = torch.cat((token, draft_tokens)) candidate_input_pos = input_pos + torch.arange(0, speculative_k + 1, device=input_pos.device) - candidate_input_pos_maxp1 = input_pos_maxp1.add(speculative_k) + candidate_input_pos_maxp1 = input_pos_maxp1 + speculative_k target_logits = target_model( idx=candidate_tokens.unsqueeze(0), input_pos=candidate_input_pos, input_pos_maxp1=candidate_input_pos_maxp1 ) @@ -228,7 +228,10 @@ def generate( # Step 1: Prefill draft and target models with the prompt. input_pos = torch.arange(0, prompt_size, device=device, dtype=torch.int64) - input_pos_maxp1 = torch.tensor(prompt_size, device=device) + # We want to skip if ThunderModules are involved, either directly or wrapped in LightningModule etc. + input_pos_maxp1 = ( + prompt_size if all(m.__class__.__name__ != "ThunderModule" for m in target_model.modules()) else None + ) next_token( draft_model, input_pos, @@ -249,7 +252,7 @@ def generate( ) # Update position trackers after prompt input_pos = torch.tensor([prompt_size], device=device, dtype=torch.int64) - input_pos_maxp1.add_(1) + input_pos_maxp1 += 1 # Step 2: Main generation loop. tokens = [] @@ -289,7 +292,7 @@ def generate( # Update positions for next iteration input_pos.add_(accepted_tokens_len) - input_pos_maxp1.add_(accepted_tokens_len) + input_pos_maxp1 += accepted_tokens_len token = new_tokens[-1].unsqueeze(0) # Finalize generated sequence diff --git a/litgpt/model.py b/litgpt/model.py index 5fcb04d4b9..db6aebe790 100644 --- a/litgpt/model.py +++ b/litgpt/model.py @@ -84,7 +84,7 @@ def forward( self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None, - input_pos_maxp1: Optional[torch.Tensor] = None, + input_pos_maxp1: Optional[int] = None, lm_head_chunk_size: int = 0, ) -> Union[torch.Tensor, List[torch.Tensor]]: """ @@ -291,7 +291,7 @@ def forward( sin: torch.Tensor, mask: Optional[torch.Tensor] = None, input_pos: Optional[torch.Tensor] = None, - input_pos_maxp1: Optional[torch.Tensor] = None, + input_pos_maxp1: Optional[int] = None, ) -> torch.Tensor: """ Non-parallel residual Parallel residual @@ -361,7 +361,7 @@ def forward( sin: torch.Tensor, mask: Optional[torch.Tensor] = None, input_pos: Optional[torch.Tensor] = None, - input_pos_maxp1: Optional[torch.Tensor] = None, + input_pos_maxp1: Optional[int] = None, ) -> torch.Tensor: # Notation: # - B | batch size diff --git a/tests/test_model.py b/tests/test_model.py index 8860a26614..39d946fb2d 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1533,7 +1533,7 @@ def test_forward_with_without_input_pos_maxp1(): model.set_kv_cache(batch_size) idx = torch.randint(0, config.padded_vocab_size, (1, 10)) input_pos = torch.arange(1, 11) - input_pos_maxp1 = torch.tensor(11) + input_pos_maxp1 = 11 logits_with_maxp1 = model(idx, input_pos, input_pos_maxp1=input_pos_maxp1) logits_no_maxp1 = model(idx, input_pos) torch.testing.assert_close(logits_with_maxp1, logits_no_maxp1)