[V1] Refactor Structured Output for multiple backends (#14694)
Signed-off-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
@@ -119,16 +119,21 @@ class Processor:
|
||||
def _validate_structured_output(self, params: SamplingParams) -> None:
|
||||
if not params.guided_decoding or not self.decoding_config:
|
||||
return
|
||||
if self.decoding_config.guided_decoding_backend != "xgrammar":
|
||||
raise ValueError(
|
||||
"Only xgrammar structured output is supported in V1.")
|
||||
if (params.guided_decoding.backend
|
||||
and params.guided_decoding.backend != 'xgrammar'):
|
||||
raise ValueError(
|
||||
"Only xgrammar structured output is supported in V1.")
|
||||
if self.vllm_config.speculative_config:
|
||||
raise ValueError("Structured output is not supported with "
|
||||
"speculative decoding.")
|
||||
|
||||
supported_backends = ["xgrammar"]
|
||||
engine_level_backend = self.decoding_config.guided_decoding_backend
|
||||
if engine_level_backend not in supported_backends:
|
||||
raise ValueError(f"Only {supported_backends} structured output is "
|
||||
"supported in V1.")
|
||||
if params.guided_decoding.backend:
|
||||
if params.guided_decoding.backend != engine_level_backend:
|
||||
raise ValueError("Request-level structured output backend "
|
||||
"must match engine-level backend. "
|
||||
f"{params.guided_decoding.backend}"
|
||||
f" != {engine_level_backend}")
|
||||
else:
|
||||
params.guided_decoding.backend = engine_level_backend
|
||||
|
||||
if vllm.platforms.current_platform.is_tpu():
|
||||
raise ValueError("Structured output is not supported on TPU.")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user