Skip to content

Commit c8425ad

Browse files
krammnicmori360
authored andcommitted
Make AlpacaToMessage public. (pytorch#1785)
1 parent ea7080b commit c8425ad

File tree

4 files changed

+73
-70
lines changed

4 files changed

+73
-70
lines changed

docs/source/api_ref_data.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ Converts data from common schema and conversation JSON formats into a list of to
6464
ShareGPTToMessages
6565
OpenAIToMessages
6666
ChosenRejectedToMessages
67+
AlpacaToMessages
6768

6869
Collaters
6970
---------

torchtune/data/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from torchtune.data._converters import get_openai_messages, get_sharegpt_messages
1818
from torchtune.data._instruct_templates import InstructTemplate
1919
from torchtune.data._messages import (
20+
AlpacaToMessages,
2021
ChosenRejectedToMessages,
2122
InputOutputToMessages,
2223
Message,
@@ -43,6 +44,7 @@
4344
"SummarizeTemplate",
4445
"OpenAIToMessages",
4546
"ShareGPTToMessages",
47+
"AlpacaToMessages",
4648
"truncate",
4749
"Message",
4850
"validate_messages",

torchtune/data/_messages.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -621,3 +621,70 @@ def validate_messages(
621621
f"System message at index {i} in messages, but system messages must come first"
622622
)
623623
last_turn = message.role
624+
625+
626+
class AlpacaToMessages(Transform):
627+
"""
628+
Message transform class for Alpaca-style datasets with "instruction", "input", and "output"
629+
(or equivalent fields specified in column_map) columns. User messages are formed from the
630+
instruction + input columns and assistant messages are formed from the output column. Prompt
631+
templating is conditional on the presence of the "input" column, and thus is handled directly
632+
in this transform class instead of a dedicated :class:`~torchtune.data.PromptTemplate` class
633+
due to this custom logic.
634+
635+
Args:
636+
train_on_input (bool): Whether the model is trained on the user prompt or not.
637+
Default is True.
638+
column_map (Optional[Dict[str, str]]): a mapping to change the expected "instruction", "input",
639+
and "output" column names to the actual column names in the dataset. Default is None,
640+
keeping the default column names.
641+
"""
642+
643+
def __init__(
644+
self, train_on_input: bool = True, column_map: Optional[Dict[str, str]] = None
645+
):
646+
self.train_on_input = train_on_input
647+
self.column_map = column_map
648+
self.template = {
649+
"prompt_input": (
650+
"Below is an instruction that describes a task, paired with an input that provides further context. "
651+
"Write a response that appropriately completes the request.\n\n"
652+
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
653+
),
654+
"prompt_no_input": (
655+
"Below is an instruction that describes a task. "
656+
"Write a response that appropriately completes the request.\n\n"
657+
"### Instruction:\n{instruction}\n\n### Response:\n"
658+
),
659+
}
660+
661+
def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]:
662+
column_map = self.column_map or {}
663+
key_input = column_map.get("input", "input")
664+
key_instruction = column_map.get("instruction", "instruction")
665+
key_output = column_map.get("output", "output")
666+
667+
if key_input in sample and sample[key_input]:
668+
prompt = self.template["prompt_input"].format(
669+
instruction=sample[key_instruction], input=sample[key_input]
670+
)
671+
else:
672+
prompt = self.template["prompt_no_input"].format(
673+
instruction=sample[key_instruction]
674+
)
675+
676+
messages = [
677+
Message(
678+
role="user",
679+
content=prompt,
680+
masked=not self.train_on_input,
681+
eot=True,
682+
),
683+
Message(
684+
role="assistant",
685+
content=sample[key_output],
686+
masked=False,
687+
eot=True,
688+
),
689+
]
690+
return {"messages": messages}

torchtune/datasets/_alpaca.py

Lines changed: 3 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -6,80 +6,13 @@
66

77
from functools import partial
88

9-
from typing import Any, Dict, Mapping, Optional, Union
9+
from typing import Any, Dict, Optional, Union
10+
11+
from torchtune.data._messages import AlpacaToMessages
1012

11-
from torchtune.data._messages import Message
1213
from torchtune.datasets._packed import PackedDataset
1314
from torchtune.datasets._sft import SFTDataset
1415
from torchtune.modules.tokenizers import ModelTokenizer
15-
from torchtune.modules.transforms import Transform
16-
17-
18-
class AlpacaToMessages(Transform):
19-
"""
20-
Message transform class for Alpaca-style datasets with "instruction", "input", and "output"
21-
(or equivalent fields specified in column_map) columns. User messages are formed from the
22-
instruction + input columns and assistant messages are formed from the output column. Prompt
23-
templating is conditional on the presence of the "input" column, and thus is handled directly
24-
in this transform class instead of a dedicated :class:`~torchtune.data.PromptTemplate` class
25-
due to this custom logic.
26-
27-
Args:
28-
train_on_input (bool): Whether the model is trained on the user prompt or not.
29-
Default is True.
30-
column_map (Optional[Dict[str, str]]): a mapping to change the expected "instruction", "input",
31-
and "output" column names to the actual column names in the dataset. Default is None,
32-
keeping the default column names.
33-
"""
34-
35-
def __init__(
36-
self, train_on_input: bool = True, column_map: Optional[Dict[str, str]] = None
37-
):
38-
self.train_on_input = train_on_input
39-
self.column_map = column_map
40-
self.template = {
41-
"prompt_input": (
42-
"Below is an instruction that describes a task, paired with an input that provides further context. "
43-
"Write a response that appropriately completes the request.\n\n"
44-
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
45-
),
46-
"prompt_no_input": (
47-
"Below is an instruction that describes a task. "
48-
"Write a response that appropriately completes the request.\n\n"
49-
"### Instruction:\n{instruction}\n\n### Response:\n"
50-
),
51-
}
52-
53-
def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]:
54-
column_map = self.column_map or {}
55-
key_input = column_map.get("input", "input")
56-
key_instruction = column_map.get("instruction", "instruction")
57-
key_output = column_map.get("output", "output")
58-
59-
if key_input in sample and sample[key_input]:
60-
prompt = self.template["prompt_input"].format(
61-
instruction=sample[key_instruction], input=sample[key_input]
62-
)
63-
else:
64-
prompt = self.template["prompt_no_input"].format(
65-
instruction=sample[key_instruction]
66-
)
67-
68-
messages = [
69-
Message(
70-
role="user",
71-
content=prompt,
72-
masked=not self.train_on_input,
73-
eot=True,
74-
),
75-
Message(
76-
role="assistant",
77-
content=sample[key_output],
78-
masked=False,
79-
eot=True,
80-
),
81-
]
82-
return {"messages": messages}
8316

8417

8518
def alpaca_dataset(

0 commit comments

Comments
 (0)