[5/N] Initialize MM components in context managers (Q-Z) (#32695)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-01-21 03:10:23 +08:00
committed by GitHub
parent f0feb1cf81
commit 193069d129
9 changed files with 178 additions and 168 deletions

View File

@@ -334,20 +334,21 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, Supports
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.multimodal_config = multimodal_config
self.audio_tower = Qwen2AudioEncoder(config.audio_config)
self.multi_modal_projector = Qwen2AudioMultiModalProjector(
config.audio_config.d_model, config.text_config.hidden_size
)
self.quant_config = quant_config
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
architectures=["Qwen2ForCausalLM"],
)
with self._mark_tower_model(vllm_config, "audio"):
self.audio_tower = Qwen2AudioEncoder(config.audio_config)
self.multi_modal_projector = Qwen2AudioMultiModalProjector(
config.audio_config.d_model, config.text_config.hidden_size
)
with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
architectures=["Qwen2ForCausalLM"],
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
@@ -441,9 +442,6 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, Supports
masked_audio_features, audio_output_lengths.flatten().tolist()
)
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
audio_input = self._parse_and_validate_audio_input(**kwargs)
if audio_input is None:

View File

@@ -1612,32 +1612,14 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
multimodal_config = vllm_config.model_config.multimodal_config
self.config = thinker_config
self.multimodal_config = multimodal_config
self.audio_tower = Qwen3OmniMoeAudioEncoder(
thinker_config.audio_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "audio_tower"),
)
self.visual = Qwen3Omni_VisionTransformer(
vision_config=thinker_config.vision_config,
norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"),
multimodal_config=multimodal_config,
)
self.quant_config = quant_config
self.language_model = Qwen3MoeLLMForCausalLM(
vllm_config=vllm_config.with_hf_config(
thinker_config.text_config, architectures=["Qwen3MoeForCausalLM"]
),
prefix=maybe_prefix(prefix, "language_model"),
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
)
with self._mark_tower_model(vllm_config, "audio"):
self.audio_tower = Qwen3OmniMoeAudioEncoder(
thinker_config.audio_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "audio_tower"),
)
self.use_deepstack = hasattr(
thinker_config.vision_config, "deepstack_visual_indexes"
@@ -1647,22 +1629,48 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
if self.use_deepstack
else 0
)
# register buffer for deepstack
self.deepstack_input_embeds = (
[
torch.zeros(
vllm_config.scheduler_config.max_num_batched_tokens,
thinker_config.text_config.hidden_size,
)
for _ in range(self.deepstack_num_level)
]
if self.use_deepstack
else None
)
self.visual_dim = thinker_config.vision_config.out_hidden_size
self.multiscale_dim = self.visual_dim * self.deepstack_num_level
def _get_deepstack_input_embeds(self, num_tokens: int) -> IntermediateTensors:
with self._mark_tower_model(vllm_config, {"image", "video"}):
self.visual = Qwen3Omni_VisionTransformer(
vision_config=thinker_config.vision_config,
norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"),
multimodal_config=multimodal_config,
)
# register buffer for deepstack
if self.use_deepstack:
self.deepstack_input_embeds = [
torch.zeros(
vllm_config.scheduler_config.max_num_batched_tokens,
thinker_config.text_config.hidden_size,
)
for _ in range(self.deepstack_num_level)
]
with self._mark_language_model(vllm_config):
self.language_model = Qwen3MoeLLMForCausalLM(
vllm_config=vllm_config.with_hf_config(
thinker_config.text_config,
architectures=["Qwen3MoeForCausalLM"],
),
prefix=maybe_prefix(prefix, "language_model"),
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
)
def _get_deepstack_input_embeds(
self,
num_tokens: int,
) -> IntermediateTensors | None:
if not getattr(self, "deepstack_input_embeds", None):
return None # If vision tower is skipped
# get deepstack_input_embeds from buffer, and clear the buffer
return IntermediateTensors(
{
@@ -1674,6 +1682,9 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
)
def _set_deepstack_input_embeds(self, deepstack_input_embeds: torch.Tensor) -> None:
if not getattr(self, "deepstack_input_embeds", None):
return
# set deepstack_input_embeds to buffer
num_tokens = deepstack_input_embeds.size(1)
if num_tokens > self.deepstack_input_embeds[0].size(0):
@@ -1692,6 +1703,9 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
)
def _clear_deepstack_input_embeds(self, num_tokens: int) -> None:
if not getattr(self, "deepstack_input_embeds", None):
return
# clear deepstack_input_embeds in buffer
if num_tokens > 0:
for idx in range(self.deepstack_num_level):
@@ -1726,9 +1740,6 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
)
return mm_input_by_modality
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None:
mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
if not mm_input_by_modality:
@@ -1844,11 +1855,7 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
if intermediate_tensors is not None:
inputs_embeds = None
if (
self.use_deepstack
and inputs_embeds is not None
and get_pp_group().is_first_rank
):
if inputs_embeds is not None and get_pp_group().is_first_rank:
deepstack_input_embeds = self._get_deepstack_input_embeds(
inputs_embeds.size(0)
)

View File

@@ -1321,7 +1321,13 @@ class Qwen3VLForConditionalGeneration(
num_layers = len(self.language_model.model.layers)
return (2, num_layers // 2, num_layers - 3)
def _get_deepstack_input_embeds(self, num_tokens: int) -> IntermediateTensors:
def _get_deepstack_input_embeds(
self,
num_tokens: int,
) -> IntermediateTensors | None:
if not getattr(self, "deepstack_input_embeds", None):
return None # If vision tower is skipped
# get deepstack_input_embeds from buffer, and clear the buffer
return IntermediateTensors(
{
@@ -1333,6 +1339,9 @@ class Qwen3VLForConditionalGeneration(
)
def _set_deepstack_input_embeds(self, deepstack_input_embeds: torch.Tensor) -> None:
if not getattr(self, "deepstack_input_embeds", None):
return
# set deepstack_input_embeds to buffer
num_tokens = deepstack_input_embeds.size(1)
if num_tokens > self.deepstack_input_embeds[0].size(0):
@@ -1351,6 +1360,9 @@ class Qwen3VLForConditionalGeneration(
)
def _clear_deepstack_input_embeds(self, num_tokens: int) -> None:
if not getattr(self, "deepstack_input_embeds", None):
return
# clear deepstack_input_embeds in buffer
if num_tokens > 0:
for idx in range(self.deepstack_num_level):
@@ -2037,11 +2049,7 @@ class Qwen3VLForConditionalGeneration(
if intermediate_tensors is not None:
inputs_embeds = None
if (
self.use_deepstack
and inputs_embeds is not None
and get_pp_group().is_first_rank
):
if inputs_embeds is not None and get_pp_group().is_first_rank:
deepstack_input_embeds = self._get_deepstack_input_embeds(
inputs_embeds.size(0)
)

View File

@@ -620,7 +620,6 @@ class RadioInternVisionModel(nn.Module):
x: torch.Tensor,
imgs_sizes: torch.Tensor | None = None,
) -> torch.FloatTensor:
assert self.patch_generator is not None
hidden_states = self.patch_generator(x, imgs_sizes=imgs_sizes)
attn_mask = None
if imgs_sizes is not None and len(imgs_sizes) > 1:

View File

@@ -1033,21 +1033,23 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
self.text_embed_dim = text_config.hidden_size
self.vision_embed_dim = vision_config.hidden_size
self.text_model = SiglipTextTransformer(
text_config,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "text_model"),
)
self.vision_model = SiglipVisionTransformer(
vision_config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "vision_model"),
)
self.text_projection_size = text_config.projection_size
with self._mark_language_model(vllm_config):
self.text_model = SiglipTextTransformer(
text_config,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "text_model"),
)
with self._mark_tower_model(vllm_config, "image"):
self.vision_model = SiglipVisionTransformer(
vision_config,
quant_config=quant_config,
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "vision_model"),
)
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler_config = pooler_config
@@ -1155,9 +1157,6 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
return self.get_image_features(pixel_values)
def get_language_model(self) -> torch.nn.Module:
return self.text_model
def _embed_text_input_ids(
self,
input_ids: torch.Tensor,

View File

@@ -674,24 +674,26 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
self.downsample_ratio = config.downsample_ratio
self.ps_version = config.ps_version
self.llm_arch_name = config.text_config.architectures[0]
self.is_mono = self.llm_arch_name == "SkyworkLM2VEForCausalLM"
self.vision_model = self._init_vision_model(
config,
quant_config=quant_config,
is_mono=self.is_mono,
prefix=maybe_prefix(prefix, "vision_model"),
)
llm_arch_name = config.text_config.architectures[0]
self.is_mono = llm_arch_name == "SkyworkLM2VEForCausalLM"
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
with self._mark_tower_model(vllm_config, "image"):
self.vision_model = self._init_vision_model(
config,
quant_config=quant_config,
is_mono=self.is_mono,
prefix=maybe_prefix(prefix, "vision_model"),
)
self.mlp1 = self._init_mlp1(
config, quant_config, prefix=maybe_prefix(prefix, "mlp1")
)
self.mlp1 = self._init_mlp1(
config, quant_config, prefix=maybe_prefix(prefix, "mlp1")
)
with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.img_context_token_id = None
self.visual_token_mask = None
@@ -838,8 +840,6 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
if image_input["type"] == "image_embeds":
return image_input["data"]
assert self.vision_model is not None
image_embeds = self.extract_feature(image_input["pixel_values_flat"])
num_patches = image_input["num_patches"]
@@ -867,9 +867,6 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
else:
self.visual_token_mask = None
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:

View File

@@ -423,38 +423,43 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config # Storing the Tarsier-specific HF config
self.vision_tower = init_vision_tower_for_tarsier(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
require_post_norm=False,
prefix=maybe_prefix(prefix, "vision_tower"),
)
projector_bias = getattr(config, "multimodal_projector_bias", True)
self.multi_modal_projector = TarsierMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size,
text_hidden_size=config.text_config.hidden_size,
projector_hidden_act=config.projector_hidden_act,
multimodal_projector_bias=projector_bias,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "multi_modal_projector"),
)
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config, # Use text_config from Tarsier's main config
prefix=maybe_prefix(prefix, "language_model"),
)
self.register_buffer(
"image_newline_idx_tensor",
torch.tensor([config.image_newline_idx], dtype=torch.long),
persistent=False,
)
self.register_buffer(
"image_new_idx_tensor",
torch.tensor([config.image_new_idx], dtype=torch.long),
persistent=False,
)
with self._mark_tower_model(vllm_config, "image"):
self.vision_tower = init_vision_tower_for_tarsier(
config,
quant_config=quant_config,
multimodal_config=multimodal_config,
require_post_norm=False,
prefix=maybe_prefix(prefix, "vision_tower"),
)
projector_bias = getattr(config, "multimodal_projector_bias", True)
self.multi_modal_projector = TarsierMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size,
text_hidden_size=config.text_config.hidden_size,
projector_hidden_act=config.projector_hidden_act,
multimodal_projector_bias=projector_bias,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "multi_modal_projector"),
)
self.register_buffer(
"image_newline_idx_tensor",
torch.tensor([config.image_newline_idx], dtype=torch.long),
persistent=False,
)
self.register_buffer(
"image_new_idx_tensor",
torch.tensor([config.image_new_idx], dtype=torch.long),
persistent=False,
)
with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
# Use text_config from Tarsier's main config
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
@@ -547,7 +552,6 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
self,
inputs: TarsierImagePixelInputs,
) -> torch.Tensor | tuple[torch.Tensor, ...]:
assert self.vision_tower is not None
pixel_values = inputs["pixel_values"]
image_features_selected = self._image_pixels_to_features(
self.vision_tower, pixel_values
@@ -575,11 +579,8 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
"Incorrect type of image_embeds. "
f"Got type: {type(projected_features)}. "
)
assert self.vision_tower is not None
return self._process_image_pixels(image_input)
def get_language_model(self) -> torch.nn.Module:
return self.language_model
return self._process_image_pixels(image_input)
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)

View File

@@ -543,7 +543,6 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
assert self.multi_modal_config
self.secondary_weights = []
self.audio_tower = ModifiedWhisperEncoder(config.audio_config)
if config.audio_model_id is not None:
# this prefix is not for initialization, but for loading weights
# note the trailing dot
@@ -554,15 +553,6 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
prefix="audio_tower.",
)
)
if config.num_projector_layers > 0:
self.multi_modal_projector = UltravoxTransformerProjector(config)
else:
self.multi_modal_projector = UltravoxFeedForwardProjector(config)
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.wrapped_model_config,
prefix=maybe_prefix(prefix, "language_model"),
)
if config.text_model_id is not None:
# this prefix is not for initialization, but for loading weights
# note the trailing dot
@@ -574,6 +564,20 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
)
)
with self._mark_tower_model(vllm_config, "audio"):
self.audio_tower = ModifiedWhisperEncoder(config.audio_config)
if config.num_projector_layers > 0:
self.multi_modal_projector = UltravoxTransformerProjector(config)
else:
self.multi_modal_projector = UltravoxFeedForwardProjector(config)
with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.wrapped_model_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
)
@@ -681,9 +685,6 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
]
return flattened_embeddings.split(embed_lens)
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
audio_input = self._parse_and_validate_audio_input(**kwargs)
if audio_input is None:

View File

@@ -366,22 +366,22 @@ class VoxtralForConditionalGeneration(
self.config = config
self.downsample_factor = self.config.audio_config.downsample_factor
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.whisper_encoder = VoxtralEncoderModel(
vllm_config.with_hf_config(config.audio_config),
prefix=maybe_prefix(prefix, "whisper_encoder"),
)
self.audio_language_adapter = AudioLanguageAdapter(
hidden_size=config.audio_config.d_model * self.downsample_factor,
dim=config.text_config.hidden_size,
)
with self._mark_language_model(vllm_config):
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
)
def get_language_model(self) -> torch.nn.Module:
return self.language_model
with self._mark_tower_model(vllm_config, "audio"):
self.whisper_encoder = VoxtralEncoderModel(
vllm_config.with_hf_config(config.audio_config),
prefix=maybe_prefix(prefix, "whisper_encoder"),
)
self.audio_language_adapter = AudioLanguageAdapter(
hidden_size=config.audio_config.d_model * self.downsample_factor,
dim=config.text_config.hidden_size,
)
def get_mm_mapping(self) -> MultiModelKeys:
"""Get module prefix for multimodal models to filter LoRA modules."""