[BUGFIX] Fix accuracy bugs in Qwen3-Next MTP (#34077)

Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
This commit is contained in:
Vadim Gimpelson
2026-02-10 19:57:11 +04:00
committed by GitHub
parent c5a66d1697
commit 000214c4bb

View File

@@ -208,7 +208,9 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
non_spec_query_lens = query_lens[~spec_sequence_masks]
num_decodes = (non_spec_query_lens == 1).sum().item()
num_prefills = non_spec_query_lens.size(0) - num_decodes
# Exclude zero-length padded sequences from prefill count.
num_zero_len = (non_spec_query_lens == 0).sum().item()
num_prefills = non_spec_query_lens.size(0) - num_decodes - num_zero_len
num_decode_tokens = num_decodes
num_prefill_tokens = non_spec_query_lens.sum().item() - num_decode_tokens
num_spec_decode_tokens = (
@@ -228,9 +230,15 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
non_spec_token_indx = torch.empty(
0, dtype=torch.int32, device=query_start_loc.device
)
spec_state_indices_tensor = block_table_tensor[:, : self.num_spec + 1]
# Filter by spec_sequence_masks to exclude padded sequences
spec_state_indices_tensor = block_table_tensor[
spec_sequence_masks, : self.num_spec + 1
]
non_spec_state_indices_tensor = None
spec_query_start_loc = query_start_loc
# Padded sequences are always at the back, so the first
# num_spec_decodes + 1 entries of query_start_loc already
# contain the correct cumulative token counts.
spec_query_start_loc = query_start_loc[: num_spec_decodes + 1]
non_spec_query_start_loc = None
non_spec_query_start_loc_cpu = None
else:
@@ -294,6 +302,12 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
else:
has_initial_state = None
# Function code counted on either presency non-spec decode or spec decode,
# but not both.
assert not (num_decodes > 0 and num_spec_decodes > 0), (
f"num_decodes: {num_decodes}, num_spec_decodes: {num_spec_decodes}"
)
# Prepare tensors for cudagraph
# Note: m.num_actual_tokens is already padded by the model runner for CUDAGraph
batch_size = m.num_actual_tokens
@@ -312,7 +326,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
spec_state_indices_tensor[num_spec_decodes:].fill_(PAD_SLOT_ID)
self.spec_sequence_masks[:num_spec_decodes].copy_(
spec_sequence_masks, non_blocking=True
spec_sequence_masks[:num_spec_decodes], non_blocking=True
)
spec_sequence_masks = self.spec_sequence_masks[:batch_size]
spec_sequence_masks[num_spec_decodes:].fill_(False)