diff --git a/tests/v1/tpu/test_sampler.py b/tests/v1/tpu/test_sampler.py index 74ad81406b4..50d40aa9dec 100644 --- a/tests/v1/tpu/test_sampler.py +++ b/tests/v1/tpu/test_sampler.py @@ -42,7 +42,7 @@ def test_sampler_different(model_name: str): sampling_params = SamplingParams(temperature=0.3, seed=42) output2 = llm.generate(prompts, sampling_params) - # Batch-case with TopK + # Batch-case with TopK/P for B in [4, 16]: p = prompts * B sampling_params = [ @@ -51,9 +51,10 @@ def test_sampler_different(model_name: str): min_p=0.8, max_tokens=64, # Vary number of ks - top_k=random.randint(4, 12)) for _ in range(B) + top_k=random.randint(4, 12), + top_p=random.random()) for _ in range(B) ] - # Make sure first two reqs have the same K + # Make sure first two reqs have the same K/P sampling_params[0] = sampling_params[1] output = llm.generate(p, sampling_params) assert output[0].outputs[0].text == output[1].outputs[0].text diff --git a/vllm/v1/sample/tpu/metadata.py b/vllm/v1/sample/tpu/metadata.py index 917d8baf60b..d4ea8c2dee0 100644 --- a/vllm/v1/sample/tpu/metadata.py +++ b/vllm/v1/sample/tpu/metadata.py @@ -11,7 +11,7 @@ min_p=0.0, # strictly disabled for now top_k=0, - # top_p=0.0, + top_p=1.0, # frequency_penalties=0.0, # presence_penalties=0.0, # repetition_penalties=0.0, @@ -26,11 +26,9 @@ class TPUSupportedSamplingMetadata: temperature: torch.Tensor = None min_p: torch.Tensor = None - # Still too slow on forward_native! top_k: torch.Tensor = None top_p: torch.Tensor = None - # Greedy sampling flag for compiling single xla graph. all_greedy: bool = True # unsupported, you need to return an extra tensor of static size BxV @@ -103,9 +101,8 @@ def fill_slice(cpu_tensor: torch.Tensor, fill_val) -> torch.Tensor: DEFAULT_SAMPLING_PARAMS["min_p"]) fill_slice(input_batch.top_k_cpu_tensor, DEFAULT_SAMPLING_PARAMS["top_k"]) - # TODO Temporarily disabled until sampling options are enabled - # fill_slice(input_batch.top_p_cpu_tensor, - # DEFAULT_SAMPLING_PARAMS["top_p"]) + fill_slice(input_batch.top_p_cpu_tensor, + DEFAULT_SAMPLING_PARAMS["top_p"]) # Slice persistent device tensors to a fixed pre-compiled padded shape. return cls( @@ -113,7 +110,8 @@ def fill_slice(cpu_tensor: torch.Tensor, fill_val) -> torch.Tensor: to(xla_device), all_greedy=input_batch.all_greedy, # TODO enable more and avoid returning None values - top_p=None, # input_batch.top_p[:padded_num_reqs], + top_p=input_batch.top_p_cpu_tensor[:padded_num_reqs].to( + xla_device), top_k=input_batch.top_k_cpu_tensor[:padded_num_reqs].to( xla_device), min_p=input_batch.min_p_cpu_tensor[:padded_num_reqs].to(