[ROCm][Kernel][V1] Enable AMD Radeon GPU Custom Paged Attention on v1 (#17004)
Signed-off-by: Hosang Yoon <hosang.yoon@amd.com>
This commit is contained in:
@@ -102,26 +102,42 @@ def on_mi250_mi300() -> bool:
|
||||
|
||||
|
||||
@cache
|
||||
def use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int,
|
||||
block_size: int, gqa_ratio: int,
|
||||
max_seq_len: int,
|
||||
sliding_window: int) -> bool:
|
||||
def use_rocm_custom_paged_attention(
|
||||
qtype: torch.dtype,
|
||||
head_size: int,
|
||||
block_size: int,
|
||||
gqa_ratio: int,
|
||||
max_seq_len: int,
|
||||
sliding_window: int,
|
||||
kv_cache_dtype: str,
|
||||
alibi_slopes: Optional[torch.Tensor] = None) -> bool:
|
||||
|
||||
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
|
||||
ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
|
||||
ON_GFX11_GFX12 = any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])
|
||||
|
||||
# rocm custom page attention not support on gfx1*
|
||||
# custom paged attn always supported on V0. On V1, requires sliding window
|
||||
# disabled due to observed numerical discrepancy.
|
||||
return (ON_GFX9 and (not envs.VLLM_USE_V1 or sliding_window == 0
|
||||
or sliding_window == (-1, -1))
|
||||
and (qtype == torch.half or qtype == torch.bfloat16)
|
||||
and (head_size == 64 or head_size == 128)
|
||||
and (block_size == 16 or block_size == 32)
|
||||
and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768
|
||||
and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
|
||||
and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN
|
||||
and envs.VLLM_ROCM_USE_AITER))
|
||||
if ON_GFX9:
|
||||
return ((not envs.VLLM_USE_V1 or sliding_window == 0
|
||||
or sliding_window == (-1, -1))
|
||||
and (qtype == torch.half or qtype == torch.bfloat16)
|
||||
and (head_size == 64 or head_size == 128)
|
||||
and (block_size == 16 or block_size == 32)
|
||||
and (gqa_ratio >= 1 and gqa_ratio <= 16)
|
||||
and max_seq_len <= 32768 and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
|
||||
and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN
|
||||
and envs.VLLM_ROCM_USE_AITER))
|
||||
|
||||
else:
|
||||
return (ON_GFX11_GFX12 and (not envs.VLLM_USE_V1 or sliding_window == 0
|
||||
or sliding_window == (-1, -1))
|
||||
and (qtype == torch.half or qtype == torch.bfloat16)
|
||||
and head_size == 128 and block_size == 16
|
||||
and (gqa_ratio >= 3 and gqa_ratio <= 16)
|
||||
and max_seq_len <= 32768 and alibi_slopes is None
|
||||
and kv_cache_dtype == "auto"
|
||||
and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
|
||||
|
||||
|
||||
class RocmPlatform(Platform):
|
||||
@@ -362,3 +378,7 @@ class RocmPlatform(Platform):
|
||||
def get_cu_count(cls, device_id: int = 0) -> int:
|
||||
return torch.cuda.get_device_properties(
|
||||
device_id).multi_processor_count
|
||||
|
||||
@classmethod
|
||||
def is_navi(cls) -> bool:
|
||||
return 'gfx1' in torch.cuda.get_device_properties(0).gcnArchName
|
||||
|
||||
Reference in New Issue
Block a user