diff --git a/inference/huggingface/text-generation/arguments.py b/inference/huggingface/text-generation/arguments.py index a6dade23f..23722cf63 100644 --- a/inference/huggingface/text-generation/arguments.py +++ b/inference/huggingface/text-generation/arguments.py @@ -18,4 +18,5 @@ parser.add_argument("--local_rank", type=int, default=int(os.getenv("LOCAL_RANK", "0")), help="local rank") parser.add_argument("--world_size", type=int, default=int(os.getenv("WORLD_SIZE", "1")), help="world_size") parser.add_argument("--test_hybrid_engine", action='store_true', help="enable hybrid engine testing") -parser.add_argument("--trust_remote_code", action='store_true', help="Trust remote code for hugging face models") \ No newline at end of file +parser.add_argument("--trust_remote_code", action='store_true', help="Trust remote code for hugging face models") +parser.add_argument("--quantize_groups", type=int, required=False, default=0, help="number of weight quantization groups to use") \ No newline at end of file diff --git a/inference/huggingface/text-generation/inference-test.py b/inference/huggingface/text-generation/inference-test.py index 0ba3b20cd..0789209e9 100644 --- a/inference/huggingface/text-generation/inference-test.py +++ b/inference/huggingface/text-generation/inference-test.py @@ -51,6 +51,7 @@ replace_with_kernel_inject=args.use_kernel, max_tokens=args.max_tokens, save_mp_checkpoint_path=args.save_mp_checkpoint_path, + quantize_groups=args.quantize_groups, **ds_kwargs )