[3/N] Initialize MM components in context managers (I-L) (#32650)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-01-20 18:21:56 +08:00
committed by GitHub
parent 8be263c3fb
commit 7f1bcd18ff
10 changed files with 203 additions and 389 deletions

View File

@@ -547,20 +547,20 @@ class InternS1ForConditionalGeneration(
)
self.downsample_ratio = config.downsample_ratio
self.llm_arch_name = config.text_config.architectures[0]
self.vision_tower = self._init_vision_model(
config,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "vision_tower"),
)
with self._mark_tower_model(vllm_config, {"image", "video"}):
self.vision_tower = self._init_vision_model(
config,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "vision_tower"),
)
self.multi_modal_projector = self._init_mlp1(config)
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.multi_modal_projector = self._init_mlp1(config)
with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.img_context_token_id = None
self.video_context_token_id = None
@@ -699,8 +699,6 @@ class InternS1ForConditionalGeneration(
):
return image_input["data"]
assert self.vision_tower is not None
image_embeds = self.extract_feature(image_input["pixel_values"])
num_patches = image_input["num_patches"]
@@ -737,9 +735,6 @@ class InternS1ForConditionalGeneration(
def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None:
self.visual_token_mask = None
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities:

View File

@@ -1092,22 +1092,24 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA)
self.downsample_ratio = config.downsample_ratio
self.ps_version = config.ps_version
self.llm_arch_name = config.text_config.architectures[0]
self.is_mono = self.llm_arch_name == "InternLM2VEForCausalLM"
self.vision_model = self._init_vision_model(
config,
quant_config=quant_config,
is_mono=self.is_mono,
prefix=maybe_prefix(prefix, "vision_model"),
)
llm_arch_name = config.text_config.architectures[0]
self.is_mono = llm_arch_name == "InternLM2VEForCausalLM"
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
with self._mark_tower_model(vllm_config, {"image", "video"}):
self.vision_model = self._init_vision_model(
config,
quant_config=quant_config,
is_mono=self.is_mono,
prefix=maybe_prefix(prefix, "vision_model"),
)
self.mlp1 = self._init_mlp1(config)
self.mlp1 = self._init_mlp1(config)
with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.img_context_token_id = None
self.video_context_token_id = None
@@ -1281,8 +1283,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA)
):
return image_input["data"]
assert self.vision_model is not None
image_embeds = self.extract_feature(image_input["pixel_values_flat"])
num_patches = image_input["num_patches"]
@@ -1325,9 +1325,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA)
else:
self.visual_token_mask = None
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities:

View File

@@ -1342,11 +1342,14 @@ class IsaacForConditionalGeneration(
"mrope_interleaved", rope_scaling["mrope_interleaved"]
)
target_cfg.rope_parameters = rope_parameters
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
architectures=["Qwen3ForCausalLM"],
prefix=maybe_prefix(prefix, "language_model"),
)
with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
architectures=["Qwen3ForCausalLM"],
prefix=maybe_prefix(prefix, "language_model"),
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
)
@@ -1363,14 +1366,16 @@ class IsaacForConditionalGeneration(
vision_cfg._attn_implementation = attn_impl
hidden_dim = vision_cfg.hidden_size * (vision_cfg.pixel_shuffle_scale_factor**2)
self.vision_embedding = IsaacVisionEmbedding(
vision_cfg=vision_cfg,
hidden_dim=hidden_dim,
output_dim=config.hidden_size,
quant_config=quant_config,
multimodal_config=self.multimodal_config,
prefix=maybe_prefix(prefix, "vision_embedding"),
)
with self._mark_tower_model(vllm_config, "image"):
self.vision_embedding = IsaacVisionEmbedding(
vision_cfg=vision_cfg,
hidden_dim=hidden_dim,
output_dim=config.hidden_size,
quant_config=quant_config,
multimodal_config=self.multimodal_config,
prefix=maybe_prefix(prefix, "vision_embedding"),
)
def iter_mm_grid_hw(
self, input_tokens: list[int], mm_features: list[MultiModalFeatureSpec]
@@ -1457,18 +1462,6 @@ class IsaacForConditionalGeneration(
return ()
return self._process_image_input(image_input)
def get_multimodal_embeddings(
self, **kwargs: object
) -> MultiModalEmbeddings | None:
# Backward compatibility for older runners.
embeddings = self.embed_multimodal(**kwargs)
if not embeddings:
return []
return embeddings
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def forward(
self,
input_ids: torch.Tensor,

View File

@@ -586,16 +586,21 @@ class KananaVForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
config = vllm_config.model_config.hf_config
self.config = config
self.vision_model = CustomQwen2VLVE._from_config(config.vision_config)
self.abstractor = DynamicCAbstractor(
config.projector_config, num_input_tokens=self.vision_model.get_num_tokens()
)
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "model"),
architectures=["LlamaForCausalLM"],
)
with self._mark_tower_model(vllm_config, "image"):
self.vision_model = CustomQwen2VLVE._from_config(config.vision_config)
self.abstractor = DynamicCAbstractor(
config.projector_config,
num_input_tokens=self.vision_model.get_num_tokens(),
)
with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "model"),
architectures=["LlamaForCausalLM"],
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
)
@@ -718,9 +723,6 @@ class KananaVForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
visual_embeds = self.forward_projector(visual_features, image_metas=image_metas)
return visual_embeds
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:

View File

@@ -1242,7 +1242,7 @@ class KeyeMultiModalProcessor(BaseMultiModalProcessor[KeyeProcessingInfo]):
return _keye_field_config(hf_inputs)
class BaseKeyeModule(nn.Module):
class BaseKeyeModule(nn.Module, SupportsMultiModal):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
@@ -1280,25 +1280,26 @@ class BaseKeyeModule(nn.Module):
self.config = config
self.multimodal_config = multimodal_config
self.visual = KeyeSiglipVisionModel(
config.vision_config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "visual"),
)
with self._mark_tower_model(vllm_config, {"image", "video"}):
self.visual = KeyeSiglipVisionModel(
config.vision_config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "visual"),
)
self.mlp_AR = self._build_projector(
config,
config.vision_config,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "mlp_AR"),
)
self.mlp_AR = self._build_projector(
config,
config.vision_config,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "mlp_AR"),
)
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model"),
architectures=["Qwen3ForCausalLM"],
)
with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model"),
architectures=["Qwen3ForCausalLM"],
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
@@ -1312,7 +1313,7 @@ class BaseKeyeModule(nn.Module):
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> nn.Module:
raise ValueError("Need projector")
raise NotImplementedError("Need projector")
def _process_image_input(self, image_input: Any) -> tuple[torch.Tensor, ...]:
siglip_position_ids = list()
@@ -1429,9 +1430,6 @@ class BaseKeyeModule(nn.Module):
return modalities
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities:

View File

@@ -42,7 +42,6 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import copy
import math
from collections.abc import Iterable, Mapping, Sequence
from dataclasses import dataclass
@@ -50,23 +49,12 @@ from typing import Annotated, Any, Literal
import torch
from torch import nn
from transformers import BatchFeature, DeepseekV2Config
from transformers import BatchFeature
from transformers.activations import GELUActivation
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_pp_group
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
)
from vllm.model_executor.models.deepseek_v2 import DeepseekV2Model
from vllm.model_executor.models.interfaces import SupportsMultiModal, SupportsPP
from vllm.model_executor.models.moonvit import MoonVitPretrainedModel
from vllm.multimodal import MULTIMODAL_REGISTRY
@@ -92,7 +80,7 @@ from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import KimiVLConfig, MoonViTConfig
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .utils import PPMissingLayer, is_pp_missing_parameter, maybe_prefix
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
from .vision import run_dp_sharded_mrope_vision_model
@@ -315,48 +303,41 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
super().__init__()
model_config = vllm_config.model_config
config: KimiVLConfig = model_config.hf_config
self.config = config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
assert isinstance(config.vision_config, MoonViTConfig)
self.use_data_parallel = (
model_config.multimodal_config.mm_encoder_tp_mode == "data"
)
self.hidden_size = config.text_config.hidden_size
self.vision_tower = MoonVitPretrainedModel(
config.vision_config,
multimodal_config=model_config.multimodal_config,
prefix=maybe_prefix(prefix, "vision_tower"),
)
self.multi_modal_projector = KimiVLMultiModalProjector(
config=config,
use_data_parallel=self.use_data_parallel,
prefix=maybe_prefix(prefix, "multi_modal_projector"),
)
self.quant_config = quant_config
sub_vllm_config = copy.deepcopy(vllm_config)
sub_vllm_config.model_config.hf_config = (
sub_vllm_config.model_config.hf_config.text_config
)
self.language_model = DeepseekV2Model(
vllm_config=sub_vllm_config,
prefix=maybe_prefix(prefix, "language_model"),
)
if get_pp_group().is_last_rank:
self.lm_head = ParallelLMHead(
config.vocab_size,
config.text_config.hidden_size,
prefix=maybe_prefix(prefix, "lm_head"),
with self._mark_tower_model(vllm_config, "image"):
self.vision_tower = MoonVitPretrainedModel(
config.vision_config,
multimodal_config=model_config.multimodal_config,
prefix=maybe_prefix(prefix, "vision_tower"),
)
else:
self.lm_head = PPMissingLayer()
self.multi_modal_projector = KimiVLMultiModalProjector(
config=config,
use_data_parallel=self.use_data_parallel,
prefix=maybe_prefix(prefix, "multi_modal_projector"),
)
with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
architectures=["DeepseekV2ForCausalLM"],
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
)
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(config.vocab_size, scale=logit_scale)
self.media_placeholder: int = self.config.media_placeholder_token_id
def _parse_and_validate_image_input(
@@ -378,8 +359,6 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
# perform vt on processored pixel_values
@torch.inference_mode()
def _process_image_pixels(self, inputs: KimiVLImagePixelInputs) -> torch.Tensor:
assert self.vision_tower is not None
pixel_values = inputs["pixel_values"]
image_grid_hws = inputs["image_grid_hws"]
if self.use_data_parallel:
@@ -399,9 +378,6 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
lengths = [x.shape[0] for x in image_features]
return self.multi_modal_projector(torch.cat(image_features)).split(lengths)
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> NestedTensors | None:
# Validate the multimodal input keyword arguments
image_input = self._parse_and_validate_image_input(**kwargs)
@@ -433,145 +409,8 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
logits = self.logits_processor(self.lm_head, hidden_states, **kwargs)
return logits
return self.language_model.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
config = self.config.text_config
_KEYS_TO_MODIFY_MAPPING = {
"language_model.lm_head": "lm_head",
"language_model.model": "language_model",
}
# only doing this for language model part for now.
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
use_mha = (
config.model_type == "deepseek"
or config.qk_nope_head_dim + config.qk_rope_head_dim == 0
)
if use_mha:
stacked_params_mapping += [
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
]
if getattr(config, "n_routed_experts", None):
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
self,
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=config.n_routed_experts,
)
else:
expert_params_mapping = []
params_dict = dict(self.named_parameters())
for args in weights:
name, loaded_weight = args[:2]
kwargs = args[2] if len(args) > 2 else {}
if "rotary_emb.inv_freq" in name:
continue
spec_layer = get_spec_layer_idx_from_weight_name(config, name)
if spec_layer is not None:
continue # skip spec decode layers for main model
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
if key_to_modify in name:
name = name.replace(key_to_modify, new_key)
use_default_weight_loading = False
if "vision" in name:
if self.vision_tower is not None:
# We only do sharding for language model and
# not vision model for now.
use_default_weight_loading = True
else:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if ("mlp.experts." in name) and name not in params_dict:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id, **kwargs)
break
else:
for idx, (
param_name,
weight_name,
expert_id,
shard_id,
) in enumerate(expert_params_mapping):
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(
param,
loaded_weight,
name,
expert_id=expert_id,
shard_id=shard_id,
**kwargs,
)
break
else:
use_default_weight_loading = True
if use_default_weight_loading:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight, **kwargs)
def get_spec_layer_idx_from_weight_name(
config: DeepseekV2Config, weight_name: str
) -> int | None:
if hasattr(config, "num_nextn_predict_layers") and (
config.num_nextn_predict_layers > 0
):
layer_idx = config.num_hidden_layers
for i in range(config.num_nextn_predict_layers):
if weight_name.startswith(f"model.layers.{layer_idx + i}."):
return layer_idx + i
return None
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)

View File

@@ -546,38 +546,37 @@ class Lfm2VLForConditionalGeneration(
self.multimodal_config = multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
if vision_config.model_type == "siglip2_vision_model":
self.vision_tower = Siglip2Model(
config=vision_config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "vision_tower"),
)
else:
raise ValueError(
f"Unsupported visual tokenizer model_type: {vision_config.model_type}"
with self._mark_tower_model(vllm_config, "image"):
if vision_config.model_type == "siglip2_vision_model":
self.vision_tower = Siglip2Model(
config=vision_config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "vision_tower"),
)
else:
raise ValueError(
f"Unsupported visual tokenizer type: {vision_config.model_type}"
)
self.multi_modal_projector = Lfm2VLMultiModalProjector(
config=config,
use_data_parallel=self.use_data_parallel,
prefix=maybe_prefix(prefix, "multi_modal_projector"),
)
self.multi_modal_projector = Lfm2VLMultiModalProjector(
config=config,
use_data_parallel=self.use_data_parallel,
prefix=f"{prefix}.multi_modal_projector",
)
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language"),
architectures=config.text_config.architectures,
)
with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language"),
architectures=config.text_config.architectures,
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
)
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def _parse_and_validate_image_input(
self, **kwargs: object
) -> LFM2VLImageInputs | None:
@@ -714,8 +713,7 @@ class Lfm2VLForConditionalGeneration(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor | None:
logits = self.language_model.compute_logits(hidden_states)
return logits
return self.language_model.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)

View File

@@ -268,27 +268,30 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP
self.config = config
self.multimodal_config = multimodal_config
# TODO: Optionally initializes this for supporting embeddings.
self.vision_tower = init_vision_tower_for_llava(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
require_post_norm=False,
prefix=maybe_prefix(prefix, "vision_tower"),
)
self.image_newline = nn.Parameter(torch.empty(config.text_config.hidden_size))
self.multi_modal_projector = LlavaMultiModalProjector(
vision_hidden_size=vision_hidden_size,
text_hidden_size=config.text_config.hidden_size,
projector_hidden_act=config.projector_hidden_act,
multimodal_projector_bias=config.multimodal_projector_bias,
)
with self._mark_tower_model(vllm_config, "image"):
self.vision_tower = init_vision_tower_for_llava(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
require_post_norm=False,
prefix=maybe_prefix(prefix, "vision_tower"),
)
self.image_newline = nn.Parameter(
torch.empty(config.text_config.hidden_size)
)
self.multi_modal_projector = LlavaMultiModalProjector(
vision_hidden_size=vision_hidden_size,
text_hidden_size=config.text_config.hidden_size,
projector_hidden_act=config.projector_hidden_act,
multimodal_projector_bias=config.multimodal_projector_bias,
)
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
@@ -427,8 +430,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP
self,
inputs: LlavaNextImagePixelInputs,
) -> torch.Tensor | tuple[torch.Tensor, ...]:
assert self.vision_tower is not None
pixel_values = inputs["pixel_values"]
if isinstance(pixel_values, torch.Tensor):
@@ -480,9 +481,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsP
for i, patch_features_batch in enumerate(patch_embeddings)
]
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:

View File

@@ -312,12 +312,10 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, Supp
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
if modality.startswith("image"):
return "<image>"
if modality.startswith("video"):
return "<video>"
raise ValueError("Only image or video modality is supported")
raise ValueError("Only video modality is supported")
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
@@ -329,26 +327,29 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, Supp
self.config = config
self.multimodal_config = multimodal_config
# Initialize the vision tower only up to the required feature layer
self.vision_tower = init_vision_tower_for_llava(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
require_post_norm=False,
prefix=maybe_prefix(prefix, "vision_tower"),
)
self.vision_resampler = LlavaNextVideoPooler(config)
self.multi_modal_projector = LlavaNextMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size,
text_hidden_size=config.text_config.hidden_size,
projector_hidden_act=config.projector_hidden_act,
multimodal_projector_bias=config.multimodal_projector_bias,
)
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
with self._mark_tower_model(vllm_config, "video"):
# Initialize the vision tower only up to the required feature layer
self.vision_tower = init_vision_tower_for_llava(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
require_post_norm=False,
prefix=maybe_prefix(prefix, "vision_tower"),
)
self.vision_resampler = LlavaNextVideoPooler(config)
self.multi_modal_projector = LlavaNextMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size,
text_hidden_size=config.text_config.hidden_size,
projector_hidden_act=config.projector_hidden_act,
multimodal_projector_bias=config.multimodal_projector_bias,
)
with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.make_empty_intermediate_tensors = (
self.language_model.model.make_empty_intermediate_tensors
@@ -395,8 +396,6 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, Supp
return image_features
def _process_video_pixels(self, inputs: LlavaNextVideoPixelInputs):
assert self.vision_tower is not None
video_pixels = inputs["pixel_values_videos"]
if isinstance(video_pixels, torch.Tensor):
@@ -419,9 +418,6 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, Supp
return [e.flatten(0, 1) for e in embeds]
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
video_input = self._parse_and_validate_video_input(**kwargs)
if video_input is None:

View File

@@ -508,21 +508,26 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, Supp
self.config = config
self.multimodal_config = multimodal_config
# Initialize the vision tower only up to the required feature layer
self.vision_tower = init_vision_tower_for_llava(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
require_post_norm=False,
prefix=maybe_prefix(prefix, "vision_tower"),
)
self.multi_modal_projector = LlavaOnevisionMultiModalProjector(config)
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.image_newline = nn.Parameter(torch.empty(config.text_config.hidden_size))
with self._mark_tower_model(vllm_config, {"image", "video"}):
# Initialize the vision tower only up to the required feature layer
self.vision_tower = init_vision_tower_for_llava(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
require_post_norm=False,
prefix=maybe_prefix(prefix, "vision_tower"),
)
self.image_newline = nn.Parameter(
torch.empty(config.text_config.hidden_size)
)
self.multi_modal_projector = LlavaOnevisionMultiModalProjector(config)
with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.make_empty_intermediate_tensors = (
self.language_model.model.make_empty_intermediate_tensors
@@ -726,8 +731,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, Supp
self,
inputs: LlavaOnevisionImagePixelInputs,
) -> torch.Tensor | list[torch.Tensor]:
assert self.vision_tower is not None
pixel_values = inputs["pixel_values"]
if isinstance(pixel_values, torch.Tensor):
@@ -801,8 +804,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, Supp
return video_features
def _process_video_pixels(self, inputs: LlavaOnevisionVideoPixelInputs):
assert self.vision_tower is not None
video_pixels = inputs["pixel_values_videos"]
if isinstance(video_pixels, torch.Tensor):
@@ -862,9 +863,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, Supp
image_feature = image_feature.view(batch_frames, -1, dim)
return image_feature
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
if not mm_input_by_modality: