[bugfix] Aria model (#32727)

Signed-off-by: Divakar Verma <divakar.verma@amd.com>
This commit is contained in:
Divakar Verma
2026-01-21 07:11:31 -06:00
committed by GitHub
parent 7727ce35c2
commit e14467be43

View File

@@ -15,7 +15,9 @@ from vllm.distributed import get_tensor_model_parallel_rank
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
@@ -554,6 +556,18 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
prefix=maybe_prefix(prefix, "language_model.model"),
)
self.lm_head = ParallelLMHead(
config.text_config.vocab_size,
config.text_config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(
config.text_config.vocab_size, scale=logit_scale
)
def _parse_and_validate_image_input(
self, **kwargs: object
) -> AriaImagePixelInputs | None:
@@ -637,7 +651,8 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
self,
hidden_states: torch.Tensor,
) -> torch.Tensor | None:
return self.language_model.compute_logits(hidden_states)
logits = self.logits_processor(self.lm_head, hidden_states)
return logits
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self)