From f6cca8559d06fe0d248912c239ff6049abf9a748 Mon Sep 17 00:00:00 2001 From: ver217 Date: Mon, 22 Aug 2022 16:28:59 +0800 Subject: [PATCH] fix hf gpt2 example --- examples/hf_gpt2/hf_gpt2.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/hf_gpt2/hf_gpt2.py b/examples/hf_gpt2/hf_gpt2.py index 0ecf588..fc70df8 100644 --- a/examples/hf_gpt2/hf_gpt2.py +++ b/examples/hf_gpt2/hf_gpt2.py @@ -15,7 +15,8 @@ from energonai.nn import VocabParallelEmbedding1D from torch.nn import Embedding from energonai.utils import get_current_device, is_using_pp -from energonai.utils.checkpointing_hf_gpt2 import load_checkpoint +from energonai.utils.checkpointing import load_checkpoint +from energonai.utils.checkpointing_hf_gpt2 import processing_HF_GPT __all__ = [ @@ -479,7 +480,7 @@ def _create_gpt_pipeline_model(depth=48, num_chunks=1, layer_partitions=None, ** assert "checkpoint_path" in model_kwargs.keys(), "You have to specify a file path to use checkpoint loading" print(model_kwargs["checkpoint_path"]) assert os.path.exists(model_kwargs["checkpoint_path"]), "Checkpoint file not found" - load_checkpoint(model_kwargs["checkpoint_path"], model, **model_kwargs) + load_checkpoint(model_kwargs["checkpoint_path"], model, preprocess_fn=processing_HF_GPT, **model_kwargs) logger.info(f'Rank{rank}/{pipeline_rank} model size = {numel * 2 / 1e9} GB') return model