[V1][Spec Decode] EAGLE-3 Support (#16937)
Signed-off-by: Bryan Lu <yuzhelu@amazon.com> Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai> Co-authored-by: Bryan Lu <yuzhelu@amazon.com>
This commit is contained in:
committed by
GitHub
parent
70116459c3
commit
a0e619e62a
@@ -330,6 +330,8 @@ class LlamaModel(nn.Module):
|
||||
else:
|
||||
self.norm = PPMissingLayer()
|
||||
|
||||
self.aux_hidden_state_layers: tuple[int] = tuple()
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size))
|
||||
@@ -355,7 +357,11 @@ class LlamaModel(nn.Module):
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||
aux_hidden_states = []
|
||||
for idx, layer in enumerate(
|
||||
self.layers[self.start_layer:self.end_layer]):
|
||||
if idx in self.aux_hidden_state_layers:
|
||||
aux_hidden_states.append(hidden_states + residual)
|
||||
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
@@ -365,6 +371,9 @@ class LlamaModel(nn.Module):
|
||||
})
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
|
||||
if len(aux_hidden_states) > 0:
|
||||
return hidden_states, aux_hidden_states
|
||||
return hidden_states
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
@@ -517,6 +526,13 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def set_aux_hidden_state_layers(self, layers: tuple[int]) -> None:
|
||||
self.model.aux_hidden_state_layers = layers
|
||||
|
||||
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int]:
|
||||
num_layers = len(self.model.layers)
|
||||
return (2, num_layers // 2, num_layers - 3)
|
||||
|
||||
def _init_model(self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
|
||||
Reference in New Issue
Block a user