Skip to content

[Draft] Token-weighted datasets: Control up/down-sampling of multiple datasets #2794

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 14 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/axolotl/utils/data/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:
LOG.warning(f"Dropped {dropped} long samples from dataset index {i}")

# Merge datasets
dataset = merge_datasets(split_datasets, cfg)
dataset = merge_datasets(split_datasets, cfg, datasets_configs)

if not cfg.skip_prepare_dataset:
# Save preprocessed dataset
Expand Down
2 changes: 1 addition & 1 deletion src/axolotl/utils/data/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def _load_raw_datasets(
prompters.append(dataset_prompter)

# Merge datasets
dataset = merge_datasets(datasets, cfg)
dataset = merge_datasets(datasets, cfg, datasets_configs)

if not cfg.skip_prepare_dataset:
dataset = drop_long_seq_in_dataset(dataset, cfg)
Expand Down
166 changes: 164 additions & 2 deletions src/axolotl/utils/data/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,19 +513,181 @@ def generate_dataset_hash_from_config(
return str(md5(config_str))


def merge_datasets(datasets: list[Dataset], cfg: DictDefault) -> Dataset:
"""Merge multiple datasets into one with optional shuffling.
def _count_tokens(ds: Dataset, sample_size: int = 2048) -> int:
"""
Return the *exact* number of tokens if the dataset is small enough,
otherwise estimate it from a random sample (saves RAM for huge corpora).
"""
if len(ds) <= sample_size:
return sum(len(ids) for ids in ds["input_ids"])

sample = ds.shuffle(seed=42).select(range(sample_size))
avg_len = sum(len(ids) for ids in sample["input_ids"]) / sample_size
return int(avg_len * len(ds))


def _has_token_weighting(datasets_configs) -> bool:
"""Check if any dataset has non-default weight or weight_strategy."""
for d_cfg in datasets_configs:
weight = getattr(d_cfg, "weight", 1.0)
strategy = getattr(d_cfg, "weight_strategy", "upsample")
if weight != 1.0 or strategy != "upsample":
return True
return False


def _validate_weights(datasets_configs) -> None:
"""Validate that weights are between 0.0-1.0 and sum to 1.0."""
weights = []
for d_cfg in datasets_configs:
weight = getattr(d_cfg, "weight", 1.0)
if not 0.0 <= weight <= 1.0:
raise ValueError(
f"Dataset weight must be between 0.0 and 1.0, got {weight} "
f"for dataset {getattr(d_cfg, 'path', '<unknown>')}"
)
weights.append(weight)

weight_sum = sum(weights)
if abs(weight_sum - 1.0) > 1e-6: # Allow for small floating point errors
raise ValueError(
f"Dataset weights must sum to 1.0, got {weight_sum}. " f"Weights: {weights}"
)


def _merge_datasets_with_token_weighting(
datasets: list[Dataset],
datasets_configs: list,
cfg: DictDefault,
) -> Dataset:
"""
Merge several HF datasets into one, honouring per-dataset weights *in tokens*.

Weights represent the relative proportion each dataset should contribute
to the final merged dataset in terms of tokens.
"""
from math import floor

LOG.info("Merging datasets with token-based weighting...")

_validate_weights(datasets_configs)

total_original_tokens = 0
original_token_counts = []
for i, (ds, d_cfg) in enumerate(zip(datasets, datasets_configs)):
original_tokens = _count_tokens(ds)
original_token_counts.append(original_tokens)
total_original_tokens += original_tokens
dataset_name = getattr(d_cfg, "path", f"dataset_{i}")
LOG.info(f"Dataset '{dataset_name}': {original_tokens:,} tokens ({len(ds):,} samples)")

LOG.info(f"Total original tokens across all datasets: {total_original_tokens:,}")

weighted_parts: list[Dataset] = []

for i, (ds, d_cfg) in enumerate(zip(datasets, datasets_configs)):
weight = float(getattr(d_cfg, "weight", 1.0) or 1.0)
strategy = getattr(d_cfg, "weight_strategy", "upsample").lower()
dataset_name = getattr(d_cfg, "path", f"dataset_{i}")

if weight == 1.0 and len(datasets) == 1:
weighted_parts.append(ds)
continue

tok_cnt = original_token_counts[i]
target_tok = max(1, int(weight * total_original_tokens))

if target_tok > tok_cnt:
effective_operation = "upsampling"
elif target_tok < tok_cnt:
effective_operation = "downsampling"
else:
effective_operation = "unchanged"

LOG.info(f"Dataset '{dataset_name}': {tok_cnt:,} → {target_tok:,} tokens "
f"(weight={weight:.3f}, {effective_operation})")

if strategy == "upsample":
if target_tok <= tok_cnt:
avg_len = max(1, tok_cnt // len(ds))
n_keep = max(1, min(len(ds), int(target_tok / avg_len)))
sampled = ds.shuffle(seed=cfg.seed).select(range(n_keep))
weighted_parts.append(sampled)
else:
repeats = max(1, floor(target_tok / tok_cnt))
weighted_parts.extend([ds] * repeats)

remaining_tok = target_tok - repeats * tok_cnt
if remaining_tok > 0:
avg_len = max(1, tok_cnt // len(ds))
n_extra = min(len(ds), max(1, int(remaining_tok / avg_len)))
extra = ds.shuffle(seed=cfg.seed).select(range(n_extra))
weighted_parts.append(extra)

elif strategy == "downsample":
avg_len = max(1, tok_cnt // len(ds))
n_keep = max(1, min(len(ds), int(target_tok / avg_len)))
sampled = ds.shuffle(seed=cfg.seed).select(range(n_keep))
weighted_parts.append(sampled)
else:
LOG.warning(
f"Unknown weight_strategy '{strategy}' "
f"for dataset {getattr(d_cfg, 'path', '<unknown>')}. "
"Using dataset without weighting."
)
weighted_parts.append(ds)

LOG.info("Weighted dataset parts before concatenation:")
total_weighted_tokens = 0
for i, part in enumerate(weighted_parts):
part_tokens = _count_tokens(part)
total_weighted_tokens += part_tokens
LOG.info(f" Part {i+1}: {part_tokens:,} tokens ({len(part):,} samples)")
LOG.info(f"Total tokens in weighted parts: {total_weighted_tokens:,}")

merged = concatenate_datasets(weighted_parts)

final_tokens = _count_tokens(merged)
LOG.info(f"Final merged dataset: {final_tokens:,} tokens ({len(merged):,} samples)")
LOG.info(f"Token count change: {total_original_tokens:,} → {final_tokens:,} "
f"({final_tokens/total_original_tokens:.2f}x)")

LOG.info("Final weight verification:")
for i, (ds, d_cfg) in enumerate(zip(datasets, datasets_configs)):
dataset_name = getattr(d_cfg, "path", f"dataset_{i}")
original_weight = float(getattr(d_cfg, "weight", 1.0) or 1.0)
target_tokens = max(1, int(original_weight * total_original_tokens))
actual_weight = target_tokens / final_tokens if final_tokens > 0 else 0

LOG.info(f" {dataset_name}: requested={original_weight:.3f}, "
f"achieved≈{actual_weight:.3f} ({target_tokens:,}/{final_tokens:,} tokens)")

if cfg.shuffle_merged_datasets:
merged = merged.shuffle(seed=cfg.seed)

return merged


def merge_datasets(
datasets: list[Dataset], cfg: DictDefault, datasets_configs: list | None = None
) -> Dataset:
"""Merge multiple datasets into one with optional token-based weighting.

Args:
datasets: List of datasets to merge.
cfg: Configuration object containing shuffle settings.
datasets_configs: Optional list of dataset configurations for token weighting.

Returns:
Merged dataset.
"""
if len(datasets) == 1:
return datasets[0]

# Check if token weighting should be used
if datasets_configs and _has_token_weighting(datasets_configs):
return _merge_datasets_with_token_weighting(datasets, datasets_configs, cfg)

LOG.info("Merging datasets...")
merged_dataset = concatenate_datasets(datasets)

Expand Down
8 changes: 8 additions & 0 deletions src/axolotl/utils/schemas/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ class SFTDataset(BaseModel):
drop_system_message: bool | None = None
trust_remote_code: bool | None = False
revision: str | None = None
weight: float | None = 1.0
weight_strategy: str | None = "upsample"

@model_validator(mode="before")
@classmethod
Expand Down Expand Up @@ -127,6 +129,8 @@ class DPODataset(BaseModel):
data_files: list[str] | None = None
revision: str | None = None
field_messages: str | None = None
weight: float | None = 1.0
weight_strategy: str | None = "upsample"


class StepwiseSupervisedDataset(BaseModel):
Expand All @@ -139,6 +143,8 @@ class StepwiseSupervisedDataset(BaseModel):
step_separator: str | None = None
max_completion_length: int | None = None
train_on_last_step_only: bool | None = None
weight: float | None = 1.0
weight_strategy: str | None = "upsample"


class UserDefinedKTOType(BaseModel):
Expand All @@ -161,6 +167,8 @@ class KTODataset(BaseModel):
data_files: list[str] | None = None
trust_remote_code: bool | None = False
revision: str | None = None
weight: float | None = 1.0
weight_strategy: str | None = "upsample"


DatasetConfig = SFTDataset | DPODataset | KTODataset | StepwiseSupervisedDataset
Loading
Loading