diff --git a/llm/merge_lora_params.py b/llm/merge_lora_params.py index 50ae4a797f34..065a2585ebc0 100644 --- a/llm/merge_lora_params.py +++ b/llm/merge_lora_params.py @@ -125,7 +125,7 @@ def merge(): model = AutoModelForCausalLM.from_pretrained( lora_config.base_model_name_or_path, config=config, - low_cpu_mem_usage=True, + low_cpu_mem_usage=args.low_gpu_mem, ) lora_config.merge_weights = True model = LoRAModel.from_pretrained(model=model, lora_path=args.lora_path, lora_config=lora_config)