[1/N] Initialize MM components in context managers (A-D) (#32632)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-01-20 14:12:42 +08:00
committed by GitHub
parent 4753f3bf69
commit b75e85dede
11 changed files with 240 additions and 268 deletions

View File

@@ -15,9 +15,7 @@ from vllm.distributed import get_tensor_model_parallel_rank
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
@@ -539,30 +537,22 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
quant_config = vllm_config.quant_config
self.config = config
self.vision_tower = AriaVisionTransformer(
config.vision_config,
quant_config=quant_config,
prefix=f"{prefix}.vision_tower",
)
self.multi_modal_projector = AriaProjector(
config, prefix=maybe_prefix(prefix, "multi_modal_projector")
)
self.vocab_size = config.text_config.vocab_size
self.language_model = AriaTextModel(
vllm_config=vllm_config.with_hf_config(config.text_config),
prefix=maybe_prefix(prefix, "language_model.model"),
)
self.pad_token_id = (
self.config.pad_token_id if self.config.pad_token_id is not None else -1
)
self.lm_head = ParallelLMHead(
self.vocab_size,
config.text_config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.vocab_size, scale=logit_scale)
with self._mark_tower_model(vllm_config, "image"):
self.vision_tower = AriaVisionTransformer(
config.vision_config,
quant_config=quant_config,
prefix=f"{prefix}.vision_tower",
)
self.multi_modal_projector = AriaProjector(
config, prefix=maybe_prefix(prefix, "multi_modal_projector")
)
with self._mark_language_model(vllm_config):
self.language_model = AriaTextModel(
vllm_config=vllm_config.with_hf_config(config.text_config),
prefix=maybe_prefix(prefix, "language_model.model"),
)
def _parse_and_validate_image_input(
self, **kwargs: object
@@ -618,9 +608,6 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
return self.multi_modal_projector(image_outputs, image_attn_mask)
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
@@ -654,9 +641,11 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
logits = self.logits_processor(self.lm_head, hidden_states)
return logits
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor | None:
return self.language_model.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self)