Skip to content

Fix eos_token problem in all required models #1806

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
229 changes: 229 additions & 0 deletions tests/torchtune/models/gemma/test_gemma_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,3 +242,232 @@ def test_tokenize_messages(self, tokenizer):
expected_mask = [True] * 75 + [False] * 125
assert expected_tokens == tokens
assert expected_mask == mask

def test_tokenize_messages_drop_eos(self, tokenizer):
messages = [
Message(
role="user",
content="Below is an instruction that describes a task. Write a response "
"that appropriately completes the request.\n\n### Instruction:\nGenerate "
"a realistic dating profile bio.\n\n### Response:\n",
masked=True,
),
Message(
role="assistant",
content="I'm an outgoing and friendly person who loves spending time with "
"friends and family. I'm also a big-time foodie and love trying out new "
"restaurants and different cuisines. I'm a big fan of the arts and enjoy "
"going to museums and galleries. I'm looking for someone who shares my "
"interest in exploring new places, as well as someone who appreciates a "
"good conversation over coffee.",
),
]
tokens, mask = tokenizer.tokenize_messages(messages, add_eos=False)
expected_tokens = [
1,
323,
418,
202,
31,
128,
15,
120,
47,
88,
584,
23,
1665,
182,
9,
434,
295,
85,
4,
780,
47,
636,
9,
1094,
213,
23,
9,
69,
69,
164,
1153,
299,
35,
961,
132,
237,
7,
5,
761,
4,
12,
0,
313,
120,
47,
88,
584,
166,
493,
171,
54,
299,
9,
906,
244,
19,
186,
767,
303,
671,
92,
209,
24,
190,
52,
38,
4,
12,
0,
1243,
7,
69,
135,
213,
166,
6,
21,
45,
128,
71,
58,
38,
14,
10,
652,
35,
462,
101,
1306,
7,
341,
171,
20,
14,
127,
26,
652,
7,
10,
1268,
4,
6,
21,
45,
591,
9,
566,
22,
994,
913,
38,
20,
52,
24,
10,
1306,
734,
14,
71,
365,
1382,
7,
10,
801,
105,
88,
244,
985,
7,
4,
6,
21,
45,
9,
566,
126,
180,
11,
5,
1137,
7,
10,
1089,
151,
8,
1156,
213,
342,
7,
10,
384,
104,
54,
470,
4,
6,
21,
45,
287,
14,
33,
125,
135,
24,
101,
512,
66,
7,
28,
822,
15,
542,
69,
59,
110,
14,
365,
229,
7,
3,
36,
267,
36,
125,
135,
24,
101,
1503,
182,
9,
222,
1661,
191,
332,
92,
92,
24,
24,
4,
2,
]
# Drop eos token.
expected_tokens = expected_tokens[:-1]
# On 1 less then with eos
expected_mask = [True] * 75 + [False] * 124
assert expected_tokens == tokens
assert expected_mask == mask
Loading
Loading