[Model] Extend collect_children and no_init_weights contexts (#32757)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -2,11 +2,27 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
|
||||
from vllm.utils.collection_utils import swap_dict_values
|
||||
from vllm.utils.collection_utils import common_prefix, swap_dict_values
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"obj,key1,key2",
|
||||
("inputs", "expected_output"),
|
||||
[
|
||||
([""], ""),
|
||||
(["a"], "a"),
|
||||
(["a", "b"], ""),
|
||||
(["a", "ab"], "a"),
|
||||
(["a", "ab", "b"], ""),
|
||||
(["abc", "a", "ab"], "a"),
|
||||
(["aba", "abc", "ab"], "ab"),
|
||||
],
|
||||
)
|
||||
def test_common_prefix(inputs, expected_output):
|
||||
assert common_prefix(inputs) == expected_output
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("obj", "key1", "key2"),
|
||||
[
|
||||
# Tests for both keys exist
|
||||
({1: "a", 2: "b"}, 1, 2),
|
||||
|
||||
@@ -18,7 +18,7 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from vllm.model_executor.models.interfaces import SupportsQuant, supports_multimodal
|
||||
from vllm.model_executor.models.interfaces import SupportsQuant
|
||||
from vllm.utils.platform_utils import is_pin_memory_available
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -165,11 +165,7 @@ _MODEL_ARCH_BY_HASH = dict[int, tuple[type[nn.Module], str]]()
|
||||
|
||||
|
||||
def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module], str]:
|
||||
from vllm.model_executor.models.adapters import (
|
||||
as_embedding_model,
|
||||
as_seq_cls_model,
|
||||
try_create_mm_pooling_model_cls,
|
||||
)
|
||||
from vllm.model_executor.models.adapters import as_embedding_model, as_seq_cls_model
|
||||
|
||||
architectures = getattr(model_config.hf_config, "architectures", [])
|
||||
|
||||
@@ -189,15 +185,6 @@ def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module],
|
||||
)
|
||||
|
||||
convert_type = model_config.convert_type
|
||||
if convert_type != "none" and supports_multimodal(model_cls):
|
||||
logger.debug_once("Detected conversion of Multi Modal model.")
|
||||
converted = try_create_mm_pooling_model_cls(model_cls)
|
||||
if converted is not None:
|
||||
logger.debug_once("Creating wrapper class to forward pooler.")
|
||||
return converted, arch
|
||||
else:
|
||||
logger.debug_once("Attempting direct conversion.")
|
||||
|
||||
if convert_type == "none":
|
||||
pass
|
||||
elif convert_type == "embed":
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import ast
|
||||
import inspect
|
||||
from collections.abc import Iterable
|
||||
from typing import TYPE_CHECKING, Any, TypeVar, cast
|
||||
|
||||
@@ -18,10 +16,12 @@ from vllm.transformers_utils.config import (
|
||||
)
|
||||
from vllm.transformers_utils.repo_utils import get_hf_file_bytes
|
||||
|
||||
from .interfaces import supports_multimodal
|
||||
from .interfaces_base import VllmModelForPooling, is_pooling_model
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.model_executor.layers.pooler import Pooler
|
||||
|
||||
_T = TypeVar("_T", bound=type[nn.Module])
|
||||
|
||||
@@ -124,41 +124,12 @@ def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str:
|
||||
return model_name + pooling_suffix
|
||||
|
||||
|
||||
def try_create_mm_pooling_model_cls(orig_cls: _T) -> _T:
|
||||
class CallVisitor(ast.NodeVisitor):
|
||||
def __init__(self):
|
||||
self.calls = []
|
||||
|
||||
def visit_Call(self, node):
|
||||
if isinstance(node.func, ast.Name):
|
||||
self.calls.append(node.func.id)
|
||||
self.generic_visit(node)
|
||||
|
||||
visitor = CallVisitor()
|
||||
visitor.visit(ast.parse(inspect.getsource(orig_cls)))
|
||||
if "init_vllm_registered_model" not in visitor.calls:
|
||||
return None
|
||||
|
||||
class ModelForPooling(orig_cls, VllmModelForPooling):
|
||||
is_pooling_model = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vllm_config: "VllmConfig",
|
||||
prefix: str = "",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
|
||||
|
||||
self.pooler = self.get_language_model().pooler
|
||||
|
||||
return ModelForPooling # type: ignore
|
||||
|
||||
|
||||
def _create_pooling_model_cls(orig_cls: _T) -> _T:
|
||||
# Lazy import
|
||||
from .utils import AutoWeightsLoader, WeightsMapper
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
|
||||
from .utils import AutoWeightsLoader, StageMissingLayer, no_init_weights
|
||||
|
||||
class ModelForPooling(orig_cls, VllmModelForPooling):
|
||||
is_pooling_model = True
|
||||
@@ -170,69 +141,84 @@ def _create_pooling_model_cls(orig_cls: _T) -> _T:
|
||||
prefix: str = "",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
|
||||
with no_init_weights(
|
||||
self,
|
||||
lambda mod: StageMissingLayer("output", mod),
|
||||
targets=(LogitsProcessor, ParallelLMHead),
|
||||
):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs)
|
||||
|
||||
# Used by SEQ_CLS_LOAD_METHODS
|
||||
self.vllm_config = vllm_config
|
||||
|
||||
# These are not used in pooling models
|
||||
objects_to_clean = [self]
|
||||
if language_model := getattr(self, "language_model", None):
|
||||
objects_to_clean.append(language_model)
|
||||
|
||||
for obj in objects_to_clean:
|
||||
for attr in ("lm_head", "logits_processor"):
|
||||
if hasattr(obj, attr):
|
||||
delattr(obj, attr)
|
||||
|
||||
# If the model already defines a pooler instance, don't overwrite it
|
||||
if not getattr(self, "pooler", None):
|
||||
self._init_pooler(vllm_config, prefix=prefix)
|
||||
pooler = getattr(self, "pooler", None)
|
||||
if not pooler and supports_multimodal(self):
|
||||
# Try to get the pooler from the LM backbone
|
||||
language_model = self.get_language_model()
|
||||
if hasattr(language_model, "pooler"):
|
||||
pooler = language_model.pooler
|
||||
|
||||
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
|
||||
if not pooler:
|
||||
pooler = self._init_pooler(vllm_config, prefix=prefix)
|
||||
|
||||
self.pooler = pooler
|
||||
|
||||
def _init_pooler(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
prefix: str = "",
|
||||
) -> "Pooler":
|
||||
raise NotImplementedError
|
||||
|
||||
def load_weights(
|
||||
self,
|
||||
weights: Iterable[tuple[str, torch.Tensor]],
|
||||
load_lm_head: bool = False,
|
||||
):
|
||||
# TODO: Support uninitialized params tracking
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
params_dict = dict(self.named_parameters())
|
||||
|
||||
# For most pooling models: We have deleted this attribute, so don't load it.
|
||||
# For converting an LLM into a seq cls model, we need the lm_head.
|
||||
if not load_lm_head:
|
||||
weights = (
|
||||
(name, data)
|
||||
for name, data in weights
|
||||
if not name.startswith("lm_head.")
|
||||
# We support loading from both `*ForCausalLM` and `*Model`
|
||||
candidate_prefixes = ["", "model."]
|
||||
target_prefix = ""
|
||||
|
||||
seen_weights = list[tuple[str, torch.Tensor]]()
|
||||
for name, loaded_weight in weights:
|
||||
seen_weights.append((name, loaded_weight))
|
||||
|
||||
try:
|
||||
target_prefix = next(
|
||||
prefix
|
||||
for prefix in candidate_prefixes
|
||||
if prefix + name in params_dict
|
||||
)
|
||||
break
|
||||
except StopIteration:
|
||||
# The weight might not exist on the model
|
||||
# (to be handled by AutoWeightsLoader)
|
||||
pass
|
||||
|
||||
if target_prefix:
|
||||
target_model = self
|
||||
for attr in target_prefix.split("."):
|
||||
if attr:
|
||||
target_model = getattr(self, attr)
|
||||
|
||||
logger.info(
|
||||
"Mapping weights to %s as they are "
|
||||
"relative to this model instead of %s.",
|
||||
target_model._get_name(),
|
||||
self._get_name(),
|
||||
)
|
||||
|
||||
# If `*ForCausalLM` defines `load_weights` on the inner model
|
||||
# and there are no other inner modules with parameters,
|
||||
# we support loading from both `*Model` and `*ForCausalLM`
|
||||
if hasattr(self, "model") and hasattr(self.model, "load_weights"):
|
||||
# Whether only `self.model` contains parameters
|
||||
model_is_only_param = all(
|
||||
name == "model" or next(child.parameters(), None) is None
|
||||
for name, child in self.named_children()
|
||||
)
|
||||
mapped_weights = (
|
||||
(target_prefix + name, weight)
|
||||
for name, weight in (*seen_weights, *weights)
|
||||
)
|
||||
|
||||
if model_is_only_param:
|
||||
mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
|
||||
weights = mapper.apply(weights)
|
||||
|
||||
loaded_params = self.model.load_weights(weights)
|
||||
loaded_params = {f"model.{name}" for name in loaded_params}
|
||||
return loaded_params
|
||||
|
||||
# For most other models
|
||||
if hasattr(orig_cls, "load_weights"):
|
||||
return orig_cls.load_weights(self, weights) # type: ignore
|
||||
# Fallback
|
||||
else:
|
||||
def default_load_weights(weights):
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
load_weights = getattr(super(), "load_weights", default_load_weights)
|
||||
return load_weights(mapped_weights)
|
||||
|
||||
return ModelForPooling # type: ignore
|
||||
|
||||
|
||||
@@ -255,11 +241,15 @@ def as_embedding_model(cls: _T) -> _T:
|
||||
from vllm.model_executor.layers.pooler import DispatchPooler
|
||||
|
||||
class ModelForEmbedding(_create_pooling_model_cls(cls)):
|
||||
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
|
||||
def _init_pooler(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
prefix: str = "",
|
||||
) -> "Pooler":
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
self.pooler = DispatchPooler.for_embedding(pooler_config)
|
||||
return DispatchPooler.for_embedding(pooler_config)
|
||||
|
||||
ModelForEmbedding.__name__ = _get_pooling_model_name(cls.__name__, "ForEmbedding")
|
||||
|
||||
@@ -292,7 +282,11 @@ def as_seq_cls_model(cls: _T) -> _T:
|
||||
class ModelForSequenceClassification(
|
||||
_create_pooling_model_cls(cls), SupportsCrossEncoding
|
||||
):
|
||||
def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""):
|
||||
def _init_pooler(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
prefix: str = "",
|
||||
) -> "Pooler":
|
||||
text_config = vllm_config.model_config.hf_config.get_text_config()
|
||||
model_config = vllm_config.model_config
|
||||
quant_config = vllm_config.quant_config
|
||||
@@ -310,9 +304,7 @@ def as_seq_cls_model(cls: _T) -> _T:
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
self.pooler = DispatchPooler.for_seq_cls(
|
||||
pooler_config, classifier=self.score
|
||||
)
|
||||
return DispatchPooler.for_seq_cls(pooler_config, classifier=self.score)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
hf_config = self.config
|
||||
@@ -424,7 +416,7 @@ def load_weights_using_from_2_way_softmax(
|
||||
pooling_model_cls = next(
|
||||
x for x in type(model).__mro__ if x.__name__ == "ModelForPooling"
|
||||
)
|
||||
loaded_weights = pooling_model_cls.load_weights(model, weights, load_lm_head=True)
|
||||
loaded_weights = pooling_model_cls.load_weights(model, weights)
|
||||
|
||||
from vllm.tokenizers import get_tokenizer
|
||||
|
||||
|
||||
@@ -44,11 +44,11 @@ from .interfaces import (
|
||||
SupportsLoRA,
|
||||
SupportsMultiModal,
|
||||
SupportsPP,
|
||||
TowerMissingLayer,
|
||||
)
|
||||
from .siglip import SiglipVisionModel
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
StageMissingLayer,
|
||||
WeightsMapper,
|
||||
init_vllm_registered_model,
|
||||
maybe_prefix,
|
||||
@@ -426,9 +426,9 @@ class BagelForConditionalGeneration(
|
||||
hidden_size=llm_hidden_size,
|
||||
)
|
||||
else:
|
||||
self.vit_model = TowerMissingLayer("image")
|
||||
self.connector = TowerMissingLayer("image")
|
||||
self.vit_pos_embed = TowerMissingLayer("image")
|
||||
self.vit_model = StageMissingLayer("image_tower")
|
||||
self.connector = StageMissingLayer("image_tower")
|
||||
self.vit_pos_embed = StageMissingLayer("image_tower")
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors
|
||||
|
||||
@@ -935,9 +935,20 @@ class ChameleonForConditionalGeneration(
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
self.model = ChameleonModel(
|
||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
|
||||
)
|
||||
|
||||
with self._mark_composite_model(
|
||||
vllm_config,
|
||||
language_targets=(
|
||||
ChameleonDecoderLayer
|
||||
if not self.config.swin_norm
|
||||
else ChameleonSwinDecoderLayer
|
||||
),
|
||||
tower_targets={"image": ChameleonVQVAE},
|
||||
):
|
||||
self.model = ChameleonModel(
|
||||
vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"),
|
||||
)
|
||||
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
@@ -970,9 +981,6 @@ class ChameleonForConditionalGeneration(
|
||||
resolve_bindings={"h": expected_h, "w": expected_w},
|
||||
)
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.model
|
||||
|
||||
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
|
||||
@@ -539,10 +539,7 @@ class Gemma3ForConditionalGeneration(
|
||||
)
|
||||
|
||||
logit_scale = getattr(config, "logit_scale", 1.0)
|
||||
if hasattr(self.language_model, "logits_processor"):
|
||||
# The logits processor can be unset if we're using
|
||||
# automatic conversion to pooling model.
|
||||
self.language_model.logits_processor.scale *= logit_scale
|
||||
self.language_model.logits_processor.scale *= logit_scale
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors
|
||||
|
||||
@@ -53,7 +53,7 @@ from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs import ChatGLMConfig
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .chatglm import ChatGLMBaseModel, ChatGLMModel
|
||||
from .chatglm import ChatGLMBaseModel, ChatGLMModel, GLMTransformer
|
||||
from .interfaces import (
|
||||
MultiModalEmbeddings,
|
||||
SupportsLoRA,
|
||||
@@ -591,11 +591,16 @@ class GLM4VForCausalLM(
|
||||
prefix: str = "",
|
||||
transformer_type: type[GLM4VModel] = GLM4VModel,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
vllm_config=vllm_config,
|
||||
prefix=prefix,
|
||||
transformer_type=transformer_type,
|
||||
)
|
||||
with self._mark_composite_model(
|
||||
vllm_config,
|
||||
language_targets=GLMTransformer,
|
||||
tower_targets={"image": EVA2CLIPModel},
|
||||
):
|
||||
super().__init__(
|
||||
vllm_config=vllm_config,
|
||||
prefix=prefix,
|
||||
transformer_type=transformer_type,
|
||||
)
|
||||
|
||||
self.transformer: GLM4VModel
|
||||
|
||||
@@ -752,9 +757,6 @@ class GLM4VForCausalLM(
|
||||
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
|
||||
return llm_positions, mrope_position_delta
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.transformer
|
||||
|
||||
embed_input_ids = SupportsMultiModal.embed_input_ids
|
||||
|
||||
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
|
||||
@@ -57,7 +57,11 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
from .idefics2_vision_model import (
|
||||
Idefics2VisionTransformer as Idefics3VisionTransformer,
|
||||
)
|
||||
from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal
|
||||
from .interfaces import (
|
||||
MultiModalEmbeddings,
|
||||
SupportsLoRA,
|
||||
SupportsMultiModal,
|
||||
)
|
||||
from .llama import LlamaModel
|
||||
from .utils import AutoWeightsLoader, maybe_prefix
|
||||
|
||||
@@ -604,9 +608,16 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLo
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
|
||||
self.model = Idefics3Model(
|
||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
|
||||
)
|
||||
with self._mark_composite_model(
|
||||
vllm_config,
|
||||
language_targets=LlamaModel,
|
||||
tower_targets={"image": (Idefics3VisionTransformer, Idefics3Connector)},
|
||||
):
|
||||
self.model = Idefics3Model(
|
||||
vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"),
|
||||
)
|
||||
|
||||
self.image_token_id = self.config.image_token_id
|
||||
|
||||
self.lm_head = ParallelLMHead(
|
||||
@@ -669,9 +680,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLo
|
||||
num_patches = image_input["num_patches"]
|
||||
return [e.flatten(0, 1) for e in image_features.split(num_patches.tolist())]
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.model
|
||||
|
||||
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable, Iterable, Mapping, MutableSequence
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from contextlib import ExitStack, contextmanager, nullcontext
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
ClassVar,
|
||||
@@ -25,6 +25,7 @@ from vllm.inputs import TokensPrompt
|
||||
from vllm.inputs.data import PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.utils.collection_utils import common_prefix
|
||||
from vllm.utils.func_utils import supports_kw
|
||||
|
||||
from .interfaces_base import VllmModel, is_pooling_model
|
||||
@@ -70,46 +71,8 @@ def _require_is_multimodal(is_multimodal: Tensor | None) -> Tensor:
|
||||
return is_multimodal
|
||||
|
||||
|
||||
class LMMissingLayer(nn.Module):
|
||||
def make_empty_intermediate_tensors(self, *args, **kwargs):
|
||||
raise RuntimeError("This module should not be called in MM encoder-only mode")
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
raise RuntimeError("This module should not be called in MM encoder-only mode")
|
||||
|
||||
|
||||
class TowerMissingLayer(nn.Module):
|
||||
def __init__(self, modalities: set[str] | str) -> None:
|
||||
if isinstance(modalities, str):
|
||||
modalities = {modalities}
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.modalities = modalities
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
raise RuntimeError(
|
||||
f"This module should not be called when the following "
|
||||
f"modalities are disabled: {self.modalities}"
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _no_init_weights(module: nn.Module, placeholder: Callable[[], nn.Module]):
|
||||
"""
|
||||
Within this context, prevent weight initialization from using device memory and
|
||||
replace direct child assignments to `module` with the result of `placeholder()`.
|
||||
"""
|
||||
|
||||
def callback(module_, name, submodule):
|
||||
if module_ is module:
|
||||
return placeholder()
|
||||
|
||||
return submodule
|
||||
|
||||
with torch.nn.modules.module.register_module_module_registration_hook(callback): # noqa: E501,SIM117
|
||||
with torch.device("meta"):
|
||||
yield
|
||||
# Cache results of `SupportsMultiModal.get_language_model`
|
||||
_language_model_by_module = dict[nn.Module, VllmModel]()
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@@ -187,31 +150,61 @@ class SupportsMultiModal(Protocol):
|
||||
Returns:
|
||||
torch.nn.Module: The core language model component.
|
||||
"""
|
||||
# Cached
|
||||
if self in _language_model_by_module:
|
||||
return _language_model_by_module[self]
|
||||
|
||||
if self._language_model_names:
|
||||
return getattr(self, self._language_model_names[0])
|
||||
mod = self
|
||||
for attr in common_prefix(
|
||||
[name.split(".") for name in self._language_model_names]
|
||||
):
|
||||
if attr:
|
||||
mod = getattr(mod, attr)
|
||||
|
||||
if mod is not self and hasattr(mod, "embed_input_ids"):
|
||||
_language_model_by_module[self] = mod
|
||||
return mod
|
||||
|
||||
# Fallback
|
||||
for mod in self.children():
|
||||
if hasattr(mod, "embed_input_ids"):
|
||||
_language_model_by_module[self] = mod
|
||||
return mod
|
||||
|
||||
raise NotImplementedError(
|
||||
f"No language model found in {type(self).__name__}! "
|
||||
"You should initialize it inside `_mark_language_model`."
|
||||
"You should initialize it via `_mark_language_model`."
|
||||
)
|
||||
|
||||
@contextmanager
|
||||
def _mark_language_model(self, vllm_config: VllmConfig):
|
||||
def _mark_language_model(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
*,
|
||||
targets: type[nn.Module] | tuple[type[nn.Module], ...] | None = None,
|
||||
):
|
||||
"""
|
||||
Mark each child module that was assigned to this model
|
||||
during this context as a language model component.
|
||||
Mark each child module that was assigned to this model during this context
|
||||
as a language model component.
|
||||
|
||||
Language model components are automatically skipped in `--mm-encoder-only`
|
||||
mode.
|
||||
|
||||
If `targets` is set, instead include descendants that are an instance
|
||||
of `targets`, even if they aren't direct children.
|
||||
"""
|
||||
from .utils import StageMissingLayer, collect_children, no_init_weights
|
||||
|
||||
mm_config = vllm_config.model_config.multimodal_config
|
||||
|
||||
children_names = list[str]()
|
||||
|
||||
def callback(module_, name, submodule):
|
||||
if module_ is self:
|
||||
children_names.append(name)
|
||||
|
||||
with torch.nn.modules.module.register_module_module_registration_hook(callback): # noqa: E501,SIM117
|
||||
with collect_children(self, targets=targets) as children_names: # noqa: SIM117
|
||||
with (
|
||||
_no_init_weights(self, LMMissingLayer)
|
||||
no_init_weights(
|
||||
self,
|
||||
lambda mod: StageMissingLayer("language_model", mod),
|
||||
targets=targets,
|
||||
)
|
||||
if mm_config.mm_encoder_only
|
||||
else nullcontext()
|
||||
):
|
||||
@@ -220,25 +213,42 @@ class SupportsMultiModal(Protocol):
|
||||
self._language_model_names = children_names
|
||||
|
||||
@contextmanager
|
||||
def _mark_tower_model(self, vllm_config: VllmConfig, modalities: set[str] | str):
|
||||
def _mark_tower_model(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
modalities: set[str] | str,
|
||||
*,
|
||||
targets: type[nn.Module] | tuple[type[nn.Module], ...] | None = None,
|
||||
):
|
||||
"""
|
||||
Mark each child module that was assigned to this model
|
||||
during this context as a tower model component.
|
||||
Mark each child module that was assigned to this model during this context
|
||||
as a tower model component.
|
||||
|
||||
Tower model components are automatically skipped when `--limit-mm-per-prompt`
|
||||
is set to zero for all of their modalities.
|
||||
|
||||
If `targets` is set, instead include descendants that are an instance
|
||||
of `targets`, even if they aren't direct children.
|
||||
"""
|
||||
from .utils import StageMissingLayer, collect_children, no_init_weights
|
||||
|
||||
if isinstance(modalities, str):
|
||||
modalities = {modalities}
|
||||
|
||||
if modalities == {"image", "video"}:
|
||||
stage_name = "vision_tower"
|
||||
else:
|
||||
stage_name = "_".join([*modalities, "tower"])
|
||||
|
||||
mm_config = vllm_config.model_config.multimodal_config
|
||||
|
||||
children_names = list[str]()
|
||||
|
||||
def callback(module_, name, submodule):
|
||||
if module_ is self:
|
||||
children_names.append(name)
|
||||
|
||||
with torch.nn.modules.module.register_module_module_registration_hook(callback): # noqa: E501,SIM117
|
||||
with collect_children(self, targets=targets) as children_names: # noqa: SIM117
|
||||
with (
|
||||
_no_init_weights(self, lambda: TowerMissingLayer(modalities))
|
||||
no_init_weights(
|
||||
self,
|
||||
lambda mod: StageMissingLayer(stage_name, mod),
|
||||
targets=targets,
|
||||
)
|
||||
if all(mm_config.get_limit_per_prompt(m) == 0 for m in modalities)
|
||||
else nullcontext()
|
||||
):
|
||||
@@ -246,6 +256,37 @@ class SupportsMultiModal(Protocol):
|
||||
|
||||
self._tower_model_names = children_names
|
||||
|
||||
@contextmanager
|
||||
def _mark_composite_model(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
*,
|
||||
language_targets: type[nn.Module] | tuple[type[nn.Module], ...],
|
||||
tower_targets: dict[str, type[nn.Module] | tuple[type[nn.Module], ...]],
|
||||
):
|
||||
"""
|
||||
Composite wrapper over `_mark_language_model` and
|
||||
`_mark_tower_model` by modality.
|
||||
"""
|
||||
with ExitStack() as stack:
|
||||
stack.enter_context(
|
||||
self._mark_language_model(
|
||||
vllm_config,
|
||||
targets=language_targets,
|
||||
)
|
||||
)
|
||||
|
||||
for modality, modality_targets in tower_targets.items():
|
||||
stack.enter_context(
|
||||
self._mark_tower_model(
|
||||
vllm_config,
|
||||
modality,
|
||||
targets=modality_targets,
|
||||
)
|
||||
)
|
||||
|
||||
yield
|
||||
|
||||
def get_num_mm_encoder_tokens(self, num_image_tokens: int) -> int:
|
||||
"""
|
||||
Implement this function to enable LoRA support
|
||||
|
||||
@@ -41,10 +41,12 @@ from vllm.sequence import IntermediateTensors
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .interfaces_base import default_pooling_type
|
||||
from .utils import (
|
||||
StageMissingLayer,
|
||||
is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory,
|
||||
make_layers,
|
||||
maybe_prefix,
|
||||
no_init_weights,
|
||||
)
|
||||
|
||||
|
||||
@@ -413,10 +415,16 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM):
|
||||
prefix: str = "",
|
||||
model_type: type[InternLM2Model] = InternLM2Model,
|
||||
):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix, model_type=model_type)
|
||||
|
||||
for attr in ("output", "logits_processor"):
|
||||
delattr(self, attr)
|
||||
with no_init_weights(
|
||||
self,
|
||||
lambda mod: StageMissingLayer("output", mod),
|
||||
targets=(LogitsProcessor, ParallelLMHead),
|
||||
):
|
||||
super().__init__(
|
||||
vllm_config=vllm_config,
|
||||
prefix=prefix,
|
||||
model_type=model_type,
|
||||
)
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.head_dtype = vllm_config.model_config.head_dtype
|
||||
|
||||
@@ -1035,11 +1035,13 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
)
|
||||
|
||||
with self._mark_tower_model(vllm_config, {"image", "video"}):
|
||||
self.vpm = vpm = self.init_vision_module(
|
||||
self.vpm = self.init_vision_module(
|
||||
config, quant_config, prefix=maybe_prefix(prefix, "vpm")
|
||||
)
|
||||
self.vision_dim = (
|
||||
vpm.embed_dim if self.version == (2, 0) else vpm.embeddings.embed_dim
|
||||
self.vpm.embed_dim
|
||||
if self.version == (2, 0)
|
||||
else self.vpm.embeddings.embed_dim
|
||||
)
|
||||
self.embed_dim = self.config.hidden_size
|
||||
|
||||
|
||||
@@ -70,20 +70,15 @@ from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import (
|
||||
LMMissingLayer,
|
||||
MixtureOfExperts,
|
||||
MultiModalEmbeddings,
|
||||
SupportsEagle3,
|
||||
SupportsLoRA,
|
||||
SupportsMultiModal,
|
||||
SupportsPP,
|
||||
TowerMissingLayer,
|
||||
)
|
||||
from .llama4 import Llama4ForCausalLM
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
maybe_prefix,
|
||||
)
|
||||
from .utils import AutoWeightsLoader, StageMissingLayer, maybe_prefix
|
||||
from .vision import run_dp_sharded_vision_model
|
||||
|
||||
|
||||
@@ -1024,7 +1019,7 @@ class Llama4ForConditionalGeneration(
|
||||
renamed = self._rename_weight_for_modelopt_checkpoint(name)
|
||||
|
||||
attr = renamed.split(".", 1)[0]
|
||||
if isinstance(getattr(self, attr), (LMMissingLayer, TowerMissingLayer)):
|
||||
if isinstance(getattr(self, attr), StageMissingLayer):
|
||||
continue
|
||||
|
||||
if renamed.startswith("language_model."):
|
||||
|
||||
@@ -1513,7 +1513,7 @@ class NemotronH_Nano_VL_V2(
|
||||
self.video_pruning_rate = multimodal_config.video_pruning_rate
|
||||
|
||||
with self._mark_language_model(vllm_config):
|
||||
self.language_model = language_model = init_vllm_registered_model(
|
||||
self.language_model = init_vllm_registered_model(
|
||||
vllm_config=vllm_config,
|
||||
hf_config=config.text_config,
|
||||
prefix=maybe_prefix(prefix, "language_model"),
|
||||
@@ -1542,7 +1542,7 @@ class NemotronH_Nano_VL_V2(
|
||||
ReLUSquaredActivation(),
|
||||
nn.Linear(vision_projection_hidden_size, llm_hidden_size, bias=False),
|
||||
)
|
||||
self.mlp1 = mlp1.to(language_model.config.dtype)
|
||||
self.mlp1 = mlp1.to(self.language_model.config.dtype)
|
||||
|
||||
self.config = config
|
||||
self.model_config = vllm_config.model_config
|
||||
|
||||
@@ -1025,12 +1025,12 @@ class PaddleOCRVLForConditionalGeneration(nn.Module, SupportsMultiModal, Support
|
||||
self.mlp_AR = Projector(config, config.vision_config)
|
||||
|
||||
with self._mark_language_model(vllm_config):
|
||||
self.language_model = language_model = Ernie4_5ForCausalLM(
|
||||
self.language_model = Ernie4_5ForCausalLM(
|
||||
vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "language_model"),
|
||||
)
|
||||
|
||||
for layer in language_model.model.layers:
|
||||
for layer in self.language_model.model.layers:
|
||||
if not isinstance(layer, PPMissingLayer):
|
||||
layer.self_attn.rotary_emb.is_neox_style = True
|
||||
|
||||
|
||||
@@ -314,13 +314,14 @@ class PaliGemmaForConditionalGeneration(
|
||||
config.text_config.architectures = ["Gemma2ForCausalLM"]
|
||||
|
||||
with self._mark_language_model(vllm_config):
|
||||
self.language_model = language_model = init_vllm_registered_model(
|
||||
self.language_model = init_vllm_registered_model(
|
||||
vllm_config=vllm_config,
|
||||
hf_config=config.text_config,
|
||||
prefix=maybe_prefix(prefix, "language_model"),
|
||||
)
|
||||
|
||||
logit_scale = getattr(config, "logit_scale", 1.0)
|
||||
language_model.logits_processor.scale *= logit_scale
|
||||
self.language_model.logits_processor.scale *= logit_scale
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors
|
||||
|
||||
@@ -461,15 +461,16 @@ class Qwen3VLMoeForConditionalGeneration(
|
||||
]
|
||||
|
||||
with self._mark_language_model(vllm_config):
|
||||
self.language_model = language_model = Qwen3MoeLLMForCausalLM(
|
||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "language_model")
|
||||
self.language_model = Qwen3MoeLLMForCausalLM(
|
||||
vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "language_model"),
|
||||
)
|
||||
|
||||
# Whether to include the gate_up_proj mapping is determined by
|
||||
# the language model.
|
||||
self.packed_modules_mapping = (
|
||||
self.packed_modules_mapping | language_model.packed_modules_mapping
|
||||
)
|
||||
# Whether to include the gate_up_proj mapping is determined by
|
||||
# the language model.
|
||||
self.packed_modules_mapping = (
|
||||
self.packed_modules_mapping | self.language_model.packed_modules_mapping
|
||||
)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors
|
||||
|
||||
@@ -58,7 +58,7 @@ from .interfaces import (
|
||||
SupportsMultiModal,
|
||||
SupportsPP,
|
||||
)
|
||||
from .qwen import QWenBaseModel, QWenModel
|
||||
from .qwen import QWenBaseModel, QWenBlock, QWenModel
|
||||
|
||||
|
||||
class QwenImagePixelInputs(TensorSchema):
|
||||
@@ -757,11 +757,16 @@ class QwenVLForConditionalGeneration(
|
||||
prefix: str = "",
|
||||
transformer_type: type[QwenVLModel] = QwenVLModel,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
vllm_config=vllm_config,
|
||||
prefix=prefix,
|
||||
transformer_type=transformer_type,
|
||||
)
|
||||
with self._mark_composite_model(
|
||||
vllm_config,
|
||||
language_targets=QWenBlock,
|
||||
tower_targets={"image": VisionTransformer},
|
||||
):
|
||||
super().__init__(
|
||||
vllm_config=vllm_config,
|
||||
prefix=prefix,
|
||||
transformer_type=transformer_type,
|
||||
)
|
||||
|
||||
self.transformer: QwenVLModel
|
||||
|
||||
@@ -795,9 +800,6 @@ class QwenVLForConditionalGeneration(
|
||||
|
||||
return self.transformer.visual(image_input["data"])
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.transformer
|
||||
|
||||
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
|
||||
@@ -2,13 +2,15 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import itertools
|
||||
from collections.abc import Iterable, Mapping
|
||||
from collections.abc import Callable, Iterable, Mapping
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Literal, Protocol, overload
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.func import functional_call
|
||||
from torch.nn.modules.module import register_module_module_registration_hook
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
@@ -24,11 +26,7 @@ from vllm.model_executor.model_loader.online_quantization import (
|
||||
support_quantized_model_reload_from_hp_weights,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.interfaces import (
|
||||
LMMissingLayer,
|
||||
TowerMissingLayer,
|
||||
supports_any_eagle,
|
||||
)
|
||||
from vllm.model_executor.models.interfaces import supports_any_eagle
|
||||
from vllm.multimodal import NestedTensors
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.math_utils import cdiv
|
||||
@@ -214,8 +212,8 @@ class AutoWeightsLoader:
|
||||
continue
|
||||
|
||||
raise ValueError(
|
||||
f"Attempted to load nested weight '{weight_qualname}' "
|
||||
f"into a single parameter '{base_prefix}'"
|
||||
f"Attempted to load nested weight {weight_qualname!r} "
|
||||
f"into a single parameter {base_prefix!r}"
|
||||
)
|
||||
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
@@ -254,7 +252,7 @@ class AutoWeightsLoader:
|
||||
module: nn.Module,
|
||||
weights: Iterable[tuple[str, torch.Tensor]],
|
||||
) -> Iterable[str]:
|
||||
if isinstance(module, (LMMissingLayer, TowerMissingLayer, PPMissingLayer)):
|
||||
if isinstance(module, (StageMissingLayer, PPMissingLayer)):
|
||||
return
|
||||
|
||||
# Avoid infinite recursion since this function is typically
|
||||
@@ -316,9 +314,14 @@ class AutoWeightsLoader:
|
||||
|
||||
continue
|
||||
|
||||
desc_param_keys = {
|
||||
base_prefix + k for k, _ in module.named_parameters(recurse=True)
|
||||
}
|
||||
msg = (
|
||||
f"There is no module or parameter named '{prefix}' "
|
||||
f"in {type(self.module).__name__}"
|
||||
f"There is no module or parameter named {prefix!r} "
|
||||
f"in {self.module._get_name()}. "
|
||||
f"The available parameters belonging to {base_prefix} "
|
||||
f"({module._get_name()}) are: {desc_param_keys}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
@@ -496,6 +499,100 @@ def isin_list(
|
||||
return torch.isin(elements, test_elements)
|
||||
|
||||
|
||||
class StageMissingLayer(nn.Module):
|
||||
def __init__(self, stage_name: str, module: nn.Module | None = None) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.stage_name = stage_name
|
||||
|
||||
# Don't register this as a child module in order to
|
||||
# avoid missing keys when loading weights
|
||||
self.__dict__["module"] = module
|
||||
|
||||
def __getattr__(self, name: str):
|
||||
return getattr(self.__dict__["module"], name)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
raise RuntimeError(f"{self} should not be called")
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return f"stage_name={self.stage_name!r}"
|
||||
|
||||
|
||||
@contextmanager
|
||||
def collect_children(
|
||||
module: nn.Module,
|
||||
*,
|
||||
targets: type[nn.Module] | tuple[type[nn.Module], ...] | None = None,
|
||||
):
|
||||
"""
|
||||
Within this context, collect all direct child assignments to `module`,
|
||||
returning a list of children names that is internally updated until the
|
||||
context is exited.
|
||||
|
||||
If `targets` is set, instead collect descendents of `module`
|
||||
that are an instance of `targets`, even if they aren't direct children.
|
||||
"""
|
||||
children_names = list[str]()
|
||||
|
||||
if targets is None:
|
||||
|
||||
def hook(module_: nn.Module, name: str, submodule: nn.Module):
|
||||
if module_ is module:
|
||||
children_names.append(name)
|
||||
|
||||
with register_module_module_registration_hook(hook):
|
||||
yield children_names
|
||||
else:
|
||||
yield children_names
|
||||
|
||||
for name, module_ in module.named_modules():
|
||||
if isinstance(module_, targets):
|
||||
children_names.append(name)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def no_init_weights(
|
||||
module: nn.Module,
|
||||
placeholder: Callable[[nn.Module], nn.Module],
|
||||
*,
|
||||
targets: type[nn.Module] | tuple[type[nn.Module], ...] | None = None,
|
||||
):
|
||||
"""
|
||||
Within this context, prevent weight initialization from using device memory and
|
||||
replace direct child assignments to `module` with the result of `placeholder()`.
|
||||
|
||||
If `targets` is set, instead prevent weight initialization and
|
||||
replace assignments where the child is an instance of `targets`,
|
||||
even if they aren't direct children of `module`.
|
||||
"""
|
||||
if targets is None:
|
||||
|
||||
def hook(module_: nn.Module, name: str, submodule: nn.Module):
|
||||
if module_ is module:
|
||||
return placeholder(submodule)
|
||||
|
||||
return submodule
|
||||
|
||||
with register_module_module_registration_hook(hook), torch.device("meta"):
|
||||
yield
|
||||
else:
|
||||
|
||||
def hook(module_: nn.Module, name: str, submodule: nn.Module):
|
||||
if isinstance(module_, targets):
|
||||
submodule.to("meta") # Free memory
|
||||
if isinstance(submodule, targets):
|
||||
submodule.to("meta") # Free memory
|
||||
return placeholder(submodule)
|
||||
|
||||
return submodule
|
||||
|
||||
# Not all descendents are targeted, so we can't use a blanket
|
||||
# `torch.device("meta")` context
|
||||
with register_module_module_registration_hook(hook):
|
||||
yield
|
||||
|
||||
|
||||
class LayerFn(Protocol):
|
||||
def __call__(self, prefix: str) -> torch.nn.Module: ...
|
||||
|
||||
@@ -627,7 +724,7 @@ def get_pp_missing_layer_names(model: torch.nn.Module) -> list[str]:
|
||||
|
||||
missing_layer_names = []
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, PPMissingLayer):
|
||||
if isinstance(module, (StageMissingLayer, PPMissingLayer)):
|
||||
# NOTE: the trailing dot is used to match the prefix of the layer.
|
||||
# without the dot, we could match a layer that is not missing,
|
||||
# e.g., 'encoder.layer.1' would match 'encoder.layer.11'
|
||||
@@ -639,7 +736,7 @@ def get_pp_missing_layer_names(model: torch.nn.Module) -> list[str]:
|
||||
|
||||
def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool:
|
||||
"""Check if a parameter is missing in a pipeline parallel model."""
|
||||
if isinstance(model, PPMissingLayer):
|
||||
if isinstance(model, (StageMissingLayer, PPMissingLayer)):
|
||||
return True
|
||||
|
||||
return any(
|
||||
|
||||
@@ -909,7 +909,12 @@ class WhisperForConditionalGeneration(
|
||||
self.config = config
|
||||
self.dtype = vllm_config.model_config.dtype
|
||||
|
||||
self.model = WhisperModel(vllm_config=vllm_config, prefix=prefix)
|
||||
with self._mark_composite_model(
|
||||
vllm_config,
|
||||
language_targets=WhisperDecoder,
|
||||
tower_targets={"audio": WhisperEncoder},
|
||||
):
|
||||
self.model = WhisperModel(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
self.proj_out = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
@@ -937,9 +942,6 @@ class WhisperForConditionalGeneration(
|
||||
)
|
||||
return decoder_outputs
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.model.decoder
|
||||
|
||||
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
# Required as part of SupportsMultiModal interface.
|
||||
audio_input = self._parse_and_validate_audio_input(**kwargs)
|
||||
|
||||
@@ -7,10 +7,10 @@ This is similar in concept to the `collections` module.
|
||||
"""
|
||||
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable, Generator, Hashable, Iterable, Mapping
|
||||
from collections.abc import Callable, Generator, Hashable, Iterable, Mapping, Sequence
|
||||
from typing import Generic, Literal, TypeVar
|
||||
|
||||
from typing_extensions import TypeIs, assert_never
|
||||
from typing_extensions import TypeIs, assert_never, overload
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@@ -74,6 +74,34 @@ def is_list_of(
|
||||
assert_never(check)
|
||||
|
||||
|
||||
@overload
|
||||
def common_prefix(items: Sequence[str]) -> str: ...
|
||||
|
||||
|
||||
@overload
|
||||
def common_prefix(items: Sequence[Sequence[T]]) -> Sequence[T]: ...
|
||||
|
||||
|
||||
def common_prefix(items: Sequence[Sequence[T] | str]) -> Sequence[T] | str:
|
||||
"""Find the longest prefix common to all items."""
|
||||
if len(items) == 0:
|
||||
return []
|
||||
if len(items) == 1:
|
||||
return items[0]
|
||||
|
||||
shortest = min(items, key=len)
|
||||
if not shortest:
|
||||
return shortest[:0]
|
||||
|
||||
for match_len in range(1, len(shortest) + 1):
|
||||
match = shortest[:match_len]
|
||||
for item in items:
|
||||
if item[:match_len] != match:
|
||||
return shortest[: match_len - 1]
|
||||
|
||||
return shortest
|
||||
|
||||
|
||||
def chunk_list(lst: list[T], chunk_size: int) -> Generator[list[T]]:
|
||||
"""Yield successive chunk_size chunks from lst."""
|
||||
for i in range(0, len(lst), chunk_size):
|
||||
|
||||
Reference in New Issue
Block a user