30
30
31
31
32
32
class TestLoRAFinetuneSingleDeviceRecipe :
33
- def _get_test_config_overrides (
34
- self ,
35
- device : str = "cpu" ,
36
- enable_ac : bool = False ,
37
- dtype_str : str = "fp32" ,
38
- epochs : int = 2 ,
39
- ):
33
+ def _get_test_config_overrides (self , dtype_str : str = "fp32" , epochs : int = 2 ):
40
34
return [
41
35
"batch_size=8" ,
42
- f "device={ device } " ,
36
+ "device=cpu " ,
43
37
f"dtype={ dtype_str } " ,
44
- f "enable_activation_checkpointing={ enable_ac } " ,
38
+ "enable_activation_checkpointing=False " ,
45
39
"dataset.train_on_input=False" ,
46
40
"seed=9" ,
47
41
f"epochs={ epochs } " ,
@@ -67,24 +61,13 @@ def _fetch_qlora_expected_loss_values(self, dtype):
67
61
@pytest .mark .integration_test
68
62
@pytest .mark .parametrize ("compile" , [True , False ])
69
63
@pytest .mark .parametrize (
70
- "config, model_type, ckpt_type, enable_activation_checkpointing, enable_activation_offloading " ,
64
+ "config, model_type, ckpt_type" ,
71
65
[
72
- ("llama2/7B_lora_single_device" , "llama2" , "meta" , False , False ),
73
- ("llama2/7B_lora_single_device" , "llama2" , "meta" , True , True ),
74
- ("llama3/8B_lora_single_device" , "llama3" , "tune" , True , False ),
66
+ ("llama2/7B_lora_single_device" , "llama2" , "meta" ),
67
+ ("llama3/8B_lora_single_device" , "llama3" , "tune" ),
75
68
],
76
69
)
77
- def test_loss (
78
- self ,
79
- compile ,
80
- config ,
81
- model_type ,
82
- ckpt_type ,
83
- enable_activation_checkpointing ,
84
- enable_activation_offloading ,
85
- tmpdir ,
86
- monkeypatch ,
87
- ):
70
+ def test_loss (self , compile , config , model_type , ckpt_type , tmpdir , monkeypatch ):
88
71
ckpt_component = CKPT_COMPONENT_MAP [ckpt_type ]
89
72
ckpt = model_type + "_" + ckpt_type
90
73
ckpt_path = Path (CKPT_MODEL_PATHS [ckpt ])
@@ -105,21 +88,11 @@ def test_loss(
105
88
tokenizer.prompt_template=null \
106
89
metric_logger.filename={ log_file } \
107
90
compile={ compile } \
108
- enable_activation_checkpointing={ enable_activation_checkpointing } \
109
- enable_activation_offloading={ enable_activation_offloading } \
110
91
""" .split ()
111
92
112
93
model_config = MODEL_TEST_CONFIGS [model_type + "_lora" ]
113
94
114
- cmd = (
115
- cmd
116
- + self ._get_test_config_overrides (
117
- device = "cuda" ,
118
- enable_ac = enable_activation_checkpointing ,
119
- dtype_str = "fp32" ,
120
- )
121
- + model_config
122
- )
95
+ cmd = cmd + self ._get_test_config_overrides (dtype_str = "fp32" ) + model_config
123
96
monkeypatch .setattr (sys , "argv" , cmd )
124
97
with pytest .raises (SystemExit , match = "" ):
125
98
runpy .run_path (TUNE_PATH , run_name = "__main__" )
0 commit comments