This commit is contained in:
@@ -44,7 +44,7 @@ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, Processor
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
from vllm import envs
|
||||
from vllm.config import ModelConfig, RendererConfig
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models import SupportsMultiModal
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict, MultiModalUUIDDict
|
||||
@@ -452,10 +452,9 @@ This is needed because `lru_cache` does not cache when an exception happens.
|
||||
|
||||
def _try_get_processor_chat_template(
|
||||
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
|
||||
*,
|
||||
trust_remote_code: bool,
|
||||
model_config: ModelConfig,
|
||||
) -> str | None:
|
||||
cache_key = (tokenizer.name_or_path, trust_remote_code)
|
||||
cache_key = (tokenizer.name_or_path, model_config.trust_remote_code)
|
||||
if cache_key in _PROCESSOR_CHAT_TEMPLATES:
|
||||
return _PROCESSOR_CHAT_TEMPLATES[cache_key]
|
||||
|
||||
@@ -467,7 +466,7 @@ def _try_get_processor_chat_template(
|
||||
PreTrainedTokenizerFast,
|
||||
ProcessorMixin,
|
||||
),
|
||||
trust_remote_code=trust_remote_code,
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
)
|
||||
if (
|
||||
isinstance(processor, ProcessorMixin)
|
||||
@@ -500,10 +499,7 @@ def resolve_hf_chat_template(
|
||||
|
||||
# 2nd priority: AutoProcessor chat template, unless tool calling is enabled
|
||||
if tools is None:
|
||||
chat_template = _try_get_processor_chat_template(
|
||||
tokenizer,
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
)
|
||||
chat_template = _try_get_processor_chat_template(tokenizer, model_config)
|
||||
if chat_template is not None:
|
||||
return chat_template
|
||||
|
||||
@@ -517,10 +513,10 @@ def resolve_hf_chat_template(
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# 4th priority: Predefined fallbacks]
|
||||
# 4th priority: Predefined fallbacks
|
||||
path = get_chat_template_fallback_path(
|
||||
model_type=model_config.hf_config.model_type,
|
||||
tokenizer_name_or_path=tokenizer.name_or_path,
|
||||
tokenizer_name_or_path=model_config.tokenizer,
|
||||
)
|
||||
if path is not None:
|
||||
logger.info_once(
|
||||
@@ -542,14 +538,14 @@ def _resolve_chat_template_content_format(
|
||||
tools: list[dict[str, Any]] | None,
|
||||
tokenizer: TokenizerLike | None,
|
||||
*,
|
||||
renderer_config: RendererConfig,
|
||||
model_config: ModelConfig,
|
||||
) -> _ChatTemplateContentFormat:
|
||||
if isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)):
|
||||
hf_chat_template = resolve_hf_chat_template(
|
||||
tokenizer,
|
||||
chat_template=chat_template,
|
||||
tools=tools,
|
||||
model_config=renderer_config.model_config,
|
||||
model_config=model_config,
|
||||
)
|
||||
else:
|
||||
hf_chat_template = None
|
||||
@@ -599,7 +595,7 @@ def resolve_chat_template_content_format(
|
||||
given_format: ChatTemplateContentFormatOption,
|
||||
tokenizer: TokenizerLike | None,
|
||||
*,
|
||||
renderer_config: RendererConfig,
|
||||
model_config: ModelConfig,
|
||||
) -> _ChatTemplateContentFormat:
|
||||
if given_format != "auto":
|
||||
return given_format
|
||||
@@ -608,7 +604,7 @@ def resolve_chat_template_content_format(
|
||||
chat_template,
|
||||
tools,
|
||||
tokenizer,
|
||||
renderer_config=renderer_config,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
_log_chat_template_content_format(
|
||||
@@ -631,32 +627,32 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||
maximum per prompt.
|
||||
"""
|
||||
|
||||
def __init__(self, renderer_config: RendererConfig):
|
||||
def __init__(self, model_config: ModelConfig):
|
||||
super().__init__()
|
||||
|
||||
self._renderer_config = renderer_config
|
||||
self._model_config = model_config
|
||||
|
||||
self._items_by_modality = defaultdict[str, list[_T | None]](list)
|
||||
self._uuids_by_modality = defaultdict[str, list[str | None]](list)
|
||||
|
||||
@property
|
||||
def renderer_config(self) -> RendererConfig:
|
||||
return self._renderer_config
|
||||
def model_config(self) -> ModelConfig:
|
||||
return self._model_config
|
||||
|
||||
@cached_property
|
||||
def model_cls(self) -> type[SupportsMultiModal]:
|
||||
from vllm.model_executor.model_loader import get_model_cls
|
||||
|
||||
model_cls = get_model_cls(self.renderer_config.model_config)
|
||||
model_cls = get_model_cls(self.model_config)
|
||||
return cast(type[SupportsMultiModal], model_cls)
|
||||
|
||||
@property
|
||||
def allowed_local_media_path(self):
|
||||
return self._renderer_config.allowed_local_media_path
|
||||
return self._model_config.allowed_local_media_path
|
||||
|
||||
@property
|
||||
def allowed_media_domains(self):
|
||||
return self._renderer_config.allowed_media_domains
|
||||
return self._model_config.allowed_media_domains
|
||||
|
||||
@property
|
||||
def mm_registry(self):
|
||||
@@ -664,7 +660,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
||||
|
||||
@cached_property
|
||||
def mm_processor(self):
|
||||
return self.mm_registry.create_processor(self.renderer_config)
|
||||
return self.mm_registry.create_processor(self.model_config)
|
||||
|
||||
def add(
|
||||
self,
|
||||
@@ -855,20 +851,19 @@ class MultiModalContentParser(BaseMultiModalContentParser):
|
||||
super().__init__()
|
||||
|
||||
self._tracker = tracker
|
||||
multimodal_config = self._tracker.model_config.multimodal_config
|
||||
media_io_kwargs = getattr(multimodal_config, "media_io_kwargs", None)
|
||||
|
||||
self._connector: MediaConnector = MEDIA_CONNECTOR_REGISTRY.load(
|
||||
envs.VLLM_MEDIA_CONNECTOR,
|
||||
media_io_kwargs=self.renderer_config.media_io_kwargs,
|
||||
media_io_kwargs=media_io_kwargs,
|
||||
allowed_local_media_path=tracker.allowed_local_media_path,
|
||||
allowed_media_domains=tracker.allowed_media_domains,
|
||||
)
|
||||
|
||||
@property
|
||||
def renderer_config(self) -> RendererConfig:
|
||||
return self._tracker.renderer_config
|
||||
|
||||
@property
|
||||
def model_config(self) -> ModelConfig:
|
||||
return self.renderer_config.model_config
|
||||
return self._tracker.model_config
|
||||
|
||||
def parse_image(self, image_url: str | None, uuid: str | None = None) -> None:
|
||||
image = self._connector.fetch_image(image_url) if image_url else None
|
||||
@@ -968,20 +963,18 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
|
||||
super().__init__()
|
||||
|
||||
self._tracker = tracker
|
||||
multimodal_config = self._tracker.model_config.multimodal_config
|
||||
media_io_kwargs = getattr(multimodal_config, "media_io_kwargs", None)
|
||||
self._connector: MediaConnector = MEDIA_CONNECTOR_REGISTRY.load(
|
||||
envs.VLLM_MEDIA_CONNECTOR,
|
||||
media_io_kwargs=self.renderer_config.media_io_kwargs,
|
||||
media_io_kwargs=media_io_kwargs,
|
||||
allowed_local_media_path=tracker.allowed_local_media_path,
|
||||
allowed_media_domains=tracker.allowed_media_domains,
|
||||
)
|
||||
|
||||
@property
|
||||
def renderer_config(self) -> RendererConfig:
|
||||
return self._tracker.renderer_config
|
||||
|
||||
@property
|
||||
def model_config(self) -> ModelConfig:
|
||||
return self.renderer_config.model_config
|
||||
return self._tracker.model_config
|
||||
|
||||
def parse_image(self, image_url: str | None, uuid: str | None = None) -> None:
|
||||
image_coro = self._connector.fetch_image_async(image_url) if image_url else None
|
||||
@@ -1611,17 +1604,15 @@ def _postprocess_messages(messages: list[ConversationMessage]) -> None:
|
||||
|
||||
def parse_chat_messages(
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
renderer_config: RendererConfig,
|
||||
model_config: ModelConfig,
|
||||
content_format: _ChatTemplateContentFormat,
|
||||
) -> tuple[
|
||||
list[ConversationMessage],
|
||||
MultiModalDataDict | None,
|
||||
MultiModalUUIDDict | None,
|
||||
]:
|
||||
model_config = renderer_config.model_config
|
||||
|
||||
conversation: list[ConversationMessage] = []
|
||||
mm_tracker = MultiModalItemTracker(renderer_config)
|
||||
mm_tracker = MultiModalItemTracker(model_config)
|
||||
|
||||
for msg in messages:
|
||||
sub_messages = _parse_chat_message_content(
|
||||
@@ -1644,17 +1635,15 @@ def parse_chat_messages(
|
||||
|
||||
def parse_chat_messages_futures(
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
renderer_config: RendererConfig,
|
||||
model_config: ModelConfig,
|
||||
content_format: _ChatTemplateContentFormat,
|
||||
) -> tuple[
|
||||
list[ConversationMessage],
|
||||
Awaitable[MultiModalDataDict | None],
|
||||
MultiModalUUIDDict | None,
|
||||
]:
|
||||
model_config = renderer_config.model_config
|
||||
|
||||
conversation: list[ConversationMessage] = []
|
||||
mm_tracker = AsyncMultiModalItemTracker(renderer_config)
|
||||
mm_tracker = AsyncMultiModalItemTracker(model_config)
|
||||
|
||||
for msg in messages:
|
||||
sub_messages = _parse_chat_message_content(
|
||||
@@ -1759,14 +1748,14 @@ def apply_hf_chat_template(
|
||||
chat_template: str | None,
|
||||
tools: list[dict[str, Any]] | None,
|
||||
*,
|
||||
renderer_config: RendererConfig,
|
||||
model_config: ModelConfig,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
hf_chat_template = resolve_hf_chat_template(
|
||||
tokenizer,
|
||||
chat_template=chat_template,
|
||||
tools=tools,
|
||||
model_config=renderer_config.model_config,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
if hf_chat_template is None:
|
||||
|
||||
Reference in New Issue
Block a user