Skip to content

Make De-duplication Multi-threaded and Happen Only During Pre-processing #2747

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
18 changes: 8 additions & 10 deletions src/axolotl/utils/data/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,11 @@ def load_tokenized_prepared_datasets(
else:
LOG.debug("NOT shuffling merged datasets")

if cfg.dataset_exact_deduplication:
_, _, dataset = deduplicate_and_log_datasets(
dataset=dataset, num_proc=cfg.dataset_processes
)

if not cfg.skip_prepare_dataset:
dataset = drop_long_seq_in_dataset(dataset, cfg)

Expand Down Expand Up @@ -438,8 +443,7 @@ def load_prepare_datasets(
)
train_fingerprint = md5(to_hash_train)
test_fingerprint = md5(to_hash_test)
if cfg.dataset_exact_deduplication:
_, _, dataset = deduplicate_and_log_datasets(dataset=dataset)

dataset = dataset.train_test_split(
test_size=val_set_size,
shuffle=False,
Expand All @@ -451,16 +455,10 @@ def load_prepare_datasets(
train_dataset = dataset["train"]
eval_dataset = dataset["test"]
elif split == "test":
if cfg.dataset_exact_deduplication:
_, eval_dataset, _ = deduplicate_and_log_datasets(eval_dataset=dataset)
else:
eval_dataset = dataset
eval_dataset = dataset
train_dataset = None
else:
if cfg.dataset_exact_deduplication:
train_dataset, _, _ = deduplicate_and_log_datasets(train_dataset=dataset)
else:
train_dataset = dataset
train_dataset = dataset
eval_dataset = None
return train_dataset, eval_dataset, prompters

Expand Down
83 changes: 45 additions & 38 deletions src/axolotl/utils/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import hashlib
import time
from enum import Enum
from typing import Optional

import huggingface_hub
import numpy as np
Expand Down Expand Up @@ -69,39 +70,48 @@ def sha256(to_hash: str, encoding: str = "utf-8") -> str:
return hashlib.sha256(to_hash.encode(encoding)).hexdigest()


def compute_row_hash(example):
return {"row_hash": sha256(str(example))}


def deduplicate_dataset(
dataset: Dataset, seen_hashes: dict[str, list[int]], other_dataset: Dataset = None
dataset: Dataset,
other_dataset: Dataset = None,
num_proc: Optional[int] = None,
) -> Dataset:
unique_indices = []

for idx, row in enumerate(dataset):
row_hash = sha256(str(row)) # Using SHA256 for collision resistance.
if row_hash not in seen_hashes:
seen_hashes[row_hash] = [idx]
unique_indices.append(idx)
else:
# Check for collision by looking up the original dataset indices
original_indices = seen_hashes[row_hash]
is_duplicate = False
for original_idx in original_indices:
if (
not idx == original_idx
and original_idx < len(dataset)
and str(dataset[original_idx]) == str(row)
):
is_duplicate = True
break
# Check in the other dataset if provided
if other_dataset is not None:
if original_idx < len(other_dataset) and str(
other_dataset[original_idx]
) == str(row):
is_duplicate = True
break
if not is_duplicate:
seen_hashes[row_hash].append(idx)
unique_indices.append(idx)
continue
if dataset is None:
LOG.warning("dataset is None. De-duplication cannot be performed.")
return dataset

# Get SHA-256 hashes for all samples in dataset
hashed = dataset.map(
compute_row_hash, remove_columns=dataset.column_names, num_proc=num_proc
)
hashes = hashed["row_hash"]
del hashed

# Get SHA-256 hashes for all samples in other_dataset (if it exists)
other_hashes = set()
if other_dataset is not None:
other_hashed = other_dataset.map(
compute_row_hash,
remove_columns=other_dataset.column_names,
num_proc=num_proc,
)
other_hashes = set(other_hashed["row_hash"])
del other_hashed

# Find all non-duplicate samples based on the hashes, saving all unique indices
seen_hashes, unique_indices = set(), set()
for idx, row_hash in enumerate(hashes):
if row_hash in seen_hashes or row_hash in other_hashes:
continue
seen_hashes.add(row_hash)
unique_indices.add(idx)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be worth dropping the added row_hash column here now that it's done.. The previous implementation did not modify the dataset.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this shouldn't be modifying it either. the only change that happens to dataset is at the end when it returns with .select. the row_hash stuff comes from a copy in hashed and other_hashed (which I should probably make delete after grabbing the hashes)

del hashes, other_hashes, seen_hashes

# Return only non-duplicate samples based on the found unique indices
return dataset.select(unique_indices)


Expand All @@ -110,23 +120,20 @@ def deduplicate_and_log_datasets(
train_dataset: Dataset = None,
eval_dataset: Dataset = None,
dataset: Dataset = None,
num_proc: Optional[int] = None,
) -> tuple[Dataset, Dataset, Dataset]:
"""
Deduplicates train, eval, and an optional dataset if provided, logging original and new sizes.

Returns:
tuple: Deduplicated train, eval, and additional datasets.
"""
seen_hashes: dict[str, list[int]] = {}

# Handle cases where datasets are None
if train_dataset is not None:
LOG.info(
f"Starting deduplication for train dataset. Original size: {len(train_dataset)}"
)
train_dataset = deduplicate_dataset(
dataset=train_dataset, seen_hashes=seen_hashes
)
train_dataset = deduplicate_dataset(dataset=train_dataset, num_proc=num_proc)
LOG.info(
f"Deduplication complete for train dataset. New size: {len(train_dataset)}"
)
Expand All @@ -138,7 +145,7 @@ def deduplicate_and_log_datasets(
f"Starting deduplication for eval dataset. Original size: {len(eval_dataset)}"
)
eval_dataset = deduplicate_dataset(
dataset=eval_dataset, seen_hashes=seen_hashes, other_dataset=train_dataset
dataset=eval_dataset, other_dataset=train_dataset, num_proc=num_proc
)
LOG.info(
f"Deduplication complete for eval dataset. New size: {len(eval_dataset)}"
Expand All @@ -150,7 +157,7 @@ def deduplicate_and_log_datasets(
LOG.info(
f"Starting deduplication for combined dataset. Original size: {len(dataset)}"
)
dataset = deduplicate_dataset(dataset=dataset, seen_hashes=seen_hashes)
dataset = deduplicate_dataset(dataset=dataset, num_proc=num_proc)
LOG.info(
f"Deduplication complete for combined dataset. New size: {len(dataset)}"
)
Expand Down
Loading