From e14467be432bf55493c21ffd9f7d1d4c32e14e19 Mon Sep 17 00:00:00 2001 From: Divakar Verma <137818590+divakar-amd@users.noreply.github.com> Date: Wed, 21 Jan 2026 07:11:31 -0600 Subject: [PATCH] [bugfix] Aria model (#32727) Signed-off-by: Divakar Verma --- vllm/model_executor/models/aria.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index f62d793ef..2c192c7d9 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -15,7 +15,9 @@ from vllm.distributed import get_tensor_model_parallel_rank from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.fused_moe import SharedFusedMoE from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name, @@ -554,6 +556,18 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): prefix=maybe_prefix(prefix, "language_model.model"), ) + self.lm_head = ParallelLMHead( + config.text_config.vocab_size, + config.text_config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) + + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor( + config.text_config.vocab_size, scale=logit_scale + ) + def _parse_and_validate_image_input( self, **kwargs: object ) -> AriaImagePixelInputs | None: @@ -637,7 +651,8 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): self, hidden_states: torch.Tensor, ) -> torch.Tensor | None: - return self.language_model.compute_logits(hidden_states) + logits = self.logits_processor(self.lm_head, hidden_states) + return logits def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): loader = AutoWeightsLoader(self)