[Bugfix][Qwen3-Next] fixes the varlen issue in qwen3-next's MTP implementation. (#24957)
Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com>
This commit is contained in:
@@ -31,6 +31,7 @@ class GDNAttentionMetadata:
|
||||
num_decode_tokens: int
|
||||
num_spec_decodes: int
|
||||
num_spec_decode_tokens: int
|
||||
num_actual_tokens: int
|
||||
|
||||
has_initial_state: Optional[torch.Tensor] = None
|
||||
|
||||
@@ -74,8 +75,8 @@ class GDNAttentionMetadataBuilder(
|
||||
self.use_full_cuda_graph = \
|
||||
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||
self.decode_cudagraph_max_bs = min(
|
||||
self.vllm_config.scheduler_config.max_num_seqs,
|
||||
self.compilation_config.max_capture_size)
|
||||
self.vllm_config.scheduler_config.max_num_seqs *
|
||||
(self.num_spec + 1), self.compilation_config.max_capture_size)
|
||||
|
||||
self.spec_state_indices_tensor = torch.empty(
|
||||
(self.decode_cudagraph_max_bs, self.num_spec + 1),
|
||||
@@ -194,9 +195,8 @@ class GDNAttentionMetadataBuilder(
|
||||
dim=0,
|
||||
out=non_spec_query_start_loc[1:])
|
||||
|
||||
num_spec_decode_tokens = min(
|
||||
num_spec_decodes * (self.num_spec + 1),
|
||||
spec_token_masks.size(0))
|
||||
num_spec_decode_tokens = (query_lens.sum().item() -
|
||||
num_prefill_tokens - num_decode_tokens)
|
||||
assert num_accepted_tokens is not None
|
||||
num_accepted_tokens = num_accepted_tokens[spec_sequence_masks]
|
||||
|
||||
@@ -206,14 +206,22 @@ class GDNAttentionMetadataBuilder(
|
||||
has_initial_state = has_initial_state[~spec_sequence_masks]
|
||||
else:
|
||||
has_initial_state = None
|
||||
num_actual_tokens = num_prefill_tokens + num_decode_tokens + \
|
||||
num_spec_decode_tokens
|
||||
|
||||
# prepare tensors for cudagraph
|
||||
#
|
||||
# With speculative decoding, the xgrammar backend may rollback tokens
|
||||
# and causing some sequences has less draft tokens than self.num_spec.
|
||||
#
|
||||
# In above cases, the max possible batch size for n tokens, can be
|
||||
# min(n, cudagraph_max_bs).
|
||||
if (self.use_full_cuda_graph and num_prefills == 0 and num_decodes == 0
|
||||
and num_spec_decodes <= self.decode_cudagraph_max_bs
|
||||
and m.num_actual_tokens <= self.decode_cudagraph_max_bs):
|
||||
num_total_tokens = self.vllm_config.pad_for_cudagraph(
|
||||
and num_spec_decode_tokens <= self.decode_cudagraph_max_bs):
|
||||
num_actual_tokens = self.vllm_config.pad_for_cudagraph(
|
||||
m.num_actual_tokens)
|
||||
batch_size = num_total_tokens // (self.num_spec + 1)
|
||||
batch_size = min(self.decode_cudagraph_max_bs, num_actual_tokens)
|
||||
|
||||
self.spec_state_indices_tensor[:num_spec_decodes].copy_(
|
||||
spec_state_indices_tensor, non_blocking=True)
|
||||
@@ -229,7 +237,7 @@ class GDNAttentionMetadataBuilder(
|
||||
assert spec_token_masks is not None
|
||||
self.spec_token_masks[:spec_token_masks.size(0)].copy_(
|
||||
spec_token_masks, non_blocking=True)
|
||||
spec_token_masks = self.spec_token_masks[:m.num_actual_tokens]
|
||||
spec_token_masks = self.spec_token_masks[:num_actual_tokens]
|
||||
spec_token_masks[spec_token_masks.size(0):].fill_(False)
|
||||
|
||||
self.spec_query_start_loc[:num_spec_decodes + 1].copy_(
|
||||
@@ -248,9 +256,9 @@ class GDNAttentionMetadataBuilder(
|
||||
if (self.use_full_cuda_graph and num_prefills == 0
|
||||
and num_spec_decodes == 0
|
||||
and num_decodes <= self.decode_cudagraph_max_bs):
|
||||
num_total_tokens = self.vllm_config.pad_for_cudagraph(
|
||||
num_actual_tokens = self.vllm_config.pad_for_cudagraph(
|
||||
m.num_actual_tokens)
|
||||
batch_size = num_total_tokens
|
||||
batch_size = num_actual_tokens
|
||||
|
||||
self.non_spec_state_indices_tensor[:num_decodes].copy_(
|
||||
non_spec_state_indices_tensor, non_blocking=True)
|
||||
@@ -274,6 +282,7 @@ class GDNAttentionMetadataBuilder(
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
num_spec_decodes=num_spec_decodes,
|
||||
num_spec_decode_tokens=num_spec_decode_tokens,
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
has_initial_state=has_initial_state,
|
||||
spec_query_start_loc=spec_query_start_loc,
|
||||
non_spec_query_start_loc=non_spec_query_start_loc,
|
||||
|
||||
Reference in New Issue
Block a user