Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -7,10 +7,12 @@ import pytest
import torch
from vllm.platforms import current_platform
from vllm.vllm_flash_attn import (fa_version_unsupported_reason,
flash_attn_varlen_func,
flash_attn_with_kvcache,
is_fa_version_supported)
from vllm.vllm_flash_attn import (
fa_version_unsupported_reason,
flash_attn_varlen_func,
flash_attn_with_kvcache,
is_fa_version_supported,
)
NUM_HEADS = [(4, 4), (8, 2)]
HEAD_SIZES = [128, 256]
@@ -44,7 +46,7 @@ def ref_paged_attn(
for i in range(num_seqs):
query_len = query_lens[i]
kv_len = kv_lens[i]
q = query[start_idx:start_idx + query_len]
q = query[start_idx : start_idx + query_len]
q *= scale
num_kv_blocks = (kv_len + block_size - 1) // block_size
@@ -62,10 +64,13 @@ def ref_paged_attn(
empty_mask = torch.ones(query_len, kv_len)
mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool()
if sliding_window is not None:
sliding_window_mask = torch.triu(empty_mask,
diagonal=kv_len -
(query_len + sliding_window) +
1).bool().logical_not()
sliding_window_mask = (
torch.triu(
empty_mask, diagonal=kv_len - (query_len + sliding_window) + 1
)
.bool()
.logical_not()
)
mask |= sliding_window_mask
if soft_cap is not None:
attn = soft_cap * torch.tanh(attn / soft_cap)
@@ -106,11 +111,15 @@ def test_flash_attn_with_paged_kv(
) -> None:
torch.set_default_device("cuda")
if not is_fa_version_supported(fa_version):
pytest.skip(f"Flash attention version {fa_version} not supported due "
f"to: \"{fa_version_unsupported_reason(fa_version)}\"")
pytest.skip(
f"Flash attention version {fa_version} not supported due "
f'to: "{fa_version_unsupported_reason(fa_version)}"'
)
if q_dtype is not None and (dtype != torch.bfloat16 or fa_version == 2):
pytest.skip("Flash attention with quantized inputs is only "
"supported on version 3 with bfloat16 base type")
pytest.skip(
"Flash attention with quantized inputs is only "
"supported on version 3 with bfloat16 base type"
)
current_platform.seed_everything(0)
num_seqs = len(kv_lens)
@@ -119,23 +128,19 @@ def test_flash_attn_with_paged_kv(
assert num_query_heads % num_kv_heads == 0
max_kv_len = max(kv_lens)
scale = head_size**-0.5
window_size = ((sliding_window - 1, 0) if sliding_window is not None else
(-1, -1))
window_size = (sliding_window - 1, 0) if sliding_window is not None else (-1, -1)
query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
key_cache = torch.randn(num_blocks,
block_size,
num_kv_heads,
head_size,
dtype=dtype)
key_cache = torch.randn(
num_blocks, block_size, num_kv_heads, head_size, dtype=dtype
)
value_cache = torch.randn_like(key_cache)
kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32)
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
block_tables = torch.randint(0,
num_blocks,
(num_seqs, max_num_blocks_per_seq),
dtype=torch.int32)
block_tables = torch.randint(
0, num_blocks, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32
)
q = query.unsqueeze(1)
out = torch.empty_like(q) if use_out else None
@@ -180,23 +185,27 @@ def test_flash_attn_with_paged_kv(
if q_dtype is not None:
atol, rtol = 1.5e-1, 1.5e-1
ref_output = ref_paged_attn(query=query,
key_cache=key_cache,
value_cache=value_cache,
query_lens=[1] * num_seqs,
kv_lens=kv_lens,
block_tables=block_tables,
scale=scale,
soft_cap=soft_cap,
sliding_window=sliding_window)
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), \
f"{torch.max(torch.abs(output - ref_output))}"
ref_output = ref_paged_attn(
query=query,
key_cache=key_cache,
value_cache=value_cache,
query_lens=[1] * num_seqs,
kv_lens=kv_lens,
block_tables=block_tables,
scale=scale,
soft_cap=soft_cap,
sliding_window=sliding_window,
)
(
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol),
f"{torch.max(torch.abs(output - ref_output))}",
)
@pytest.mark.parametrize("use_out", [True, False])
@pytest.mark.parametrize("seq_lens",
[[(1, 1328), (5, 18),
(129, 463)], [(1, 523), (1, 37), (1, 2011)]])
@pytest.mark.parametrize(
"seq_lens", [[(1, 1328), (5, 18), (129, 463)], [(1, 523), (1, 37), (1, 2011)]]
)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@@ -222,11 +231,15 @@ def test_varlen_with_paged_kv(
) -> None:
torch.set_default_device("cuda")
if not is_fa_version_supported(fa_version):
pytest.skip(f"Flash attention version {fa_version} not supported due "
f"to: \"{fa_version_unsupported_reason(fa_version)}\"")
pytest.skip(
f"Flash attention version {fa_version} not supported due "
f'to: "{fa_version_unsupported_reason(fa_version)}"'
)
if q_dtype is not None and (dtype != torch.bfloat16 or fa_version == 2):
pytest.skip("Flash attention with quantized inputs is only "
"supported on version 3 with bfloat16 base type")
pytest.skip(
"Flash attention with quantized inputs is only "
"supported on version 3 with bfloat16 base type"
)
current_platform.seed_everything(0)
num_seqs = len(seq_lens)
query_lens = [x[0] for x in seq_lens]
@@ -236,30 +249,23 @@ def test_varlen_with_paged_kv(
assert num_query_heads % num_kv_heads == 0
max_query_len = max(query_lens)
max_kv_len = max(kv_lens)
window_size = ((sliding_window - 1, 0) if sliding_window is not None else
(-1, -1))
window_size = (sliding_window - 1, 0) if sliding_window is not None else (-1, -1)
scale = head_size**-0.5
query = torch.randn(sum(query_lens),
num_query_heads,
head_size,
dtype=dtype)
key_cache = torch.randn(num_blocks,
block_size,
num_kv_heads,
head_size,
dtype=dtype)
query = torch.randn(sum(query_lens), num_query_heads, head_size, dtype=dtype)
key_cache = torch.randn(
num_blocks, block_size, num_kv_heads, head_size, dtype=dtype
)
value_cache = torch.randn_like(key_cache)
cu_query_lens = torch.tensor([0] + query_lens,
dtype=torch.int32).cumsum(dim=0,
dtype=torch.int32)
cu_query_lens = torch.tensor([0] + query_lens, dtype=torch.int32).cumsum(
dim=0, dtype=torch.int32
)
kv_lens = torch.tensor(kv_lens, dtype=torch.int32)
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
block_tables = torch.randint(0,
num_blocks,
(num_seqs, max_num_blocks_per_seq),
dtype=torch.int32)
block_tables = torch.randint(
0, num_blocks, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32
)
out = torch.empty_like(query) if use_out else None
@@ -315,5 +321,7 @@ def test_varlen_with_paged_kv(
atol, rtol = 1.5e-2, 1e-2
if q_dtype is not None:
atol, rtol = 1.5e-1, 1.5e-1
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), \
f"{torch.max(torch.abs(output - ref_output))}"
(
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol),
f"{torch.max(torch.abs(output - ref_output))}",
)