Skip to content

Commit bf6cfdf

Browse files
authored
Enable passing in external position ids (#1493)
1 parent 0db4425 commit bf6cfdf

File tree

3 files changed

+50
-6
lines changed

3 files changed

+50
-6
lines changed

llmfoundry/models/mpt/modeling_mpt.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -760,6 +760,7 @@ def forward(
760760
output_hidden_states: Optional[bool] = None,
761761
use_cache: Optional[bool] = None,
762762
inputs_embeds: Optional[torch.Tensor] = None,
763+
position_ids: Optional[torch.LongTensor] = None,
763764
) -> BaseModelOutputWithPast:
764765
return_dict = (
765766
return_dict if return_dict is not None else self.config.return_dict
@@ -856,12 +857,16 @@ def forward(
856857
)
857858

858859
if self.learned_pos_emb or (self.rope and self.rope_impl == 'hf'):
859-
pos = torch.arange(
860-
past_position,
861-
S + past_position,
862-
dtype=torch.long,
863-
device=input_device,
864-
).unsqueeze(0)
860+
if position_ids is None:
861+
pos = torch.arange(
862+
past_position,
863+
S + past_position,
864+
dtype=torch.long,
865+
device=input_device,
866+
).unsqueeze(0)
867+
else:
868+
pos = position_ids
869+
865870
if attention_mask is not None:
866871
# adjust the position indices to account for padding tokens
867872
pos = torch.clamp(
@@ -1128,6 +1133,7 @@ def forward(
11281133
output_hidden_states: Optional[bool] = None,
11291134
use_cache: Optional[bool] = None,
11301135
inputs_embeds: Optional[torch.FloatTensor] = None,
1136+
position_ids: Optional[torch.LongTensor] = None,
11311137
) -> CausalLMOutputWithPast:
11321138
return_dict = (
11331139
return_dict if return_dict is not None else self.config.return_dict
@@ -1146,6 +1152,7 @@ def forward(
11461152
output_hidden_states=output_hidden_states,
11471153
use_cache=use_cache,
11481154
inputs_embeds=inputs_embeds,
1155+
position_ids=position_ids,
11491156
)
11501157

11511158
if self.lm_head is not None:
@@ -1443,6 +1450,7 @@ def forward(self, batch: MutableMapping) -> CausalLMOutputWithPast:
14431450
attention_mask=batch.get('attention_mask', None),
14441451
sequence_id=batch.get('sequence_id', None),
14451452
inputs_embeds=batch.get('inputs_embeds', None),
1453+
position_ids=batch.get('position_ids', None),
14461454
)
14471455

14481456
def loss(self, outputs: CausalLMOutputWithPast,

tests/models/test_model.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2941,3 +2941,37 @@ def test_hf_rotary_child_class_builds():
29412941

29422942
assert torch.all(cos == cos_mp)
29432943
assert torch.all(sin == sin_mp)
2944+
2945+
2946+
@pytest.mark.parametrize(
2947+
'conf_path',
2948+
[
2949+
'scripts/train/yamls/pretrain/testing.yaml',
2950+
],
2951+
)
2952+
def test_position_ids_fwd_pass(
2953+
request: pytest.FixtureRequest,
2954+
conf_path: str,
2955+
batch_size: int = 2,
2956+
):
2957+
test_cfg, model, _ = _get_objs(request=request, conf_path=conf_path)
2958+
model.eval()
2959+
2960+
# run a forward where we do not pass the position_ids
2961+
batch = gen_random_batch(batch_size, test_cfg)
2962+
outputs = model(batch)
2963+
loss_no_ids = model.loss(outputs, batch)
2964+
assert isinstance(loss_no_ids, torch.Tensor)
2965+
2966+
# run a forward where we explicitly pass the position_ids
2967+
input_ids = batch['input_ids']
2968+
_, S = input_ids.size()
2969+
pos = torch.arange(0, S, dtype=torch.long,
2970+
device=input_ids.device).unsqueeze(0)
2971+
batch['position_ids'] = pos
2972+
2973+
outputs = model(batch)
2974+
loss_ids = model.loss(outputs, batch)
2975+
assert isinstance(loss_ids, torch.Tensor)
2976+
2977+
assert torch.eq(loss_no_ids, loss_ids)

tests/models/test_mpt_gen.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def forward(
3737
output_hidden_states: Optional[bool] = None,
3838
use_cache: Optional[bool] = None,
3939
inputs_embeds: Optional[torch.FloatTensor] = None,
40+
position_ids: Optional[torch.LongTensor] = None,
4041
):
4142
result = super().forward(
4243
input_ids,
@@ -49,6 +50,7 @@ def forward(
4950
output_hidden_states,
5051
use_cache,
5152
inputs_embeds,
53+
position_ids,
5254
)
5355
# Modify the logits to select the next token.
5456
if dist.get_global_rank() == 0:

0 commit comments

Comments
 (0)