[Renderer] Consolidate factory methods (#38218)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -16,7 +16,7 @@ from vllm.entrypoints.openai.models.serving import OpenAIServingModels
|
||||
from vllm.entrypoints.serve.render.serving import OpenAIServingRender
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.renderers.hf import HfRenderer
|
||||
from vllm.tokenizers.registry import tokenizer_args_from_config
|
||||
from vllm.tokenizers.registry import cached_tokenizer_from_config
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
|
||||
MODEL_NAME = "openai-community/gpt2"
|
||||
@@ -72,11 +72,9 @@ class MockVllmConfig:
|
||||
|
||||
|
||||
def _build_renderer(model_config: MockModelConfig):
|
||||
_, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config)
|
||||
|
||||
return HfRenderer.from_config(
|
||||
return HfRenderer(
|
||||
MockVllmConfig(model_config, parallel_config=MockParallelConfig()),
|
||||
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
|
||||
cached_tokenizer_from_config(model_config),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -41,7 +41,7 @@ from vllm.renderers.hf import HfRenderer
|
||||
from vllm.renderers.mistral import MistralRenderer
|
||||
from vllm.tokenizers import get_tokenizer
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.tokenizers.registry import tokenizer_args_from_config
|
||||
from vllm.tokenizers.registry import cached_tokenizer_from_config
|
||||
from vllm.tool_parsers import ToolParserManager
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
|
||||
@@ -553,11 +553,9 @@ class MockVllmConfig:
|
||||
|
||||
|
||||
def _build_renderer(model_config: MockModelConfig):
|
||||
_, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config)
|
||||
|
||||
return HfRenderer.from_config(
|
||||
return HfRenderer(
|
||||
MockVllmConfig(model_config, parallel_config=MockParallelConfig()),
|
||||
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
|
||||
cached_tokenizer_from_config(model_config),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ from vllm.entrypoints.openai.models.serving import OpenAIServingModels
|
||||
from vllm.entrypoints.serve.render.serving import OpenAIServingRender
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.renderers.hf import HfRenderer
|
||||
from vllm.tokenizers.registry import tokenizer_args_from_config
|
||||
from vllm.tokenizers.registry import cached_tokenizer_from_config
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
|
||||
MODEL_NAME = "openai-community/gpt2"
|
||||
@@ -93,11 +93,9 @@ def _build_serving_completion(engine: AsyncLLM) -> OpenAIServingCompletion:
|
||||
|
||||
|
||||
def _build_renderer(model_config: MockModelConfig):
|
||||
_, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config)
|
||||
|
||||
return HfRenderer.from_config(
|
||||
return HfRenderer(
|
||||
MockVllmConfig(model_config, parallel_config=MockParallelConfig()),
|
||||
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
|
||||
cached_tokenizer_from_config(model_config),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ from vllm.entrypoints.serve.render.serving import OpenAIServingRender
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry
|
||||
from vllm.renderers.hf import HfRenderer
|
||||
from vllm.tokenizers.registry import tokenizer_args_from_config
|
||||
from vllm.tokenizers.registry import cached_tokenizer_from_config
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
|
||||
MODEL_NAME = "openai-community/gpt2"
|
||||
@@ -101,11 +101,9 @@ def register_mock_resolver():
|
||||
|
||||
|
||||
def _build_renderer(model_config: MockModelConfig):
|
||||
_, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config)
|
||||
|
||||
return HfRenderer.from_config(
|
||||
return HfRenderer(
|
||||
MockVllmConfig(model_config, parallel_config=MockParallelConfig()),
|
||||
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
|
||||
cached_tokenizer_from_config(model_config),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -15,7 +15,6 @@ from vllm.inputs import SingletonPrompt
|
||||
from vllm.renderers import TokenizeParams
|
||||
from vllm.renderers.hf import HfRenderer
|
||||
from vllm.renderers.inputs.preprocess import parse_model_prompt, prompt_to_seq
|
||||
from vllm.tokenizers.registry import tokenizer_args_from_config
|
||||
|
||||
MODEL_NAME = "openai-community/gpt2"
|
||||
|
||||
@@ -81,8 +80,6 @@ def _build_renderer(
|
||||
truncation_side: str = "left",
|
||||
max_chars_per_token: int = 1,
|
||||
):
|
||||
_, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config)
|
||||
|
||||
renderer = HfRenderer(
|
||||
MockVllmConfig(model_config, parallel_config=MockParallelConfig()),
|
||||
tokenizer=(
|
||||
|
||||
@@ -8,7 +8,7 @@ from vllm.assets.video import VideoAsset
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
from vllm.multimodal.parse import parse_mm_uuids
|
||||
from vllm.renderers.hf import HfRenderer
|
||||
from vllm.tokenizers.registry import tokenizer_args_from_config
|
||||
from vllm.tokenizers.registry import cached_tokenizer_from_config
|
||||
|
||||
cherry_pil_image = ImageAsset("cherry_blossom").pil_image
|
||||
stop_pil_image = ImageAsset("stop_sign").pil_image
|
||||
@@ -29,11 +29,9 @@ def _build_renderer(
|
||||
cache_config=CacheConfig(enable_prefix_caching=enable_prefix_caching),
|
||||
)
|
||||
|
||||
_, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config)
|
||||
|
||||
return HfRenderer.from_config(
|
||||
return HfRenderer(
|
||||
vllm_config,
|
||||
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
|
||||
cached_tokenizer_from_config(model_config),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -542,7 +542,9 @@ class ModelConfig:
|
||||
|
||||
# Set default tokenizer modes based on model architecture
|
||||
if self.tokenizer_mode == "auto":
|
||||
if arch == "Grok1ForCausalLM":
|
||||
if self.model_impl == "terratorch":
|
||||
self.tokenizer_mode = "terratorch"
|
||||
elif arch == "Grok1ForCausalLM":
|
||||
self.tokenizer_mode = "grok2"
|
||||
elif arch == "MoonshotKimiaForCausalLM":
|
||||
self.tokenizer_mode = "kimi_audio"
|
||||
|
||||
@@ -69,15 +69,6 @@ _T = TypeVar("_T", bound=TokenizerLike, default=TokenizerLike)
|
||||
|
||||
|
||||
class BaseRenderer(ABC, Generic[_T]):
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: "VllmConfig",
|
||||
tokenizer_kwargs: dict[str, Any],
|
||||
) -> "BaseRenderer":
|
||||
raise NotImplementedError
|
||||
|
||||
def __init__(self, config: "VllmConfig", tokenizer: _T | None) -> None:
|
||||
super().__init__()
|
||||
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
ChatCompletionMessageParam,
|
||||
ConversationMessage,
|
||||
@@ -10,7 +8,6 @@ from vllm.entrypoints.chat_utils import (
|
||||
parse_chat_messages_async,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import cached_get_tokenizer
|
||||
from vllm.tokenizers.deepseek_v32 import DeepseekV32Tokenizer
|
||||
|
||||
from .base import BaseRenderer
|
||||
@@ -22,23 +19,6 @@ logger = init_logger(__name__)
|
||||
|
||||
|
||||
class DeepseekV32Renderer(BaseRenderer[DeepseekV32Tokenizer]):
|
||||
@classmethod
|
||||
def from_config( # type: ignore[override]
|
||||
cls,
|
||||
config: VllmConfig,
|
||||
tokenizer_kwargs: dict[str, Any],
|
||||
) -> "DeepseekV32Renderer":
|
||||
model_config = config.model_config
|
||||
if model_config.skip_tokenizer_init:
|
||||
tokenizer = None
|
||||
else:
|
||||
tokenizer = cached_get_tokenizer(
|
||||
tokenizer_cls=DeepseekV32Tokenizer,
|
||||
**tokenizer_kwargs,
|
||||
)
|
||||
|
||||
return cls(config, tokenizer)
|
||||
|
||||
def render_messages(
|
||||
self,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
ChatCompletionMessageParam,
|
||||
ConversationMessage,
|
||||
@@ -10,7 +8,6 @@ from vllm.entrypoints.chat_utils import (
|
||||
parse_chat_messages_async,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import cached_get_tokenizer
|
||||
from vllm.tokenizers.grok2 import Grok2Tokenizer
|
||||
|
||||
from .base import BaseRenderer
|
||||
@@ -22,23 +19,6 @@ logger = init_logger(__name__)
|
||||
|
||||
|
||||
class Grok2Renderer(BaseRenderer[Grok2Tokenizer]):
|
||||
@classmethod
|
||||
def from_config( # type: ignore[override]
|
||||
cls,
|
||||
config: VllmConfig,
|
||||
tokenizer_kwargs: dict[str, Any],
|
||||
) -> "Grok2Renderer":
|
||||
model_config = config.model_config
|
||||
if model_config.skip_tokenizer_init:
|
||||
tokenizer = None
|
||||
else:
|
||||
tokenizer = cached_get_tokenizer(
|
||||
tokenizer_cls=Grok2Tokenizer,
|
||||
**tokenizer_kwargs,
|
||||
)
|
||||
|
||||
return cls(config, tokenizer)
|
||||
|
||||
def render_messages(
|
||||
self,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
|
||||
@@ -27,8 +27,7 @@ from vllm.entrypoints.chat_utils import (
|
||||
)
|
||||
from vllm.inputs import MultiModalDataDict, MultiModalUUIDDict
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import cached_get_tokenizer
|
||||
from vllm.tokenizers.hf import CachedHfTokenizer, HfTokenizer
|
||||
from vllm.tokenizers.hf import HfTokenizer
|
||||
from vllm.transformers_utils.chat_templates import get_chat_template_fallback_path
|
||||
from vllm.transformers_utils.processor import cached_get_processor
|
||||
from vllm.utils.func_utils import supports_kw
|
||||
@@ -604,26 +603,6 @@ def replace_vision_chunk_video_placeholder(
|
||||
|
||||
|
||||
class HfRenderer(BaseRenderer[HfTokenizer]):
|
||||
@classmethod
|
||||
def from_config( # type: ignore[override]
|
||||
cls,
|
||||
config: VllmConfig,
|
||||
tokenizer_kwargs: dict[str, Any],
|
||||
) -> "HfRenderer":
|
||||
model_config = config.model_config
|
||||
if model_config.skip_tokenizer_init:
|
||||
tokenizer = None
|
||||
else:
|
||||
tokenizer = cast(
|
||||
HfTokenizer,
|
||||
cached_get_tokenizer(
|
||||
tokenizer_cls=CachedHfTokenizer, # type: ignore[type-abstract]
|
||||
**tokenizer_kwargs,
|
||||
),
|
||||
)
|
||||
|
||||
return cls(config, tokenizer)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: VllmConfig,
|
||||
|
||||
@@ -1,49 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any, cast
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.tokenizers.kimi_audio import KimiAudioTokenizer
|
||||
from vllm.tokenizers.registry import get_tokenizer
|
||||
|
||||
from .hf import HfRenderer, HfTokenizer
|
||||
|
||||
|
||||
class KimiAudioRenderer(HfRenderer):
|
||||
"""Renderer for Kimi-Audio models.
|
||||
|
||||
This renderer uses HfRenderer internally with a custom TikToken tokenizer.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_config( # type: ignore[override]
|
||||
cls,
|
||||
config: VllmConfig,
|
||||
tokenizer_kwargs: dict[str, Any],
|
||||
) -> "HfRenderer":
|
||||
"""Create an HfRenderer instance for Kimi-Audio models."""
|
||||
model_config = config.model_config
|
||||
if model_config.skip_tokenizer_init:
|
||||
tokenizer = None
|
||||
else:
|
||||
# Extract tokenizer_name from kwargs (already processed by
|
||||
# tokenizer_args_from_config for ModelScope/GGUF/etc)
|
||||
tokenizer_name = tokenizer_kwargs.pop(
|
||||
"tokenizer_name", model_config.tokenizer
|
||||
)
|
||||
# Remove tokenizer_cls from kwargs to avoid duplicate argument
|
||||
tokenizer_kwargs = {
|
||||
k: v for k, v in tokenizer_kwargs.items() if k != "tokenizer_cls"
|
||||
}
|
||||
# Use get_tokenizer directly instead of cached_get_tokenizer
|
||||
# (KimiAudioTokenizer doesn't work with get_cached_tokenizer)
|
||||
tokenizer = cast(
|
||||
HfTokenizer,
|
||||
get_tokenizer(
|
||||
tokenizer_name,
|
||||
tokenizer_cls=KimiAudioTokenizer, # type: ignore[arg-type]
|
||||
**tokenizer_kwargs,
|
||||
),
|
||||
)
|
||||
|
||||
return HfRenderer(config, tokenizer)
|
||||
@@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
@@ -11,7 +10,6 @@ from vllm.entrypoints.chat_utils import (
|
||||
parse_chat_messages_async,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import cached_get_tokenizer
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.utils.async_utils import make_async
|
||||
|
||||
@@ -51,23 +49,6 @@ def safe_apply_chat_template(
|
||||
|
||||
|
||||
class MistralRenderer(BaseRenderer[MistralTokenizer]):
|
||||
@classmethod
|
||||
def from_config( # type: ignore[override]
|
||||
cls,
|
||||
config: VllmConfig,
|
||||
tokenizer_kwargs: dict[str, Any],
|
||||
) -> "MistralRenderer":
|
||||
model_config = config.model_config
|
||||
if model_config.skip_tokenizer_init:
|
||||
tokenizer = None
|
||||
else:
|
||||
tokenizer = cached_get_tokenizer(
|
||||
tokenizer_cls=MistralTokenizer,
|
||||
**tokenizer_kwargs,
|
||||
)
|
||||
|
||||
return cls(config, tokenizer)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: VllmConfig,
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.tokenizers import cached_get_tokenizer
|
||||
from vllm.tokenizers.qwen_vl import QwenVLTokenizer
|
||||
|
||||
from .hf import HfRenderer
|
||||
|
||||
|
||||
class QwenVLRenderer(HfRenderer):
|
||||
@classmethod
|
||||
def from_config( # type: ignore[override]
|
||||
cls,
|
||||
config: VllmConfig,
|
||||
tokenizer_kwargs: dict[str, Any],
|
||||
) -> "HfRenderer":
|
||||
model_config = config.model_config
|
||||
if model_config.skip_tokenizer_init:
|
||||
tokenizer = None
|
||||
else:
|
||||
tokenizer = cached_get_tokenizer(
|
||||
tokenizer_cls=QwenVLTokenizer,
|
||||
**tokenizer_kwargs,
|
||||
)
|
||||
|
||||
return HfRenderer(config, tokenizer)
|
||||
@@ -1,10 +1,14 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers.registry import tokenizer_args_from_config
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tokenizers.registry import (
|
||||
cached_tokenizer_from_config,
|
||||
tokenizer_args_from_config,
|
||||
)
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
|
||||
from .base import BaseRenderer
|
||||
@@ -19,9 +23,9 @@ _VLLM_RENDERERS = {
|
||||
"deepseek_v32": ("deepseek_v32", "DeepseekV32Renderer"),
|
||||
"hf": ("hf", "HfRenderer"),
|
||||
"grok2": ("grok2", "Grok2Renderer"),
|
||||
"kimi_audio": ("kimi_audio", "KimiAudioRenderer"),
|
||||
"kimi_audio": ("hf", "HfRenderer"),
|
||||
"mistral": ("mistral", "MistralRenderer"),
|
||||
"qwen_vl": ("qwen_vl", "QwenVLRenderer"),
|
||||
"qwen_vl": ("hf", "HfRenderer"),
|
||||
"terratorch": ("terratorch", "TerratorchRenderer"),
|
||||
}
|
||||
|
||||
@@ -58,10 +62,10 @@ class RendererRegistry:
|
||||
self,
|
||||
renderer_mode: str,
|
||||
config: "VllmConfig",
|
||||
tokenizer_kwargs: dict[str, Any],
|
||||
tokenizer: TokenizerLike | None,
|
||||
) -> BaseRenderer:
|
||||
renderer_cls = self.load_renderer_cls(renderer_mode)
|
||||
return renderer_cls.from_config(config, tokenizer_kwargs)
|
||||
return renderer_cls(config, tokenizer)
|
||||
|
||||
|
||||
RENDERER_REGISTRY = RendererRegistry(
|
||||
@@ -76,20 +80,7 @@ RENDERER_REGISTRY = RendererRegistry(
|
||||
def renderer_from_config(config: "VllmConfig", **kwargs):
|
||||
model_config = config.model_config
|
||||
|
||||
tokenizer_mode, tokenizer_name, args, kwargs = tokenizer_args_from_config(
|
||||
model_config, **kwargs
|
||||
)
|
||||
tokenizer = cached_tokenizer_from_config(model_config, **kwargs)
|
||||
renderer_mode, *_ = tokenizer_args_from_config(model_config, **kwargs)
|
||||
|
||||
if (
|
||||
model_config.tokenizer_mode == "auto"
|
||||
and model_config.model_impl == "terratorch"
|
||||
):
|
||||
renderer_mode = "terratorch"
|
||||
else:
|
||||
renderer_mode = tokenizer_mode
|
||||
|
||||
return RENDERER_REGISTRY.load_renderer(
|
||||
renderer_mode,
|
||||
config,
|
||||
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
|
||||
)
|
||||
return RENDERER_REGISTRY.load_renderer(renderer_mode, config, tokenizer)
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
ChatCompletionMessageParam,
|
||||
ConversationMessage,
|
||||
@@ -20,18 +18,6 @@ logger = init_logger(__name__)
|
||||
|
||||
|
||||
class TerratorchRenderer(BaseRenderer):
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: VllmConfig, # type: ignore[override]
|
||||
tokenizer_kwargs: dict[str, Any],
|
||||
) -> "TerratorchRenderer":
|
||||
model_config = config.model_config
|
||||
if not model_config.skip_tokenizer_init:
|
||||
raise ValueError("Terratorch renderer requires `skip_tokenizer_init=True`")
|
||||
|
||||
return cls(config, None)
|
||||
|
||||
def render_messages(
|
||||
self,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
|
||||
Reference in New Issue
Block a user