[Sampler] Support returning all prompt logprobs (#23868)

Signed-off-by: Xingyu Liu <charlotteliu12x@gmail.com>
Co-authored-by: 22quinn <33176974+22quinn@users.noreply.github.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
Xingyu Liu
2025-09-07 19:34:31 -07:00
committed by GitHub
parent 67841317d1
commit b3d7e3c845
4 changed files with 38 additions and 18 deletions

View File

@@ -65,19 +65,27 @@ class Processor:
) -> None:
max_logprobs = self.model_config.max_logprobs
if max_logprobs == -1:
return
max_logprobs = self.model_config.get_vocab_size()
# Validate sample logprobs.
if params.logprobs and (params.logprobs == -1
or params.logprobs > max_logprobs):
raise ValueError(
f"Requested sample logprobs of {params.logprobs}, "
f"which is greater than max allowed: {max_logprobs}")
if params.logprobs:
num_logprobs = params.logprobs
if num_logprobs == -1:
num_logprobs = self.model_config.get_vocab_size()
if num_logprobs > max_logprobs:
raise ValueError(
f"Requested sample logprobs of {num_logprobs}, "
f"which is is greater than max allowed: {max_logprobs}")
# Validate prompt logprobs.
if params.prompt_logprobs and params.prompt_logprobs > max_logprobs:
raise ValueError(
f"Requested prompt logprobs of {params.prompt_logprobs}, "
f"which is greater than max allowed: {max_logprobs}")
if params.prompt_logprobs:
num_prompt_logprobs = params.prompt_logprobs
if num_prompt_logprobs == -1:
num_prompt_logprobs = self.model_config.get_vocab_size()
if num_prompt_logprobs > max_logprobs:
raise ValueError(
f"Requested prompt logprobs of {num_prompt_logprobs}, "
f"which is is greater than max allowed: {max_logprobs}")
def _validate_sampling_params(
self,