Skip to content

Commit 7d7e7fb

Browse files
CatherineSueLayssy
authored andcommitted
[OAI] Support non-normalized logprobs in OpenAI server (sgl-project#5961)
1 parent 02ec708 commit 7d7e7fb

File tree

1 file changed

+4
-9
lines changed

1 file changed

+4
-9
lines changed

python/sglang/srt/layers/sampler.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,9 @@ def forward(
8686
# NOTE: the top_p_renorm_prob from flashinfer has numerical problems,
8787
# https://github.com/flashinfer-ai/flashinfer/issues/708
8888
# so we use the torch implementation.
89-
90-
# clamp to avoid -inf
91-
logprobs = torch.log(
92-
top_p_normalize_probs_torch(probs, sampling_info.top_ps)
93-
).clamp(min=torch.finfo(probs.dtype).min)
89+
# NOTE: OpenAI's logprobs is independent of top-p, we use the
90+
# same rule.
91+
logprobs = torch.log(probs).clamp(min=torch.finfo(probs.dtype).min)
9492

9593
max_top_k_round, batch_size = 32, probs.shape[0]
9694
if sampling_info.need_min_p_sampling:
@@ -121,10 +119,7 @@ def forward(
121119
)
122120

123121
if return_logprob:
124-
# clamp to avoid -inf
125-
logprobs = torch.log(
126-
top_p_normalize_probs_torch(probs, sampling_info.top_ps)
127-
).clamp(min=torch.finfo(probs.dtype).min)
122+
logprobs = torch.log(probs).clamp(min=torch.finfo(probs.dtype).min)
128123
else:
129124
raise ValueError(
130125
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"

0 commit comments

Comments
 (0)