Skip to content

Commit d523857

Browse files
committed
coderabbit comments
1 parent 669579a commit d523857

File tree

3 files changed

+19
-19
lines changed

3 files changed

+19
-19
lines changed

src/axolotl/utils/data/rl.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,10 @@ def _load_datasets():
6969

7070
# Prepare datasets (with file locking logic for multiple ranks)
7171
loader = FileLockLoader(cfg)
72-
train_dataset, eval_dataset = loader.load(_load_datasets)
73-
loader.cleanup()
72+
try:
73+
train_dataset, eval_dataset = loader.load(_load_datasets)
74+
finally:
75+
loader.cleanup()
7476

7577
# Apply deduplication if configured
7678
if cfg.dataset_exact_deduplication:
@@ -187,10 +189,10 @@ def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:
187189
Returns:
188190
Combined and processed dataset for the specified split.
189191
"""
190-
datasets = cfg.datasets if split == "train" else cfg.test_datasets
192+
datasets_configs = cfg.datasets if split == "train" else cfg.test_datasets
191193
split_datasets: list[Dataset | DatasetDict] = []
192194

193-
for dataset_config in datasets_with_name_generator(datasets):
195+
for dataset_config in datasets_with_name_generator(datasets_configs):
194196
dataset: Dataset | DatasetDict = load_dataset_with_config(
195197
dataset_config, cfg.hf_use_auth_token, streaming=False
196198
)
@@ -199,7 +201,7 @@ def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:
199201
tokenizer = load_tokenizer(cfg)
200202

201203
for i, data_set in enumerate(split_datasets):
202-
_type = datasets[i]["type"]
204+
_type = datasets_configs[i]["type"]
203205
if _type:
204206
if isinstance(_type, DictDefault):
205207
_type = "user_defined.default"
@@ -246,7 +248,7 @@ def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:
246248
if not cfg.skip_prepare_dataset:
247249
# Save preprocessed dataset
248250
dataset_hash = generate_dataset_hash_from_config(
249-
cfg, cfg.datasets, tokenizer.name_or_path
251+
cfg, datasets_configs, tokenizer.name_or_path
250252
)
251253
save_preprocessed_dataset(cfg, dataset, dataset_hash, split)
252254

src/axolotl/utils/data/sft.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,10 @@ def _load_datasets():
9999

100100
# Prepare datasets (with file locking logic for multiple ranks)
101101
loader = FileLockLoader(cfg)
102-
train_dataset, eval_dataset, prompters = loader.load(_load_datasets)
103-
loader.cleanup()
102+
try:
103+
train_dataset, eval_dataset, prompters = loader.load(_load_datasets)
104+
finally:
105+
loader.cleanup()
104106

105107
# Validate sample packing configuration for evaluation
106108
if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False:
@@ -265,11 +267,11 @@ def _load_tokenized_prepared_datasets(
265267
Tuple of (dataset, prompters list).
266268
"""
267269
# Select correct dataset configuration based on split
268-
datasets_config = cfg.datasets if split == "train" else cfg.test_datasets
270+
datasets_configs = cfg.datasets if split == "train" else cfg.test_datasets
269271

270272
# Generate dataset hash for caching
271273
dataset_hash = generate_dataset_hash_from_config(
272-
cfg, datasets_config, tokenizer.name_or_path
274+
cfg, datasets_configs, tokenizer.name_or_path
273275
)
274276

275277
# Try loading from hub if push_dataset_to_hub is configured
@@ -286,7 +288,7 @@ def _load_tokenized_prepared_datasets(
286288
if dataset is None:
287289
dataset, prompters = _load_raw_datasets(
288290
cfg,
289-
datasets_config,
291+
datasets_configs,
290292
tokenizer,
291293
split,
292294
processor,
@@ -298,7 +300,7 @@ def _load_tokenized_prepared_datasets(
298300

299301
def _load_raw_datasets(
300302
cfg: DictDefault,
301-
cfg_datasets: list,
303+
datasets_configs: list,
302304
tokenizer: PreTrainedTokenizer,
303305
split: str,
304306
processor: ProcessorMixin | None = None,
@@ -315,7 +317,7 @@ def _load_raw_datasets(
315317
# Load and process individual datasets
316318
datasets = []
317319
prompters = []
318-
for dataset_config in datasets_with_name_generator(cfg_datasets):
320+
for dataset_config in datasets_with_name_generator(datasets_configs):
319321
dataset_wrapper, dataset_prompter = _load_and_process_single_dataset(
320322
dataset_config=dataset_config,
321323
cfg=cfg,
@@ -338,7 +340,7 @@ def _load_raw_datasets(
338340

339341
# Save the prepared dataset
340342
dataset_hash = generate_dataset_hash_from_config(
341-
cfg, cfg.datasets, tokenizer.name_or_path
343+
cfg, datasets_configs, tokenizer.name_or_path
342344
)
343345
save_preprocessed_dataset(cfg, dataset, dataset_hash, split)
344346

tests/e2e/multigpu/test_locking.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -90,11 +90,7 @@ def worker():
9090
def test_load_waiting_for_ready_flag(self, mock_sleep, loader):
9191
"""Test that processes wait for the ready flag to appear."""
9292
mock_load_fn = Mock(return_value="waiting_data")
93-
94-
# Create a mock path object with controllable exists() behavior
95-
mock_ready_flag_path = Mock()
96-
97-
# Track exists() calls
93+
mock_ready_flag_path = Path(tempfile.mktemp())
9894
exists_call_count = 0
9995

10096
def mock_exists():

0 commit comments

Comments
 (0)