Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -42,9 +42,7 @@ BLOCK_SIZES = [16, 32]
USE_ALIBI = [False, True]
KV_CACHE_DTYPE = ["auto", "fp8"]
SEEDS = [0]
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
def ref_masked_attention(
@@ -110,8 +108,7 @@ def ref_single_query_cached_kv_attention(
# Create the ALiBi bias used in the paged attention kernel.
position_ids = torch.arange(seq_len).int()
alibi_bias = (position_ids - seq_len + 1).float()
alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(
1, 1, -1)
alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(1, 1, -1)
out = ref_masked_attention(q, keys, values, scale, alibi_bias)
out = out.view(num_query_heads, head_size)
@@ -119,8 +116,8 @@ def ref_single_query_cached_kv_attention(
@pytest.mark.parametrize(
"version",
["v1", "v2"] if not current_platform.is_rocm() else ["v1", "v2", "rocm"])
"version", ["v1", "v2"] if not current_platform.is_rocm() else ["v1", "v2", "rocm"]
)
@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@@ -143,13 +140,18 @@ def test_paged_attention(
seed: int,
device: str,
) -> None:
if ((kv_cache_dtype == "fp8" and head_size % 16)
or (version == "rocm" and head_size not in (64, 128))):
if (kv_cache_dtype == "fp8" and head_size % 16) or (
version == "rocm" and head_size not in (64, 128)
):
pytest.skip()
if (version == "rocm" and current_platform.is_navi()
and (kv_cache_dtype == "fp8" or head_size != 128
or block_size != 16 or use_alibi)):
if (
version == "rocm"
and current_platform.is_navi()
and (
kv_cache_dtype == "fp8" or head_size != 128 or block_size != 16 or use_alibi
)
):
pytest.skip()
global PARTITION_SIZE
@@ -177,18 +179,24 @@ def test_paged_attention(
block_tables_lst: list[list[int]] = []
for _ in range(num_seqs):
block_table = [
random.randint(0, NUM_BLOCKS - 1)
for _ in range(max_num_blocks_per_seq)
random.randint(0, NUM_BLOCKS - 1) for _ in range(max_num_blocks_per_seq)
]
block_tables_lst.append(block_table)
block_tables = torch.tensor(block_tables_lst, dtype=torch.int)
# Create the KV caches.
key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
num_kv_heads, head_size,
kv_cache_dtype, dtype, seed,
device)
key_caches, value_caches = kv_cache_factory(
NUM_BLOCKS,
block_size,
1,
num_kv_heads,
head_size,
kv_cache_dtype,
dtype,
seed,
device,
)
key_cache, value_cache = key_caches[0], value_caches[0]
# Using default kv_scale
@@ -214,18 +222,37 @@ def test_paged_attention(
v_scale,
)
opcheck(torch.ops._C.paged_attention_v1,
(output, query, key_cache, value_cache, num_kv_heads, scale,
block_tables, seq_lens, block_size, max_seq_len, alibi_slopes,
kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
cond=(head_size == HEAD_SIZES[0]
and block_size == BLOCK_SIZES[0]))
opcheck(
torch.ops._C.paged_attention_v1,
(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
0,
0,
0,
64,
0,
),
cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0]),
)
elif version in ("v2", "rocm"):
if current_platform.is_rocm() and version == "rocm":
PARTITION_SIZE = PARTITION_SIZE_ROCM
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
num_partitions = (max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE
assert PARTITION_SIZE % block_size == 0
num_seqs, num_heads, head_size = output.shape
tmp_output = torch.empty(
@@ -258,13 +285,34 @@ def test_paged_attention(
v_scale,
)
opcheck(torch.ops._C.paged_attention_v2,
(output, exp_sums, max_logits, tmp_output, query,
key_cache, value_cache, num_kv_heads, scale, block_tables,
seq_lens, block_size, max_seq_len, alibi_slopes,
kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
cond=(head_size == HEAD_SIZES[0]
and block_size == BLOCK_SIZES[0]))
opcheck(
torch.ops._C.paged_attention_v2,
(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
0,
0,
0,
64,
0,
),
cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0]),
)
else:
ops.paged_attention_rocm(
@@ -288,13 +336,30 @@ def test_paged_attention(
v_scale,
)
opcheck(torch.ops._rocm_C.paged_attention,
(output, exp_sums, max_logits, tmp_output, query,
key_cache, value_cache, num_kv_heads, scale, block_tables,
seq_lens, None, block_size, max_seq_len, alibi_slopes,
kv_cache_dtype, k_scale, v_scale),
cond=(head_size == HEAD_SIZES[0]
and block_size == BLOCK_SIZES[0]))
opcheck(
torch.ops._rocm_C.paged_attention,
(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
None,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
),
cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0]),
)
else:
raise AssertionError(f"Unknown version: {version}")
@@ -303,18 +368,17 @@ def test_paged_attention(
if kv_cache_dtype == "fp8":
# Convert cache data back to dtype.
x = 16 // torch.tensor([], dtype=dtype).element_size()
key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x,
block_size, x)
dequantized_key_cache = torch.empty(size=key_cache_shape,
dtype=dtype,
device=device)
key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, block_size, x)
dequantized_key_cache = torch.empty(
size=key_cache_shape, dtype=dtype, device=device
)
ops.convert_fp8(dequantized_key_cache, key_cache)
key_cache = dequantized_key_cache
value_cache_shape = value_cache.shape
dequantized_value_cache = torch.empty(size=value_cache_shape,
dtype=dtype,
device=device)
dequantized_value_cache = torch.empty(
size=value_cache_shape, dtype=dtype, device=device
)
ops.convert_fp8(dequantized_value_cache, value_cache)
value_cache = dequantized_value_cache
@@ -367,8 +431,9 @@ def ref_multi_query_kv_attention(
if alibi_bias:
attn_mask = alibi_bias[i]
else:
attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
diagonal=1)
attn_mask = torch.triu(
torch.ones(seq_len, seq_len, dtype=dtype), diagonal=1
)
attn_mask = attn_mask * torch.finfo(dtype).min
attn_mask = attn_mask.to(dtype=dtype)
@@ -390,8 +455,9 @@ def ref_multi_query_kv_attention(
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.skipif(current_platform.is_rocm(),
reason="Xformers backend is not supported on ROCm.")
@pytest.mark.skipif(
current_platform.is_rocm(), reason="Xformers backend is not supported on ROCm."
)
@torch.inference_mode()
def test_multi_query_kv_attention(
num_seqs: int,
@@ -413,13 +479,11 @@ def test_multi_query_kv_attention(
scale = float(1.0 / (head_size**0.5))
num_query_heads, num_kv_heads = num_heads
qkv = torch.empty(num_tokens,
num_query_heads + 2 * num_kv_heads,
head_size,
dtype=dtype)
qkv = torch.empty(
num_tokens, num_query_heads + 2 * num_kv_heads, head_size, dtype=dtype
)
qkv.uniform_(-scale, scale)
query, key, value = qkv.split(
[num_query_heads, num_kv_heads, num_kv_heads], dim=1)
query, key, value = qkv.split([num_query_heads, num_kv_heads, num_kv_heads], dim=1)
num_queries_per_kv = num_query_heads // num_kv_heads
if num_queries_per_kv > 1:
@@ -429,8 +493,7 @@ def test_multi_query_kv_attention(
alibi_bias = None
if use_alibi:
alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)
attn_bias = make_alibi_bias(alibi_slopes, num_kv_heads, dtype,
seq_lens)
attn_bias = make_alibi_bias(alibi_slopes, num_kv_heads, dtype, seq_lens)
output = torch.empty_like(query)
start = 0
# Dynamic sequence length not supported with custom attn_bias.
@@ -442,7 +505,8 @@ def test_multi_query_kv_attention(
value[None, start:end],
attn_bias=attn_bias[i],
p=0.0,
scale=scale)
scale=scale,
)
output[start:end].copy_(out.view_as(query[start:end]))
start += seq_len
# xformers.AttentionBias to Tensor for use in reference impl.
@@ -485,8 +549,9 @@ def test_multi_query_kv_attention(
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.skipif(current_platform.is_rocm(),
reason="Xformers backend is not supported on ROCm.")
@pytest.mark.skipif(
current_platform.is_rocm(), reason="Xformers backend is not supported on ROCm."
)
@torch.inference_mode()
def test_multi_query_kv_attention_with_alibi(
num_seqs: int,