[Multimodal][Speculative Decoding]Eagle Eagle3 mm support, enablement on qwen2.5vl (#22872)
Signed-off-by: Junhong <liujunhong11@huawei.com> Signed-off-by: Junhong Liu <98734602+LJH-LBJ@users.noreply.github.com> Co-authored-by: Junhong <liujunhong11@huawei.com> Co-authored-by: LJH-LBJ <98734602+LJH-LBJ@users.noreply.github.com>
This commit is contained in:
@@ -8,7 +8,6 @@ import torch
|
||||
import torch.nn as nn
|
||||
from transformers import LlamaConfig
|
||||
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
@@ -19,6 +18,7 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
|
||||
from vllm.model_executor.models.llama import (LlamaDecoderLayer,
|
||||
LlamaForCausalLM)
|
||||
|
||||
@@ -102,7 +102,6 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class LlamaModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
@@ -145,13 +144,21 @@ class LlamaModel(nn.Module):
|
||||
eps=self.config.rms_norm_eps,
|
||||
)
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
input_embeds: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
input_embeds = self.embed_tokens(input_ids)
|
||||
if input_embeds is None:
|
||||
input_embeds = self.get_input_embeddings(input_ids)
|
||||
assert hidden_states.shape[-1] == input_embeds.shape[-1]
|
||||
|
||||
residual = None
|
||||
@@ -239,11 +246,7 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
|
||||
hidden_states: torch.Tensor,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if inputs_embeds is not None:
|
||||
raise NotImplementedError(
|
||||
f"{type(self).__name__} does not support multimodal inputs yet."
|
||||
)
|
||||
return self.model(input_ids, positions, hidden_states)
|
||||
return self.model(input_ids, positions, hidden_states, inputs_embeds)
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
@@ -299,3 +302,11 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
|
||||
skip_substrs=skip_substrs,
|
||||
)
|
||||
loader.load_weights(model_weights.items())
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.model.get_input_embeddings(input_ids)
|
||||
return inputs_embeds
|
||||
|
||||
Reference in New Issue
Block a user