[ROCm] Faster Custom Paged Attention kernels (#12348)

This commit is contained in:
TJian
2025-03-04 01:24:45 +08:00
committed by GitHub
parent 98175b2816
commit 848a6438ae
6 changed files with 1145 additions and 447 deletions

View File

@@ -25,6 +25,7 @@ MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
# Reduce NUM_BLOCKS when it happens.
NUM_BLOCKS = 4321 # Arbitrary values for testing
PARTITION_SIZE = 512
PARTITION_SIZE_ROCM = 256
# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
DTYPES = [
torch.half, torch.bfloat16, torch.float
@@ -146,6 +147,8 @@ def test_paged_attention(
or (version == "rocm" and head_size not in (64, 128))):
pytest.skip()
global PARTITION_SIZE
current_platform.seed_everything(seed)
torch.set_default_device(device)
scale = float(1.0 / (head_size**0.5))
@@ -214,6 +217,9 @@ def test_paged_attention(
and block_size == BLOCK_SIZES[0]))
elif version in ("v2", "rocm"):
if current_platform.is_rocm() and version == "rocm":
PARTITION_SIZE = PARTITION_SIZE_ROCM
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
assert PARTITION_SIZE % block_size == 0
num_seqs, num_heads, head_size = output.shape
@@ -432,4 +438,4 @@ def test_multi_query_kv_attention(
)
atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)