[BugFix][Spec Decode] Fix out-of-range index triggered by eagle3; re-enable test for LlamaForCausalLMEagle3 (#24392)

Signed-off-by: wwl2755 <wangwenlong2755@gmail.com>
This commit is contained in:
Wenlong Wang
2025-09-09 21:24:23 -07:00
committed by GitHub
parent 309d7aa401
commit 53b42f4102
7 changed files with 58 additions and 41 deletions

View File

@@ -171,7 +171,22 @@ class LlamaAttention(nn.Module):
sliding_window = None
if layer_types := getattr(config, "layer_types", None):
is_sliding = layer_types[layer_idx] == "sliding_attention"
# Fix for Eagle3 compatibility:
# for draft models, subtract target layer count
# to get draft-relative layer index starting from 0
if hasattr(config, 'target_layer_count'):
# This is a draft model,
# adjust layer_idx to be relative to draft layers
effective_layer_idx = layer_idx - config.target_layer_count
else:
# This is a target model, use layer_idx directly
effective_layer_idx = layer_idx
assert effective_layer_idx < len(layer_types), \
f"effective_layer_idx: {effective_layer_idx} \
is out of bounds for layer_types: {layer_types}"
is_sliding = layer_types[
effective_layer_idx] == "sliding_attention"
if is_sliding:
sliding_window = config.sliding_window