[V1][Frontend] Add Testing For V1 Runtime Parameters (#14159)

Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
This commit is contained in:
Robert Shaw
2025-03-05 14:18:55 +00:00
committed by GitHub
parent 47d4a7e004
commit 257e200a25
3 changed files with 201 additions and 17 deletions

View File

@@ -55,11 +55,8 @@ class Processor:
def _validate_logprobs(
self,
params: Union[SamplingParams, PoolingParams],
params: SamplingParams,
) -> None:
if not isinstance(params, SamplingParams):
return
max_logprobs = self.model_config.max_logprobs
# Validate sample logprobs.
if params.logprobs and params.logprobs > max_logprobs:
@@ -79,17 +76,10 @@ class Processor:
raise ValueError("Prefix caching with prompt logprobs not yet "
"supported on VLLM V1.")
def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None:
if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!")
def _validate_allowed_token_ids(
def _validate_sampling_params(
self,
params: Union[SamplingParams, PoolingParams],
params: SamplingParams,
) -> None:
if not isinstance(params, SamplingParams):
return
if params.allowed_token_ids is None:
return
if not params.allowed_token_ids:
@@ -99,6 +89,42 @@ class Processor:
raise ValueError(
"allowed_token_ids contains out-of-vocab token id!")
def _validate_supported_sampling_params(
self,
params: SamplingParams,
) -> None:
# Best of not yet supported.
if params.best_of:
raise ValueError("VLLM V1 does not yet support best_of.")
# Bad words not yet supported.
if params.bad_words:
raise ValueError("VLLM V1 does not yet support bad_words.")
# Logits processors not supported.
if params.logits_processors:
raise ValueError("VLLM V1 does not support per request "
"user provided logits processors.")
def _validate_params(
self,
params: Union[SamplingParams, PoolingParams],
):
"""
Validate supported SamplingParam.
Should raise ValueError if unsupported for API Server.
"""
if not isinstance(params, SamplingParams):
raise ValueError("V1 does not yet support Pooling models.")
self._validate_logprobs(params)
self._validate_sampling_params(params)
self._validate_supported_sampling_params(params)
def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None:
if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!")
def process_inputs(
self,
request_id: str,
@@ -114,14 +140,17 @@ class Processor:
# TODO(woosuk): Support pooling models.
# TODO(woosuk): Support encoder-decoder models.
self._validate_logprobs(params)
self._validate_lora(lora_request)
self._validate_allowed_token_ids(params)
self._validate_params(params)
if priority != 0:
raise ValueError("V1 does not support priority yet.")
if trace_headers is not None:
raise ValueError("V1 does not support tracing yet.")
if prompt_adapter_request is not None:
raise ValueError("V1 does not support prompt_adapter_request.")
if arrival_time is None:
arrival_time = time.time()
assert priority == 0, "vLLM V1 does not support priority at the moment."
assert trace_headers is None, "vLLM V1 does not support tracing yet."
# Process inputs, which includes:
# 1. Tokenize text prompt, with LoRA request if one exists.