Skip to content

Commit d79bf29

Browse files
SalmanMohammadimori360
authored andcommitted
Fix quantization with generate (pytorch#1784)
1 parent ee3e703 commit d79bf29

File tree

2 files changed

+23
-5
lines changed

2 files changed

+23
-5
lines changed

recipes/generate.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from torchtune import config, generation, training, utils
1616
from torchtune.config._utils import _get_component_from_path
1717
from torchtune.data import ChatFormat, InstructTemplate, Message
18+
from torchtune.training import FullModelTorchTuneCheckpointer
1819

1920
logger = utils.get_logger("DEBUG")
2021

@@ -44,12 +45,26 @@ def __init__(self, cfg: DictConfig) -> None:
4445

4546
def setup(self, cfg: DictConfig) -> None:
4647
checkpointer = config.instantiate(cfg.checkpointer)
48+
49+
if self._quantization_mode is not None:
50+
if not isinstance(checkpointer, FullModelTorchTuneCheckpointer):
51+
raise ValueError(
52+
"Quantization is only supported for models quantized and saved with the "
53+
"FullModelTorchTuneCheckpointer - please ensure you have quantized your "
54+
"model and are using the quantized weights!"
55+
)
56+
if "qat" in self._quantization_mode:
57+
raise ValueError(
58+
"You have specified a quantizer with 'QAT' - "
59+
"QAT quantizers should only be used during quantization aware training "
60+
"and when quantizing models. Please use the corresponding post-training "
61+
"quantizer e.g. Int8DynActInt4WeightQuantizer for Int8DynActInt4WeightQATQuantizer."
62+
)
63+
4764
if self._quantization_mode is None:
4865
ckpt_dict = checkpointer.load_checkpoint()
4966
else:
5067
# weights_only needs to be False when loading a quantized model
51-
# currently loading a quantized model is only supported with the
52-
# FullModelTorchTuneCheckpointer
5368
ckpt_dict = checkpointer.load_checkpoint(weights_only=False)
5469

5570
self._model = self._setup_model(
@@ -69,8 +84,11 @@ def _setup_model(
6984
if self._quantization_mode is not None:
7085
model = self._quantizer.quantize(model)
7186
model = model.to(device=self._device, dtype=self._dtype)
72-
73-
model.load_state_dict(model_state_dict)
87+
for k, v in model_state_dict.items():
88+
model_state_dict[k] = v.to(self._device)
89+
model.load_state_dict(model_state_dict, assign=True)
90+
else:
91+
model.load_state_dict(model_state_dict)
7492

7593
# Validate model was loaded in with the expected dtype.
7694
training.validate_expected_param_dtype(

torchtune/generation/_generation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,7 @@ def generate(
366366
tokens, logits = custom_generate_next_token(
367367
model,
368368
input_pos=curr_input_pos,
369-
x=tokens,
369+
x=tokens.clone(),
370370
mask=curr_masks,
371371
temperature=temperature,
372372
top_k=top_k,

0 commit comments

Comments
 (0)