[V1][Spec Decode] Apply torch.compile & cudagraph to EAGLE (#17211)

Signed-off-by: Bryan Lu <yuzhelu@amazon.com>
This commit is contained in:
Bryan Lu
2025-04-29 14:10:00 -07:00
committed by GitHub
parent c9c1b59e59
commit 70788bdbdc
6 changed files with 152 additions and 53 deletions

View File

@@ -6,7 +6,8 @@ import torch
import torch.nn as nn
from transformers import LlamaConfig
from vllm.config import ModelConfig
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -37,17 +38,19 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
self.input_layernorm = nn.Identity()
@support_torch_compile
class LlamaModel(nn.Module):
def __init__(
self,
*,
model_config: ModelConfig,
start_layer_id: int = 0,
vllm_config: VllmConfig,
prefix: str = "",
start_layer_id: int = 0,
) -> None:
super().__init__()
self.config = model_config.hf_config
self.config = vllm_config. \
speculative_config.draft_model_config.hf_config
self.vocab_size = self.config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
self.config.vocab_size,
@@ -75,8 +78,7 @@ class LlamaModel(nn.Module):
hidden_states = self.fc(
torch.cat((input_embeds, hidden_states), dim=-1))
residual = None
for i in range(len(self.layers)):
layer = self.layers[i]
for layer in self.layers:
hidden_states, residual = layer(
positions,
hidden_states,
@@ -117,12 +119,13 @@ class LlamaModel(nn.Module):
class EagleLlamaForCausalLM(LlamaForCausalLM):
def __init__(self, *, model_config: ModelConfig, start_layer_id: int = 0):
def __init__(self, *, vllm_config: VllmConfig, start_layer_id: int = 0):
nn.Module.__init__(self)
self.config = model_config.hf_config
self.model = LlamaModel(model_config=model_config,
start_layer_id=start_layer_id,
prefix="model")
self.config = vllm_config. \
speculative_config.draft_model_config.hf_config
self.model = LlamaModel(vllm_config=vllm_config,
prefix="model",
start_layer_id=start_layer_id)
logit_scale = getattr(self.config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.config.vocab_size,