[V1][Spec Decode] EAGLE-3 Support (#16937)
Signed-off-by: Bryan Lu <yuzhelu@amazon.com> Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai> Co-authored-by: Bryan Lu <yuzhelu@amazon.com>
This commit is contained in:
committed by
GitHub
parent
70116459c3
commit
a0e619e62a
232
vllm/model_executor/models/llama_eagle3.py
Normal file
232
vllm/model_executor/models/llama_eagle3.py
Normal file
@@ -0,0 +1,232 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Iterable, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import LlamaConfig
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import QKVParallelLinear
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.llama import (LlamaDecoderLayer,
|
||||
LlamaForCausalLM)
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
|
||||
from .utils import AutoWeightsLoader, maybe_prefix
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class LlamaDecoderLayer(LlamaDecoderLayer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__(config, quant_config=quant_config, prefix=prefix)
|
||||
|
||||
# override qkv
|
||||
self.self_attn.qkv_proj = QKVParallelLinear(
|
||||
2 * self.hidden_size,
|
||||
self.self_attn.head_dim,
|
||||
self.self_attn.total_num_heads,
|
||||
self.self_attn.total_num_kv_heads,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "qkv_proj"),
|
||||
)
|
||||
|
||||
self.hidden_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
embeds: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
residual = hidden_states
|
||||
embeds = self.input_layernorm(embeds)
|
||||
hidden_states = self.hidden_norm(hidden_states)
|
||||
|
||||
hidden_states = torch.cat([embeds, hidden_states], dim=-1)
|
||||
# Self Attention
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
)
|
||||
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
class LlamaModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
model_config: ModelConfig,
|
||||
start_layer_id: int = 0,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = model_config.hf_config
|
||||
self.vocab_size = self.config.vocab_size
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
self.config.vocab_size,
|
||||
self.config.hidden_size,
|
||||
prefix=maybe_prefix(prefix, "embed_tokens"),
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
LlamaDecoderLayer(
|
||||
self.config,
|
||||
prefix=maybe_prefix(prefix, f"layers.{start_layer_id}"),
|
||||
)
|
||||
])
|
||||
if hasattr(self.config, "target_hidden_size"):
|
||||
self.fc = torch.nn.Linear(self.config.target_hidden_size * 3,
|
||||
self.config.hidden_size,
|
||||
bias=False)
|
||||
else:
|
||||
self.fc = torch.nn.Linear(self.config.hidden_size * 3,
|
||||
self.config.hidden_size,
|
||||
bias=False)
|
||||
self.norm = RMSNorm(
|
||||
self.config.hidden_size,
|
||||
eps=self.config.rms_norm_eps,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
input_embeds = self.embed_tokens(input_ids)
|
||||
if (hidden_states.shape[-1] != input_embeds.shape[-1]):
|
||||
hidden_states = self.fc(hidden_states)
|
||||
|
||||
residual = None
|
||||
hidden_states, residual = self.layers[0](
|
||||
positions,
|
||||
input_embeds,
|
||||
hidden_states,
|
||||
residual,
|
||||
)
|
||||
|
||||
hidden_states, hidden_prenorm = self.norm(hidden_states, residual)
|
||||
return hidden_states, hidden_prenorm
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
(".qkv_proj", ".q_proj", "q"),
|
||||
(".qkv_proj", ".k_proj", "k"),
|
||||
(".qkv_proj", ".v_proj", "v"),
|
||||
(".gate_up_proj", ".gate_proj", 0),
|
||||
(".gate_up_proj", ".up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if 'midlayer.' in name:
|
||||
name = name.replace('midlayer.', 'layers.0.')
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
|
||||
class Eagle3LlamaForCausalLM(LlamaForCausalLM):
|
||||
|
||||
def __init__(self, *, model_config: ModelConfig, start_layer_id: int = 0):
|
||||
nn.Module.__init__(self)
|
||||
self.config = model_config.hf_config
|
||||
self.model = LlamaModel(model_config=model_config,
|
||||
start_layer_id=start_layer_id,
|
||||
prefix="model")
|
||||
|
||||
logit_scale = getattr(self.config, "logit_scale", 1.0)
|
||||
self.lm_head = ParallelLMHead(
|
||||
self.config.draft_vocab_size,
|
||||
self.config.hidden_size,
|
||||
org_num_embeddings=self.config.draft_vocab_size,
|
||||
padding_size=(DEFAULT_VOCAB_PADDING_SIZE),
|
||||
prefix="")
|
||||
self.logits_processor = LogitsProcessor(self.config.draft_vocab_size,
|
||||
scale=logit_scale)
|
||||
self.draft_id_to_target_id = nn.Parameter(
|
||||
torch.zeros((self.config.draft_vocab_size),
|
||||
dtype=torch.long).type(torch.LongTensor),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return self.model(input_ids, positions, hidden_states)
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
base = torch.arange(self.config.draft_vocab_size, device=logits.device)
|
||||
targets = base + self.draft_id_to_target_id
|
||||
logits_new = logits.new_full((
|
||||
logits.shape[0],
|
||||
self.config.vocab_size,
|
||||
), float('-inf'))
|
||||
logits_new[:, targets] = logits
|
||||
return logits_new
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_prefixes=None,
|
||||
)
|
||||
|
||||
model_weights = {}
|
||||
for name, loaded_weight in weights:
|
||||
if "t2d" in name:
|
||||
continue
|
||||
if "d2t" in name:
|
||||
name = name.replace("d2t", "draft_id_to_target_id")
|
||||
elif "lm_head" not in name:
|
||||
name = "model." + name
|
||||
model_weights[name] = loaded_weight
|
||||
|
||||
return loader.load_weights(model_weights.items())
|
||||
Reference in New Issue
Block a user