[Misc] Automatically resolve HF processor init kwargs (#22005)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -4,9 +4,15 @@
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
||||
|
||||
from transformers import (AutoFeatureExtractor, AutoImageProcessor,
|
||||
AutoProcessor)
|
||||
from transformers.feature_extraction_utils import FeatureExtractionMixin
|
||||
from transformers.image_processing_utils import BaseImageProcessor
|
||||
from transformers.processing_utils import ProcessorMixin
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
from vllm.utils import get_allowed_kwarg_only_overrides
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig
|
||||
|
||||
@@ -33,23 +39,42 @@ class HashableList(list):
|
||||
return hash(tuple(self))
|
||||
|
||||
|
||||
def _merge_mm_kwargs(model_config: "ModelConfig", **kwargs):
|
||||
mm_config = model_config.get_multimodal_config()
|
||||
base_kwargs = mm_config.mm_processor_kwargs
|
||||
if base_kwargs is None:
|
||||
base_kwargs = {}
|
||||
def _get_processor_factory_fn(processor_cls: Union[type, tuple[type, ...]]):
|
||||
if isinstance(processor_cls, tuple) or processor_cls == ProcessorMixin:
|
||||
return AutoProcessor.from_pretrained
|
||||
if hasattr(processor_cls, "from_pretrained"):
|
||||
return processor_cls.from_pretrained
|
||||
|
||||
merged_kwargs = {**base_kwargs, **kwargs}
|
||||
return processor_cls
|
||||
|
||||
|
||||
def _merge_mm_kwargs(
|
||||
model_config: "ModelConfig",
|
||||
processor_cls: Union[type, tuple[type, ...]],
|
||||
/,
|
||||
**kwargs,
|
||||
):
|
||||
mm_config = model_config.get_multimodal_config()
|
||||
merged_kwargs = mm_config.merge_mm_processor_kwargs(kwargs)
|
||||
|
||||
factory = _get_processor_factory_fn(processor_cls)
|
||||
allowed_kwargs = get_allowed_kwarg_only_overrides(
|
||||
factory,
|
||||
merged_kwargs,
|
||||
requires_kw_only=False,
|
||||
allow_var_kwargs=True,
|
||||
)
|
||||
|
||||
# NOTE: Pythonic dict is not hashable and will raise unhashable type
|
||||
# error when calling `cached_get_processor`, therefore we need to
|
||||
# wrap it to a hashable dict.
|
||||
for key, value in merged_kwargs.items():
|
||||
for key, value in allowed_kwargs.items():
|
||||
if isinstance(value, dict):
|
||||
merged_kwargs[key] = HashableDict(value)
|
||||
allowed_kwargs[key] = HashableDict(value)
|
||||
if isinstance(value, list):
|
||||
merged_kwargs[key] = HashableList(value)
|
||||
return merged_kwargs
|
||||
allowed_kwargs[key] = HashableList(value)
|
||||
|
||||
return allowed_kwargs
|
||||
|
||||
|
||||
def get_processor(
|
||||
@@ -61,21 +86,29 @@ def get_processor(
|
||||
**kwargs: Any,
|
||||
) -> _P:
|
||||
"""Load a processor for the given model name via HuggingFace."""
|
||||
# don't put this import at the top level
|
||||
# it will call torch.cuda.device_count()
|
||||
from transformers import AutoProcessor
|
||||
|
||||
processor_factory = (AutoProcessor if processor_cls == ProcessorMixin or
|
||||
isinstance(processor_cls, tuple) else processor_cls)
|
||||
if revision is None:
|
||||
revision = "main"
|
||||
|
||||
try:
|
||||
processor = processor_factory.from_pretrained(
|
||||
processor_name,
|
||||
*args,
|
||||
revision=revision,
|
||||
trust_remote_code=trust_remote_code,
|
||||
**kwargs,
|
||||
)
|
||||
if isinstance(processor_cls, tuple) or processor_cls == ProcessorMixin:
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
processor_name,
|
||||
*args,
|
||||
revision=revision,
|
||||
trust_remote_code=trust_remote_code,
|
||||
**kwargs,
|
||||
)
|
||||
elif issubclass(processor_cls, ProcessorMixin):
|
||||
processor = processor_cls.from_pretrained(
|
||||
processor_name,
|
||||
*args,
|
||||
revision=revision,
|
||||
trust_remote_code=trust_remote_code,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
# Processors that are standalone classes unrelated to HF
|
||||
processor = processor_cls(*args, **kwargs)
|
||||
except ValueError as e:
|
||||
# If the error pertains to the processor class not existing or not
|
||||
# currently being imported, suggest using the --trust-remote-code flag.
|
||||
@@ -112,7 +145,7 @@ def cached_processor_from_config(
|
||||
revision=model_config.revision,
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
processor_cls=processor_cls, # type: ignore[arg-type]
|
||||
**_merge_mm_kwargs(model_config, **kwargs),
|
||||
**_merge_mm_kwargs(model_config, processor_cls, **kwargs),
|
||||
)
|
||||
|
||||
|
||||
@@ -125,10 +158,6 @@ def get_feature_extractor(
|
||||
):
|
||||
"""Load an audio feature extractor for the given model name
|
||||
via HuggingFace."""
|
||||
# don't put this import at the top level
|
||||
# it will call torch.cuda.device_count()
|
||||
from transformers import AutoFeatureExtractor
|
||||
from transformers.feature_extraction_utils import FeatureExtractionMixin
|
||||
try:
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
||||
processor_name,
|
||||
@@ -164,7 +193,7 @@ def cached_feature_extractor_from_config(
|
||||
model_config.model,
|
||||
revision=model_config.revision,
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
**_merge_mm_kwargs(model_config, **kwargs),
|
||||
**_merge_mm_kwargs(model_config, AutoFeatureExtractor, **kwargs),
|
||||
)
|
||||
|
||||
|
||||
@@ -176,11 +205,6 @@ def get_image_processor(
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Load an image processor for the given model name via HuggingFace."""
|
||||
# don't put this import at the top level
|
||||
# it will call torch.cuda.device_count()
|
||||
from transformers import AutoImageProcessor
|
||||
from transformers.image_processing_utils import BaseImageProcessor
|
||||
|
||||
try:
|
||||
processor = AutoImageProcessor.from_pretrained(
|
||||
processor_name,
|
||||
@@ -217,5 +241,5 @@ def cached_image_processor_from_config(
|
||||
model_config.model,
|
||||
revision=model_config.revision,
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
**_merge_mm_kwargs(model_config, **kwargs),
|
||||
**_merge_mm_kwargs(model_config, AutoImageProcessor, **kwargs),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user