[Model] Extend collect_children and no_init_weights contexts (#32757)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -7,10 +7,10 @@ This is similar in concept to the `collections` module.
|
||||
"""
|
||||
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable, Generator, Hashable, Iterable, Mapping
|
||||
from collections.abc import Callable, Generator, Hashable, Iterable, Mapping, Sequence
|
||||
from typing import Generic, Literal, TypeVar
|
||||
|
||||
from typing_extensions import TypeIs, assert_never
|
||||
from typing_extensions import TypeIs, assert_never, overload
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@@ -74,6 +74,34 @@ def is_list_of(
|
||||
assert_never(check)
|
||||
|
||||
|
||||
@overload
|
||||
def common_prefix(items: Sequence[str]) -> str: ...
|
||||
|
||||
|
||||
@overload
|
||||
def common_prefix(items: Sequence[Sequence[T]]) -> Sequence[T]: ...
|
||||
|
||||
|
||||
def common_prefix(items: Sequence[Sequence[T] | str]) -> Sequence[T] | str:
|
||||
"""Find the longest prefix common to all items."""
|
||||
if len(items) == 0:
|
||||
return []
|
||||
if len(items) == 1:
|
||||
return items[0]
|
||||
|
||||
shortest = min(items, key=len)
|
||||
if not shortest:
|
||||
return shortest[:0]
|
||||
|
||||
for match_len in range(1, len(shortest) + 1):
|
||||
match = shortest[:match_len]
|
||||
for item in items:
|
||||
if item[:match_len] != match:
|
||||
return shortest[: match_len - 1]
|
||||
|
||||
return shortest
|
||||
|
||||
|
||||
def chunk_list(lst: list[T], chunk_size: int) -> Generator[list[T]]:
|
||||
"""Yield successive chunk_size chunks from lst."""
|
||||
for i in range(0, len(lst), chunk_size):
|
||||
|
||||
Reference in New Issue
Block a user