[BUGFIX] Fix accuracy bugs in Qwen3-Next MTP (#34077)
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
This commit is contained in:
@@ -208,7 +208,9 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
|
||||
non_spec_query_lens = query_lens[~spec_sequence_masks]
|
||||
num_decodes = (non_spec_query_lens == 1).sum().item()
|
||||
num_prefills = non_spec_query_lens.size(0) - num_decodes
|
||||
# Exclude zero-length padded sequences from prefill count.
|
||||
num_zero_len = (non_spec_query_lens == 0).sum().item()
|
||||
num_prefills = non_spec_query_lens.size(0) - num_decodes - num_zero_len
|
||||
num_decode_tokens = num_decodes
|
||||
num_prefill_tokens = non_spec_query_lens.sum().item() - num_decode_tokens
|
||||
num_spec_decode_tokens = (
|
||||
@@ -228,9 +230,15 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
non_spec_token_indx = torch.empty(
|
||||
0, dtype=torch.int32, device=query_start_loc.device
|
||||
)
|
||||
spec_state_indices_tensor = block_table_tensor[:, : self.num_spec + 1]
|
||||
# Filter by spec_sequence_masks to exclude padded sequences
|
||||
spec_state_indices_tensor = block_table_tensor[
|
||||
spec_sequence_masks, : self.num_spec + 1
|
||||
]
|
||||
non_spec_state_indices_tensor = None
|
||||
spec_query_start_loc = query_start_loc
|
||||
# Padded sequences are always at the back, so the first
|
||||
# num_spec_decodes + 1 entries of query_start_loc already
|
||||
# contain the correct cumulative token counts.
|
||||
spec_query_start_loc = query_start_loc[: num_spec_decodes + 1]
|
||||
non_spec_query_start_loc = None
|
||||
non_spec_query_start_loc_cpu = None
|
||||
else:
|
||||
@@ -294,6 +302,12 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
else:
|
||||
has_initial_state = None
|
||||
|
||||
# Function code counted on either presency non-spec decode or spec decode,
|
||||
# but not both.
|
||||
assert not (num_decodes > 0 and num_spec_decodes > 0), (
|
||||
f"num_decodes: {num_decodes}, num_spec_decodes: {num_spec_decodes}"
|
||||
)
|
||||
|
||||
# Prepare tensors for cudagraph
|
||||
# Note: m.num_actual_tokens is already padded by the model runner for CUDAGraph
|
||||
batch_size = m.num_actual_tokens
|
||||
@@ -312,7 +326,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
spec_state_indices_tensor[num_spec_decodes:].fill_(PAD_SLOT_ID)
|
||||
|
||||
self.spec_sequence_masks[:num_spec_decodes].copy_(
|
||||
spec_sequence_masks, non_blocking=True
|
||||
spec_sequence_masks[:num_spec_decodes], non_blocking=True
|
||||
)
|
||||
spec_sequence_masks = self.spec_sequence_masks[:batch_size]
|
||||
spec_sequence_masks[num_spec_decodes:].fill_(False)
|
||||
|
||||
Reference in New Issue
Block a user