[5/N] Initialize MM components in context managers (Q-Z) (#32695)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -334,20 +334,21 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, Supports
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
|
||||
self.audio_tower = Qwen2AudioEncoder(config.audio_config)
|
||||
self.multi_modal_projector = Qwen2AudioMultiModalProjector(
|
||||
config.audio_config.d_model, config.text_config.hidden_size
|
||||
)
|
||||
|
||||
self.quant_config = quant_config
|
||||
|
||||
self.language_model = init_vllm_registered_model(
|
||||
vllm_config=vllm_config,
|
||||
hf_config=config.text_config,
|
||||
prefix=maybe_prefix(prefix, "language_model"),
|
||||
architectures=["Qwen2ForCausalLM"],
|
||||
)
|
||||
with self._mark_tower_model(vllm_config, "audio"):
|
||||
self.audio_tower = Qwen2AudioEncoder(config.audio_config)
|
||||
self.multi_modal_projector = Qwen2AudioMultiModalProjector(
|
||||
config.audio_config.d_model, config.text_config.hidden_size
|
||||
)
|
||||
|
||||
with self._mark_language_model(vllm_config):
|
||||
self.language_model = init_vllm_registered_model(
|
||||
vllm_config=vllm_config,
|
||||
hf_config=config.text_config,
|
||||
prefix=maybe_prefix(prefix, "language_model"),
|
||||
architectures=["Qwen2ForCausalLM"],
|
||||
)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors
|
||||
@@ -441,9 +442,6 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, Supports
|
||||
masked_audio_features, audio_output_lengths.flatten().tolist()
|
||||
)
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
audio_input = self._parse_and_validate_audio_input(**kwargs)
|
||||
if audio_input is None:
|
||||
|
||||
@@ -1612,32 +1612,14 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
self.config = thinker_config
|
||||
self.multimodal_config = multimodal_config
|
||||
|
||||
self.audio_tower = Qwen3OmniMoeAudioEncoder(
|
||||
thinker_config.audio_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=maybe_prefix(prefix, "audio_tower"),
|
||||
)
|
||||
|
||||
self.visual = Qwen3Omni_VisionTransformer(
|
||||
vision_config=thinker_config.vision_config,
|
||||
norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6),
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
multimodal_config=multimodal_config,
|
||||
)
|
||||
self.quant_config = quant_config
|
||||
|
||||
self.language_model = Qwen3MoeLLMForCausalLM(
|
||||
vllm_config=vllm_config.with_hf_config(
|
||||
thinker_config.text_config, architectures=["Qwen3MoeForCausalLM"]
|
||||
),
|
||||
prefix=maybe_prefix(prefix, "language_model"),
|
||||
)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors
|
||||
)
|
||||
with self._mark_tower_model(vllm_config, "audio"):
|
||||
self.audio_tower = Qwen3OmniMoeAudioEncoder(
|
||||
thinker_config.audio_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=maybe_prefix(prefix, "audio_tower"),
|
||||
)
|
||||
|
||||
self.use_deepstack = hasattr(
|
||||
thinker_config.vision_config, "deepstack_visual_indexes"
|
||||
@@ -1647,22 +1629,48 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
|
||||
if self.use_deepstack
|
||||
else 0
|
||||
)
|
||||
# register buffer for deepstack
|
||||
self.deepstack_input_embeds = (
|
||||
[
|
||||
torch.zeros(
|
||||
vllm_config.scheduler_config.max_num_batched_tokens,
|
||||
thinker_config.text_config.hidden_size,
|
||||
)
|
||||
for _ in range(self.deepstack_num_level)
|
||||
]
|
||||
if self.use_deepstack
|
||||
else None
|
||||
)
|
||||
self.visual_dim = thinker_config.vision_config.out_hidden_size
|
||||
self.multiscale_dim = self.visual_dim * self.deepstack_num_level
|
||||
|
||||
def _get_deepstack_input_embeds(self, num_tokens: int) -> IntermediateTensors:
|
||||
with self._mark_tower_model(vllm_config, {"image", "video"}):
|
||||
self.visual = Qwen3Omni_VisionTransformer(
|
||||
vision_config=thinker_config.vision_config,
|
||||
norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6),
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
multimodal_config=multimodal_config,
|
||||
)
|
||||
|
||||
# register buffer for deepstack
|
||||
if self.use_deepstack:
|
||||
self.deepstack_input_embeds = [
|
||||
torch.zeros(
|
||||
vllm_config.scheduler_config.max_num_batched_tokens,
|
||||
thinker_config.text_config.hidden_size,
|
||||
)
|
||||
for _ in range(self.deepstack_num_level)
|
||||
]
|
||||
|
||||
with self._mark_language_model(vllm_config):
|
||||
self.language_model = Qwen3MoeLLMForCausalLM(
|
||||
vllm_config=vllm_config.with_hf_config(
|
||||
thinker_config.text_config,
|
||||
architectures=["Qwen3MoeForCausalLM"],
|
||||
),
|
||||
prefix=maybe_prefix(prefix, "language_model"),
|
||||
)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def _get_deepstack_input_embeds(
|
||||
self,
|
||||
num_tokens: int,
|
||||
) -> IntermediateTensors | None:
|
||||
if not getattr(self, "deepstack_input_embeds", None):
|
||||
return None # If vision tower is skipped
|
||||
|
||||
# get deepstack_input_embeds from buffer, and clear the buffer
|
||||
return IntermediateTensors(
|
||||
{
|
||||
@@ -1674,6 +1682,9 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
|
||||
)
|
||||
|
||||
def _set_deepstack_input_embeds(self, deepstack_input_embeds: torch.Tensor) -> None:
|
||||
if not getattr(self, "deepstack_input_embeds", None):
|
||||
return
|
||||
|
||||
# set deepstack_input_embeds to buffer
|
||||
num_tokens = deepstack_input_embeds.size(1)
|
||||
if num_tokens > self.deepstack_input_embeds[0].size(0):
|
||||
@@ -1692,6 +1703,9 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
|
||||
)
|
||||
|
||||
def _clear_deepstack_input_embeds(self, num_tokens: int) -> None:
|
||||
if not getattr(self, "deepstack_input_embeds", None):
|
||||
return
|
||||
|
||||
# clear deepstack_input_embeds in buffer
|
||||
if num_tokens > 0:
|
||||
for idx in range(self.deepstack_num_level):
|
||||
@@ -1726,9 +1740,6 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
|
||||
)
|
||||
return mm_input_by_modality
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None:
|
||||
mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
|
||||
if not mm_input_by_modality:
|
||||
@@ -1844,11 +1855,7 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
|
||||
if (
|
||||
self.use_deepstack
|
||||
and inputs_embeds is not None
|
||||
and get_pp_group().is_first_rank
|
||||
):
|
||||
if inputs_embeds is not None and get_pp_group().is_first_rank:
|
||||
deepstack_input_embeds = self._get_deepstack_input_embeds(
|
||||
inputs_embeds.size(0)
|
||||
)
|
||||
|
||||
@@ -1321,7 +1321,13 @@ class Qwen3VLForConditionalGeneration(
|
||||
num_layers = len(self.language_model.model.layers)
|
||||
return (2, num_layers // 2, num_layers - 3)
|
||||
|
||||
def _get_deepstack_input_embeds(self, num_tokens: int) -> IntermediateTensors:
|
||||
def _get_deepstack_input_embeds(
|
||||
self,
|
||||
num_tokens: int,
|
||||
) -> IntermediateTensors | None:
|
||||
if not getattr(self, "deepstack_input_embeds", None):
|
||||
return None # If vision tower is skipped
|
||||
|
||||
# get deepstack_input_embeds from buffer, and clear the buffer
|
||||
return IntermediateTensors(
|
||||
{
|
||||
@@ -1333,6 +1339,9 @@ class Qwen3VLForConditionalGeneration(
|
||||
)
|
||||
|
||||
def _set_deepstack_input_embeds(self, deepstack_input_embeds: torch.Tensor) -> None:
|
||||
if not getattr(self, "deepstack_input_embeds", None):
|
||||
return
|
||||
|
||||
# set deepstack_input_embeds to buffer
|
||||
num_tokens = deepstack_input_embeds.size(1)
|
||||
if num_tokens > self.deepstack_input_embeds[0].size(0):
|
||||
@@ -1351,6 +1360,9 @@ class Qwen3VLForConditionalGeneration(
|
||||
)
|
||||
|
||||
def _clear_deepstack_input_embeds(self, num_tokens: int) -> None:
|
||||
if not getattr(self, "deepstack_input_embeds", None):
|
||||
return
|
||||
|
||||
# clear deepstack_input_embeds in buffer
|
||||
if num_tokens > 0:
|
||||
for idx in range(self.deepstack_num_level):
|
||||
@@ -2037,11 +2049,7 @@ class Qwen3VLForConditionalGeneration(
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
|
||||
if (
|
||||
self.use_deepstack
|
||||
and inputs_embeds is not None
|
||||
and get_pp_group().is_first_rank
|
||||
):
|
||||
if inputs_embeds is not None and get_pp_group().is_first_rank:
|
||||
deepstack_input_embeds = self._get_deepstack_input_embeds(
|
||||
inputs_embeds.size(0)
|
||||
)
|
||||
|
||||
@@ -620,7 +620,6 @@ class RadioInternVisionModel(nn.Module):
|
||||
x: torch.Tensor,
|
||||
imgs_sizes: torch.Tensor | None = None,
|
||||
) -> torch.FloatTensor:
|
||||
assert self.patch_generator is not None
|
||||
hidden_states = self.patch_generator(x, imgs_sizes=imgs_sizes)
|
||||
attn_mask = None
|
||||
if imgs_sizes is not None and len(imgs_sizes) > 1:
|
||||
|
||||
@@ -1033,21 +1033,23 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
|
||||
|
||||
self.text_embed_dim = text_config.hidden_size
|
||||
self.vision_embed_dim = vision_config.hidden_size
|
||||
|
||||
self.text_model = SiglipTextTransformer(
|
||||
text_config,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "text_model"),
|
||||
)
|
||||
self.vision_model = SiglipVisionTransformer(
|
||||
vision_config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=maybe_prefix(prefix, "vision_model"),
|
||||
)
|
||||
|
||||
self.text_projection_size = text_config.projection_size
|
||||
|
||||
with self._mark_language_model(vllm_config):
|
||||
self.text_model = SiglipTextTransformer(
|
||||
text_config,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "text_model"),
|
||||
)
|
||||
|
||||
with self._mark_tower_model(vllm_config, "image"):
|
||||
self.vision_model = SiglipVisionTransformer(
|
||||
vision_config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=maybe_prefix(prefix, "vision_model"),
|
||||
)
|
||||
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
self.pooler_config = pooler_config
|
||||
@@ -1155,9 +1157,6 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
|
||||
|
||||
return self.get_image_features(pixel_values)
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.text_model
|
||||
|
||||
def _embed_text_input_ids(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
|
||||
@@ -674,24 +674,26 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
self.downsample_ratio = config.downsample_ratio
|
||||
self.ps_version = config.ps_version
|
||||
|
||||
self.llm_arch_name = config.text_config.architectures[0]
|
||||
self.is_mono = self.llm_arch_name == "SkyworkLM2VEForCausalLM"
|
||||
self.vision_model = self._init_vision_model(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
is_mono=self.is_mono,
|
||||
prefix=maybe_prefix(prefix, "vision_model"),
|
||||
)
|
||||
llm_arch_name = config.text_config.architectures[0]
|
||||
self.is_mono = llm_arch_name == "SkyworkLM2VEForCausalLM"
|
||||
|
||||
self.language_model = init_vllm_registered_model(
|
||||
vllm_config=vllm_config,
|
||||
hf_config=config.text_config,
|
||||
prefix=maybe_prefix(prefix, "language_model"),
|
||||
)
|
||||
with self._mark_tower_model(vllm_config, "image"):
|
||||
self.vision_model = self._init_vision_model(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
is_mono=self.is_mono,
|
||||
prefix=maybe_prefix(prefix, "vision_model"),
|
||||
)
|
||||
self.mlp1 = self._init_mlp1(
|
||||
config, quant_config, prefix=maybe_prefix(prefix, "mlp1")
|
||||
)
|
||||
|
||||
self.mlp1 = self._init_mlp1(
|
||||
config, quant_config, prefix=maybe_prefix(prefix, "mlp1")
|
||||
)
|
||||
with self._mark_language_model(vllm_config):
|
||||
self.language_model = init_vllm_registered_model(
|
||||
vllm_config=vllm_config,
|
||||
hf_config=config.text_config,
|
||||
prefix=maybe_prefix(prefix, "language_model"),
|
||||
)
|
||||
|
||||
self.img_context_token_id = None
|
||||
self.visual_token_mask = None
|
||||
@@ -838,8 +840,6 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
if image_input["type"] == "image_embeds":
|
||||
return image_input["data"]
|
||||
|
||||
assert self.vision_model is not None
|
||||
|
||||
image_embeds = self.extract_feature(image_input["pixel_values_flat"])
|
||||
|
||||
num_patches = image_input["num_patches"]
|
||||
@@ -867,9 +867,6 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
else:
|
||||
self.visual_token_mask = None
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
|
||||
@@ -423,38 +423,43 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
|
||||
self.config = config # Storing the Tarsier-specific HF config
|
||||
self.vision_tower = init_vision_tower_for_tarsier(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
require_post_norm=False,
|
||||
prefix=maybe_prefix(prefix, "vision_tower"),
|
||||
)
|
||||
projector_bias = getattr(config, "multimodal_projector_bias", True)
|
||||
|
||||
self.multi_modal_projector = TarsierMultiModalProjector(
|
||||
vision_hidden_size=config.vision_config.hidden_size,
|
||||
text_hidden_size=config.text_config.hidden_size,
|
||||
projector_hidden_act=config.projector_hidden_act,
|
||||
multimodal_projector_bias=projector_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "multi_modal_projector"),
|
||||
)
|
||||
self.language_model = init_vllm_registered_model(
|
||||
vllm_config=vllm_config,
|
||||
hf_config=config.text_config, # Use text_config from Tarsier's main config
|
||||
prefix=maybe_prefix(prefix, "language_model"),
|
||||
)
|
||||
self.register_buffer(
|
||||
"image_newline_idx_tensor",
|
||||
torch.tensor([config.image_newline_idx], dtype=torch.long),
|
||||
persistent=False,
|
||||
)
|
||||
self.register_buffer(
|
||||
"image_new_idx_tensor",
|
||||
torch.tensor([config.image_new_idx], dtype=torch.long),
|
||||
persistent=False,
|
||||
)
|
||||
with self._mark_tower_model(vllm_config, "image"):
|
||||
self.vision_tower = init_vision_tower_for_tarsier(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
multimodal_config=multimodal_config,
|
||||
require_post_norm=False,
|
||||
prefix=maybe_prefix(prefix, "vision_tower"),
|
||||
)
|
||||
projector_bias = getattr(config, "multimodal_projector_bias", True)
|
||||
|
||||
self.multi_modal_projector = TarsierMultiModalProjector(
|
||||
vision_hidden_size=config.vision_config.hidden_size,
|
||||
text_hidden_size=config.text_config.hidden_size,
|
||||
projector_hidden_act=config.projector_hidden_act,
|
||||
multimodal_projector_bias=projector_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "multi_modal_projector"),
|
||||
)
|
||||
self.register_buffer(
|
||||
"image_newline_idx_tensor",
|
||||
torch.tensor([config.image_newline_idx], dtype=torch.long),
|
||||
persistent=False,
|
||||
)
|
||||
self.register_buffer(
|
||||
"image_new_idx_tensor",
|
||||
torch.tensor([config.image_new_idx], dtype=torch.long),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
with self._mark_language_model(vllm_config):
|
||||
self.language_model = init_vllm_registered_model(
|
||||
vllm_config=vllm_config,
|
||||
# Use text_config from Tarsier's main config
|
||||
hf_config=config.text_config,
|
||||
prefix=maybe_prefix(prefix, "language_model"),
|
||||
)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors
|
||||
@@ -547,7 +552,6 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
|
||||
self,
|
||||
inputs: TarsierImagePixelInputs,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, ...]:
|
||||
assert self.vision_tower is not None
|
||||
pixel_values = inputs["pixel_values"]
|
||||
image_features_selected = self._image_pixels_to_features(
|
||||
self.vision_tower, pixel_values
|
||||
@@ -575,11 +579,8 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
|
||||
"Incorrect type of image_embeds. "
|
||||
f"Got type: {type(projected_features)}. "
|
||||
)
|
||||
assert self.vision_tower is not None
|
||||
return self._process_image_pixels(image_input)
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
return self._process_image_pixels(image_input)
|
||||
|
||||
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
|
||||
@@ -543,7 +543,6 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
|
||||
assert self.multi_modal_config
|
||||
|
||||
self.secondary_weights = []
|
||||
self.audio_tower = ModifiedWhisperEncoder(config.audio_config)
|
||||
if config.audio_model_id is not None:
|
||||
# this prefix is not for initialization, but for loading weights
|
||||
# note the trailing dot
|
||||
@@ -554,15 +553,6 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
|
||||
prefix="audio_tower.",
|
||||
)
|
||||
)
|
||||
if config.num_projector_layers > 0:
|
||||
self.multi_modal_projector = UltravoxTransformerProjector(config)
|
||||
else:
|
||||
self.multi_modal_projector = UltravoxFeedForwardProjector(config)
|
||||
self.language_model = init_vllm_registered_model(
|
||||
vllm_config=vllm_config,
|
||||
hf_config=config.wrapped_model_config,
|
||||
prefix=maybe_prefix(prefix, "language_model"),
|
||||
)
|
||||
if config.text_model_id is not None:
|
||||
# this prefix is not for initialization, but for loading weights
|
||||
# note the trailing dot
|
||||
@@ -574,6 +564,20 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
|
||||
)
|
||||
)
|
||||
|
||||
with self._mark_tower_model(vllm_config, "audio"):
|
||||
self.audio_tower = ModifiedWhisperEncoder(config.audio_config)
|
||||
if config.num_projector_layers > 0:
|
||||
self.multi_modal_projector = UltravoxTransformerProjector(config)
|
||||
else:
|
||||
self.multi_modal_projector = UltravoxFeedForwardProjector(config)
|
||||
|
||||
with self._mark_language_model(vllm_config):
|
||||
self.language_model = init_vllm_registered_model(
|
||||
vllm_config=vllm_config,
|
||||
hf_config=config.wrapped_model_config,
|
||||
prefix=maybe_prefix(prefix, "language_model"),
|
||||
)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors
|
||||
)
|
||||
@@ -681,9 +685,6 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
|
||||
]
|
||||
return flattened_embeddings.split(embed_lens)
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
audio_input = self._parse_and_validate_audio_input(**kwargs)
|
||||
if audio_input is None:
|
||||
|
||||
@@ -366,22 +366,22 @@ class VoxtralForConditionalGeneration(
|
||||
self.config = config
|
||||
self.downsample_factor = self.config.audio_config.downsample_factor
|
||||
|
||||
self.language_model = init_vllm_registered_model(
|
||||
vllm_config=vllm_config,
|
||||
hf_config=config.text_config,
|
||||
prefix=maybe_prefix(prefix, "language_model"),
|
||||
)
|
||||
self.whisper_encoder = VoxtralEncoderModel(
|
||||
vllm_config.with_hf_config(config.audio_config),
|
||||
prefix=maybe_prefix(prefix, "whisper_encoder"),
|
||||
)
|
||||
self.audio_language_adapter = AudioLanguageAdapter(
|
||||
hidden_size=config.audio_config.d_model * self.downsample_factor,
|
||||
dim=config.text_config.hidden_size,
|
||||
)
|
||||
with self._mark_language_model(vllm_config):
|
||||
self.language_model = init_vllm_registered_model(
|
||||
vllm_config=vllm_config,
|
||||
hf_config=config.text_config,
|
||||
prefix=maybe_prefix(prefix, "language_model"),
|
||||
)
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
with self._mark_tower_model(vllm_config, "audio"):
|
||||
self.whisper_encoder = VoxtralEncoderModel(
|
||||
vllm_config.with_hf_config(config.audio_config),
|
||||
prefix=maybe_prefix(prefix, "whisper_encoder"),
|
||||
)
|
||||
self.audio_language_adapter = AudioLanguageAdapter(
|
||||
hidden_size=config.audio_config.d_model * self.downsample_factor,
|
||||
dim=config.text_config.hidden_size,
|
||||
)
|
||||
|
||||
def get_mm_mapping(self) -> MultiModelKeys:
|
||||
"""Get module prefix for multimodal models to filter LoRA modules."""
|
||||
|
||||
Reference in New Issue
Block a user