[Meta] Official Eagle mm support, first enablement on llama4 (#20788)
Signed-off-by: morgendave <morgendave@gmail.com> Co-authored-by: Roger Wang <hey@rogerw.me>
This commit is contained in:
@@ -256,6 +256,7 @@ class Llama4DecoderLayer(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
self.layer_idx = extract_layer_index(prefix)
|
||||
self.global_layer = config.no_rope_layers[self.layer_idx] == 0
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = config.rope_theta
|
||||
rope_scaling = config.rope_scaling
|
||||
|
||||
@@ -37,8 +37,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.llama4 import (Llama4DecoderLayer,
|
||||
Llama4ForCausalLM)
|
||||
from vllm.model_executor.models.utils import extract_layer_index
|
||||
from vllm.multimodal.inputs import NestedTensors
|
||||
|
||||
from .utils import AutoWeightsLoader, maybe_prefix
|
||||
from .utils import AutoWeightsLoader, maybe_prefix, merge_multimodal_embeddings
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -78,15 +79,23 @@ class LlamaModel(nn.Module):
|
||||
self.norm = RMSNorm(self.config.hidden_size,
|
||||
eps=self.config.rms_norm_eps)
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
input_embeds = self.embed_tokens(input_ids)
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings(input_ids)
|
||||
hidden_states = self.fc(
|
||||
torch.cat((input_embeds, hidden_states), dim=-1))
|
||||
torch.cat((inputs_embeds, hidden_states), dim=-1))
|
||||
residual = None
|
||||
for layer in self.layers:
|
||||
hidden_states, residual = layer(
|
||||
@@ -190,8 +199,9 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM):
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return self.model(input_ids, positions, hidden_states)
|
||||
return self.model(input_ids, positions, hidden_states, inputs_embeds)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> None:
|
||||
@@ -212,3 +222,20 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM):
|
||||
model_weights[name] = loaded_weight
|
||||
|
||||
loader.load_weights(model_weights.items())
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[NestedTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.model.get_input_embeddings(input_ids)
|
||||
|
||||
if multimodal_embeddings is not None:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
multimodal_embeddings,
|
||||
self.config.image_token_index,
|
||||
)
|
||||
|
||||
return inputs_embeds
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -148,7 +149,12 @@ class EagleLlamaForCausalLM(LlamaForCausalLM):
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if inputs_embeds is not None:
|
||||
raise NotImplementedError(
|
||||
f"{type(self).__name__} does not support multimodal inputs yet."
|
||||
)
|
||||
return self.model(input_ids, positions, hidden_states)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
|
||||
@@ -202,7 +202,12 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if inputs_embeds is not None:
|
||||
raise NotImplementedError(
|
||||
f"{type(self).__name__} does not support multimodal inputs yet."
|
||||
)
|
||||
return self.model(input_ids, positions, hidden_states)
|
||||
|
||||
def compute_logits(
|
||||
|
||||
Reference in New Issue
Block a user