[V1][Spec Decode] EAGLE-3 Support (#16937)

Signed-off-by: Bryan Lu <yuzhelu@amazon.com>
Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai>
Co-authored-by: Bryan Lu <yuzhelu@amazon.com>
This commit is contained in:
Benjamin Chislett
2025-04-25 18:43:07 -04:00
committed by GitHub
parent 70116459c3
commit a0e619e62a
12 changed files with 358 additions and 34 deletions

View File

@@ -330,6 +330,8 @@ class LlamaModel(nn.Module):
else:
self.norm = PPMissingLayer()
self.aux_hidden_state_layers: tuple[int] = tuple()
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
@@ -355,7 +357,11 @@ class LlamaModel(nn.Module):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer:self.end_layer]:
aux_hidden_states = []
for idx, layer in enumerate(
self.layers[self.start_layer:self.end_layer]):
if idx in self.aux_hidden_state_layers:
aux_hidden_states.append(hidden_states + residual)
hidden_states, residual = layer(positions, hidden_states, residual)
if not get_pp_group().is_last_rank:
@@ -365,6 +371,9 @@ class LlamaModel(nn.Module):
})
hidden_states, _ = self.norm(hidden_states, residual)
if len(aux_hidden_states) > 0:
return hidden_states, aux_hidden_states
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str,
@@ -517,6 +526,13 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def set_aux_hidden_state_layers(self, layers: tuple[int]) -> None:
self.model.aux_hidden_state_layers = layers
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int]:
num_layers = len(self.model.layers)
return (2, num_layers // 2, num_layers - 3)
def _init_model(self,
vllm_config: VllmConfig,
prefix: str = "",