Skip to content

Data loader refactor #2707

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 36 commits into from
Jun 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
0cbafb8
data loading refactor (wip)
djsaunde May 20, 2025
cf4f561
updates
djsaunde May 20, 2025
f81e023
progress
djsaunde May 22, 2025
cfa7c3f
pytest
djsaunde May 22, 2025
2ec912a
pytest fix
djsaunde May 22, 2025
718b519
lint
djsaunde May 23, 2025
b744d02
zero_first -> filelock, more simplifications
djsaunde May 23, 2025
ac920a5
small simplification
djsaunde May 23, 2025
42c6b8c
import change
djsaunde May 23, 2025
234cae4
nit
djsaunde May 23, 2025
ec4783d
lint
djsaunde May 29, 2025
69f83a5
simplify dedup
djsaunde May 29, 2025
250126e
couldnt resist
djsaunde May 29, 2025
8b4e7d0
review comments WIP
djsaunde May 30, 2025
cf4fa7f
continued wip
djsaunde May 30, 2025
e697ca3
minor changes
djsaunde Jun 2, 2025
fd5ed2b
fix; remove contrived test
djsaunde Jun 3, 2025
06c6baf
further refactor
djsaunde Jun 3, 2025
f96e641
set default seed in pydantic config
djsaunde Jun 3, 2025
3d7d9d2
lint
djsaunde Jun 3, 2025
0f4243f
continued simplication
djsaunde Jun 4, 2025
29a8e27
lint
djsaunde Jun 4, 2025
8598ca3
renaming and nits
djsaunde Jun 4, 2025
e74fc59
filelock tests
djsaunde Jun 5, 2025
403551c
fix
djsaunde Jun 5, 2025
bbcc108
fix
djsaunde Jun 5, 2025
b7e01ab
lint
djsaunde Jun 5, 2025
eeaa5ee
remove nullable arg
djsaunde Jun 5, 2025
c1b7eb1
remove unnecessary code
djsaunde Jun 10, 2025
8505f17
moving dataset save fn to shared module
djsaunde Jun 10, 2025
aa34452
remove debug print
djsaunde Jun 10, 2025
148490a
matching var naming
djsaunde Jun 10, 2025
669579a
fn name change
djsaunde Jun 10, 2025
d523857
coderabbit comments
djsaunde Jun 10, 2025
aa0c3ef
naming nit
djsaunde Jun 10, 2025
9b1b33d
fix test
djsaunde Jun 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions src/axolotl/common/const.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
"""
Various shared constants
"""
"""Various shared constants"""

DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
77 changes: 29 additions & 48 deletions src/axolotl/common/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,13 @@
import math
import random
from dataclasses import dataclass
from typing import Optional, Union

from datasets import Dataset

import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs
from axolotl.loaders import load_processor, load_tokenizer
from axolotl.utils.data import prepare_dataset
from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.data import prepare_datasets, prepare_preference_datasets
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.enums import RLType
Expand All @@ -30,16 +28,7 @@ class TrainDatasetMeta:


def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset:
"""
Randomly sample `num_samples` samples from `dataset`.

Args:
dataset: Dataset.
num_samples: Number of samples to return.

Returns:
Random sample (with replacement) of examples in `dataset`.
"""
"""Randomly sample `num_samples` samples with replacement from `dataset`."""
return dataset.select(
[random.randrange(0, len(dataset) - 1) for _ in range(num_samples)] # nosec
)
Expand All @@ -51,44 +40,37 @@ def load_datasets(
cli_args: PreprocessCliArgs | TrainerCliArgs | None = None,
debug: bool = False,
) -> TrainDatasetMeta:
"""
Loads one or more training or evaluation datasets, calling
`axolotl.utils.data.prepare_dataset`. Optionally, logs out debug information.
"""Loads one or more training or evaluation datasets, calling
`axolotl.utils.data.prepare_datasets`. Optionally, logs out debug information.

Args:
cfg: Dictionary mapping `axolotl` config keys to values.
cli_args: Command-specific CLI arguments.
debug: Whether to print out tokenization of sample
debug: Whether to print out tokenization of sample. This is duplicated in
`cfg` and `cli_args`, but is kept due to use in our Colab notebooks.

Returns:
Dataclass with fields for training and evaluation datasets and the computed
`total_num_steps`.
`total_num_steps`.
"""
tokenizer = load_tokenizer(cfg)
processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None
preprocess_iterable = (
cli_args
and hasattr(cli_args, "iterable")
and cli_args.iterable is not None
and cli_args.iterable
)
preprocess_iterable = getattr(cli_args, "iterable", False)

train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset(
train_dataset, eval_dataset, total_num_steps, prompters = prepare_datasets(
cfg,
tokenizer,
processor=processor,
preprocess_iterable=preprocess_iterable,
)

if ( # pylint: disable=too-many-boolean-expressions
cli_args
and (
cli_args.debug
or cfg.debug
or cli_args.debug_text_only
or int(cli_args.debug_num_examples) > 0
)
) or debug:
if (
cfg.debug
or getattr(cli_args, "debug", False)
or getattr(cli_args, "debug_text_only", False)
or getattr(cli_args, "debug_num_examples", 0) > 0
or debug
):
LOG.info("check_dataset_labels...")

num_examples = cli_args.debug_num_examples if cli_args else 1
Expand All @@ -113,13 +95,10 @@ def load_datasets(


def load_preference_datasets(
*,
cfg: DictDefault,
cli_args: Union[PreprocessCliArgs, TrainerCliArgs],
*, cfg: DictDefault, cli_args: PreprocessCliArgs | TrainerCliArgs
) -> TrainDatasetMeta:
"""
Loads one or more training or evaluation datasets for RL training using paired
preference data, calling `axolotl.utils.data.rl.load_prepare_preference_datasets`.
"""Loads one or more training or evaluation datasets for RL training using paired
preference data, calling `axolotl.utils.data.rl.prepare_preference_datasets`.
Optionally, logs out debug information.

Args:
Expand All @@ -130,21 +109,23 @@ def load_preference_datasets(
Dataclass with fields for training and evaluation datasets and the computed
`total_num_steps`.
"""
train_dataset, eval_dataset = load_prepare_preference_datasets(cfg)
total_num_steps: Optional[int] = int(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
)
if cfg.rl is RLType.GRPO:
total_num_steps = None
tokenizer = load_tokenizer(cfg)
train_dataset, eval_dataset = prepare_preference_datasets(cfg, tokenizer)

total_num_steps: int | None = None
if cfg.rl is not RLType.GRPO:
total_num_steps = int(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
)

if cli_args.debug or cfg.debug:
LOG.info("check_dataset_labels...")

tokenizer = load_tokenizer(cfg)
train_samples = sample_dataset(train_dataset, cli_args.debug_num_examples)
check_dataset_labels(
train_samples,
tokenizer,
dataset=train_samples,
tokenizer=tokenizer,
num_examples=cli_args.debug_num_examples,
text_only=cli_args.debug_text_only,
rl_mode=True,
Expand Down
2 changes: 1 addition & 1 deletion src/axolotl/core/builders/causal.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ def build(self, total_num_steps):
elif "tokenizer" in sig.parameters:
trainer_kwargs["tokenizer"] = self.tokenizer
if (
not (trainer_cls in [AxolotlRewardTrainer, AxolotlPRMTrainer])
trainer_cls not in [AxolotlRewardTrainer, AxolotlPRMTrainer]
and self.cfg.datasets is not None
):
trainer_kwargs["dataset_tags"] = [
Expand Down
43 changes: 23 additions & 20 deletions src/axolotl/datasets.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Module containing Dataset functionality"""

import os
from typing import List, Optional, Union

import torch
from datasets import Dataset, IterableDataset
Expand All @@ -20,21 +19,21 @@


class TokenizedPromptDataset(Dataset):
"""
Dataset that returns tokenized prompts from a stream of text files.
Args:
prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for processing the data.
dataset (dataset.Dataset): Dataset with text files.
process_count (int): Number of processes to use for tokenizing.
keep_in_memory (bool): Whether to keep the tokenized dataset in memory.
"""Dataset that returns tokenized prompts from a stream of text files.

Args:
prompt_tokenizer: The prompt tokenizing method for processing the data.
dataset: Dataset with text files.
process_count: Number of processes to use for tokenizing.
keep_in_memory: Whether to keep the tokenized dataset in memory.
"""

def __init__( # pylint: disable=super-init-not-called
self,
prompt_tokenizer: PromptTokenizingStrategy,
dataset: Dataset,
process_count: Optional[int] = None,
keep_in_memory: Optional[bool] = False,
process_count: int | None = None,
keep_in_memory: bool | None = False,
**kwargs,
):
self.prompt_tokenizer = prompt_tokenizer
Expand Down Expand Up @@ -76,14 +75,14 @@ def process(self, dataset):

def wrap_dataset_for_tokenized_prompt(
prompt_tokenizer: PromptTokenizingStrategy,
dataset: Union[Dataset, IterableDataset],
dataset: Dataset | IterableDataset,
**kwargs,
):
if isinstance(dataset, IterableDataset):
map_kwargs = {}
if prompt_tokenizer.supports_batched:
map_kwargs["batched"] = True
features = dataset.features.keys()
features = list(dataset.features.keys())
return dataset.map(
prompt_tokenizer.tokenize_prompt,
remove_columns=features,
Expand All @@ -94,12 +93,13 @@ def wrap_dataset_for_tokenized_prompt(

# TODO this isn't the best since it can't interleave datasets
class ConstantLengthDataset(IterableDataset):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know that we actually use this anywhere and is a good candidate for pruning

"""
Iterable dataset that returns constant length chunks of tokens from stream of text files.
Args:
tokenizer (Tokenizer): The processor used for processing the data.
dataset (dataset.Dataset): Dataset with text files.
seq_length (int): Length of token sequences to return.
"""Iterable dataset that returns constant length chunks of tokens from stream of
text files.

Args:
tokenizer: The processor used for processing the data.
dataset: Dataset with text files.
seq_length: Length of token sequences to return.
"""

def __init__( # pylint: disable=super-init-not-called
Expand All @@ -110,7 +110,7 @@ def __init__( # pylint: disable=super-init-not-called
):
self.tokenizer = tokenizer
self.concat_token_id = tokenizer.eos_token_id
self.datasets: List[IterableDataset] = datasets
self.datasets: list[IterableDataset] = datasets
self.seq_length = seq_length

vocab_size = len(tokenizer.get_vocab())
Expand Down Expand Up @@ -174,7 +174,10 @@ def __iter__(self):
}
else:
LOG.warning(
f"dropping batch due to tensor size mismatch input_ids: {input_ids.size()}, labels: {labels.size()}, attention_mask: {attention_mask.size()}"
"Dropping batch due to tensor size mismatch "
f"input_ids: {input_ids.size()}, "
f"labels: {labels.size()}, "
f"attention_mask: {attention_mask.size()}"
)
buffer = {
"input_ids": [],
Expand Down
9 changes: 6 additions & 3 deletions src/axolotl/loaders/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
from transformers import (
AddedToken,
AutoTokenizer,
PreTrainedTokenizer,
)

from axolotl.integrations.base import PluginManager
from axolotl.loaders.utils import get_linear_embedding_layers, load_model_config
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import (
barrier,
is_local_main_process,
Expand Down Expand Up @@ -117,7 +119,7 @@ def modify_tokenizer_files(
return tokenizer_dir


def load_tokenizer(cfg):
def load_tokenizer(cfg: DictDefault) -> PreTrainedTokenizer:
"""Load and configure the tokenizer based on the provided config."""
model_config = load_model_config(cfg)
tokenizer_kwargs = {}
Expand Down Expand Up @@ -207,11 +209,12 @@ def load_tokenizer(cfg):
)
and k != "pad_token"
):
lora_modules_to_save = ", ".join(
lora_modules_to_save_str = ", ".join(
[f"`{x}`" for x in lora_modules_to_save]
)
raise ValueError(
f"Please set lora_modules_to_save to [{lora_modules_to_save}] when using an adapter and changing the special tokens."
f"Please set lora_modules_to_save to [{lora_modules_to_save_str}] "
"when using an adapter and changing the special tokens."
)

tokenizer.add_special_tokens(
Expand Down
1 change: 0 additions & 1 deletion src/axolotl/prompt_strategies/messages/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,3 @@ def load(tokenizer, cfg, ds_cfg, processor=None):
except Exception as exc: # pylint: disable=broad-exception-caught
LOG.error(f"Failed to load prompt strategy `{strategy}`: {str(exc)}")
raise exc
return None
11 changes: 11 additions & 0 deletions src/axolotl/prompt_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import abc
from typing import Callable, Dict, List, Optional, Tuple, Union

from datasets import Dataset
from transformers import BatchEncoding, PreTrainedTokenizer

from axolotl.prompters import Prompter
Expand All @@ -28,6 +29,16 @@ class DatasetWrappingStrategy(abc.ABC):
Abstract class for wrapping datasets for Chat Messages
"""

@abc.abstractmethod
def wrap_dataset(
self,
dataset,
process_count: int | None = None,
keep_in_memory: bool | None = False,
**kwargs,
) -> Dataset:
pass


class PromptTokenizingStrategy(abc.ABC):
"""
Expand Down
4 changes: 2 additions & 2 deletions src/axolotl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ def setup_model_and_tokenizer(
) -> tuple[
PreTrainedModel, PreTrainedTokenizer, PeftConfig | None, ProcessorMixin | None
]:
"""
Load the tokenizer, processor (for multimodal models), and model based on configuration.
"""Load the tokenizer, processor (for multimodal models), and model based on
configuration.

Args:
cfg: Dictionary mapping `axolotl` config keys to values.
Expand Down
25 changes: 15 additions & 10 deletions src/axolotl/utils/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
"""
Data processing modules
"""
"""Init for `axolotl.utils.data` module."""

from axolotl.utils.data.pretraining import ( # noqa: F401
from axolotl.utils.data.pretraining import (
encode_pretraining,
wrap_pretraining_dataset,
)
from axolotl.utils.data.rl import load_prepare_preference_datasets # noqa: F401
from axolotl.utils.data.sft import ( # noqa: F401
from axolotl.utils.data.rl import prepare_preference_datasets
from axolotl.utils.data.sft import (
get_dataset_wrapper,
load_prepare_datasets,
load_tokenized_prepared_datasets,
prepare_dataset,
prepare_datasets,
)
from axolotl.utils.data.utils import md5 # noqa: F401
from axolotl.utils.data.utils import md5

__all__ = [
"encode_pretraining",
"wrap_pretraining_dataset",
"prepare_preference_datasets",
"get_dataset_wrapper",
"prepare_datasets",
"md5",
]
Loading
Loading