Skip to content

Commit 19cacbe

Browse files
authored
Consistent type checks for prepend and append tags. (#1824)
1 parent c70ad29 commit 19cacbe

File tree

1 file changed

+20
-12
lines changed

1 file changed

+20
-12
lines changed

torchtune/data/_prompt_templates.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -107,16 +107,17 @@ def __call__(
107107
"""
108108
formatted_dialogue = []
109109
for message in messages:
110+
content = message.content
110111
if message.role in self.template:
111112
prepend_tag = self.template[message.role][0]
112113
append_tag = self.template[message.role][1]
113-
content = (
114-
[{"type": "text", "content": prepend_tag}]
115-
+ message.content
116-
+ [{"type": "text", "content": append_tag}]
117-
)
118-
else:
119114
content = message.content
115+
116+
if isinstance(prepend_tag, str) and len(prepend_tag) > 0:
117+
content = [{"type": "text", "content": prepend_tag}] + content
118+
119+
if isinstance(append_tag, str) and len(append_tag) > 0:
120+
content = content + [{"type": "text", "content": append_tag}]
120121
formatted_dialogue.append(
121122
Message(
122123
role=message.role,
@@ -183,13 +184,20 @@ def __call__(
183184
and index == len(messages) - 1
184185
and len(message.text_content) == 0
185186
):
186-
content = [{"type": "text", "content": prepend_tag}] + message.content
187+
content = message.content
188+
if isinstance(prepend_tag, str) and len(prepend_tag) > 0:
189+
content = [
190+
{"type": "text", "content": prepend_tag}
191+
] + message.content
187192
else:
188-
content = (
189-
[{"type": "text", "content": prepend_tag}]
190-
+ message.content
191-
+ [{"type": "text", "content": append_tag}]
192-
)
193+
content = message.content
194+
195+
if isinstance(prepend_tag, str) and len(prepend_tag) > 0:
196+
content = [{"type": "text", "content": prepend_tag}] + content
197+
198+
if isinstance(append_tag, str) and len(append_tag) > 0:
199+
content = content + [{"type": "text", "content": append_tag}]
200+
193201
formatted_dialogue.append(
194202
Message(
195203
role=message.role,

0 commit comments

Comments
 (0)