[V1][Spec Decode] Apply torch.compile & cudagraph to EAGLE (#17211)
Signed-off-by: Bryan Lu <yuzhelu@amazon.com>
This commit is contained in:
@@ -6,7 +6,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from transformers import LlamaConfig
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import QKVParallelLinear
|
||||
@@ -167,8 +167,9 @@ class LlamaModel(nn.Module):
|
||||
|
||||
class Eagle3LlamaForCausalLM(LlamaForCausalLM):
|
||||
|
||||
def __init__(self, *, model_config: ModelConfig, start_layer_id: int = 0):
|
||||
def __init__(self, *, vllm_config: VllmConfig, start_layer_id: int = 0):
|
||||
nn.Module.__init__(self)
|
||||
model_config = vllm_config.speculative_config.draft_model_config
|
||||
self.config = model_config.hf_config
|
||||
self.model = LlamaModel(model_config=model_config,
|
||||
start_layer_id=start_layer_id,
|
||||
|
||||
Reference in New Issue
Block a user