Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -10,18 +10,26 @@ import pytest
|
||||
import torch
|
||||
|
||||
from tests.v1.attention.test_mla_backends import (
|
||||
BATCH_SPECS, BatchSpec, MockAttentionLayer,
|
||||
create_and_prepopulate_kv_cache)
|
||||
from tests.v1.attention.utils import (create_common_attn_metadata,
|
||||
create_standard_kv_cache_spec,
|
||||
create_vllm_config)
|
||||
BATCH_SPECS,
|
||||
BatchSpec,
|
||||
MockAttentionLayer,
|
||||
create_and_prepopulate_kv_cache,
|
||||
)
|
||||
from tests.v1.attention.utils import (
|
||||
create_common_attn_metadata,
|
||||
create_standard_kv_cache_spec,
|
||||
create_vllm_config,
|
||||
)
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.ops import flashmla
|
||||
from vllm.model_executor.layers.linear import ColumnParallelLinear
|
||||
from vllm.utils import cdiv
|
||||
from vllm.v1.attention.backends.mla.flashmla_sparse import (
|
||||
FlashMLASparseBackend, FlashMLASparseDecodeAndContextMetadata,
|
||||
FlashMLASparseImpl, FlashMLASparseMetadata)
|
||||
FlashMLASparseBackend,
|
||||
FlashMLASparseDecodeAndContextMetadata,
|
||||
FlashMLASparseImpl,
|
||||
FlashMLASparseMetadata,
|
||||
)
|
||||
from vllm.v1.attention.backends.mla.indexer import split_prefill_chunks
|
||||
|
||||
SPARSE_BACKEND_BATCH_SPECS = {
|
||||
@@ -35,41 +43,42 @@ SPARSE_BACKEND_BATCH_SPECS = {
|
||||
]
|
||||
}
|
||||
|
||||
SPARSE_BACKEND_BATCH_SPECS["large_q_prefill"] = BatchSpec(seq_lens=[1024] * 2,
|
||||
query_lens=[256] * 2)
|
||||
SPARSE_BACKEND_BATCH_SPECS["large_q_prefill"] = BatchSpec(
|
||||
seq_lens=[1024] * 2, query_lens=[256] * 2
|
||||
)
|
||||
SPARSE_BACKEND_BATCH_SPECS["large_q_pure_prefill"] = BatchSpec(
|
||||
seq_lens=[256] * 2, query_lens=[256] * 2)
|
||||
seq_lens=[256] * 2, query_lens=[256] * 2
|
||||
)
|
||||
|
||||
|
||||
def _dequantize_fp8_ds_mla_entry(
|
||||
cache_slice: torch.Tensor, kv_lora_rank: int, rope_dim: int,
|
||||
dtype: torch.dtype) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
cache_slice: torch.Tensor, kv_lora_rank: int, rope_dim: int, dtype: torch.dtype
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Dequantize a single fp8_ds_mla cache entry back to latent + rope."""
|
||||
|
||||
# The first kv_lora_rank bytes store FP8 latent values with one scale per
|
||||
# 128 element tile written as float32 right after the latent payload.
|
||||
scales = cache_slice.view(torch.float32)[kv_lora_rank //
|
||||
4:kv_lora_rank // 4 + 4]
|
||||
latent = torch.empty(kv_lora_rank,
|
||||
dtype=torch.float16,
|
||||
device=cache_slice.device)
|
||||
scales = cache_slice.view(torch.float32)[kv_lora_rank // 4 : kv_lora_rank // 4 + 4]
|
||||
latent = torch.empty(kv_lora_rank, dtype=torch.float16, device=cache_slice.device)
|
||||
for tile_idx in range(4):
|
||||
tile_start = tile_idx * 128
|
||||
tile_end = tile_start + 128
|
||||
ops.convert_fp8(latent[tile_start:tile_end],
|
||||
cache_slice[tile_start:tile_end],
|
||||
float(scales[tile_idx].item()),
|
||||
kv_dtype="fp8")
|
||||
ops.convert_fp8(
|
||||
latent[tile_start:tile_end],
|
||||
cache_slice[tile_start:tile_end],
|
||||
float(scales[tile_idx].item()),
|
||||
kv_dtype="fp8",
|
||||
)
|
||||
latent = latent.to(dtype)
|
||||
|
||||
rope_offset = kv_lora_rank // 2 + 8
|
||||
rope_vals = cache_slice.view(dtype)[rope_offset:rope_offset + rope_dim]
|
||||
rope_vals = cache_slice.view(dtype)[rope_offset : rope_offset + rope_dim]
|
||||
return latent, rope_vals.clone()
|
||||
|
||||
|
||||
def _quantize_dequantize_fp8_ds_mla(
|
||||
kv_c: torch.Tensor, k_pe: torch.Tensor, block_size: int,
|
||||
scale: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
kv_c: torch.Tensor, k_pe: torch.Tensor, block_size: int, scale: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Round-trip kv_c/k_pe though the fp8_ds_mla cache layout."""
|
||||
|
||||
if kv_c.numel() == 0:
|
||||
@@ -81,21 +90,14 @@ def _quantize_dequantize_fp8_ds_mla(
|
||||
num_blocks = max(1, math.ceil(num_tokens / block_size))
|
||||
entry_size = kv_lora_rank + 4 * 4 + 2 * rope_dim
|
||||
|
||||
tmp_cache = torch.zeros(num_blocks,
|
||||
block_size,
|
||||
entry_size,
|
||||
dtype=torch.uint8,
|
||||
device=kv_c.device)
|
||||
slot_mapping = torch.arange(num_tokens,
|
||||
dtype=torch.long,
|
||||
device=kv_c.device)
|
||||
tmp_cache = torch.zeros(
|
||||
num_blocks, block_size, entry_size, dtype=torch.uint8, device=kv_c.device
|
||||
)
|
||||
slot_mapping = torch.arange(num_tokens, dtype=torch.long, device=kv_c.device)
|
||||
|
||||
ops.concat_and_cache_mla(kv_c,
|
||||
k_pe,
|
||||
tmp_cache,
|
||||
slot_mapping,
|
||||
kv_cache_dtype="fp8_ds_mla",
|
||||
scale=scale)
|
||||
ops.concat_and_cache_mla(
|
||||
kv_c, k_pe, tmp_cache, slot_mapping, kv_cache_dtype="fp8_ds_mla", scale=scale
|
||||
)
|
||||
|
||||
dequant_kv_c = torch.empty_like(kv_c)
|
||||
dequant_k_pe = torch.empty_like(k_pe)
|
||||
@@ -106,7 +108,8 @@ def _quantize_dequantize_fp8_ds_mla(
|
||||
block_offset = slot % block_size
|
||||
cache_slice = tmp_cache[block_idx, block_offset]
|
||||
latent, rope_vals = _dequantize_fp8_ds_mla_entry(
|
||||
cache_slice, kv_lora_rank, rope_dim, kv_c.dtype)
|
||||
cache_slice, kv_lora_rank, rope_dim, kv_c.dtype
|
||||
)
|
||||
dequant_kv_c[token_idx] = latent
|
||||
dequant_k_pe[token_idx] = rope_vals
|
||||
|
||||
@@ -123,10 +126,9 @@ def test_sparse_backend_metadata_registration():
|
||||
dtype_list = backend.get_supported_dtypes()
|
||||
assert torch.bfloat16 in dtype_list
|
||||
|
||||
shape = backend.get_kv_cache_shape(num_blocks=2,
|
||||
block_size=64,
|
||||
num_kv_heads=1,
|
||||
head_size=576)
|
||||
shape = backend.get_kv_cache_shape(
|
||||
num_blocks=2, block_size=64, num_kv_heads=1, head_size=576
|
||||
)
|
||||
assert shape == (2, 64, 576)
|
||||
|
||||
|
||||
@@ -141,13 +143,10 @@ def test_sparse_decode_metadata_filters_prefill_indices():
|
||||
|
||||
indices = torch.tensor([[0, 3, 5], [1, 2, 4]], dtype=torch.int32)
|
||||
|
||||
context_indices, new_token_indices = metadata.filter_prefill_indices(
|
||||
indices)
|
||||
context_indices, new_token_indices = metadata.filter_prefill_indices(indices)
|
||||
|
||||
expected_context = torch.tensor([[-1, -1, 5], [-1, -1, 4]],
|
||||
dtype=torch.int32)
|
||||
expected_new_tokens = torch.tensor([[-1, -1, 1], [-1, 0, 2]],
|
||||
dtype=torch.int32)
|
||||
expected_context = torch.tensor([[-1, -1, 5], [-1, -1, 4]], dtype=torch.int32)
|
||||
expected_new_tokens = torch.tensor([[-1, -1, 1], [-1, 0, 2]], dtype=torch.int32)
|
||||
|
||||
assert torch.equal(context_indices, expected_context)
|
||||
assert torch.equal(new_token_indices, expected_new_tokens)
|
||||
@@ -162,14 +161,9 @@ def test_sparse_impl_zero_fills_when_metadata_missing():
|
||||
kv_cache = torch.zeros((1, 1, 1))
|
||||
output = torch.ones((2, 4))
|
||||
|
||||
result = FlashMLASparseImpl.forward(impl,
|
||||
dummy_layer,
|
||||
q,
|
||||
k_c,
|
||||
k_pe,
|
||||
kv_cache,
|
||||
attn_metadata=None,
|
||||
output=output)
|
||||
result = FlashMLASparseImpl.forward(
|
||||
impl, dummy_layer, q, k_c, k_pe, kv_cache, attn_metadata=None, output=output
|
||||
)
|
||||
|
||||
assert result is output
|
||||
assert torch.all(result == 0)
|
||||
@@ -177,8 +171,7 @@ def test_sparse_impl_zero_fills_when_metadata_missing():
|
||||
|
||||
@pytest.mark.parametrize("batch_name", list(SPARSE_BACKEND_BATCH_SPECS.keys()))
|
||||
@pytest.mark.parametrize("kv_cache_dtype", ["fp8_ds_mla", "auto"])
|
||||
def test_sparse_backend_decode_correctness(dist_init, batch_name,
|
||||
kv_cache_dtype):
|
||||
def test_sparse_backend_decode_correctness(dist_init, batch_name, kv_cache_dtype):
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA is required for sparse MLA decode test")
|
||||
|
||||
@@ -203,14 +196,13 @@ def test_sparse_backend_decode_correctness(dist_init, batch_name,
|
||||
vllm_config = create_vllm_config(
|
||||
model_name="deepseek-ai/DeepSeek-V2-Lite-Chat",
|
||||
max_model_len=max_seqlen,
|
||||
num_gpu_blocks=max(2048,
|
||||
cdiv(total_cache_tokens, block_size) + 1),
|
||||
block_size=block_size)
|
||||
num_gpu_blocks=max(2048, cdiv(total_cache_tokens, block_size) + 1),
|
||||
block_size=block_size,
|
||||
)
|
||||
model_config = vllm_config.model_config
|
||||
model_config.hf_config = SimpleNamespace(
|
||||
attn_module_list_cfg=[{
|
||||
"topk_tokens": topk_tokens
|
||||
}])
|
||||
attn_module_list_cfg=[{"topk_tokens": topk_tokens}]
|
||||
)
|
||||
model_config.hf_text_config = SimpleNamespace(
|
||||
q_lora_rank=None,
|
||||
kv_lora_rank=kv_lora_rank,
|
||||
@@ -221,13 +213,13 @@ def test_sparse_backend_decode_correctness(dist_init, batch_name,
|
||||
)
|
||||
model_config.dtype = dtype
|
||||
model_config.get_num_attention_heads = MethodType(
|
||||
lambda self, parallel_config: num_heads, model_config)
|
||||
model_config.get_num_kv_heads = MethodType(lambda self, parallel_config: 1,
|
||||
model_config)
|
||||
model_config.get_head_size = MethodType(lambda self: head_size,
|
||||
model_config)
|
||||
model_config.get_sliding_window = MethodType(lambda self: None,
|
||||
model_config)
|
||||
lambda self, parallel_config: num_heads, model_config
|
||||
)
|
||||
model_config.get_num_kv_heads = MethodType(
|
||||
lambda self, parallel_config: 1, model_config
|
||||
)
|
||||
model_config.get_head_size = MethodType(lambda self: head_size, model_config)
|
||||
model_config.get_sliding_window = MethodType(lambda self: None, model_config)
|
||||
|
||||
kv_cache_spec = create_standard_kv_cache_spec(vllm_config)
|
||||
|
||||
@@ -236,16 +228,10 @@ def test_sparse_backend_decode_correctness(dist_init, batch_name,
|
||||
scale = 1.0 / math.sqrt(head_size)
|
||||
|
||||
# Shared MLA projection weights to keep reference and backend in sync
|
||||
W_UK = torch.randn(kv_lora_rank,
|
||||
num_heads,
|
||||
qk_nope_head_dim,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
W_UV = torch.randn(kv_lora_rank,
|
||||
num_heads,
|
||||
v_head_dim,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
W_UK = torch.randn(
|
||||
kv_lora_rank, num_heads, qk_nope_head_dim, dtype=dtype, device=device
|
||||
)
|
||||
W_UV = torch.randn(kv_lora_rank, num_heads, v_head_dim, dtype=dtype, device=device)
|
||||
|
||||
# Build synthetic decode-only workload
|
||||
seq_lens = batch_spec.seq_lens
|
||||
@@ -262,17 +248,15 @@ def test_sparse_backend_decode_correctness(dist_init, batch_name,
|
||||
q_len = query_lens[i]
|
||||
ctx_len = s_len - q_len
|
||||
|
||||
q_c = torch.rand(q_len,
|
||||
num_heads,
|
||||
qk_nope_head_dim + qk_rope_head_dim,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
q_c = torch.rand(
|
||||
q_len,
|
||||
num_heads,
|
||||
qk_nope_head_dim + qk_rope_head_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
kv_c_full = torch.rand(s_len, kv_lora_rank, dtype=dtype, device=device)
|
||||
k_pe_full = torch.rand(s_len,
|
||||
1,
|
||||
qk_rope_head_dim,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
k_pe_full = torch.rand(s_len, 1, qk_rope_head_dim, dtype=dtype, device=device)
|
||||
|
||||
kv_c_full, k_pe_full = _quantize_dequantize_fp8_ds_mla(
|
||||
kv_c_full,
|
||||
@@ -298,7 +282,8 @@ def test_sparse_backend_decode_correctness(dist_init, batch_name,
|
||||
v_sdpa_in = v_mqa.unsqueeze(0).transpose(1, 2)
|
||||
|
||||
sdpa_out = torch.nn.functional.scaled_dot_product_attention(
|
||||
q_sdpa_in, k_sdpa_in, v_sdpa_in, attn_mask=attn_mask, scale=scale)
|
||||
q_sdpa_in, k_sdpa_in, v_sdpa_in, attn_mask=attn_mask, scale=scale
|
||||
)
|
||||
sdpa_out = sdpa_out.transpose(1, 2).squeeze(0)
|
||||
|
||||
sdpa_out = torch.einsum("qnl,lnv->qnv", sdpa_out, W_UV)
|
||||
@@ -307,8 +292,8 @@ def test_sparse_backend_decode_correctness(dist_init, batch_name,
|
||||
all_q_vllm.append(q_c)
|
||||
all_kv_c_vllm.append(kv_c_full[ctx_len:])
|
||||
all_k_pe_vllm.append(k_pe_full[ctx_len:])
|
||||
kv_c_contexts.append(kv_c_full[:ctx_len + 1])
|
||||
k_pe_contexts.append(k_pe_full[:ctx_len + 1])
|
||||
kv_c_contexts.append(kv_c_full[: ctx_len + 1])
|
||||
k_pe_contexts.append(k_pe_full[: ctx_len + 1])
|
||||
|
||||
query_vllm = torch.cat(all_q_vllm, dim=0)
|
||||
kv_c_vllm = torch.cat(all_kv_c_vllm, dim=0)
|
||||
@@ -321,7 +306,8 @@ def test_sparse_backend_decode_correctness(dist_init, batch_name,
|
||||
batch_spec,
|
||||
vllm_config.cache_config.block_size,
|
||||
device,
|
||||
arange_block_indices=True)
|
||||
arange_block_indices=True,
|
||||
)
|
||||
|
||||
kv_cache = create_and_prepopulate_kv_cache(
|
||||
kv_c_contexts=kv_c_contexts,
|
||||
@@ -339,31 +325,31 @@ def test_sparse_backend_decode_correctness(dist_init, batch_name,
|
||||
|
||||
builder_cls = FlashMLASparseBackend.get_builder_cls()
|
||||
builder = builder_cls(kv_cache_spec, ["placeholder"], vllm_config, device)
|
||||
metadata = builder.build(common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata)
|
||||
metadata = builder.build(
|
||||
common_prefix_len=0, common_attn_metadata=common_attn_metadata
|
||||
)
|
||||
|
||||
starts = np.asarray(common_attn_metadata.query_start_loc_cpu,
|
||||
dtype=np.int32)
|
||||
starts = np.asarray(common_attn_metadata.query_start_loc_cpu, dtype=np.int32)
|
||||
seg_lengths = np.diff(starts)
|
||||
positions = np.arange(starts[-1], dtype=np.int32) - np.repeat(
|
||||
starts[:-1], seg_lengths)
|
||||
starts[:-1], seg_lengths
|
||||
)
|
||||
seq_lengths = np.asarray(common_attn_metadata.seq_lens_cpu, dtype=np.int32)
|
||||
prefix_lengths = seq_lengths - seg_lengths
|
||||
positions += np.repeat(prefix_lengths, seg_lengths)
|
||||
|
||||
pos_gpu = torch.as_tensor(positions, device=device, dtype=torch.int32)
|
||||
topk = metadata.topk_tokens
|
||||
debug_indices = torch.arange(topk, device=device,
|
||||
dtype=torch.int32).unsqueeze(0)
|
||||
debug_indices = torch.arange(topk, device=device, dtype=torch.int32).unsqueeze(0)
|
||||
token_positions = pos_gpu.unsqueeze(1)
|
||||
causal_mask = (debug_indices <= token_positions)
|
||||
debug_indices = torch.where(causal_mask, debug_indices,
|
||||
torch.full_like(debug_indices, -1))
|
||||
causal_mask = debug_indices <= token_positions
|
||||
debug_indices = torch.where(
|
||||
causal_mask, debug_indices, torch.full_like(debug_indices, -1)
|
||||
)
|
||||
|
||||
# FlashMLASparseImpl now reads top-k indices from the indexer-provided
|
||||
# buffer, so emulate that contract with a simple namespace mock.
|
||||
debug_indices = debug_indices.expand(metadata.num_actual_tokens,
|
||||
-1).clone()
|
||||
debug_indices = debug_indices.expand(metadata.num_actual_tokens, -1).clone()
|
||||
mock_indexer = SimpleNamespace(topk_indices_buffer=debug_indices)
|
||||
|
||||
ok, reason = flashmla.is_flashmla_supported()
|
||||
@@ -372,59 +358,54 @@ def test_sparse_backend_decode_correctness(dist_init, batch_name,
|
||||
|
||||
kv_b_proj_weight = torch.cat([W_UK, W_UV], dim=-1)
|
||||
kv_b_proj_weight = kv_b_proj_weight.view(
|
||||
kv_lora_rank, num_heads * (qk_nope_head_dim + v_head_dim))
|
||||
kv_lora_rank, num_heads * (qk_nope_head_dim + v_head_dim)
|
||||
)
|
||||
|
||||
mock_kv_b_proj = ColumnParallelLinear(input_size=kv_lora_rank,
|
||||
output_size=num_heads *
|
||||
(qk_nope_head_dim + v_head_dim),
|
||||
bias=False).to(device=device,
|
||||
dtype=dtype)
|
||||
mock_kv_b_proj = ColumnParallelLinear(
|
||||
input_size=kv_lora_rank,
|
||||
output_size=num_heads * (qk_nope_head_dim + v_head_dim),
|
||||
bias=False,
|
||||
).to(device=device, dtype=dtype)
|
||||
mock_kv_b_proj.weight = torch.nn.Parameter(kv_b_proj_weight.T.contiguous())
|
||||
|
||||
impl_cls = FlashMLASparseBackend.get_impl_cls()
|
||||
impl = impl_cls(num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
scale=scale,
|
||||
num_kv_heads=1,
|
||||
alibi_slopes=None,
|
||||
sliding_window=None,
|
||||
kv_cache_dtype=vllm_config.cache_config.cache_dtype,
|
||||
logits_soft_cap=None,
|
||||
attn_type="decoder",
|
||||
kv_sharing_target_layer_name=None,
|
||||
q_lora_rank=None,
|
||||
kv_lora_rank=kv_lora_rank,
|
||||
qk_nope_head_dim=qk_nope_head_dim,
|
||||
qk_rope_head_dim=qk_rope_head_dim,
|
||||
qk_head_dim=qk_nope_head_dim + qk_rope_head_dim,
|
||||
v_head_dim=v_head_dim,
|
||||
kv_b_proj=mock_kv_b_proj,
|
||||
indexer=mock_indexer)
|
||||
impl = impl_cls(
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
scale=scale,
|
||||
num_kv_heads=1,
|
||||
alibi_slopes=None,
|
||||
sliding_window=None,
|
||||
kv_cache_dtype=vllm_config.cache_config.cache_dtype,
|
||||
logits_soft_cap=None,
|
||||
attn_type="decoder",
|
||||
kv_sharing_target_layer_name=None,
|
||||
q_lora_rank=None,
|
||||
kv_lora_rank=kv_lora_rank,
|
||||
qk_nope_head_dim=qk_nope_head_dim,
|
||||
qk_rope_head_dim=qk_rope_head_dim,
|
||||
qk_head_dim=qk_nope_head_dim + qk_rope_head_dim,
|
||||
v_head_dim=v_head_dim,
|
||||
kv_b_proj=mock_kv_b_proj,
|
||||
indexer=mock_indexer,
|
||||
)
|
||||
|
||||
impl.process_weights_after_loading(dtype)
|
||||
|
||||
layer = MockAttentionLayer(device)
|
||||
out_buffer = torch.empty(metadata.num_actual_tokens,
|
||||
num_heads * v_head_dim,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
out_buffer = torch.empty(
|
||||
metadata.num_actual_tokens, num_heads * v_head_dim, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
backend_output = impl.forward(layer,
|
||||
query_vllm,
|
||||
kv_c_vllm,
|
||||
k_pe_vllm,
|
||||
kv_cache,
|
||||
metadata,
|
||||
output=out_buffer)
|
||||
backend_output = impl.forward(
|
||||
layer, query_vllm, kv_c_vllm, k_pe_vllm, kv_cache, metadata, output=out_buffer
|
||||
)
|
||||
|
||||
assert backend_output.shape == sdpa_reference.shape
|
||||
assert backend_output.dtype == sdpa_reference.dtype
|
||||
assert torch.isfinite(backend_output).all()
|
||||
|
||||
torch.testing.assert_close(backend_output,
|
||||
sdpa_reference,
|
||||
rtol=0.5,
|
||||
atol=0.5)
|
||||
torch.testing.assert_close(backend_output, sdpa_reference, rtol=0.5, atol=0.5)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
Reference in New Issue
Block a user