Skip to content

Commit 5198042

Browse files
NickLucchejoennlae
authored andcommitted
[Frontend] Add sampling params to v1/audio/transcriptions endpoint (vllm-project#16591)
Signed-off-by: Jannis Schönleber <[email protected]> Signed-off-by: NickLucche <[email protected]> Co-authored-by: Jannis Schönleber <[email protected]> Signed-off-by: Mu Huai <[email protected]>
1 parent b3010fd commit 5198042

File tree

4 files changed

+122
-11
lines changed

4 files changed

+122
-11
lines changed

docs/source/serving/openai_compatible_server.md

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -402,9 +402,26 @@ you can use the [official OpenAI Python client](https://github.com/openai/openai
402402
To use the Transcriptions API, please install with extra audio dependencies using `pip install vllm[audio]`.
403403
:::
404404

405+
Code example: <gh-file:examples/online_serving/openai_transcription_client.py>
405406
<!-- TODO: api enforced limits + uploading audios -->
406407

407-
Code example: <gh-file:examples/online_serving/openai_transcription_client.py>
408+
#### Extra Parameters
409+
410+
The following [sampling parameters](#sampling-params) are supported.
411+
412+
:::{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
413+
:language: python
414+
:start-after: begin-transcription-sampling-params
415+
:end-before: end-transcription-sampling-params
416+
:::
417+
418+
The following extra parameters are supported:
419+
420+
:::{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
421+
:language: python
422+
:start-after: begin-transcription-extra-params
423+
:end-before: end-transcription-extra-params
424+
:::
408425

409426
(tokenizer-api)=
410427

examples/online_serving/openai_transcription_client.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,12 @@ def sync_openai():
2626
model="openai/whisper-large-v3",
2727
language="en",
2828
response_format="json",
29-
temperature=0.0)
29+
temperature=0.0,
30+
# Additional sampling params not provided by OpenAI API.
31+
extra_body=dict(
32+
seed=4419,
33+
repetition_penalty=1.3,
34+
))
3035
print("transcription result:", transcription.text)
3136

3237

tests/entrypoints/openai/test_transcription_validation.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,3 +192,36 @@ async def post_with_stream(*args, **kwargs):
192192
else:
193193
continuous = continuous and hasattr(chunk, 'usage')
194194
assert final and continuous
195+
196+
197+
@pytest.mark.asyncio
198+
async def test_sampling_params(mary_had_lamb):
199+
"""
200+
Compare sampling with params and greedy sampling to assert results
201+
are different when extreme sampling parameters values are picked.
202+
"""
203+
model_name = "openai/whisper-small"
204+
server_args = ["--enforce-eager"]
205+
with RemoteOpenAIServer(model_name, server_args) as remote_server:
206+
client = remote_server.get_async_client()
207+
transcription = await client.audio.transcriptions.create(
208+
model=model_name,
209+
file=mary_had_lamb,
210+
language="en",
211+
temperature=0.8,
212+
extra_body=dict(seed=42,
213+
repetition_penalty=1.9,
214+
top_k=12,
215+
top_p=0.4,
216+
min_p=0.5,
217+
frequency_penalty=1.8,
218+
presence_penalty=2.0))
219+
220+
greedy_transcription = await client.audio.transcriptions.create(
221+
model=model_name,
222+
file=mary_had_lamb,
223+
language="en",
224+
temperature=0.0,
225+
extra_body=dict(seed=42))
226+
227+
assert greedy_transcription.text != transcription.text

vllm/entrypoints/openai/protocol.py

Lines changed: 65 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1577,14 +1577,6 @@ class TranscriptionRequest(OpenAIBaseModel):
15771577
"""
15781578

15791579
## TODO (varun) : Support if set to 0, certain thresholds are met !!
1580-
temperature: float = Field(default=0.0)
1581-
"""The sampling temperature, between 0 and 1.
1582-
1583-
Higher values like 0.8 will make the output more random, while lower values
1584-
like 0.2 will make it more focused / deterministic. If set to 0, the model
1585-
will use [log probability](https://en.wikipedia.org/wiki/Log_probability)
1586-
to automatically increase the temperature until certain thresholds are hit.
1587-
"""
15881580

15891581
timestamp_granularities: list[Literal["word", "segment"]] = Field(
15901582
alias="timestamp_granularities[]", default=[])
@@ -1596,6 +1588,7 @@ class TranscriptionRequest(OpenAIBaseModel):
15961588
timestamps incurs additional latency.
15971589
"""
15981590

1591+
# doc: begin-transcription-extra-params
15991592
stream: Optional[bool] = False
16001593
"""Custom field not present in the original OpenAI definition. When set,
16011594
it will enable output to be streamed in a similar fashion as the Chat
@@ -1604,10 +1597,51 @@ class TranscriptionRequest(OpenAIBaseModel):
16041597
# Flattened stream option to simplify form data.
16051598
stream_include_usage: Optional[bool] = False
16061599
stream_continuous_usage_stats: Optional[bool] = False
1600+
# doc: end-transcription-extra-params
1601+
1602+
# doc: begin-transcription-sampling-params
1603+
temperature: float = Field(default=0.0)
1604+
"""The sampling temperature, between 0 and 1.
1605+
1606+
Higher values like 0.8 will make the output more random, while lower values
1607+
like 0.2 will make it more focused / deterministic. If set to 0, the model
1608+
will use [log probability](https://en.wikipedia.org/wiki/Log_probability)
1609+
to automatically increase the temperature until certain thresholds are hit.
1610+
"""
1611+
1612+
top_p: Optional[float] = None
1613+
"""Enables nucleus (top-p) sampling, where tokens are selected from the
1614+
smallest possible set whose cumulative probability exceeds `p`.
1615+
"""
1616+
1617+
top_k: Optional[int] = None
1618+
"""Limits sampling to the `k` most probable tokens at each step."""
1619+
1620+
min_p: Optional[float] = None
1621+
"""Filters out tokens with a probability lower than `min_p`, ensuring a
1622+
minimum likelihood threshold during sampling.
1623+
"""
1624+
1625+
seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
1626+
"""The seed to use for sampling."""
1627+
1628+
frequency_penalty: Optional[float] = 0.0
1629+
"""The frequency penalty to use for sampling."""
1630+
1631+
repetition_penalty: Optional[float] = None
1632+
"""The repetition penalty to use for sampling."""
1633+
1634+
presence_penalty: Optional[float] = 0.0
1635+
"""The presence penalty to use for sampling."""
1636+
# doc: end-transcription-sampling-params
16071637

16081638
# Default sampling parameters for transcription requests.
16091639
_DEFAULT_SAMPLING_PARAMS: dict = {
1610-
"temperature": 0,
1640+
"repetition_penalty": 1.0,
1641+
"temperature": 1.0,
1642+
"top_p": 1.0,
1643+
"top_k": -1,
1644+
"min_p": 0.0,
16111645
}
16121646

16131647
def to_sampling_params(
@@ -1619,13 +1653,35 @@ def to_sampling_params(
16191653

16201654
if default_sampling_params is None:
16211655
default_sampling_params = {}
1656+
16221657
# Default parameters
16231658
if (temperature := self.temperature) is None:
16241659
temperature = default_sampling_params.get(
16251660
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
1661+
if (top_p := self.top_p) is None:
1662+
top_p = default_sampling_params.get(
1663+
"top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"])
1664+
if (top_k := self.top_k) is None:
1665+
top_k = default_sampling_params.get(
1666+
"top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"])
1667+
if (min_p := self.min_p) is None:
1668+
min_p = default_sampling_params.get(
1669+
"min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"])
1670+
1671+
if (repetition_penalty := self.repetition_penalty) is None:
1672+
repetition_penalty = default_sampling_params.get(
1673+
"repetition_penalty",
1674+
self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"])
16261675

16271676
return SamplingParams.from_optional(temperature=temperature,
16281677
max_tokens=max_tokens,
1678+
seed=self.seed,
1679+
top_p=top_p,
1680+
top_k=top_k,
1681+
min_p=min_p,
1682+
frequency_penalty=self.frequency_penalty,
1683+
repetition_penalty=repetition_penalty,
1684+
presence_penalty=self.presence_penalty,
16291685
output_kind=RequestOutputKind.DELTA
16301686
if self.stream \
16311687
else RequestOutputKind.FINAL_ONLY)

0 commit comments

Comments
 (0)