AITER MLA backend: Avoid CPU sync in _build_decode (#35765)
Signed-off-by: Patrick Schlangen <pschlan@amd.com>
This commit is contained in:
@@ -17,6 +17,7 @@ from vllm.model_executor.layers.attention.mla_attention import (
|
||||
MLACommonMetadataBuilder,
|
||||
QueryLenSupport,
|
||||
)
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.attention.backend import AttentionCGSupport, AttentionLayer, MultipleOf
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
@@ -108,13 +109,16 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
max_num_reqs, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
# Persistent buffer for paged_kv_indices to avoid blocking boolean mask
|
||||
# indexing (block_table_tensor[mask]) which has data-dependent output size.
|
||||
self.paged_kv_indices = torch.zeros(
|
||||
max_num_pages, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||
self.paged_kv_indptr = torch.zeros(
|
||||
max_num_reqs + 1, dtype=torch.int32, device=device
|
||||
)
|
||||
self.paged_kv_indices = torch.zeros(
|
||||
max_num_pages, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
self.qo_indptr = torch.zeros(
|
||||
max_num_reqs + 1, dtype=torch.int32, device=device
|
||||
@@ -134,11 +138,6 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
device = self.device
|
||||
num_reqs = seq_lens_device.size(0)
|
||||
|
||||
mask = torch.arange(
|
||||
block_table_tensor.size(1), dtype=block_table_tensor.dtype, device=device
|
||||
).unsqueeze(0) < seq_lens_device.unsqueeze(1)
|
||||
paged_kv_indices = block_table_tensor[mask]
|
||||
|
||||
# kernel block size is always 1, so each page has exactly 1 token.
|
||||
# last_page_len is always 1 - just slice the pre-initialized buffer.
|
||||
paged_kv_last_page_len = self.paged_kv_last_page_len[:num_reqs]
|
||||
@@ -153,14 +152,17 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
max_qo_len = qo_len.max().item()
|
||||
|
||||
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||
num_actual_pages = paged_kv_indices.size(0)
|
||||
|
||||
self.paged_kv_indices[:num_actual_pages].copy_(
|
||||
paged_kv_indices, non_blocking=True
|
||||
)
|
||||
self.paged_kv_indices[num_actual_pages:].fill_(-1)
|
||||
paged_kv_indices = self.paged_kv_indices[:num_actual_pages]
|
||||
self.paged_kv_indices.fill_(-1)
|
||||
_copy_page_indices_kernel[(num_reqs,)](
|
||||
self.paged_kv_indices,
|
||||
block_table_tensor,
|
||||
block_table_tensor.stride(0),
|
||||
paged_kv_indptr,
|
||||
BLOCK_SIZE=1024,
|
||||
)
|
||||
paged_kv_indices = self.paged_kv_indices
|
||||
|
||||
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||
self.paged_kv_indptr[: 1 + num_reqs].copy_(
|
||||
paged_kv_indptr, non_blocking=True
|
||||
)
|
||||
@@ -196,6 +198,35 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
return attn_metadata
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _copy_page_indices_kernel(
|
||||
page_indices,
|
||||
block_table,
|
||||
block_table_stride,
|
||||
cu_num_blocks,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
"""Copy block table rows into a flat page_indices buffer using indptr.
|
||||
Avoids blocking boolean mask indexing (tensor[mask]) which has
|
||||
data-dependent output size and forces sync.
|
||||
This is the same kernel as introduced in backends/flashinfer.py.
|
||||
"""
|
||||
req_idx = tl.program_id(0)
|
||||
row_ptr = block_table + req_idx * block_table_stride
|
||||
start_idx = tl.load(cu_num_blocks + req_idx)
|
||||
end_idx = tl.load(cu_num_blocks + req_idx + 1)
|
||||
num_blocks = end_idx - start_idx
|
||||
|
||||
offset = tl.arange(0, BLOCK_SIZE)
|
||||
for i in tl.range(0, num_blocks, BLOCK_SIZE):
|
||||
block_ids = tl.load(row_ptr + i + offset, mask=i + offset < num_blocks)
|
||||
tl.store(
|
||||
page_indices + start_idx + i + offset,
|
||||
block_ids,
|
||||
mask=i + offset < num_blocks,
|
||||
)
|
||||
|
||||
|
||||
class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user