Skip to content

Commit cc802af

Browse files
committed
moving dataset save fn to shared module
1 parent 037ae56 commit cc802af

File tree

4 files changed

+97
-99
lines changed

4 files changed

+97
-99
lines changed

src/axolotl/utils/data/lock.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
LOCK_FILE_NAME = "datasets_prep.lock"
1313
READY_FILE_NAME = "datasets_ready.flag"
14+
PROCESS_COUNTER_FILE_NAME = "process_counter.txt"
1415

1516

1617
class FileLockLoader:
@@ -27,10 +28,14 @@ def __init__(self, cfg: DictDefault):
2728
)
2829
self.lock_file_path = Path(self.dataset_prepared_path) / LOCK_FILE_NAME
2930
self.ready_flag_path = Path(self.dataset_prepared_path) / READY_FILE_NAME
30-
self.counter_path = Path(self.dataset_prepared_path) / "process_counter.txt"
31+
self.counter_path = Path(self.dataset_prepared_path) / PROCESS_COUNTER_FILE_NAME
3132

3233
def load(self, load_fn: Callable[[], Any]) -> Any:
34+
import torch.distributed as dist
35+
3336
with FileLock(str(self.lock_file_path)):
37+
print(f"FileLock acquired by rank {dist.get_rank()}")
38+
3439
# Increment process counter
3540
self._increment_counter()
3641

src/axolotl/utils/data/rl.py

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""Data handling specific to RL trainers."""
22

33
import inspect
4-
import os
54
from functools import partial
65
from typing import Any, Callable, Literal
76

@@ -13,15 +12,15 @@
1312
from axolotl.prompt_strategies.kto import load as load_kto
1413
from axolotl.prompt_strategies.orpo import load as load_orpo
1514
from axolotl.utils.data.lock import FileLockLoader
16-
from axolotl.utils.data.sft import _try_load_from_hub
1715
from axolotl.utils.data.shared import (
1816
create_train_validation_split,
1917
datasets_with_name_generator,
2018
generate_dataset_hash_from_config,
21-
get_prepared_dataset_path,
2219
load_dataset_with_config,
2320
load_preprocessed_dataset,
2421
merge_datasets,
22+
save_preprocessed_dataset,
23+
try_load_from_hub,
2524
)
2625
from axolotl.utils.data.utils import (
2726
deduplicate_and_log_datasets,
@@ -82,22 +81,6 @@ def _load_datasets():
8281
return train_dataset, eval_dataset
8382

8483

85-
def _save_preprocessed_dataset(
86-
cfg: DictDefault, dataset: Dataset, dataset_hash: str
87-
) -> None:
88-
"""Save preprocessed dataset to disk.
89-
90-
Args:
91-
cfg: Configuration object.
92-
dataset: Dataset to save.
93-
dataset_hash: Hash identifying the dataset configuration.
94-
"""
95-
prepared_ds_path = get_prepared_dataset_path(cfg, dataset_hash)
96-
LOG.info(f"Saving prepared dataset to disk... {prepared_ds_path}")
97-
os.makedirs(prepared_ds_path, exist_ok=True)
98-
dataset.save_to_disk(str(prepared_ds_path))
99-
100-
10184
def _map_dataset(
10285
cfg: DictDefault,
10386
dataset: Dataset | DatasetDict,
@@ -265,7 +248,7 @@ def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:
265248
dataset_hash = generate_dataset_hash_from_config(
266249
cfg, cfg.datasets, tokenizer.name_or_path
267250
)
268-
_save_preprocessed_dataset(cfg, dataset, dataset_hash)
251+
save_preprocessed_dataset(cfg, dataset, dataset_hash, split)
269252

270253
return dataset
271254

@@ -295,7 +278,7 @@ def _load_or_create_dataset_split(
295278
# Try loading from hub if push_dataset_to_hub is configured
296279
dataset = None
297280
if cfg.push_dataset_to_hub:
298-
dataset = _try_load_from_hub(cfg, dataset_hash, split)
281+
dataset = try_load_from_hub(cfg, dataset_hash, split)
299282

300283
# Attempt to load preprocessed dataset
301284
if dataset is None:

src/axolotl/utils/data/sft.py

Lines changed: 6 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
"""Data handling specific to SFT."""
22

33
import functools
4-
import os
54
import tempfile
6-
from typing import Any, Generator, Literal
5+
from typing import Literal
76

87
from datasets import (
98
Dataset,
@@ -20,10 +19,11 @@
2019
create_train_validation_split,
2120
datasets_with_name_generator,
2221
generate_dataset_hash_from_config,
23-
get_prepared_dataset_path,
2422
load_dataset_with_config,
2523
load_preprocessed_dataset,
2624
merge_datasets,
25+
save_preprocessed_dataset,
26+
try_load_from_hub,
2727
)
2828
from axolotl.utils.data.utils import (
2929
deduplicate_and_log_datasets,
@@ -275,7 +275,7 @@ def _load_tokenized_prepared_datasets(
275275
# Try loading from hub if push_dataset_to_hub is configured
276276
dataset = None
277277
if cfg.push_dataset_to_hub:
278-
dataset = _try_load_from_hub(cfg, dataset_hash, split)
278+
dataset = try_load_from_hub(cfg, dataset_hash, split)
279279

280280
# If not found on hub, try loading from disk
281281
if dataset is None:
@@ -296,71 +296,6 @@ def _load_tokenized_prepared_datasets(
296296
return dataset, prompters
297297

298298

299-
def _try_load_from_hub(
300-
cfg: DictDefault, dataset_hash: str, split: str
301-
) -> Dataset | None:
302-
"""Try to load the prepared dataset from HuggingFace Hub."""
303-
try:
304-
LOG.info(
305-
"Attempting to load prepared dataset from HuggingFace Hub at "
306-
f"{cfg.push_dataset_to_hub} (version {dataset_hash})..."
307-
)
308-
dataset = load_dataset(
309-
cfg.push_dataset_to_hub,
310-
dataset_hash,
311-
token=cfg.hf_use_auth_token,
312-
)
313-
return dataset[split]
314-
except Exception: # pylint: disable=broad-except # nosec
315-
LOG.info("Unable to find prepared dataset in HuggingFace Hub")
316-
return None
317-
318-
319-
def _generate_from_iterable_dataset(
320-
dataset: IterableDataset, worker_id: list[int], num_workers: list[int]
321-
) -> Generator[Any, None, None]:
322-
"""Generator function to correctly split the dataset for each worker"""
323-
for i, item in enumerate(dataset):
324-
if i % num_workers[0] == worker_id[0]:
325-
yield item
326-
327-
328-
def _save_preprocessed_dataset(
329-
cfg: DictDefault,
330-
dataset: Dataset,
331-
dataset_hash: str,
332-
split: str,
333-
) -> None:
334-
prepared_ds_path = get_prepared_dataset_path(cfg, dataset_hash)
335-
if isinstance(dataset, IterableDataset):
336-
num_workers = cfg.dataset_processes
337-
338-
ds_from_iter = Dataset.from_generator(
339-
functools.partial(_generate_from_iterable_dataset, dataset),
340-
features=dataset.features,
341-
num_proc=num_workers,
342-
split=split,
343-
gen_kwargs={
344-
"worker_id": list(range(num_workers)),
345-
"num_workers": [num_workers] * num_workers,
346-
},
347-
)
348-
ds_from_iter.save_to_disk(str(prepared_ds_path))
349-
else:
350-
os.makedirs(prepared_ds_path, exist_ok=True)
351-
dataset.save_to_disk(str(prepared_ds_path))
352-
if cfg.push_dataset_to_hub:
353-
LOG.info(
354-
"Pushing merged prepared dataset to Huggingface hub at "
355-
f"{cfg.push_dataset_to_hub} (version {dataset_hash})..."
356-
)
357-
dataset.push_to_hub(
358-
cfg.push_dataset_to_hub,
359-
dataset_hash,
360-
private=True,
361-
)
362-
363-
364299
def _load_raw_datasets(
365300
cfg: DictDefault,
366301
cfg_datasets: list,
@@ -370,7 +305,7 @@ def _load_raw_datasets(
370305
preprocess_iterable: bool = False,
371306
) -> tuple[Dataset, list[Prompter | None]]:
372307
"""Load, process, merge, and save raw datasets."""
373-
LOG.info("Loading raw datasets...")
308+
LOG.info("Loading raw datasets...", main_process_only=False)
374309
if not cfg.is_preprocess:
375310
LOG.warning(
376311
"Processing datasets during training can lead to VRAM instability. Please "
@@ -405,7 +340,7 @@ def _load_raw_datasets(
405340
dataset_hash = generate_dataset_hash_from_config(
406341
cfg, cfg.datasets, tokenizer.name_or_path
407342
)
408-
_save_preprocessed_dataset(cfg, dataset, dataset_hash, split)
343+
save_preprocessed_dataset(cfg, dataset, dataset_hash, split)
409344

410345
return dataset, prompters
411346

src/axolotl/utils/data/shared.py

Lines changed: 81 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22

33
from __future__ import annotations
44

5+
import functools
6+
import os
57
from pathlib import Path
6-
from typing import TYPE_CHECKING, Generator
8+
from typing import TYPE_CHECKING, Any, Generator
79

810
from datasets import (
911
Dataset,
@@ -391,6 +393,53 @@ def create_train_validation_split(
391393
return split_dataset["train"], split_dataset["test"]
392394

393395

396+
def _generate_from_iterable_dataset(
397+
dataset: IterableDataset, worker_id: list[int], num_workers: list[int]
398+
) -> Generator[Any, None, None]:
399+
"""Generator function to correctly split the dataset for each worker"""
400+
for i, item in enumerate(dataset):
401+
if i % num_workers[0] == worker_id[0]:
402+
yield item
403+
404+
405+
def save_preprocessed_dataset(
406+
cfg: DictDefault,
407+
dataset: Dataset,
408+
dataset_hash: str,
409+
split: str,
410+
) -> None:
411+
"""Save preprocessed dataset to disk and optionally push to the HF Hub."""
412+
prepared_ds_path = get_prepared_dataset_path(cfg, dataset_hash)
413+
if isinstance(dataset, IterableDataset):
414+
num_workers = cfg.dataset_processes
415+
416+
ds_from_iter = Dataset.from_generator(
417+
functools.partial(_generate_from_iterable_dataset, dataset),
418+
features=dataset.features,
419+
num_proc=num_workers,
420+
split=split,
421+
gen_kwargs={
422+
"worker_id": list(range(num_workers)),
423+
"num_workers": [num_workers] * num_workers,
424+
},
425+
)
426+
ds_from_iter.save_to_disk(str(prepared_ds_path))
427+
else:
428+
os.makedirs(prepared_ds_path, exist_ok=True)
429+
dataset.save_to_disk(str(prepared_ds_path))
430+
if cfg.push_dataset_to_hub:
431+
LOG.info(
432+
"Pushing merged prepared dataset to Huggingface hub at "
433+
f"{cfg.push_dataset_to_hub} (version {dataset_hash})...",
434+
main_process_only=False,
435+
)
436+
dataset.push_to_hub(
437+
cfg.push_dataset_to_hub,
438+
dataset_hash,
439+
private=True,
440+
)
441+
442+
394443
def load_preprocessed_dataset(cfg: DictDefault, dataset_hash: str) -> Dataset | None:
395444
"""Load preprocessed dataset from disk if available.
396445
@@ -409,13 +458,39 @@ def load_preprocessed_dataset(cfg: DictDefault, dataset_hash: str) -> Dataset |
409458
and not cfg.skip_prepare_dataset
410459
and not cfg.is_preprocess
411460
):
412-
LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
461+
LOG.info(
462+
f"Loading prepared dataset from disk at {prepared_ds_path}...",
463+
main_process_only=False,
464+
)
413465
return load_from_disk(str(prepared_ds_path))
414466

415-
LOG.info(f"Unable to find prepared dataset in {prepared_ds_path}")
467+
LOG.info(
468+
f"Unable to find prepared dataset in {prepared_ds_path}",
469+
main_process_only=False,
470+
)
416471
return None
417472

418473

474+
def try_load_from_hub(
475+
cfg: DictDefault, dataset_hash: str, split: str
476+
) -> Dataset | None:
477+
"""Try to load the prepared dataset from HuggingFace Hub."""
478+
try:
479+
LOG.info(
480+
"Attempting to load prepared dataset from HuggingFace Hub at "
481+
f"{cfg.push_dataset_to_hub} (version {dataset_hash})..."
482+
)
483+
dataset = load_dataset(
484+
cfg.push_dataset_to_hub,
485+
dataset_hash,
486+
token=cfg.hf_use_auth_token,
487+
)
488+
return dataset[split]
489+
except Exception: # pylint: disable=broad-except # nosec
490+
LOG.info("Unable to find prepared dataset in HuggingFace Hub")
491+
return None
492+
493+
419494
def generate_dataset_hash_from_config(
420495
cfg: DictDefault, cfg_datasets: list, tokenizer_name: str
421496
) -> str:
@@ -451,13 +526,13 @@ def merge_datasets(datasets: list[Dataset], cfg: DictDefault) -> Dataset:
451526
if len(datasets) == 1:
452527
return datasets[0]
453528

454-
LOG.info("Merging datasets")
529+
LOG.info("Merging datasets...")
455530
merged_dataset = concatenate_datasets(datasets)
456531

457532
if cfg.shuffle_merged_datasets:
458-
LOG.debug("Shuffle merged datasets")
533+
LOG.debug("Shuffling merged datasets...")
459534
merged_dataset = merged_dataset.shuffle(seed=cfg.seed)
460535
else:
461-
LOG.debug("NOT shuffling merged datasets")
536+
LOG.debug("Not shuffling merged datasets.")
462537

463538
return merged_dataset

0 commit comments

Comments
 (0)