Improve-mm-and-pooler-and-decoding-configs (#16789)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-04-18 06:13:32 +01:00
committed by GitHub
parent 7eb4255628
commit e78587a64c
14 changed files with 84 additions and 78 deletions

View File

@@ -17,7 +17,7 @@ from dataclasses import (MISSING, dataclass, field, fields, is_dataclass,
from importlib.util import find_spec
from pathlib import Path
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal,
Optional, Protocol, TypeVar, Union)
Optional, Protocol, TypeVar, Union, get_args)
import torch
from pydantic import BaseModel, Field, PrivateAttr
@@ -2725,6 +2725,7 @@ class PromptAdapterConfig:
self.prompt_adapter_dtype)
@config
@dataclass
class MultiModalConfig:
"""Controls the behavior of multimodal models."""
@@ -2732,6 +2733,8 @@ class MultiModalConfig:
limit_per_prompt: Mapping[str, int] = field(default_factory=dict)
"""
The maximum number of input items allowed per prompt for each modality.
This should be a JSON string that will be parsed into a dictionary.
Defaults to 1 (V0) or 999 (V1) for each modality.
"""
def compute_hash(self) -> str:
@@ -2753,24 +2756,20 @@ class MultiModalConfig:
usedforsecurity=False).hexdigest()
return hash_str
def get_default_limit_per_prompt(self) -> int:
"""
Return the default number of input items allowed per prompt
for any modality if not specified by the user.
"""
return 999 if envs.VLLM_USE_V1 else 1
def get_limit_per_prompt(self, modality: str) -> int:
"""
Get the maximum number of input items allowed per prompt
for the given modality.
"""
default = self.get_default_limit_per_prompt()
return self.limit_per_prompt.get(modality, default)
return self.limit_per_prompt.get(
modality,
999 if envs.VLLM_USE_V1 else 1,
)
# TODO: Add configs to init vision tower or not.
@config
@dataclass
class PoolerConfig:
"""Controls the behavior of output pooling in pooling models."""
@@ -3095,15 +3094,28 @@ def get_served_model_name(model: str,
return served_model_name
GuidedDecodingBackendV0 = Literal["auto", "outlines", "lm-format-enforcer",
"xgrammar"]
GuidedDecodingBackendV1 = Literal["auto", "xgrammar", "guidance"]
@config
@dataclass
class DecodingConfig:
"""Dataclass which contains the decoding strategy of the engine"""
"""Dataclass which contains the decoding strategy of the engine."""
# Which guided decoding algo to use.
# 'outlines' / 'lm-format-enforcer' / 'xgrammar'
guided_decoding_backend: str = "auto" if envs.VLLM_USE_V1 else "xgrammar"
guided_decoding_backend: Union[
GuidedDecodingBackendV0,
GuidedDecodingBackendV1] = "auto" if envs.VLLM_USE_V1 else "xgrammar"
"""Which engine will be used for guided decoding (JSON schema / regex etc)
by default. With "auto", we will make opinionated choices based on request
contents and what the backend libraries currently support, so the behavior
is subject to change in each release."""
reasoning_backend: Optional[str] = None
"""Select the reasoning parser depending on the model that you're using.
This is used to parse the reasoning content into OpenAI API format.
Required for `--enable-reasoning`."""
def compute_hash(self) -> str:
"""
@@ -3125,17 +3137,12 @@ class DecodingConfig:
return hash_str
def __post_init__(self):
v0_valid_guided_backends = [
'outlines', 'lm-format-enforcer', 'xgrammar', 'auto'
]
v1_valid_guided_backends = ['xgrammar', 'guidance', 'auto']
backend = GuidedDecodingParams(
backend=self.guided_decoding_backend).backend_name
if envs.VLLM_USE_V1:
valid_guided_backends = v1_valid_guided_backends
valid_guided_backends = get_args(GuidedDecodingBackendV1)
else:
valid_guided_backends = v0_valid_guided_backends
valid_guided_backends = get_args(GuidedDecodingBackendV0)
if backend not in valid_guided_backends:
raise ValueError(f"Invalid guided_decoding_backend '{backend}',"
f" must be one of {valid_guided_backends}")