[Model] Remove unnecessary get_language_model (#37545)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -1704,6 +1704,12 @@ class ConformerEncoder(nn.Module):
|
||||
# ----- Encoder END -----
|
||||
|
||||
|
||||
# This subclass is specific to vLLM in order for
|
||||
# `_mark_composite_model` to target this module
|
||||
class CohereASRProjector(nn.Linear):
|
||||
pass
|
||||
|
||||
|
||||
class CohereASRModel(nn.Module):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
@@ -1714,7 +1720,7 @@ class CohereASRModel(nn.Module):
|
||||
)
|
||||
|
||||
if self.encoder.d_model != self.decoder.hidden_size:
|
||||
self.encoder_decoder_proj = torch.nn.Linear(
|
||||
self.encoder_decoder_proj = CohereASRProjector(
|
||||
self.encoder.d_model, self.decoder.hidden_size
|
||||
)
|
||||
|
||||
@@ -2096,18 +2102,25 @@ class CohereASRForConditionalGeneration(
|
||||
self.config = config
|
||||
self.dtype = vllm_config.model_config.dtype
|
||||
|
||||
self.model = CohereASRModel(vllm_config=vllm_config, prefix=prefix)
|
||||
lm_head_config = config.head
|
||||
self.unpadded_vocab_size = lm_head_config["num_classes"]
|
||||
with self._mark_composite_model(
|
||||
vllm_config,
|
||||
language_targets=CohereASRDecoder,
|
||||
tower_targets={"audio": (ConformerEncoder, CohereASRProjector)},
|
||||
):
|
||||
self.model = CohereASRModel(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
head_config = config.head
|
||||
|
||||
self.proj_out = ParallelLMHead(
|
||||
lm_head_config["num_classes"],
|
||||
lm_head_config["hidden_size"],
|
||||
head_config["num_classes"],
|
||||
head_config["hidden_size"],
|
||||
quant_config=quant_config,
|
||||
bias=True,
|
||||
) # NOTE: bias is True
|
||||
logit_scale = getattr(lm_head_config, "logit_scale", 1.0)
|
||||
|
||||
logit_scale = getattr(head_config, "logit_scale", 1.0)
|
||||
self.logits_processor = LogitsProcessor(
|
||||
self.unpadded_vocab_size, lm_head_config["num_classes"], logit_scale
|
||||
head_config["num_classes"], scale=logit_scale
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
||||
@@ -1373,7 +1373,6 @@ class Ernie4_5_VLMoeForConditionalGeneration(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor | None:
|
||||
"""compute logits"""
|
||||
return self.language_model.compute_logits(hidden_states)
|
||||
|
||||
def _vision_forward(
|
||||
|
||||
@@ -754,12 +754,17 @@ class FireRedASR2ForConditionalGeneration(
|
||||
self.config = config
|
||||
self.dtype = vllm_config.model_config.dtype
|
||||
|
||||
self.model = FireRedASR2Model(
|
||||
vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"),
|
||||
)
|
||||
logit_scale = getattr(config, "logit_scale", 1.0)
|
||||
with self._mark_composite_model(
|
||||
vllm_config,
|
||||
language_targets=Qwen2ForCausalLM,
|
||||
tower_targets={"audio": (FireRedASR2Encoder, FireRedASR2Adapter)},
|
||||
):
|
||||
self.model = FireRedASR2Model(
|
||||
vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"),
|
||||
)
|
||||
|
||||
logit_scale = getattr(config, "logit_scale", 1.0)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size, scale=logit_scale)
|
||||
|
||||
def forward(
|
||||
|
||||
@@ -470,15 +470,6 @@ class HCXVisionV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
self.vision_config = vision_config
|
||||
self.text_config = text_config
|
||||
self.vllm_config = vllm_config
|
||||
self.dtype = vllm_config.model_config.dtype
|
||||
|
||||
# Initialize Qwen2.5 Vision Transformer
|
||||
self.visual = Qwen2_5_VisionTransformer(
|
||||
vision_config=vision_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
)
|
||||
|
||||
# Linear projector (vision_hidden_size -> text_hidden_size)
|
||||
# For V2 model: mm_projector_type is "linear"
|
||||
@@ -492,18 +483,21 @@ class HCXVisionV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
else:
|
||||
out_hidden = vision_hidden_size
|
||||
|
||||
# Always create Linear projector since HF checkpoint has mm_projector weights
|
||||
self.mm_projector = nn.Linear(out_hidden, text_hidden_size)
|
||||
with self._mark_tower_model(vllm_config, {"image", "video"}):
|
||||
self.visual = Qwen2_5_VisionTransformer(
|
||||
vision_config=vision_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
)
|
||||
self.mm_projector = nn.Linear(out_hidden, text_hidden_size)
|
||||
|
||||
# Language model
|
||||
self.lm_head_vocab_size = getattr(
|
||||
text_config, "padded_vocab_size", text_config.vocab_size
|
||||
)
|
||||
self.language_model = init_vllm_registered_model(
|
||||
vllm_config=vllm_config,
|
||||
hf_config=text_config,
|
||||
prefix=maybe_prefix(prefix, "language_model"),
|
||||
)
|
||||
with self._mark_language_model(vllm_config):
|
||||
self.language_model = init_vllm_registered_model(
|
||||
vllm_config=vllm_config,
|
||||
hf_config=text_config,
|
||||
prefix=maybe_prefix(prefix, "language_model"),
|
||||
)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors
|
||||
@@ -633,9 +627,6 @@ class HCXVisionV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
return modalities
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def embed_multimodal(
|
||||
self,
|
||||
**kwargs: object,
|
||||
|
||||
@@ -576,20 +576,19 @@ class InternS1ProForConditionalGeneration(
|
||||
multimodal_config.is_multimodal_pruning_enabled()
|
||||
)
|
||||
|
||||
if not multimodal_config.get_limit_per_prompt(
|
||||
"image"
|
||||
) and not multimodal_config.get_limit_per_prompt("video"):
|
||||
self.visual = None
|
||||
else:
|
||||
with self._mark_tower_model(vllm_config, {"image", "video"}):
|
||||
self.visual = Qwen3_VisionTransformer(
|
||||
config.vision_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
)
|
||||
|
||||
self.language_model = InternS1ProMoeLLMForCausalLM(
|
||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "language_model")
|
||||
)
|
||||
with self._mark_language_model(vllm_config):
|
||||
self.language_model = InternS1ProMoeLLMForCausalLM(
|
||||
vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "language_model"),
|
||||
)
|
||||
|
||||
# Whether to include the gate_up_proj mapping is determined by
|
||||
# the language model.
|
||||
self.packed_modules_mapping = (
|
||||
|
||||
@@ -15,7 +15,6 @@ from transformers import WhisperConfig as HFWhisperConfig
|
||||
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.inputs.data import PromptType, TokensPrompt
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.model_loader import DefaultModelLoader
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.interfaces import (
|
||||
@@ -54,7 +53,6 @@ from vllm.tokenizers import cached_get_tokenizer
|
||||
from vllm.tokenizers.kimi_audio import KimiAudioTokenizer
|
||||
from vllm.transformers_utils.processor import cached_feature_extractor_from_config
|
||||
from vllm.transformers_utils.processors.kimi_audio import KimiAudioProcessor
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
|
||||
# Kimi-Audio constants
|
||||
KIMIA_WHISPER_SUBFOLDER = "whisper-large-v3"
|
||||
@@ -431,28 +429,24 @@ class KimiAudioForConditionalGeneration(
|
||||
)
|
||||
]
|
||||
|
||||
self.audio_tower = KimiAudioWhisperEncoder(
|
||||
vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "audio_tower"),
|
||||
)
|
||||
with self._mark_tower_model(vllm_config, "audio"):
|
||||
self.audio_tower = KimiAudioWhisperEncoder(
|
||||
vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "audio_tower"),
|
||||
)
|
||||
self.multi_modal_projector = KimiAudioMultiModalProjector(
|
||||
whisper_dim=getattr(self.config, "kimia_adaptor_input_dim", 5120),
|
||||
llm_dim=self.config.hidden_size,
|
||||
prefix=maybe_prefix(prefix, "multi_modal_projector"),
|
||||
)
|
||||
|
||||
self.multi_modal_projector = KimiAudioMultiModalProjector(
|
||||
whisper_dim=getattr(self.config, "kimia_adaptor_input_dim", 5120),
|
||||
llm_dim=self.config.hidden_size,
|
||||
prefix=maybe_prefix(prefix, "multi_modal_projector"),
|
||||
)
|
||||
|
||||
self.language_model = init_vllm_registered_model(
|
||||
vllm_config=vllm_config.with_hf_config(
|
||||
self.config, architectures=["Qwen2ForCausalLM"]
|
||||
),
|
||||
prefix=maybe_prefix(prefix, "language_model"),
|
||||
)
|
||||
|
||||
self.logits_processor = LogitsProcessor(
|
||||
self.config.vocab_size,
|
||||
self.config.vocab_size,
|
||||
)
|
||||
with self._mark_language_model(vllm_config):
|
||||
self.language_model = init_vllm_registered_model(
|
||||
vllm_config=vllm_config.with_hf_config(
|
||||
self.config, architectures=["Qwen2ForCausalLM"]
|
||||
),
|
||||
prefix=maybe_prefix(prefix, "language_model"),
|
||||
)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors
|
||||
@@ -595,12 +589,8 @@ class KimiAudioForConditionalGeneration(
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata | None = None,
|
||||
) -> torch.Tensor | None:
|
||||
logits = self.logits_processor(
|
||||
self.language_model.lm_head, hidden_states, sampling_metadata
|
||||
)
|
||||
return logits
|
||||
return self.language_model.compute_logits(hidden_states)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
"""Load weights, skipping MIMO layers (TTS-only) for ASR."""
|
||||
|
||||
@@ -163,29 +163,30 @@ class LightOnOCRForConditionalGeneration(Mistral3ForConditionalGeneration):
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
|
||||
self.vision_tower = init_vision_tower_for_llava(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
require_post_norm=False,
|
||||
prefix=maybe_prefix(prefix, "vision_tower"),
|
||||
)
|
||||
with self._mark_tower_model(vllm_config, "image"):
|
||||
self.vision_tower = init_vision_tower_for_llava(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
require_post_norm=False,
|
||||
prefix=maybe_prefix(prefix, "vision_tower"),
|
||||
)
|
||||
self.multi_modal_projector = Mistral3MultiModalProjector(
|
||||
vision_hidden_size=config.vision_config.hidden_size,
|
||||
text_hidden_size=config.text_config.hidden_size,
|
||||
projector_hidden_act=config.projector_hidden_act,
|
||||
spatial_merge_size=config.spatial_merge_size,
|
||||
patch_size=config.vision_config.patch_size,
|
||||
multimodal_projector_bias=config.multimodal_projector_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "multi_modal_projector"),
|
||||
)
|
||||
|
||||
self.multi_modal_projector = Mistral3MultiModalProjector(
|
||||
vision_hidden_size=config.vision_config.hidden_size,
|
||||
text_hidden_size=config.text_config.hidden_size,
|
||||
projector_hidden_act=config.projector_hidden_act,
|
||||
spatial_merge_size=config.spatial_merge_size,
|
||||
patch_size=config.vision_config.patch_size,
|
||||
multimodal_projector_bias=config.multimodal_projector_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "multi_modal_projector"),
|
||||
)
|
||||
|
||||
self.language_model = init_vllm_registered_model(
|
||||
vllm_config=vllm_config,
|
||||
hf_config=config.text_config,
|
||||
prefix=maybe_prefix(prefix, "language_model"),
|
||||
)
|
||||
with self._mark_language_model(vllm_config):
|
||||
self.language_model = init_vllm_registered_model(
|
||||
vllm_config=vllm_config,
|
||||
hf_config=config.text_config,
|
||||
prefix=maybe_prefix(prefix, "language_model"),
|
||||
)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors
|
||||
|
||||
Reference in New Issue
Block a user