[ROCm][Perf] Enable shuffle kv cache layout and assembly paged attention kernel for AiterFlashAttentionBackend (#29887)
Signed-off-by: ganyi <ygan@amd.com>
This commit is contained in:
@@ -833,6 +833,7 @@ class rocm_aiter_ops:
|
|||||||
_FMOE_ENABLED = envs.VLLM_ROCM_USE_AITER_MOE
|
_FMOE_ENABLED = envs.VLLM_ROCM_USE_AITER_MOE
|
||||||
_MLA_ENABLED = envs.VLLM_ROCM_USE_AITER_MLA
|
_MLA_ENABLED = envs.VLLM_ROCM_USE_AITER_MLA
|
||||||
_MHA_ENABLED = envs.VLLM_ROCM_USE_AITER_MHA
|
_MHA_ENABLED = envs.VLLM_ROCM_USE_AITER_MHA
|
||||||
|
_SHUFFLE_KV_CACHE_ENABLED = envs.VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT
|
||||||
_TRITON_UNIFIED_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION
|
_TRITON_UNIFIED_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION
|
||||||
# TODO: Consolidate under _LINEAR_ENABLED
|
# TODO: Consolidate under _LINEAR_ENABLED
|
||||||
_FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM
|
_FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM
|
||||||
@@ -859,6 +860,7 @@ class rocm_aiter_ops:
|
|||||||
cls._FMOE_ENABLED = envs.VLLM_ROCM_USE_AITER_MOE
|
cls._FMOE_ENABLED = envs.VLLM_ROCM_USE_AITER_MOE
|
||||||
cls._MLA_ENABLED = envs.VLLM_ROCM_USE_AITER_MLA
|
cls._MLA_ENABLED = envs.VLLM_ROCM_USE_AITER_MLA
|
||||||
cls._MHA_ENABLED = envs.VLLM_ROCM_USE_AITER_MHA
|
cls._MHA_ENABLED = envs.VLLM_ROCM_USE_AITER_MHA
|
||||||
|
cls._SHUFFLE_KV_CACHE_ENABLED = envs.VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT
|
||||||
cls._TRITON_UNIFIED_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION
|
cls._TRITON_UNIFIED_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION
|
||||||
cls._FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM
|
cls._FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM
|
||||||
cls._FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM
|
cls._FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM
|
||||||
@@ -906,6 +908,11 @@ class rocm_aiter_ops:
|
|||||||
def is_mha_enabled(cls) -> bool:
|
def is_mha_enabled(cls) -> bool:
|
||||||
return cls._AITER_ENABLED and cls._MHA_ENABLED
|
return cls._AITER_ENABLED and cls._MHA_ENABLED
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@if_aiter_supported
|
||||||
|
def is_shuffle_kv_cache_enabled(cls) -> bool:
|
||||||
|
return cls._AITER_ENABLED and cls._SHUFFLE_KV_CACHE_ENABLED
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@if_aiter_supported
|
@if_aiter_supported
|
||||||
def is_triton_unified_attn_enabled(cls) -> bool:
|
def is_triton_unified_attn_enabled(cls) -> bool:
|
||||||
|
|||||||
@@ -128,6 +128,7 @@ if TYPE_CHECKING:
|
|||||||
VLLM_ROCM_FP8_PADDING: bool = True
|
VLLM_ROCM_FP8_PADDING: bool = True
|
||||||
VLLM_ROCM_MOE_PADDING: bool = True
|
VLLM_ROCM_MOE_PADDING: bool = True
|
||||||
VLLM_ROCM_CUSTOM_PAGED_ATTN: bool = True
|
VLLM_ROCM_CUSTOM_PAGED_ATTN: bool = True
|
||||||
|
VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT: bool = False
|
||||||
VLLM_ENABLE_V1_MULTIPROCESSING: bool = True
|
VLLM_ENABLE_V1_MULTIPROCESSING: bool = True
|
||||||
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
|
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
|
||||||
VLLM_DISABLE_COMPILE_CACHE: bool = False
|
VLLM_DISABLE_COMPILE_CACHE: bool = False
|
||||||
@@ -1018,6 +1019,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
"VLLM_ROCM_CUSTOM_PAGED_ATTN": lambda: (
|
"VLLM_ROCM_CUSTOM_PAGED_ATTN": lambda: (
|
||||||
os.getenv("VLLM_ROCM_CUSTOM_PAGED_ATTN", "True").lower() in ("true", "1")
|
os.getenv("VLLM_ROCM_CUSTOM_PAGED_ATTN", "True").lower() in ("true", "1")
|
||||||
),
|
),
|
||||||
|
# Whether to use the shuffled kv cache layout
|
||||||
|
"VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT": lambda: (
|
||||||
|
os.getenv("VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT", "False").lower() in ("true", "1")
|
||||||
|
),
|
||||||
# Custom quick allreduce kernel for MI3* cards
|
# Custom quick allreduce kernel for MI3* cards
|
||||||
# Choice of quantization level: FP, INT8, INT6, INT4 or NONE
|
# Choice of quantization level: FP, INT8, INT6, INT4 or NONE
|
||||||
# Recommended for large models to get allreduce
|
# Recommended for large models to get allreduce
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from typing import ClassVar
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm._aiter_ops import rocm_aiter_ops
|
||||||
from vllm.attention.layer import Attention
|
from vllm.attention.layer import Attention
|
||||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@@ -30,7 +31,6 @@ from vllm.v1.kv_cache_interface import AttentionSpec
|
|||||||
|
|
||||||
_PARTITION_SIZE_ROCM = 256
|
_PARTITION_SIZE_ROCM = 256
|
||||||
_CP_TOKENS_PER_ITER_ROCM = 32 * 1024
|
_CP_TOKENS_PER_ITER_ROCM = 32 * 1024
|
||||||
|
|
||||||
if current_platform.is_rocm():
|
if current_platform.is_rocm():
|
||||||
import aiter
|
import aiter
|
||||||
|
|
||||||
@@ -52,7 +52,7 @@ if current_platform.is_rocm():
|
|||||||
cu_seqlens_kv_ptr, # [num_batches + 1]
|
cu_seqlens_kv_ptr, # [num_batches + 1]
|
||||||
token_to_batch_ptr, # [max_cum_tokens]
|
token_to_batch_ptr, # [max_cum_tokens]
|
||||||
seq_start_ptr, # [num_batches]
|
seq_start_ptr, # [num_batches]
|
||||||
k_scale_ptr,
|
k_scale_ptr, # [1] / [num_blocks, num_kv_heads, page_size]
|
||||||
v_scale_ptr,
|
v_scale_ptr,
|
||||||
num_heads,
|
num_heads,
|
||||||
head_size,
|
head_size,
|
||||||
@@ -64,13 +64,15 @@ if current_platform.is_rocm():
|
|||||||
BLOCK_SIZE: tl.constexpr,
|
BLOCK_SIZE: tl.constexpr,
|
||||||
):
|
):
|
||||||
token_id = tl.program_id(0)
|
token_id = tl.program_id(0)
|
||||||
|
head_id = tl.program_id(1)
|
||||||
col_offsets = tl.arange(0, BLOCK_SIZE)
|
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||||
if DEQUANT:
|
|
||||||
k_scale = tl.load(k_scale_ptr)
|
|
||||||
v_scale = tl.load(v_scale_ptr)
|
|
||||||
|
|
||||||
key_ptr_offset = key_ptr + token_id * head_size * num_heads
|
key_ptr_offset = (
|
||||||
value_ptr_offset = value_ptr + token_id * head_size * num_heads
|
key_ptr + token_id * head_size * num_heads + head_id * head_size
|
||||||
|
)
|
||||||
|
value_ptr_offset = (
|
||||||
|
value_ptr + token_id * head_size * num_heads + head_id * head_size
|
||||||
|
)
|
||||||
batch_idx = tl.load(token_to_batch_ptr + token_id)
|
batch_idx = tl.load(token_to_batch_ptr + token_id)
|
||||||
batch_start = tl.load(seq_start_ptr + batch_idx)
|
batch_start = tl.load(seq_start_ptr + batch_idx)
|
||||||
token_start = tl.load(cu_seqlens_kv_ptr + batch_idx)
|
token_start = tl.load(cu_seqlens_kv_ptr + batch_idx)
|
||||||
@@ -89,24 +91,54 @@ if current_platform.is_rocm():
|
|||||||
key_cache_ptr
|
key_cache_ptr
|
||||||
+ block_id * num_heads * head_size * PAGE_SIZE
|
+ block_id * num_heads * head_size * PAGE_SIZE
|
||||||
+ slot_id * num_heads * head_size
|
+ slot_id * num_heads * head_size
|
||||||
|
+ head_id * head_size
|
||||||
)
|
)
|
||||||
value_cache_ptr_offset = (
|
value_cache_ptr_offset = (
|
||||||
value_cache_ptr
|
value_cache_ptr
|
||||||
+ block_id * num_heads * head_size * PAGE_SIZE
|
+ block_id * num_heads * head_size * PAGE_SIZE
|
||||||
+ slot_id * num_heads * head_size
|
+ slot_id * num_heads * head_size
|
||||||
|
+ head_id * head_size
|
||||||
)
|
)
|
||||||
|
k_reg = tl.load(key_cache_ptr_offset + col_offsets)
|
||||||
|
v_reg = tl.load(value_cache_ptr_offset + col_offsets)
|
||||||
|
if DEQUANT:
|
||||||
|
k_scale = tl.load(k_scale_ptr)
|
||||||
|
v_scale = tl.load(v_scale_ptr)
|
||||||
|
k_dtype = k_reg.dtype
|
||||||
|
v_dtype = v_reg.dtype
|
||||||
|
k_reg = (k_reg.to(tl.float32) * k_scale).to(k_dtype)
|
||||||
|
v_reg = (v_reg.to(tl.float32) * v_scale).to(v_dtype)
|
||||||
|
tl.store(key_ptr_offset + col_offsets, k_reg)
|
||||||
|
tl.store(value_ptr_offset + col_offsets, v_reg)
|
||||||
|
|
||||||
for i in tl.range(0, head_size * num_heads, BLOCK_SIZE):
|
elif CACHE_FORMAT == "SHUFFLE":
|
||||||
mask = (col_offsets + i) < head_size * num_heads
|
# for kv cache layout as
|
||||||
k_reg = tl.load(key_cache_ptr_offset + col_offsets + i, mask=mask)
|
# K: [num_blocks, num_head, head_dim // x, page_size, x]
|
||||||
v_reg = tl.load(value_cache_ptr_offset + col_offsets + i, mask=mask)
|
# V: [num_blocks, num_head, page_size // x, head_dim, x]
|
||||||
if DEQUANT:
|
key_cache_ptr_offset = (
|
||||||
k_dtype = k_reg.dtype
|
key_cache_ptr
|
||||||
v_dtype = v_reg.dtype
|
+ block_id * num_heads * head_size * PAGE_SIZE
|
||||||
k_reg = (k_reg.to(tl.float32) * k_scale).to(k_dtype)
|
+ head_id * head_size * PAGE_SIZE
|
||||||
v_reg = (v_reg.to(tl.float32) * v_scale).to(v_dtype)
|
+ slot_id * x
|
||||||
tl.store(key_ptr_offset + col_offsets + i, k_reg, mask=mask)
|
)
|
||||||
tl.store(value_ptr_offset + col_offsets + i, v_reg, mask=mask)
|
value_cache_ptr_offset = (
|
||||||
|
value_cache_ptr
|
||||||
|
+ block_id * num_heads * head_size * PAGE_SIZE
|
||||||
|
+ head_id * head_size * PAGE_SIZE
|
||||||
|
+ (slot_id // x) * head_size * x
|
||||||
|
+ slot_id % x
|
||||||
|
)
|
||||||
|
k_reg_offset = col_offsets // x * PAGE_SIZE * x + col_offsets % x
|
||||||
|
v_reg_offset = col_offsets * x
|
||||||
|
k_reg = tl.load(key_cache_ptr_offset + k_reg_offset)
|
||||||
|
v_reg = tl.load(value_cache_ptr_offset + v_reg_offset)
|
||||||
|
if DEQUANT:
|
||||||
|
k_scale = 1.0
|
||||||
|
v_scale = 1.0
|
||||||
|
k_reg = k_reg.to(tl.float32) * k_scale
|
||||||
|
v_reg = v_reg.to(tl.float32) * v_scale
|
||||||
|
tl.store(key_ptr_offset + col_offsets, k_reg)
|
||||||
|
tl.store(value_ptr_offset + col_offsets, v_reg)
|
||||||
|
|
||||||
def cp_mha_gather_cache(
|
def cp_mha_gather_cache(
|
||||||
key_cache: torch.Tensor,
|
key_cache: torch.Tensor,
|
||||||
@@ -123,17 +155,14 @@ if current_platform.is_rocm():
|
|||||||
kv_cache_layout: str,
|
kv_cache_layout: str,
|
||||||
total_tokens: int,
|
total_tokens: int,
|
||||||
):
|
):
|
||||||
assert kv_cache_layout in ["v0", "NHD", "HND"], (
|
assert kv_cache_layout in ["NHD", "SHUFFLE"], (
|
||||||
"kv_cache_layout only support v0, NHD, HND"
|
"kv_cache_layout only support NHD, SHUFFLE"
|
||||||
)
|
)
|
||||||
head_dim = key.shape[2]
|
head_dim = key.shape[2]
|
||||||
x = 0
|
x = 16 // key_cache.element_size()
|
||||||
# assert dequant is True, "Currently, we only support "\
|
# assert dequant is True, "Currently, we only support "\
|
||||||
# "gather cache with dequant"
|
# "gather cache with dequant"
|
||||||
# For k cache layout: [num_blocks, num_heads, page_size, head_dim]
|
# For k cache layout: [num_blocks, num_heads, page_size, head_dim]
|
||||||
assert kv_cache_layout == "NHD", (
|
|
||||||
"ROCM_AITER_FA_BACKEND Only support NHD kv cache layout for now"
|
|
||||||
)
|
|
||||||
assert head_dim == key_cache.shape[3], (
|
assert head_dim == key_cache.shape[3], (
|
||||||
"We assume your kv cache layout is [num_blocks, "
|
"We assume your kv cache layout is [num_blocks, "
|
||||||
"page_size, num_heads, head_dim], but got otherwise"
|
"page_size, num_heads, head_dim], but got otherwise"
|
||||||
@@ -141,7 +170,7 @@ if current_platform.is_rocm():
|
|||||||
page_size = key_cache.shape[1]
|
page_size = key_cache.shape[1]
|
||||||
num_heads = key_cache.shape[2]
|
num_heads = key_cache.shape[2]
|
||||||
|
|
||||||
grid = lambda meta: (total_tokens,)
|
grid = lambda meta: (total_tokens, num_heads)
|
||||||
cp_mha_gather_cache_kernel[grid](
|
cp_mha_gather_cache_kernel[grid](
|
||||||
key_cache,
|
key_cache,
|
||||||
value_cache,
|
value_cache,
|
||||||
@@ -163,6 +192,112 @@ if current_platform.is_rocm():
|
|||||||
BLOCK_SIZE=head_dim,
|
BLOCK_SIZE=head_dim,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def reshape_and_cache_shuffle_kernel(
|
||||||
|
key_ptr, # [num_tokens, num_kv_heads, head_size]
|
||||||
|
value_ptr, # [num_tokens, num_kv_heads, head_size]
|
||||||
|
key_cache_ptr, # [num_blocks, num_kv_heads, head_size // x, block_size, x]
|
||||||
|
value_cache_ptr, # [num_blocks, num_kv_heads, block_size // x, head_size, x]
|
||||||
|
slot_mapping_ptr, # [num_tokens]
|
||||||
|
k_scale_ptr, # [num_blocks, num_kv_heads, block_size]
|
||||||
|
v_scale_ptr, # [num_blocks, num_kv_heads, block_size]
|
||||||
|
x,
|
||||||
|
k_stride0,
|
||||||
|
v_stride0,
|
||||||
|
block_size,
|
||||||
|
head_size,
|
||||||
|
num_kv_heads,
|
||||||
|
BLOCK_SIZE: tl.constexpr,
|
||||||
|
QUANT: tl.constexpr,
|
||||||
|
IS_FNUZ: tl.constexpr,
|
||||||
|
):
|
||||||
|
tid = tl.program_id(0)
|
||||||
|
head_id = tl.program_id(1)
|
||||||
|
offset = tl.arange(0, BLOCK_SIZE)
|
||||||
|
src_offset_k = tid * k_stride0 + head_id * head_size
|
||||||
|
src_offset_v = tid * v_stride0 + head_id * head_size
|
||||||
|
slot_id = tl.load(slot_mapping_ptr + tid)
|
||||||
|
if slot_id < 0:
|
||||||
|
return
|
||||||
|
block_id = slot_id // block_size
|
||||||
|
block_offset = slot_id % block_size
|
||||||
|
dst_offset = (
|
||||||
|
block_id * num_kv_heads * head_size * block_size
|
||||||
|
+ head_id * head_size * block_size
|
||||||
|
)
|
||||||
|
dst_k_shuffle_offset = (
|
||||||
|
dst_offset + offset // x * block_size * x + block_offset * x + offset % x
|
||||||
|
)
|
||||||
|
dst_v_shuffle_offset = (
|
||||||
|
dst_offset
|
||||||
|
+ block_offset // x * head_size * x
|
||||||
|
+ offset * x
|
||||||
|
+ block_offset % x
|
||||||
|
)
|
||||||
|
k_val = tl.load(key_ptr + src_offset_k + offset)
|
||||||
|
v_val = tl.load(value_ptr + src_offset_v + offset)
|
||||||
|
if QUANT:
|
||||||
|
k_scale = 1.0
|
||||||
|
v_scale = 1.0
|
||||||
|
k_dtype = key_cache_ptr.type.element_ty
|
||||||
|
v_dtype = value_cache_ptr.type.element_ty
|
||||||
|
k_val = (k_val.to(tl.float32) / k_scale).to(k_dtype)
|
||||||
|
v_val = (v_val.to(tl.float32) / v_scale).to(v_dtype)
|
||||||
|
tl.store(key_cache_ptr + dst_k_shuffle_offset, k_val)
|
||||||
|
tl.store(value_cache_ptr + dst_v_shuffle_offset, v_val)
|
||||||
|
|
||||||
|
def reshape_and_cache_shuffle_triton(
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
value_cache: torch.Tensor,
|
||||||
|
slot_mapping: torch.Tensor,
|
||||||
|
kv_cache_dtype: str,
|
||||||
|
k_scales: torch.Tensor,
|
||||||
|
v_scales: torch.Tensor,
|
||||||
|
):
|
||||||
|
num_tokens = slot_mapping.shape[0]
|
||||||
|
_, num_kv_heads, head_size = key.shape
|
||||||
|
num_blocks, block_size, _, _ = key_cache.shape
|
||||||
|
x = 16 // key_cache.element_size()
|
||||||
|
k_cache_template = torch.empty(
|
||||||
|
[num_blocks, num_kv_heads, head_size // x, block_size, x],
|
||||||
|
dtype=key_cache.dtype,
|
||||||
|
device="meta",
|
||||||
|
)
|
||||||
|
v_cache_template = torch.empty(
|
||||||
|
[num_blocks, num_kv_heads, block_size // x, head_size, x],
|
||||||
|
dtype=value_cache.dtype,
|
||||||
|
device="meta",
|
||||||
|
)
|
||||||
|
new_key_cache = key_cache.view_as(k_cache_template)
|
||||||
|
new_value_cache = value_cache.view_as(v_cache_template)
|
||||||
|
QUANT = False
|
||||||
|
if kv_cache_dtype.startswith("fp8"):
|
||||||
|
QUANT = True
|
||||||
|
grid = (
|
||||||
|
num_tokens,
|
||||||
|
num_kv_heads,
|
||||||
|
)
|
||||||
|
reshape_and_cache_shuffle_kernel[grid](
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
new_key_cache,
|
||||||
|
new_value_cache,
|
||||||
|
slot_mapping,
|
||||||
|
k_scales,
|
||||||
|
v_scales,
|
||||||
|
x,
|
||||||
|
key.stride(0),
|
||||||
|
value.stride(0),
|
||||||
|
block_size,
|
||||||
|
head_size,
|
||||||
|
num_kv_heads,
|
||||||
|
BLOCK_SIZE=head_size,
|
||||||
|
QUANT=QUANT,
|
||||||
|
IS_FNUZ=current_platform.fp8_dtype() == torch.float8_e4m3fnuz,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@@ -253,6 +388,11 @@ class AiterFlashAttentionMetadata:
|
|||||||
common_prefix_len: int
|
common_prefix_len: int
|
||||||
total_tokens: int
|
total_tokens: int
|
||||||
|
|
||||||
|
# Only for fp8 shuffle layout kv cache, we allocate kv_scale for each layer
|
||||||
|
# since we might integrate per token quant for kv cache in the future.
|
||||||
|
k_scale: dict[str, torch.Tensor] | None
|
||||||
|
v_scale: dict[str, torch.Tensor] | None
|
||||||
|
|
||||||
|
|
||||||
class AiterFlashAttentionMetadataBuilder(
|
class AiterFlashAttentionMetadataBuilder(
|
||||||
AttentionMetadataBuilder[AiterFlashAttentionMetadata]
|
AttentionMetadataBuilder[AiterFlashAttentionMetadata]
|
||||||
@@ -303,6 +443,7 @@ class AiterFlashAttentionMetadataBuilder(
|
|||||||
dtype=self.model_config.dtype,
|
dtype=self.model_config.dtype,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
self.scale = torch.tensor([1.0], dtype=torch.float, device=self.device)
|
||||||
|
|
||||||
def build_for_cudagraph_capture(
|
def build_for_cudagraph_capture(
|
||||||
self, common_attn_metadata: CommonAttentionMetadata
|
self, common_attn_metadata: CommonAttentionMetadata
|
||||||
@@ -325,7 +466,27 @@ class AiterFlashAttentionMetadataBuilder(
|
|||||||
common_attn_metadata,
|
common_attn_metadata,
|
||||||
decode_threshold=self.reorder_batch_threshold,
|
decode_threshold=self.reorder_batch_threshold,
|
||||||
)
|
)
|
||||||
|
# Allocate scales for fp8 shuffle kv cache with shuffle_kv_cache enabled
|
||||||
|
if (
|
||||||
|
rocm_aiter_ops.is_shuffle_kv_cache_enabled()
|
||||||
|
and self.scale.numel() == 1
|
||||||
|
and self.vllm_config.cache_config.cache_dtype.startswith("fp8")
|
||||||
|
):
|
||||||
|
layers = get_layers_from_vllm_config(self.vllm_config, Attention)
|
||||||
|
first_layer_name = [k for k in layers][0]
|
||||||
|
kv_cache_shape = (
|
||||||
|
self.vllm_config.compilation_config.static_forward_context[
|
||||||
|
first_layer_name
|
||||||
|
]
|
||||||
|
.kv_cache[0]
|
||||||
|
.shape
|
||||||
|
)
|
||||||
|
num_blocks = kv_cache_shape[1]
|
||||||
|
self.scale = torch.ones(
|
||||||
|
[num_blocks, self.num_heads_kv, self.block_size],
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
(
|
(
|
||||||
num_decodes,
|
num_decodes,
|
||||||
num_extends,
|
num_extends,
|
||||||
@@ -507,6 +668,8 @@ class AiterFlashAttentionMetadataBuilder(
|
|||||||
use_cascade=use_cascade,
|
use_cascade=use_cascade,
|
||||||
common_prefix_len=common_prefix_len,
|
common_prefix_len=common_prefix_len,
|
||||||
total_tokens=self.total_tokens,
|
total_tokens=self.total_tokens,
|
||||||
|
k_scale=self.scale,
|
||||||
|
v_scale=self.scale,
|
||||||
)
|
)
|
||||||
return attn_metadata
|
return attn_metadata
|
||||||
|
|
||||||
@@ -548,7 +711,6 @@ class AiterFlashAttentionBackend(AttentionBackend):
|
|||||||
) -> tuple[int, ...]:
|
) -> tuple[int, ...]:
|
||||||
if block_size % 16 != 0:
|
if block_size % 16 != 0:
|
||||||
raise ValueError("Block size must be a multiple of 16.")
|
raise ValueError("Block size must be a multiple of 16.")
|
||||||
|
|
||||||
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||||
|
|
||||||
|
|
||||||
@@ -630,7 +792,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
|||||||
cu_seqlens_kv=swa_cu_seqlens,
|
cu_seqlens_kv=swa_cu_seqlens,
|
||||||
token_to_batch=swa_token_to_batch,
|
token_to_batch=swa_token_to_batch,
|
||||||
seq_starts=swa_seq_starts,
|
seq_starts=swa_seq_starts,
|
||||||
dequant=False,
|
dequant=self.kv_cache_dtype.startswith("fp8"),
|
||||||
kv_cache_layout="NHD",
|
kv_cache_layout="NHD",
|
||||||
total_tokens=swa_total_tokens,
|
total_tokens=swa_total_tokens,
|
||||||
)
|
)
|
||||||
@@ -668,8 +830,8 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
|||||||
min_seqlen_q: int,
|
min_seqlen_q: int,
|
||||||
block_table: torch.Tensor,
|
block_table: torch.Tensor,
|
||||||
slot_mapping: torch.Tensor,
|
slot_mapping: torch.Tensor,
|
||||||
k_scale: float,
|
k_scale: torch.Tensor,
|
||||||
v_scale: float,
|
v_scale: torch.Tensor,
|
||||||
):
|
):
|
||||||
if self.sliding_window[0] != -1:
|
if self.sliding_window[0] != -1:
|
||||||
self.extend_for_sliding_window(
|
self.extend_for_sliding_window(
|
||||||
@@ -725,8 +887,10 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
|||||||
cu_seqlens_kv=cu_seqlens_kv[chunk_idx],
|
cu_seqlens_kv=cu_seqlens_kv[chunk_idx],
|
||||||
token_to_batch=token_to_batch[chunk_idx],
|
token_to_batch=token_to_batch[chunk_idx],
|
||||||
seq_starts=chunk_starts[chunk_idx],
|
seq_starts=chunk_starts[chunk_idx],
|
||||||
dequant=False,
|
dequant=self.kv_cache_dtype.startswith("fp8"),
|
||||||
kv_cache_layout="NHD",
|
kv_cache_layout="SHUFFLE"
|
||||||
|
if rocm_aiter_ops.is_shuffle_kv_cache_enabled()
|
||||||
|
else "NHD",
|
||||||
total_tokens=total_token_per_batch[chunk_idx],
|
total_tokens=total_token_per_batch[chunk_idx],
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -823,6 +987,9 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
|||||||
# key and value may be None in the case of cross attention. They are
|
# key and value may be None in the case of cross attention. They are
|
||||||
# calculated once based on the output from the encoder and then cached
|
# calculated once based on the output from the encoder and then cached
|
||||||
# in KV cache.
|
# in KV cache.
|
||||||
|
if self.kv_cache_dtype.startswith("fp8"):
|
||||||
|
key_cache = key_cache.view(current_platform.fp8_dtype())
|
||||||
|
value_cache = value_cache.view(current_platform.fp8_dtype())
|
||||||
if (
|
if (
|
||||||
self.kv_sharing_target_layer_name is None
|
self.kv_sharing_target_layer_name is None
|
||||||
and key is not None
|
and key is not None
|
||||||
@@ -835,21 +1002,31 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
|||||||
# key[:num_actual_tokens] and value[:num_actual_tokens] because
|
# key[:num_actual_tokens] and value[:num_actual_tokens] because
|
||||||
# the reshape_and_cache_flash op uses the slot_mapping's shape
|
# the reshape_and_cache_flash op uses the slot_mapping's shape
|
||||||
# to determine the number of actual tokens.
|
# to determine the number of actual tokens.
|
||||||
|
if rocm_aiter_ops.is_shuffle_kv_cache_enabled():
|
||||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
# We may calculate per token quant scale in
|
||||||
key,
|
# reshape_and_cache_shuffle_triton which might differ from
|
||||||
value,
|
# vllm's style when shuffle layout is used.
|
||||||
key_cache,
|
reshape_and_cache_shuffle_triton(
|
||||||
value_cache,
|
key,
|
||||||
attn_metadata.slot_mapping,
|
value,
|
||||||
self.kv_cache_dtype,
|
key_cache,
|
||||||
layer._k_scale,
|
value_cache,
|
||||||
layer._v_scale,
|
attn_metadata.slot_mapping,
|
||||||
)
|
self.kv_cache_dtype,
|
||||||
|
attn_metadata.k_scale,
|
||||||
if self.kv_cache_dtype.startswith("fp8"):
|
attn_metadata.v_scale,
|
||||||
key_cache = key_cache.view(current_platform.fp8_dtype())
|
)
|
||||||
value_cache = value_cache.view(current_platform.fp8_dtype())
|
else:
|
||||||
|
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
attn_metadata.slot_mapping,
|
||||||
|
self.kv_cache_dtype,
|
||||||
|
layer._k_scale,
|
||||||
|
layer._v_scale,
|
||||||
|
)
|
||||||
|
|
||||||
# decode:extend:prefill
|
# decode:extend:prefill
|
||||||
query = query[:num_actual_tokens]
|
query = query[:num_actual_tokens]
|
||||||
@@ -902,6 +1079,11 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
|||||||
extend_keys = key[extend_tokens_slice]
|
extend_keys = key[extend_tokens_slice]
|
||||||
extend_values = value[extend_tokens_slice]
|
extend_values = value[extend_tokens_slice]
|
||||||
extend_outputs = output[extend_tokens_slice]
|
extend_outputs = output[extend_tokens_slice]
|
||||||
|
k_scale = layer._k_scale
|
||||||
|
v_scale = layer._v_scale
|
||||||
|
if rocm_aiter_ops.is_shuffle_kv_cache_enabled():
|
||||||
|
k_scale = attn_metadata.k_scale
|
||||||
|
v_scale = attn_metadata.v_scale
|
||||||
self.extend_forward(
|
self.extend_forward(
|
||||||
attn_metadata=attn_metadata,
|
attn_metadata=attn_metadata,
|
||||||
query=extend_querys,
|
query=extend_querys,
|
||||||
@@ -920,14 +1102,17 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
|||||||
slot_mapping=attn_metadata.slot_mapping[
|
slot_mapping=attn_metadata.slot_mapping[
|
||||||
num_decodes : num_decodes + num_extends
|
num_decodes : num_decodes + num_extends
|
||||||
],
|
],
|
||||||
k_scale=layer._k_scale,
|
k_scale=k_scale,
|
||||||
v_scale=layer._v_scale,
|
v_scale=v_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
# calculate for decodes
|
# calculate for decodes
|
||||||
if num_decodes > 0:
|
if num_decodes > 0:
|
||||||
assert attn_metadata.decode_metadata is not None
|
assert attn_metadata.decode_metadata is not None
|
||||||
if self.sliding_window[0] != -1:
|
if self.sliding_window[0] != -1:
|
||||||
|
assert not rocm_aiter_ops.is_shuffle_kv_cache_enabled(), (
|
||||||
|
"Sliding window with shuffle layout is not supported yet."
|
||||||
|
)
|
||||||
from aiter.ops.triton.unified_attention import (
|
from aiter.ops.triton.unified_attention import (
|
||||||
unified_attention,
|
unified_attention,
|
||||||
)
|
)
|
||||||
@@ -957,41 +1142,71 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
assert attn_metadata.decode_metadata is not None
|
assert attn_metadata.decode_metadata is not None
|
||||||
_, num_heads, head_size = query.shape
|
|
||||||
nbytes_per_qo_elem = torch.finfo(query.dtype).bits // 8
|
|
||||||
num_seqs = attn_metadata.seq_lens.shape[0]
|
|
||||||
max_num_partitions = (
|
|
||||||
attn_metadata.max_seq_len + _PARTITION_SIZE_ROCM - 1
|
|
||||||
) // _PARTITION_SIZE_ROCM
|
|
||||||
|
|
||||||
workspace_buffer = torch.empty(
|
if rocm_aiter_ops.is_shuffle_kv_cache_enabled():
|
||||||
(num_seqs * num_heads * max_num_partitions * head_size)
|
num_blocks, block_size, num_kv_heads, head_size = key_cache.shape
|
||||||
* nbytes_per_qo_elem
|
x = 16 // key_cache.element_size()
|
||||||
+ 2 * (num_seqs * num_heads * max_num_partitions) * 4,
|
k_cache_template = torch.empty(
|
||||||
dtype=torch.uint8,
|
[num_blocks, num_kv_heads, head_size // x, block_size, x],
|
||||||
device=output.device,
|
dtype=key_cache.dtype,
|
||||||
)
|
device="meta",
|
||||||
|
)
|
||||||
|
v_cache_template = torch.empty(
|
||||||
|
[num_blocks, num_kv_heads, block_size // x, head_size, x],
|
||||||
|
dtype=value_cache.dtype,
|
||||||
|
device="meta",
|
||||||
|
)
|
||||||
|
new_key_cache = key_cache.view_as(k_cache_template)
|
||||||
|
new_value_cache = value_cache.view_as(v_cache_template)
|
||||||
|
aiter.pa_fwd_asm(
|
||||||
|
Q=query[:num_decode_tokens],
|
||||||
|
K=new_key_cache,
|
||||||
|
V=new_value_cache,
|
||||||
|
block_tables=attn_metadata.block_table[:num_decodes],
|
||||||
|
context_lens=attn_metadata.seq_lens[:num_decodes],
|
||||||
|
block_tables_stride0=attn_metadata.block_table[
|
||||||
|
:num_decodes
|
||||||
|
].stride(0),
|
||||||
|
K_QScale=attn_metadata.k_scale,
|
||||||
|
V_QScale=attn_metadata.v_scale,
|
||||||
|
out_=output[:num_decode_tokens],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
_, num_heads, head_size = query.shape
|
||||||
|
nbytes_per_qo_elem = torch.finfo(query.dtype).bits // 8
|
||||||
|
num_seqs = attn_metadata.seq_lens.shape[0]
|
||||||
|
max_num_partitions = (
|
||||||
|
attn_metadata.max_seq_len + _PARTITION_SIZE_ROCM - 1
|
||||||
|
) // _PARTITION_SIZE_ROCM
|
||||||
|
|
||||||
torch.ops.aiter.paged_attention_v1(
|
workspace_buffer = torch.empty(
|
||||||
output[:num_decode_tokens],
|
(num_seqs * num_heads * max_num_partitions * head_size)
|
||||||
workspace_buffer,
|
* nbytes_per_qo_elem
|
||||||
query[:num_decode_tokens],
|
+ 2 * (num_seqs * num_heads * max_num_partitions) * 4,
|
||||||
key_cache,
|
dtype=torch.uint8,
|
||||||
value_cache,
|
device=output.device,
|
||||||
self.scale,
|
)
|
||||||
attn_metadata.block_table[:num_decodes],
|
|
||||||
attn_metadata.query_start_loc[:num_decodes],
|
torch.ops.aiter.paged_attention_v1(
|
||||||
attn_metadata.seq_lens[:num_decodes],
|
output[:num_decode_tokens],
|
||||||
attn_metadata.max_seq_len,
|
workspace_buffer,
|
||||||
self.alibi_slopes,
|
query[:num_decode_tokens],
|
||||||
self.kv_cache_dtype,
|
key_cache,
|
||||||
"NHD",
|
value_cache,
|
||||||
self.logits_soft_cap,
|
self.scale,
|
||||||
layer._k_scale,
|
attn_metadata.block_table[:num_decodes],
|
||||||
layer._v_scale,
|
attn_metadata.query_start_loc[:num_decodes],
|
||||||
None,
|
attn_metadata.seq_lens[:num_decodes],
|
||||||
_PARTITION_SIZE_ROCM,
|
attn_metadata.max_seq_len,
|
||||||
)
|
self.alibi_slopes,
|
||||||
|
self.kv_cache_dtype,
|
||||||
|
"NHD",
|
||||||
|
self.logits_soft_cap,
|
||||||
|
layer._k_scale,
|
||||||
|
layer._v_scale,
|
||||||
|
None,
|
||||||
|
_PARTITION_SIZE_ROCM,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Cascade attention is not implemented for ROCM AITER"
|
"Cascade attention is not implemented for ROCM AITER"
|
||||||
|
|||||||
Reference in New Issue
Block a user