[Bugfix] Fix broadcasting logic for multi_modal_kwargs (#6836)
This commit is contained in:
@@ -17,7 +17,7 @@ from functools import lru_cache, partial, wraps
|
||||
from platform import uname
|
||||
from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic,
|
||||
Hashable, List, Optional, OrderedDict, Set, Tuple, TypeVar,
|
||||
Union)
|
||||
Union, overload)
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
@@ -53,6 +53,7 @@ TORCH_DTYPE_TO_NUMPY_DTYPE = {
|
||||
P = ParamSpec('P')
|
||||
K = TypeVar("K")
|
||||
T = TypeVar("T")
|
||||
U = TypeVar("U")
|
||||
|
||||
|
||||
class _Sentinel:
|
||||
@@ -712,6 +713,54 @@ def merge_dicts(dict1: Dict[K, List[T]],
|
||||
return dict(merged_dict)
|
||||
|
||||
|
||||
JSONTree = Union[Dict[str, "JSONTree[T]"], List["JSONTree[T]"],
|
||||
Tuple["JSONTree[T]", ...], T]
|
||||
"""A nested JSON structure where the leaves need not be JSON-serializable."""
|
||||
|
||||
|
||||
@overload
|
||||
def json_map_leaves(
|
||||
func: Callable[[T], U],
|
||||
value: Dict[str, JSONTree[T]],
|
||||
) -> Dict[str, JSONTree[U]]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def json_map_leaves(
|
||||
func: Callable[[T], U],
|
||||
value: List[JSONTree[T]],
|
||||
) -> List[JSONTree[U]]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def json_map_leaves(
|
||||
func: Callable[[T], U],
|
||||
value: Tuple[JSONTree[T], ...],
|
||||
) -> Tuple[JSONTree[U], ...]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def json_map_leaves(
|
||||
func: Callable[[T], U],
|
||||
value: JSONTree[T],
|
||||
) -> JSONTree[U]:
|
||||
...
|
||||
|
||||
|
||||
def json_map_leaves(func: Callable[[T], U], value: JSONTree[T]) -> JSONTree[U]:
|
||||
if isinstance(value, dict):
|
||||
return {k: json_map_leaves(func, v) for k, v in value.items()}
|
||||
elif isinstance(value, list):
|
||||
return [json_map_leaves(func, v) for v in value]
|
||||
elif isinstance(value, tuple):
|
||||
return tuple(json_map_leaves(func, v) for v in value)
|
||||
else:
|
||||
return func(value)
|
||||
|
||||
|
||||
def flatten_2d_lists(lists: List[List[T]]) -> List[T]:
|
||||
"""Flatten a list of lists to a single list."""
|
||||
return [item for sublist in lists for item in sublist]
|
||||
|
||||
Reference in New Issue
Block a user