[Refactor] Clean up processor kwargs extraction (#35872)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -7,7 +7,8 @@ from transformers.processing_utils import ProcessingKwargs
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from vllm.transformers_utils.processor import (
|
||||
get_processor_kwargs_from_processor,
|
||||
get_processor_kwargs_keys,
|
||||
get_processor_kwargs_type,
|
||||
)
|
||||
|
||||
|
||||
@@ -35,7 +36,7 @@ def _assert_has_all_expected(keys: set[str]) -> None:
|
||||
assert k in keys
|
||||
|
||||
|
||||
# Path 1: __call__ method has kwargs: Unpack[*ProcessingKwargs]
|
||||
# Path 1: __call__ method has kwargs: Unpack[*ProcessorKwargs]
|
||||
class _ProcWithUnpack:
|
||||
def __call__(self, *args, **kwargs: Unpack[_FakeProcessorKwargs]): # type: ignore
|
||||
return None
|
||||
@@ -43,11 +44,11 @@ class _ProcWithUnpack:
|
||||
|
||||
def test_get_processor_kwargs_from_processor_unpack_path_returns_full_union():
|
||||
proc = _ProcWithUnpack()
|
||||
keys = get_processor_kwargs_from_processor(proc)
|
||||
keys = get_processor_kwargs_keys(get_processor_kwargs_type(proc))
|
||||
_assert_has_all_expected(keys)
|
||||
|
||||
|
||||
# ---- Path 2: No Unpack, fallback to scanning *ProcessingKwargs in module ----
|
||||
# ---- Path 2: No Unpack, fallback to scanning *ProcessorKwargs in module ----
|
||||
|
||||
|
||||
class _ProcWithoutUnpack:
|
||||
@@ -62,5 +63,5 @@ def test_get_processor_kwargs_from_processor_module_scan_returns_full_union():
|
||||
assert hasattr(mod, "_FakeProcessorKwargs")
|
||||
|
||||
proc = _ProcWithoutUnpack()
|
||||
keys = get_processor_kwargs_from_processor(proc)
|
||||
keys = get_processor_kwargs_keys(get_processor_kwargs_type(proc))
|
||||
_assert_has_all_expected(keys)
|
||||
|
||||
@@ -111,29 +111,6 @@ def _get_processor_factory_fn(processor_cls: type | tuple[type, ...]):
|
||||
return processor_cls
|
||||
|
||||
|
||||
@lru_cache
|
||||
def _collect_dynamic_keys_from_processing_kwargs(kwargs_cls: type) -> set[str]:
|
||||
dynamic_kwargs: set[str] = set()
|
||||
if kwargs_cls is None:
|
||||
return dynamic_kwargs
|
||||
# get kwargs annotations in processor
|
||||
# merge text_kwargs / images_kwargs / videos_kwargs / audio_kwargs
|
||||
kwargs_type_annotations = get_type_hints(kwargs_cls)
|
||||
for kw_type in ("text_kwargs", "images_kwargs", "videos_kwargs", "audio_kwargs"):
|
||||
if kw_type in kwargs_type_annotations:
|
||||
# Use __annotations__ instead of get_type_hints() to avoid
|
||||
# NameError from unresolved forward references (e.g.
|
||||
# PILImageResampling). We only need key names, not types.
|
||||
kw_cls = kwargs_type_annotations[kw_type]
|
||||
kw_annotations: dict[str, Any] = {}
|
||||
for base in reversed(kw_cls.__mro__):
|
||||
kw_annotations.update(getattr(base, "__annotations__", {}))
|
||||
for kw_name in kw_annotations:
|
||||
dynamic_kwargs.add(kw_name)
|
||||
dynamic_kwargs |= {"text_kwargs", "images_kwargs", "videos_kwargs", "audio_kwargs"}
|
||||
return dynamic_kwargs
|
||||
|
||||
|
||||
def _merge_mm_kwargs(
|
||||
model_config: "ModelConfig",
|
||||
processor_cls: type | tuple[type, ...],
|
||||
@@ -224,38 +201,63 @@ cached_get_processor = lru_cache(get_processor)
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_processor_kwargs_from_processor(processor: _P) -> set[str]:
|
||||
def get_processor_kwargs_type(
|
||||
processor: ProcessorMixin,
|
||||
) -> type[processing_utils.ProcessingKwargs]:
|
||||
try:
|
||||
# get kwargs annotations in processor
|
||||
call_kwargs = inspect.signature(type(processor).__call__).parameters.get(
|
||||
"kwargs"
|
||||
)
|
||||
call_params = inspect.signature(type(processor).__call__).parameters
|
||||
call_kwargs = call_params.get("kwargs")
|
||||
call_kwargs_annotations = call_kwargs.annotation if call_kwargs else None
|
||||
|
||||
# if the processor has explicit kwargs annotation, use it
|
||||
if call_kwargs_annotations not in (None, inspect._empty):
|
||||
# get_type_hints will parse all type annotations at runtime,
|
||||
# and if an annotation refers to a type or
|
||||
# name that hasn’t been imported or defined, it will raise an error.
|
||||
# So we use __annotations__ to get the raw annotations directly.
|
||||
return _collect_dynamic_keys_from_processing_kwargs(
|
||||
get_args(call_kwargs_annotations)[0]
|
||||
)
|
||||
# otherwise, try to get from ProcessingKwargs
|
||||
else:
|
||||
module_name = type(processor).__module__
|
||||
mod = importlib.import_module(module_name)
|
||||
# find *ProcessingKwargs in the module
|
||||
processor_kwargs: set[str] = set()
|
||||
for name, obj in vars(mod).items():
|
||||
if name.endswith("ProcessingKwargs"):
|
||||
processor_kwargs = (
|
||||
processor_kwargs
|
||||
| _collect_dynamic_keys_from_processing_kwargs(obj)
|
||||
)
|
||||
return processor_kwargs
|
||||
return get_args(call_kwargs_annotations)[0]
|
||||
|
||||
# otherwise, try to get from ProcessorKwargs
|
||||
module_name = type(processor).__module__
|
||||
mod = importlib.import_module(module_name)
|
||||
for name, obj in vars(mod).items():
|
||||
if name.endswith("ProcessorKwargs"):
|
||||
return obj
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to collect processor kwargs")
|
||||
return set()
|
||||
|
||||
return processing_utils.ProcessingKwargs
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_processor_kwargs_keys(
|
||||
kwargs_cls: type[processing_utils.ProcessingKwargs],
|
||||
) -> set[str]:
|
||||
dynamic_kwargs: set[str] = set()
|
||||
modality_kwargs = {"text_kwargs", "images_kwargs", "videos_kwargs", "audio_kwargs"}
|
||||
|
||||
try:
|
||||
# get kwargs annotations in processor
|
||||
# merge text_kwargs / images_kwargs / videos_kwargs / audio_kwargs
|
||||
kwargs_type_annotations = get_type_hints(kwargs_cls)
|
||||
for kw_type in modality_kwargs:
|
||||
if kw_type in kwargs_type_annotations:
|
||||
# Use __annotations__ instead of get_type_hints() to avoid
|
||||
# NameError from unresolved forward references (e.g.
|
||||
# PILImageResampling). We only need key names, not types.
|
||||
kw_cls = kwargs_type_annotations[kw_type]
|
||||
kw_annotations: dict[str, Any] = {}
|
||||
for base in reversed(kw_cls.__mro__):
|
||||
kw_annotations.update(getattr(base, "__annotations__", {}))
|
||||
for kw_name in kw_annotations:
|
||||
dynamic_kwargs.add(kw_name)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to collect processor kwargs")
|
||||
|
||||
return dynamic_kwargs | modality_kwargs
|
||||
|
||||
|
||||
def cached_get_processor_without_dynamic_kwargs(
|
||||
@@ -275,7 +277,9 @@ def cached_get_processor_without_dynamic_kwargs(
|
||||
)
|
||||
|
||||
# Step 2: use temporary processor collect dynamic keys
|
||||
dynamic_keys = get_processor_kwargs_from_processor(processor)
|
||||
dynamic_keys = get_processor_kwargs_keys(
|
||||
get_processor_kwargs_type(processor) # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
# Step 3: use dynamic_keys filter kwargs
|
||||
filtered_kwargs = {k: v for k, v in kwargs.items() if k not in dynamic_keys}
|
||||
|
||||
Reference in New Issue
Block a user