Skip to content

Add HuggingFaceCheckpointer option for only registering final checkpoint #1516

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 92 additions & 44 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,9 @@ class HuggingFaceCheckpointer(Callback):
keys ``input_example`` and ``signature``.
flatten_imports (Sequence[str]): A sequence of import prefixes that will
be flattened when editing MPT files.
final_register_only (bool): If true, only register the model in the MLFlow
registry on the last batch and do not save the HuggingFace checkpoint. If
registration fails, then we will fallback to saving the HuggingFace checkpoint.
"""

def __init__(
Expand All @@ -173,6 +176,7 @@ def __init__(
mlflow_registered_model_name: Optional[str] = None,
mlflow_logging_config: Optional[dict] = None,
flatten_imports: Sequence[str] = ('llmfoundry',),
final_register_only: bool = False,
):
_, _, self.save_dir_format_str = parse_uri(save_folder)
self.overwrite = overwrite
Expand All @@ -184,6 +188,7 @@ def __init__(
}[precision]
self.flatten_imports = flatten_imports
self.using_peft = False
self.final_register_only = final_register_only

# mlflow config setup
self.mlflow_registered_model_name = mlflow_registered_model_name
Expand Down Expand Up @@ -249,7 +254,7 @@ def __init__(
self.last_checkpoint_batch: Optional[Time] = None
self.mlflow_loggers = []

self.child_processes: list[SpawnProcess] = []
self.register_processes: list[SpawnProcess] = []
# Temporary save directory used by child_processes.
self.temp_save_dir = None

Expand All @@ -259,7 +264,18 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None:
state,
event,
) and self.last_checkpoint_batch != state.timestamp.batch:
self._save_checkpoint(state, logger)
is_last_batch = self._is_last_batch(state)
self._save_checkpoint(
state,
logger,
register_to_mflow=(
self.mlflow_registered_model_name is not None and
is_last_batch
),
upload_to_save_folder=not (
self.final_register_only and is_last_batch
),
)
elif event == Event.INIT:
if not isinstance(state.model, HuggingFaceModel):
raise ValueError(
Expand Down Expand Up @@ -300,14 +316,26 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None:
# Wait for all child processes spawned by the callback to finish.
timeout = 3600
wait_start = time.time()
while not self._all_child_processes_done():
while not self._all_register_processes_done():
wait_time = time.time() - wait_start
if wait_time > timeout:
raise TimeoutError(
f'Waited {wait_time} seconds for child processes to complete. Exceeded timeout of {timeout} seconds.',
)
time.sleep(2)

if self._any_register_processes_error(
) and self.final_register_only:
log.error(
'An error occurred in one or more registration processes. Fallback to saving the HuggingFace checkpoint.',
)
self._save_checkpoint(
state,
logger,
upload_to_save_folder=True,
register_to_mflow=False,
)

# Clean up temporary save directory; all processes are done with it.
if self.temp_save_dir is not None:
shutil.rmtree(self.temp_save_dir)
Expand Down Expand Up @@ -339,12 +367,23 @@ def _is_last_batch(self, state: State):

return False

def _all_child_processes_done(self) -> bool:
not_done = any(process.is_alive() for process in self.child_processes)
def _all_register_processes_done(self) -> bool:
not_done = any(
process.is_alive() for process in self.register_processes
)
x = torch.tensor(1 if not_done else 0).to(device='cuda')
dist.all_reduce(x, reduce_operation='MAX')
return x.item() == 0

def _any_register_processes_error(self) -> bool:
has_errors = any(
process.exitcode is not None and process.exitcode != 0
for process in self.register_processes
)
x = torch.tensor(1 if has_errors else 0).to(device='cuda')
dist.all_reduce(x, reduce_operation='MAX')
return x.item() == 1

def transform_model_and_tokenizer(
self,
model: PreTrainedModel,
Expand Down Expand Up @@ -412,7 +451,13 @@ def transform_model_pre_registration(
"""
return model

def _save_checkpoint(self, state: State, logger: Logger):
def _save_checkpoint(
self,
state: State,
logger: Logger,
upload_to_save_folder: bool,
register_to_mflow: bool,
):
del logger # unused

self.last_checkpoint_batch = state.timestamp.batch
Expand Down Expand Up @@ -548,50 +593,53 @@ def tensor_hook(
].base_model_name_or_path = self.pretrained_model_name

log.debug('Saving Hugging Face checkpoint to disk')
# This context manager casts the TE extra state in io.BytesIO format to tensor format
# Needed for proper hf ckpt saving.
context_manager = te.onnx_export(
True,
) if is_te_imported and state.precision == Precision.AMP_FP8 else contextlib.nullcontext(
)
with context_manager:
new_model_instance.save_pretrained(temp_save_dir)
if original_tokenizer is not None:
assert isinstance(
original_tokenizer,
PreTrainedTokenizerBase,
)
original_tokenizer.save_pretrained(temp_save_dir)

# Only need to edit files for MPT because it has custom code
if new_model_instance.config.model_type == 'mpt':
log.debug('Editing MPT files for HuggingFace compatibility')
edit_files_for_hf_compatibility(
temp_save_dir,
self.flatten_imports,
)

if self.remote_ud is not None:
for filename in os.listdir(temp_save_dir):
remote_file_name = os.path.join(save_dir, filename)
remote_file_uri = self.remote_ud.remote_backend.get_uri(
remote_file_name,
)
log.info(
f'Uploading HuggingFace formatted checkpoint to {remote_file_uri}',
if upload_to_save_folder:
# This context manager casts the TE extra state in io.BytesIO format to tensor format
# Needed for proper hf ckpt saving.
context_manager = te.onnx_export(
True,
) if is_te_imported and state.precision == Precision.AMP_FP8 else contextlib.nullcontext(
)
with context_manager:
new_model_instance.save_pretrained(temp_save_dir)
if original_tokenizer is not None:
assert isinstance(
original_tokenizer,
PreTrainedTokenizerBase,
)
self.remote_ud.upload_file(
state=state,
remote_file_name=remote_file_name,
file_path=Path(os.path.join(temp_save_dir, filename)),
overwrite=self.overwrite,
original_tokenizer.save_pretrained(temp_save_dir)

# Only need to edit files for MPT because it has custom code
if new_model_instance.config.model_type == 'mpt':
log.debug('Editing MPT files for HuggingFace compatibility')
edit_files_for_hf_compatibility(
temp_save_dir,
self.flatten_imports,
)

if self.remote_ud is not None:
for filename in os.listdir(temp_save_dir):
remote_file_name = os.path.join(save_dir, filename)
remote_file_uri = self.remote_ud.remote_backend.get_uri(
remote_file_name,
)
log.info(
f'Uploading HuggingFace formatted checkpoint to {remote_file_uri}',
)
self.remote_ud.upload_file(
state=state,
remote_file_name=remote_file_name,
file_path=Path(
os.path.join(temp_save_dir, filename),
),
overwrite=self.overwrite,
)

dist.barrier()

if dist.get_global_rank() == 0:
if self.mlflow_registered_model_name and self._is_last_batch(state):

if register_to_mflow:
new_model_instance = self.transform_model_pre_registration(
new_model_instance,
)
Expand Down Expand Up @@ -680,7 +728,7 @@ def tensor_hook(
# Restore the monitor process.
if monitor_process is not None:
mlflow_logger.monitor_process = monitor_process # type: ignore
self.child_processes.append(process)
self.register_processes.append(process)

# Save the temporary directory to be cleaned up later.
if use_temp_dir:
Expand Down
7 changes: 6 additions & 1 deletion llmfoundry/command_utils/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,12 @@ def train(cfg: DictConfig) -> Trainer:
)

hf_checkpointer_callback = hf_checkpointer_callbacks[0]
hf_checkpointer_callback._save_checkpoint(trainer.state, trainer.logger)
hf_checkpointer_callback._save_checkpoint(
trainer.state,
trainer.logger,
upload_to_save_folder=True,
register_to_mflow=False,
)
return trainer

if train_cfg.only_composer_checkpoint:
Expand Down
1 change: 1 addition & 0 deletions tests/a_scripts/inference/test_convert_composer_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,7 @@ class MockSpawnProcess:
def __init__(self, target: Callable, kwargs: dict[str, Any]):
self.target = target
self.kwargs = kwargs
self.exitcode = 0

def start(self):
self.target(**self.kwargs)
Expand Down
Loading