[BUGFIX] Fix accuracy bugs in Qwen3-Next MTP (#34077)
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
This commit is contained in:
@@ -208,7 +208,9 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
|||||||
|
|
||||||
non_spec_query_lens = query_lens[~spec_sequence_masks]
|
non_spec_query_lens = query_lens[~spec_sequence_masks]
|
||||||
num_decodes = (non_spec_query_lens == 1).sum().item()
|
num_decodes = (non_spec_query_lens == 1).sum().item()
|
||||||
num_prefills = non_spec_query_lens.size(0) - num_decodes
|
# Exclude zero-length padded sequences from prefill count.
|
||||||
|
num_zero_len = (non_spec_query_lens == 0).sum().item()
|
||||||
|
num_prefills = non_spec_query_lens.size(0) - num_decodes - num_zero_len
|
||||||
num_decode_tokens = num_decodes
|
num_decode_tokens = num_decodes
|
||||||
num_prefill_tokens = non_spec_query_lens.sum().item() - num_decode_tokens
|
num_prefill_tokens = non_spec_query_lens.sum().item() - num_decode_tokens
|
||||||
num_spec_decode_tokens = (
|
num_spec_decode_tokens = (
|
||||||
@@ -228,9 +230,15 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
|||||||
non_spec_token_indx = torch.empty(
|
non_spec_token_indx = torch.empty(
|
||||||
0, dtype=torch.int32, device=query_start_loc.device
|
0, dtype=torch.int32, device=query_start_loc.device
|
||||||
)
|
)
|
||||||
spec_state_indices_tensor = block_table_tensor[:, : self.num_spec + 1]
|
# Filter by spec_sequence_masks to exclude padded sequences
|
||||||
|
spec_state_indices_tensor = block_table_tensor[
|
||||||
|
spec_sequence_masks, : self.num_spec + 1
|
||||||
|
]
|
||||||
non_spec_state_indices_tensor = None
|
non_spec_state_indices_tensor = None
|
||||||
spec_query_start_loc = query_start_loc
|
# Padded sequences are always at the back, so the first
|
||||||
|
# num_spec_decodes + 1 entries of query_start_loc already
|
||||||
|
# contain the correct cumulative token counts.
|
||||||
|
spec_query_start_loc = query_start_loc[: num_spec_decodes + 1]
|
||||||
non_spec_query_start_loc = None
|
non_spec_query_start_loc = None
|
||||||
non_spec_query_start_loc_cpu = None
|
non_spec_query_start_loc_cpu = None
|
||||||
else:
|
else:
|
||||||
@@ -294,6 +302,12 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
|||||||
else:
|
else:
|
||||||
has_initial_state = None
|
has_initial_state = None
|
||||||
|
|
||||||
|
# Function code counted on either presency non-spec decode or spec decode,
|
||||||
|
# but not both.
|
||||||
|
assert not (num_decodes > 0 and num_spec_decodes > 0), (
|
||||||
|
f"num_decodes: {num_decodes}, num_spec_decodes: {num_spec_decodes}"
|
||||||
|
)
|
||||||
|
|
||||||
# Prepare tensors for cudagraph
|
# Prepare tensors for cudagraph
|
||||||
# Note: m.num_actual_tokens is already padded by the model runner for CUDAGraph
|
# Note: m.num_actual_tokens is already padded by the model runner for CUDAGraph
|
||||||
batch_size = m.num_actual_tokens
|
batch_size = m.num_actual_tokens
|
||||||
@@ -312,7 +326,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
|||||||
spec_state_indices_tensor[num_spec_decodes:].fill_(PAD_SLOT_ID)
|
spec_state_indices_tensor[num_spec_decodes:].fill_(PAD_SLOT_ID)
|
||||||
|
|
||||||
self.spec_sequence_masks[:num_spec_decodes].copy_(
|
self.spec_sequence_masks[:num_spec_decodes].copy_(
|
||||||
spec_sequence_masks, non_blocking=True
|
spec_sequence_masks[:num_spec_decodes], non_blocking=True
|
||||||
)
|
)
|
||||||
spec_sequence_masks = self.spec_sequence_masks[:batch_size]
|
spec_sequence_masks = self.spec_sequence_masks[:batch_size]
|
||||||
spec_sequence_masks[num_spec_decodes:].fill_(False)
|
spec_sequence_masks[num_spec_decodes:].fill_(False)
|
||||||
|
|||||||
Reference in New Issue
Block a user