[Qwen3-Next][GDN] fixes cuda graph capturing bug in GDN metadata and a stride bug in causal_conv_1d. (#25743)

Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com>
This commit is contained in:
Tao He
2025-09-26 16:18:58 +08:00
committed by GitHub
parent 6e30010d2f
commit 99b3a504c5
3 changed files with 48 additions and 43 deletions

View File

@@ -125,7 +125,7 @@ class GDNAttentionMetadataBuilder(
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
num_accepted_tokens: Optional[torch.Tensor] = None,
num_draft_tokens: Optional[torch.Tensor] = None,
num_decode_draft_tokens_cpu: Optional[torch.Tensor] = None,
fast_build: bool = False,
) -> GDNAttentionMetadata:
m = common_attn_metadata
@@ -133,23 +133,25 @@ class GDNAttentionMetadataBuilder(
query_start_loc = m.query_start_loc
context_lens = m.num_computed_tokens_cpu
context_lens_tensor = context_lens.to(query_start_loc.device)
seq_lens_tensor = m.seq_lens
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
if (not self.use_spec_decode or num_draft_tokens is None
or num_draft_tokens.sum().item() == 0):
if (not self.use_spec_decode or num_decode_draft_tokens_cpu is None
or num_decode_draft_tokens_cpu[num_decode_draft_tokens_cpu >=
0].sum().item() == 0):
spec_sequence_masks = None
num_spec_decodes = 0
else:
spec_sequence_masks = (num_draft_tokens > 0) & (
context_lens_tensor +
(num_draft_tokens + 1) == seq_lens_tensor)
if spec_sequence_masks.sum().item() == 0:
spec_sequence_masks = num_decode_draft_tokens_cpu >= 0
num_spec_decodes = spec_sequence_masks.sum().item()
if num_spec_decodes == 0:
spec_sequence_masks = None
else:
spec_sequence_masks = spec_sequence_masks.to(
query_start_loc.device, non_blocking=True)
if spec_sequence_masks is None:
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(m, decode_threshold=1))
num_spec_decodes = 0
num_spec_decode_tokens = 0
spec_token_masks = None
spec_state_indices_tensor = None
@@ -158,7 +160,6 @@ class GDNAttentionMetadataBuilder(
non_spec_query_start_loc = query_start_loc
num_accepted_tokens = None
else:
num_spec_decodes = spec_sequence_masks.sum().item()
query_lens = query_start_loc[1:] - query_start_loc[:-1]
non_spec_query_lens = query_lens[~spec_sequence_masks]
@@ -314,28 +315,18 @@ class GDNAttentionMetadataBuilder(
"""
m = common_attn_metadata
assert (m.num_reqs * (self.num_spec + 1) <= m.num_actual_tokens
and ((m.num_reqs + 1) * (self.num_spec + 1)
>= m.num_actual_tokens)), \
"GDN only supports decode-only full CUDAGraph capture. " \
"Make sure all cudagraph capture sizes <= max_num_seq."
assert (
m.num_reqs <= self.decode_cudagraph_max_bs
and m.num_actual_tokens <= self.decode_cudagraph_max_bs), (
f"GDN only supports decode-only full CUDAGraph capture. "
f"Make sure batch size ({m.num_reqs}) <= "
f"cudagraph capture sizes ({self.decode_cudagraph_max_bs}), "
f"and number of tokens ({m.num_actual_tokens}) <= "
f"cudagraph capture sizes ({self.decode_cudagraph_max_bs}).")
num_accepted_tokens = torch.full((m.num_reqs, ),
m.max_query_len,
dtype=torch.int32,
device=m.query_start_loc.device)
num_drafted_tokens = torch.full((m.num_reqs, ),
self.num_spec,
dtype=torch.int32,
device=m.query_start_loc.device)
num_accepted_tokens = torch.diff(m.query_start_loc)
num_decode_draft_tokens_cpu = (num_accepted_tokens - 1).cpu()
m.num_computed_tokens_cpu = m.seq_lens_cpu - num_accepted_tokens.cpu()
# Fixes query-start loc for spec-sequence-indices.
m.query_start_loc = torch.arange(0,
m.num_actual_tokens + 1,
step=m.max_query_len,
device=m.query_start_loc.device,
dtype=torch.int32)
m.num_computed_tokens_cpu = (m.seq_lens_cpu - torch.full(
(m.num_reqs, ), m.max_query_len, dtype=torch.int32, device='cpu'))
return self.build(0, m, num_accepted_tokens, num_drafted_tokens)
return self.build(0, m, num_accepted_tokens,
num_decode_draft_tokens_cpu)