Skip to content

Commit 596dd9d

Browse files
irenedeaSaaketh Narayan
andauthored
Do dtype conversion in torch hook to save memory (#1384)
* Do dtype conversion in torch hook to save memory * update code comment Co-authored-by: Saaketh Narayan <[email protected]> --------- Co-authored-by: Saaketh Narayan <[email protected]>
1 parent 0bed4ff commit 596dd9d

File tree

2 files changed

+14
-14
lines changed

2 files changed

+14
-14
lines changed

llmfoundry/callbacks/hf_checkpointer.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -435,8 +435,8 @@ def _save_checkpoint(self, state: State, logger: Logger):
435435

436436
cpu_offload = True
437437

438-
# Add a dtensor->cpu tensor hook to avoid CUDA OOM
439-
def dtensor_to_tensor_hook(
438+
# Add hook to move tensors to cpu to avoid CUDA OOM
439+
def tensor_hook(
440440
module: nn.Module,
441441
state_dict: Dict[str, Any],
442442
prefix: str,
@@ -449,20 +449,23 @@ def dtensor_to_tensor_hook(
449449
dtensor_fqns.append(fqn)
450450
tensor = tensor.full_tensor() # type: ignore
451451
if dist.get_global_rank() == 0:
452+
# Offload any DTensors to CPU
452453
if cpu_offload:
453454
tensor = tensor.cpu()
454455
state_dict[fqn] = tensor
456+
else:
457+
state_dict[fqn] = None
458+
# Convert the state dict to the requested precision
459+
if isinstance(tensor, torch.Tensor):
460+
state_dict[fqn] = tensor.to(dtype=self.dtype)
461+
del tensor
455462
if dist.get_global_rank() != 0:
456-
for fqn in dtensor_fqns:
457-
del state_dict[fqn]
463+
state_dict = {}
458464
return state_dict
459465

460466
hooks = []
461467
for _, module in state_dict_model.named_modules():
462-
if isinstance(module, FSDP):
463-
hooks.append(
464-
module._register_state_dict_hook(dtensor_to_tensor_hook),
465-
)
468+
hooks.append(module._register_state_dict_hook(tensor_hook),)
466469

467470
state_dict = get_model_state_dict(
468471
state_dict_model,
@@ -474,11 +477,6 @@ def dtensor_to_tensor_hook(
474477
for hook in hooks:
475478
hook.remove()
476479

477-
# Convert the state dict to the requested precision
478-
for k, v in state_dict.items():
479-
if isinstance(v, torch.Tensor):
480-
state_dict[k] = v.to(dtype=self.dtype)
481-
482480
new_model_instance = None # Need this for pyright because variable could be unbound
483481

484482
if dist.get_global_rank() == 0:
@@ -537,7 +535,7 @@ def dtensor_to_tensor_hook(
537535
original_tokenizer.save_pretrained(temp_save_dir)
538536

539537
# Only need to edit files for MPT because it has custom code
540-
if original_model.config.model_type == 'mpt':
538+
if new_model_instance.config.model_type == 'mpt':
541539
log.debug('Editing MPT files for HuggingFace compatibility')
542540
edit_files_for_hf_compatibility(
543541
temp_save_dir,

tests/a_scripts/inference/test_convert_composer_to_hf.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,8 @@ def test_huggingface_conversion_callback_interval(
383383
mlflow_logger_mock.model_registry_prefix = ''
384384
mlflow_logger_mock._experiment_id = 'mlflow-experiment-id'
385385
mlflow_logger_mock._run_id = 'mlflow-run-id'
386+
mlflow_logger_mock._enabled = True
387+
mlflow_logger_mock.run_url = 'fake-url'
386388
checkpointer_callback.transform_model_pre_registration = MagicMock(
387389
wraps=checkpointer_callback.transform_model_pre_registration,
388390
)

0 commit comments

Comments
 (0)