[VLM][Bugfix] Pass processor kwargs properly on init (#13516)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -1,25 +1,59 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from functools import lru_cache
|
||||
from typing import Any, cast
|
||||
from typing import TYPE_CHECKING, Any, Union, cast
|
||||
|
||||
from transformers.processing_utils import ProcessorMixin
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig
|
||||
|
||||
_P = TypeVar("_P", bound=ProcessorMixin, default=ProcessorMixin)
|
||||
|
||||
|
||||
class HashableDict(dict):
|
||||
"""
|
||||
A dictionary that can be hashed by lru_cache.
|
||||
"""
|
||||
|
||||
# NOTE: pythonic dict is not hashable,
|
||||
# we override on it directly for simplicity
|
||||
def __hash__(self) -> int: # type: ignore[override]
|
||||
return hash(frozenset(self.items()))
|
||||
|
||||
|
||||
def _merge_mm_kwargs(model_config: "ModelConfig", **kwargs):
|
||||
base_kwargs = model_config.mm_processor_kwargs
|
||||
if base_kwargs is None:
|
||||
base_kwargs = {}
|
||||
|
||||
merged_kwargs = {**base_kwargs, **kwargs}
|
||||
|
||||
# NOTE: Pythonic dict is not hashable and will raise unhashable type
|
||||
# error when calling `cached_get_processor`, therefore we need to
|
||||
# wrap it to a hashable dict.
|
||||
for key, value in merged_kwargs.items():
|
||||
if isinstance(value, dict):
|
||||
merged_kwargs[key] = HashableDict(value)
|
||||
|
||||
return merged_kwargs
|
||||
|
||||
|
||||
def get_processor(
|
||||
processor_name: str,
|
||||
*args: Any,
|
||||
trust_remote_code: bool = False,
|
||||
processor_cls: type[ProcessorMixin] = ProcessorMixin,
|
||||
processor_cls: Union[type[_P], tuple[type[_P], ...]] = ProcessorMixin,
|
||||
**kwargs: Any,
|
||||
):
|
||||
) -> _P:
|
||||
"""Load a processor for the given model name via HuggingFace."""
|
||||
# don't put this import at the top level
|
||||
# it will call torch.cuda.device_count()
|
||||
from transformers import AutoProcessor
|
||||
|
||||
processor_factory = (AutoProcessor
|
||||
if processor_cls == ProcessorMixin else processor_cls)
|
||||
processor_factory = (AutoProcessor if processor_cls == ProcessorMixin or
|
||||
isinstance(processor_cls, tuple) else processor_cls)
|
||||
|
||||
try:
|
||||
processor = processor_factory.from_pretrained(
|
||||
@@ -43,12 +77,30 @@ def get_processor(
|
||||
else:
|
||||
raise e
|
||||
|
||||
return cast(ProcessorMixin, processor)
|
||||
if not isinstance(processor, processor_cls):
|
||||
raise TypeError("Invalid type of HuggingFace processor. "
|
||||
f"Expected type: {processor_cls}, but "
|
||||
f"found type: {type(processor)}")
|
||||
|
||||
return processor
|
||||
|
||||
|
||||
cached_get_processor = lru_cache(get_processor)
|
||||
|
||||
|
||||
def cached_processor_from_config(
|
||||
model_config: "ModelConfig",
|
||||
processor_cls: Union[type[_P], tuple[type[_P], ...]] = ProcessorMixin,
|
||||
**kwargs: Any,
|
||||
) -> _P:
|
||||
return cached_get_processor(
|
||||
model_config.model,
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
processor_cls=processor_cls, # type: ignore[arg-type]
|
||||
**_merge_mm_kwargs(model_config, **kwargs),
|
||||
)
|
||||
|
||||
|
||||
def get_image_processor(
|
||||
processor_name: str,
|
||||
*args: Any,
|
||||
@@ -85,6 +137,20 @@ def get_image_processor(
|
||||
return cast(BaseImageProcessor, processor)
|
||||
|
||||
|
||||
cached_get_image_processor = lru_cache(get_image_processor)
|
||||
|
||||
|
||||
def cached_image_processor_from_config(
|
||||
model_config: "ModelConfig",
|
||||
**kwargs: Any,
|
||||
):
|
||||
return cached_get_image_processor(
|
||||
model_config.model,
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
**_merge_mm_kwargs(model_config, **kwargs),
|
||||
)
|
||||
|
||||
|
||||
def get_video_processor(
|
||||
processor_name: str,
|
||||
*args: Any,
|
||||
@@ -104,3 +170,17 @@ def get_video_processor(
|
||||
)
|
||||
|
||||
return cast(BaseImageProcessor, processor.video_processor)
|
||||
|
||||
|
||||
cached_get_video_processor = lru_cache(get_video_processor)
|
||||
|
||||
|
||||
def cached_video_processor_from_config(
|
||||
model_config: "ModelConfig",
|
||||
**kwargs: Any,
|
||||
):
|
||||
return cached_get_video_processor(
|
||||
model_config.model,
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
**_merge_mm_kwargs(model_config, **kwargs),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user