diff --git a/torchtune/data/_prompt_templates.py b/torchtune/data/_prompt_templates.py index 67167b3ed9..a7fa070a2e 100644 --- a/torchtune/data/_prompt_templates.py +++ b/torchtune/data/_prompt_templates.py @@ -107,16 +107,17 @@ def __call__( """ formatted_dialogue = [] for message in messages: + content = message.content if message.role in self.template: prepend_tag = self.template[message.role][0] append_tag = self.template[message.role][1] - content = ( - [{"type": "text", "content": prepend_tag}] - + message.content - + [{"type": "text", "content": append_tag}] - ) - else: content = message.content + + if isinstance(prepend_tag, str) and len(prepend_tag) > 0: + content = [{"type": "text", "content": prepend_tag}] + content + + if isinstance(append_tag, str) and len(append_tag) > 0: + content = content + [{"type": "text", "content": append_tag}] formatted_dialogue.append( Message( role=message.role, @@ -183,13 +184,20 @@ def __call__( and index == len(messages) - 1 and len(message.text_content) == 0 ): - content = [{"type": "text", "content": prepend_tag}] + message.content + content = message.content + if isinstance(prepend_tag, str) and len(prepend_tag) > 0: + content = [ + {"type": "text", "content": prepend_tag} + ] + message.content else: - content = ( - [{"type": "text", "content": prepend_tag}] - + message.content - + [{"type": "text", "content": append_tag}] - ) + content = message.content + + if isinstance(prepend_tag, str) and len(prepend_tag) > 0: + content = [{"type": "text", "content": prepend_tag}] + content + + if isinstance(append_tag, str) and len(append_tag) > 0: + content = content + [{"type": "text", "content": append_tag}] + formatted_dialogue.append( Message( role=message.role,