-
Notifications
You must be signed in to change notification settings - Fork 570
Support rope scaling #1391
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Support rope scaling #1391
Changes from all commits
Commits
Show all changes
24 commits
Select commit
Hold shift + click to select a range
8fed63a
support rope scaling
milocress 8c21a1f
Merge branch 'main' into milo/support-llama3-rope
milocress aa4736f
use rope scaling
milocress 853bfa6
Merge branch 'main' into milo/support-llama3-rope
milocress 1e7d8d3
update to use rope config
milocress 2b4f21f
merged
milocress a032676
update config args
milocress 0d190a2
use allowlist for config to enforce hygeine
milocress 5c24b70
Merge branch 'main' into milo/support-llama3-rope
dakinggg 604f0b9
allow llama3 rope config
milocress 77fd401
merged
milocress 0682283
add unit test
milocress ff0d3de
documented allowed llama config keys
milocress ef6c8c2
Update llmfoundry/models/mpt/modeling_mpt.py
dakinggg dd1de37
Address comments 1
milocress 1515708
Apply suggestions from code review
milocress 8da6165
Apply suggestions from code review
milocress b0700c9
use same codepath for all the hf rotary embeddings
milocress 518a3a1
fix
milocress 44ce115
update
milocress d297395
test WIP but fix get/pop
milocress cf460ec
change the thing being popped
milocress d48bcb4
give up on testing hf
milocress c4051bc
Merge branch 'main' into milo/support-llama3-rope
milocress File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
# Copyright 2024 MosaicML LLM Foundry authors | ||
# SPDX-License-Identifier: Apache-2.0 | ||
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding | ||
milocress marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
from llmfoundry.models.mpt.modeling_mpt import gen_rotary_embedding | ||
|
||
rope_config = { | ||
'rope_theta': 500000.0, | ||
'rope_impl': 'hf', | ||
'rope_hf_config': { | ||
'factor': 8.0, | ||
'low_freq_factor': 1.0, | ||
'high_freq_factor': 4.0, | ||
'original_max_position_embeddings': 8192, | ||
'type': 'llama3', | ||
}, | ||
} | ||
|
||
rope_dail_config = {} | ||
dakinggg marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
def test_rope_scaling(): | ||
d_model = 128 | ||
n_heads = 32 | ||
max_seq_len = 65536 | ||
|
||
embedding = gen_rotary_embedding( | ||
d_model=d_model, | ||
n_heads=n_heads, | ||
rope_dail_config=rope_dail_config, | ||
max_seq_len=max_seq_len, | ||
**rope_config, | ||
) | ||
|
||
assert isinstance(embedding, LlamaRotaryEmbedding) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.