[Bugfix] Validate logit biases to prevent out of vocab ids crashing engine (#16529)
Signed-off-by: Ryan McConville <ryan@ryanmcconville.com>
This commit is contained in:
@@ -77,6 +77,7 @@ class Processor:
|
||||
params: SamplingParams,
|
||||
) -> None:
|
||||
self._validate_structured_output(params)
|
||||
self._validate_logit_bias(params)
|
||||
|
||||
if params.allowed_token_ids is None:
|
||||
return
|
||||
@@ -87,6 +88,26 @@ class Processor:
|
||||
raise ValueError(
|
||||
"allowed_token_ids contains out-of-vocab token id!")
|
||||
|
||||
def _validate_logit_bias(
|
||||
self,
|
||||
params: SamplingParams,
|
||||
) -> None:
|
||||
"""Validate logit_bias token IDs are within vocabulary range."""
|
||||
if not params.logit_bias:
|
||||
return
|
||||
|
||||
vocab_size = self.model_config.get_vocab_size()
|
||||
invalid_token_ids = []
|
||||
|
||||
for token_id in params.logit_bias:
|
||||
if token_id < 0 or token_id >= vocab_size:
|
||||
invalid_token_ids.append(token_id)
|
||||
|
||||
if invalid_token_ids:
|
||||
raise ValueError(
|
||||
f"token_id(s) {invalid_token_ids} in logit_bias contain "
|
||||
f"out-of-vocab token ids. Vocabulary size: {vocab_size}")
|
||||
|
||||
def _validate_supported_sampling_params(
|
||||
self,
|
||||
params: SamplingParams,
|
||||
|
||||
Reference in New Issue
Block a user