15
15
from torchtune import config , generation , training , utils
16
16
from torchtune .config ._utils import _get_component_from_path
17
17
from torchtune .data import ChatFormat , InstructTemplate , Message
18
+ from torchtune .training import FullModelTorchTuneCheckpointer
18
19
19
20
logger = utils .get_logger ("DEBUG" )
20
21
@@ -44,12 +45,26 @@ def __init__(self, cfg: DictConfig) -> None:
44
45
45
46
def setup (self , cfg : DictConfig ) -> None :
46
47
checkpointer = config .instantiate (cfg .checkpointer )
48
+
49
+ if self ._quantization_mode is not None :
50
+ if not isinstance (checkpointer , FullModelTorchTuneCheckpointer ):
51
+ raise ValueError (
52
+ "Quantization is only supported for models quantized and saved with the "
53
+ "FullModelTorchTuneCheckpointer - please ensure you have quantized your "
54
+ "model and are using the quantized weights!"
55
+ )
56
+ if "qat" in self ._quantization_mode :
57
+ raise ValueError (
58
+ "You have specified a quantizer with 'QAT' - "
59
+ "QAT quantizers should only be used during quantization aware training "
60
+ "and when quantizing models. Please use the corresponding post-training "
61
+ "quantizer e.g. Int8DynActInt4WeightQuantizer for Int8DynActInt4WeightQATQuantizer."
62
+ )
63
+
47
64
if self ._quantization_mode is None :
48
65
ckpt_dict = checkpointer .load_checkpoint ()
49
66
else :
50
67
# weights_only needs to be False when loading a quantized model
51
- # currently loading a quantized model is only supported with the
52
- # FullModelTorchTuneCheckpointer
53
68
ckpt_dict = checkpointer .load_checkpoint (weights_only = False )
54
69
55
70
self ._model = self ._setup_model (
@@ -69,8 +84,11 @@ def _setup_model(
69
84
if self ._quantization_mode is not None :
70
85
model = self ._quantizer .quantize (model )
71
86
model = model .to (device = self ._device , dtype = self ._dtype )
72
-
73
- model .load_state_dict (model_state_dict )
87
+ for k , v in model_state_dict .items ():
88
+ model_state_dict [k ] = v .to (self ._device )
89
+ model .load_state_dict (model_state_dict , assign = True )
90
+ else :
91
+ model .load_state_dict (model_state_dict )
74
92
75
93
# Validate model was loaded in with the expected dtype.
76
94
training .validate_expected_param_dtype (
0 commit comments