[V1][Spec Decode] Apply torch.compile & cudagraph to EAGLE (#17211)

Signed-off-by: Bryan Lu <yuzhelu@amazon.com>
This commit is contained in:
Bryan Lu
2025-04-29 14:10:00 -07:00
committed by GitHub
parent c9c1b59e59
commit 70788bdbdc
6 changed files with 152 additions and 53 deletions

View File

@@ -6,7 +6,7 @@ import torch
import torch.nn as nn
from transformers import LlamaConfig
from vllm.config import ModelConfig
from vllm.config import ModelConfig, VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import QKVParallelLinear
@@ -167,8 +167,9 @@ class LlamaModel(nn.Module):
class Eagle3LlamaForCausalLM(LlamaForCausalLM):
def __init__(self, *, model_config: ModelConfig, start_layer_id: int = 0):
def __init__(self, *, vllm_config: VllmConfig, start_layer_id: int = 0):
nn.Module.__init__(self)
model_config = vllm_config.speculative_config.draft_model_config
self.config = model_config.hf_config
self.model = LlamaModel(model_config=model_config,
start_layer_id=start_layer_id,