Skip to content

Commit daf5076

Browse files
committed
simplify dedup
1 parent 71b441b commit daf5076

File tree

4 files changed

+96
-115
lines changed

4 files changed

+96
-115
lines changed

src/axolotl/utils/data/rl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,8 @@ def prepare_preference_datasets(cfg: DictDefault) -> tuple[Dataset, Dataset | No
9292

9393
# Apply deduplication if configured
9494
if cfg.dataset_exact_deduplication:
95-
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
96-
train_dataset=train_dataset, eval_dataset=eval_dataset
95+
train_dataset, eval_dataset = deduplicate_and_log_datasets(
96+
dataset=train_dataset, other_dataset=eval_dataset
9797
)
9898

9999
return train_dataset, eval_dataset

src/axolotl/utils/data/sft.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -568,7 +568,7 @@ def _load_prepare_datasets(
568568
index=cfg.dataset_shard_idx,
569569
)
570570

571-
# Apply deduplication and create train/validation splits based on the split type
571+
# Apply deduplication and create train / validation splits based on the split type
572572
if split == "train":
573573
train_dataset, eval_dataset = _handle_train_split(dataset, cfg)
574574
else:
@@ -619,7 +619,7 @@ def _handle_train_split(
619619

620620
# No validation split - apply deduplication if needed and return as train dataset
621621
if cfg.dataset_exact_deduplication:
622-
train_dataset, _, _ = deduplicate_and_log_datasets(train_dataset=dataset)
622+
train_dataset, _ = deduplicate_and_log_datasets(dataset=dataset)
623623
else:
624624
train_dataset = dataset
625625

@@ -631,7 +631,7 @@ def _handle_test_split(
631631
) -> tuple[None, Dataset | None]:
632632
"""Handle processing for test split."""
633633
if cfg.dataset_exact_deduplication:
634-
_, eval_dataset, _ = deduplicate_and_log_datasets(eval_dataset=dataset)
634+
eval_dataset, _ = deduplicate_and_log_datasets(dataset=dataset)
635635
else:
636636
eval_dataset = dataset
637637

@@ -651,7 +651,7 @@ def _create_train_validation_split(
651651

652652
# Apply deduplication before splitting if configured
653653
if cfg.dataset_exact_deduplication:
654-
_, _, dataset = deduplicate_and_log_datasets(dataset=dataset)
654+
dataset, _ = deduplicate_and_log_datasets(dataset=dataset)
655655

656656
# Create the train/test split
657657
split_dataset = dataset.train_test_split(

src/axolotl/utils/data/utils.py

Lines changed: 71 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import hashlib
66
import time
77
from enum import Enum
8+
from typing import Callable
89

910
import huggingface_hub
1011
import numpy as np
@@ -29,7 +30,18 @@ class RetryStrategy(Enum):
2930

3031
def retry_on_request_exceptions(
3132
max_retries=3, delay=1, retry_strategy: RetryStrategy = RetryStrategy.LINEAR
32-
):
33+
) -> Callable:
34+
"""Decorator that retries function calls on specific request exceptions.
35+
36+
Args:
37+
max_retries: Maximum number of retry attempts.
38+
delay: Base delay between retries in seconds.
39+
retry_strategy: Strategy for calculating retry delays.
40+
41+
Returns:
42+
Decorated function with retry logic.
43+
"""
44+
3345
def decorator(func):
3446
@functools.wraps(func)
3547
def wrapper(*args, **kwargs): # pylint: disable=inconsistent-return-statements
@@ -58,106 +70,93 @@ def wrapper(*args, **kwargs): # pylint: disable=inconsistent-return-statements
5870

5971

6072
def md5(to_hash: str, encoding: str = "utf-8") -> str:
73+
"""Generate MD5 hash of a string."""
6174
try:
6275
return hashlib.md5(to_hash.encode(encoding), usedforsecurity=False).hexdigest()
6376
except TypeError:
6477
return hashlib.md5(to_hash.encode(encoding)).hexdigest() # nosec
6578

6679

6780
def sha256(to_hash: str, encoding: str = "utf-8") -> str:
81+
"""Generate SHA256 hash of a string."""
6882
return hashlib.sha256(to_hash.encode(encoding)).hexdigest()
6983

7084

71-
def deduplicate_dataset(
72-
dataset: Dataset, seen_hashes: dict[str, list[int]], other_dataset: Dataset = None
73-
) -> Dataset:
74-
unique_indices = []
85+
def _deduplicate_dataset(
86+
dataset: Dataset,
87+
seen_rows: set[str] | None = None,
88+
) -> tuple[Dataset, set[str]]:
89+
"""Remove duplicate rows from a dataset by storing row content directly.
90+
91+
Args:
92+
dataset: Dataset to deduplicate.
93+
seen_rows: Set of previously seen row strings (for cross-deduplication).
7594
95+
Returns:
96+
Tuple of deduplicated dataset and the set of seen rows.
97+
"""
98+
if seen_rows is None:
99+
seen_rows = set()
100+
101+
unique_indices = []
76102
for idx, row in enumerate(dataset):
77-
row_hash = sha256(str(row)) # Using SHA256 for collision resistance.
78-
if row_hash not in seen_hashes:
79-
seen_hashes[row_hash] = [idx]
103+
row_str = str(row)
104+
if row_str not in seen_rows:
105+
seen_rows.add(row_str)
80106
unique_indices.append(idx)
81-
else:
82-
# Check for collision by looking up the original dataset indices
83-
original_indices = seen_hashes[row_hash]
84-
is_duplicate = False
85-
for original_idx in original_indices:
86-
if (
87-
not idx == original_idx
88-
and original_idx < len(dataset)
89-
and str(dataset[original_idx]) == str(row)
90-
):
91-
is_duplicate = True
92-
break
93-
# Check in the other dataset if provided
94-
if other_dataset is not None:
95-
if original_idx < len(other_dataset) and str(
96-
other_dataset[original_idx]
97-
) == str(row):
98-
is_duplicate = True
99-
break
100-
if not is_duplicate:
101-
seen_hashes[row_hash].append(idx)
102-
unique_indices.append(idx)
103-
continue
104-
return dataset.select(unique_indices)
107+
108+
return dataset.select(unique_indices), seen_rows
105109

106110

107111
def deduplicate_and_log_datasets(
108-
*,
109-
train_dataset: Dataset | None = None,
110-
eval_dataset: Dataset | None = None,
111-
dataset: Dataset | None = None,
112-
) -> tuple[Dataset | None, Dataset | None, Dataset | None]:
113-
"""Deduplicates train, eval, and an optional dataset if provided, logging original
114-
and new sizes.
112+
dataset: Dataset,
113+
other_dataset: Dataset | None = None,
114+
dataset_name: str | None = "train",
115+
other_name: str | None = "eval",
116+
) -> tuple[Dataset, Dataset | None]:
117+
"""Deduplicate datasets, with optional cross-dataset deduplication.
118+
119+
Args:
120+
dataset: Primary dataset to deduplicate.
121+
other_dataset: Optional second dataset to deduplicate against the first.
122+
dataset_name: Name for the primary dataset (for logging).
123+
other_name: Name for the second dataset (for logging).
115124
116125
Returns:
117-
Deduplicated train, eval, and additional datasets.
126+
Tuple of (deduplicated_dataset, deduplicated_other_dataset).
118127
"""
119-
seen_hashes: dict[str, list[int]] = {}
128+
# Deduplicate primary dataset
129+
LOG.info(
130+
f"Starting deduplication for {dataset_name} dataset. Original size: {len(dataset)}"
131+
)
132+
dataset, seen_rows = _deduplicate_dataset(dataset)
133+
LOG.info(
134+
f"Deduplication complete for {dataset_name} dataset. New size: {len(dataset)}"
135+
)
120136

121-
# Handle cases where datasets are None
122-
if train_dataset is not None:
137+
# Deduplicate second dataset if provided
138+
if other_dataset is not None:
123139
LOG.info(
124-
f"Starting deduplication for train dataset. Original size: {len(train_dataset)}"
125-
)
126-
train_dataset = deduplicate_dataset(
127-
dataset=train_dataset, seen_hashes=seen_hashes
140+
f"Starting deduplication for {other_name} dataset. Original size: {len(other_dataset)}"
128141
)
142+
other_dataset, _ = _deduplicate_dataset(other_dataset, seen_rows)
129143
LOG.info(
130-
f"Deduplication complete for train dataset. New size: {len(train_dataset)}"
144+
f"Deduplication complete for {other_name} dataset. New size: {len(other_dataset)}"
131145
)
132-
else:
133-
LOG.info("Train dataset is None. Skipping deduplication.")
134146

135-
if eval_dataset is not None:
136-
LOG.info(
137-
f"Starting deduplication for eval dataset. Original size: {len(eval_dataset)}"
138-
)
139-
eval_dataset = deduplicate_dataset(
140-
dataset=eval_dataset, seen_hashes=seen_hashes, other_dataset=train_dataset
141-
)
142-
LOG.info(
143-
f"Deduplication complete for eval dataset. New size: {len(eval_dataset)}"
144-
)
145-
else:
146-
LOG.info("Eval dataset is None. Skipping deduplication.")
147+
return dataset, other_dataset
147148

148-
if dataset is not None and (eval_dataset is None and train_dataset is None):
149-
LOG.info(
150-
f"Starting deduplication for combined dataset. Original size: {len(dataset)}"
151-
)
152-
dataset = deduplicate_dataset(dataset=dataset, seen_hashes=seen_hashes)
153-
LOG.info(
154-
f"Deduplication complete for combined dataset. New size: {len(dataset)}"
155-
)
156149

157-
return train_dataset, eval_dataset, dataset
150+
def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault) -> Dataset:
151+
"""Remove sequences longer than configured maximum from dataset.
158152
153+
Args:
154+
dataset: Dataset to filter.
155+
cfg: Dictionary mapping `axolotl` config keys to values.
159156
160-
def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault):
157+
Returns:
158+
Filtered dataset with long sequences removed.
159+
"""
161160
if "input_ids" not in dataset.column_names:
162161
LOG.warning(
163162
"Dataset does not contain 'input_ids' column. Skip drop long seq. This is "

tests/test_exact_deduplication.py

Lines changed: 19 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -71,35 +71,13 @@ def setUp(self):
7171
self.expected_dataset = Dataset.from_dict(self.expected_data)
7272

7373
def test_deduplication(self):
74-
train_dataset, _, _ = deduplicate_and_log_datasets(train_dataset=self.dataset)
75-
_, eval_dataset, _ = deduplicate_and_log_datasets(eval_dataset=self.dataset)
76-
77-
verify_deduplication(train_dataset, self.expected_dataset, "train_dataset")
78-
verify_deduplication(eval_dataset, self.expected_dataset, "eval_dataset")
79-
80-
def test_datasets_are_none(self):
81-
# Test when both datasets are None
82-
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
83-
train_dataset=None, eval_dataset=None
74+
train_dataset, _ = deduplicate_and_log_datasets(dataset=self.dataset)
75+
eval_dataset, _ = deduplicate_and_log_datasets(
76+
dataset=self.dataset, dataset_name="eval"
8477
)
85-
self.assertIsNone(train_dataset, "Expected train_dataset to be None")
86-
self.assertIsNone(eval_dataset, "Expected eval_dataset to be None")
8778

88-
def test_only_train_is_none(self):
89-
# Test when only train_dataset is None
90-
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
91-
train_dataset=None, eval_dataset=self.dataset
92-
)
93-
self.assertIsNone(train_dataset, "Expected train_dataset to be None")
94-
verify_deduplication(eval_dataset, self.expected_dataset, "eval_dataset")
95-
96-
def test_only_eval_is_none(self):
97-
# Test when only eval_dataset is None
98-
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
99-
train_dataset=self.dataset, eval_dataset=None
100-
)
101-
self.assertIsNone(eval_dataset, "Expected eval_dataset to be None")
10279
verify_deduplication(train_dataset, self.expected_dataset, "train_dataset")
80+
verify_deduplication(eval_dataset, self.expected_dataset, "eval_dataset")
10381

10482
def test_exact_duplicates(self):
10583
# Test when datasets are exact duplicates
@@ -115,8 +93,10 @@ def test_exact_duplicates(self):
11593
expected_dataset = Dataset.from_dict(expected_data)
11694

11795
# Run deduplication
118-
train_dataset, _, _ = deduplicate_and_log_datasets(train_dataset=dataset)
119-
_, eval_dataset, _ = deduplicate_and_log_datasets(eval_dataset=dataset)
96+
train_dataset, _ = deduplicate_and_log_datasets(dataset=dataset)
97+
eval_dataset, _ = deduplicate_and_log_datasets(
98+
dataset=dataset, dataset_name="eval"
99+
)
120100

121101
verify_deduplication(train_dataset, expected_dataset, "train_dataset")
122102
verify_deduplication(eval_dataset, expected_dataset, "eval_dataset")
@@ -139,8 +119,10 @@ def test_partial_duplicates(self):
139119
expected_dataset = Dataset.from_dict(expected_data)
140120

141121
# Run deduplication
142-
train_dataset, _, _ = deduplicate_and_log_datasets(train_dataset=dataset)
143-
_, eval_dataset, _ = deduplicate_and_log_datasets(eval_dataset=dataset)
122+
train_dataset, _ = deduplicate_and_log_datasets(dataset=dataset)
123+
eval_dataset, _ = deduplicate_and_log_datasets(
124+
dataset=dataset, dataset_name="eval"
125+
)
144126

145127
verify_deduplication(train_dataset, expected_dataset, "train_dataset")
146128
verify_deduplication(eval_dataset, expected_dataset, "eval_dataset")
@@ -169,8 +151,8 @@ def test_combined_duplicates_empty(self):
169151
expected_dataset_eval = Dataset.from_dict(expected_data_eval)
170152

171153
# Run deduplication
172-
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
173-
train_dataset=dataset, eval_dataset=dataset
154+
train_dataset, eval_dataset = deduplicate_and_log_datasets(
155+
dataset=dataset, other_dataset=dataset
174156
)
175157

176158
verify_deduplication(train_dataset, expected_dataset_train, "train_dataset")
@@ -206,8 +188,8 @@ def test_combined_duplicates_one(self):
206188
expected_dataset_eval = Dataset.from_dict(expected_data_eval)
207189

208190
# Run deduplication
209-
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
210-
train_dataset=dataset_train, eval_dataset=dataset_eval
191+
train_dataset, eval_dataset = deduplicate_and_log_datasets(
192+
dataset=dataset_train, other_dataset=dataset_eval
211193
)
212194

213195
verify_deduplication(train_dataset, expected_dataset_train, "train_dataset")
@@ -441,8 +423,8 @@ def setUp(self):
441423
),
442424
)
443425
def test_deduplication_wrong_collision_train_eval(self, _mock_sha256):
444-
dedup_train, dedup_eval, _ = deduplicate_and_log_datasets(
445-
train_dataset=self.train_dataset, eval_dataset=self.eval_dataset
426+
dedup_train, dedup_eval = deduplicate_and_log_datasets(
427+
dataset=self.train_dataset, other_dataset=self.eval_dataset
446428
)
447429
self.assertEqual(
448430
len(dedup_train),
@@ -466,7 +448,7 @@ def test_deduplication_wrong_collision_train_eval(self, _mock_sha256):
466448
)
467449

468450
def test_deduplication_dataset_only(self):
469-
_, _, dedup_dataset = deduplicate_and_log_datasets(dataset=self.dataset)
451+
dedup_dataset, _ = deduplicate_and_log_datasets(dataset=self.dataset)
470452
self.assertEqual(
471453
len(dedup_dataset), 3, "Dataset should have all original values"
472454
)

0 commit comments

Comments
 (0)