[ROCm][Perf] Enable shuffle kv cache layout and assembly paged attention kernel for AiterFlashAttentionBackend (#29887)
Signed-off-by: ganyi <ygan@amd.com>
This commit is contained in:
@@ -833,6 +833,7 @@ class rocm_aiter_ops:
|
||||
_FMOE_ENABLED = envs.VLLM_ROCM_USE_AITER_MOE
|
||||
_MLA_ENABLED = envs.VLLM_ROCM_USE_AITER_MLA
|
||||
_MHA_ENABLED = envs.VLLM_ROCM_USE_AITER_MHA
|
||||
_SHUFFLE_KV_CACHE_ENABLED = envs.VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT
|
||||
_TRITON_UNIFIED_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION
|
||||
# TODO: Consolidate under _LINEAR_ENABLED
|
||||
_FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM
|
||||
@@ -859,6 +860,7 @@ class rocm_aiter_ops:
|
||||
cls._FMOE_ENABLED = envs.VLLM_ROCM_USE_AITER_MOE
|
||||
cls._MLA_ENABLED = envs.VLLM_ROCM_USE_AITER_MLA
|
||||
cls._MHA_ENABLED = envs.VLLM_ROCM_USE_AITER_MHA
|
||||
cls._SHUFFLE_KV_CACHE_ENABLED = envs.VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT
|
||||
cls._TRITON_UNIFIED_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION
|
||||
cls._FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM
|
||||
cls._FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM
|
||||
@@ -906,6 +908,11 @@ class rocm_aiter_ops:
|
||||
def is_mha_enabled(cls) -> bool:
|
||||
return cls._AITER_ENABLED and cls._MHA_ENABLED
|
||||
|
||||
@classmethod
|
||||
@if_aiter_supported
|
||||
def is_shuffle_kv_cache_enabled(cls) -> bool:
|
||||
return cls._AITER_ENABLED and cls._SHUFFLE_KV_CACHE_ENABLED
|
||||
|
||||
@classmethod
|
||||
@if_aiter_supported
|
||||
def is_triton_unified_attn_enabled(cls) -> bool:
|
||||
|
||||
@@ -128,6 +128,7 @@ if TYPE_CHECKING:
|
||||
VLLM_ROCM_FP8_PADDING: bool = True
|
||||
VLLM_ROCM_MOE_PADDING: bool = True
|
||||
VLLM_ROCM_CUSTOM_PAGED_ATTN: bool = True
|
||||
VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT: bool = False
|
||||
VLLM_ENABLE_V1_MULTIPROCESSING: bool = True
|
||||
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
|
||||
VLLM_DISABLE_COMPILE_CACHE: bool = False
|
||||
@@ -1018,6 +1019,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_ROCM_CUSTOM_PAGED_ATTN": lambda: (
|
||||
os.getenv("VLLM_ROCM_CUSTOM_PAGED_ATTN", "True").lower() in ("true", "1")
|
||||
),
|
||||
# Whether to use the shuffled kv cache layout
|
||||
"VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT": lambda: (
|
||||
os.getenv("VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT", "False").lower() in ("true", "1")
|
||||
),
|
||||
# Custom quick allreduce kernel for MI3* cards
|
||||
# Choice of quantization level: FP, INT8, INT6, INT4 or NONE
|
||||
# Recommended for large models to get allreduce
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import ClassVar
|
||||
|
||||
import torch
|
||||
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
@@ -30,7 +31,6 @@ from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
_PARTITION_SIZE_ROCM = 256
|
||||
_CP_TOKENS_PER_ITER_ROCM = 32 * 1024
|
||||
|
||||
if current_platform.is_rocm():
|
||||
import aiter
|
||||
|
||||
@@ -52,7 +52,7 @@ if current_platform.is_rocm():
|
||||
cu_seqlens_kv_ptr, # [num_batches + 1]
|
||||
token_to_batch_ptr, # [max_cum_tokens]
|
||||
seq_start_ptr, # [num_batches]
|
||||
k_scale_ptr,
|
||||
k_scale_ptr, # [1] / [num_blocks, num_kv_heads, page_size]
|
||||
v_scale_ptr,
|
||||
num_heads,
|
||||
head_size,
|
||||
@@ -64,13 +64,15 @@ if current_platform.is_rocm():
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
token_id = tl.program_id(0)
|
||||
head_id = tl.program_id(1)
|
||||
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||
if DEQUANT:
|
||||
k_scale = tl.load(k_scale_ptr)
|
||||
v_scale = tl.load(v_scale_ptr)
|
||||
|
||||
key_ptr_offset = key_ptr + token_id * head_size * num_heads
|
||||
value_ptr_offset = value_ptr + token_id * head_size * num_heads
|
||||
key_ptr_offset = (
|
||||
key_ptr + token_id * head_size * num_heads + head_id * head_size
|
||||
)
|
||||
value_ptr_offset = (
|
||||
value_ptr + token_id * head_size * num_heads + head_id * head_size
|
||||
)
|
||||
batch_idx = tl.load(token_to_batch_ptr + token_id)
|
||||
batch_start = tl.load(seq_start_ptr + batch_idx)
|
||||
token_start = tl.load(cu_seqlens_kv_ptr + batch_idx)
|
||||
@@ -89,24 +91,54 @@ if current_platform.is_rocm():
|
||||
key_cache_ptr
|
||||
+ block_id * num_heads * head_size * PAGE_SIZE
|
||||
+ slot_id * num_heads * head_size
|
||||
+ head_id * head_size
|
||||
)
|
||||
value_cache_ptr_offset = (
|
||||
value_cache_ptr
|
||||
+ block_id * num_heads * head_size * PAGE_SIZE
|
||||
+ slot_id * num_heads * head_size
|
||||
+ head_id * head_size
|
||||
)
|
||||
k_reg = tl.load(key_cache_ptr_offset + col_offsets)
|
||||
v_reg = tl.load(value_cache_ptr_offset + col_offsets)
|
||||
if DEQUANT:
|
||||
k_scale = tl.load(k_scale_ptr)
|
||||
v_scale = tl.load(v_scale_ptr)
|
||||
k_dtype = k_reg.dtype
|
||||
v_dtype = v_reg.dtype
|
||||
k_reg = (k_reg.to(tl.float32) * k_scale).to(k_dtype)
|
||||
v_reg = (v_reg.to(tl.float32) * v_scale).to(v_dtype)
|
||||
tl.store(key_ptr_offset + col_offsets, k_reg)
|
||||
tl.store(value_ptr_offset + col_offsets, v_reg)
|
||||
|
||||
for i in tl.range(0, head_size * num_heads, BLOCK_SIZE):
|
||||
mask = (col_offsets + i) < head_size * num_heads
|
||||
k_reg = tl.load(key_cache_ptr_offset + col_offsets + i, mask=mask)
|
||||
v_reg = tl.load(value_cache_ptr_offset + col_offsets + i, mask=mask)
|
||||
if DEQUANT:
|
||||
k_dtype = k_reg.dtype
|
||||
v_dtype = v_reg.dtype
|
||||
k_reg = (k_reg.to(tl.float32) * k_scale).to(k_dtype)
|
||||
v_reg = (v_reg.to(tl.float32) * v_scale).to(v_dtype)
|
||||
tl.store(key_ptr_offset + col_offsets + i, k_reg, mask=mask)
|
||||
tl.store(value_ptr_offset + col_offsets + i, v_reg, mask=mask)
|
||||
elif CACHE_FORMAT == "SHUFFLE":
|
||||
# for kv cache layout as
|
||||
# K: [num_blocks, num_head, head_dim // x, page_size, x]
|
||||
# V: [num_blocks, num_head, page_size // x, head_dim, x]
|
||||
key_cache_ptr_offset = (
|
||||
key_cache_ptr
|
||||
+ block_id * num_heads * head_size * PAGE_SIZE
|
||||
+ head_id * head_size * PAGE_SIZE
|
||||
+ slot_id * x
|
||||
)
|
||||
value_cache_ptr_offset = (
|
||||
value_cache_ptr
|
||||
+ block_id * num_heads * head_size * PAGE_SIZE
|
||||
+ head_id * head_size * PAGE_SIZE
|
||||
+ (slot_id // x) * head_size * x
|
||||
+ slot_id % x
|
||||
)
|
||||
k_reg_offset = col_offsets // x * PAGE_SIZE * x + col_offsets % x
|
||||
v_reg_offset = col_offsets * x
|
||||
k_reg = tl.load(key_cache_ptr_offset + k_reg_offset)
|
||||
v_reg = tl.load(value_cache_ptr_offset + v_reg_offset)
|
||||
if DEQUANT:
|
||||
k_scale = 1.0
|
||||
v_scale = 1.0
|
||||
k_reg = k_reg.to(tl.float32) * k_scale
|
||||
v_reg = v_reg.to(tl.float32) * v_scale
|
||||
tl.store(key_ptr_offset + col_offsets, k_reg)
|
||||
tl.store(value_ptr_offset + col_offsets, v_reg)
|
||||
|
||||
def cp_mha_gather_cache(
|
||||
key_cache: torch.Tensor,
|
||||
@@ -123,17 +155,14 @@ if current_platform.is_rocm():
|
||||
kv_cache_layout: str,
|
||||
total_tokens: int,
|
||||
):
|
||||
assert kv_cache_layout in ["v0", "NHD", "HND"], (
|
||||
"kv_cache_layout only support v0, NHD, HND"
|
||||
assert kv_cache_layout in ["NHD", "SHUFFLE"], (
|
||||
"kv_cache_layout only support NHD, SHUFFLE"
|
||||
)
|
||||
head_dim = key.shape[2]
|
||||
x = 0
|
||||
x = 16 // key_cache.element_size()
|
||||
# assert dequant is True, "Currently, we only support "\
|
||||
# "gather cache with dequant"
|
||||
# For k cache layout: [num_blocks, num_heads, page_size, head_dim]
|
||||
assert kv_cache_layout == "NHD", (
|
||||
"ROCM_AITER_FA_BACKEND Only support NHD kv cache layout for now"
|
||||
)
|
||||
assert head_dim == key_cache.shape[3], (
|
||||
"We assume your kv cache layout is [num_blocks, "
|
||||
"page_size, num_heads, head_dim], but got otherwise"
|
||||
@@ -141,7 +170,7 @@ if current_platform.is_rocm():
|
||||
page_size = key_cache.shape[1]
|
||||
num_heads = key_cache.shape[2]
|
||||
|
||||
grid = lambda meta: (total_tokens,)
|
||||
grid = lambda meta: (total_tokens, num_heads)
|
||||
cp_mha_gather_cache_kernel[grid](
|
||||
key_cache,
|
||||
value_cache,
|
||||
@@ -163,6 +192,112 @@ if current_platform.is_rocm():
|
||||
BLOCK_SIZE=head_dim,
|
||||
)
|
||||
|
||||
@triton.jit
|
||||
def reshape_and_cache_shuffle_kernel(
|
||||
key_ptr, # [num_tokens, num_kv_heads, head_size]
|
||||
value_ptr, # [num_tokens, num_kv_heads, head_size]
|
||||
key_cache_ptr, # [num_blocks, num_kv_heads, head_size // x, block_size, x]
|
||||
value_cache_ptr, # [num_blocks, num_kv_heads, block_size // x, head_size, x]
|
||||
slot_mapping_ptr, # [num_tokens]
|
||||
k_scale_ptr, # [num_blocks, num_kv_heads, block_size]
|
||||
v_scale_ptr, # [num_blocks, num_kv_heads, block_size]
|
||||
x,
|
||||
k_stride0,
|
||||
v_stride0,
|
||||
block_size,
|
||||
head_size,
|
||||
num_kv_heads,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
QUANT: tl.constexpr,
|
||||
IS_FNUZ: tl.constexpr,
|
||||
):
|
||||
tid = tl.program_id(0)
|
||||
head_id = tl.program_id(1)
|
||||
offset = tl.arange(0, BLOCK_SIZE)
|
||||
src_offset_k = tid * k_stride0 + head_id * head_size
|
||||
src_offset_v = tid * v_stride0 + head_id * head_size
|
||||
slot_id = tl.load(slot_mapping_ptr + tid)
|
||||
if slot_id < 0:
|
||||
return
|
||||
block_id = slot_id // block_size
|
||||
block_offset = slot_id % block_size
|
||||
dst_offset = (
|
||||
block_id * num_kv_heads * head_size * block_size
|
||||
+ head_id * head_size * block_size
|
||||
)
|
||||
dst_k_shuffle_offset = (
|
||||
dst_offset + offset // x * block_size * x + block_offset * x + offset % x
|
||||
)
|
||||
dst_v_shuffle_offset = (
|
||||
dst_offset
|
||||
+ block_offset // x * head_size * x
|
||||
+ offset * x
|
||||
+ block_offset % x
|
||||
)
|
||||
k_val = tl.load(key_ptr + src_offset_k + offset)
|
||||
v_val = tl.load(value_ptr + src_offset_v + offset)
|
||||
if QUANT:
|
||||
k_scale = 1.0
|
||||
v_scale = 1.0
|
||||
k_dtype = key_cache_ptr.type.element_ty
|
||||
v_dtype = value_cache_ptr.type.element_ty
|
||||
k_val = (k_val.to(tl.float32) / k_scale).to(k_dtype)
|
||||
v_val = (v_val.to(tl.float32) / v_scale).to(v_dtype)
|
||||
tl.store(key_cache_ptr + dst_k_shuffle_offset, k_val)
|
||||
tl.store(value_cache_ptr + dst_v_shuffle_offset, v_val)
|
||||
|
||||
def reshape_and_cache_shuffle_triton(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
k_scales: torch.Tensor,
|
||||
v_scales: torch.Tensor,
|
||||
):
|
||||
num_tokens = slot_mapping.shape[0]
|
||||
_, num_kv_heads, head_size = key.shape
|
||||
num_blocks, block_size, _, _ = key_cache.shape
|
||||
x = 16 // key_cache.element_size()
|
||||
k_cache_template = torch.empty(
|
||||
[num_blocks, num_kv_heads, head_size // x, block_size, x],
|
||||
dtype=key_cache.dtype,
|
||||
device="meta",
|
||||
)
|
||||
v_cache_template = torch.empty(
|
||||
[num_blocks, num_kv_heads, block_size // x, head_size, x],
|
||||
dtype=value_cache.dtype,
|
||||
device="meta",
|
||||
)
|
||||
new_key_cache = key_cache.view_as(k_cache_template)
|
||||
new_value_cache = value_cache.view_as(v_cache_template)
|
||||
QUANT = False
|
||||
if kv_cache_dtype.startswith("fp8"):
|
||||
QUANT = True
|
||||
grid = (
|
||||
num_tokens,
|
||||
num_kv_heads,
|
||||
)
|
||||
reshape_and_cache_shuffle_kernel[grid](
|
||||
key,
|
||||
value,
|
||||
new_key_cache,
|
||||
new_value_cache,
|
||||
slot_mapping,
|
||||
k_scales,
|
||||
v_scales,
|
||||
x,
|
||||
key.stride(0),
|
||||
value.stride(0),
|
||||
block_size,
|
||||
head_size,
|
||||
num_kv_heads,
|
||||
BLOCK_SIZE=head_size,
|
||||
QUANT=QUANT,
|
||||
IS_FNUZ=current_platform.fp8_dtype() == torch.float8_e4m3fnuz,
|
||||
)
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -253,6 +388,11 @@ class AiterFlashAttentionMetadata:
|
||||
common_prefix_len: int
|
||||
total_tokens: int
|
||||
|
||||
# Only for fp8 shuffle layout kv cache, we allocate kv_scale for each layer
|
||||
# since we might integrate per token quant for kv cache in the future.
|
||||
k_scale: dict[str, torch.Tensor] | None
|
||||
v_scale: dict[str, torch.Tensor] | None
|
||||
|
||||
|
||||
class AiterFlashAttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[AiterFlashAttentionMetadata]
|
||||
@@ -303,6 +443,7 @@ class AiterFlashAttentionMetadataBuilder(
|
||||
dtype=self.model_config.dtype,
|
||||
device=device,
|
||||
)
|
||||
self.scale = torch.tensor([1.0], dtype=torch.float, device=self.device)
|
||||
|
||||
def build_for_cudagraph_capture(
|
||||
self, common_attn_metadata: CommonAttentionMetadata
|
||||
@@ -325,7 +466,27 @@ class AiterFlashAttentionMetadataBuilder(
|
||||
common_attn_metadata,
|
||||
decode_threshold=self.reorder_batch_threshold,
|
||||
)
|
||||
|
||||
# Allocate scales for fp8 shuffle kv cache with shuffle_kv_cache enabled
|
||||
if (
|
||||
rocm_aiter_ops.is_shuffle_kv_cache_enabled()
|
||||
and self.scale.numel() == 1
|
||||
and self.vllm_config.cache_config.cache_dtype.startswith("fp8")
|
||||
):
|
||||
layers = get_layers_from_vllm_config(self.vllm_config, Attention)
|
||||
first_layer_name = [k for k in layers][0]
|
||||
kv_cache_shape = (
|
||||
self.vllm_config.compilation_config.static_forward_context[
|
||||
first_layer_name
|
||||
]
|
||||
.kv_cache[0]
|
||||
.shape
|
||||
)
|
||||
num_blocks = kv_cache_shape[1]
|
||||
self.scale = torch.ones(
|
||||
[num_blocks, self.num_heads_kv, self.block_size],
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
(
|
||||
num_decodes,
|
||||
num_extends,
|
||||
@@ -507,6 +668,8 @@ class AiterFlashAttentionMetadataBuilder(
|
||||
use_cascade=use_cascade,
|
||||
common_prefix_len=common_prefix_len,
|
||||
total_tokens=self.total_tokens,
|
||||
k_scale=self.scale,
|
||||
v_scale=self.scale,
|
||||
)
|
||||
return attn_metadata
|
||||
|
||||
@@ -548,7 +711,6 @@ class AiterFlashAttentionBackend(AttentionBackend):
|
||||
) -> tuple[int, ...]:
|
||||
if block_size % 16 != 0:
|
||||
raise ValueError("Block size must be a multiple of 16.")
|
||||
|
||||
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||
|
||||
|
||||
@@ -630,7 +792,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
||||
cu_seqlens_kv=swa_cu_seqlens,
|
||||
token_to_batch=swa_token_to_batch,
|
||||
seq_starts=swa_seq_starts,
|
||||
dequant=False,
|
||||
dequant=self.kv_cache_dtype.startswith("fp8"),
|
||||
kv_cache_layout="NHD",
|
||||
total_tokens=swa_total_tokens,
|
||||
)
|
||||
@@ -668,8 +830,8 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
||||
min_seqlen_q: int,
|
||||
block_table: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
k_scale: float,
|
||||
v_scale: float,
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
):
|
||||
if self.sliding_window[0] != -1:
|
||||
self.extend_for_sliding_window(
|
||||
@@ -725,8 +887,10 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
||||
cu_seqlens_kv=cu_seqlens_kv[chunk_idx],
|
||||
token_to_batch=token_to_batch[chunk_idx],
|
||||
seq_starts=chunk_starts[chunk_idx],
|
||||
dequant=False,
|
||||
kv_cache_layout="NHD",
|
||||
dequant=self.kv_cache_dtype.startswith("fp8"),
|
||||
kv_cache_layout="SHUFFLE"
|
||||
if rocm_aiter_ops.is_shuffle_kv_cache_enabled()
|
||||
else "NHD",
|
||||
total_tokens=total_token_per_batch[chunk_idx],
|
||||
)
|
||||
|
||||
@@ -823,6 +987,9 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
||||
# key and value may be None in the case of cross attention. They are
|
||||
# calculated once based on the output from the encoder and then cached
|
||||
# in KV cache.
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
key_cache = key_cache.view(current_platform.fp8_dtype())
|
||||
value_cache = value_cache.view(current_platform.fp8_dtype())
|
||||
if (
|
||||
self.kv_sharing_target_layer_name is None
|
||||
and key is not None
|
||||
@@ -835,21 +1002,31 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
||||
# key[:num_actual_tokens] and value[:num_actual_tokens] because
|
||||
# the reshape_and_cache_flash op uses the slot_mapping's shape
|
||||
# to determine the number of actual tokens.
|
||||
|
||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
key_cache = key_cache.view(current_platform.fp8_dtype())
|
||||
value_cache = value_cache.view(current_platform.fp8_dtype())
|
||||
if rocm_aiter_ops.is_shuffle_kv_cache_enabled():
|
||||
# We may calculate per token quant scale in
|
||||
# reshape_and_cache_shuffle_triton which might differ from
|
||||
# vllm's style when shuffle layout is used.
|
||||
reshape_and_cache_shuffle_triton(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
attn_metadata.k_scale,
|
||||
attn_metadata.v_scale,
|
||||
)
|
||||
else:
|
||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
# decode:extend:prefill
|
||||
query = query[:num_actual_tokens]
|
||||
@@ -902,6 +1079,11 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
||||
extend_keys = key[extend_tokens_slice]
|
||||
extend_values = value[extend_tokens_slice]
|
||||
extend_outputs = output[extend_tokens_slice]
|
||||
k_scale = layer._k_scale
|
||||
v_scale = layer._v_scale
|
||||
if rocm_aiter_ops.is_shuffle_kv_cache_enabled():
|
||||
k_scale = attn_metadata.k_scale
|
||||
v_scale = attn_metadata.v_scale
|
||||
self.extend_forward(
|
||||
attn_metadata=attn_metadata,
|
||||
query=extend_querys,
|
||||
@@ -920,14 +1102,17 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
||||
slot_mapping=attn_metadata.slot_mapping[
|
||||
num_decodes : num_decodes + num_extends
|
||||
],
|
||||
k_scale=layer._k_scale,
|
||||
v_scale=layer._v_scale,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
)
|
||||
|
||||
# calculate for decodes
|
||||
if num_decodes > 0:
|
||||
assert attn_metadata.decode_metadata is not None
|
||||
if self.sliding_window[0] != -1:
|
||||
assert not rocm_aiter_ops.is_shuffle_kv_cache_enabled(), (
|
||||
"Sliding window with shuffle layout is not supported yet."
|
||||
)
|
||||
from aiter.ops.triton.unified_attention import (
|
||||
unified_attention,
|
||||
)
|
||||
@@ -957,41 +1142,71 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
||||
)
|
||||
return
|
||||
assert attn_metadata.decode_metadata is not None
|
||||
_, num_heads, head_size = query.shape
|
||||
nbytes_per_qo_elem = torch.finfo(query.dtype).bits // 8
|
||||
num_seqs = attn_metadata.seq_lens.shape[0]
|
||||
max_num_partitions = (
|
||||
attn_metadata.max_seq_len + _PARTITION_SIZE_ROCM - 1
|
||||
) // _PARTITION_SIZE_ROCM
|
||||
|
||||
workspace_buffer = torch.empty(
|
||||
(num_seqs * num_heads * max_num_partitions * head_size)
|
||||
* nbytes_per_qo_elem
|
||||
+ 2 * (num_seqs * num_heads * max_num_partitions) * 4,
|
||||
dtype=torch.uint8,
|
||||
device=output.device,
|
||||
)
|
||||
if rocm_aiter_ops.is_shuffle_kv_cache_enabled():
|
||||
num_blocks, block_size, num_kv_heads, head_size = key_cache.shape
|
||||
x = 16 // key_cache.element_size()
|
||||
k_cache_template = torch.empty(
|
||||
[num_blocks, num_kv_heads, head_size // x, block_size, x],
|
||||
dtype=key_cache.dtype,
|
||||
device="meta",
|
||||
)
|
||||
v_cache_template = torch.empty(
|
||||
[num_blocks, num_kv_heads, block_size // x, head_size, x],
|
||||
dtype=value_cache.dtype,
|
||||
device="meta",
|
||||
)
|
||||
new_key_cache = key_cache.view_as(k_cache_template)
|
||||
new_value_cache = value_cache.view_as(v_cache_template)
|
||||
aiter.pa_fwd_asm(
|
||||
Q=query[:num_decode_tokens],
|
||||
K=new_key_cache,
|
||||
V=new_value_cache,
|
||||
block_tables=attn_metadata.block_table[:num_decodes],
|
||||
context_lens=attn_metadata.seq_lens[:num_decodes],
|
||||
block_tables_stride0=attn_metadata.block_table[
|
||||
:num_decodes
|
||||
].stride(0),
|
||||
K_QScale=attn_metadata.k_scale,
|
||||
V_QScale=attn_metadata.v_scale,
|
||||
out_=output[:num_decode_tokens],
|
||||
)
|
||||
else:
|
||||
_, num_heads, head_size = query.shape
|
||||
nbytes_per_qo_elem = torch.finfo(query.dtype).bits // 8
|
||||
num_seqs = attn_metadata.seq_lens.shape[0]
|
||||
max_num_partitions = (
|
||||
attn_metadata.max_seq_len + _PARTITION_SIZE_ROCM - 1
|
||||
) // _PARTITION_SIZE_ROCM
|
||||
|
||||
torch.ops.aiter.paged_attention_v1(
|
||||
output[:num_decode_tokens],
|
||||
workspace_buffer,
|
||||
query[:num_decode_tokens],
|
||||
key_cache,
|
||||
value_cache,
|
||||
self.scale,
|
||||
attn_metadata.block_table[:num_decodes],
|
||||
attn_metadata.query_start_loc[:num_decodes],
|
||||
attn_metadata.seq_lens[:num_decodes],
|
||||
attn_metadata.max_seq_len,
|
||||
self.alibi_slopes,
|
||||
self.kv_cache_dtype,
|
||||
"NHD",
|
||||
self.logits_soft_cap,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
None,
|
||||
_PARTITION_SIZE_ROCM,
|
||||
)
|
||||
workspace_buffer = torch.empty(
|
||||
(num_seqs * num_heads * max_num_partitions * head_size)
|
||||
* nbytes_per_qo_elem
|
||||
+ 2 * (num_seqs * num_heads * max_num_partitions) * 4,
|
||||
dtype=torch.uint8,
|
||||
device=output.device,
|
||||
)
|
||||
|
||||
torch.ops.aiter.paged_attention_v1(
|
||||
output[:num_decode_tokens],
|
||||
workspace_buffer,
|
||||
query[:num_decode_tokens],
|
||||
key_cache,
|
||||
value_cache,
|
||||
self.scale,
|
||||
attn_metadata.block_table[:num_decodes],
|
||||
attn_metadata.query_start_loc[:num_decodes],
|
||||
attn_metadata.seq_lens[:num_decodes],
|
||||
attn_metadata.max_seq_len,
|
||||
self.alibi_slopes,
|
||||
self.kv_cache_dtype,
|
||||
"NHD",
|
||||
self.logits_soft_cap,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
None,
|
||||
_PARTITION_SIZE_ROCM,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Cascade attention is not implemented for ROCM AITER"
|
||||
|
||||
Reference in New Issue
Block a user