5
5
import hashlib
6
6
import time
7
7
from enum import Enum
8
+ from typing import Callable
8
9
9
10
import huggingface_hub
10
11
import numpy as np
@@ -29,7 +30,18 @@ class RetryStrategy(Enum):
29
30
30
31
def retry_on_request_exceptions (
31
32
max_retries = 3 , delay = 1 , retry_strategy : RetryStrategy = RetryStrategy .LINEAR
32
- ):
33
+ ) -> Callable :
34
+ """Decorator that retries function calls on specific request exceptions.
35
+
36
+ Args:
37
+ max_retries: Maximum number of retry attempts.
38
+ delay: Base delay between retries in seconds.
39
+ retry_strategy: Strategy for calculating retry delays.
40
+
41
+ Returns:
42
+ Decorated function with retry logic.
43
+ """
44
+
33
45
def decorator (func ):
34
46
@functools .wraps (func )
35
47
def wrapper (* args , ** kwargs ): # pylint: disable=inconsistent-return-statements
@@ -58,106 +70,93 @@ def wrapper(*args, **kwargs): # pylint: disable=inconsistent-return-statements
58
70
59
71
60
72
def md5 (to_hash : str , encoding : str = "utf-8" ) -> str :
73
+ """Generate MD5 hash of a string."""
61
74
try :
62
75
return hashlib .md5 (to_hash .encode (encoding ), usedforsecurity = False ).hexdigest ()
63
76
except TypeError :
64
77
return hashlib .md5 (to_hash .encode (encoding )).hexdigest () # nosec
65
78
66
79
67
80
def sha256 (to_hash : str , encoding : str = "utf-8" ) -> str :
81
+ """Generate SHA256 hash of a string."""
68
82
return hashlib .sha256 (to_hash .encode (encoding )).hexdigest ()
69
83
70
84
71
- def deduplicate_dataset (
72
- dataset : Dataset , seen_hashes : dict [str , list [int ]], other_dataset : Dataset = None
73
- ) -> Dataset :
74
- unique_indices = []
85
+ def _deduplicate_dataset (
86
+ dataset : Dataset ,
87
+ seen_rows : set [str ] | None = None ,
88
+ ) -> tuple [Dataset , set [str ]]:
89
+ """Remove duplicate rows from a dataset by storing row content directly.
90
+
91
+ Args:
92
+ dataset: Dataset to deduplicate.
93
+ seen_rows: Set of previously seen row strings (for cross-deduplication).
75
94
95
+ Returns:
96
+ Tuple of deduplicated dataset and the set of seen rows.
97
+ """
98
+ if seen_rows is None :
99
+ seen_rows = set ()
100
+
101
+ unique_indices = []
76
102
for idx , row in enumerate (dataset ):
77
- row_hash = sha256 ( str (row )) # Using SHA256 for collision resistance.
78
- if row_hash not in seen_hashes :
79
- seen_hashes [ row_hash ] = [ idx ]
103
+ row_str = str (row )
104
+ if row_str not in seen_rows :
105
+ seen_rows . add ( row_str )
80
106
unique_indices .append (idx )
81
- else :
82
- # Check for collision by looking up the original dataset indices
83
- original_indices = seen_hashes [row_hash ]
84
- is_duplicate = False
85
- for original_idx in original_indices :
86
- if (
87
- not idx == original_idx
88
- and original_idx < len (dataset )
89
- and str (dataset [original_idx ]) == str (row )
90
- ):
91
- is_duplicate = True
92
- break
93
- # Check in the other dataset if provided
94
- if other_dataset is not None :
95
- if original_idx < len (other_dataset ) and str (
96
- other_dataset [original_idx ]
97
- ) == str (row ):
98
- is_duplicate = True
99
- break
100
- if not is_duplicate :
101
- seen_hashes [row_hash ].append (idx )
102
- unique_indices .append (idx )
103
- continue
104
- return dataset .select (unique_indices )
107
+
108
+ return dataset .select (unique_indices ), seen_rows
105
109
106
110
107
111
def deduplicate_and_log_datasets (
108
- * ,
109
- train_dataset : Dataset | None = None ,
110
- eval_dataset : Dataset | None = None ,
111
- dataset : Dataset | None = None ,
112
- ) -> tuple [Dataset | None , Dataset | None , Dataset | None ]:
113
- """Deduplicates train, eval, and an optional dataset if provided, logging original
114
- and new sizes.
112
+ dataset : Dataset ,
113
+ other_dataset : Dataset | None = None ,
114
+ dataset_name : str | None = "train" ,
115
+ other_name : str | None = "eval" ,
116
+ ) -> tuple [Dataset , Dataset | None ]:
117
+ """Deduplicate datasets, with optional cross-dataset deduplication.
118
+
119
+ Args:
120
+ dataset: Primary dataset to deduplicate.
121
+ other_dataset: Optional second dataset to deduplicate against the first.
122
+ dataset_name: Name for the primary dataset (for logging).
123
+ other_name: Name for the second dataset (for logging).
115
124
116
125
Returns:
117
- Deduplicated train, eval, and additional datasets .
126
+ Tuple of (deduplicated_dataset, deduplicated_other_dataset) .
118
127
"""
119
- seen_hashes : dict [str , list [int ]] = {}
128
+ # Deduplicate primary dataset
129
+ LOG .info (
130
+ f"Starting deduplication for { dataset_name } dataset. Original size: { len (dataset )} "
131
+ )
132
+ dataset , seen_rows = _deduplicate_dataset (dataset )
133
+ LOG .info (
134
+ f"Deduplication complete for { dataset_name } dataset. New size: { len (dataset )} "
135
+ )
120
136
121
- # Handle cases where datasets are None
122
- if train_dataset is not None :
137
+ # Deduplicate second dataset if provided
138
+ if other_dataset is not None :
123
139
LOG .info (
124
- f"Starting deduplication for train dataset. Original size: { len (train_dataset )} "
125
- )
126
- train_dataset = deduplicate_dataset (
127
- dataset = train_dataset , seen_hashes = seen_hashes
140
+ f"Starting deduplication for { other_name } dataset. Original size: { len (other_dataset )} "
128
141
)
142
+ other_dataset , _ = _deduplicate_dataset (other_dataset , seen_rows )
129
143
LOG .info (
130
- f"Deduplication complete for train dataset. New size: { len (train_dataset )} "
144
+ f"Deduplication complete for { other_name } dataset. New size: { len (other_dataset )} "
131
145
)
132
- else :
133
- LOG .info ("Train dataset is None. Skipping deduplication." )
134
146
135
- if eval_dataset is not None :
136
- LOG .info (
137
- f"Starting deduplication for eval dataset. Original size: { len (eval_dataset )} "
138
- )
139
- eval_dataset = deduplicate_dataset (
140
- dataset = eval_dataset , seen_hashes = seen_hashes , other_dataset = train_dataset
141
- )
142
- LOG .info (
143
- f"Deduplication complete for eval dataset. New size: { len (eval_dataset )} "
144
- )
145
- else :
146
- LOG .info ("Eval dataset is None. Skipping deduplication." )
147
+ return dataset , other_dataset
147
148
148
- if dataset is not None and (eval_dataset is None and train_dataset is None ):
149
- LOG .info (
150
- f"Starting deduplication for combined dataset. Original size: { len (dataset )} "
151
- )
152
- dataset = deduplicate_dataset (dataset = dataset , seen_hashes = seen_hashes )
153
- LOG .info (
154
- f"Deduplication complete for combined dataset. New size: { len (dataset )} "
155
- )
156
149
157
- return train_dataset , eval_dataset , dataset
150
+ def drop_long_seq_in_dataset (dataset : Dataset , cfg : DictDefault ) -> Dataset :
151
+ """Remove sequences longer than configured maximum from dataset.
158
152
153
+ Args:
154
+ dataset: Dataset to filter.
155
+ cfg: Dictionary mapping `axolotl` config keys to values.
159
156
160
- def drop_long_seq_in_dataset (dataset : Dataset , cfg : DictDefault ):
157
+ Returns:
158
+ Filtered dataset with long sequences removed.
159
+ """
161
160
if "input_ids" not in dataset .column_names :
162
161
LOG .warning (
163
162
"Dataset does not contain 'input_ids' column. Skip drop long seq. This is "
0 commit comments