Skip to content

input_pos_maxp1 as a Python integer #2016

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions litgpt/generate/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def next_token(
model: GPT,
input_pos: torch.Tensor,
x: torch.Tensor,
input_pos_maxp1: Optional[torch.Tensor] = None,
input_pos_maxp1: Optional[int] = None,
**sample_kwargs: Dict[str, Any],
) -> torch.Tensor:
logits = model(x, input_pos, input_pos_maxp1=input_pos_maxp1)
Expand Down Expand Up @@ -180,10 +180,7 @@ def generate_fn(
input_pos = torch.arange(0, prompt_size, device=device, dtype=torch.int64)
# input_pos_maxp1 introduces data-dependent shapes and control flow.
# We want to skip if ThunderModules are involved, either directly or wrapped in LightningModule etc.
if not any(m.__class__.__name__ == "ThunderModule" for m in model.modules()):
input_pos_maxp1 = torch.tensor(prompt_size, device=device)
else:
input_pos_maxp1 = None
input_pos_maxp1 = prompt_size if all(m.__class__.__name__ != "ThunderModule" for m in model.modules()) else None
for current_idx in range(max_returned_tokens - prompt_size):
# Generate the token
token = next_token(
Expand Down Expand Up @@ -231,7 +228,7 @@ def generate_fn(
else:
input_pos.add_(1)
if input_pos_maxp1 is not None:
input_pos_maxp1.add_(1)
input_pos_maxp1 += 1

# Yield any remaining tokens
if yielded_idx < len(tokens):
Expand Down
2 changes: 1 addition & 1 deletion litgpt/generate/sequentially.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def layer_to_device(
def move_block_input(device: torch.device, module: torch.nn.Module, ins):
"""``forward_pre_hook`` to move a Block's input before forward."""
# during inference, none of the inputs are None: x, cos, sin, mask, input_pos
return tuple(t.to(device) for t in ins)
return tuple(t.to(device) if torch.is_tensor(t) else t for t in ins)


def move_block_output(device: torch.device, module: torch.nn.Module, ins, outs) -> torch.Tensor:
Expand Down
17 changes: 10 additions & 7 deletions litgpt/generate/speculative_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def speculative_decoding(
target_model: GPT,
token: torch.Tensor,
input_pos: torch.Tensor,
input_pos_maxp1: torch.Tensor,
input_pos_maxp1: int,
speculative_k: int,
**sample_kwargs: Dict[str, Any],
) -> torch.Tensor:
Expand Down Expand Up @@ -100,7 +100,7 @@ def speculative_decoding(
# Step 1: Generate candidate tokens using draft model
# The draft model autoregressively generates k tokens, keeping track of probabilities
draft_input_pos = input_pos.clone()
draft_input_pos_maxp1 = input_pos_maxp1.clone()
draft_input_pos_maxp1 = input_pos_maxp1
draft_tokens, draft_probs = [], []
draft_token = token
for idx in range(speculative_k):
Expand All @@ -109,7 +109,7 @@ def speculative_decoding(
)
draft_token, draft_prob = sample(logits, **sample_kwargs)
draft_input_pos.add_(1)
draft_input_pos_maxp1.add_(1)
draft_input_pos_maxp1 += 1
draft_tokens.append(draft_token)
draft_probs.append(draft_prob)
draft_tokens = torch.cat(draft_tokens)
Expand All @@ -118,7 +118,7 @@ def speculative_decoding(
# Feed both original token and draft tokens to get target probabilities
candidate_tokens = torch.cat((token, draft_tokens))
candidate_input_pos = input_pos + torch.arange(0, speculative_k + 1, device=input_pos.device)
candidate_input_pos_maxp1 = input_pos_maxp1.add(speculative_k)
candidate_input_pos_maxp1 = input_pos_maxp1 + speculative_k
target_logits = target_model(
idx=candidate_tokens.unsqueeze(0), input_pos=candidate_input_pos, input_pos_maxp1=candidate_input_pos_maxp1
)
Expand Down Expand Up @@ -228,7 +228,10 @@ def generate(

# Step 1: Prefill draft and target models with the prompt.
input_pos = torch.arange(0, prompt_size, device=device, dtype=torch.int64)
input_pos_maxp1 = torch.tensor(prompt_size, device=device)
# We want to skip if ThunderModules are involved, either directly or wrapped in LightningModule etc.
input_pos_maxp1 = (
prompt_size if all(m.__class__.__name__ != "ThunderModule" for m in target_model.modules()) else None
)
next_token(
draft_model,
input_pos,
Expand All @@ -249,7 +252,7 @@ def generate(
)
# Update position trackers after prompt
input_pos = torch.tensor([prompt_size], device=device, dtype=torch.int64)
input_pos_maxp1.add_(1)
input_pos_maxp1 += 1

# Step 2: Main generation loop.
tokens = []
Expand Down Expand Up @@ -289,7 +292,7 @@ def generate(

# Update positions for next iteration
input_pos.add_(accepted_tokens_len)
input_pos_maxp1.add_(accepted_tokens_len)
input_pos_maxp1 += accepted_tokens_len
token = new_tokens[-1].unsqueeze(0)

# Finalize generated sequence
Expand Down
6 changes: 3 additions & 3 deletions litgpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def forward(
self,
idx: torch.Tensor,
input_pos: Optional[torch.Tensor] = None,
input_pos_maxp1: Optional[torch.Tensor] = None,
input_pos_maxp1: Optional[int] = None,
lm_head_chunk_size: int = 0,
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""
Expand Down Expand Up @@ -291,7 +291,7 @@ def forward(
sin: torch.Tensor,
mask: Optional[torch.Tensor] = None,
input_pos: Optional[torch.Tensor] = None,
input_pos_maxp1: Optional[torch.Tensor] = None,
input_pos_maxp1: Optional[int] = None,
) -> torch.Tensor:
"""
Non-parallel residual Parallel residual
Expand Down Expand Up @@ -361,7 +361,7 @@ def forward(
sin: torch.Tensor,
mask: Optional[torch.Tensor] = None,
input_pos: Optional[torch.Tensor] = None,
input_pos_maxp1: Optional[torch.Tensor] = None,
input_pos_maxp1: Optional[int] = None,
) -> torch.Tensor:
# Notation:
# - B | batch size
Expand Down
2 changes: 1 addition & 1 deletion tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1533,7 +1533,7 @@ def test_forward_with_without_input_pos_maxp1():
model.set_kv_cache(batch_size)
idx = torch.randint(0, config.padded_vocab_size, (1, 10))
input_pos = torch.arange(1, 11)
input_pos_maxp1 = torch.tensor(11)
input_pos_maxp1 = 11
logits_with_maxp1 = model(idx, input_pos, input_pos_maxp1=input_pos_maxp1)
logits_no_maxp1 = model(idx, input_pos)
torch.testing.assert_close(logits_with_maxp1, logits_no_maxp1)