[Bugfix] Validate custom logits processor xargs for online serving (#27560)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@@ -18,6 +18,11 @@ In vLLM, logits processors operate at batch granularity. During a given engine s
|
||||
|
||||
Custom logits processors must subclass `vllm.v1.sample.logits_processor.LogitsProcessor` and define (at minimum) the following methods:
|
||||
|
||||
* `validate_params(cls, sampling_params: SamplingParams)`:
|
||||
* Raise `ValueError` if `SamplingParams` has invalid arguments (especially custom arguments) used by logits processor.
|
||||
* When request is sent to entrypoint, `validate_params()` will validate `SamplingParams` and refuse request with invalid arguments.
|
||||
* **Note:** it's important to implement `validate_params()` to prevent invalid parameters for custom logits processor. Otherwise requests with invalid parameters can cause unexpected behaviour in custom logits processor.
|
||||
|
||||
* `__init__(self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool)`
|
||||
* `vllm_config`: engine configuration data structure
|
||||
* `device`: hardware accelerator device info
|
||||
@@ -103,6 +108,14 @@ The contrived example below implements a custom logits processor which consumes
|
||||
class DummyLogitsProcessor(LogitsProcessor):
|
||||
"""Fake logit processor to support unit testing and examples"""
|
||||
|
||||
@classmethod
|
||||
def validate_params(cls, params: SamplingParams):
|
||||
target_token: int | None = params.extra_args and params.extra_args.get(
|
||||
"target_token"
|
||||
)
|
||||
if target_token is not None and not isinstance(target_token, int):
|
||||
raise ValueError(f"target_token value {target_token} is not int")
|
||||
|
||||
def __init__(self, vllm_config: "VllmConfig", device: torch.device,
|
||||
is_pin_memory: bool):
|
||||
self.req_info: dict[int, int] = {}
|
||||
@@ -118,6 +131,7 @@ The contrived example below implements a custom logits processor which consumes
|
||||
# Process added requests.
|
||||
for index, params, _, _ in batch_update.added:
|
||||
assert params is not None
|
||||
self.validate_params(params)
|
||||
if params.extra_args and (target_token :=
|
||||
params.extra_args.get("target_token")):
|
||||
self.req_info[index] = target_token
|
||||
@@ -157,6 +171,7 @@ The contrived example below implements a custom logits processor which consumes
|
||||
logits[rows, cols] = values_to_keep
|
||||
|
||||
return logits
|
||||
|
||||
```
|
||||
|
||||
In the rest of this document, we will use `DummyLogitsProcessor` as an example of a custom logits processor.
|
||||
@@ -180,7 +195,13 @@ RequestLogitsProcessor = Union[
|
||||
|
||||
While request-level logits processors are explicitly *not* supported in the vLLM engine, vLLM *does* provide a convenient process to wrap an existing `Callable` request-level logits processor and create a batch-level logits processor that is compatible with vLLM. The `Callable` must conform to the type annotation above; if your request-level logits processor has a different interface, then in order to wrap it, you may need to modify it or implement an additional wrapper layer to comply with the interface specification above.
|
||||
|
||||
You can wrap the request-level logits processor by subclassing `AdapterLogitsProcessor` as shown in the example below (in this example, `DummyPerReqLogitsProcessor` is a stand-in for your request-level logits processor which needs to be wrapped.) Override `AdapterLogitsProcessor.is_argmax_invariant(self)` to accurately reflect whether your request-level logits processor may impact which token has the highest-value logit. Override `AdapterLogitsProcessor.new_req_logits_processor(self,params)` to create a new request-level logits processor instance from a `SamplingParams` instance:
|
||||
You can wrap the request-level logits processor by subclassing `AdapterLogitsProcessor` as shown in the example below (in this example, `DummyPerReqLogitsProcessor` is a stand-in for your request-level logits processor which needs to be wrapped.):
|
||||
|
||||
* Override `AdapterLogitsProcessor.validate_params(cls,params)` to validate request's sampling parameters.
|
||||
|
||||
* Override `AdapterLogitsProcessor.is_argmax_invariant(self)` to accurately reflect whether your request-level logits processor may impact which token has the highest-value logit.
|
||||
|
||||
* Override `AdapterLogitsProcessor.new_req_logits_processor(self,params)` to create a new request-level logits processor instance from a `SamplingParams` instance:
|
||||
|
||||
??? code "Example of Wrapping a Request-Level Logits Processor"
|
||||
|
||||
@@ -220,6 +241,16 @@ You can wrap the request-level logits processor by subclassing `AdapterLogitsPro
|
||||
"""Example of wrapping a fake request-level logit processor to create a
|
||||
batch-level logits processor"""
|
||||
|
||||
@classmethod
|
||||
def validate_params(cls, params: SamplingParams):
|
||||
target_token: Any | None = params.extra_args and params.extra_args.get(
|
||||
"target_token"
|
||||
)
|
||||
if target_token is not None and not isinstance(target_token, int):
|
||||
raise ValueError(
|
||||
f"target_token value {target_token} is not int"
|
||||
)
|
||||
|
||||
def is_argmax_invariant(self) -> bool:
|
||||
return False
|
||||
|
||||
@@ -240,18 +271,11 @@ You can wrap the request-level logits processor by subclassing `AdapterLogitsPro
|
||||
Returns:
|
||||
`Callable` request logits processor, or None
|
||||
"""
|
||||
target_token: Optional[Any] = params.extra_args and params.extra_args.get(
|
||||
target_token: Any | None = params.extra_args and params.extra_args.get(
|
||||
"target_token"
|
||||
)
|
||||
if target_token is None:
|
||||
return None
|
||||
if not isinstance(target_token, int):
|
||||
logger.warning(
|
||||
"target_token value %s is not int; not applying logits"
|
||||
" processor to request.",
|
||||
target_token,
|
||||
)
|
||||
return None
|
||||
return DummyPerReqLogitsProcessor(target_token)
|
||||
```
|
||||
|
||||
|
||||
Reference in New Issue
Block a user