diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index f62d793ef..2c192c7d9 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -15,7 +15,9 @@ from vllm.distributed import get_tensor_model_parallel_rank from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.fused_moe import SharedFusedMoE from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name, @@ -554,6 +556,18 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): prefix=maybe_prefix(prefix, "language_model.model"), ) + self.lm_head = ParallelLMHead( + config.text_config.vocab_size, + config.text_config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) + + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor( + config.text_config.vocab_size, scale=logit_scale + ) + def _parse_and_validate_image_input( self, **kwargs: object ) -> AriaImagePixelInputs | None: @@ -637,7 +651,8 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): self, hidden_states: torch.Tensor, ) -> torch.Tensor | None: - return self.language_model.compute_logits(hidden_states) + logits = self.logits_processor(self.lm_head, hidden_states) + return logits def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): loader = AutoWeightsLoader(self)