Simplify (and fix) passing of guided decoding backend options (#17008)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -17,12 +17,14 @@ from dataclasses import (MISSING, dataclass, field, fields, is_dataclass,
|
||||
from importlib.util import find_spec
|
||||
from pathlib import Path
|
||||
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal,
|
||||
Optional, Protocol, TypeVar, Union, get_args, get_origin)
|
||||
Optional, Protocol, TypeVar, Union, cast, get_args,
|
||||
get_origin)
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
from torch.distributed import ProcessGroup, ReduceOp
|
||||
from transformers import PretrainedConfig
|
||||
from typing_extensions import deprecated
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
|
||||
@@ -32,7 +34,6 @@ from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS,
|
||||
get_quantization_config)
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
from vllm.platforms import CpuArchEnum, current_platform
|
||||
from vllm.sampling_params import GuidedDecodingParams
|
||||
from vllm.tracing import is_otel_available, otel_import_error_traceback
|
||||
from vllm.transformers_utils.config import (
|
||||
ConfigFormat, get_config, get_hf_image_processor_config,
|
||||
@@ -344,7 +345,7 @@ class ModelConfig:
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
task: Union[TaskOption, Literal["draft"]],
|
||||
task: Literal[TaskOption, Literal["draft"]],
|
||||
tokenizer: str,
|
||||
tokenizer_mode: str,
|
||||
trust_remote_code: bool,
|
||||
@@ -701,7 +702,7 @@ class ModelConfig:
|
||||
|
||||
def _resolve_task(
|
||||
self,
|
||||
task_option: Union[TaskOption, Literal["draft"]],
|
||||
task_option: Literal[TaskOption, Literal["draft"]],
|
||||
) -> tuple[set[_ResolvedTask], _ResolvedTask]:
|
||||
if task_option == "draft":
|
||||
return {"draft"}, "draft"
|
||||
@@ -3185,13 +3186,36 @@ GuidedDecodingBackend = Literal[GuidedDecodingBackendV0,
|
||||
class DecodingConfig:
|
||||
"""Dataclass which contains the decoding strategy of the engine."""
|
||||
|
||||
guided_decoding_backend: GuidedDecodingBackend = \
|
||||
"auto" if envs.VLLM_USE_V1 else "xgrammar"
|
||||
@property
|
||||
@deprecated(
|
||||
"`guided_decoding_backend` is deprecated and has been renamed to "
|
||||
"`backend`. This will be removed in v0.10.0. Please use the "
|
||||
"`backend` argument instead.")
|
||||
def guided_decoding_backend(self) -> GuidedDecodingBackend:
|
||||
return self.backend
|
||||
|
||||
@guided_decoding_backend.setter
|
||||
def guided_decoding_backend(self, value: GuidedDecodingBackend):
|
||||
self.backend = value
|
||||
|
||||
backend: GuidedDecodingBackend = "auto" if envs.VLLM_USE_V1 else "xgrammar"
|
||||
"""Which engine will be used for guided decoding (JSON schema / regex etc)
|
||||
by default. With "auto", we will make opinionated choices based on request
|
||||
contents and what the backend libraries currently support, so the behavior
|
||||
is subject to change in each release."""
|
||||
|
||||
disable_fallback: bool = False
|
||||
"""If `True`, vLLM will not fallback to a different backend on error."""
|
||||
|
||||
disable_any_whitespace: bool = False
|
||||
"""If `True`, the model will not generate any whitespace during guided
|
||||
decoding. This is only supported for xgrammar and guidance backends."""
|
||||
|
||||
disable_additional_properties: bool = False
|
||||
"""If `True`, the `guidance` backend will not use `additionalProperties`
|
||||
in the JSON schema. This is only supported for the `guidance` backend and
|
||||
is used to better align its behaviour with `outlines` and `xgrammar`."""
|
||||
|
||||
reasoning_backend: Optional[str] = None
|
||||
"""Select the reasoning parser depending on the model that you're using.
|
||||
This is used to parse the reasoning content into OpenAI API format.
|
||||
@@ -3217,15 +3241,41 @@ class DecodingConfig:
|
||||
return hash_str
|
||||
|
||||
def __post_init__(self):
|
||||
backend = GuidedDecodingParams(
|
||||
backend=self.guided_decoding_backend).backend_name
|
||||
if ":" in self.backend:
|
||||
self._extract_backend_options()
|
||||
|
||||
if envs.VLLM_USE_V1:
|
||||
valid_guided_backends = get_args(GuidedDecodingBackendV1)
|
||||
else:
|
||||
valid_guided_backends = get_args(GuidedDecodingBackendV0)
|
||||
if backend not in valid_guided_backends:
|
||||
raise ValueError(f"Invalid guided_decoding_backend '{backend}',"
|
||||
if self.backend not in valid_guided_backends:
|
||||
raise ValueError(f"Invalid backend '{self.backend}',"
|
||||
f" must be one of {valid_guided_backends}")
|
||||
if (self.disable_any_whitespace
|
||||
and self.backend not in ("xgrammar", "guidance")):
|
||||
raise ValueError("disable_any_whitespace is only supported for "
|
||||
"xgrammar and guidance backends.")
|
||||
if (self.disable_additional_properties and self.backend != "guidance"):
|
||||
raise ValueError("disable_additional_properties is only supported "
|
||||
"for the guidance backend.")
|
||||
|
||||
@deprecated(
|
||||
"Passing guided decoding backend options inside backend in the format "
|
||||
"'backend:...' is deprecated. This will be removed in v0.10.0. Please "
|
||||
"use the dedicated arguments '--disable-fallback', "
|
||||
"'--disable-any-whitespace' and '--disable-additional-properties' "
|
||||
"instead.")
|
||||
def _extract_backend_options(self):
|
||||
"""Extract backend options from the backend string."""
|
||||
backend, options = self.backend.split(":")
|
||||
self.backend = cast(GuidedDecodingBackend, backend)
|
||||
options_set = set(options.strip().split(","))
|
||||
if "no-fallback" in options_set:
|
||||
self.disable_fallback = True
|
||||
if "disable-any-whitespace" in options_set:
|
||||
self.disable_any_whitespace = True
|
||||
if "no-additional-properties" in options_set:
|
||||
self.disable_additional_properties = True
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
Reference in New Issue
Block a user