[Core][Frontend] Add Support for Inference Time mm_processor_kwargs (#9131)
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
This commit is contained in:
@@ -1277,18 +1277,87 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args,
|
||||
return await task(*args, **kwargs)
|
||||
|
||||
|
||||
def supports_kw(callable: Callable[..., object], kw_name: str) -> bool:
|
||||
def supports_kw(
|
||||
callable: Callable[..., object],
|
||||
kw_name: str,
|
||||
requires_kw_only: bool = False,
|
||||
allow_var_kwargs: bool = True,
|
||||
) -> bool:
|
||||
"""Check if a keyword is a valid kwarg for a callable; if requires_kw_only
|
||||
disallows kwargs names that can also be positional arguments.
|
||||
"""
|
||||
params = inspect.signature(callable).parameters
|
||||
if kw_name in params:
|
||||
return True
|
||||
if not params:
|
||||
return False
|
||||
|
||||
return any(param.kind == inspect.Parameter.VAR_KEYWORD
|
||||
for param in params.values())
|
||||
param_val = params.get(kw_name)
|
||||
|
||||
# Types where the it may be valid, i.e., explicitly defined & nonvariadic
|
||||
passable_kw_types = set((inspect.Parameter.POSITIONAL_ONLY,
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
inspect.Parameter.KEYWORD_ONLY))
|
||||
|
||||
if param_val:
|
||||
is_sig_param = param_val.kind in passable_kw_types
|
||||
# We want kwargs only, but this is passable as a positional arg
|
||||
if (requires_kw_only and is_sig_param
|
||||
and param_val.kind != inspect.Parameter.KEYWORD_ONLY):
|
||||
return False
|
||||
if ((requires_kw_only
|
||||
and param_val.kind == inspect.Parameter.KEYWORD_ONLY)
|
||||
or (not requires_kw_only and is_sig_param)):
|
||||
return True
|
||||
|
||||
# If we're okay with var-kwargs, it's supported as long as
|
||||
# the kw_name isn't something like *args, **kwargs
|
||||
if allow_var_kwargs:
|
||||
# Get the last param; type is ignored here because params is a proxy
|
||||
# mapping, but it wraps an ordered dict, and they appear in order.
|
||||
# Ref: https://docs.python.org/3/library/inspect.html#inspect.Signature.parameters
|
||||
last_param = params[next(reversed(params))] # type: ignore
|
||||
return (last_param.kind == inspect.Parameter.VAR_KEYWORD
|
||||
and last_param.name != kw_name)
|
||||
return False
|
||||
|
||||
|
||||
def resolve_mm_processor_kwargs(
|
||||
init_kwargs: Optional[Dict[str, Any]],
|
||||
inference_kwargs: Optional[Dict[str, Any]],
|
||||
callable: Callable[..., object],
|
||||
allow_var_kwargs: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""Applies filtering to eliminate invalid mm_processor_kwargs, i.e.,
|
||||
those who are not explicit keywords to the given callable (of one is
|
||||
given; otherwise no filtering is done), then merges the kwarg dicts,
|
||||
giving priority to inference_kwargs if there are any collisions.
|
||||
|
||||
In the case that no kwarg overrides are provided, returns an empty
|
||||
dict so that it can still be kwarg expanded into the callable later on.
|
||||
|
||||
If allow_var_kwargs=True, allows for things that can be expanded into
|
||||
kwargs as long as they aren't naming collision for var_kwargs or potential
|
||||
positional arguments.
|
||||
"""
|
||||
# Filter inference time multimodal processor kwargs provided
|
||||
runtime_mm_kwargs = get_allowed_kwarg_only_overrides(
|
||||
callable,
|
||||
overrides=inference_kwargs,
|
||||
allow_var_kwargs=allow_var_kwargs)
|
||||
|
||||
# Filter init time multimodal processor kwargs provided
|
||||
init_mm_kwargs = get_allowed_kwarg_only_overrides(
|
||||
callable, overrides=init_kwargs, allow_var_kwargs=allow_var_kwargs)
|
||||
|
||||
# Merge the final processor kwargs, prioritizing inference
|
||||
# time values over the initialization time values.
|
||||
mm_processor_kwargs = {**init_mm_kwargs, **runtime_mm_kwargs}
|
||||
return mm_processor_kwargs
|
||||
|
||||
|
||||
def get_allowed_kwarg_only_overrides(
|
||||
callable: Callable[..., object],
|
||||
overrides: Optional[Dict[str, Any]],
|
||||
allow_var_kwargs: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Given a callable which has one or more keyword only params and a dict
|
||||
@@ -1300,7 +1369,9 @@ def get_allowed_kwarg_only_overrides(
|
||||
|
||||
Args:
|
||||
callable: Callable which takes 0 or more keyword only arguments.
|
||||
If None is provided, all overrides names are allowed.
|
||||
overrides: Potential overrides to be used when invoking the callable.
|
||||
allow_var_kwargs: Allows overrides that are expandable for var kwargs.
|
||||
|
||||
Returns:
|
||||
Dictionary containing the kwargs to be leveraged which may be used
|
||||
@@ -1310,17 +1381,15 @@ def get_allowed_kwarg_only_overrides(
|
||||
if not overrides:
|
||||
return {}
|
||||
|
||||
allowed_override_names = [
|
||||
name for name, param in inspect.signature(callable).parameters.items()
|
||||
if param.kind == inspect.Parameter.KEYWORD_ONLY
|
||||
]
|
||||
|
||||
# Drop any mm_processor_kwargs provided by the user that are
|
||||
# not kwarg names accepted by the provided input processor.
|
||||
# Drop any mm_processor_kwargs provided by the user that
|
||||
# are not kwargs, unless it can fit it var_kwargs param
|
||||
filtered_overrides = {
|
||||
kwarg_name: val
|
||||
for kwarg_name, val in overrides.items()
|
||||
if kwarg_name in allowed_override_names
|
||||
if supports_kw(callable,
|
||||
kwarg_name,
|
||||
requires_kw_only=True,
|
||||
allow_var_kwargs=allow_var_kwargs)
|
||||
}
|
||||
|
||||
# If anything is dropped, log a warning
|
||||
|
||||
Reference in New Issue
Block a user