Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -42,9 +42,7 @@ BLOCK_SIZES = [16, 32]
|
||||
USE_ALIBI = [False, True]
|
||||
KV_CACHE_DTYPE = ["auto", "fp8"]
|
||||
SEEDS = [0]
|
||||
CUDA_DEVICES = [
|
||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||
]
|
||||
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
||||
|
||||
|
||||
def ref_masked_attention(
|
||||
@@ -110,8 +108,7 @@ def ref_single_query_cached_kv_attention(
|
||||
# Create the ALiBi bias used in the paged attention kernel.
|
||||
position_ids = torch.arange(seq_len).int()
|
||||
alibi_bias = (position_ids - seq_len + 1).float()
|
||||
alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(
|
||||
1, 1, -1)
|
||||
alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(1, 1, -1)
|
||||
|
||||
out = ref_masked_attention(q, keys, values, scale, alibi_bias)
|
||||
out = out.view(num_query_heads, head_size)
|
||||
@@ -119,8 +116,8 @@ def ref_single_query_cached_kv_attention(
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"version",
|
||||
["v1", "v2"] if not current_platform.is_rocm() else ["v1", "v2", "rocm"])
|
||||
"version", ["v1", "v2"] if not current_platform.is_rocm() else ["v1", "v2", "rocm"]
|
||||
)
|
||||
@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@@ -143,13 +140,18 @@ def test_paged_attention(
|
||||
seed: int,
|
||||
device: str,
|
||||
) -> None:
|
||||
if ((kv_cache_dtype == "fp8" and head_size % 16)
|
||||
or (version == "rocm" and head_size not in (64, 128))):
|
||||
if (kv_cache_dtype == "fp8" and head_size % 16) or (
|
||||
version == "rocm" and head_size not in (64, 128)
|
||||
):
|
||||
pytest.skip()
|
||||
|
||||
if (version == "rocm" and current_platform.is_navi()
|
||||
and (kv_cache_dtype == "fp8" or head_size != 128
|
||||
or block_size != 16 or use_alibi)):
|
||||
if (
|
||||
version == "rocm"
|
||||
and current_platform.is_navi()
|
||||
and (
|
||||
kv_cache_dtype == "fp8" or head_size != 128 or block_size != 16 or use_alibi
|
||||
)
|
||||
):
|
||||
pytest.skip()
|
||||
|
||||
global PARTITION_SIZE
|
||||
@@ -177,18 +179,24 @@ def test_paged_attention(
|
||||
block_tables_lst: list[list[int]] = []
|
||||
for _ in range(num_seqs):
|
||||
block_table = [
|
||||
random.randint(0, NUM_BLOCKS - 1)
|
||||
for _ in range(max_num_blocks_per_seq)
|
||||
random.randint(0, NUM_BLOCKS - 1) for _ in range(max_num_blocks_per_seq)
|
||||
]
|
||||
block_tables_lst.append(block_table)
|
||||
|
||||
block_tables = torch.tensor(block_tables_lst, dtype=torch.int)
|
||||
|
||||
# Create the KV caches.
|
||||
key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
|
||||
num_kv_heads, head_size,
|
||||
kv_cache_dtype, dtype, seed,
|
||||
device)
|
||||
key_caches, value_caches = kv_cache_factory(
|
||||
NUM_BLOCKS,
|
||||
block_size,
|
||||
1,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
kv_cache_dtype,
|
||||
dtype,
|
||||
seed,
|
||||
device,
|
||||
)
|
||||
key_cache, value_cache = key_caches[0], value_caches[0]
|
||||
|
||||
# Using default kv_scale
|
||||
@@ -214,18 +222,37 @@ def test_paged_attention(
|
||||
v_scale,
|
||||
)
|
||||
|
||||
opcheck(torch.ops._C.paged_attention_v1,
|
||||
(output, query, key_cache, value_cache, num_kv_heads, scale,
|
||||
block_tables, seq_lens, block_size, max_seq_len, alibi_slopes,
|
||||
kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
|
||||
cond=(head_size == HEAD_SIZES[0]
|
||||
and block_size == BLOCK_SIZES[0]))
|
||||
opcheck(
|
||||
torch.ops._C.paged_attention_v1,
|
||||
(
|
||||
output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
seq_lens,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
64,
|
||||
0,
|
||||
),
|
||||
cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0]),
|
||||
)
|
||||
|
||||
elif version in ("v2", "rocm"):
|
||||
if current_platform.is_rocm() and version == "rocm":
|
||||
PARTITION_SIZE = PARTITION_SIZE_ROCM
|
||||
|
||||
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
|
||||
num_partitions = (max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE
|
||||
assert PARTITION_SIZE % block_size == 0
|
||||
num_seqs, num_heads, head_size = output.shape
|
||||
tmp_output = torch.empty(
|
||||
@@ -258,13 +285,34 @@ def test_paged_attention(
|
||||
v_scale,
|
||||
)
|
||||
|
||||
opcheck(torch.ops._C.paged_attention_v2,
|
||||
(output, exp_sums, max_logits, tmp_output, query,
|
||||
key_cache, value_cache, num_kv_heads, scale, block_tables,
|
||||
seq_lens, block_size, max_seq_len, alibi_slopes,
|
||||
kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
|
||||
cond=(head_size == HEAD_SIZES[0]
|
||||
and block_size == BLOCK_SIZES[0]))
|
||||
opcheck(
|
||||
torch.ops._C.paged_attention_v2,
|
||||
(
|
||||
output,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
seq_lens,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
64,
|
||||
0,
|
||||
),
|
||||
cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0]),
|
||||
)
|
||||
|
||||
else:
|
||||
ops.paged_attention_rocm(
|
||||
@@ -288,13 +336,30 @@ def test_paged_attention(
|
||||
v_scale,
|
||||
)
|
||||
|
||||
opcheck(torch.ops._rocm_C.paged_attention,
|
||||
(output, exp_sums, max_logits, tmp_output, query,
|
||||
key_cache, value_cache, num_kv_heads, scale, block_tables,
|
||||
seq_lens, None, block_size, max_seq_len, alibi_slopes,
|
||||
kv_cache_dtype, k_scale, v_scale),
|
||||
cond=(head_size == HEAD_SIZES[0]
|
||||
and block_size == BLOCK_SIZES[0]))
|
||||
opcheck(
|
||||
torch.ops._rocm_C.paged_attention,
|
||||
(
|
||||
output,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
seq_lens,
|
||||
None,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
),
|
||||
cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0]),
|
||||
)
|
||||
|
||||
else:
|
||||
raise AssertionError(f"Unknown version: {version}")
|
||||
@@ -303,18 +368,17 @@ def test_paged_attention(
|
||||
if kv_cache_dtype == "fp8":
|
||||
# Convert cache data back to dtype.
|
||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||
key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x,
|
||||
block_size, x)
|
||||
dequantized_key_cache = torch.empty(size=key_cache_shape,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, block_size, x)
|
||||
dequantized_key_cache = torch.empty(
|
||||
size=key_cache_shape, dtype=dtype, device=device
|
||||
)
|
||||
ops.convert_fp8(dequantized_key_cache, key_cache)
|
||||
key_cache = dequantized_key_cache
|
||||
|
||||
value_cache_shape = value_cache.shape
|
||||
dequantized_value_cache = torch.empty(size=value_cache_shape,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
dequantized_value_cache = torch.empty(
|
||||
size=value_cache_shape, dtype=dtype, device=device
|
||||
)
|
||||
ops.convert_fp8(dequantized_value_cache, value_cache)
|
||||
value_cache = dequantized_value_cache
|
||||
|
||||
@@ -367,8 +431,9 @@ def ref_multi_query_kv_attention(
|
||||
if alibi_bias:
|
||||
attn_mask = alibi_bias[i]
|
||||
else:
|
||||
attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
|
||||
diagonal=1)
|
||||
attn_mask = torch.triu(
|
||||
torch.ones(seq_len, seq_len, dtype=dtype), diagonal=1
|
||||
)
|
||||
attn_mask = attn_mask * torch.finfo(dtype).min
|
||||
attn_mask = attn_mask.to(dtype=dtype)
|
||||
|
||||
@@ -390,8 +455,9 @@ def ref_multi_query_kv_attention(
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.skipif(current_platform.is_rocm(),
|
||||
reason="Xformers backend is not supported on ROCm.")
|
||||
@pytest.mark.skipif(
|
||||
current_platform.is_rocm(), reason="Xformers backend is not supported on ROCm."
|
||||
)
|
||||
@torch.inference_mode()
|
||||
def test_multi_query_kv_attention(
|
||||
num_seqs: int,
|
||||
@@ -413,13 +479,11 @@ def test_multi_query_kv_attention(
|
||||
|
||||
scale = float(1.0 / (head_size**0.5))
|
||||
num_query_heads, num_kv_heads = num_heads
|
||||
qkv = torch.empty(num_tokens,
|
||||
num_query_heads + 2 * num_kv_heads,
|
||||
head_size,
|
||||
dtype=dtype)
|
||||
qkv = torch.empty(
|
||||
num_tokens, num_query_heads + 2 * num_kv_heads, head_size, dtype=dtype
|
||||
)
|
||||
qkv.uniform_(-scale, scale)
|
||||
query, key, value = qkv.split(
|
||||
[num_query_heads, num_kv_heads, num_kv_heads], dim=1)
|
||||
query, key, value = qkv.split([num_query_heads, num_kv_heads, num_kv_heads], dim=1)
|
||||
|
||||
num_queries_per_kv = num_query_heads // num_kv_heads
|
||||
if num_queries_per_kv > 1:
|
||||
@@ -429,8 +493,7 @@ def test_multi_query_kv_attention(
|
||||
alibi_bias = None
|
||||
if use_alibi:
|
||||
alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)
|
||||
attn_bias = make_alibi_bias(alibi_slopes, num_kv_heads, dtype,
|
||||
seq_lens)
|
||||
attn_bias = make_alibi_bias(alibi_slopes, num_kv_heads, dtype, seq_lens)
|
||||
output = torch.empty_like(query)
|
||||
start = 0
|
||||
# Dynamic sequence length not supported with custom attn_bias.
|
||||
@@ -442,7 +505,8 @@ def test_multi_query_kv_attention(
|
||||
value[None, start:end],
|
||||
attn_bias=attn_bias[i],
|
||||
p=0.0,
|
||||
scale=scale)
|
||||
scale=scale,
|
||||
)
|
||||
output[start:end].copy_(out.view_as(query[start:end]))
|
||||
start += seq_len
|
||||
# xformers.AttentionBias to Tensor for use in reference impl.
|
||||
@@ -485,8 +549,9 @@ def test_multi_query_kv_attention(
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.skipif(current_platform.is_rocm(),
|
||||
reason="Xformers backend is not supported on ROCm.")
|
||||
@pytest.mark.skipif(
|
||||
current_platform.is_rocm(), reason="Xformers backend is not supported on ROCm."
|
||||
)
|
||||
@torch.inference_mode()
|
||||
def test_multi_query_kv_attention_with_alibi(
|
||||
num_seqs: int,
|
||||
|
||||
Reference in New Issue
Block a user