Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -10,18 +10,26 @@ import pytest
import torch
from tests.v1.attention.test_mla_backends import (
BATCH_SPECS, BatchSpec, MockAttentionLayer,
create_and_prepopulate_kv_cache)
from tests.v1.attention.utils import (create_common_attn_metadata,
create_standard_kv_cache_spec,
create_vllm_config)
BATCH_SPECS,
BatchSpec,
MockAttentionLayer,
create_and_prepopulate_kv_cache,
)
from tests.v1.attention.utils import (
create_common_attn_metadata,
create_standard_kv_cache_spec,
create_vllm_config,
)
from vllm import _custom_ops as ops
from vllm.attention.ops import flashmla
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.utils import cdiv
from vllm.v1.attention.backends.mla.flashmla_sparse import (
FlashMLASparseBackend, FlashMLASparseDecodeAndContextMetadata,
FlashMLASparseImpl, FlashMLASparseMetadata)
FlashMLASparseBackend,
FlashMLASparseDecodeAndContextMetadata,
FlashMLASparseImpl,
FlashMLASparseMetadata,
)
from vllm.v1.attention.backends.mla.indexer import split_prefill_chunks
SPARSE_BACKEND_BATCH_SPECS = {
@@ -35,41 +43,42 @@ SPARSE_BACKEND_BATCH_SPECS = {
]
}
SPARSE_BACKEND_BATCH_SPECS["large_q_prefill"] = BatchSpec(seq_lens=[1024] * 2,
query_lens=[256] * 2)
SPARSE_BACKEND_BATCH_SPECS["large_q_prefill"] = BatchSpec(
seq_lens=[1024] * 2, query_lens=[256] * 2
)
SPARSE_BACKEND_BATCH_SPECS["large_q_pure_prefill"] = BatchSpec(
seq_lens=[256] * 2, query_lens=[256] * 2)
seq_lens=[256] * 2, query_lens=[256] * 2
)
def _dequantize_fp8_ds_mla_entry(
cache_slice: torch.Tensor, kv_lora_rank: int, rope_dim: int,
dtype: torch.dtype) -> tuple[torch.Tensor, torch.Tensor]:
cache_slice: torch.Tensor, kv_lora_rank: int, rope_dim: int, dtype: torch.dtype
) -> tuple[torch.Tensor, torch.Tensor]:
"""Dequantize a single fp8_ds_mla cache entry back to latent + rope."""
# The first kv_lora_rank bytes store FP8 latent values with one scale per
# 128 element tile written as float32 right after the latent payload.
scales = cache_slice.view(torch.float32)[kv_lora_rank //
4:kv_lora_rank // 4 + 4]
latent = torch.empty(kv_lora_rank,
dtype=torch.float16,
device=cache_slice.device)
scales = cache_slice.view(torch.float32)[kv_lora_rank // 4 : kv_lora_rank // 4 + 4]
latent = torch.empty(kv_lora_rank, dtype=torch.float16, device=cache_slice.device)
for tile_idx in range(4):
tile_start = tile_idx * 128
tile_end = tile_start + 128
ops.convert_fp8(latent[tile_start:tile_end],
cache_slice[tile_start:tile_end],
float(scales[tile_idx].item()),
kv_dtype="fp8")
ops.convert_fp8(
latent[tile_start:tile_end],
cache_slice[tile_start:tile_end],
float(scales[tile_idx].item()),
kv_dtype="fp8",
)
latent = latent.to(dtype)
rope_offset = kv_lora_rank // 2 + 8
rope_vals = cache_slice.view(dtype)[rope_offset:rope_offset + rope_dim]
rope_vals = cache_slice.view(dtype)[rope_offset : rope_offset + rope_dim]
return latent, rope_vals.clone()
def _quantize_dequantize_fp8_ds_mla(
kv_c: torch.Tensor, k_pe: torch.Tensor, block_size: int,
scale: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
kv_c: torch.Tensor, k_pe: torch.Tensor, block_size: int, scale: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""Round-trip kv_c/k_pe though the fp8_ds_mla cache layout."""
if kv_c.numel() == 0:
@@ -81,21 +90,14 @@ def _quantize_dequantize_fp8_ds_mla(
num_blocks = max(1, math.ceil(num_tokens / block_size))
entry_size = kv_lora_rank + 4 * 4 + 2 * rope_dim
tmp_cache = torch.zeros(num_blocks,
block_size,
entry_size,
dtype=torch.uint8,
device=kv_c.device)
slot_mapping = torch.arange(num_tokens,
dtype=torch.long,
device=kv_c.device)
tmp_cache = torch.zeros(
num_blocks, block_size, entry_size, dtype=torch.uint8, device=kv_c.device
)
slot_mapping = torch.arange(num_tokens, dtype=torch.long, device=kv_c.device)
ops.concat_and_cache_mla(kv_c,
k_pe,
tmp_cache,
slot_mapping,
kv_cache_dtype="fp8_ds_mla",
scale=scale)
ops.concat_and_cache_mla(
kv_c, k_pe, tmp_cache, slot_mapping, kv_cache_dtype="fp8_ds_mla", scale=scale
)
dequant_kv_c = torch.empty_like(kv_c)
dequant_k_pe = torch.empty_like(k_pe)
@@ -106,7 +108,8 @@ def _quantize_dequantize_fp8_ds_mla(
block_offset = slot % block_size
cache_slice = tmp_cache[block_idx, block_offset]
latent, rope_vals = _dequantize_fp8_ds_mla_entry(
cache_slice, kv_lora_rank, rope_dim, kv_c.dtype)
cache_slice, kv_lora_rank, rope_dim, kv_c.dtype
)
dequant_kv_c[token_idx] = latent
dequant_k_pe[token_idx] = rope_vals
@@ -123,10 +126,9 @@ def test_sparse_backend_metadata_registration():
dtype_list = backend.get_supported_dtypes()
assert torch.bfloat16 in dtype_list
shape = backend.get_kv_cache_shape(num_blocks=2,
block_size=64,
num_kv_heads=1,
head_size=576)
shape = backend.get_kv_cache_shape(
num_blocks=2, block_size=64, num_kv_heads=1, head_size=576
)
assert shape == (2, 64, 576)
@@ -141,13 +143,10 @@ def test_sparse_decode_metadata_filters_prefill_indices():
indices = torch.tensor([[0, 3, 5], [1, 2, 4]], dtype=torch.int32)
context_indices, new_token_indices = metadata.filter_prefill_indices(
indices)
context_indices, new_token_indices = metadata.filter_prefill_indices(indices)
expected_context = torch.tensor([[-1, -1, 5], [-1, -1, 4]],
dtype=torch.int32)
expected_new_tokens = torch.tensor([[-1, -1, 1], [-1, 0, 2]],
dtype=torch.int32)
expected_context = torch.tensor([[-1, -1, 5], [-1, -1, 4]], dtype=torch.int32)
expected_new_tokens = torch.tensor([[-1, -1, 1], [-1, 0, 2]], dtype=torch.int32)
assert torch.equal(context_indices, expected_context)
assert torch.equal(new_token_indices, expected_new_tokens)
@@ -162,14 +161,9 @@ def test_sparse_impl_zero_fills_when_metadata_missing():
kv_cache = torch.zeros((1, 1, 1))
output = torch.ones((2, 4))
result = FlashMLASparseImpl.forward(impl,
dummy_layer,
q,
k_c,
k_pe,
kv_cache,
attn_metadata=None,
output=output)
result = FlashMLASparseImpl.forward(
impl, dummy_layer, q, k_c, k_pe, kv_cache, attn_metadata=None, output=output
)
assert result is output
assert torch.all(result == 0)
@@ -177,8 +171,7 @@ def test_sparse_impl_zero_fills_when_metadata_missing():
@pytest.mark.parametrize("batch_name", list(SPARSE_BACKEND_BATCH_SPECS.keys()))
@pytest.mark.parametrize("kv_cache_dtype", ["fp8_ds_mla", "auto"])
def test_sparse_backend_decode_correctness(dist_init, batch_name,
kv_cache_dtype):
def test_sparse_backend_decode_correctness(dist_init, batch_name, kv_cache_dtype):
if not torch.cuda.is_available():
pytest.skip("CUDA is required for sparse MLA decode test")
@@ -203,14 +196,13 @@ def test_sparse_backend_decode_correctness(dist_init, batch_name,
vllm_config = create_vllm_config(
model_name="deepseek-ai/DeepSeek-V2-Lite-Chat",
max_model_len=max_seqlen,
num_gpu_blocks=max(2048,
cdiv(total_cache_tokens, block_size) + 1),
block_size=block_size)
num_gpu_blocks=max(2048, cdiv(total_cache_tokens, block_size) + 1),
block_size=block_size,
)
model_config = vllm_config.model_config
model_config.hf_config = SimpleNamespace(
attn_module_list_cfg=[{
"topk_tokens": topk_tokens
}])
attn_module_list_cfg=[{"topk_tokens": topk_tokens}]
)
model_config.hf_text_config = SimpleNamespace(
q_lora_rank=None,
kv_lora_rank=kv_lora_rank,
@@ -221,13 +213,13 @@ def test_sparse_backend_decode_correctness(dist_init, batch_name,
)
model_config.dtype = dtype
model_config.get_num_attention_heads = MethodType(
lambda self, parallel_config: num_heads, model_config)
model_config.get_num_kv_heads = MethodType(lambda self, parallel_config: 1,
model_config)
model_config.get_head_size = MethodType(lambda self: head_size,
model_config)
model_config.get_sliding_window = MethodType(lambda self: None,
model_config)
lambda self, parallel_config: num_heads, model_config
)
model_config.get_num_kv_heads = MethodType(
lambda self, parallel_config: 1, model_config
)
model_config.get_head_size = MethodType(lambda self: head_size, model_config)
model_config.get_sliding_window = MethodType(lambda self: None, model_config)
kv_cache_spec = create_standard_kv_cache_spec(vllm_config)
@@ -236,16 +228,10 @@ def test_sparse_backend_decode_correctness(dist_init, batch_name,
scale = 1.0 / math.sqrt(head_size)
# Shared MLA projection weights to keep reference and backend in sync
W_UK = torch.randn(kv_lora_rank,
num_heads,
qk_nope_head_dim,
dtype=dtype,
device=device)
W_UV = torch.randn(kv_lora_rank,
num_heads,
v_head_dim,
dtype=dtype,
device=device)
W_UK = torch.randn(
kv_lora_rank, num_heads, qk_nope_head_dim, dtype=dtype, device=device
)
W_UV = torch.randn(kv_lora_rank, num_heads, v_head_dim, dtype=dtype, device=device)
# Build synthetic decode-only workload
seq_lens = batch_spec.seq_lens
@@ -262,17 +248,15 @@ def test_sparse_backend_decode_correctness(dist_init, batch_name,
q_len = query_lens[i]
ctx_len = s_len - q_len
q_c = torch.rand(q_len,
num_heads,
qk_nope_head_dim + qk_rope_head_dim,
dtype=dtype,
device=device)
q_c = torch.rand(
q_len,
num_heads,
qk_nope_head_dim + qk_rope_head_dim,
dtype=dtype,
device=device,
)
kv_c_full = torch.rand(s_len, kv_lora_rank, dtype=dtype, device=device)
k_pe_full = torch.rand(s_len,
1,
qk_rope_head_dim,
dtype=dtype,
device=device)
k_pe_full = torch.rand(s_len, 1, qk_rope_head_dim, dtype=dtype, device=device)
kv_c_full, k_pe_full = _quantize_dequantize_fp8_ds_mla(
kv_c_full,
@@ -298,7 +282,8 @@ def test_sparse_backend_decode_correctness(dist_init, batch_name,
v_sdpa_in = v_mqa.unsqueeze(0).transpose(1, 2)
sdpa_out = torch.nn.functional.scaled_dot_product_attention(
q_sdpa_in, k_sdpa_in, v_sdpa_in, attn_mask=attn_mask, scale=scale)
q_sdpa_in, k_sdpa_in, v_sdpa_in, attn_mask=attn_mask, scale=scale
)
sdpa_out = sdpa_out.transpose(1, 2).squeeze(0)
sdpa_out = torch.einsum("qnl,lnv->qnv", sdpa_out, W_UV)
@@ -307,8 +292,8 @@ def test_sparse_backend_decode_correctness(dist_init, batch_name,
all_q_vllm.append(q_c)
all_kv_c_vllm.append(kv_c_full[ctx_len:])
all_k_pe_vllm.append(k_pe_full[ctx_len:])
kv_c_contexts.append(kv_c_full[:ctx_len + 1])
k_pe_contexts.append(k_pe_full[:ctx_len + 1])
kv_c_contexts.append(kv_c_full[: ctx_len + 1])
k_pe_contexts.append(k_pe_full[: ctx_len + 1])
query_vllm = torch.cat(all_q_vllm, dim=0)
kv_c_vllm = torch.cat(all_kv_c_vllm, dim=0)
@@ -321,7 +306,8 @@ def test_sparse_backend_decode_correctness(dist_init, batch_name,
batch_spec,
vllm_config.cache_config.block_size,
device,
arange_block_indices=True)
arange_block_indices=True,
)
kv_cache = create_and_prepopulate_kv_cache(
kv_c_contexts=kv_c_contexts,
@@ -339,31 +325,31 @@ def test_sparse_backend_decode_correctness(dist_init, batch_name,
builder_cls = FlashMLASparseBackend.get_builder_cls()
builder = builder_cls(kv_cache_spec, ["placeholder"], vllm_config, device)
metadata = builder.build(common_prefix_len=0,
common_attn_metadata=common_attn_metadata)
metadata = builder.build(
common_prefix_len=0, common_attn_metadata=common_attn_metadata
)
starts = np.asarray(common_attn_metadata.query_start_loc_cpu,
dtype=np.int32)
starts = np.asarray(common_attn_metadata.query_start_loc_cpu, dtype=np.int32)
seg_lengths = np.diff(starts)
positions = np.arange(starts[-1], dtype=np.int32) - np.repeat(
starts[:-1], seg_lengths)
starts[:-1], seg_lengths
)
seq_lengths = np.asarray(common_attn_metadata.seq_lens_cpu, dtype=np.int32)
prefix_lengths = seq_lengths - seg_lengths
positions += np.repeat(prefix_lengths, seg_lengths)
pos_gpu = torch.as_tensor(positions, device=device, dtype=torch.int32)
topk = metadata.topk_tokens
debug_indices = torch.arange(topk, device=device,
dtype=torch.int32).unsqueeze(0)
debug_indices = torch.arange(topk, device=device, dtype=torch.int32).unsqueeze(0)
token_positions = pos_gpu.unsqueeze(1)
causal_mask = (debug_indices <= token_positions)
debug_indices = torch.where(causal_mask, debug_indices,
torch.full_like(debug_indices, -1))
causal_mask = debug_indices <= token_positions
debug_indices = torch.where(
causal_mask, debug_indices, torch.full_like(debug_indices, -1)
)
# FlashMLASparseImpl now reads top-k indices from the indexer-provided
# buffer, so emulate that contract with a simple namespace mock.
debug_indices = debug_indices.expand(metadata.num_actual_tokens,
-1).clone()
debug_indices = debug_indices.expand(metadata.num_actual_tokens, -1).clone()
mock_indexer = SimpleNamespace(topk_indices_buffer=debug_indices)
ok, reason = flashmla.is_flashmla_supported()
@@ -372,59 +358,54 @@ def test_sparse_backend_decode_correctness(dist_init, batch_name,
kv_b_proj_weight = torch.cat([W_UK, W_UV], dim=-1)
kv_b_proj_weight = kv_b_proj_weight.view(
kv_lora_rank, num_heads * (qk_nope_head_dim + v_head_dim))
kv_lora_rank, num_heads * (qk_nope_head_dim + v_head_dim)
)
mock_kv_b_proj = ColumnParallelLinear(input_size=kv_lora_rank,
output_size=num_heads *
(qk_nope_head_dim + v_head_dim),
bias=False).to(device=device,
dtype=dtype)
mock_kv_b_proj = ColumnParallelLinear(
input_size=kv_lora_rank,
output_size=num_heads * (qk_nope_head_dim + v_head_dim),
bias=False,
).to(device=device, dtype=dtype)
mock_kv_b_proj.weight = torch.nn.Parameter(kv_b_proj_weight.T.contiguous())
impl_cls = FlashMLASparseBackend.get_impl_cls()
impl = impl_cls(num_heads=num_heads,
head_size=head_size,
scale=scale,
num_kv_heads=1,
alibi_slopes=None,
sliding_window=None,
kv_cache_dtype=vllm_config.cache_config.cache_dtype,
logits_soft_cap=None,
attn_type="decoder",
kv_sharing_target_layer_name=None,
q_lora_rank=None,
kv_lora_rank=kv_lora_rank,
qk_nope_head_dim=qk_nope_head_dim,
qk_rope_head_dim=qk_rope_head_dim,
qk_head_dim=qk_nope_head_dim + qk_rope_head_dim,
v_head_dim=v_head_dim,
kv_b_proj=mock_kv_b_proj,
indexer=mock_indexer)
impl = impl_cls(
num_heads=num_heads,
head_size=head_size,
scale=scale,
num_kv_heads=1,
alibi_slopes=None,
sliding_window=None,
kv_cache_dtype=vllm_config.cache_config.cache_dtype,
logits_soft_cap=None,
attn_type="decoder",
kv_sharing_target_layer_name=None,
q_lora_rank=None,
kv_lora_rank=kv_lora_rank,
qk_nope_head_dim=qk_nope_head_dim,
qk_rope_head_dim=qk_rope_head_dim,
qk_head_dim=qk_nope_head_dim + qk_rope_head_dim,
v_head_dim=v_head_dim,
kv_b_proj=mock_kv_b_proj,
indexer=mock_indexer,
)
impl.process_weights_after_loading(dtype)
layer = MockAttentionLayer(device)
out_buffer = torch.empty(metadata.num_actual_tokens,
num_heads * v_head_dim,
dtype=dtype,
device=device)
out_buffer = torch.empty(
metadata.num_actual_tokens, num_heads * v_head_dim, dtype=dtype, device=device
)
backend_output = impl.forward(layer,
query_vllm,
kv_c_vllm,
k_pe_vllm,
kv_cache,
metadata,
output=out_buffer)
backend_output = impl.forward(
layer, query_vllm, kv_c_vllm, k_pe_vllm, kv_cache, metadata, output=out_buffer
)
assert backend_output.shape == sdpa_reference.shape
assert backend_output.dtype == sdpa_reference.dtype
assert torch.isfinite(backend_output).all()
torch.testing.assert_close(backend_output,
sdpa_reference,
rtol=0.5,
atol=0.5)
torch.testing.assert_close(backend_output, sdpa_reference, rtol=0.5, atol=0.5)
@pytest.mark.parametrize(