diff --git a/examples/hf_gpt2/hf_gpt2.py b/examples/hf_gpt2/hf_gpt2.py index 0ecf588..fc70df8 100644 --- a/examples/hf_gpt2/hf_gpt2.py +++ b/examples/hf_gpt2/hf_gpt2.py @@ -15,7 +15,8 @@ from energonai.nn import VocabParallelEmbedding1D from torch.nn import Embedding from energonai.utils import get_current_device, is_using_pp -from energonai.utils.checkpointing_hf_gpt2 import load_checkpoint +from energonai.utils.checkpointing import load_checkpoint +from energonai.utils.checkpointing_hf_gpt2 import processing_HF_GPT __all__ = [ @@ -479,7 +480,7 @@ def _create_gpt_pipeline_model(depth=48, num_chunks=1, layer_partitions=None, ** assert "checkpoint_path" in model_kwargs.keys(), "You have to specify a file path to use checkpoint loading" print(model_kwargs["checkpoint_path"]) assert os.path.exists(model_kwargs["checkpoint_path"]), "Checkpoint file not found" - load_checkpoint(model_kwargs["checkpoint_path"], model, **model_kwargs) + load_checkpoint(model_kwargs["checkpoint_path"], model, preprocess_fn=processing_HF_GPT, **model_kwargs) logger.info(f'Rank{rank}/{pipeline_rank} model size = {numel * 2 / 1e9} GB') return model