Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -14,9 +14,11 @@ HEAD_SIZES = [128, 256]
BLOCK_SIZES = [16]
DTYPES = [torch.bfloat16]
QDTYPES = [None, torch.float8_e4m3fn] if not current_platform.is_rocm() else [
None, torch.float8_e4m3fnuz
]
QDTYPES = (
[None, torch.float8_e4m3fn]
if not current_platform.is_rocm()
else [None, torch.float8_e4m3fnuz]
)
# one value large enough to test overflow in index calculation.
# one value small enough to test the schema op check
NUM_BLOCKS = [32768, 2048]
@@ -42,7 +44,7 @@ def ref_paged_attn(
for i in range(num_seqs):
query_len = query_lens[i]
kv_len = kv_lens[i]
q = query[start_idx:start_idx + query_len]
q = query[start_idx : start_idx + query_len]
q *= scale
num_kv_blocks = (kv_len + block_size - 1) // block_size
@@ -60,10 +62,13 @@ def ref_paged_attn(
empty_mask = torch.ones(query_len, kv_len)
mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool()
if sliding_window is not None:
sliding_window_mask = torch.triu(empty_mask,
diagonal=kv_len -
(query_len + sliding_window) +
1).bool().logical_not()
sliding_window_mask = (
torch.triu(
empty_mask, diagonal=kv_len - (query_len + sliding_window) + 1
)
.bool()
.logical_not()
)
mask |= sliding_window_mask
if soft_cap is not None and soft_cap > 0:
attn = soft_cap * torch.tanh(attn / soft_cap)
@@ -77,9 +82,9 @@ def ref_paged_attn(
return torch.cat(outputs, dim=0)
@pytest.mark.parametrize("seq_lens",
[[(1, 1328), (5, 18),
(129, 463)], [(1, 523), (1, 37), (1, 2011)]])
@pytest.mark.parametrize(
"seq_lens", [[(1, 1328), (5, 18), (129, 463)], [(1, 523), (1, 37), (1, 2011)]]
)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@@ -111,30 +116,23 @@ def test_triton_unified_attn(
assert num_query_heads % num_kv_heads == 0
max_query_len = max(query_lens)
max_kv_len = max(kv_lens)
window_size = ((sliding_window - 1, 0) if sliding_window is not None else
(-1, -1))
window_size = (sliding_window - 1, 0) if sliding_window is not None else (-1, -1)
scale = head_size**-0.5
query = torch.randn(sum(query_lens),
num_query_heads,
head_size,
dtype=dtype)
key_cache = torch.randn(num_blocks,
block_size,
num_kv_heads,
head_size,
dtype=dtype)
query = torch.randn(sum(query_lens), num_query_heads, head_size, dtype=dtype)
key_cache = torch.randn(
num_blocks, block_size, num_kv_heads, head_size, dtype=dtype
)
value_cache = torch.randn_like(key_cache)
cu_query_lens = torch.tensor([0] + query_lens,
dtype=torch.int32).cumsum(dim=0,
dtype=torch.int32)
cu_query_lens = torch.tensor([0] + query_lens, dtype=torch.int32).cumsum(
dim=0, dtype=torch.int32
)
kv_lens = torch.tensor(kv_lens, dtype=torch.int32)
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
block_tables = torch.randint(0,
num_blocks,
(num_seqs, max_num_blocks_per_seq),
dtype=torch.int32)
block_tables = torch.randint(
0, num_blocks, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32
)
output = torch.empty_like(query)
@@ -188,5 +186,7 @@ def test_triton_unified_attn(
atol, rtol = 1.5e-2, 1e-2
if q_dtype is not None:
atol, rtol = 1.5e-1, 1.5e-1
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), \
f"{torch.max(torch.abs(output - ref_output))}"
(
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol),
f"{torch.max(torch.abs(output - ref_output))}",
)