[V1][Spec Decode] Apply torch.compile & cudagraph to EAGLE3 (#17504)

Signed-off-by: qizixi <qizixi@meta.com>
This commit is contained in:
qizixi
2025-05-01 16:19:30 -07:00
committed by GitHub
parent 9b70e2b4c1
commit 39c0813a7f
2 changed files with 36 additions and 31 deletions

View File

@@ -6,7 +6,8 @@ import torch
import torch.nn as nn
from transformers import LlamaConfig
from vllm.config import ModelConfig, VllmConfig
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import QKVParallelLinear
@@ -76,17 +77,19 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
return hidden_states, residual
@support_torch_compile
class LlamaModel(nn.Module):
def __init__(
self,
*,
model_config: ModelConfig,
vllm_config: VllmConfig,
start_layer_id: int = 0,
prefix: str = "",
) -> None:
super().__init__()
self.config = model_config.hf_config
self.config = vllm_config. \
speculative_config.draft_model_config.hf_config
self.vocab_size = self.config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
self.config.vocab_size,
@@ -119,8 +122,7 @@ class LlamaModel(nn.Module):
hidden_states: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
input_embeds = self.embed_tokens(input_ids)
if (hidden_states.shape[-1] != input_embeds.shape[-1]):
hidden_states = self.fc(hidden_states)
assert hidden_states.shape[-1] == input_embeds.shape[-1]
residual = None
hidden_states, residual = self.layers[0](
@@ -169,9 +171,9 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
def __init__(self, *, vllm_config: VllmConfig, start_layer_id: int = 0):
nn.Module.__init__(self)
model_config = vllm_config.speculative_config.draft_model_config
self.config = model_config.hf_config
self.model = LlamaModel(model_config=model_config,
self.config = vllm_config. \
speculative_config.draft_model_config.hf_config
self.model = LlamaModel(vllm_config=vllm_config,
start_layer_id=start_layer_id,
prefix="model")
@@ -214,6 +216,13 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
logits_new[:, targets] = logits
return logits_new
def combine_hidden_states(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
# combine multiple auxiliary hidden states returned by eagle3
return self.model.fc(hidden_states)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(
self,