[Refactor] Consolidate SupportsEagle (#36063)

Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
This commit is contained in:
Benjamin Chislett
2026-03-13 19:22:40 -04:00
committed by GitHub
parent 54a6db827f
commit 8b346309a5
24 changed files with 229 additions and 235 deletions

View File

@@ -61,6 +61,7 @@ from vllm.v1.attention.backend import AttentionType
from .adapters import as_embedding_model, as_seq_cls_model
from .interfaces import (
EagleModelMixin,
SupportsEagle,
SupportsEagle3,
SupportsLoRA,
@@ -351,7 +352,7 @@ def llama_model_invariants(
# mark_unbacked_dims={"input_ids": 0},
shape_invariants=llama_model_invariants
)
class LlamaModel(nn.Module):
class LlamaModel(nn.Module, EagleModelMixin):
def __init__(
self,
*,
@@ -389,8 +390,6 @@ class LlamaModel(nn.Module):
else:
self.norm = PPMissingLayer()
self.aux_hidden_state_layers = tuple[int, ...]()
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size
)
@@ -417,15 +416,16 @@ class LlamaModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
aux_hidden_states = []
aux_hidden_states = self._maybe_add_hidden_state([], 0, hidden_states, residual)
for idx, layer in enumerate(
islice(self.layers, self.start_layer, self.end_layer)
):
if idx in self.aux_hidden_state_layers:
aux_hidden_states.append(hidden_states + residual)
hidden_states, residual = layer(
positions, hidden_states, residual, **extra_layer_kwargs
)
self._maybe_add_hidden_state(
aux_hidden_states, idx + 1, hidden_states, residual
)
if not get_pp_group().is_last_rank:
return IntermediateTensors(
@@ -556,18 +556,6 @@ class LlamaForCausalLM(
self.model.make_empty_intermediate_tensors
)
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.model.aux_hidden_state_layers = layers
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
"""Override to return default layers for Llama
Note: The GPU model runner will override this with layers from
the speculative config if available, providing dynamic configuration.
"""
num_layers = len(self.model.layers)
return (2, num_layers // 2, num_layers - 3)
def _init_model(
self,
vllm_config: VllmConfig,