[Bugfix] Fix broadcasting logic for multi_modal_kwargs (#6836)

This commit is contained in:
Cyrus Leung
2024-07-31 10:38:45 +08:00
committed by GitHub
parent da1f7cc12a
commit f230cc2ca6
16 changed files with 254 additions and 211 deletions

View File

@@ -17,7 +17,7 @@ from functools import lru_cache, partial, wraps
from platform import uname
from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic,
Hashable, List, Optional, OrderedDict, Set, Tuple, TypeVar,
Union)
Union, overload)
import numpy as np
import numpy.typing as npt
@@ -53,6 +53,7 @@ TORCH_DTYPE_TO_NUMPY_DTYPE = {
P = ParamSpec('P')
K = TypeVar("K")
T = TypeVar("T")
U = TypeVar("U")
class _Sentinel:
@@ -712,6 +713,54 @@ def merge_dicts(dict1: Dict[K, List[T]],
return dict(merged_dict)
JSONTree = Union[Dict[str, "JSONTree[T]"], List["JSONTree[T]"],
Tuple["JSONTree[T]", ...], T]
"""A nested JSON structure where the leaves need not be JSON-serializable."""
@overload
def json_map_leaves(
func: Callable[[T], U],
value: Dict[str, JSONTree[T]],
) -> Dict[str, JSONTree[U]]:
...
@overload
def json_map_leaves(
func: Callable[[T], U],
value: List[JSONTree[T]],
) -> List[JSONTree[U]]:
...
@overload
def json_map_leaves(
func: Callable[[T], U],
value: Tuple[JSONTree[T], ...],
) -> Tuple[JSONTree[U], ...]:
...
@overload
def json_map_leaves(
func: Callable[[T], U],
value: JSONTree[T],
) -> JSONTree[U]:
...
def json_map_leaves(func: Callable[[T], U], value: JSONTree[T]) -> JSONTree[U]:
if isinstance(value, dict):
return {k: json_map_leaves(func, v) for k, v in value.items()}
elif isinstance(value, list):
return [json_map_leaves(func, v) for v in value]
elif isinstance(value, tuple):
return tuple(json_map_leaves(func, v) for v in value)
else:
return func(value)
def flatten_2d_lists(lists: List[List[T]]) -> List[T]:
"""Flatten a list of lists to a single list."""
return [item for sublist in lists for item in sublist]