[ROCm][FEAT] Enable Full Graph Mode in AITER MLA V1 Attn Backend (Decode Phase only) (#20254)

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
vllmellm
2025-07-03 00:25:46 +08:00
committed by GitHub
parent 139508a418
commit a1aafc827a

View File

@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Optional from typing import Any, ClassVar, Optional
import torch import torch
@@ -63,6 +63,7 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
full_cudagraph_supported: ClassVar[bool] = True # decode only
def __init__(self, runner, kv_cache_spec: AttentionSpec, def __init__(self, runner, kv_cache_spec: AttentionSpec,
block_table: BlockTable): block_table: BlockTable):
@@ -70,56 +71,83 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
assert self.kv_cache_spec.block_size == 1, "AITER MLA" \ assert self.kv_cache_spec.block_size == 1, "AITER MLA" \
"only supports block size 1." "only supports block size 1."
def _get_paged_kv_tensors( # Preparing persistent buffers
self, block_table: torch.Tensor, if self.runner.full_cuda_graph:
seq_lens: torch.Tensor) -> tuple[torch.Tensor, ...]: device = self.runner.device
max_num_reqs = self.runner.max_num_reqs
self.paged_kv_indptr = torch.zeros(max_num_reqs + 1,
dtype=torch.int32,
device=device)
self.paged_kv_indices = torch.zeros(
block_table.get_device_tensor().numel(
), # max num pages possible
dtype=torch.int32,
device=device)
self.paged_kv_last_page_len = torch.zeros(max_num_reqs,
dtype=torch.int32,
device=device)
self.qo_indptr = torch.arange(0,
max_num_reqs + 1,
dtype=torch.int32,
device=device)
def _build_decode(self, block_table_tensor: torch.Tensor,
seq_lens: torch.Tensor) -> AiterMLADecodeMetadata:
page_size = self.kv_cache_spec.block_size page_size = self.kv_cache_spec.block_size
block_table_bounds = (seq_lens + page_size - 1) // page_size block_table_bounds = (seq_lens + page_size - 1) // page_size
device = self.runner.device device = self.runner.device
mask = (torch.arange(block_table.size(1), mask = (torch.arange(block_table_tensor.size(1),
dtype=block_table.dtype, dtype=block_table_tensor.dtype,
device=device).unsqueeze(0) device=device).unsqueeze(0)
< block_table_bounds.unsqueeze(1)) < block_table_bounds.unsqueeze(1))
paged_kv_indices = block_table[mask] paged_kv_indices = block_table_tensor[mask]
paged_kv_last_page_len = seq_lens % page_size
paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0,
page_size, paged_kv_last_page_len)
paged_kv_indptr = torch.cat([ paged_kv_indptr = torch.cat([
torch.zeros(1, dtype=block_table_bounds.dtype, device=device), torch.zeros(1, dtype=block_table_bounds.dtype, device=device),
block_table_bounds.cumsum(dim=0, dtype=torch.int32) block_table_bounds.cumsum(dim=0, dtype=torch.int32)
]) ])
paged_kv_last_page_len = seq_lens % page_size if self.runner.full_cuda_graph:
paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0, num_reqs = self._num_decodes
page_size, paged_kv_last_page_len)
qo_indptr = torch.arange(0,
self._num_decodes + 1,
step=1,
dtype=torch.int32,
device=device)
return ( num_actual_pages = paged_kv_indices.size(0)
paged_kv_indices,
paged_kv_indptr,
paged_kv_last_page_len,
qo_indptr,
)
def _build_decode(self, block_table_tensor: torch.Tensor, self.paged_kv_indices[:num_actual_pages].copy_(paged_kv_indices,
seq_lens: torch.Tensor) -> AiterMLADecodeMetadata: non_blocking=True)
self.paged_kv_indices[num_actual_pages:].fill_(-1)
paged_kv_indices = self.paged_kv_indices[:num_actual_pages]
( self.paged_kv_indptr[:1 + num_reqs].copy_(paged_kv_indptr,
paged_kv_indices, non_blocking=True)
paged_kv_indptr, self.paged_kv_indptr[1 + num_reqs:].fill_(paged_kv_indptr[-1])
paged_last_page_len, paged_kv_indptr = self.paged_kv_indptr[:1 + num_reqs]
qo_indptr,
) = self._get_paged_kv_tensors(block_table_tensor, seq_lens) self.paged_kv_last_page_len[:num_reqs].copy_(
paged_kv_last_page_len, non_blocking=True)
self.paged_kv_last_page_len[num_reqs:].fill_(1)
paged_kv_last_page_len = self.paged_kv_last_page_len[:num_reqs]
qo_indptr = self.qo_indptr[:1 + num_reqs]
else:
qo_indptr = torch.arange(0,
self._num_decodes + 1,
step=1,
dtype=torch.int32,
device=device)
attn_metadata = AiterMLADecodeMetadata( attn_metadata = AiterMLADecodeMetadata(
block_table=block_table_tensor, block_table=block_table_tensor,
seq_lens=seq_lens, seq_lens=seq_lens,
paged_kv_indptr=paged_kv_indptr, paged_kv_indptr=paged_kv_indptr,
paged_kv_indices=paged_kv_indices, paged_kv_indices=paged_kv_indices,
paged_kv_last_page_len=paged_last_page_len, paged_kv_last_page_len=paged_kv_last_page_len,
qo_indptr=qo_indptr) qo_indptr=qo_indptr)
return attn_metadata return attn_metadata