[Model] Extend collect_children and no_init_weights contexts (#32757)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-01-22 16:20:27 +08:00
committed by GitHub
parent 1bf1a34b19
commit 2b8a38b6d6
20 changed files with 444 additions and 257 deletions

View File

@@ -7,10 +7,10 @@ This is similar in concept to the `collections` module.
"""
from collections import defaultdict
from collections.abc import Callable, Generator, Hashable, Iterable, Mapping
from collections.abc import Callable, Generator, Hashable, Iterable, Mapping, Sequence
from typing import Generic, Literal, TypeVar
from typing_extensions import TypeIs, assert_never
from typing_extensions import TypeIs, assert_never, overload
T = TypeVar("T")
@@ -74,6 +74,34 @@ def is_list_of(
assert_never(check)
@overload
def common_prefix(items: Sequence[str]) -> str: ...
@overload
def common_prefix(items: Sequence[Sequence[T]]) -> Sequence[T]: ...
def common_prefix(items: Sequence[Sequence[T] | str]) -> Sequence[T] | str:
"""Find the longest prefix common to all items."""
if len(items) == 0:
return []
if len(items) == 1:
return items[0]
shortest = min(items, key=len)
if not shortest:
return shortest[:0]
for match_len in range(1, len(shortest) + 1):
match = shortest[:match_len]
for item in items:
if item[:match_len] != match:
return shortest[: match_len - 1]
return shortest
def chunk_list(lst: list[T], chunk_size: int) -> Generator[list[T]]:
"""Yield successive chunk_size chunks from lst."""
for i in range(0, len(lst), chunk_size):