[PERF] Qwen3-next MTP speedup (change bool mask indexing to index_select / index_copy to reduce d2h) (#26437)
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
This commit is contained in:
@@ -47,9 +47,9 @@ class GDNAttentionMetadata:
|
||||
None # shape: [batch - num_spec_decodes,]
|
||||
)
|
||||
spec_sequence_masks: torch.Tensor | None = None # shape: [batch,]
|
||||
spec_token_masks: torch.Tensor | None = (
|
||||
None # shape: [num_prefill_tokens + num_decode_tokens,]
|
||||
)
|
||||
spec_token_indx: torch.Tensor | None = None
|
||||
non_spec_token_indx: torch.Tensor | None = None
|
||||
|
||||
num_accepted_tokens: torch.Tensor | None = None # shape: [batch,]
|
||||
|
||||
# The following attributes are for triton implementation of causal_conv1d
|
||||
@@ -105,9 +105,14 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
dtype=torch.bool,
|
||||
device=device,
|
||||
)
|
||||
self.spec_token_masks = torch.empty(
|
||||
self.spec_token_indx = torch.empty(
|
||||
(self.decode_cudagraph_max_bs * (self.num_spec + 1),),
|
||||
dtype=torch.bool,
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.non_spec_token_indx = torch.empty(
|
||||
(self.decode_cudagraph_max_bs * (self.num_spec + 1),),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.spec_query_start_loc = torch.empty(
|
||||
@@ -166,7 +171,8 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
split_decodes_and_prefills(m, decode_threshold=1)
|
||||
)
|
||||
num_spec_decode_tokens = 0
|
||||
spec_token_masks = None
|
||||
spec_token_indx = None
|
||||
non_spec_token_indx = None
|
||||
spec_state_indices_tensor = None
|
||||
non_spec_state_indices_tensor = m.block_table_tensor[:, 0]
|
||||
spec_query_start_loc = None
|
||||
@@ -180,18 +186,23 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
num_prefills = non_spec_query_lens.size(0) - num_decodes
|
||||
num_decode_tokens = num_decodes
|
||||
num_prefill_tokens = non_spec_query_lens.sum().item() - num_decode_tokens
|
||||
num_spec_decode_tokens = (
|
||||
query_lens.sum().item() - num_prefill_tokens - num_decode_tokens
|
||||
)
|
||||
|
||||
if num_prefills == 0 and num_decodes == 0:
|
||||
spec_token_masks = torch.ones(
|
||||
(
|
||||
min(
|
||||
num_spec_decodes * (self.num_spec + 1),
|
||||
query_start_loc[-1].item(),
|
||||
)
|
||||
),
|
||||
dtype=torch.bool,
|
||||
spec_token_size = min(
|
||||
num_spec_decodes * (self.num_spec + 1),
|
||||
query_start_loc[-1].item(),
|
||||
)
|
||||
spec_token_indx = torch.arange(
|
||||
spec_token_size,
|
||||
dtype=torch.int32,
|
||||
device=query_start_loc.device,
|
||||
)
|
||||
non_spec_token_indx = torch.empty(
|
||||
0, dtype=torch.int32, device=query_start_loc.device
|
||||
)
|
||||
spec_state_indices_tensor = m.block_table_tensor[:, : self.num_spec + 1]
|
||||
non_spec_state_indices_tensor = None
|
||||
spec_query_start_loc = query_start_loc
|
||||
@@ -200,6 +211,11 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
spec_token_masks = torch.repeat_interleave(
|
||||
spec_sequence_masks, query_lens
|
||||
)
|
||||
index = torch.argsort(spec_token_masks)
|
||||
num_non_spec_tokens = num_prefill_tokens + num_decode_tokens
|
||||
non_spec_token_indx = index[:num_non_spec_tokens]
|
||||
spec_token_indx = index[num_non_spec_tokens:]
|
||||
|
||||
spec_state_indices_tensor = m.block_table_tensor[
|
||||
spec_sequence_masks, : self.num_spec + 1
|
||||
]
|
||||
@@ -226,9 +242,6 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
out=non_spec_query_start_loc[1:],
|
||||
)
|
||||
|
||||
num_spec_decode_tokens = (
|
||||
query_lens.sum().item() - num_prefill_tokens - num_decode_tokens
|
||||
)
|
||||
assert num_accepted_tokens is not None
|
||||
num_accepted_tokens = num_accepted_tokens[spec_sequence_masks]
|
||||
|
||||
@@ -274,12 +287,18 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
spec_sequence_masks = self.spec_sequence_masks[:batch_size]
|
||||
spec_sequence_masks[num_spec_decodes:].fill_(False)
|
||||
|
||||
assert spec_token_masks is not None
|
||||
self.spec_token_masks[: spec_token_masks.size(0)].copy_(
|
||||
spec_token_masks, non_blocking=True
|
||||
assert non_spec_token_indx is not None and spec_token_indx is not None
|
||||
self.non_spec_token_indx[: non_spec_token_indx.size(0)].copy_(
|
||||
non_spec_token_indx, non_blocking=True
|
||||
)
|
||||
spec_token_masks = self.spec_token_masks[:num_actual_tokens]
|
||||
spec_token_masks[spec_token_masks.size(0) :].fill_(False)
|
||||
non_spec_token_indx = self.non_spec_token_indx[
|
||||
: non_spec_token_indx.size(0)
|
||||
]
|
||||
|
||||
self.spec_token_indx[: spec_token_indx.size(0)].copy_(
|
||||
spec_token_indx, non_blocking=True
|
||||
)
|
||||
spec_token_indx = self.spec_token_indx[: spec_token_indx.size(0)]
|
||||
|
||||
self.spec_query_start_loc[: num_spec_decodes + 1].copy_(
|
||||
spec_query_start_loc, non_blocking=True
|
||||
@@ -332,7 +351,8 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
spec_state_indices_tensor=spec_state_indices_tensor,
|
||||
non_spec_state_indices_tensor=non_spec_state_indices_tensor,
|
||||
spec_sequence_masks=spec_sequence_masks,
|
||||
spec_token_masks=spec_token_masks,
|
||||
spec_token_indx=spec_token_indx,
|
||||
non_spec_token_indx=non_spec_token_indx,
|
||||
num_accepted_tokens=num_accepted_tokens,
|
||||
nums_dict=nums_dict,
|
||||
batch_ptr=batch_ptr,
|
||||
|
||||
Reference in New Issue
Block a user