[Bugfix] Validate custom logits processor xargs for online serving (#27560)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@@ -33,6 +33,8 @@ Output: ' in the hands of the people.\n\nThe future of AI is in the'
|
||||
------------------------------------------------------------
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
@@ -48,6 +50,16 @@ from vllm.v1.sample.logits_processor.builtin import process_dict_updates
|
||||
class DummyLogitsProcessor(LogitsProcessor):
|
||||
"""Fake logit processor to support unit testing and examples"""
|
||||
|
||||
@classmethod
|
||||
def validate_params(cls, params: SamplingParams):
|
||||
target_token: Any | None = params.extra_args and params.extra_args.get(
|
||||
"target_token"
|
||||
)
|
||||
if target_token is not None and not isinstance(target_token, int):
|
||||
raise ValueError(
|
||||
f"target_token value {target_token} {type(target_token)} is not int"
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool
|
||||
):
|
||||
@@ -57,14 +69,17 @@ class DummyLogitsProcessor(LogitsProcessor):
|
||||
return False
|
||||
|
||||
def update_state(self, batch_update: BatchUpdate | None):
|
||||
def extract_extra_arg(params: SamplingParams) -> int | None:
|
||||
self.validate_params(params)
|
||||
return params.extra_args and params.extra_args.get("target_token")
|
||||
|
||||
process_dict_updates(
|
||||
self.req_info,
|
||||
batch_update,
|
||||
# This function returns the LP's per-request state based on the
|
||||
# request details, or None if this LP does not apply to the
|
||||
# request.
|
||||
lambda params, _, __: params.extra_args
|
||||
and (params.extra_args.get("target_token")),
|
||||
lambda params, _, __: extract_extra_arg(params),
|
||||
)
|
||||
|
||||
def apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
@@ -76,6 +76,14 @@ class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor):
|
||||
"""Example of wrapping a fake request-level logit processor to create a
|
||||
batch-level logits processor"""
|
||||
|
||||
@classmethod
|
||||
def validate_params(cls, params: SamplingParams):
|
||||
target_token: Any | None = params.extra_args and params.extra_args.get(
|
||||
"target_token"
|
||||
)
|
||||
if target_token is not None and not isinstance(target_token, int):
|
||||
raise ValueError(f"target_token value {target_token} is not int")
|
||||
|
||||
def is_argmax_invariant(self) -> bool:
|
||||
return False
|
||||
|
||||
@@ -101,13 +109,6 @@ class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor):
|
||||
)
|
||||
if target_token is None:
|
||||
return None
|
||||
if not isinstance(target_token, int):
|
||||
logger.warning(
|
||||
"target_token value %s is not int; not applying logits"
|
||||
" processor to request.",
|
||||
target_token,
|
||||
)
|
||||
return None
|
||||
return DummyPerReqLogitsProcessor(target_token)
|
||||
|
||||
|
||||
|
||||
@@ -77,6 +77,14 @@ class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor):
|
||||
"""Example of overriding the wrapper class `__init__()` in order to utilize
|
||||
info about the device type"""
|
||||
|
||||
@classmethod
|
||||
def validate_params(cls, params: SamplingParams):
|
||||
target_token = params.extra_args and params.extra_args.get("target_token")
|
||||
if target_token is not None and not isinstance(target_token, int):
|
||||
raise ValueError(
|
||||
f"`target_token` has to be an integer, got {target_token}."
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool
|
||||
):
|
||||
@@ -113,13 +121,6 @@ class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor):
|
||||
is None
|
||||
):
|
||||
return None
|
||||
if not isinstance(target_token, int):
|
||||
logger.warning(
|
||||
"target_token value %s is not int; not applying logits"
|
||||
" processor to request.",
|
||||
target_token,
|
||||
)
|
||||
return None
|
||||
return DummyPerReqLogitsProcessor(target_token)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user