[Refactor] Clean up processor kwargs extraction (#35872)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-03-04 11:53:53 +08:00
committed by GitHub
parent 6e9f21e8a2
commit e379396167
2 changed files with 55 additions and 50 deletions

View File

@@ -7,7 +7,8 @@ from transformers.processing_utils import ProcessingKwargs
from typing_extensions import Unpack
from vllm.transformers_utils.processor import (
get_processor_kwargs_from_processor,
get_processor_kwargs_keys,
get_processor_kwargs_type,
)
@@ -35,7 +36,7 @@ def _assert_has_all_expected(keys: set[str]) -> None:
assert k in keys
# Path 1: __call__ method has kwargs: Unpack[*ProcessingKwargs]
# Path 1: __call__ method has kwargs: Unpack[*ProcessorKwargs]
class _ProcWithUnpack:
def __call__(self, *args, **kwargs: Unpack[_FakeProcessorKwargs]): # type: ignore
return None
@@ -43,11 +44,11 @@ class _ProcWithUnpack:
def test_get_processor_kwargs_from_processor_unpack_path_returns_full_union():
proc = _ProcWithUnpack()
keys = get_processor_kwargs_from_processor(proc)
keys = get_processor_kwargs_keys(get_processor_kwargs_type(proc))
_assert_has_all_expected(keys)
# ---- Path 2: No Unpack, fallback to scanning *ProcessingKwargs in module ----
# ---- Path 2: No Unpack, fallback to scanning *ProcessorKwargs in module ----
class _ProcWithoutUnpack:
@@ -62,5 +63,5 @@ def test_get_processor_kwargs_from_processor_module_scan_returns_full_union():
assert hasattr(mod, "_FakeProcessorKwargs")
proc = _ProcWithoutUnpack()
keys = get_processor_kwargs_from_processor(proc)
keys = get_processor_kwargs_keys(get_processor_kwargs_type(proc))
_assert_has_all_expected(keys)

View File

@@ -111,29 +111,6 @@ def _get_processor_factory_fn(processor_cls: type | tuple[type, ...]):
return processor_cls
@lru_cache
def _collect_dynamic_keys_from_processing_kwargs(kwargs_cls: type) -> set[str]:
dynamic_kwargs: set[str] = set()
if kwargs_cls is None:
return dynamic_kwargs
# get kwargs annotations in processor
# merge text_kwargs / images_kwargs / videos_kwargs / audio_kwargs
kwargs_type_annotations = get_type_hints(kwargs_cls)
for kw_type in ("text_kwargs", "images_kwargs", "videos_kwargs", "audio_kwargs"):
if kw_type in kwargs_type_annotations:
# Use __annotations__ instead of get_type_hints() to avoid
# NameError from unresolved forward references (e.g.
# PILImageResampling). We only need key names, not types.
kw_cls = kwargs_type_annotations[kw_type]
kw_annotations: dict[str, Any] = {}
for base in reversed(kw_cls.__mro__):
kw_annotations.update(getattr(base, "__annotations__", {}))
for kw_name in kw_annotations:
dynamic_kwargs.add(kw_name)
dynamic_kwargs |= {"text_kwargs", "images_kwargs", "videos_kwargs", "audio_kwargs"}
return dynamic_kwargs
def _merge_mm_kwargs(
model_config: "ModelConfig",
processor_cls: type | tuple[type, ...],
@@ -224,38 +201,63 @@ cached_get_processor = lru_cache(get_processor)
@lru_cache
def get_processor_kwargs_from_processor(processor: _P) -> set[str]:
def get_processor_kwargs_type(
processor: ProcessorMixin,
) -> type[processing_utils.ProcessingKwargs]:
try:
# get kwargs annotations in processor
call_kwargs = inspect.signature(type(processor).__call__).parameters.get(
"kwargs"
)
call_params = inspect.signature(type(processor).__call__).parameters
call_kwargs = call_params.get("kwargs")
call_kwargs_annotations = call_kwargs.annotation if call_kwargs else None
# if the processor has explicit kwargs annotation, use it
if call_kwargs_annotations not in (None, inspect._empty):
# get_type_hints will parse all type annotations at runtime,
# and if an annotation refers to a type or
# name that hasnt been imported or defined, it will raise an error.
# So we use __annotations__ to get the raw annotations directly.
return _collect_dynamic_keys_from_processing_kwargs(
get_args(call_kwargs_annotations)[0]
)
# otherwise, try to get from ProcessingKwargs
else:
module_name = type(processor).__module__
mod = importlib.import_module(module_name)
# find *ProcessingKwargs in the module
processor_kwargs: set[str] = set()
for name, obj in vars(mod).items():
if name.endswith("ProcessingKwargs"):
processor_kwargs = (
processor_kwargs
| _collect_dynamic_keys_from_processing_kwargs(obj)
)
return processor_kwargs
return get_args(call_kwargs_annotations)[0]
# otherwise, try to get from ProcessorKwargs
module_name = type(processor).__module__
mod = importlib.import_module(module_name)
for name, obj in vars(mod).items():
if name.endswith("ProcessorKwargs"):
return obj
except Exception:
logger.exception("Failed to collect processor kwargs")
return set()
return processing_utils.ProcessingKwargs
@lru_cache
def get_processor_kwargs_keys(
kwargs_cls: type[processing_utils.ProcessingKwargs],
) -> set[str]:
dynamic_kwargs: set[str] = set()
modality_kwargs = {"text_kwargs", "images_kwargs", "videos_kwargs", "audio_kwargs"}
try:
# get kwargs annotations in processor
# merge text_kwargs / images_kwargs / videos_kwargs / audio_kwargs
kwargs_type_annotations = get_type_hints(kwargs_cls)
for kw_type in modality_kwargs:
if kw_type in kwargs_type_annotations:
# Use __annotations__ instead of get_type_hints() to avoid
# NameError from unresolved forward references (e.g.
# PILImageResampling). We only need key names, not types.
kw_cls = kwargs_type_annotations[kw_type]
kw_annotations: dict[str, Any] = {}
for base in reversed(kw_cls.__mro__):
kw_annotations.update(getattr(base, "__annotations__", {}))
for kw_name in kw_annotations:
dynamic_kwargs.add(kw_name)
except Exception:
logger.exception("Failed to collect processor kwargs")
return dynamic_kwargs | modality_kwargs
def cached_get_processor_without_dynamic_kwargs(
@@ -275,7 +277,9 @@ def cached_get_processor_without_dynamic_kwargs(
)
# Step 2: use temporary processor collect dynamic keys
dynamic_keys = get_processor_kwargs_from_processor(processor)
dynamic_keys = get_processor_kwargs_keys(
get_processor_kwargs_type(processor) # type: ignore[arg-type]
)
# Step 3: use dynamic_keys filter kwargs
filtered_kwargs = {k: v for k, v in kwargs.items() if k not in dynamic_keys}