[ROCm][FEAT] Enable Full Graph Mode in AITER MLA V1 Attn Backend (Decode Phase only) (#20254)
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
@@ -2,7 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Optional
|
from typing import Any, ClassVar, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -63,6 +63,7 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
|
|||||||
|
|
||||||
|
|
||||||
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||||
|
full_cudagraph_supported: ClassVar[bool] = True # decode only
|
||||||
|
|
||||||
def __init__(self, runner, kv_cache_spec: AttentionSpec,
|
def __init__(self, runner, kv_cache_spec: AttentionSpec,
|
||||||
block_table: BlockTable):
|
block_table: BlockTable):
|
||||||
@@ -70,56 +71,83 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
|||||||
assert self.kv_cache_spec.block_size == 1, "AITER MLA" \
|
assert self.kv_cache_spec.block_size == 1, "AITER MLA" \
|
||||||
"only supports block size 1."
|
"only supports block size 1."
|
||||||
|
|
||||||
def _get_paged_kv_tensors(
|
# Preparing persistent buffers
|
||||||
self, block_table: torch.Tensor,
|
if self.runner.full_cuda_graph:
|
||||||
seq_lens: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
device = self.runner.device
|
||||||
|
max_num_reqs = self.runner.max_num_reqs
|
||||||
|
self.paged_kv_indptr = torch.zeros(max_num_reqs + 1,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device)
|
||||||
|
self.paged_kv_indices = torch.zeros(
|
||||||
|
block_table.get_device_tensor().numel(
|
||||||
|
), # max num pages possible
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device)
|
||||||
|
self.paged_kv_last_page_len = torch.zeros(max_num_reqs,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device)
|
||||||
|
|
||||||
|
self.qo_indptr = torch.arange(0,
|
||||||
|
max_num_reqs + 1,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device)
|
||||||
|
|
||||||
|
def _build_decode(self, block_table_tensor: torch.Tensor,
|
||||||
|
seq_lens: torch.Tensor) -> AiterMLADecodeMetadata:
|
||||||
page_size = self.kv_cache_spec.block_size
|
page_size = self.kv_cache_spec.block_size
|
||||||
block_table_bounds = (seq_lens + page_size - 1) // page_size
|
block_table_bounds = (seq_lens + page_size - 1) // page_size
|
||||||
device = self.runner.device
|
device = self.runner.device
|
||||||
|
|
||||||
mask = (torch.arange(block_table.size(1),
|
mask = (torch.arange(block_table_tensor.size(1),
|
||||||
dtype=block_table.dtype,
|
dtype=block_table_tensor.dtype,
|
||||||
device=device).unsqueeze(0)
|
device=device).unsqueeze(0)
|
||||||
< block_table_bounds.unsqueeze(1))
|
< block_table_bounds.unsqueeze(1))
|
||||||
paged_kv_indices = block_table[mask]
|
paged_kv_indices = block_table_tensor[mask]
|
||||||
|
|
||||||
|
paged_kv_last_page_len = seq_lens % page_size
|
||||||
|
paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0,
|
||||||
|
page_size, paged_kv_last_page_len)
|
||||||
|
|
||||||
paged_kv_indptr = torch.cat([
|
paged_kv_indptr = torch.cat([
|
||||||
torch.zeros(1, dtype=block_table_bounds.dtype, device=device),
|
torch.zeros(1, dtype=block_table_bounds.dtype, device=device),
|
||||||
block_table_bounds.cumsum(dim=0, dtype=torch.int32)
|
block_table_bounds.cumsum(dim=0, dtype=torch.int32)
|
||||||
])
|
])
|
||||||
|
|
||||||
paged_kv_last_page_len = seq_lens % page_size
|
if self.runner.full_cuda_graph:
|
||||||
paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0,
|
num_reqs = self._num_decodes
|
||||||
page_size, paged_kv_last_page_len)
|
|
||||||
qo_indptr = torch.arange(0,
|
|
||||||
self._num_decodes + 1,
|
|
||||||
step=1,
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=device)
|
|
||||||
|
|
||||||
return (
|
num_actual_pages = paged_kv_indices.size(0)
|
||||||
paged_kv_indices,
|
|
||||||
paged_kv_indptr,
|
|
||||||
paged_kv_last_page_len,
|
|
||||||
qo_indptr,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _build_decode(self, block_table_tensor: torch.Tensor,
|
self.paged_kv_indices[:num_actual_pages].copy_(paged_kv_indices,
|
||||||
seq_lens: torch.Tensor) -> AiterMLADecodeMetadata:
|
non_blocking=True)
|
||||||
|
self.paged_kv_indices[num_actual_pages:].fill_(-1)
|
||||||
|
paged_kv_indices = self.paged_kv_indices[:num_actual_pages]
|
||||||
|
|
||||||
(
|
self.paged_kv_indptr[:1 + num_reqs].copy_(paged_kv_indptr,
|
||||||
paged_kv_indices,
|
non_blocking=True)
|
||||||
paged_kv_indptr,
|
self.paged_kv_indptr[1 + num_reqs:].fill_(paged_kv_indptr[-1])
|
||||||
paged_last_page_len,
|
paged_kv_indptr = self.paged_kv_indptr[:1 + num_reqs]
|
||||||
qo_indptr,
|
|
||||||
) = self._get_paged_kv_tensors(block_table_tensor, seq_lens)
|
self.paged_kv_last_page_len[:num_reqs].copy_(
|
||||||
|
paged_kv_last_page_len, non_blocking=True)
|
||||||
|
self.paged_kv_last_page_len[num_reqs:].fill_(1)
|
||||||
|
paged_kv_last_page_len = self.paged_kv_last_page_len[:num_reqs]
|
||||||
|
|
||||||
|
qo_indptr = self.qo_indptr[:1 + num_reqs]
|
||||||
|
|
||||||
|
else:
|
||||||
|
qo_indptr = torch.arange(0,
|
||||||
|
self._num_decodes + 1,
|
||||||
|
step=1,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device)
|
||||||
|
|
||||||
attn_metadata = AiterMLADecodeMetadata(
|
attn_metadata = AiterMLADecodeMetadata(
|
||||||
block_table=block_table_tensor,
|
block_table=block_table_tensor,
|
||||||
seq_lens=seq_lens,
|
seq_lens=seq_lens,
|
||||||
paged_kv_indptr=paged_kv_indptr,
|
paged_kv_indptr=paged_kv_indptr,
|
||||||
paged_kv_indices=paged_kv_indices,
|
paged_kv_indices=paged_kv_indices,
|
||||||
paged_kv_last_page_len=paged_last_page_len,
|
paged_kv_last_page_len=paged_kv_last_page_len,
|
||||||
qo_indptr=qo_indptr)
|
qo_indptr=qo_indptr)
|
||||||
|
|
||||||
return attn_metadata
|
return attn_metadata
|
||||||
|
|||||||
Reference in New Issue
Block a user