[bugfix] Aria model (#32727)
Signed-off-by: Divakar Verma <divakar.verma@amd.com>
This commit is contained in:
@@ -15,7 +15,9 @@ from vllm.distributed import get_tensor_model_parallel_rank
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
|
||||
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader,
|
||||
maybe_remap_kv_scale_name,
|
||||
@@ -554,6 +556,18 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
prefix=maybe_prefix(prefix, "language_model.model"),
|
||||
)
|
||||
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.text_config.vocab_size,
|
||||
config.text_config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
|
||||
logit_scale = getattr(config, "logit_scale", 1.0)
|
||||
self.logits_processor = LogitsProcessor(
|
||||
config.text_config.vocab_size, scale=logit_scale
|
||||
)
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object
|
||||
) -> AriaImagePixelInputs | None:
|
||||
@@ -637,7 +651,8 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor | None:
|
||||
return self.language_model.compute_logits(hidden_states)
|
||||
logits = self.logits_processor(self.lm_head, hidden_states)
|
||||
return logits
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
loader = AutoWeightsLoader(self)
|
||||
|
||||
Reference in New Issue
Block a user