[Core][Frontend] Add Support for Inference Time mm_processor_kwargs (#9131)

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
This commit is contained in:
Alex Brooks
2024-10-08 08:12:56 -06:00
committed by GitHub
parent 8c746226c9
commit a3691b6b5e
21 changed files with 440 additions and 118 deletions

View File

@@ -1277,18 +1277,87 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args,
return await task(*args, **kwargs)
def supports_kw(callable: Callable[..., object], kw_name: str) -> bool:
def supports_kw(
callable: Callable[..., object],
kw_name: str,
requires_kw_only: bool = False,
allow_var_kwargs: bool = True,
) -> bool:
"""Check if a keyword is a valid kwarg for a callable; if requires_kw_only
disallows kwargs names that can also be positional arguments.
"""
params = inspect.signature(callable).parameters
if kw_name in params:
return True
if not params:
return False
return any(param.kind == inspect.Parameter.VAR_KEYWORD
for param in params.values())
param_val = params.get(kw_name)
# Types where the it may be valid, i.e., explicitly defined & nonvariadic
passable_kw_types = set((inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
inspect.Parameter.KEYWORD_ONLY))
if param_val:
is_sig_param = param_val.kind in passable_kw_types
# We want kwargs only, but this is passable as a positional arg
if (requires_kw_only and is_sig_param
and param_val.kind != inspect.Parameter.KEYWORD_ONLY):
return False
if ((requires_kw_only
and param_val.kind == inspect.Parameter.KEYWORD_ONLY)
or (not requires_kw_only and is_sig_param)):
return True
# If we're okay with var-kwargs, it's supported as long as
# the kw_name isn't something like *args, **kwargs
if allow_var_kwargs:
# Get the last param; type is ignored here because params is a proxy
# mapping, but it wraps an ordered dict, and they appear in order.
# Ref: https://docs.python.org/3/library/inspect.html#inspect.Signature.parameters
last_param = params[next(reversed(params))] # type: ignore
return (last_param.kind == inspect.Parameter.VAR_KEYWORD
and last_param.name != kw_name)
return False
def resolve_mm_processor_kwargs(
init_kwargs: Optional[Dict[str, Any]],
inference_kwargs: Optional[Dict[str, Any]],
callable: Callable[..., object],
allow_var_kwargs: bool = False,
) -> Dict[str, Any]:
"""Applies filtering to eliminate invalid mm_processor_kwargs, i.e.,
those who are not explicit keywords to the given callable (of one is
given; otherwise no filtering is done), then merges the kwarg dicts,
giving priority to inference_kwargs if there are any collisions.
In the case that no kwarg overrides are provided, returns an empty
dict so that it can still be kwarg expanded into the callable later on.
If allow_var_kwargs=True, allows for things that can be expanded into
kwargs as long as they aren't naming collision for var_kwargs or potential
positional arguments.
"""
# Filter inference time multimodal processor kwargs provided
runtime_mm_kwargs = get_allowed_kwarg_only_overrides(
callable,
overrides=inference_kwargs,
allow_var_kwargs=allow_var_kwargs)
# Filter init time multimodal processor kwargs provided
init_mm_kwargs = get_allowed_kwarg_only_overrides(
callable, overrides=init_kwargs, allow_var_kwargs=allow_var_kwargs)
# Merge the final processor kwargs, prioritizing inference
# time values over the initialization time values.
mm_processor_kwargs = {**init_mm_kwargs, **runtime_mm_kwargs}
return mm_processor_kwargs
def get_allowed_kwarg_only_overrides(
callable: Callable[..., object],
overrides: Optional[Dict[str, Any]],
allow_var_kwargs: bool = False,
) -> Dict[str, Any]:
"""
Given a callable which has one or more keyword only params and a dict
@@ -1300,7 +1369,9 @@ def get_allowed_kwarg_only_overrides(
Args:
callable: Callable which takes 0 or more keyword only arguments.
If None is provided, all overrides names are allowed.
overrides: Potential overrides to be used when invoking the callable.
allow_var_kwargs: Allows overrides that are expandable for var kwargs.
Returns:
Dictionary containing the kwargs to be leveraged which may be used
@@ -1310,17 +1381,15 @@ def get_allowed_kwarg_only_overrides(
if not overrides:
return {}
allowed_override_names = [
name for name, param in inspect.signature(callable).parameters.items()
if param.kind == inspect.Parameter.KEYWORD_ONLY
]
# Drop any mm_processor_kwargs provided by the user that are
# not kwarg names accepted by the provided input processor.
# Drop any mm_processor_kwargs provided by the user that
# are not kwargs, unless it can fit it var_kwargs param
filtered_overrides = {
kwarg_name: val
for kwarg_name, val in overrides.items()
if kwarg_name in allowed_override_names
if supports_kw(callable,
kwarg_name,
requires_kw_only=True,
allow_var_kwargs=allow_var_kwargs)
}
# If anything is dropped, log a warning