Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for v1 attention backends without GPUModelRunner dependency."""
from functools import partial
from typing import Optional, Union
@@ -8,21 +9,30 @@ import pytest
import torch
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
from tests.v1.attention.utils import (BatchSpec, create_common_attn_metadata,
create_standard_kv_cache_spec,
create_vllm_config,
get_attention_backend)
from tests.v1.attention.utils import (
BatchSpec,
create_common_attn_metadata,
create_standard_kv_cache_spec,
create_vllm_config,
get_attention_backend,
)
from vllm.attention.backends.registry import _Backend
from vllm.config import ModelConfig
from vllm.platforms import current_platform
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv, is_torch_equal_or_newer
from vllm.v1.attention.backends.utils import (CommonAttentionMetadata,
set_kv_cache_layout)
from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata,
set_kv_cache_layout,
)
from vllm.v1.kv_cache_interface import FullAttentionSpec
BACKENDS_TO_TEST = [
_Backend.FLASH_ATTN, _Backend.FLASHINFER, _Backend.FLEX_ATTENTION,
_Backend.TRITON_ATTN, _Backend.TREE_ATTN, "FLEX_ATTENTION_SLOW"
_Backend.FLASH_ATTN,
_Backend.FLASHINFER,
_Backend.FLEX_ATTENTION,
_Backend.TRITON_ATTN,
_Backend.TREE_ATTN,
"FLEX_ATTENTION_SLOW",
]
# Remove flashinfer from the list if it's not available
@@ -49,42 +59,38 @@ def _convert_dtype_to_torch(dtype):
# Define common batch configurations
BATCH_SPECS = {
"small_decode":
BatchSpec(seq_lens=[32, 40], query_lens=[1, 1]),
"small_prefill":
BatchSpec(seq_lens=[32, 40], query_lens=[8, 8]),
"mixed_small":
BatchSpec(seq_lens=[32, 40, 48, 56], query_lens=[1, 1, 5, 5]),
"medium_decode":
BatchSpec(seq_lens=[128, 256, 512, 1024, 128, 256, 512, 1024],
query_lens=[1, 1, 1, 1, 1, 1, 1, 1]),
"medium_prefill":
BatchSpec(seq_lens=[256, 512, 1024, 2048], query_lens=[16, 16, 16, 16]),
"mixed_medium":
BatchSpec(seq_lens=[512, 1024, 2048, 512, 1024, 2048],
query_lens=[1, 1, 1, 7, 7, 7]),
"large_decode":
BatchSpec(seq_lens=[2048] * 32, query_lens=[1] * 32),
"large_prefill":
BatchSpec(seq_lens=[4096] * 8, query_lens=[32] * 8),
"single_decode":
BatchSpec(seq_lens=[1024], query_lens=[1]),
"single_prefill":
BatchSpec(seq_lens=[1024], query_lens=[64]),
"small_decode": BatchSpec(seq_lens=[32, 40], query_lens=[1, 1]),
"small_prefill": BatchSpec(seq_lens=[32, 40], query_lens=[8, 8]),
"mixed_small": BatchSpec(seq_lens=[32, 40, 48, 56], query_lens=[1, 1, 5, 5]),
"medium_decode": BatchSpec(
seq_lens=[128, 256, 512, 1024, 128, 256, 512, 1024],
query_lens=[1, 1, 1, 1, 1, 1, 1, 1],
),
"medium_prefill": BatchSpec(
seq_lens=[256, 512, 1024, 2048], query_lens=[16, 16, 16, 16]
),
"mixed_medium": BatchSpec(
seq_lens=[512, 1024, 2048, 512, 1024, 2048], query_lens=[1, 1, 1, 7, 7, 7]
),
"large_decode": BatchSpec(seq_lens=[2048] * 32, query_lens=[1] * 32),
"large_prefill": BatchSpec(seq_lens=[4096] * 8, query_lens=[32] * 8),
"single_decode": BatchSpec(seq_lens=[1024], query_lens=[1]),
"single_prefill": BatchSpec(seq_lens=[1024], query_lens=[64]),
}
def create_and_prepopulate_kv_cache(
k_contexts: list[torch.Tensor],
v_contexts: list[torch.Tensor],
block_size: int,
num_kv_heads: int,
head_size: int,
dtype: torch.dtype,
device: torch.device,
num_blocks: int,
common_attn_metadata: CommonAttentionMetadata,
randomize_blocks: bool = True) -> torch.Tensor:
k_contexts: list[torch.Tensor],
v_contexts: list[torch.Tensor],
block_size: int,
num_kv_heads: int,
head_size: int,
dtype: torch.dtype,
device: torch.device,
num_blocks: int,
common_attn_metadata: CommonAttentionMetadata,
randomize_blocks: bool = True,
) -> torch.Tensor:
"""Create and prepopulate a KV cache with context data.
Args:
@@ -106,20 +112,18 @@ def create_and_prepopulate_kv_cache(
"""
batch_size = len(k_contexts)
seq_lens = common_attn_metadata.seq_lens_cpu
query_lens = common_attn_metadata.query_start_loc_cpu[
1:] - common_attn_metadata.query_start_loc_cpu[:-1]
query_lens = (
common_attn_metadata.query_start_loc_cpu[1:]
- common_attn_metadata.query_start_loc_cpu[:-1]
)
context_lens = common_attn_metadata.num_computed_tokens_cpu
block_table = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping
# Create KV cache
kv_cache = torch.empty(2,
num_blocks,
block_size,
num_kv_heads,
head_size,
dtype=dtype,
device=device)
kv_cache = torch.empty(
2, num_blocks, block_size, num_kv_heads, head_size, dtype=dtype, device=device
)
kv_cache_flat = kv_cache.view(2, -1, num_kv_heads, head_size)
# Populate the cache with the context tokens
@@ -168,8 +172,8 @@ def create_and_prepopulate_kv_cache(
start = common_attn_metadata.query_start_loc_cpu[i]
end = common_attn_metadata.query_start_loc_cpu[i + 1]
slot_mapping[start:end] = block_table[
i,
block_indices] * block_size + token_inter_block_offsets.to(device)
i, block_indices
] * block_size + token_inter_block_offsets.to(device)
return kv_cache
@@ -222,20 +226,19 @@ def run_attention_backend(
# Return mock parameters for a single layer
head_size = vllm_config.model_config.get_head_size()
return {
layer_name:
PerLayerParameters(
layer_name: PerLayerParameters(
window_left=-1, # No sliding window
logits_soft_cap=0.0, # No soft cap
sm_scale=1.0 / (head_size**0.5) # Standard scale
sm_scale=1.0 / (head_size**0.5), # Standard scale
)
for layer_name in layer_names
}
with unittest.mock.patch(
'vllm.v1.attention.backends.flashinfer.get_per_layer_parameters',
mock_get_per_layer_parameters):
builder = builder_cls(kv_cache_spec, layer_names, vllm_config,
device)
"vllm.v1.attention.backends.flashinfer.get_per_layer_parameters",
mock_get_per_layer_parameters,
):
builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device)
attn_metadata = builder.build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
@@ -252,9 +255,11 @@ def run_attention_backend(
# Instantiate implementation
num_heads = vllm_config.model_config.get_num_attention_heads(
vllm_config.parallel_config)
vllm_config.parallel_config
)
num_kv_heads = vllm_config.model_config.get_num_kv_heads(
vllm_config.parallel_config)
vllm_config.parallel_config
)
head_size = vllm_config.model_config.get_head_size()
scale = 1.0 / (head_size**0.5)
impl = impl_cls(
@@ -274,13 +279,9 @@ def run_attention_backend(
# Run forward pass
# NOTE: The query, key, and value are already shaped correctly
# in the calling test function.
output = impl.forward(mock_layer,
query,
key,
value,
kv_cache,
attn_metadata,
output=output)
output = impl.forward(
mock_layer, query, key, value, kv_cache, attn_metadata, output=output
)
return output
@@ -311,10 +312,12 @@ def _test_backend_correctness(
5. Comparing the vLLM backend's output to the ground-truth SDPA output.
"""
current_platform.seed_everything(42)
vllm_config = create_vllm_config(model_name=model,
max_model_len=max(batch_spec.seq_lens),
block_size=block_size,
num_gpu_blocks=8192)
vllm_config = create_vllm_config(
model_name=model,
max_model_len=max(batch_spec.seq_lens),
block_size=block_size,
num_gpu_blocks=8192,
)
device = torch.device("cuda:0")
kv_cache_spec = create_standard_kv_cache_spec(vllm_config)
@@ -324,9 +327,11 @@ def _test_backend_correctness(
seq_lens = batch_spec.seq_lens
query_lens = batch_spec.query_lens
num_q_heads = vllm_config.model_config.get_num_attention_heads(
vllm_config.parallel_config)
vllm_config.parallel_config
)
num_kv_heads = vllm_config.model_config.get_num_kv_heads(
vllm_config.parallel_config)
vllm_config.parallel_config
)
head_size = vllm_config.model_config.get_head_size()
sliding_window = vllm_config.model_config.get_sliding_window()
dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype)
@@ -344,21 +349,9 @@ def _test_backend_correctness(
context_len = s_len - q_len
# Generate Q, K, V for the whole sequence to be used in SDPA
q = torch.randn(q_len,
num_q_heads,
head_size,
dtype=dtype,
device=device)
k_full = torch.randn(s_len,
num_kv_heads,
head_size,
dtype=dtype,
device=device)
v_full = torch.randn(s_len,
num_kv_heads,
head_size,
dtype=dtype,
device=device)
q = torch.randn(q_len, num_q_heads, head_size, dtype=dtype, device=device)
k_full = torch.randn(s_len, num_kv_heads, head_size, dtype=dtype, device=device)
v_full = torch.randn(s_len, num_kv_heads, head_size, dtype=dtype, device=device)
# SDPA expects (N, H, L, D), so unsqueeze batch and permute
q_sdpa_in = q.unsqueeze(0).transpose(1, 2)
@@ -368,7 +361,8 @@ def _test_backend_correctness(
if num_q_heads != num_kv_heads:
assert num_q_heads % num_kv_heads == 0, (
f"num_q_heads ({num_q_heads}) must be divisible by "
f"num_kv_heads ({num_kv_heads})")
f"num_kv_heads ({num_kv_heads})"
)
repeats = num_q_heads // num_kv_heads
k_sdpa_in = k_sdpa_in.repeat_interleave(repeats, dim=1)
v_sdpa_in = v_sdpa_in.repeat_interleave(repeats, dim=1)
@@ -378,18 +372,17 @@ def _test_backend_correctness(
kv_len = s_len
final_mask_mod = partial(mask_mod, context_len=context_len)
block_mask = create_block_mask(final_mask_mod,
B=None,
H=None,
Q_LEN=q_len,
KV_LEN=kv_len,
device=device)
sdpa_out_i = flex_attention(q_sdpa_in,
k_sdpa_in,
v_sdpa_in,
block_mask=block_mask,
scale=scale,
enable_gqa=True)
block_mask = create_block_mask(
final_mask_mod, B=None, H=None, Q_LEN=q_len, KV_LEN=kv_len, device=device
)
sdpa_out_i = flex_attention(
q_sdpa_in,
k_sdpa_in,
v_sdpa_in,
block_mask=block_mask,
scale=scale,
enable_gqa=True,
)
all_sdpa_outputs.append(sdpa_out_i.transpose(1, 2).squeeze(0))
@@ -408,7 +401,8 @@ def _test_backend_correctness(
sdpa_output = torch.cat(all_sdpa_outputs, dim=0)
common_attn_metadata = create_common_attn_metadata(
batch_spec, vllm_config.cache_config.block_size, device)
batch_spec, vllm_config.cache_config.block_size, device
)
# 3. Simulate Paged KV Cache and a realistic slot_mapping
kv_cache = create_and_prepopulate_kv_cache(
@@ -421,7 +415,8 @@ def _test_backend_correctness(
device=device,
num_blocks=vllm_config.cache_config.num_gpu_blocks or 1000,
common_attn_metadata=common_attn_metadata,
randomize_blocks=True)
randomize_blocks=True,
)
# 4. Run vLLM backends and compare
# Note: flex_attention has known Triton kernel compatibility issues
@@ -437,8 +432,9 @@ def _test_backend_correctness(
kv_cache_for_backend = kv_cache.transpose(0, 1)
# For FlashInfer default to HND layout and
kv_cache_for_backend = kv_cache_for_backend.transpose(
2, 3).contiguous().transpose(2, 3)
kv_cache_for_backend = (
kv_cache_for_backend.transpose(2, 3).contiguous().transpose(2, 3)
)
set_kv_cache_layout("HND")
backend_output = run_attention_backend(
@@ -458,32 +454,45 @@ def _test_backend_correctness(
# Check shape and dtype consistency
assert backend_output.shape == sdpa_output.shape, (
f"[{backend_name}] shape {backend_output.shape} != "
f"SDPA shape {sdpa_output.shape}")
f"SDPA shape {sdpa_output.shape}"
)
assert backend_output.dtype == sdpa_output.dtype, (
f"[{backend_name}] dtype {backend_output.dtype} != "
f"SDPA dtype {sdpa_output.dtype}")
f"SDPA dtype {sdpa_output.dtype}"
)
assert torch.isfinite(backend_output).all(), (
f"[{backend_name}] produced non-finite values")
f"[{backend_name}] produced non-finite values"
)
# Check numerical similarity
def error_msg(msg: str, backend_name: str):
return (f"[{backend_name}] output differs from SDPA baseline. "
f"{msg}")
return f"[{backend_name}] output differs from SDPA baseline. {msg}"
torch.testing.assert_close(backend_output,
sdpa_output,
rtol=rtol,
atol=atol,
msg=partial(error_msg,
backend_name=backend_name))
torch.testing.assert_close(
backend_output,
sdpa_output,
rtol=rtol,
atol=atol,
msg=partial(error_msg, backend_name=backend_name),
)
@pytest.mark.parametrize("batch_spec_name", [
"small_decode", "small_prefill", "mixed_small", "medium_decode",
"medium_prefill", "mixed_medium", "large_decode", "large_prefill",
"single_decode", "single_prefill"
])
@pytest.mark.parametrize(
"batch_spec_name",
[
"small_decode",
"small_prefill",
"mixed_small",
"medium_decode",
"medium_prefill",
"mixed_medium",
"large_decode",
"large_prefill",
"single_decode",
"single_prefill",
],
)
@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"])
def test_causal_backend_correctness(batch_spec_name: str, model: str):
"""Test backend's correctness with causal attention."""
@@ -499,33 +508,33 @@ def test_causal_backend_correctness(batch_spec_name: str, model: str):
return (q_idx + context_len) >= kv_idx
batch_spec = BATCH_SPECS[batch_spec_name]
LARGE_BLOCK_BACKENDS = ([_Backend.FLEX_ATTENTION]
if is_torch_equal_or_newer("2.9.0.dev0") else [])
LARGE_BLOCK_BACKENDS = (
[_Backend.FLEX_ATTENTION] if is_torch_equal_or_newer("2.9.0.dev0") else []
)
SMALL_BLOCK_BACKENDS = [
x for x in BACKENDS_TO_TEST if x not in LARGE_BLOCK_BACKENDS
]
_test_backend_correctness(batch_spec, model, SMALL_BLOCK_BACKENDS,
causal_mask_mod)
_test_backend_correctness(batch_spec, model, SMALL_BLOCK_BACKENDS, causal_mask_mod)
# Fast FlexAttention needs to run with block_size=128
if LARGE_BLOCK_BACKENDS:
_test_backend_correctness(batch_spec,
model,
LARGE_BLOCK_BACKENDS,
causal_mask_mod,
block_size=128)
_test_backend_correctness(
batch_spec, model, LARGE_BLOCK_BACKENDS, causal_mask_mod, block_size=128
)
SLIDING_WINDOW_BACKENDS_TO_TEST = [
_Backend.FLASH_ATTN, _Backend.FLEX_ATTENTION, _Backend.TRITON_ATTN,
"FLEX_ATTENTION_SLOW"
_Backend.FLASH_ATTN,
_Backend.FLEX_ATTENTION,
_Backend.TRITON_ATTN,
"FLEX_ATTENTION_SLOW",
]
@pytest.mark.parametrize("batch_spec_name", [
"small_decode", "small_prefill", "mixed_medium", "large_decode",
"large_prefill"
])
@pytest.mark.parametrize(
"batch_spec_name",
["small_decode", "small_prefill", "mixed_medium", "large_decode", "large_prefill"],
)
@pytest.mark.parametrize("model", ["microsoft/Phi-tiny-MoE-instruct"])
def test_sliding_window_backend_correctness(batch_spec_name: str, model: str):
"""Test backend's correctness with sliding window attention."""
@@ -544,25 +553,28 @@ def test_sliding_window_backend_correctness(batch_spec_name: str, model: str):
return causal_mask & window_mask
batch_spec = BATCH_SPECS[batch_spec_name]
model_config = ModelConfig(model=model,
max_model_len=max(batch_spec.seq_lens))
model_config = ModelConfig(model=model, max_model_len=max(batch_spec.seq_lens))
sliding_window = model_config.get_sliding_window()
sliding_window_mask_mod_fn = partial(sliding_window_mask_mod,
sliding_window=sliding_window)
sliding_window_mask_mod_fn = partial(
sliding_window_mask_mod, sliding_window=sliding_window
)
LARGE_BLOCK_BACKENDS = ([_Backend.FLEX_ATTENTION]
if is_torch_equal_or_newer("2.9.0.dev0") else [])
LARGE_BLOCK_BACKENDS = (
[_Backend.FLEX_ATTENTION] if is_torch_equal_or_newer("2.9.0.dev0") else []
)
SMALL_BLOCK_BACKENDS = [
x for x in SLIDING_WINDOW_BACKENDS_TO_TEST
if x not in LARGE_BLOCK_BACKENDS
x for x in SLIDING_WINDOW_BACKENDS_TO_TEST if x not in LARGE_BLOCK_BACKENDS
]
_test_backend_correctness(batch_spec, model, SMALL_BLOCK_BACKENDS,
sliding_window_mask_mod_fn)
_test_backend_correctness(
batch_spec, model, SMALL_BLOCK_BACKENDS, sliding_window_mask_mod_fn
)
# Fast FlexAttention needs to run with block_size=128
if LARGE_BLOCK_BACKENDS:
_test_backend_correctness(batch_spec,
model,
LARGE_BLOCK_BACKENDS,
sliding_window_mask_mod_fn,
block_size=128)
_test_backend_correctness(
batch_spec,
model,
LARGE_BLOCK_BACKENDS,
sliding_window_mask_mod_fn,
block_size=128,
)