diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index d49cf8850..f62d793ef 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -621,14 +621,8 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): inputs_embeds: torch.Tensor | None = None, **kwargs: object, ) -> torch.Tensor | IntermediateTensors: - if inputs_embeds is None: - multimodal_embeddings = self.embed_multimodal(**kwargs) - inputs_embeds = self.embed_input_ids( - input_ids, - multimodal_embeddings, - is_multimodal=input_ids == self.config.image_token_index, - ) - input_ids = None + if intermediate_tensors is not None: + inputs_embeds = None hidden_states = self.language_model( input_ids, diff --git a/vllm/model_executor/models/dots_ocr.py b/vllm/model_executor/models/dots_ocr.py index 90658445f..ca360210b 100644 --- a/vllm/model_executor/models/dots_ocr.py +++ b/vllm/model_executor/models/dots_ocr.py @@ -791,14 +791,6 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA ) -> torch.Tensor | IntermediateTensors: if intermediate_tensors is not None: inputs_embeds = None - elif inputs_embeds is None: - vision_embeddings = self.embed_multimodal(**kwargs) - inputs_embeds = self.embed_input_ids( - input_ids, - vision_embeddings, - is_multimodal=input_ids == self.config.image_token_id, - ) - input_ids = None hidden_states = self.language_model( input_ids=input_ids, diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index b0be74d24..809395cf3 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -71,8 +71,6 @@ def _require_is_multimodal(is_multimodal: Tensor | None) -> Tensor: class LMMissingLayer(nn.Module): - packed_modules_mapping: dict[str, list[str]] = {} - def make_empty_intermediate_tensors(self, *args, **kwargs): raise RuntimeError("This module should not be called in MM encoder-only mode") @@ -81,8 +79,6 @@ class LMMissingLayer(nn.Module): class TowerMissingLayer(nn.Module): - packed_modules_mapping: dict[str, list[str]] = {} - def __init__(self, modalities: set[str] | str) -> None: if isinstance(modalities, str): modalities = {modalities} @@ -92,7 +88,10 @@ class TowerMissingLayer(nn.Module): self.modalities = modalities def __call__(self, *args, **kwargs): - raise RuntimeError(f"The following modalities are disabled: {self.modalities}") + raise RuntimeError( + f"This module should not be called when the following " + f"modalities are disabled: {self.modalities}" + ) @contextmanager diff --git a/vllm/model_executor/models/interns1.py b/vllm/model_executor/models/interns1.py index 441328f5e..13479e306 100644 --- a/vllm/model_executor/models/interns1.py +++ b/vllm/model_executor/models/interns1.py @@ -789,7 +789,6 @@ class InternS1ForConditionalGeneration( **kwargs: object, ) -> IntermediateTensors: if intermediate_tensors is not None: - input_ids = None inputs_embeds = None forward_kwargs = { diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 9b7289419..fedbae445 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -1379,7 +1379,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA) **kwargs: object, ) -> IntermediateTensors: if intermediate_tensors is not None: - input_ids = None inputs_embeds = None forward_kwargs = { diff --git a/vllm/model_executor/models/midashenglm.py b/vllm/model_executor/models/midashenglm.py index 9b243f832..8d89af52c 100644 --- a/vllm/model_executor/models/midashenglm.py +++ b/vllm/model_executor/models/midashenglm.py @@ -707,30 +707,30 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP): config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config self.config = config - - # Initialize audio components - self.audio_encoder = DashengAudioTransformer( - config.audio_encoder_config, - quant_config=quant_config, - prefix=maybe_prefix(prefix, "audio_encoder"), - ) - self.audio_projector = AudioProjectorSubsample( - in_dim=config.audio_encoder_config.embed_dim, - out_dim=config.text_config.hidden_size, - downsample_rate=config.subsample_factor, - quant_config=quant_config, - prefix=maybe_prefix(prefix, "audio_projector"), - ) - - # Initialize language model (decoder) - self.decoder = init_vllm_registered_model( - vllm_config=vllm_config, - hf_config=config.text_config, - prefix=maybe_prefix(prefix, "decoder"), - architectures=["Qwen2ForCausalLM"], - ) - self.quant_config = quant_config + + with self._mark_tower_model(vllm_config, "audio"): + self.audio_encoder = DashengAudioTransformer( + config.audio_encoder_config, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "audio_encoder"), + ) + self.audio_projector = AudioProjectorSubsample( + in_dim=config.audio_encoder_config.embed_dim, + out_dim=config.text_config.hidden_size, + downsample_rate=config.subsample_factor, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "audio_projector"), + ) + + with self._mark_language_model(vllm_config): + self.decoder = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "decoder"), + architectures=["Qwen2ForCausalLM"], + ) + self.make_empty_intermediate_tensors = ( self.decoder.make_empty_intermediate_tensors ) @@ -787,9 +787,6 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP): return torch.split(masked_audio_features, audio_output_lengths.tolist()) - def get_language_model(self) -> torch.nn.Module: - return self.decoder - def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: audio_input = self._parse_and_validate_audio_input(**kwargs) diff --git a/vllm/model_executor/models/minicpmo.py b/vllm/model_executor/models/minicpmo.py index fa2feb0ba..c77ccca0a 100644 --- a/vllm/model_executor/models/minicpmo.py +++ b/vllm/model_executor/models/minicpmo.py @@ -553,9 +553,11 @@ class MiniCPMO(MiniCPMV2_6): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config=vllm_config, prefix=prefix) - self.apm = self.init_audio_module( - vllm_config=vllm_config, prefix=maybe_prefix(prefix, "apm") - ) + + with self._mark_tower_model(vllm_config, "audio"): + self.apm = self.init_audio_module( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "apm") + ) def init_audio_module(self, *, vllm_config: VllmConfig, prefix: str = ""): # Do not use parameters temporarily diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 499d247bc..de76e9abf 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -1028,25 +1028,27 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): self.multimodal_config = multimodal_config self.version = get_version_by_config(self.config) - self.llm = self.init_llm( - vllm_config=vllm_config, prefix=maybe_prefix(prefix, "llm") - ) - self.vpm = self.init_vision_module( - config, quant_config, prefix=maybe_prefix(prefix, "vpm") - ) - self.vision_dim = ( - self.vpm.embed_dim - if self.version == (2, 0) - else self.vpm.embeddings.embed_dim - ) - self.embed_dim = self.config.hidden_size - self.resampler = self.init_resampler( - self.embed_dim, - self.vision_dim, - quant_config=quant_config, - prefix=maybe_prefix(prefix, "resampler"), - ) + with self._mark_language_model(vllm_config): + self.llm = self.init_llm( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "llm") + ) + + with self._mark_tower_model(vllm_config, {"image", "video"}): + self.vpm = vpm = self.init_vision_module( + config, quant_config, prefix=maybe_prefix(prefix, "vpm") + ) + self.vision_dim = ( + vpm.embed_dim if self.version == (2, 0) else vpm.embeddings.embed_dim + ) + self.embed_dim = self.config.hidden_size + + self.resampler = self.init_resampler( + self.embed_dim, + self.vision_dim, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "resampler"), + ) self.make_empty_intermediate_tensors = self.llm.make_empty_intermediate_tensors @@ -1134,9 +1136,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): return multimodal_embeddings - def get_language_model(self) -> torch.nn.Module: - return self.llm - def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: modalities = self._parse_and_validate_multimodal_inputs(**kwargs) if not modalities: diff --git a/vllm/model_executor/models/minimax_vl_01.py b/vllm/model_executor/models/minimax_vl_01.py index b4a496dcb..513c46265 100644 --- a/vllm/model_executor/models/minimax_vl_01.py +++ b/vllm/model_executor/models/minimax_vl_01.py @@ -201,28 +201,33 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, Support self.config = config self.multimodal_config = multimodal_config - # TODO: Optionally initializes this for supporting embeddings. - self.vision_tower = init_vision_tower_for_llava( - config, - quant_config=quant_config, - multimodal_config=multimodal_config, - require_post_norm=False, - prefix=maybe_prefix(prefix, "vision_tower"), - ) - self.multi_modal_projector = MiniMaxVL01MultiModalProjector( - vision_hidden_size=config.vision_config.hidden_size, - text_hidden_size=config.text_config.hidden_size, - projector_hidden_act=config.projector_hidden_act, - multimodal_projector_bias=True, - quant_config=quant_config, - prefix=maybe_prefix(prefix, "multi_modal_projector"), - ) - self.image_newline = nn.Parameter(torch.empty(config.text_config.hidden_size)) - self.language_model = init_vllm_registered_model( - vllm_config=vllm_config, - hf_config=config.text_config, - prefix=maybe_prefix(prefix, "language_model"), - ) + with self._mark_tower_model(vllm_config, "image"): + self.vision_tower = init_vision_tower_for_llava( + config, + quant_config=quant_config, + multimodal_config=multimodal_config, + require_post_norm=False, + prefix=maybe_prefix(prefix, "vision_tower"), + ) + self.multi_modal_projector = MiniMaxVL01MultiModalProjector( + vision_hidden_size=config.vision_config.hidden_size, + text_hidden_size=config.text_config.hidden_size, + projector_hidden_act=config.projector_hidden_act, + multimodal_projector_bias=True, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "multi_modal_projector"), + ) + self.image_newline = nn.Parameter( + torch.empty(config.text_config.hidden_size) + ) + + with self._mark_language_model(vllm_config): + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + ) + self.vision_feature_layer = config.vision_feature_layer self.vocab_size = config.text_config.vocab_size self.pad_token_id = -1 @@ -233,9 +238,6 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, Support self.language_model.make_empty_intermediate_tensors ) - def get_language_model(self) -> torch.nn.Module: - return self.language_model - def _image_pixels_to_features( self, vision_tower: CLIPVisionModel | SiglipVisionModel | PixtralHFVisionModel, @@ -302,8 +304,6 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, Support self, inputs: MiniMaxVL01ImagePixelInputs, ) -> torch.Tensor | tuple[torch.Tensor, ...]: - assert self.vision_tower is not None - pixel_values = inputs["pixel_values"] return self._image_pixels_to_features(self.vision_tower, pixel_values) @@ -314,7 +314,6 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, Support if image_input["type"] == "image_embeds": return image_input["data"] - assert self.vision_tower is not None image_features = self._process_image_pixels(image_input) if isinstance(image_features, torch.Tensor): @@ -369,14 +368,6 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal, Support ) -> torch.Tensor | IntermediateTensors: if intermediate_tensors is not None: inputs_embeds = None - elif inputs_embeds is None: - vision_embeddings = self.embed_multimodal(**kwargs) - inputs_embeds = self.embed_input_ids( - input_ids, - vision_embeddings, - is_multimodal=input_ids == self.config.image_token_index, - ) - input_ids = None hidden_states = self.language_model.model( input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index e30bc6889..6864279ed 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -1441,15 +1441,20 @@ class MolmoForCausalLM( self.multimodal_config = multimodal_config vision_config = VisionBackboneConfig() - self.vision_backbone = MolmoVisionBackbone( - config, - vision_config, - quant_config, - prefix=maybe_prefix(prefix, "vision_backbone"), - ) - self.model = MolmoModel( - vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") - ) + + with self._mark_tower_model(vllm_config, "image"): + self.vision_backbone = MolmoVisionBackbone( + config, + vision_config, + quant_config, + prefix=maybe_prefix(prefix, "vision_backbone"), + ) + + with self._mark_language_model(vllm_config): + self.model = MolmoModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) + self.img_patch_id = None if self.config.weight_tying: @@ -1525,9 +1530,6 @@ class MolmoForCausalLM( results.append(feats[is_valid][order]) return results - def get_language_model(self) -> torch.nn.Module: - return self.model - def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: diff --git a/vllm/model_executor/models/molmo2.py b/vllm/model_executor/models/molmo2.py index 28bb3d94d..b725d023f 100644 --- a/vllm/model_executor/models/molmo2.py +++ b/vllm/model_executor/models/molmo2.py @@ -2514,16 +2514,19 @@ class Molmo2ForConditionalGeneration( kwargs[field.name] = getattr(config.adapter_config, field.name) adapter_config = AdapterConfig(**kwargs) - self.vision_backbone = Molmo2VisionBackbone( - vit_config, - adapter_config, - quant_config, - prefix=maybe_prefix(prefix, "vision_backbone"), - ) - self.model = Molmo2TextModel( - vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model"), - ) + with self._mark_tower_model(vllm_config, {"image", "video"}): + self.vision_backbone = Molmo2VisionBackbone( + vit_config, + adapter_config, + quant_config, + prefix=maybe_prefix(prefix, "vision_backbone"), + ) + + with self._mark_language_model(vllm_config): + self.model = Molmo2TextModel( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model"), + ) self.img_patch_id = config.image_patch_id @@ -2687,9 +2690,6 @@ class Molmo2ForConditionalGeneration( out.append(out_features) return tuple(out) - def get_language_model(self) -> torch.nn.Module: - return self.model - def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None: modalities = self._parse_and_validate_multimodal_inputs(**kwargs) if not modalities: diff --git a/vllm/model_executor/models/nano_nemotron_vl.py b/vllm/model_executor/models/nano_nemotron_vl.py index 82a458426..c9f51db27 100644 --- a/vllm/model_executor/models/nano_nemotron_vl.py +++ b/vllm/model_executor/models/nano_nemotron_vl.py @@ -1511,34 +1511,38 @@ class NemotronH_Nano_VL_V2( self.ps_version = config.ps_version self.image_tag_type = config.image_tag_type self.video_pruning_rate = multimodal_config.video_pruning_rate - self.language_model = init_vllm_registered_model( - vllm_config=vllm_config, - hf_config=config.text_config, - prefix=maybe_prefix(prefix, "language_model"), - ) - self.vision_model = self.get_vit_model_from_radio_config(config).to( - self.language_model.config.dtype - ) - # Construct the vision projection. - vit_hidden_size = config.vit_hidden_size - vision_projection_hidden_size = config.projector_hidden_size - llm_hidden_size = config.text_config.hidden_size + with self._mark_language_model(vllm_config): + self.language_model = language_model = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + ) - self.mlp1 = nn.Sequential( - RMSNorm( - hidden_size=vit_hidden_size * int(1 / self.downsample_ratio) ** 2, - eps=1e-5, - ), - nn.Linear( - vit_hidden_size * int(1 / self.downsample_ratio) ** 2, - vision_projection_hidden_size, - bias=False, - ), - ReLUSquaredActivation(), - nn.Linear(vision_projection_hidden_size, llm_hidden_size, bias=False), - ) - self.mlp1 = self.mlp1.to(self.language_model.config.dtype) + with self._mark_tower_model(vllm_config, {"image", "video"}): + self.vision_model = self.get_vit_model_from_radio_config(config).to( + self.language_model.config.dtype + ) + + # Construct the vision projection. + vit_hidden_size = config.vit_hidden_size + vision_projection_hidden_size = config.projector_hidden_size + llm_hidden_size = config.text_config.hidden_size + + mlp1 = nn.Sequential( + RMSNorm( + hidden_size=vit_hidden_size * int(1 / self.downsample_ratio) ** 2, + eps=1e-5, + ), + nn.Linear( + vit_hidden_size * int(1 / self.downsample_ratio) ** 2, + vision_projection_hidden_size, + bias=False, + ), + ReLUSquaredActivation(), + nn.Linear(vision_projection_hidden_size, llm_hidden_size, bias=False), + ) + self.mlp1 = mlp1.to(language_model.config.dtype) self.config = config self.model_config = vllm_config.model_config @@ -1909,9 +1913,6 @@ class NemotronH_Nano_VL_V2( return multimodal_embeddings - def get_language_model(self) -> torch.nn.Module: - return self.language_model - def forward( self, input_ids: torch.Tensor, @@ -1921,7 +1922,6 @@ class NemotronH_Nano_VL_V2( **kwargs: object, ) -> torch.Tensor | IntermediateTensors: if intermediate_tensors is not None: - input_ids = None inputs_embeds = None hidden_states = self.language_model( diff --git a/vllm/model_executor/models/nemotron_parse.py b/vllm/model_executor/models/nemotron_parse.py index 505ba35c6..8f66bb897 100644 --- a/vllm/model_executor/models/nemotron_parse.py +++ b/vllm/model_executor/models/nemotron_parse.py @@ -820,16 +820,18 @@ class NemotronParseForConditionalGeneration(nn.Module, SupportsMultiModal): cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config - self.encoder = RadioWithNeck( - config=config, quant_config=quant_config, prefix=f"{prefix}.encoder" - ) + with self._mark_tower_model(vllm_config, "image"): + self.encoder = RadioWithNeck( + config=config, quant_config=quant_config, prefix=f"{prefix}.encoder" + ) - self.decoder = MBartDecoderNoPos( - config.decoder, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.decoder", - ) + with self._mark_language_model(vllm_config): + self.decoder = MBartDecoderNoPos( + config.decoder, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.decoder", + ) self.vocab_size = config.decoder.vocab_size self.lm_head = ParallelLMHead( @@ -883,9 +885,6 @@ class NemotronParseForConditionalGeneration(nn.Module, SupportsMultiModal): pixel_values = pixel_values.to(dtype) return self.encoder(pixel_values) - def get_language_model(self) -> torch.nn.Module: - return self.decoder - def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: diff --git a/vllm/model_executor/models/nemotron_vl.py b/vllm/model_executor/models/nemotron_vl.py index 391980fc6..620cee109 100644 --- a/vllm/model_executor/models/nemotron_vl.py +++ b/vllm/model_executor/models/nemotron_vl.py @@ -385,20 +385,20 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor self.downsample_ratio = config.downsample_ratio self.ps_version = config.ps_version - self.llm_arch_name = config.text_config.architectures[0] - self.vision_model = self._init_vision_model( - config, - quant_config=quant_config, - prefix=maybe_prefix(prefix, "vision_model"), - ) + with self._mark_tower_model(vllm_config, "image"): + self.vision_model = self._init_vision_model( + config, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "vision_model"), + ) + self.mlp1 = self._init_mlp1(config) - self.language_model = init_vllm_registered_model( - vllm_config=vllm_config, - hf_config=config.text_config, - prefix=maybe_prefix(prefix, "language_model"), - ) - - self.mlp1 = self._init_mlp1(config) + with self._mark_language_model(vllm_config): + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + ) self.img_context_token_id = None @@ -520,8 +520,6 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor if image_input["type"] == "image_embeds": return image_input["data"] - assert self.vision_model is not None - image_embeds = self.extract_feature(image_input["pixel_values_flat"]) num_patches = image_input["num_patches"] @@ -556,9 +554,6 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None: self.visual_token_mask = None - def get_language_model(self) -> torch.nn.Module: - return self.language_model - def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: modalities = self._parse_and_validate_multimodal_inputs(**kwargs) if not modalities: @@ -609,7 +604,6 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor **kwargs: object, ) -> IntermediateTensors: if intermediate_tensors is not None: - input_ids = None inputs_embeds = None forward_kwargs = { diff --git a/vllm/model_executor/models/ovis.py b/vllm/model_executor/models/ovis.py index fda848916..af3c6669d 100644 --- a/vllm/model_executor/models/ovis.py +++ b/vllm/model_executor/models/ovis.py @@ -417,7 +417,7 @@ class Ovis(nn.Module, SupportsMultiModal, SupportsPP): @classmethod def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): - return "" + return IMAGE_TOKEN raise ValueError("Only image modality is supported") @@ -427,20 +427,22 @@ class Ovis(nn.Module, SupportsMultiModal, SupportsPP): quant_config = vllm_config.quant_config self.config: PretrainedConfig = config - self.llm = init_vllm_registered_model( - vllm_config=vllm_config.with_hf_config(config.get_text_config()), - prefix=maybe_prefix(prefix, "llm"), - ) - self.visual_tokenizer = VisualTokenizer( - config=config.visual_tokenizer_config, - quant_config=quant_config, - prefix=f"{prefix}.visual_tokenizer", - ) + with self._mark_language_model(vllm_config): + self.llm = init_vllm_registered_model( + vllm_config=vllm_config.with_hf_config(config.get_text_config()), + prefix=maybe_prefix(prefix, "llm"), + ) - self.vte = VisualEmbedding( - self.config.visual_tokenizer_config.vocab_size, self.config.hidden_size - ) + with self._mark_tower_model(vllm_config, "image"): + self.visual_tokenizer = VisualTokenizer( + config=config.visual_tokenizer_config, + quant_config=quant_config, + prefix=f"{prefix}.visual_tokenizer", + ) + self.vte = VisualEmbedding( + self.config.visual_tokenizer_config.vocab_size, self.config.hidden_size + ) text_model_type = self.config.get_text_config().model_type self.image_pad_token_id = IMAGE_PAD_TOKEN_ID_MAP[text_model_type] @@ -546,12 +548,8 @@ class Ovis(nn.Module, SupportsMultiModal, SupportsPP): self, hidden_states: torch.Tensor, ) -> torch.Tensor | None: - logits = self.llm.compute_logits(hidden_states) - return logits + return self.llm.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) - - def get_language_model(self) -> torch.nn.Module: - return self.llm diff --git a/vllm/model_executor/models/ovis2_5.py b/vllm/model_executor/models/ovis2_5.py index db59cb38e..501116252 100644 --- a/vllm/model_executor/models/ovis2_5.py +++ b/vllm/model_executor/models/ovis2_5.py @@ -451,6 +451,15 @@ class Ovis2_5MultiModalProcessor(BaseMultiModalProcessor[Ovis2_5ProcessingInfo]) dummy_inputs=Ovis2_5DummyInputsBuilder, ) class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP): + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> str | None: + if modality.startswith("image"): + return IMAGE_TOKEN + if modality.startswith("video"): + return VIDEO_TOKEN + + raise ValueError("Only image or video modality is supported") + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config @@ -458,20 +467,22 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP): multimodal_config = vllm_config.model_config.multimodal_config self.config: PretrainedConfig = config - self.llm = init_vllm_registered_model( - vllm_config=vllm_config.with_hf_config(config.text_config), - prefix=maybe_prefix(prefix, "llm"), - ) - self.visual_tokenizer = VisualTokenizer( - config=config.vit_config, - visual_vocab_size=config.visual_vocab_size, - multimodal_config=multimodal_config, - quant_config=quant_config, - prefix=f"{prefix}.visual_tokenizer", - ) + with self._mark_language_model(vllm_config): + self.llm = init_vllm_registered_model( + vllm_config=vllm_config.with_hf_config(config.text_config), + prefix=maybe_prefix(prefix, "llm"), + ) - self.vte = VisualEmbedding(config.visual_vocab_size, config.hidden_size) + with self._mark_tower_model(vllm_config, {"image", "video"}): + self.visual_tokenizer = VisualTokenizer( + config=config.vit_config, + visual_vocab_size=config.visual_vocab_size, + multimodal_config=multimodal_config, + quant_config=quant_config, + prefix=f"{prefix}.visual_tokenizer", + ) + self.vte = VisualEmbedding(config.visual_vocab_size, config.hidden_size) text_model_type = self.config.get_text_config().model_type self.image_pad_token_id = IMAGE_PAD_TOKEN_ID_MAP[text_model_type] @@ -650,12 +661,8 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP): self, hidden_states: torch.Tensor, ) -> torch.Tensor | None: - logits = self.llm.compute_logits(hidden_states) - return logits + return self.llm.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) - - def get_language_model(self) -> torch.nn.Module: - return self.llm diff --git a/vllm/model_executor/models/paddleocr_vl.py b/vllm/model_executor/models/paddleocr_vl.py index 555073430..ebccef986 100644 --- a/vllm/model_executor/models/paddleocr_vl.py +++ b/vllm/model_executor/models/paddleocr_vl.py @@ -999,6 +999,13 @@ class PaddleOCRVLForConditionalGeneration(nn.Module, SupportsMultiModal, Support } ) + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> str | None: + if modality.startswith("image"): + return "<|IMAGE_START|><|IMAGE_PLACEHOLDER|><|IMAGE_END|>" + + raise ValueError("Only image modality is supported") + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config @@ -1008,22 +1015,24 @@ class PaddleOCRVLForConditionalGeneration(nn.Module, SupportsMultiModal, Support self.config = config self.multimodal_config = multimodal_config - self.visual = SiglipVisionModel( - config=config.vision_config, - quant_config=quant_config, - multimodal_config=multimodal_config, - prefix=maybe_prefix(prefix, "visual"), - ) - self.mlp_AR = Projector(config, config.vision_config) + with self._mark_tower_model(vllm_config, "image"): + self.visual = SiglipVisionModel( + config=config.vision_config, + quant_config=quant_config, + multimodal_config=multimodal_config, + prefix=maybe_prefix(prefix, "visual"), + ) + self.mlp_AR = Projector(config, config.vision_config) - self.language_model = Ernie4_5ForCausalLM( - vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "language_model"), - ) + with self._mark_language_model(vllm_config): + self.language_model = language_model = Ernie4_5ForCausalLM( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "language_model"), + ) - for layer in self.language_model.model.layers: - if not isinstance(layer, PPMissingLayer): - layer.self_attn.rotary_emb.is_neox_style = True + for layer in language_model.model.layers: + if not isinstance(layer, PPMissingLayer): + layer.self_attn.rotary_emb.is_neox_style = True self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors @@ -1151,9 +1160,6 @@ class PaddleOCRVLForConditionalGeneration(nn.Module, SupportsMultiModal, Support return llm_positions, mrope_position_delta - def get_language_model(self) -> nn.Module: - return self.language_model - def _parse_and_validate_image_input( self, **kwargs: object ) -> PaddleOCRImagePixelInputs | None: @@ -1180,29 +1186,10 @@ class PaddleOCRVLForConditionalGeneration(nn.Module, SupportsMultiModal, Support if intermediate_tensors is not None: inputs_embeds = None - elif inputs_embeds is None: - vision_embeddings = self.embed_multimodal(**kwargs) - is_multimodal = kwargs.pop("is_multimodal", None) - handle_oov_mm_token = kwargs.pop("handle_oov_mm_token", False) - inputs_embeds = self.embed_input_ids( - input_ids, - vision_embeddings, - is_multimodal=is_multimodal, - handle_oov_mm_token=handle_oov_mm_token, - ) - input_ids = None - return self.language_model( input_ids, positions, intermediate_tensors, inputs_embeds ) - @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> str | None: - if modality.startswith("image"): - return "<|IMAGE_START|><|IMAGE_PLACEHOLDER|><|IMAGE_END|>" - - raise ValueError("Only image modality is supported") - def encode_image( self, pixel_values: torch.Tensor, image_grid_thw: torch.Tensor ) -> torch.Tensor: diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index 40c43827b..6a9cc5c03 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -295,30 +295,32 @@ class PaliGemmaForConditionalGeneration( multimodal_config = vllm_config.model_config.multimodal_config self.config = config self.multimodal_config = multimodal_config - - self.vision_tower = SiglipVisionModel( - config.vision_config, - quant_config, - prefix=maybe_prefix(prefix, "vision_tower"), - ) - self.multi_modal_projector = PaliGemmaMultiModalProjector( - vision_hidden_size=config.vision_config.hidden_size, - projection_dim=config.vision_config.projection_dim, - ) - self.quant_config = quant_config + with self._mark_tower_model(vllm_config, "image"): + self.vision_tower = SiglipVisionModel( + config.vision_config, + quant_config, + prefix=maybe_prefix(prefix, "vision_tower"), + ) + self.multi_modal_projector = PaliGemmaMultiModalProjector( + vision_hidden_size=config.vision_config.hidden_size, + projection_dim=config.vision_config.projection_dim, + ) + if config.text_config.model_type == "gemma": config.text_config.architectures = ["GemmaForCausalLM"] else: config.text_config.architectures = ["Gemma2ForCausalLM"] - self.language_model = init_vllm_registered_model( - vllm_config=vllm_config, - hf_config=config.text_config, - prefix=maybe_prefix(prefix, "language_model"), - ) - logit_scale = getattr(config, "logit_scale", 1.0) - self.language_model.logits_processor.scale *= logit_scale + + with self._mark_language_model(vllm_config): + self.language_model = language_model = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + ) + logit_scale = getattr(config, "logit_scale", 1.0) + language_model.logits_processor.scale *= logit_scale self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors @@ -367,7 +369,6 @@ class PaliGemmaForConditionalGeneration( if image_input["type"] == "image_embeds": return image_input["data"] - assert self.vision_tower is not None pixel_values = image_input["data"] image_features = self._image_pixels_to_features( self.vision_tower, @@ -376,9 +377,6 @@ class PaliGemmaForConditionalGeneration( return self.multi_modal_projector(image_features) - def get_language_model(self) -> torch.nn.Module: - return self.language_model - def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 3de269309..169182cc1 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -586,31 +586,31 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant) self.multimodal_config = multimodal_config self.image_token_id = _IMAGE_TOKEN_ID - self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=maybe_prefix(prefix, "model.embed_tokens"), - ) + with self._mark_tower_model(vllm_config, "image"): + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "model.embed_tokens"), + ) + self.vision_embed_tokens = Phi3HDImageEmbedding( + config, + quant_config=quant_config, + multimodal_config=multimodal_config, + prefix=maybe_prefix(prefix, "model.vision_embed_tokens"), + ) - # TODO: Optionally initializes this for supporting input embeddings. - self.vision_embed_tokens = Phi3HDImageEmbedding( - config, - quant_config=quant_config, - multimodal_config=multimodal_config, - prefix=maybe_prefix(prefix, "model.vision_embed_tokens"), - ) - - self.language_model = init_vllm_registered_model( - vllm_config=vllm_config, - # The prefix is empty intentionally because default prefix of - # LlamaForCausalLM is "model" - prefix="", - # We don't directly initialize vLLM's LlamaForCausalLM so we - # can automatically apply embedding wrapper if this model is - # initialized as an embedding model - architectures=["LlamaForCausalLM"], - ) + with self._mark_language_model(vllm_config): + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + # The prefix is empty intentionally because default prefix of + # LlamaForCausalLM is "model" + prefix="", + # We don't directly initialize vLLM's LlamaForCausalLM so we + # can automatically apply embedding wrapper if this model is + # initialized as an embedding model + architectures=["LlamaForCausalLM"], + ) self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors @@ -652,17 +652,12 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant) if image_input["type"] == "image_embeds": return image_input["data"] - assert self.vision_embed_tokens is not None - image_embeds = self.vision_embed_tokens( image_input["pixel_values"], image_input["image_sizes"] ) return image_embeds - def get_language_model(self) -> torch.nn.Module: - return self.language_model - def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: diff --git a/vllm/model_executor/models/phi4mm.py b/vllm/model_executor/models/phi4mm.py index 2839c9304..06d551a4c 100644 --- a/vllm/model_executor/models/phi4mm.py +++ b/vllm/model_executor/models/phi4mm.py @@ -1027,12 +1027,13 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): # Tensor/Pipeline parallel not supported for now. assert get_pp_group().world_size == 1, "pipeline parallel is not supported" - self.vision_encoder = Phi4MMImageEncoder( - config, - quant_config, - prefix="model.vision_embed_tokens", - model_dir=config._name_or_path, - ) + with self._mark_tower_model(vllm_config, {"image", "video"}): + self.vision_encoder = Phi4MMImageEncoder( + config, + quant_config, + prefix="model.vision_embed_tokens", + model_dir=config._name_or_path, + ) if isinstance(config.embd_layer["audio_embd_layer"], dict): embedding_config = { @@ -1044,10 +1045,13 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): "embedding_cls": self.config.embd_layer["embedding_cls"] } - self.embed_tokens_extend = AudioEmbedding(config, **embedding_config) - self.model = LlamaModel( - vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") - ) + with self._mark_tower_model(vllm_config, "audio"): + self.embed_tokens_extend = AudioEmbedding(config, **embedding_config) + + with self._mark_language_model(vllm_config): + self.model = LlamaModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) self.lm_head = ParallelLMHead( config.vocab_size, @@ -1245,6 +1249,3 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): connector=["audio_projection_for_vision", "audio_projection"], tower_model=["vision_encoder", "embed_tokens_extend"], ) - - def get_language_model(self) -> torch.nn.Module: - return self.model diff --git a/vllm/model_executor/models/qwen3_vl_moe.py b/vllm/model_executor/models/qwen3_vl_moe.py index 40f87a04a..8df1458e3 100644 --- a/vllm/model_executor/models/qwen3_vl_moe.py +++ b/vllm/model_executor/models/qwen3_vl_moe.py @@ -453,15 +453,15 @@ class Qwen3VLMoeForConditionalGeneration( ] with self._mark_language_model(vllm_config): - self.language_model = Qwen3MoeLLMForCausalLM( + self.language_model = language_model = Qwen3MoeLLMForCausalLM( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "language_model") ) - # Whether to include the gate_up_proj mapping is determined by - # the language model. - self.packed_modules_mapping = ( - self.packed_modules_mapping | self.language_model.packed_modules_mapping - ) + # Whether to include the gate_up_proj mapping is determined by + # the language model. + self.packed_modules_mapping = ( + self.packed_modules_mapping | language_model.packed_modules_mapping + ) self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors diff --git a/vllm/model_executor/models/skyworkr1v.py b/vllm/model_executor/models/skyworkr1v.py index ec676805a..f54d603b7 100644 --- a/vllm/model_executor/models/skyworkr1v.py +++ b/vllm/model_executor/models/skyworkr1v.py @@ -908,7 +908,6 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP): **kwargs: object, ) -> IntermediateTensors: if intermediate_tensors is not None: - input_ids = None inputs_embeds = None forward_kwargs = { diff --git a/vllm/model_executor/models/step3_vl.py b/vllm/model_executor/models/step3_vl.py index 80f5568b4..f057a6e3f 100644 --- a/vllm/model_executor/models/step3_vl.py +++ b/vllm/model_executor/models/step3_vl.py @@ -1104,14 +1104,6 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP) ) -> torch.Tensor | IntermediateTensors: if intermediate_tensors is not None: inputs_embeds = None - elif inputs_embeds is None: - vision_embeddings = self.embed_multimodal(**kwargs) - inputs_embeds = self.embed_input_ids( - input_ids, - vision_embeddings, - is_multimodal=input_ids == self.config.image_token_id, - ) - input_ids = None hidden_states = self.language_model( input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds diff --git a/vllm/model_executor/models/tarsier.py b/vllm/model_executor/models/tarsier.py index 935d2575a..87ce679e6 100644 --- a/vllm/model_executor/models/tarsier.py +++ b/vllm/model_executor/models/tarsier.py @@ -597,14 +597,7 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP) ) -> torch.Tensor | IntermediateTensors: if intermediate_tensors is not None: inputs_embeds = None - elif inputs_embeds is None: - vision_embeddings = self.embed_multimodal(**kwargs) - inputs_embeds = self.embed_input_ids( - input_ids, - vision_embeddings, - is_multimodal=input_ids == self.config.image_token_index, - ) - input_ids = None + hidden_states = self.language_model.model( input_ids=input_ids, positions=positions,