Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -1,15 +1,19 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for v1 MLA backends without GPUModelRunner dependency."""
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.v1.attention.utils import (BatchSpec, create_common_attn_metadata,
|
||||
create_standard_kv_cache_spec,
|
||||
create_vllm_config,
|
||||
get_attention_backend)
|
||||
from tests.v1.attention.utils import (
|
||||
BatchSpec,
|
||||
create_common_attn_metadata,
|
||||
create_standard_kv_cache_spec,
|
||||
create_vllm_config,
|
||||
get_attention_backend,
|
||||
)
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
|
||||
@@ -17,13 +21,14 @@ from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec
|
||||
|
||||
BACKENDS_TO_TEST = [
|
||||
_Backend.CUTLASS_MLA, _Backend.FLASHMLA, _Backend.FLASH_ATTN_MLA,
|
||||
_Backend.TRITON_MLA
|
||||
_Backend.CUTLASS_MLA,
|
||||
_Backend.FLASHMLA,
|
||||
_Backend.FLASH_ATTN_MLA,
|
||||
_Backend.TRITON_MLA,
|
||||
]
|
||||
|
||||
# Remove CUTLASS_MLA from the list if not using sm100
|
||||
if not torch.cuda.is_available() or torch.cuda.get_device_properties(
|
||||
0).major < 10:
|
||||
if not torch.cuda.is_available() or torch.cuda.get_device_properties(0).major < 10:
|
||||
BACKENDS_TO_TEST.remove(_Backend.CUTLASS_MLA)
|
||||
|
||||
torch.manual_seed(42)
|
||||
@@ -46,45 +51,41 @@ def _convert_dtype_to_torch(dtype):
|
||||
|
||||
# Define common batch configurations
|
||||
BATCH_SPECS = {
|
||||
"small_decode":
|
||||
BatchSpec(seq_lens=[32, 40], query_lens=[1, 1]),
|
||||
"small_prefill":
|
||||
BatchSpec(seq_lens=[32, 40], query_lens=[8, 8]),
|
||||
"mixed_small":
|
||||
BatchSpec(seq_lens=[32, 40, 48, 56], query_lens=[1, 1, 5, 5]),
|
||||
"medium_decode":
|
||||
BatchSpec(seq_lens=[128, 256, 512, 1024, 128, 256, 512, 1024],
|
||||
query_lens=[1, 1, 1, 1, 1, 1, 1, 1]),
|
||||
"medium_prefill":
|
||||
BatchSpec(seq_lens=[256, 512, 1024, 2048], query_lens=[16, 16, 16, 16]),
|
||||
"mixed_medium":
|
||||
BatchSpec(seq_lens=[512, 1024, 2048, 512, 1024, 2048],
|
||||
query_lens=[1, 1, 1, 7, 7, 7]),
|
||||
"large_decode":
|
||||
BatchSpec(seq_lens=[2048] * 32, query_lens=[1] * 32),
|
||||
"large_prefill":
|
||||
BatchSpec(seq_lens=[4096] * 8, query_lens=[32] * 8),
|
||||
"single_decode":
|
||||
BatchSpec(seq_lens=[1024], query_lens=[1]),
|
||||
"single_prefill":
|
||||
BatchSpec(seq_lens=[1024], query_lens=[64]),
|
||||
"small_decode": BatchSpec(seq_lens=[32, 40], query_lens=[1, 1]),
|
||||
"small_prefill": BatchSpec(seq_lens=[32, 40], query_lens=[8, 8]),
|
||||
"mixed_small": BatchSpec(seq_lens=[32, 40, 48, 56], query_lens=[1, 1, 5, 5]),
|
||||
"medium_decode": BatchSpec(
|
||||
seq_lens=[128, 256, 512, 1024, 128, 256, 512, 1024],
|
||||
query_lens=[1, 1, 1, 1, 1, 1, 1, 1],
|
||||
),
|
||||
"medium_prefill": BatchSpec(
|
||||
seq_lens=[256, 512, 1024, 2048], query_lens=[16, 16, 16, 16]
|
||||
),
|
||||
"mixed_medium": BatchSpec(
|
||||
seq_lens=[512, 1024, 2048, 512, 1024, 2048], query_lens=[1, 1, 1, 7, 7, 7]
|
||||
),
|
||||
"large_decode": BatchSpec(seq_lens=[2048] * 32, query_lens=[1] * 32),
|
||||
"large_prefill": BatchSpec(seq_lens=[4096] * 8, query_lens=[32] * 8),
|
||||
"single_decode": BatchSpec(seq_lens=[1024], query_lens=[1]),
|
||||
"single_prefill": BatchSpec(seq_lens=[1024], query_lens=[64]),
|
||||
}
|
||||
|
||||
|
||||
def create_and_prepopulate_kv_cache(
|
||||
kv_c_contexts: list[torch.Tensor],
|
||||
k_pe_contexts: list[torch.Tensor],
|
||||
block_size: int,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
num_blocks: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
randomize_blocks: bool = True,
|
||||
kv_cache_dtype: Optional[str] = None,
|
||||
scale: Union[float, torch.Tensor] = 1.0) -> torch.Tensor:
|
||||
kv_c_contexts: list[torch.Tensor],
|
||||
k_pe_contexts: list[torch.Tensor],
|
||||
block_size: int,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
num_blocks: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
randomize_blocks: bool = True,
|
||||
kv_cache_dtype: Optional[str] = None,
|
||||
scale: Union[float, torch.Tensor] = 1.0,
|
||||
) -> torch.Tensor:
|
||||
"""Create and prepopulate an MLA KV cache with context data.
|
||||
|
||||
|
||||
Args:
|
||||
kv_c_contexts: List of latent KV context tensors for each sequence
|
||||
k_pe_contexts: List of key positional embedding context tensors
|
||||
@@ -95,21 +96,23 @@ def create_and_prepopulate_kv_cache(
|
||||
device: Device to create the cache on
|
||||
num_blocks: Total number of blocks in the cache
|
||||
common_attn_metadata: Common attention metadata
|
||||
randomize_blocks: Whether to randomly permute blocks
|
||||
randomize_blocks: Whether to randomly permute blocks
|
||||
or use sequential order
|
||||
kv_cache_dtype: Optional kv cache dtype string. When set to
|
||||
"fp8_ds_mla" the cache is populated using the
|
||||
fp8 DeepSeek MLA layout via concat_and_cache_mla.
|
||||
scale: Scaling factor forwarded to concat_and_cache_mla when the
|
||||
fp8 cache layout is requested.
|
||||
|
||||
|
||||
Returns:
|
||||
MLA KV cache tensor
|
||||
"""
|
||||
batch_size = len(kv_c_contexts)
|
||||
seq_lens = common_attn_metadata.seq_lens_cpu
|
||||
query_lens = common_attn_metadata.query_start_loc_cpu[
|
||||
1:] - common_attn_metadata.query_start_loc_cpu[:-1]
|
||||
query_lens = (
|
||||
common_attn_metadata.query_start_loc_cpu[1:]
|
||||
- common_attn_metadata.query_start_loc_cpu[:-1]
|
||||
)
|
||||
context_lens = common_attn_metadata.num_computed_tokens_cpu
|
||||
block_table = common_attn_metadata.block_table_tensor
|
||||
slot_mapping = common_attn_metadata.slot_mapping
|
||||
@@ -118,27 +121,26 @@ def create_and_prepopulate_kv_cache(
|
||||
|
||||
if use_fp8_ds_mla:
|
||||
if not kv_c_contexts:
|
||||
raise ValueError("kv_c_contexts cannot be empty when using"
|
||||
" fp8_ds_mla cache dtype")
|
||||
raise ValueError(
|
||||
"kv_c_contexts cannot be empty when using fp8_ds_mla cache dtype"
|
||||
)
|
||||
kv_lora_rank = kv_c_contexts[0].shape[-1]
|
||||
rope_dim = k_pe_contexts[0].shape[-1]
|
||||
entry_size = kv_lora_rank + 4 * 4 + 2 * rope_dim
|
||||
kv_cache = torch.zeros(num_blocks,
|
||||
block_size,
|
||||
entry_size,
|
||||
dtype=torch.uint8,
|
||||
device=device)
|
||||
scale_tensor = (scale
|
||||
if isinstance(scale, torch.Tensor) else torch.tensor(
|
||||
scale, dtype=torch.float32, device=device))
|
||||
kv_cache = torch.zeros(
|
||||
num_blocks, block_size, entry_size, dtype=torch.uint8, device=device
|
||||
)
|
||||
scale_tensor = (
|
||||
scale
|
||||
if isinstance(scale, torch.Tensor)
|
||||
else torch.tensor(scale, dtype=torch.float32, device=device)
|
||||
)
|
||||
scale_tensor = scale_tensor.to(device=device, dtype=torch.float32)
|
||||
else:
|
||||
# Create MLA KV cache: (num_blocks, block_size, head_size)
|
||||
kv_cache = torch.empty(num_blocks,
|
||||
block_size,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
kv_cache = torch.empty(
|
||||
num_blocks, block_size, head_size, dtype=dtype, device=device
|
||||
)
|
||||
kv_cache_flat = kv_cache.view(-1, head_size)
|
||||
|
||||
# Populate the cache with the context tokens
|
||||
@@ -154,8 +156,7 @@ def create_and_prepopulate_kv_cache(
|
||||
start = start_block_idx * block_size
|
||||
|
||||
if use_fp8_ds_mla:
|
||||
slots = torch.arange(context_len, device=device,
|
||||
dtype=torch.long) + start
|
||||
slots = torch.arange(context_len, device=device, dtype=torch.long) + start
|
||||
ops.concat_and_cache_mla(
|
||||
kv_c_context,
|
||||
k_pe_context.squeeze(1),
|
||||
@@ -165,8 +166,7 @@ def create_and_prepopulate_kv_cache(
|
||||
scale=scale_tensor,
|
||||
)
|
||||
else:
|
||||
kv_context = torch.cat(
|
||||
[kv_c_context, k_pe_context.squeeze(1)], dim=-1)
|
||||
kv_context = torch.cat([kv_c_context, k_pe_context.squeeze(1)], dim=-1)
|
||||
end = start + kv_context.shape[0]
|
||||
kv_cache_flat[start:end, ...] = kv_context
|
||||
|
||||
@@ -177,15 +177,14 @@ def create_and_prepopulate_kv_cache(
|
||||
|
||||
# Permute the context blocks (excluding block 0 which is null)
|
||||
if randomize_blocks:
|
||||
perm = torch.randperm(
|
||||
blocks_end - 1) + 1 # Random permutation starting from block 1
|
||||
perm = (
|
||||
torch.randperm(blocks_end - 1) + 1
|
||||
) # Random permutation starting from block 1
|
||||
else:
|
||||
perm = torch.arange(
|
||||
1, blocks_end) # Sequential order starting from block 1
|
||||
perm = torch.arange(1, blocks_end) # Sequential order starting from block 1
|
||||
|
||||
inv_perm = torch.zeros(blocks_end, dtype=torch.long, device=device)
|
||||
inv_perm[1:] = torch.argsort(
|
||||
perm) + 1 # Add 1 to account for starting from block 1
|
||||
inv_perm[1:] = torch.argsort(perm) + 1 # Add 1 to account for starting from block 1
|
||||
kv_cache[1:blocks_end, ...] = kv_cache[perm, ...]
|
||||
|
||||
# Construct the right block table
|
||||
@@ -206,8 +205,8 @@ def create_and_prepopulate_kv_cache(
|
||||
start = common_attn_metadata.query_start_loc_cpu[i]
|
||||
end = common_attn_metadata.query_start_loc_cpu[i + 1]
|
||||
slot_mapping[start:end] = block_table[
|
||||
i,
|
||||
block_indices] * block_size + token_inter_block_offsets.to(device)
|
||||
i, block_indices
|
||||
] * block_size + token_inter_block_offsets.to(device)
|
||||
|
||||
return kv_cache
|
||||
|
||||
@@ -221,15 +220,23 @@ class MockAttentionLayer:
|
||||
self._v_scale = torch.tensor(1.0, device=device)
|
||||
|
||||
|
||||
def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec,
|
||||
layer_names: list[str], vllm_config,
|
||||
device: torch.device,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
query: torch.Tensor, kv_c: torch.Tensor,
|
||||
k_pe: torch.Tensor, kv_cache: torch.Tensor,
|
||||
kv_lora_rank: int, qk_nope_head_dim: int,
|
||||
qk_rope_head_dim: int, v_head_dim: int,
|
||||
mock_kv_b_proj) -> torch.Tensor:
|
||||
def run_attention_backend(
|
||||
backend: _Backend,
|
||||
kv_cache_spec: FullAttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config,
|
||||
device: torch.device,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
query: torch.Tensor,
|
||||
kv_c: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
kv_lora_rank: int,
|
||||
qk_nope_head_dim: int,
|
||||
qk_rope_head_dim: int,
|
||||
v_head_dim: int,
|
||||
mock_kv_b_proj,
|
||||
) -> torch.Tensor:
|
||||
"""Run attention computation using the specified backend's AttentionImpl."""
|
||||
|
||||
builder_cls, impl_cls = get_attention_backend(backend)
|
||||
@@ -243,9 +250,11 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec,
|
||||
|
||||
# Instantiate MLA implementation
|
||||
num_heads = vllm_config.model_config.get_num_attention_heads(
|
||||
vllm_config.parallel_config)
|
||||
vllm_config.parallel_config
|
||||
)
|
||||
num_kv_heads = vllm_config.model_config.get_num_kv_heads(
|
||||
vllm_config.parallel_config)
|
||||
vllm_config.parallel_config
|
||||
)
|
||||
head_size = vllm_config.model_config.get_head_size()
|
||||
scale = 1.0 / (head_size**0.5)
|
||||
impl = impl_cls(
|
||||
@@ -275,30 +284,35 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec,
|
||||
# Create mock layer and output buffer
|
||||
mock_layer = MockAttentionLayer(device)
|
||||
num_tokens = query.shape[0]
|
||||
output = torch.empty(num_tokens,
|
||||
num_heads * v_head_dim,
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
output = torch.empty(
|
||||
num_tokens, num_heads * v_head_dim, dtype=query.dtype, device=query.device
|
||||
)
|
||||
|
||||
# Run forward pass
|
||||
# NOTE: The query, key, and value are already shaped correctly
|
||||
# in the calling test function.
|
||||
output = impl.forward(mock_layer,
|
||||
query,
|
||||
kv_c,
|
||||
k_pe,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
output=output)
|
||||
output = impl.forward(
|
||||
mock_layer, query, kv_c, k_pe, kv_cache, attn_metadata, output=output
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_spec_name", [
|
||||
"small_decode", "small_prefill", "mixed_small", "medium_decode",
|
||||
"medium_prefill", "mixed_medium", "large_decode", "large_prefill",
|
||||
"single_decode", "single_prefill"
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"batch_spec_name",
|
||||
[
|
||||
"small_decode",
|
||||
"small_prefill",
|
||||
"mixed_small",
|
||||
"medium_decode",
|
||||
"medium_prefill",
|
||||
"mixed_medium",
|
||||
"large_decode",
|
||||
"large_prefill",
|
||||
"single_decode",
|
||||
"single_prefill",
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("model", ["deepseek-ai/DeepSeek-V2-Lite-Chat"])
|
||||
def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
|
||||
"""
|
||||
@@ -317,9 +331,9 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
|
||||
5. Comparing the vLLM backend's output to the ground-truth SDPA output.
|
||||
"""
|
||||
batch_spec = BATCH_SPECS[batch_spec_name]
|
||||
vllm_config = create_vllm_config(model_name=model,
|
||||
max_model_len=max(batch_spec.seq_lens),
|
||||
num_gpu_blocks=2048)
|
||||
vllm_config = create_vllm_config(
|
||||
model_name=model, max_model_len=max(batch_spec.seq_lens), num_gpu_blocks=2048
|
||||
)
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
kv_cache_spec = create_standard_kv_cache_spec(vllm_config)
|
||||
@@ -329,7 +343,8 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
|
||||
seq_lens = batch_spec.seq_lens
|
||||
query_lens = batch_spec.query_lens
|
||||
num_q_heads = vllm_config.model_config.get_num_attention_heads(
|
||||
vllm_config.parallel_config)
|
||||
vllm_config.parallel_config
|
||||
)
|
||||
head_size = vllm_config.model_config.get_head_size()
|
||||
dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype)
|
||||
block_size = vllm_config.cache_config.block_size
|
||||
@@ -338,8 +353,9 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
|
||||
qk_nope_head_dim = 128
|
||||
v_head_dim = 128
|
||||
total_head_size = kv_lora_rank + qk_rope_head_dim
|
||||
assert kv_lora_rank + qk_rope_head_dim == head_size, \
|
||||
assert kv_lora_rank + qk_rope_head_dim == head_size, (
|
||||
f"MLA dimensions don't match: {total_head_size} != {head_size}"
|
||||
)
|
||||
scale = 1.0 / (total_head_size**0.5)
|
||||
|
||||
# 2. Generate data and compute SDPA reference output for MLA
|
||||
@@ -348,16 +364,12 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
|
||||
kv_c_contexts, k_pe_contexts = [], []
|
||||
|
||||
# Create shared MLA weight matrices for consistency across all sequences
|
||||
W_UK = torch.randn(kv_lora_rank,
|
||||
num_q_heads,
|
||||
qk_nope_head_dim,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
W_UV = torch.randn(kv_lora_rank,
|
||||
num_q_heads,
|
||||
v_head_dim,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
W_UK = torch.randn(
|
||||
kv_lora_rank, num_q_heads, qk_nope_head_dim, dtype=dtype, device=device
|
||||
)
|
||||
W_UV = torch.randn(
|
||||
kv_lora_rank, num_q_heads, v_head_dim, dtype=dtype, device=device
|
||||
)
|
||||
kv_b_proj_weight = torch.cat([W_UK, W_UV], dim=-1)
|
||||
|
||||
for i, backend in enumerate(BACKENDS_TO_TEST):
|
||||
@@ -371,24 +383,19 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
|
||||
# Generate MLA tensors
|
||||
# Q has both nope and rope components:
|
||||
# [q_len, num_heads, qk_nope_head_dim + qk_rope_head_dim]
|
||||
q_c = torch.randn(q_len,
|
||||
num_q_heads,
|
||||
qk_nope_head_dim + qk_rope_head_dim,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
q_c = torch.randn(
|
||||
q_len,
|
||||
num_q_heads,
|
||||
qk_nope_head_dim + qk_rope_head_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# KV_C (latent K/V): [s_len, kv_lora_rank]
|
||||
kv_c_full = torch.randn(s_len,
|
||||
kv_lora_rank,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
kv_c_full = torch.randn(s_len, kv_lora_rank, dtype=dtype, device=device)
|
||||
|
||||
# K_PE (rope component): [s_len, 1, qk_rope_head_dim]
|
||||
k_pe_full = torch.randn(s_len,
|
||||
1,
|
||||
qk_rope_head_dim,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
k_pe_full = torch.randn(s_len, 1, qk_rope_head_dim, dtype=dtype, device=device)
|
||||
|
||||
# Determine if this is decode or prefill
|
||||
is_decode = []
|
||||
@@ -404,8 +411,9 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
|
||||
# Transform q_nope to latent space: q_nope @ W_UK
|
||||
# q_nope: [1, num_heads, qk_nope_head_dim]
|
||||
# W_UK: [kv_lora_rank, num_heads, qk_nope_head_dim]
|
||||
ql_nope = torch.einsum("qnh,lnh->qnl", q_nope,
|
||||
W_UK) # [1, num_heads, kv_lora_rank]
|
||||
ql_nope = torch.einsum(
|
||||
"qnh,lnh->qnl", q_nope, W_UK
|
||||
) # [1, num_heads, kv_lora_rank]
|
||||
|
||||
# Build MQA attention inputs
|
||||
# Q: [1, num_heads, kv_lora_rank + qk_rope_head_dim]
|
||||
@@ -431,25 +439,24 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
|
||||
v_sdpa_in = v_mqa.unsqueeze(0).transpose(1, 2)
|
||||
|
||||
sdpa_out_i_decode = torch.nn.functional.scaled_dot_product_attention(
|
||||
q_sdpa_in, k_sdpa_in, v_sdpa_in, attn_mask=attn_mask, scale=scale)
|
||||
q_sdpa_in, k_sdpa_in, v_sdpa_in, attn_mask=attn_mask, scale=scale
|
||||
)
|
||||
sdpa_out_i_decode = sdpa_out_i_decode.transpose(1, 2).squeeze(
|
||||
0) # [1, num_heads, kv_lora_rank]
|
||||
0
|
||||
) # [1, num_heads, kv_lora_rank]
|
||||
|
||||
# Project back to output space: sdpa_out @ W_UV
|
||||
sdpa_out_i_decode = torch.einsum("qnl,lnv->qnv", sdpa_out_i_decode,
|
||||
W_UV)
|
||||
sdpa_out_i_decode = torch.einsum("qnl,lnv->qnv", sdpa_out_i_decode, W_UV)
|
||||
sdpa_out_i_decode = sdpa_out_i_decode.flatten(start_dim=-2)
|
||||
|
||||
#######################################################
|
||||
# Prefill path: MHA-style attention with full sequence
|
||||
# Apply kv_b_proj to the full kv_c tensor
|
||||
kv_nope_full = torch.einsum("sl,lnh->snh", kv_c_full, kv_b_proj_weight)
|
||||
k_nope_full, v_full = kv_nope_full.split(
|
||||
[qk_nope_head_dim, v_head_dim], dim=-1)
|
||||
k_nope_full, v_full = kv_nope_full.split([qk_nope_head_dim, v_head_dim], dim=-1)
|
||||
|
||||
# Build attention inputs for full sequence
|
||||
q_mha = torch.cat([q_nope, q_pe],
|
||||
dim=-1) # [q_len, num_heads, total_dim]
|
||||
q_mha = torch.cat([q_nope, q_pe], dim=-1) # [q_len, num_heads, total_dim]
|
||||
k_pe_full_expanded = k_pe_full.expand(-1, num_q_heads, -1)
|
||||
k_full = torch.cat([k_nope_full, k_pe_full_expanded], dim=-1)
|
||||
|
||||
@@ -468,7 +475,8 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
|
||||
|
||||
# Single attention call with custom mask
|
||||
sdpa_out_i_prefill = torch.nn.functional.scaled_dot_product_attention(
|
||||
q_sdpa_in, k_sdpa_in, v_sdpa_in, attn_mask=attn_mask, scale=scale)
|
||||
q_sdpa_in, k_sdpa_in, v_sdpa_in, attn_mask=attn_mask, scale=scale
|
||||
)
|
||||
sdpa_out_i_prefill = sdpa_out_i_prefill.transpose(1, 2).squeeze(0)
|
||||
sdpa_out_i_prefill = sdpa_out_i_prefill.flatten(start_dim=-2)
|
||||
|
||||
@@ -497,22 +505,25 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
|
||||
|
||||
# Create mock kv_b_proj using the same weights as reference implementation
|
||||
from vllm.model_executor.layers.linear import ColumnParallelLinear
|
||||
mock_kv_b_proj = ColumnParallelLinear(input_size=kv_lora_rank,
|
||||
output_size=num_q_heads *
|
||||
(qk_nope_head_dim + v_head_dim),
|
||||
bias=False).to(device=device,
|
||||
dtype=dtype)
|
||||
|
||||
mock_kv_b_proj = ColumnParallelLinear(
|
||||
input_size=kv_lora_rank,
|
||||
output_size=num_q_heads * (qk_nope_head_dim + v_head_dim),
|
||||
bias=False,
|
||||
).to(device=device, dtype=dtype)
|
||||
|
||||
# Set the mock weights to match our reference implementation
|
||||
# Reshape W_UK and W_UV to match the expected kv_b_proj format
|
||||
# [kv_lora_rank, num_heads, qk_nope_head_dim + v_head_dim]
|
||||
kv_b_proj_weight = kv_b_proj_weight.view(
|
||||
kv_lora_rank, num_q_heads * (qk_nope_head_dim + v_head_dim))
|
||||
kv_lora_rank, num_q_heads * (qk_nope_head_dim + v_head_dim)
|
||||
)
|
||||
mock_kv_b_proj.weight = torch.nn.Parameter(kv_b_proj_weight.T)
|
||||
|
||||
# Create metadata using original batch spec
|
||||
common_attn_metadata = create_common_attn_metadata(
|
||||
batch_spec, vllm_config.cache_config.block_size, device)
|
||||
batch_spec, vllm_config.cache_config.block_size, device
|
||||
)
|
||||
|
||||
# 3. Simulate Paged KV Cache and a realistic slot_mapping
|
||||
kv_cache = create_and_prepopulate_kv_cache(
|
||||
@@ -524,41 +535,56 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
|
||||
device=device,
|
||||
num_blocks=vllm_config.cache_config.num_gpu_blocks,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
randomize_blocks=True)
|
||||
randomize_blocks=True,
|
||||
)
|
||||
|
||||
# 4. Run vLLM backends and compare
|
||||
for i, backend_name in enumerate(BACKENDS_TO_TEST):
|
||||
backend_output = run_attention_backend(
|
||||
backend_name, kv_cache_spec, ["placeholder"], vllm_config, device,
|
||||
common_attn_metadata, query_vllm, kv_c_vllm, k_pe_vllm, kv_cache,
|
||||
kv_lora_rank, qk_nope_head_dim, qk_rope_head_dim, v_head_dim,
|
||||
mock_kv_b_proj)
|
||||
backend_name,
|
||||
kv_cache_spec,
|
||||
["placeholder"],
|
||||
vllm_config,
|
||||
device,
|
||||
common_attn_metadata,
|
||||
query_vllm,
|
||||
kv_c_vllm,
|
||||
k_pe_vllm,
|
||||
kv_cache,
|
||||
kv_lora_rank,
|
||||
qk_nope_head_dim,
|
||||
qk_rope_head_dim,
|
||||
v_head_dim,
|
||||
mock_kv_b_proj,
|
||||
)
|
||||
|
||||
# Check shape and dtype consistency
|
||||
assert backend_output.shape == sdpa_outputs[i].shape, (
|
||||
f"[{backend_name}] shape {backend_output.shape} != "
|
||||
f"SDPA shape {sdpa_outputs[i].shape}")
|
||||
f"SDPA shape {sdpa_outputs[i].shape}"
|
||||
)
|
||||
assert backend_output.dtype == sdpa_outputs[i].dtype, (
|
||||
f"[{backend_name}] dtype {backend_output.dtype} != "
|
||||
f"SDPA dtype {sdpa_outputs[i].dtype}")
|
||||
f"SDPA dtype {sdpa_outputs[i].dtype}"
|
||||
)
|
||||
|
||||
assert torch.isfinite(backend_output).all(), (
|
||||
f"[{backend_name}] produced non-finite values")
|
||||
f"[{backend_name}] produced non-finite values"
|
||||
)
|
||||
|
||||
# Check numerical similarity
|
||||
rtol = 1e-2
|
||||
atol = 5e-1
|
||||
|
||||
max_diff = torch.max(torch.abs(backend_output -
|
||||
sdpa_outputs[i])).item()
|
||||
max_diff = torch.max(torch.abs(backend_output - sdpa_outputs[i])).item()
|
||||
max_rel_diff = torch.max(
|
||||
torch.abs(backend_output - sdpa_outputs[i]) /
|
||||
torch.abs(sdpa_outputs[i])).item()
|
||||
all_close = torch.allclose(backend_output,
|
||||
sdpa_outputs[i],
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
torch.abs(backend_output - sdpa_outputs[i]) / torch.abs(sdpa_outputs[i])
|
||||
).item()
|
||||
all_close = torch.allclose(
|
||||
backend_output, sdpa_outputs[i], rtol=rtol, atol=atol
|
||||
)
|
||||
|
||||
assert all_close, (
|
||||
f"[{backend_name}] output differs from SDPA baseline. "
|
||||
f"Max diff: {max_diff:.6f}, max rel diff: {max_rel_diff:.6f})")
|
||||
f"Max diff: {max_diff:.6f}, max rel diff: {max_rel_diff:.6f})"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user