Skip to content

Commit 8c91a32

Browse files
committed
undo test changes
1 parent 206cc88 commit 8c91a32

File tree

1 file changed

+8
-35
lines changed

1 file changed

+8
-35
lines changed

tests/recipes/test_lora_finetune_single_device.py

Lines changed: 8 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -30,18 +30,12 @@
3030

3131

3232
class TestLoRAFinetuneSingleDeviceRecipe:
33-
def _get_test_config_overrides(
34-
self,
35-
device: str = "cpu",
36-
enable_ac: bool = False,
37-
dtype_str: str = "fp32",
38-
epochs: int = 2,
39-
):
33+
def _get_test_config_overrides(self, dtype_str: str = "fp32", epochs: int = 2):
4034
return [
4135
"batch_size=8",
42-
f"device={device}",
36+
"device=cpu",
4337
f"dtype={dtype_str}",
44-
f"enable_activation_checkpointing={enable_ac}",
38+
"enable_activation_checkpointing=False",
4539
"dataset.train_on_input=False",
4640
"seed=9",
4741
f"epochs={epochs}",
@@ -67,24 +61,13 @@ def _fetch_qlora_expected_loss_values(self, dtype):
6761
@pytest.mark.integration_test
6862
@pytest.mark.parametrize("compile", [True, False])
6963
@pytest.mark.parametrize(
70-
"config, model_type, ckpt_type, enable_activation_checkpointing, enable_activation_offloading",
64+
"config, model_type, ckpt_type",
7165
[
72-
("llama2/7B_lora_single_device", "llama2", "meta", False, False),
73-
("llama2/7B_lora_single_device", "llama2", "meta", True, True),
74-
("llama3/8B_lora_single_device", "llama3", "tune", True, False),
66+
("llama2/7B_lora_single_device", "llama2", "meta"),
67+
("llama3/8B_lora_single_device", "llama3", "tune"),
7568
],
7669
)
77-
def test_loss(
78-
self,
79-
compile,
80-
config,
81-
model_type,
82-
ckpt_type,
83-
enable_activation_checkpointing,
84-
enable_activation_offloading,
85-
tmpdir,
86-
monkeypatch,
87-
):
70+
def test_loss(self, compile, config, model_type, ckpt_type, tmpdir, monkeypatch):
8871
ckpt_component = CKPT_COMPONENT_MAP[ckpt_type]
8972
ckpt = model_type + "_" + ckpt_type
9073
ckpt_path = Path(CKPT_MODEL_PATHS[ckpt])
@@ -105,21 +88,11 @@ def test_loss(
10588
tokenizer.prompt_template=null \
10689
metric_logger.filename={log_file} \
10790
compile={compile} \
108-
enable_activation_checkpointing={enable_activation_checkpointing} \
109-
enable_activation_offloading={enable_activation_offloading} \
11091
""".split()
11192

11293
model_config = MODEL_TEST_CONFIGS[model_type + "_lora"]
11394

114-
cmd = (
115-
cmd
116-
+ self._get_test_config_overrides(
117-
device="cuda",
118-
enable_ac=enable_activation_checkpointing,
119-
dtype_str="fp32",
120-
)
121-
+ model_config
122-
)
95+
cmd = cmd + self._get_test_config_overrides(dtype_str="fp32") + model_config
12396
monkeypatch.setattr(sys, "argv", cmd)
12497
with pytest.raises(SystemExit, match=""):
12598
runpy.run_path(TUNE_PATH, run_name="__main__")

0 commit comments

Comments
 (0)