Improve-mm-and-pooler-and-decoding-configs (#16789)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -17,7 +17,7 @@ from dataclasses import (MISSING, dataclass, field, fields, is_dataclass,
|
||||
from importlib.util import find_spec
|
||||
from pathlib import Path
|
||||
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal,
|
||||
Optional, Protocol, TypeVar, Union)
|
||||
Optional, Protocol, TypeVar, Union, get_args)
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
@@ -2725,6 +2725,7 @@ class PromptAdapterConfig:
|
||||
self.prompt_adapter_dtype)
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class MultiModalConfig:
|
||||
"""Controls the behavior of multimodal models."""
|
||||
@@ -2732,6 +2733,8 @@ class MultiModalConfig:
|
||||
limit_per_prompt: Mapping[str, int] = field(default_factory=dict)
|
||||
"""
|
||||
The maximum number of input items allowed per prompt for each modality.
|
||||
This should be a JSON string that will be parsed into a dictionary.
|
||||
Defaults to 1 (V0) or 999 (V1) for each modality.
|
||||
"""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
@@ -2753,24 +2756,20 @@ class MultiModalConfig:
|
||||
usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
def get_default_limit_per_prompt(self) -> int:
|
||||
"""
|
||||
Return the default number of input items allowed per prompt
|
||||
for any modality if not specified by the user.
|
||||
"""
|
||||
return 999 if envs.VLLM_USE_V1 else 1
|
||||
|
||||
def get_limit_per_prompt(self, modality: str) -> int:
|
||||
"""
|
||||
Get the maximum number of input items allowed per prompt
|
||||
for the given modality.
|
||||
"""
|
||||
default = self.get_default_limit_per_prompt()
|
||||
return self.limit_per_prompt.get(modality, default)
|
||||
return self.limit_per_prompt.get(
|
||||
modality,
|
||||
999 if envs.VLLM_USE_V1 else 1,
|
||||
)
|
||||
|
||||
# TODO: Add configs to init vision tower or not.
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class PoolerConfig:
|
||||
"""Controls the behavior of output pooling in pooling models."""
|
||||
@@ -3095,15 +3094,28 @@ def get_served_model_name(model: str,
|
||||
return served_model_name
|
||||
|
||||
|
||||
GuidedDecodingBackendV0 = Literal["auto", "outlines", "lm-format-enforcer",
|
||||
"xgrammar"]
|
||||
GuidedDecodingBackendV1 = Literal["auto", "xgrammar", "guidance"]
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class DecodingConfig:
|
||||
"""Dataclass which contains the decoding strategy of the engine"""
|
||||
"""Dataclass which contains the decoding strategy of the engine."""
|
||||
|
||||
# Which guided decoding algo to use.
|
||||
# 'outlines' / 'lm-format-enforcer' / 'xgrammar'
|
||||
guided_decoding_backend: str = "auto" if envs.VLLM_USE_V1 else "xgrammar"
|
||||
guided_decoding_backend: Union[
|
||||
GuidedDecodingBackendV0,
|
||||
GuidedDecodingBackendV1] = "auto" if envs.VLLM_USE_V1 else "xgrammar"
|
||||
"""Which engine will be used for guided decoding (JSON schema / regex etc)
|
||||
by default. With "auto", we will make opinionated choices based on request
|
||||
contents and what the backend libraries currently support, so the behavior
|
||||
is subject to change in each release."""
|
||||
|
||||
reasoning_backend: Optional[str] = None
|
||||
"""Select the reasoning parser depending on the model that you're using.
|
||||
This is used to parse the reasoning content into OpenAI API format.
|
||||
Required for `--enable-reasoning`."""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
@@ -3125,17 +3137,12 @@ class DecodingConfig:
|
||||
return hash_str
|
||||
|
||||
def __post_init__(self):
|
||||
v0_valid_guided_backends = [
|
||||
'outlines', 'lm-format-enforcer', 'xgrammar', 'auto'
|
||||
]
|
||||
v1_valid_guided_backends = ['xgrammar', 'guidance', 'auto']
|
||||
|
||||
backend = GuidedDecodingParams(
|
||||
backend=self.guided_decoding_backend).backend_name
|
||||
if envs.VLLM_USE_V1:
|
||||
valid_guided_backends = v1_valid_guided_backends
|
||||
valid_guided_backends = get_args(GuidedDecodingBackendV1)
|
||||
else:
|
||||
valid_guided_backends = v0_valid_guided_backends
|
||||
valid_guided_backends = get_args(GuidedDecodingBackendV0)
|
||||
if backend not in valid_guided_backends:
|
||||
raise ValueError(f"Invalid guided_decoding_backend '{backend}',"
|
||||
f" must be one of {valid_guided_backends}")
|
||||
|
||||
Reference in New Issue
Block a user