[Refactor] Make Renderer an abstract class (#33479)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -11,7 +11,7 @@ from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import PoolingRequestOutput, RequestOutput
|
||||
from vllm.plugins.io_processors import IOProcessor
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.renderers import RendererLike
|
||||
from vllm.renderers import BaseRenderer
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.tasks import SupportedTask
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
@@ -28,7 +28,7 @@ class EngineClient(ABC):
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def renderer(self) -> RendererLike: ...
|
||||
def renderer(self) -> BaseRenderer: ...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
|
||||
@@ -2,11 +2,11 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from .params import ChatParams, TokenizeParams, merge_kwargs
|
||||
from .protocol import RendererLike
|
||||
from .protocol import BaseRenderer
|
||||
from .registry import RendererRegistry, renderer_from_config
|
||||
|
||||
__all__ = [
|
||||
"RendererLike",
|
||||
"BaseRenderer",
|
||||
"RendererRegistry",
|
||||
"renderer_from_config",
|
||||
"ChatParams",
|
||||
|
||||
@@ -15,18 +15,18 @@ from vllm.tokenizers import cached_get_tokenizer
|
||||
from vllm.tokenizers.deepseek_v32 import DeepseekV32Tokenizer
|
||||
|
||||
from .params import ChatParams
|
||||
from .protocol import RendererLike
|
||||
from .protocol import BaseRenderer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class DeepseekV32Renderer(RendererLike):
|
||||
class DeepseekV32Renderer(BaseRenderer):
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: ModelConfig,
|
||||
tokenizer_kwargs: dict[str, Any],
|
||||
) -> "RendererLike":
|
||||
) -> "BaseRenderer":
|
||||
return cls(config, tokenizer_kwargs)
|
||||
|
||||
def __init__(
|
||||
@@ -34,9 +34,7 @@ class DeepseekV32Renderer(RendererLike):
|
||||
config: ModelConfig,
|
||||
tokenizer_kwargs: dict[str, Any],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
super().__init__(config)
|
||||
|
||||
if config.skip_tokenizer_init:
|
||||
tokenizer = None
|
||||
|
||||
@@ -15,18 +15,18 @@ from vllm.tokenizers import cached_get_tokenizer
|
||||
from vllm.tokenizers.grok2 import Grok2Tokenizer
|
||||
|
||||
from .params import ChatParams
|
||||
from .protocol import RendererLike
|
||||
from .protocol import BaseRenderer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class Grok2Renderer(RendererLike):
|
||||
class Grok2Renderer(BaseRenderer):
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: ModelConfig,
|
||||
tokenizer_kwargs: dict[str, Any],
|
||||
) -> "RendererLike":
|
||||
) -> "BaseRenderer":
|
||||
return cls(config, tokenizer_kwargs)
|
||||
|
||||
def __init__(
|
||||
@@ -34,9 +34,7 @@ class Grok2Renderer(RendererLike):
|
||||
config: ModelConfig,
|
||||
tokenizer_kwargs: dict[str, Any],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
super().__init__(config)
|
||||
|
||||
if config.skip_tokenizer_init:
|
||||
tokenizer = None
|
||||
|
||||
@@ -34,7 +34,7 @@ from vllm.transformers_utils.processor import cached_get_processor
|
||||
from vllm.utils.func_utils import supports_kw
|
||||
|
||||
from .params import ChatParams
|
||||
from .protocol import RendererLike
|
||||
from .protocol import BaseRenderer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalUUIDDict
|
||||
@@ -584,13 +584,13 @@ def replace_vision_chunk_video_placeholder(
|
||||
return prompt_raw
|
||||
|
||||
|
||||
class HfRenderer(RendererLike):
|
||||
class HfRenderer(BaseRenderer):
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: ModelConfig,
|
||||
tokenizer_kwargs: dict[str, Any],
|
||||
) -> "RendererLike":
|
||||
) -> "BaseRenderer":
|
||||
return cls(config, tokenizer_kwargs)
|
||||
|
||||
def __init__(
|
||||
@@ -598,9 +598,8 @@ class HfRenderer(RendererLike):
|
||||
config: ModelConfig,
|
||||
tokenizer_kwargs: dict[str, Any],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
super().__init__(config)
|
||||
|
||||
self.config = config
|
||||
self.use_unified_vision_chunk = getattr(
|
||||
config.hf_config, "use_unified_vision_chunk", False
|
||||
)
|
||||
|
||||
@@ -17,7 +17,7 @@ from vllm.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.utils.async_utils import make_async
|
||||
|
||||
from .params import ChatParams
|
||||
from .protocol import RendererLike
|
||||
from .protocol import BaseRenderer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -49,13 +49,13 @@ def safe_apply_chat_template(
|
||||
raise ValueError(str(e)) from e
|
||||
|
||||
|
||||
class MistralRenderer(RendererLike):
|
||||
class MistralRenderer(BaseRenderer):
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: ModelConfig,
|
||||
tokenizer_kwargs: dict[str, Any],
|
||||
) -> "RendererLike":
|
||||
) -> "BaseRenderer":
|
||||
return cls(config, tokenizer_kwargs)
|
||||
|
||||
def __init__(
|
||||
@@ -63,9 +63,7 @@ class MistralRenderer(RendererLike):
|
||||
config: ModelConfig,
|
||||
tokenizer_kwargs: dict[str, Any],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
super().__init__(config)
|
||||
|
||||
if config.skip_tokenizer_init:
|
||||
tokenizer = None
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import asyncio
|
||||
from typing import TYPE_CHECKING, Any, Protocol
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
@@ -19,19 +20,26 @@ if TYPE_CHECKING:
|
||||
)
|
||||
|
||||
|
||||
class RendererLike(Protocol):
|
||||
config: "ModelConfig"
|
||||
_async_tokenizer: AsyncMicrobatchTokenizer
|
||||
|
||||
class BaseRenderer(ABC):
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: "ModelConfig",
|
||||
tokenizer_kwargs: dict[str, Any],
|
||||
) -> "RendererLike":
|
||||
) -> "BaseRenderer":
|
||||
raise NotImplementedError
|
||||
|
||||
def __init__(self, config: "ModelConfig") -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
|
||||
# Lazy initialization since offline LLM doesn't use async
|
||||
self._async_tokenizer: AsyncMicrobatchTokenizer | None = None
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def tokenizer(self) -> TokenizerLike | None:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -43,8 +51,7 @@ class RendererLike(Protocol):
|
||||
return tokenizer
|
||||
|
||||
def get_async_tokenizer(self) -> AsyncMicrobatchTokenizer:
|
||||
# Lazy initialization since offline LLM doesn't use async
|
||||
if not hasattr(self, "_async_tokenizer"):
|
||||
if self._async_tokenizer is None:
|
||||
self._async_tokenizer = AsyncMicrobatchTokenizer(self.get_tokenizer())
|
||||
|
||||
return self._async_tokenizer
|
||||
@@ -104,6 +111,7 @@ class RendererLike(Protocol):
|
||||
) -> list[TextPrompt | TokensPrompt | EmbedsPrompt]:
|
||||
return self.render_completions(prompt_input, prompt_embeds)
|
||||
|
||||
@abstractmethod
|
||||
def render_messages(
|
||||
self,
|
||||
messages: list["ChatCompletionMessageParam"],
|
||||
|
||||
@@ -7,7 +7,7 @@ from vllm.logger import init_logger
|
||||
from vllm.tokenizers.registry import tokenizer_args_from_config
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
|
||||
from .protocol import RendererLike
|
||||
from .protocol import BaseRenderer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig
|
||||
@@ -43,7 +43,7 @@ class RendererRegistry:
|
||||
|
||||
return None
|
||||
|
||||
def load_renderer_cls(self, renderer_mode: str) -> type[RendererLike]:
|
||||
def load_renderer_cls(self, renderer_mode: str) -> type[BaseRenderer]:
|
||||
if renderer_mode not in self.renderers:
|
||||
raise ValueError(f"No renderer registered for {renderer_mode=!r}.")
|
||||
|
||||
@@ -57,7 +57,7 @@ class RendererRegistry:
|
||||
renderer_mode: str,
|
||||
config: "ModelConfig",
|
||||
tokenizer_kwargs: dict[str, Any],
|
||||
) -> RendererLike:
|
||||
) -> BaseRenderer:
|
||||
renderer_cls = self.load_renderer_cls(renderer_mode)
|
||||
return renderer_cls.from_config(config, tokenizer_kwargs)
|
||||
|
||||
|
||||
@@ -14,24 +14,22 @@ from vllm.logger import init_logger
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
from .params import ChatParams
|
||||
from .protocol import RendererLike
|
||||
from .protocol import BaseRenderer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class TerratorchRenderer(RendererLike):
|
||||
class TerratorchRenderer(BaseRenderer):
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: "ModelConfig",
|
||||
tokenizer_kwargs: dict[str, Any],
|
||||
) -> "RendererLike":
|
||||
) -> "BaseRenderer":
|
||||
return cls(config)
|
||||
|
||||
def __init__(self, config: ModelConfig) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
super().__init__(config)
|
||||
|
||||
if not config.skip_tokenizer_init:
|
||||
raise ValueError("Terratorch renderer requires `skip_tokenizer_init=True`")
|
||||
|
||||
@@ -24,7 +24,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.outputs import STREAM_FINISHED, PoolingRequestOutput, RequestOutput
|
||||
from vllm.plugins.io_processors import get_io_processor
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.renderers import RendererLike, merge_kwargs
|
||||
from vllm.renderers import BaseRenderer, merge_kwargs
|
||||
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
||||
from vllm.tasks import SupportedTask
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
@@ -844,7 +844,7 @@ class AsyncLLM(EngineClient):
|
||||
return self.input_processor.get_tokenizer()
|
||||
|
||||
@property
|
||||
def renderer(self) -> RendererLike:
|
||||
def renderer(self) -> BaseRenderer:
|
||||
return self.input_processor.renderer
|
||||
|
||||
async def is_tracing_enabled(self) -> bool:
|
||||
|
||||
@@ -29,7 +29,7 @@ from vllm.multimodal.parse import ModalityDataItems, MultiModalDataItems
|
||||
from vllm.multimodal.processing.context import set_request_id
|
||||
from vllm.multimodal.utils import argsort_mm_positions
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.renderers import RendererLike
|
||||
from vllm.renderers import BaseRenderer
|
||||
from vllm.sampling_params import _SAMPLING_EPS, SamplingParams
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
@@ -96,7 +96,7 @@ class InputProcessor:
|
||||
return self.input_preprocessor.get_tokenizer()
|
||||
|
||||
@property
|
||||
def renderer(self) -> RendererLike:
|
||||
def renderer(self) -> BaseRenderer:
|
||||
return self.input_preprocessor.renderer
|
||||
|
||||
def _validate_logprobs(
|
||||
|
||||
@@ -21,7 +21,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.outputs import PoolingRequestOutput, RequestOutput
|
||||
from vllm.plugins.io_processors import get_io_processor
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.renderers import RendererLike
|
||||
from vllm.renderers import BaseRenderer
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.tasks import SupportedTask
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
@@ -367,7 +367,7 @@ class LLMEngine:
|
||||
return self.input_processor.get_tokenizer()
|
||||
|
||||
@property
|
||||
def renderer(self) -> RendererLike:
|
||||
def renderer(self) -> BaseRenderer:
|
||||
return self.input_processor.renderer
|
||||
|
||||
def do_log_stats(self) -> None:
|
||||
|
||||
Reference in New Issue
Block a user