[ROCm] Faster Custom Paged Attention kernels (#12348)
This commit is contained in:
@@ -25,6 +25,7 @@ MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
|
||||
# Reduce NUM_BLOCKS when it happens.
|
||||
NUM_BLOCKS = 4321 # Arbitrary values for testing
|
||||
PARTITION_SIZE = 512
|
||||
PARTITION_SIZE_ROCM = 256
|
||||
# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
|
||||
DTYPES = [
|
||||
torch.half, torch.bfloat16, torch.float
|
||||
@@ -146,6 +147,8 @@ def test_paged_attention(
|
||||
or (version == "rocm" and head_size not in (64, 128))):
|
||||
pytest.skip()
|
||||
|
||||
global PARTITION_SIZE
|
||||
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
scale = float(1.0 / (head_size**0.5))
|
||||
@@ -214,6 +217,9 @@ def test_paged_attention(
|
||||
and block_size == BLOCK_SIZES[0]))
|
||||
|
||||
elif version in ("v2", "rocm"):
|
||||
if current_platform.is_rocm() and version == "rocm":
|
||||
PARTITION_SIZE = PARTITION_SIZE_ROCM
|
||||
|
||||
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
|
||||
assert PARTITION_SIZE % block_size == 0
|
||||
num_seqs, num_heads, head_size = output.shape
|
||||
@@ -432,4 +438,4 @@ def test_multi_query_kv_attention(
|
||||
)
|
||||
atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
|
||||
rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5
|
||||
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
|
||||
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
|
||||
Reference in New Issue
Block a user