[V1][Spec Decode] Apply torch.compile & cudagraph to EAGLE3 (#17504)
Signed-off-by: qizixi <qizixi@meta.com>
This commit is contained in:
@@ -6,7 +6,8 @@ import torch
|
||||
import torch.nn as nn
|
||||
from transformers import LlamaConfig
|
||||
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import QKVParallelLinear
|
||||
@@ -76,17 +77,19 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class LlamaModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
model_config: ModelConfig,
|
||||
vllm_config: VllmConfig,
|
||||
start_layer_id: int = 0,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = model_config.hf_config
|
||||
self.config = vllm_config. \
|
||||
speculative_config.draft_model_config.hf_config
|
||||
self.vocab_size = self.config.vocab_size
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
self.config.vocab_size,
|
||||
@@ -119,8 +122,7 @@ class LlamaModel(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
input_embeds = self.embed_tokens(input_ids)
|
||||
if (hidden_states.shape[-1] != input_embeds.shape[-1]):
|
||||
hidden_states = self.fc(hidden_states)
|
||||
assert hidden_states.shape[-1] == input_embeds.shape[-1]
|
||||
|
||||
residual = None
|
||||
hidden_states, residual = self.layers[0](
|
||||
@@ -169,9 +171,9 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, start_layer_id: int = 0):
|
||||
nn.Module.__init__(self)
|
||||
model_config = vllm_config.speculative_config.draft_model_config
|
||||
self.config = model_config.hf_config
|
||||
self.model = LlamaModel(model_config=model_config,
|
||||
self.config = vllm_config. \
|
||||
speculative_config.draft_model_config.hf_config
|
||||
self.model = LlamaModel(vllm_config=vllm_config,
|
||||
start_layer_id=start_layer_id,
|
||||
prefix="model")
|
||||
|
||||
@@ -214,6 +216,13 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
|
||||
logits_new[:, targets] = logits
|
||||
return logits_new
|
||||
|
||||
def combine_hidden_states(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# combine multiple auxiliary hidden states returned by eagle3
|
||||
return self.model.fc(hidden_states)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user