[Refactor] Consolidate SupportsEagle (#36063)
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
This commit is contained in:
committed by
GitHub
parent
54a6db827f
commit
8b346309a5
@@ -13,15 +13,15 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.models.interfaces import EagleModelMixin
|
||||
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
|
||||
class PredictableLlamaModel(nn.Module):
|
||||
class PredictableLlamaModel(nn.Module, EagleModelMixin):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
self.config = vllm_config.model_config.hf_config
|
||||
self.aux_hidden_state_layers = tuple[int, ...]()
|
||||
|
||||
# Create minimal embed_tokens for embedding
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
|
||||
Reference in New Issue
Block a user