[Bugfix] Validate custom logits processor xargs for online serving (#27560)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py
2025-11-06 00:53:33 +08:00
committed by GitHub
parent 6cae1e5332
commit 3f5a4b6473
18 changed files with 239 additions and 56 deletions

View File

@@ -52,6 +52,16 @@ prompts = [
class DummyLogitsProcessor(LogitsProcessor):
"""Fake logit processor to support unit testing and examples"""
@classmethod
def validate_params(cls, params: SamplingParams):
target_token: int | None = params.extra_args and params.extra_args.get(
"target_token"
)
if target_token is not None and not isinstance(target_token, int):
raise ValueError(
f"target_token value {target_token} {type(target_token)} is not int"
)
def __init__(
self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool
):
@@ -62,11 +72,14 @@ class DummyLogitsProcessor(LogitsProcessor):
return False
def update_state(self, batch_update: BatchUpdate | None):
def extract_extra_arg(params: SamplingParams) -> int | None:
self.validate_params(params)
return params.extra_args and params.extra_args.get("target_token")
process_dict_updates(
self.req_info,
batch_update,
lambda params, _, __: params.extra_args
and (params.extra_args.get("target_token")),
lambda params, _, __: extract_extra_arg(params),
)
def apply(self, logits: torch.Tensor) -> torch.Tensor: