@@ -435,8 +435,8 @@ def _save_checkpoint(self, state: State, logger: Logger):
435
435
436
436
cpu_offload = True
437
437
438
- # Add a dtensor->cpu tensor hook to avoid CUDA OOM
439
- def dtensor_to_tensor_hook (
438
+ # Add hook to move tensors to cpu to avoid CUDA OOM
439
+ def tensor_hook (
440
440
module : nn .Module ,
441
441
state_dict : Dict [str , Any ],
442
442
prefix : str ,
@@ -449,20 +449,23 @@ def dtensor_to_tensor_hook(
449
449
dtensor_fqns .append (fqn )
450
450
tensor = tensor .full_tensor () # type: ignore
451
451
if dist .get_global_rank () == 0 :
452
+ # Offload any DTensors to CPU
452
453
if cpu_offload :
453
454
tensor = tensor .cpu ()
454
455
state_dict [fqn ] = tensor
456
+ else :
457
+ state_dict [fqn ] = None
458
+ # Convert the state dict to the requested precision
459
+ if isinstance (tensor , torch .Tensor ):
460
+ state_dict [fqn ] = tensor .to (dtype = self .dtype )
461
+ del tensor
455
462
if dist .get_global_rank () != 0 :
456
- for fqn in dtensor_fqns :
457
- del state_dict [fqn ]
463
+ state_dict = {}
458
464
return state_dict
459
465
460
466
hooks = []
461
467
for _ , module in state_dict_model .named_modules ():
462
- if isinstance (module , FSDP ):
463
- hooks .append (
464
- module ._register_state_dict_hook (dtensor_to_tensor_hook ),
465
- )
468
+ hooks .append (module ._register_state_dict_hook (tensor_hook ),)
466
469
467
470
state_dict = get_model_state_dict (
468
471
state_dict_model ,
@@ -474,11 +477,6 @@ def dtensor_to_tensor_hook(
474
477
for hook in hooks :
475
478
hook .remove ()
476
479
477
- # Convert the state dict to the requested precision
478
- for k , v in state_dict .items ():
479
- if isinstance (v , torch .Tensor ):
480
- state_dict [k ] = v .to (dtype = self .dtype )
481
-
482
480
new_model_instance = None # Need this for pyright because variable could be unbound
483
481
484
482
if dist .get_global_rank () == 0 :
@@ -537,7 +535,7 @@ def dtensor_to_tensor_hook(
537
535
original_tokenizer .save_pretrained (temp_save_dir )
538
536
539
537
# Only need to edit files for MPT because it has custom code
540
- if original_model .config .model_type == 'mpt' :
538
+ if new_model_instance .config .model_type == 'mpt' :
541
539
log .debug ('Editing MPT files for HuggingFace compatibility' )
542
540
edit_files_for_hf_compatibility (
543
541
temp_save_dir ,
0 commit comments