[Refactor] Make Renderer an abstract class (#33479)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-02-01 10:36:30 +08:00
committed by GitHub
parent 079781177a
commit a358e4dffe
12 changed files with 49 additions and 50 deletions

View File

@@ -11,7 +11,7 @@ from vllm.lora.request import LoRARequest
from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.plugins.io_processors import IOProcessor from vllm.plugins.io_processors import IOProcessor
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.renderers import RendererLike from vllm.renderers import BaseRenderer
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.tasks import SupportedTask from vllm.tasks import SupportedTask
from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine import EngineCoreRequest
@@ -28,7 +28,7 @@ class EngineClient(ABC):
@property @property
@abstractmethod @abstractmethod
def renderer(self) -> RendererLike: ... def renderer(self) -> BaseRenderer: ...
@property @property
@abstractmethod @abstractmethod

View File

@@ -2,11 +2,11 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from .params import ChatParams, TokenizeParams, merge_kwargs from .params import ChatParams, TokenizeParams, merge_kwargs
from .protocol import RendererLike from .protocol import BaseRenderer
from .registry import RendererRegistry, renderer_from_config from .registry import RendererRegistry, renderer_from_config
__all__ = [ __all__ = [
"RendererLike", "BaseRenderer",
"RendererRegistry", "RendererRegistry",
"renderer_from_config", "renderer_from_config",
"ChatParams", "ChatParams",

View File

@@ -15,18 +15,18 @@ from vllm.tokenizers import cached_get_tokenizer
from vllm.tokenizers.deepseek_v32 import DeepseekV32Tokenizer from vllm.tokenizers.deepseek_v32 import DeepseekV32Tokenizer
from .params import ChatParams from .params import ChatParams
from .protocol import RendererLike from .protocol import BaseRenderer
logger = init_logger(__name__) logger = init_logger(__name__)
class DeepseekV32Renderer(RendererLike): class DeepseekV32Renderer(BaseRenderer):
@classmethod @classmethod
def from_config( def from_config(
cls, cls,
config: ModelConfig, config: ModelConfig,
tokenizer_kwargs: dict[str, Any], tokenizer_kwargs: dict[str, Any],
) -> "RendererLike": ) -> "BaseRenderer":
return cls(config, tokenizer_kwargs) return cls(config, tokenizer_kwargs)
def __init__( def __init__(
@@ -34,9 +34,7 @@ class DeepseekV32Renderer(RendererLike):
config: ModelConfig, config: ModelConfig,
tokenizer_kwargs: dict[str, Any], tokenizer_kwargs: dict[str, Any],
) -> None: ) -> None:
super().__init__() super().__init__(config)
self.config = config
if config.skip_tokenizer_init: if config.skip_tokenizer_init:
tokenizer = None tokenizer = None

View File

@@ -15,18 +15,18 @@ from vllm.tokenizers import cached_get_tokenizer
from vllm.tokenizers.grok2 import Grok2Tokenizer from vllm.tokenizers.grok2 import Grok2Tokenizer
from .params import ChatParams from .params import ChatParams
from .protocol import RendererLike from .protocol import BaseRenderer
logger = init_logger(__name__) logger = init_logger(__name__)
class Grok2Renderer(RendererLike): class Grok2Renderer(BaseRenderer):
@classmethod @classmethod
def from_config( def from_config(
cls, cls,
config: ModelConfig, config: ModelConfig,
tokenizer_kwargs: dict[str, Any], tokenizer_kwargs: dict[str, Any],
) -> "RendererLike": ) -> "BaseRenderer":
return cls(config, tokenizer_kwargs) return cls(config, tokenizer_kwargs)
def __init__( def __init__(
@@ -34,9 +34,7 @@ class Grok2Renderer(RendererLike):
config: ModelConfig, config: ModelConfig,
tokenizer_kwargs: dict[str, Any], tokenizer_kwargs: dict[str, Any],
) -> None: ) -> None:
super().__init__() super().__init__(config)
self.config = config
if config.skip_tokenizer_init: if config.skip_tokenizer_init:
tokenizer = None tokenizer = None

View File

@@ -34,7 +34,7 @@ from vllm.transformers_utils.processor import cached_get_processor
from vllm.utils.func_utils import supports_kw from vllm.utils.func_utils import supports_kw
from .params import ChatParams from .params import ChatParams
from .protocol import RendererLike from .protocol import BaseRenderer
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalUUIDDict from vllm.multimodal.inputs import MultiModalDataDict, MultiModalUUIDDict
@@ -584,13 +584,13 @@ def replace_vision_chunk_video_placeholder(
return prompt_raw return prompt_raw
class HfRenderer(RendererLike): class HfRenderer(BaseRenderer):
@classmethod @classmethod
def from_config( def from_config(
cls, cls,
config: ModelConfig, config: ModelConfig,
tokenizer_kwargs: dict[str, Any], tokenizer_kwargs: dict[str, Any],
) -> "RendererLike": ) -> "BaseRenderer":
return cls(config, tokenizer_kwargs) return cls(config, tokenizer_kwargs)
def __init__( def __init__(
@@ -598,9 +598,8 @@ class HfRenderer(RendererLike):
config: ModelConfig, config: ModelConfig,
tokenizer_kwargs: dict[str, Any], tokenizer_kwargs: dict[str, Any],
) -> None: ) -> None:
super().__init__() super().__init__(config)
self.config = config
self.use_unified_vision_chunk = getattr( self.use_unified_vision_chunk = getattr(
config.hf_config, "use_unified_vision_chunk", False config.hf_config, "use_unified_vision_chunk", False
) )

View File

@@ -17,7 +17,7 @@ from vllm.tokenizers.mistral import MistralTokenizer
from vllm.utils.async_utils import make_async from vllm.utils.async_utils import make_async
from .params import ChatParams from .params import ChatParams
from .protocol import RendererLike from .protocol import BaseRenderer
logger = init_logger(__name__) logger = init_logger(__name__)
@@ -49,13 +49,13 @@ def safe_apply_chat_template(
raise ValueError(str(e)) from e raise ValueError(str(e)) from e
class MistralRenderer(RendererLike): class MistralRenderer(BaseRenderer):
@classmethod @classmethod
def from_config( def from_config(
cls, cls,
config: ModelConfig, config: ModelConfig,
tokenizer_kwargs: dict[str, Any], tokenizer_kwargs: dict[str, Any],
) -> "RendererLike": ) -> "BaseRenderer":
return cls(config, tokenizer_kwargs) return cls(config, tokenizer_kwargs)
def __init__( def __init__(
@@ -63,9 +63,7 @@ class MistralRenderer(RendererLike):
config: ModelConfig, config: ModelConfig,
tokenizer_kwargs: dict[str, Any], tokenizer_kwargs: dict[str, Any],
) -> None: ) -> None:
super().__init__() super().__init__(config)
self.config = config
if config.skip_tokenizer_init: if config.skip_tokenizer_init:
tokenizer = None tokenizer = None

View File

@@ -1,7 +1,8 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio import asyncio
from typing import TYPE_CHECKING, Any, Protocol from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
@@ -19,19 +20,26 @@ if TYPE_CHECKING:
) )
class RendererLike(Protocol): class BaseRenderer(ABC):
config: "ModelConfig"
_async_tokenizer: AsyncMicrobatchTokenizer
@classmethod @classmethod
@abstractmethod
def from_config( def from_config(
cls, cls,
config: "ModelConfig", config: "ModelConfig",
tokenizer_kwargs: dict[str, Any], tokenizer_kwargs: dict[str, Any],
) -> "RendererLike": ) -> "BaseRenderer":
raise NotImplementedError raise NotImplementedError
def __init__(self, config: "ModelConfig") -> None:
super().__init__()
self.config = config
# Lazy initialization since offline LLM doesn't use async
self._async_tokenizer: AsyncMicrobatchTokenizer | None = None
@property @property
@abstractmethod
def tokenizer(self) -> TokenizerLike | None: def tokenizer(self) -> TokenizerLike | None:
raise NotImplementedError raise NotImplementedError
@@ -43,8 +51,7 @@ class RendererLike(Protocol):
return tokenizer return tokenizer
def get_async_tokenizer(self) -> AsyncMicrobatchTokenizer: def get_async_tokenizer(self) -> AsyncMicrobatchTokenizer:
# Lazy initialization since offline LLM doesn't use async if self._async_tokenizer is None:
if not hasattr(self, "_async_tokenizer"):
self._async_tokenizer = AsyncMicrobatchTokenizer(self.get_tokenizer()) self._async_tokenizer = AsyncMicrobatchTokenizer(self.get_tokenizer())
return self._async_tokenizer return self._async_tokenizer
@@ -104,6 +111,7 @@ class RendererLike(Protocol):
) -> list[TextPrompt | TokensPrompt | EmbedsPrompt]: ) -> list[TextPrompt | TokensPrompt | EmbedsPrompt]:
return self.render_completions(prompt_input, prompt_embeds) return self.render_completions(prompt_input, prompt_embeds)
@abstractmethod
def render_messages( def render_messages(
self, self,
messages: list["ChatCompletionMessageParam"], messages: list["ChatCompletionMessageParam"],

View File

@@ -7,7 +7,7 @@ from vllm.logger import init_logger
from vllm.tokenizers.registry import tokenizer_args_from_config from vllm.tokenizers.registry import tokenizer_args_from_config
from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.utils.import_utils import resolve_obj_by_qualname
from .protocol import RendererLike from .protocol import BaseRenderer
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import ModelConfig from vllm.config import ModelConfig
@@ -43,7 +43,7 @@ class RendererRegistry:
return None return None
def load_renderer_cls(self, renderer_mode: str) -> type[RendererLike]: def load_renderer_cls(self, renderer_mode: str) -> type[BaseRenderer]:
if renderer_mode not in self.renderers: if renderer_mode not in self.renderers:
raise ValueError(f"No renderer registered for {renderer_mode=!r}.") raise ValueError(f"No renderer registered for {renderer_mode=!r}.")
@@ -57,7 +57,7 @@ class RendererRegistry:
renderer_mode: str, renderer_mode: str,
config: "ModelConfig", config: "ModelConfig",
tokenizer_kwargs: dict[str, Any], tokenizer_kwargs: dict[str, Any],
) -> RendererLike: ) -> BaseRenderer:
renderer_cls = self.load_renderer_cls(renderer_mode) renderer_cls = self.load_renderer_cls(renderer_mode)
return renderer_cls.from_config(config, tokenizer_kwargs) return renderer_cls.from_config(config, tokenizer_kwargs)

View File

@@ -14,24 +14,22 @@ from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from .params import ChatParams from .params import ChatParams
from .protocol import RendererLike from .protocol import BaseRenderer
logger = init_logger(__name__) logger = init_logger(__name__)
class TerratorchRenderer(RendererLike): class TerratorchRenderer(BaseRenderer):
@classmethod @classmethod
def from_config( def from_config(
cls, cls,
config: "ModelConfig", config: "ModelConfig",
tokenizer_kwargs: dict[str, Any], tokenizer_kwargs: dict[str, Any],
) -> "RendererLike": ) -> "BaseRenderer":
return cls(config) return cls(config)
def __init__(self, config: ModelConfig) -> None: def __init__(self, config: ModelConfig) -> None:
super().__init__() super().__init__(config)
self.config = config
if not config.skip_tokenizer_init: if not config.skip_tokenizer_init:
raise ValueError("Terratorch renderer requires `skip_tokenizer_init=True`") raise ValueError("Terratorch renderer requires `skip_tokenizer_init=True`")

View File

@@ -24,7 +24,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.outputs import STREAM_FINISHED, PoolingRequestOutput, RequestOutput from vllm.outputs import STREAM_FINISHED, PoolingRequestOutput, RequestOutput
from vllm.plugins.io_processors import get_io_processor from vllm.plugins.io_processors import get_io_processor
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.renderers import RendererLike, merge_kwargs from vllm.renderers import BaseRenderer, merge_kwargs
from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.tasks import SupportedTask from vllm.tasks import SupportedTask
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
@@ -844,7 +844,7 @@ class AsyncLLM(EngineClient):
return self.input_processor.get_tokenizer() return self.input_processor.get_tokenizer()
@property @property
def renderer(self) -> RendererLike: def renderer(self) -> BaseRenderer:
return self.input_processor.renderer return self.input_processor.renderer
async def is_tracing_enabled(self) -> bool: async def is_tracing_enabled(self) -> bool:

View File

@@ -29,7 +29,7 @@ from vllm.multimodal.parse import ModalityDataItems, MultiModalDataItems
from vllm.multimodal.processing.context import set_request_id from vllm.multimodal.processing.context import set_request_id
from vllm.multimodal.utils import argsort_mm_positions from vllm.multimodal.utils import argsort_mm_positions
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.renderers import RendererLike from vllm.renderers import BaseRenderer
from vllm.sampling_params import _SAMPLING_EPS, SamplingParams from vllm.sampling_params import _SAMPLING_EPS, SamplingParams
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.mistral import MistralTokenizer from vllm.tokenizers.mistral import MistralTokenizer
@@ -96,7 +96,7 @@ class InputProcessor:
return self.input_preprocessor.get_tokenizer() return self.input_preprocessor.get_tokenizer()
@property @property
def renderer(self) -> RendererLike: def renderer(self) -> BaseRenderer:
return self.input_preprocessor.renderer return self.input_preprocessor.renderer
def _validate_logprobs( def _validate_logprobs(

View File

@@ -21,7 +21,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.plugins.io_processors import get_io_processor from vllm.plugins.io_processors import get_io_processor
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.renderers import RendererLike from vllm.renderers import BaseRenderer
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.tasks import SupportedTask from vllm.tasks import SupportedTask
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
@@ -367,7 +367,7 @@ class LLMEngine:
return self.input_processor.get_tokenizer() return self.input_processor.get_tokenizer()
@property @property
def renderer(self) -> RendererLike: def renderer(self) -> BaseRenderer:
return self.input_processor.renderer return self.input_processor.renderer
def do_log_stats(self) -> None: def do_log_stats(self) -> None: