[Bugfix][Qwen3-Next] fixes the varlen issue in qwen3-next's MTP implementation. (#24957)

Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com>
This commit is contained in:
Tao He
2025-09-17 21:59:09 +08:00
committed by GitHub
parent 1b962e2457
commit dd6a910aac
3 changed files with 139 additions and 34 deletions

View File

@@ -31,6 +31,7 @@ class GDNAttentionMetadata:
num_decode_tokens: int
num_spec_decodes: int
num_spec_decode_tokens: int
num_actual_tokens: int
has_initial_state: Optional[torch.Tensor] = None
@@ -74,8 +75,8 @@ class GDNAttentionMetadataBuilder(
self.use_full_cuda_graph = \
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
self.decode_cudagraph_max_bs = min(
self.vllm_config.scheduler_config.max_num_seqs,
self.compilation_config.max_capture_size)
self.vllm_config.scheduler_config.max_num_seqs *
(self.num_spec + 1), self.compilation_config.max_capture_size)
self.spec_state_indices_tensor = torch.empty(
(self.decode_cudagraph_max_bs, self.num_spec + 1),
@@ -194,9 +195,8 @@ class GDNAttentionMetadataBuilder(
dim=0,
out=non_spec_query_start_loc[1:])
num_spec_decode_tokens = min(
num_spec_decodes * (self.num_spec + 1),
spec_token_masks.size(0))
num_spec_decode_tokens = (query_lens.sum().item() -
num_prefill_tokens - num_decode_tokens)
assert num_accepted_tokens is not None
num_accepted_tokens = num_accepted_tokens[spec_sequence_masks]
@@ -206,14 +206,22 @@ class GDNAttentionMetadataBuilder(
has_initial_state = has_initial_state[~spec_sequence_masks]
else:
has_initial_state = None
num_actual_tokens = num_prefill_tokens + num_decode_tokens + \
num_spec_decode_tokens
# prepare tensors for cudagraph
#
# With speculative decoding, the xgrammar backend may rollback tokens
# and causing some sequences has less draft tokens than self.num_spec.
#
# In above cases, the max possible batch size for n tokens, can be
# min(n, cudagraph_max_bs).
if (self.use_full_cuda_graph and num_prefills == 0 and num_decodes == 0
and num_spec_decodes <= self.decode_cudagraph_max_bs
and m.num_actual_tokens <= self.decode_cudagraph_max_bs):
num_total_tokens = self.vllm_config.pad_for_cudagraph(
and num_spec_decode_tokens <= self.decode_cudagraph_max_bs):
num_actual_tokens = self.vllm_config.pad_for_cudagraph(
m.num_actual_tokens)
batch_size = num_total_tokens // (self.num_spec + 1)
batch_size = min(self.decode_cudagraph_max_bs, num_actual_tokens)
self.spec_state_indices_tensor[:num_spec_decodes].copy_(
spec_state_indices_tensor, non_blocking=True)
@@ -229,7 +237,7 @@ class GDNAttentionMetadataBuilder(
assert spec_token_masks is not None
self.spec_token_masks[:spec_token_masks.size(0)].copy_(
spec_token_masks, non_blocking=True)
spec_token_masks = self.spec_token_masks[:m.num_actual_tokens]
spec_token_masks = self.spec_token_masks[:num_actual_tokens]
spec_token_masks[spec_token_masks.size(0):].fill_(False)
self.spec_query_start_loc[:num_spec_decodes + 1].copy_(
@@ -248,9 +256,9 @@ class GDNAttentionMetadataBuilder(
if (self.use_full_cuda_graph and num_prefills == 0
and num_spec_decodes == 0
and num_decodes <= self.decode_cudagraph_max_bs):
num_total_tokens = self.vllm_config.pad_for_cudagraph(
num_actual_tokens = self.vllm_config.pad_for_cudagraph(
m.num_actual_tokens)
batch_size = num_total_tokens
batch_size = num_actual_tokens
self.non_spec_state_indices_tensor[:num_decodes].copy_(
non_spec_state_indices_tensor, non_blocking=True)
@@ -274,6 +282,7 @@ class GDNAttentionMetadataBuilder(
num_decode_tokens=num_decode_tokens,
num_spec_decodes=num_spec_decodes,
num_spec_decode_tokens=num_spec_decode_tokens,
num_actual_tokens=num_actual_tokens,
has_initial_state=has_initial_state,
spec_query_start_loc=spec_query_start_loc,
non_spec_query_start_loc=non_spec_query_start_loc,