Skip to content

Make AlpacaToMessage public. #1785

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/api_ref_data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ Converts data from common schema and conversation JSON formats into a list of to
ShareGPTToMessages
OpenAIToMessages
ChosenRejectedToMessages
AlpacaToMessages

Collaters
---------
Expand Down
2 changes: 2 additions & 0 deletions torchtune/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from torchtune.data._converters import get_openai_messages, get_sharegpt_messages
from torchtune.data._instruct_templates import InstructTemplate
from torchtune.data._messages import (
AlpacaToMessages,
ChosenRejectedToMessages,
InputOutputToMessages,
Message,
Expand All @@ -43,6 +44,7 @@
"SummarizeTemplate",
"OpenAIToMessages",
"ShareGPTToMessages",
"AlpacaToMessages",
"truncate",
"Message",
"validate_messages",
Expand Down
67 changes: 67 additions & 0 deletions torchtune/data/_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,3 +621,70 @@ def validate_messages(
f"System message at index {i} in messages, but system messages must come first"
)
last_turn = message.role


class AlpacaToMessages(Transform):
"""
Message transform class for Alpaca-style datasets with "instruction", "input", and "output"
(or equivalent fields specified in column_map) columns. User messages are formed from the
instruction + input columns and assistant messages are formed from the output column. Prompt
templating is conditional on the presence of the "input" column, and thus is handled directly
in this transform class instead of a dedicated :class:`~torchtune.data.PromptTemplate` class
due to this custom logic.

Args:
train_on_input (bool): Whether the model is trained on the user prompt or not.
Default is True.
column_map (Optional[Dict[str, str]]): a mapping to change the expected "instruction", "input",
and "output" column names to the actual column names in the dataset. Default is None,
keeping the default column names.
"""

def __init__(
self, train_on_input: bool = True, column_map: Optional[Dict[str, str]] = None
):
self.train_on_input = train_on_input
self.column_map = column_map
self.template = {
"prompt_input": (
"Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
),
"prompt_no_input": (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Response:\n"
),
}

def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]:
column_map = self.column_map or {}
key_input = column_map.get("input", "input")
key_instruction = column_map.get("instruction", "instruction")
key_output = column_map.get("output", "output")

if key_input in sample and sample[key_input]:
prompt = self.template["prompt_input"].format(
instruction=sample[key_instruction], input=sample[key_input]
)
else:
prompt = self.template["prompt_no_input"].format(
instruction=sample[key_instruction]
)

messages = [
Message(
role="user",
content=prompt,
masked=not self.train_on_input,
eot=True,
),
Message(
role="assistant",
content=sample[key_output],
masked=False,
eot=True,
),
]
return {"messages": messages}
73 changes: 3 additions & 70 deletions torchtune/datasets/_alpaca.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,80 +6,13 @@

from functools import partial

from typing import Any, Dict, Mapping, Optional, Union
from typing import Any, Dict, Optional, Union

from torchtune.data._messages import AlpacaToMessages

from torchtune.data._messages import Message
from torchtune.datasets._packed import PackedDataset
from torchtune.datasets._sft import SFTDataset
from torchtune.modules.tokenizers import ModelTokenizer
from torchtune.modules.transforms import Transform


class AlpacaToMessages(Transform):
"""
Message transform class for Alpaca-style datasets with "instruction", "input", and "output"
(or equivalent fields specified in column_map) columns. User messages are formed from the
instruction + input columns and assistant messages are formed from the output column. Prompt
templating is conditional on the presence of the "input" column, and thus is handled directly
in this transform class instead of a dedicated :class:`~torchtune.data.PromptTemplate` class
due to this custom logic.

Args:
train_on_input (bool): Whether the model is trained on the user prompt or not.
Default is True.
column_map (Optional[Dict[str, str]]): a mapping to change the expected "instruction", "input",
and "output" column names to the actual column names in the dataset. Default is None,
keeping the default column names.
"""

def __init__(
self, train_on_input: bool = True, column_map: Optional[Dict[str, str]] = None
):
self.train_on_input = train_on_input
self.column_map = column_map
self.template = {
"prompt_input": (
"Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
),
"prompt_no_input": (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Response:\n"
),
}

def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]:
column_map = self.column_map or {}
key_input = column_map.get("input", "input")
key_instruction = column_map.get("instruction", "instruction")
key_output = column_map.get("output", "output")

if key_input in sample and sample[key_input]:
prompt = self.template["prompt_input"].format(
instruction=sample[key_instruction], input=sample[key_input]
)
else:
prompt = self.template["prompt_no_input"].format(
instruction=sample[key_instruction]
)

messages = [
Message(
role="user",
content=prompt,
masked=not self.train_on_input,
eot=True,
),
Message(
role="assistant",
content=sample[key_output],
masked=False,
eot=True,
),
]
return {"messages": messages}


def alpaca_dataset(
Expand Down
Loading