[Speculators][Speculative Decoding] Add Qwen Eagle3 Support (#21835)

Signed-off-by: Dipika Sikka <dipikasikka1@gmail.com>
This commit is contained in:
Dipika Sikka
2025-08-01 22:43:37 -04:00
committed by GitHub
parent a65f46be5e
commit 9f9c38c392
4 changed files with 46 additions and 11 deletions

View File

@@ -330,6 +330,8 @@ class Qwen2Model(nn.Module):
else:
self.norm = PPMissingLayer()
self.aux_hidden_state_layers: tuple[int] = tuple()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
@@ -350,18 +352,25 @@ class Qwen2Model(nn.Module):
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer:self.end_layer]:
hidden_states, residual = layer(
positions,
hidden_states,
residual,
)
aux_hidden_states = []
for idx, layer in enumerate(
self.layers[self.start_layer:self.end_layer]):
if idx in self.aux_hidden_state_layers:
aux_hidden_states.append(hidden_states + residual)
hidden_states, residual = layer(positions, hidden_states, residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
if len(aux_hidden_states) > 0:
return hidden_states, aux_hidden_states
return hidden_states
def load_weights(self, weights: Iterable[tuple[str,