Skip to content

Commit 38df2dd

Browse files
SalmanMohammadimori360
authored andcommitted
Fixing quantization in eval recipe (pytorch#1777)
1 parent 7122361 commit 38df2dd

File tree

2 files changed

+98
-12
lines changed

2 files changed

+98
-12
lines changed

recipes/eleuther_eval.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from torchtune.modules.tokenizers import ModelTokenizer
2929
from torchtune.modules.transforms import Transform
3030
from torchtune.recipe_interfaces import EvalRecipeInterface
31+
from torchtune.training import FullModelTorchTuneCheckpointer
3132

3233
try:
3334
import lm_eval
@@ -487,28 +488,39 @@ def setup(self, cfg: DictConfig) -> None:
487488

488489
# Load checkpoint
489490
checkpointer = config.instantiate(cfg.checkpointer)
490-
if quantization_mode is None:
491-
ckpt_dict = checkpointer.load_checkpoint()
492-
else:
493-
# weights_only needs to be False when loading a quantized model
494-
# currently loading a quantized model is only supported with the
495-
# FullModelTorchTuneCheckpointer
496-
ckpt_dict = checkpointer.load_checkpoint(weights_only=False)
497491

498492
# Initialize model
499493
with training.set_default_dtype(self.dtype), self.device:
500494
model = config.instantiate(cfg.model)
501495

502496
# Quantize model if requested
503497
if quantization_mode is not None:
498+
if not isinstance(checkpointer, FullModelTorchTuneCheckpointer):
499+
raise ValueError(
500+
"Quantization is only supported for models quantized and saved with the "
501+
"FullModelTorchTuneCheckpointer - please ensure you have quantized your "
502+
"model and are using the quantized weights!"
503+
)
504+
if "qat" in quantization_mode:
505+
raise ValueError(
506+
"You have specified a quantizer with 'QAT' - "
507+
"QAT quantizers should only be used during quantization aware training "
508+
"and when quantizing models. Please use the corresponding post-training "
509+
"quantizer e.g. Int8DynActInt4WeightQuantizer for Int8DynActInt4WeightQATQuantizer."
510+
)
504511
model = quantizer.quantize(model)
505512
model = model.to(device=self.device, dtype=self.dtype)
506-
for k, v in model_state_dict.items():
507-
model_state_dict[k] = v.to(self._device)
508-
model.load_state_dict(model_state_dict, assign=True)
513+
ckpt_dict = checkpointer.load_checkpoint(weights_only=False)[
514+
training.MODEL_KEY
515+
]
516+
for k, v in ckpt_dict.items():
517+
ckpt_dict[k] = v.to(self.device)
518+
model.load_state_dict(ckpt_dict, assign=True)
519+
else:
520+
ckpt_dict = checkpointer.load_checkpoint()[training.MODEL_KEY]
521+
model.load_state_dict(ckpt_dict)
509522

510523
# Load model weights into initialized model
511-
model.load_state_dict(ckpt_dict[training.MODEL_KEY])
512524
self.logger.info(f"Model is initialized with precision {self.dtype}.")
513525

514526
# Put model in eval mode.

tests/recipes/test_eleuther_eval.py

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import pytest
1515

1616
from tests.common import TUNE_PATH
17-
from tests.recipes.utils import llama2_test_config
17+
from tests.recipes.utils import llama2_test_config, write_hf_ckpt_config
1818
from tests.test_utils import CKPT_MODEL_PATHS
1919

2020

@@ -126,6 +126,80 @@ def test_eval_recipe_errors_without_lm_eval(self, capsys, monkeypatch, tmpdir):
126126
in printed_err
127127
)
128128

129+
@pytest.mark.integration_test
130+
def test_eval_recipe_errors_with_quantization_hf_checkpointer(
131+
self, capsys, monkeypatch, tmpdir
132+
):
133+
ckpt = "llama2_hf"
134+
ckpt_path = Path(CKPT_MODEL_PATHS[ckpt])
135+
ckpt_dir = ckpt_path.parent
136+
137+
# Config file needed for model conversion.
138+
write_hf_ckpt_config(ckpt_dir)
139+
140+
cmd = f"""
141+
tune run eleuther_eval \
142+
--config eleuther_evaluation \
143+
output_dir={tmpdir} \
144+
checkpointer=torchtune.training.FullModelHFCheckpointer \
145+
checkpointer.checkpoint_dir='{ckpt_dir}' \
146+
checkpointer.checkpoint_files=[{ckpt_path}]\
147+
checkpointer.output_dir={tmpdir} \
148+
checkpointer.model_type=LLAMA2 \
149+
tokenizer.path=/tmp/test-artifacts/tokenizer.model \
150+
tokenizer.prompt_template=null \
151+
limit=1 \
152+
dtype=fp32 \
153+
device=cpu \
154+
quantizer._component_=torchtune.training.quantization.Int8DynActInt4WeightQuantizer \
155+
quantizer.groupsize=256 \
156+
""".split()
157+
158+
model_config = llama2_test_config()
159+
cmd = cmd + model_config
160+
161+
monkeypatch.setattr(sys, "argv", cmd)
162+
with pytest.raises(
163+
ValueError,
164+
match="Quantization is only supported for models quantized and saved with the "
165+
"FullModelTorchTuneCheckpointer",
166+
):
167+
runpy.run_path(TUNE_PATH, run_name="__main__")
168+
169+
@pytest.mark.integration_test
170+
def test_eval_recipe_errors_with_qat_quantizer(self, capsys, monkeypatch, tmpdir):
171+
ckpt = "llama2_tune"
172+
ckpt_path = Path(CKPT_MODEL_PATHS[ckpt])
173+
ckpt_dir = ckpt_path.parent
174+
175+
cmd = f"""
176+
tune run eleuther_eval \
177+
--config eleuther_evaluation \
178+
output_dir={tmpdir} \
179+
checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \
180+
checkpointer.checkpoint_dir='{ckpt_dir}' \
181+
checkpointer.checkpoint_files=[{ckpt_path}]\
182+
checkpointer.output_dir={tmpdir} \
183+
checkpointer.model_type=LLAMA2 \
184+
tokenizer.path=/tmp/test-artifacts/tokenizer.model \
185+
tokenizer.prompt_template=null \
186+
limit=1 \
187+
dtype=fp32 \
188+
device=cpu \
189+
quantizer._component_=torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer \
190+
quantizer.groupsize=32\
191+
""".split()
192+
193+
model_config = llama2_test_config()
194+
cmd = cmd + model_config
195+
196+
monkeypatch.setattr(sys, "argv", cmd)
197+
with pytest.raises(
198+
ValueError,
199+
match="QAT quantizers should only be used during quantization aware training",
200+
):
201+
runpy.run_path(TUNE_PATH, run_name="__main__")
202+
129203
@pytest.mark.integration_test
130204
def test_eval_recipe_errors_with_generate_until_and_mc_tasks(
131205
self, caplog, capsys, monkeypatch, tmpdir

0 commit comments

Comments
 (0)