[V1][Spec Decode] Apply torch.compile & cudagraph to EAGLE (#17211)
Signed-off-by: Bryan Lu <yuzhelu@amazon.com>
This commit is contained in:
@@ -6,7 +6,8 @@ import torch
|
||||
import torch.nn as nn
|
||||
from transformers import LlamaConfig
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@@ -37,17 +38,19 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
|
||||
self.input_layernorm = nn.Identity()
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class LlamaModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
model_config: ModelConfig,
|
||||
start_layer_id: int = 0,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
start_layer_id: int = 0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = model_config.hf_config
|
||||
self.config = vllm_config. \
|
||||
speculative_config.draft_model_config.hf_config
|
||||
self.vocab_size = self.config.vocab_size
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
self.config.vocab_size,
|
||||
@@ -75,8 +78,7 @@ class LlamaModel(nn.Module):
|
||||
hidden_states = self.fc(
|
||||
torch.cat((input_embeds, hidden_states), dim=-1))
|
||||
residual = None
|
||||
for i in range(len(self.layers)):
|
||||
layer = self.layers[i]
|
||||
for layer in self.layers:
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
@@ -117,12 +119,13 @@ class LlamaModel(nn.Module):
|
||||
|
||||
class EagleLlamaForCausalLM(LlamaForCausalLM):
|
||||
|
||||
def __init__(self, *, model_config: ModelConfig, start_layer_id: int = 0):
|
||||
def __init__(self, *, vllm_config: VllmConfig, start_layer_id: int = 0):
|
||||
nn.Module.__init__(self)
|
||||
self.config = model_config.hf_config
|
||||
self.model = LlamaModel(model_config=model_config,
|
||||
start_layer_id=start_layer_id,
|
||||
prefix="model")
|
||||
self.config = vllm_config. \
|
||||
speculative_config.draft_model_config.hf_config
|
||||
self.model = LlamaModel(vllm_config=vllm_config,
|
||||
prefix="model",
|
||||
start_layer_id=start_layer_id)
|
||||
|
||||
logit_scale = getattr(self.config, "logit_scale", 1.0)
|
||||
self.logits_processor = LogitsProcessor(self.config.vocab_size,
|
||||
|
||||
Reference in New Issue
Block a user