[Refactor] Make Renderer an abstract class (#33479)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -11,7 +11,7 @@ from vllm.lora.request import LoRARequest
|
|||||||
from vllm.outputs import PoolingRequestOutput, RequestOutput
|
from vllm.outputs import PoolingRequestOutput, RequestOutput
|
||||||
from vllm.plugins.io_processors import IOProcessor
|
from vllm.plugins.io_processors import IOProcessor
|
||||||
from vllm.pooling_params import PoolingParams
|
from vllm.pooling_params import PoolingParams
|
||||||
from vllm.renderers import RendererLike
|
from vllm.renderers import BaseRenderer
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.tasks import SupportedTask
|
from vllm.tasks import SupportedTask
|
||||||
from vllm.v1.engine import EngineCoreRequest
|
from vllm.v1.engine import EngineCoreRequest
|
||||||
@@ -28,7 +28,7 @@ class EngineClient(ABC):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def renderer(self) -> RendererLike: ...
|
def renderer(self) -> BaseRenderer: ...
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|||||||
@@ -2,11 +2,11 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from .params import ChatParams, TokenizeParams, merge_kwargs
|
from .params import ChatParams, TokenizeParams, merge_kwargs
|
||||||
from .protocol import RendererLike
|
from .protocol import BaseRenderer
|
||||||
from .registry import RendererRegistry, renderer_from_config
|
from .registry import RendererRegistry, renderer_from_config
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"RendererLike",
|
"BaseRenderer",
|
||||||
"RendererRegistry",
|
"RendererRegistry",
|
||||||
"renderer_from_config",
|
"renderer_from_config",
|
||||||
"ChatParams",
|
"ChatParams",
|
||||||
|
|||||||
@@ -15,18 +15,18 @@ from vllm.tokenizers import cached_get_tokenizer
|
|||||||
from vllm.tokenizers.deepseek_v32 import DeepseekV32Tokenizer
|
from vllm.tokenizers.deepseek_v32 import DeepseekV32Tokenizer
|
||||||
|
|
||||||
from .params import ChatParams
|
from .params import ChatParams
|
||||||
from .protocol import RendererLike
|
from .protocol import BaseRenderer
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class DeepseekV32Renderer(RendererLike):
|
class DeepseekV32Renderer(BaseRenderer):
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(
|
def from_config(
|
||||||
cls,
|
cls,
|
||||||
config: ModelConfig,
|
config: ModelConfig,
|
||||||
tokenizer_kwargs: dict[str, Any],
|
tokenizer_kwargs: dict[str, Any],
|
||||||
) -> "RendererLike":
|
) -> "BaseRenderer":
|
||||||
return cls(config, tokenizer_kwargs)
|
return cls(config, tokenizer_kwargs)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -34,9 +34,7 @@ class DeepseekV32Renderer(RendererLike):
|
|||||||
config: ModelConfig,
|
config: ModelConfig,
|
||||||
tokenizer_kwargs: dict[str, Any],
|
tokenizer_kwargs: dict[str, Any],
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__(config)
|
||||||
|
|
||||||
self.config = config
|
|
||||||
|
|
||||||
if config.skip_tokenizer_init:
|
if config.skip_tokenizer_init:
|
||||||
tokenizer = None
|
tokenizer = None
|
||||||
|
|||||||
@@ -15,18 +15,18 @@ from vllm.tokenizers import cached_get_tokenizer
|
|||||||
from vllm.tokenizers.grok2 import Grok2Tokenizer
|
from vllm.tokenizers.grok2 import Grok2Tokenizer
|
||||||
|
|
||||||
from .params import ChatParams
|
from .params import ChatParams
|
||||||
from .protocol import RendererLike
|
from .protocol import BaseRenderer
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class Grok2Renderer(RendererLike):
|
class Grok2Renderer(BaseRenderer):
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(
|
def from_config(
|
||||||
cls,
|
cls,
|
||||||
config: ModelConfig,
|
config: ModelConfig,
|
||||||
tokenizer_kwargs: dict[str, Any],
|
tokenizer_kwargs: dict[str, Any],
|
||||||
) -> "RendererLike":
|
) -> "BaseRenderer":
|
||||||
return cls(config, tokenizer_kwargs)
|
return cls(config, tokenizer_kwargs)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -34,9 +34,7 @@ class Grok2Renderer(RendererLike):
|
|||||||
config: ModelConfig,
|
config: ModelConfig,
|
||||||
tokenizer_kwargs: dict[str, Any],
|
tokenizer_kwargs: dict[str, Any],
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__(config)
|
||||||
|
|
||||||
self.config = config
|
|
||||||
|
|
||||||
if config.skip_tokenizer_init:
|
if config.skip_tokenizer_init:
|
||||||
tokenizer = None
|
tokenizer = None
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ from vllm.transformers_utils.processor import cached_get_processor
|
|||||||
from vllm.utils.func_utils import supports_kw
|
from vllm.utils.func_utils import supports_kw
|
||||||
|
|
||||||
from .params import ChatParams
|
from .params import ChatParams
|
||||||
from .protocol import RendererLike
|
from .protocol import BaseRenderer
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalUUIDDict
|
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalUUIDDict
|
||||||
@@ -584,13 +584,13 @@ def replace_vision_chunk_video_placeholder(
|
|||||||
return prompt_raw
|
return prompt_raw
|
||||||
|
|
||||||
|
|
||||||
class HfRenderer(RendererLike):
|
class HfRenderer(BaseRenderer):
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(
|
def from_config(
|
||||||
cls,
|
cls,
|
||||||
config: ModelConfig,
|
config: ModelConfig,
|
||||||
tokenizer_kwargs: dict[str, Any],
|
tokenizer_kwargs: dict[str, Any],
|
||||||
) -> "RendererLike":
|
) -> "BaseRenderer":
|
||||||
return cls(config, tokenizer_kwargs)
|
return cls(config, tokenizer_kwargs)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -598,9 +598,8 @@ class HfRenderer(RendererLike):
|
|||||||
config: ModelConfig,
|
config: ModelConfig,
|
||||||
tokenizer_kwargs: dict[str, Any],
|
tokenizer_kwargs: dict[str, Any],
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__(config)
|
||||||
|
|
||||||
self.config = config
|
|
||||||
self.use_unified_vision_chunk = getattr(
|
self.use_unified_vision_chunk = getattr(
|
||||||
config.hf_config, "use_unified_vision_chunk", False
|
config.hf_config, "use_unified_vision_chunk", False
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ from vllm.tokenizers.mistral import MistralTokenizer
|
|||||||
from vllm.utils.async_utils import make_async
|
from vllm.utils.async_utils import make_async
|
||||||
|
|
||||||
from .params import ChatParams
|
from .params import ChatParams
|
||||||
from .protocol import RendererLike
|
from .protocol import BaseRenderer
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@@ -49,13 +49,13 @@ def safe_apply_chat_template(
|
|||||||
raise ValueError(str(e)) from e
|
raise ValueError(str(e)) from e
|
||||||
|
|
||||||
|
|
||||||
class MistralRenderer(RendererLike):
|
class MistralRenderer(BaseRenderer):
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(
|
def from_config(
|
||||||
cls,
|
cls,
|
||||||
config: ModelConfig,
|
config: ModelConfig,
|
||||||
tokenizer_kwargs: dict[str, Any],
|
tokenizer_kwargs: dict[str, Any],
|
||||||
) -> "RendererLike":
|
) -> "BaseRenderer":
|
||||||
return cls(config, tokenizer_kwargs)
|
return cls(config, tokenizer_kwargs)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -63,9 +63,7 @@ class MistralRenderer(RendererLike):
|
|||||||
config: ModelConfig,
|
config: ModelConfig,
|
||||||
tokenizer_kwargs: dict[str, Any],
|
tokenizer_kwargs: dict[str, Any],
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__(config)
|
||||||
|
|
||||||
self.config = config
|
|
||||||
|
|
||||||
if config.skip_tokenizer_init:
|
if config.skip_tokenizer_init:
|
||||||
tokenizer = None
|
tokenizer = None
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import TYPE_CHECKING, Any, Protocol
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
|
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
|
||||||
from vllm.tokenizers import TokenizerLike
|
from vllm.tokenizers import TokenizerLike
|
||||||
@@ -19,19 +20,26 @@ if TYPE_CHECKING:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class RendererLike(Protocol):
|
class BaseRenderer(ABC):
|
||||||
config: "ModelConfig"
|
|
||||||
_async_tokenizer: AsyncMicrobatchTokenizer
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@abstractmethod
|
||||||
def from_config(
|
def from_config(
|
||||||
cls,
|
cls,
|
||||||
config: "ModelConfig",
|
config: "ModelConfig",
|
||||||
tokenizer_kwargs: dict[str, Any],
|
tokenizer_kwargs: dict[str, Any],
|
||||||
) -> "RendererLike":
|
) -> "BaseRenderer":
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def __init__(self, config: "ModelConfig") -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
# Lazy initialization since offline LLM doesn't use async
|
||||||
|
self._async_tokenizer: AsyncMicrobatchTokenizer | None = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@abstractmethod
|
||||||
def tokenizer(self) -> TokenizerLike | None:
|
def tokenizer(self) -> TokenizerLike | None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@@ -43,8 +51,7 @@ class RendererLike(Protocol):
|
|||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
||||||
def get_async_tokenizer(self) -> AsyncMicrobatchTokenizer:
|
def get_async_tokenizer(self) -> AsyncMicrobatchTokenizer:
|
||||||
# Lazy initialization since offline LLM doesn't use async
|
if self._async_tokenizer is None:
|
||||||
if not hasattr(self, "_async_tokenizer"):
|
|
||||||
self._async_tokenizer = AsyncMicrobatchTokenizer(self.get_tokenizer())
|
self._async_tokenizer = AsyncMicrobatchTokenizer(self.get_tokenizer())
|
||||||
|
|
||||||
return self._async_tokenizer
|
return self._async_tokenizer
|
||||||
@@ -104,6 +111,7 @@ class RendererLike(Protocol):
|
|||||||
) -> list[TextPrompt | TokensPrompt | EmbedsPrompt]:
|
) -> list[TextPrompt | TokensPrompt | EmbedsPrompt]:
|
||||||
return self.render_completions(prompt_input, prompt_embeds)
|
return self.render_completions(prompt_input, prompt_embeds)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
def render_messages(
|
def render_messages(
|
||||||
self,
|
self,
|
||||||
messages: list["ChatCompletionMessageParam"],
|
messages: list["ChatCompletionMessageParam"],
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from vllm.logger import init_logger
|
|||||||
from vllm.tokenizers.registry import tokenizer_args_from_config
|
from vllm.tokenizers.registry import tokenizer_args_from_config
|
||||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||||
|
|
||||||
from .protocol import RendererLike
|
from .protocol import BaseRenderer
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
@@ -43,7 +43,7 @@ class RendererRegistry:
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def load_renderer_cls(self, renderer_mode: str) -> type[RendererLike]:
|
def load_renderer_cls(self, renderer_mode: str) -> type[BaseRenderer]:
|
||||||
if renderer_mode not in self.renderers:
|
if renderer_mode not in self.renderers:
|
||||||
raise ValueError(f"No renderer registered for {renderer_mode=!r}.")
|
raise ValueError(f"No renderer registered for {renderer_mode=!r}.")
|
||||||
|
|
||||||
@@ -57,7 +57,7 @@ class RendererRegistry:
|
|||||||
renderer_mode: str,
|
renderer_mode: str,
|
||||||
config: "ModelConfig",
|
config: "ModelConfig",
|
||||||
tokenizer_kwargs: dict[str, Any],
|
tokenizer_kwargs: dict[str, Any],
|
||||||
) -> RendererLike:
|
) -> BaseRenderer:
|
||||||
renderer_cls = self.load_renderer_cls(renderer_mode)
|
renderer_cls = self.load_renderer_cls(renderer_mode)
|
||||||
return renderer_cls.from_config(config, tokenizer_kwargs)
|
return renderer_cls.from_config(config, tokenizer_kwargs)
|
||||||
|
|
||||||
|
|||||||
@@ -14,24 +14,22 @@ from vllm.logger import init_logger
|
|||||||
from vllm.tokenizers import TokenizerLike
|
from vllm.tokenizers import TokenizerLike
|
||||||
|
|
||||||
from .params import ChatParams
|
from .params import ChatParams
|
||||||
from .protocol import RendererLike
|
from .protocol import BaseRenderer
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class TerratorchRenderer(RendererLike):
|
class TerratorchRenderer(BaseRenderer):
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(
|
def from_config(
|
||||||
cls,
|
cls,
|
||||||
config: "ModelConfig",
|
config: "ModelConfig",
|
||||||
tokenizer_kwargs: dict[str, Any],
|
tokenizer_kwargs: dict[str, Any],
|
||||||
) -> "RendererLike":
|
) -> "BaseRenderer":
|
||||||
return cls(config)
|
return cls(config)
|
||||||
|
|
||||||
def __init__(self, config: ModelConfig) -> None:
|
def __init__(self, config: ModelConfig) -> None:
|
||||||
super().__init__()
|
super().__init__(config)
|
||||||
|
|
||||||
self.config = config
|
|
||||||
|
|
||||||
if not config.skip_tokenizer_init:
|
if not config.skip_tokenizer_init:
|
||||||
raise ValueError("Terratorch renderer requires `skip_tokenizer_init=True`")
|
raise ValueError("Terratorch renderer requires `skip_tokenizer_init=True`")
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
|||||||
from vllm.outputs import STREAM_FINISHED, PoolingRequestOutput, RequestOutput
|
from vllm.outputs import STREAM_FINISHED, PoolingRequestOutput, RequestOutput
|
||||||
from vllm.plugins.io_processors import get_io_processor
|
from vllm.plugins.io_processors import get_io_processor
|
||||||
from vllm.pooling_params import PoolingParams
|
from vllm.pooling_params import PoolingParams
|
||||||
from vllm.renderers import RendererLike, merge_kwargs
|
from vllm.renderers import BaseRenderer, merge_kwargs
|
||||||
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
||||||
from vllm.tasks import SupportedTask
|
from vllm.tasks import SupportedTask
|
||||||
from vllm.tokenizers import TokenizerLike
|
from vllm.tokenizers import TokenizerLike
|
||||||
@@ -844,7 +844,7 @@ class AsyncLLM(EngineClient):
|
|||||||
return self.input_processor.get_tokenizer()
|
return self.input_processor.get_tokenizer()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def renderer(self) -> RendererLike:
|
def renderer(self) -> BaseRenderer:
|
||||||
return self.input_processor.renderer
|
return self.input_processor.renderer
|
||||||
|
|
||||||
async def is_tracing_enabled(self) -> bool:
|
async def is_tracing_enabled(self) -> bool:
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ from vllm.multimodal.parse import ModalityDataItems, MultiModalDataItems
|
|||||||
from vllm.multimodal.processing.context import set_request_id
|
from vllm.multimodal.processing.context import set_request_id
|
||||||
from vllm.multimodal.utils import argsort_mm_positions
|
from vllm.multimodal.utils import argsort_mm_positions
|
||||||
from vllm.pooling_params import PoolingParams
|
from vllm.pooling_params import PoolingParams
|
||||||
from vllm.renderers import RendererLike
|
from vllm.renderers import BaseRenderer
|
||||||
from vllm.sampling_params import _SAMPLING_EPS, SamplingParams
|
from vllm.sampling_params import _SAMPLING_EPS, SamplingParams
|
||||||
from vllm.tokenizers import TokenizerLike
|
from vllm.tokenizers import TokenizerLike
|
||||||
from vllm.tokenizers.mistral import MistralTokenizer
|
from vllm.tokenizers.mistral import MistralTokenizer
|
||||||
@@ -96,7 +96,7 @@ class InputProcessor:
|
|||||||
return self.input_preprocessor.get_tokenizer()
|
return self.input_preprocessor.get_tokenizer()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def renderer(self) -> RendererLike:
|
def renderer(self) -> BaseRenderer:
|
||||||
return self.input_preprocessor.renderer
|
return self.input_preprocessor.renderer
|
||||||
|
|
||||||
def _validate_logprobs(
|
def _validate_logprobs(
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
|||||||
from vllm.outputs import PoolingRequestOutput, RequestOutput
|
from vllm.outputs import PoolingRequestOutput, RequestOutput
|
||||||
from vllm.plugins.io_processors import get_io_processor
|
from vllm.plugins.io_processors import get_io_processor
|
||||||
from vllm.pooling_params import PoolingParams
|
from vllm.pooling_params import PoolingParams
|
||||||
from vllm.renderers import RendererLike
|
from vllm.renderers import BaseRenderer
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.tasks import SupportedTask
|
from vllm.tasks import SupportedTask
|
||||||
from vllm.tokenizers import TokenizerLike
|
from vllm.tokenizers import TokenizerLike
|
||||||
@@ -367,7 +367,7 @@ class LLMEngine:
|
|||||||
return self.input_processor.get_tokenizer()
|
return self.input_processor.get_tokenizer()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def renderer(self) -> RendererLike:
|
def renderer(self) -> BaseRenderer:
|
||||||
return self.input_processor.renderer
|
return self.input_processor.renderer
|
||||||
|
|
||||||
def do_log_stats(self) -> None:
|
def do_log_stats(self) -> None:
|
||||||
|
|||||||
Reference in New Issue
Block a user