[Bugfix] Validate logit biases to prevent out of vocab ids crashing engine (#16529)

Signed-off-by: Ryan McConville <ryan@ryanmcconville.com>
This commit is contained in:
Ryan McConville
2025-04-12 21:19:19 +01:00
committed by GitHub
parent 93e5f3c5fb
commit 6c11ecf8d3
3 changed files with 119 additions and 0 deletions

View File

@@ -77,6 +77,7 @@ class Processor:
params: SamplingParams,
) -> None:
self._validate_structured_output(params)
self._validate_logit_bias(params)
if params.allowed_token_ids is None:
return
@@ -87,6 +88,26 @@ class Processor:
raise ValueError(
"allowed_token_ids contains out-of-vocab token id!")
def _validate_logit_bias(
self,
params: SamplingParams,
) -> None:
"""Validate logit_bias token IDs are within vocabulary range."""
if not params.logit_bias:
return
vocab_size = self.model_config.get_vocab_size()
invalid_token_ids = []
for token_id in params.logit_bias:
if token_id < 0 or token_id >= vocab_size:
invalid_token_ids.append(token_id)
if invalid_token_ids:
raise ValueError(
f"token_id(s) {invalid_token_ids} in logit_bias contain "
f"out-of-vocab token ids. Vocabulary size: {vocab_size}")
def _validate_supported_sampling_params(
self,
params: SamplingParams,