[Qwen3-Next][GDN] fixes cuda graph capturing bug in GDN metadata and a stride bug in causal_conv_1d. (#25743)
Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com>
This commit is contained in:
@@ -125,7 +125,7 @@ class GDNAttentionMetadataBuilder(
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
num_accepted_tokens: Optional[torch.Tensor] = None,
|
||||
num_draft_tokens: Optional[torch.Tensor] = None,
|
||||
num_decode_draft_tokens_cpu: Optional[torch.Tensor] = None,
|
||||
fast_build: bool = False,
|
||||
) -> GDNAttentionMetadata:
|
||||
m = common_attn_metadata
|
||||
@@ -133,23 +133,25 @@ class GDNAttentionMetadataBuilder(
|
||||
query_start_loc = m.query_start_loc
|
||||
context_lens = m.num_computed_tokens_cpu
|
||||
context_lens_tensor = context_lens.to(query_start_loc.device)
|
||||
seq_lens_tensor = m.seq_lens
|
||||
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
|
||||
|
||||
if (not self.use_spec_decode or num_draft_tokens is None
|
||||
or num_draft_tokens.sum().item() == 0):
|
||||
if (not self.use_spec_decode or num_decode_draft_tokens_cpu is None
|
||||
or num_decode_draft_tokens_cpu[num_decode_draft_tokens_cpu >=
|
||||
0].sum().item() == 0):
|
||||
spec_sequence_masks = None
|
||||
num_spec_decodes = 0
|
||||
else:
|
||||
spec_sequence_masks = (num_draft_tokens > 0) & (
|
||||
context_lens_tensor +
|
||||
(num_draft_tokens + 1) == seq_lens_tensor)
|
||||
if spec_sequence_masks.sum().item() == 0:
|
||||
spec_sequence_masks = num_decode_draft_tokens_cpu >= 0
|
||||
num_spec_decodes = spec_sequence_masks.sum().item()
|
||||
if num_spec_decodes == 0:
|
||||
spec_sequence_masks = None
|
||||
else:
|
||||
spec_sequence_masks = spec_sequence_masks.to(
|
||||
query_start_loc.device, non_blocking=True)
|
||||
|
||||
if spec_sequence_masks is None:
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
split_decodes_and_prefills(m, decode_threshold=1))
|
||||
num_spec_decodes = 0
|
||||
num_spec_decode_tokens = 0
|
||||
spec_token_masks = None
|
||||
spec_state_indices_tensor = None
|
||||
@@ -158,7 +160,6 @@ class GDNAttentionMetadataBuilder(
|
||||
non_spec_query_start_loc = query_start_loc
|
||||
num_accepted_tokens = None
|
||||
else:
|
||||
num_spec_decodes = spec_sequence_masks.sum().item()
|
||||
query_lens = query_start_loc[1:] - query_start_loc[:-1]
|
||||
|
||||
non_spec_query_lens = query_lens[~spec_sequence_masks]
|
||||
@@ -314,28 +315,18 @@ class GDNAttentionMetadataBuilder(
|
||||
"""
|
||||
m = common_attn_metadata
|
||||
|
||||
assert (m.num_reqs * (self.num_spec + 1) <= m.num_actual_tokens
|
||||
and ((m.num_reqs + 1) * (self.num_spec + 1)
|
||||
>= m.num_actual_tokens)), \
|
||||
"GDN only supports decode-only full CUDAGraph capture. " \
|
||||
"Make sure all cudagraph capture sizes <= max_num_seq."
|
||||
assert (
|
||||
m.num_reqs <= self.decode_cudagraph_max_bs
|
||||
and m.num_actual_tokens <= self.decode_cudagraph_max_bs), (
|
||||
f"GDN only supports decode-only full CUDAGraph capture. "
|
||||
f"Make sure batch size ({m.num_reqs}) <= "
|
||||
f"cudagraph capture sizes ({self.decode_cudagraph_max_bs}), "
|
||||
f"and number of tokens ({m.num_actual_tokens}) <= "
|
||||
f"cudagraph capture sizes ({self.decode_cudagraph_max_bs}).")
|
||||
|
||||
num_accepted_tokens = torch.full((m.num_reqs, ),
|
||||
m.max_query_len,
|
||||
dtype=torch.int32,
|
||||
device=m.query_start_loc.device)
|
||||
num_drafted_tokens = torch.full((m.num_reqs, ),
|
||||
self.num_spec,
|
||||
dtype=torch.int32,
|
||||
device=m.query_start_loc.device)
|
||||
num_accepted_tokens = torch.diff(m.query_start_loc)
|
||||
num_decode_draft_tokens_cpu = (num_accepted_tokens - 1).cpu()
|
||||
m.num_computed_tokens_cpu = m.seq_lens_cpu - num_accepted_tokens.cpu()
|
||||
|
||||
# Fixes query-start loc for spec-sequence-indices.
|
||||
m.query_start_loc = torch.arange(0,
|
||||
m.num_actual_tokens + 1,
|
||||
step=m.max_query_len,
|
||||
device=m.query_start_loc.device,
|
||||
dtype=torch.int32)
|
||||
m.num_computed_tokens_cpu = (m.seq_lens_cpu - torch.full(
|
||||
(m.num_reqs, ), m.max_query_len, dtype=torch.int32, device='cpu'))
|
||||
|
||||
return self.build(0, m, num_accepted_tokens, num_drafted_tokens)
|
||||
return self.build(0, m, num_accepted_tokens,
|
||||
num_decode_draft_tokens_cpu)
|
||||
|
||||
Reference in New Issue
Block a user