[1/N] Initialize MM components in context managers (A-D) (#32632)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -15,9 +15,7 @@ from vllm.distributed import get_tensor_model_parallel_rank
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
|
||||
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader,
|
||||
maybe_remap_kv_scale_name,
|
||||
@@ -539,30 +537,22 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.config = config
|
||||
self.vision_tower = AriaVisionTransformer(
|
||||
config.vision_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.vision_tower",
|
||||
)
|
||||
self.multi_modal_projector = AriaProjector(
|
||||
config, prefix=maybe_prefix(prefix, "multi_modal_projector")
|
||||
)
|
||||
self.vocab_size = config.text_config.vocab_size
|
||||
self.language_model = AriaTextModel(
|
||||
vllm_config=vllm_config.with_hf_config(config.text_config),
|
||||
prefix=maybe_prefix(prefix, "language_model.model"),
|
||||
)
|
||||
self.pad_token_id = (
|
||||
self.config.pad_token_id if self.config.pad_token_id is not None else -1
|
||||
)
|
||||
self.lm_head = ParallelLMHead(
|
||||
self.vocab_size,
|
||||
config.text_config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
logit_scale = getattr(config, "logit_scale", 1.0)
|
||||
self.logits_processor = LogitsProcessor(self.vocab_size, scale=logit_scale)
|
||||
|
||||
with self._mark_tower_model(vllm_config, "image"):
|
||||
self.vision_tower = AriaVisionTransformer(
|
||||
config.vision_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.vision_tower",
|
||||
)
|
||||
self.multi_modal_projector = AriaProjector(
|
||||
config, prefix=maybe_prefix(prefix, "multi_modal_projector")
|
||||
)
|
||||
|
||||
with self._mark_language_model(vllm_config):
|
||||
self.language_model = AriaTextModel(
|
||||
vllm_config=vllm_config.with_hf_config(config.text_config),
|
||||
prefix=maybe_prefix(prefix, "language_model.model"),
|
||||
)
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object
|
||||
@@ -618,9 +608,6 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
|
||||
return self.multi_modal_projector(image_outputs, image_attn_mask)
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
@@ -654,9 +641,11 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states)
|
||||
return logits
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor | None:
|
||||
return self.language_model.compute_logits(hidden_states)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
loader = AutoWeightsLoader(self)
|
||||
|
||||
Reference in New Issue
Block a user