From eac2dc2b410dc11af4b424802e86ef9d36bac28a Mon Sep 17 00:00:00 2001 From: pschlan-amd Date: Wed, 11 Mar 2026 08:25:00 +0100 Subject: [PATCH] AITER MLA backend: Avoid CPU sync in _build_decode (#35765) Signed-off-by: Patrick Schlangen --- .../attention/backends/mla/rocm_aiter_mla.py | 61 ++++++++++++++----- 1 file changed, 46 insertions(+), 15 deletions(-) diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 7b465db44..9ded91162 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -17,6 +17,7 @@ from vllm.model_executor.layers.attention.mla_attention import ( MLACommonMetadataBuilder, QueryLenSupport, ) +from vllm.triton_utils import tl, triton from vllm.v1.attention.backend import AttentionCGSupport, AttentionLayer, MultipleOf from vllm.v1.kv_cache_interface import AttentionSpec @@ -108,13 +109,16 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): max_num_reqs, dtype=torch.int32, device=device ) + # Persistent buffer for paged_kv_indices to avoid blocking boolean mask + # indexing (block_table_tensor[mask]) which has data-dependent output size. + self.paged_kv_indices = torch.zeros( + max_num_pages, dtype=torch.int32, device=device + ) + if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): self.paged_kv_indptr = torch.zeros( max_num_reqs + 1, dtype=torch.int32, device=device ) - self.paged_kv_indices = torch.zeros( - max_num_pages, dtype=torch.int32, device=device - ) self.qo_indptr = torch.zeros( max_num_reqs + 1, dtype=torch.int32, device=device @@ -134,11 +138,6 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): device = self.device num_reqs = seq_lens_device.size(0) - mask = torch.arange( - block_table_tensor.size(1), dtype=block_table_tensor.dtype, device=device - ).unsqueeze(0) < seq_lens_device.unsqueeze(1) - paged_kv_indices = block_table_tensor[mask] - # kernel block size is always 1, so each page has exactly 1 token. # last_page_len is always 1 - just slice the pre-initialized buffer. paged_kv_last_page_len = self.paged_kv_last_page_len[:num_reqs] @@ -153,14 +152,17 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): max_qo_len = qo_len.max().item() if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): - num_actual_pages = paged_kv_indices.size(0) - - self.paged_kv_indices[:num_actual_pages].copy_( - paged_kv_indices, non_blocking=True - ) - self.paged_kv_indices[num_actual_pages:].fill_(-1) - paged_kv_indices = self.paged_kv_indices[:num_actual_pages] + self.paged_kv_indices.fill_(-1) + _copy_page_indices_kernel[(num_reqs,)]( + self.paged_kv_indices, + block_table_tensor, + block_table_tensor.stride(0), + paged_kv_indptr, + BLOCK_SIZE=1024, + ) + paged_kv_indices = self.paged_kv_indices + if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): self.paged_kv_indptr[: 1 + num_reqs].copy_( paged_kv_indptr, non_blocking=True ) @@ -196,6 +198,35 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): return attn_metadata +@triton.jit +def _copy_page_indices_kernel( + page_indices, + block_table, + block_table_stride, + cu_num_blocks, + BLOCK_SIZE: tl.constexpr, +): + """Copy block table rows into a flat page_indices buffer using indptr. + Avoids blocking boolean mask indexing (tensor[mask]) which has + data-dependent output size and forces sync. + This is the same kernel as introduced in backends/flashinfer.py. + """ + req_idx = tl.program_id(0) + row_ptr = block_table + req_idx * block_table_stride + start_idx = tl.load(cu_num_blocks + req_idx) + end_idx = tl.load(cu_num_blocks + req_idx + 1) + num_blocks = end_idx - start_idx + + offset = tl.arange(0, BLOCK_SIZE) + for i in tl.range(0, num_blocks, BLOCK_SIZE): + block_ids = tl.load(row_ptr + i + offset, mask=i + offset < num_blocks) + tl.store( + page_indices + start_idx + i + offset, + block_ids, + mask=i + offset < num_blocks, + ) + + class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): def __init__( self,