Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for v1 attention backends without GPUModelRunner dependency."""
|
||||
|
||||
from functools import partial
|
||||
from typing import Optional, Union
|
||||
|
||||
@@ -8,21 +9,30 @@ import pytest
|
||||
import torch
|
||||
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
|
||||
|
||||
from tests.v1.attention.utils import (BatchSpec, create_common_attn_metadata,
|
||||
create_standard_kv_cache_spec,
|
||||
create_vllm_config,
|
||||
get_attention_backend)
|
||||
from tests.v1.attention.utils import (
|
||||
BatchSpec,
|
||||
create_common_attn_metadata,
|
||||
create_standard_kv_cache_spec,
|
||||
create_vllm_config,
|
||||
get_attention_backend,
|
||||
)
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv, is_torch_equal_or_newer
|
||||
from vllm.v1.attention.backends.utils import (CommonAttentionMetadata,
|
||||
set_kv_cache_layout)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
CommonAttentionMetadata,
|
||||
set_kv_cache_layout,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec
|
||||
|
||||
BACKENDS_TO_TEST = [
|
||||
_Backend.FLASH_ATTN, _Backend.FLASHINFER, _Backend.FLEX_ATTENTION,
|
||||
_Backend.TRITON_ATTN, _Backend.TREE_ATTN, "FLEX_ATTENTION_SLOW"
|
||||
_Backend.FLASH_ATTN,
|
||||
_Backend.FLASHINFER,
|
||||
_Backend.FLEX_ATTENTION,
|
||||
_Backend.TRITON_ATTN,
|
||||
_Backend.TREE_ATTN,
|
||||
"FLEX_ATTENTION_SLOW",
|
||||
]
|
||||
|
||||
# Remove flashinfer from the list if it's not available
|
||||
@@ -49,42 +59,38 @@ def _convert_dtype_to_torch(dtype):
|
||||
|
||||
# Define common batch configurations
|
||||
BATCH_SPECS = {
|
||||
"small_decode":
|
||||
BatchSpec(seq_lens=[32, 40], query_lens=[1, 1]),
|
||||
"small_prefill":
|
||||
BatchSpec(seq_lens=[32, 40], query_lens=[8, 8]),
|
||||
"mixed_small":
|
||||
BatchSpec(seq_lens=[32, 40, 48, 56], query_lens=[1, 1, 5, 5]),
|
||||
"medium_decode":
|
||||
BatchSpec(seq_lens=[128, 256, 512, 1024, 128, 256, 512, 1024],
|
||||
query_lens=[1, 1, 1, 1, 1, 1, 1, 1]),
|
||||
"medium_prefill":
|
||||
BatchSpec(seq_lens=[256, 512, 1024, 2048], query_lens=[16, 16, 16, 16]),
|
||||
"mixed_medium":
|
||||
BatchSpec(seq_lens=[512, 1024, 2048, 512, 1024, 2048],
|
||||
query_lens=[1, 1, 1, 7, 7, 7]),
|
||||
"large_decode":
|
||||
BatchSpec(seq_lens=[2048] * 32, query_lens=[1] * 32),
|
||||
"large_prefill":
|
||||
BatchSpec(seq_lens=[4096] * 8, query_lens=[32] * 8),
|
||||
"single_decode":
|
||||
BatchSpec(seq_lens=[1024], query_lens=[1]),
|
||||
"single_prefill":
|
||||
BatchSpec(seq_lens=[1024], query_lens=[64]),
|
||||
"small_decode": BatchSpec(seq_lens=[32, 40], query_lens=[1, 1]),
|
||||
"small_prefill": BatchSpec(seq_lens=[32, 40], query_lens=[8, 8]),
|
||||
"mixed_small": BatchSpec(seq_lens=[32, 40, 48, 56], query_lens=[1, 1, 5, 5]),
|
||||
"medium_decode": BatchSpec(
|
||||
seq_lens=[128, 256, 512, 1024, 128, 256, 512, 1024],
|
||||
query_lens=[1, 1, 1, 1, 1, 1, 1, 1],
|
||||
),
|
||||
"medium_prefill": BatchSpec(
|
||||
seq_lens=[256, 512, 1024, 2048], query_lens=[16, 16, 16, 16]
|
||||
),
|
||||
"mixed_medium": BatchSpec(
|
||||
seq_lens=[512, 1024, 2048, 512, 1024, 2048], query_lens=[1, 1, 1, 7, 7, 7]
|
||||
),
|
||||
"large_decode": BatchSpec(seq_lens=[2048] * 32, query_lens=[1] * 32),
|
||||
"large_prefill": BatchSpec(seq_lens=[4096] * 8, query_lens=[32] * 8),
|
||||
"single_decode": BatchSpec(seq_lens=[1024], query_lens=[1]),
|
||||
"single_prefill": BatchSpec(seq_lens=[1024], query_lens=[64]),
|
||||
}
|
||||
|
||||
|
||||
def create_and_prepopulate_kv_cache(
|
||||
k_contexts: list[torch.Tensor],
|
||||
v_contexts: list[torch.Tensor],
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
num_blocks: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
randomize_blocks: bool = True) -> torch.Tensor:
|
||||
k_contexts: list[torch.Tensor],
|
||||
v_contexts: list[torch.Tensor],
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
num_blocks: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
randomize_blocks: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""Create and prepopulate a KV cache with context data.
|
||||
|
||||
Args:
|
||||
@@ -106,20 +112,18 @@ def create_and_prepopulate_kv_cache(
|
||||
"""
|
||||
batch_size = len(k_contexts)
|
||||
seq_lens = common_attn_metadata.seq_lens_cpu
|
||||
query_lens = common_attn_metadata.query_start_loc_cpu[
|
||||
1:] - common_attn_metadata.query_start_loc_cpu[:-1]
|
||||
query_lens = (
|
||||
common_attn_metadata.query_start_loc_cpu[1:]
|
||||
- common_attn_metadata.query_start_loc_cpu[:-1]
|
||||
)
|
||||
context_lens = common_attn_metadata.num_computed_tokens_cpu
|
||||
block_table = common_attn_metadata.block_table_tensor
|
||||
slot_mapping = common_attn_metadata.slot_mapping
|
||||
|
||||
# Create KV cache
|
||||
kv_cache = torch.empty(2,
|
||||
num_blocks,
|
||||
block_size,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
kv_cache = torch.empty(
|
||||
2, num_blocks, block_size, num_kv_heads, head_size, dtype=dtype, device=device
|
||||
)
|
||||
kv_cache_flat = kv_cache.view(2, -1, num_kv_heads, head_size)
|
||||
|
||||
# Populate the cache with the context tokens
|
||||
@@ -168,8 +172,8 @@ def create_and_prepopulate_kv_cache(
|
||||
start = common_attn_metadata.query_start_loc_cpu[i]
|
||||
end = common_attn_metadata.query_start_loc_cpu[i + 1]
|
||||
slot_mapping[start:end] = block_table[
|
||||
i,
|
||||
block_indices] * block_size + token_inter_block_offsets.to(device)
|
||||
i, block_indices
|
||||
] * block_size + token_inter_block_offsets.to(device)
|
||||
|
||||
return kv_cache
|
||||
|
||||
@@ -222,20 +226,19 @@ def run_attention_backend(
|
||||
# Return mock parameters for a single layer
|
||||
head_size = vllm_config.model_config.get_head_size()
|
||||
return {
|
||||
layer_name:
|
||||
PerLayerParameters(
|
||||
layer_name: PerLayerParameters(
|
||||
window_left=-1, # No sliding window
|
||||
logits_soft_cap=0.0, # No soft cap
|
||||
sm_scale=1.0 / (head_size**0.5) # Standard scale
|
||||
sm_scale=1.0 / (head_size**0.5), # Standard scale
|
||||
)
|
||||
for layer_name in layer_names
|
||||
}
|
||||
|
||||
with unittest.mock.patch(
|
||||
'vllm.v1.attention.backends.flashinfer.get_per_layer_parameters',
|
||||
mock_get_per_layer_parameters):
|
||||
builder = builder_cls(kv_cache_spec, layer_names, vllm_config,
|
||||
device)
|
||||
"vllm.v1.attention.backends.flashinfer.get_per_layer_parameters",
|
||||
mock_get_per_layer_parameters,
|
||||
):
|
||||
builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device)
|
||||
attn_metadata = builder.build(
|
||||
common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
@@ -252,9 +255,11 @@ def run_attention_backend(
|
||||
|
||||
# Instantiate implementation
|
||||
num_heads = vllm_config.model_config.get_num_attention_heads(
|
||||
vllm_config.parallel_config)
|
||||
vllm_config.parallel_config
|
||||
)
|
||||
num_kv_heads = vllm_config.model_config.get_num_kv_heads(
|
||||
vllm_config.parallel_config)
|
||||
vllm_config.parallel_config
|
||||
)
|
||||
head_size = vllm_config.model_config.get_head_size()
|
||||
scale = 1.0 / (head_size**0.5)
|
||||
impl = impl_cls(
|
||||
@@ -274,13 +279,9 @@ def run_attention_backend(
|
||||
# Run forward pass
|
||||
# NOTE: The query, key, and value are already shaped correctly
|
||||
# in the calling test function.
|
||||
output = impl.forward(mock_layer,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
output=output)
|
||||
output = impl.forward(
|
||||
mock_layer, query, key, value, kv_cache, attn_metadata, output=output
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
@@ -311,10 +312,12 @@ def _test_backend_correctness(
|
||||
5. Comparing the vLLM backend's output to the ground-truth SDPA output.
|
||||
"""
|
||||
current_platform.seed_everything(42)
|
||||
vllm_config = create_vllm_config(model_name=model,
|
||||
max_model_len=max(batch_spec.seq_lens),
|
||||
block_size=block_size,
|
||||
num_gpu_blocks=8192)
|
||||
vllm_config = create_vllm_config(
|
||||
model_name=model,
|
||||
max_model_len=max(batch_spec.seq_lens),
|
||||
block_size=block_size,
|
||||
num_gpu_blocks=8192,
|
||||
)
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
kv_cache_spec = create_standard_kv_cache_spec(vllm_config)
|
||||
@@ -324,9 +327,11 @@ def _test_backend_correctness(
|
||||
seq_lens = batch_spec.seq_lens
|
||||
query_lens = batch_spec.query_lens
|
||||
num_q_heads = vllm_config.model_config.get_num_attention_heads(
|
||||
vllm_config.parallel_config)
|
||||
vllm_config.parallel_config
|
||||
)
|
||||
num_kv_heads = vllm_config.model_config.get_num_kv_heads(
|
||||
vllm_config.parallel_config)
|
||||
vllm_config.parallel_config
|
||||
)
|
||||
head_size = vllm_config.model_config.get_head_size()
|
||||
sliding_window = vllm_config.model_config.get_sliding_window()
|
||||
dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype)
|
||||
@@ -344,21 +349,9 @@ def _test_backend_correctness(
|
||||
context_len = s_len - q_len
|
||||
|
||||
# Generate Q, K, V for the whole sequence to be used in SDPA
|
||||
q = torch.randn(q_len,
|
||||
num_q_heads,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
k_full = torch.randn(s_len,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
v_full = torch.randn(s_len,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
q = torch.randn(q_len, num_q_heads, head_size, dtype=dtype, device=device)
|
||||
k_full = torch.randn(s_len, num_kv_heads, head_size, dtype=dtype, device=device)
|
||||
v_full = torch.randn(s_len, num_kv_heads, head_size, dtype=dtype, device=device)
|
||||
|
||||
# SDPA expects (N, H, L, D), so unsqueeze batch and permute
|
||||
q_sdpa_in = q.unsqueeze(0).transpose(1, 2)
|
||||
@@ -368,7 +361,8 @@ def _test_backend_correctness(
|
||||
if num_q_heads != num_kv_heads:
|
||||
assert num_q_heads % num_kv_heads == 0, (
|
||||
f"num_q_heads ({num_q_heads}) must be divisible by "
|
||||
f"num_kv_heads ({num_kv_heads})")
|
||||
f"num_kv_heads ({num_kv_heads})"
|
||||
)
|
||||
repeats = num_q_heads // num_kv_heads
|
||||
k_sdpa_in = k_sdpa_in.repeat_interleave(repeats, dim=1)
|
||||
v_sdpa_in = v_sdpa_in.repeat_interleave(repeats, dim=1)
|
||||
@@ -378,18 +372,17 @@ def _test_backend_correctness(
|
||||
kv_len = s_len
|
||||
|
||||
final_mask_mod = partial(mask_mod, context_len=context_len)
|
||||
block_mask = create_block_mask(final_mask_mod,
|
||||
B=None,
|
||||
H=None,
|
||||
Q_LEN=q_len,
|
||||
KV_LEN=kv_len,
|
||||
device=device)
|
||||
sdpa_out_i = flex_attention(q_sdpa_in,
|
||||
k_sdpa_in,
|
||||
v_sdpa_in,
|
||||
block_mask=block_mask,
|
||||
scale=scale,
|
||||
enable_gqa=True)
|
||||
block_mask = create_block_mask(
|
||||
final_mask_mod, B=None, H=None, Q_LEN=q_len, KV_LEN=kv_len, device=device
|
||||
)
|
||||
sdpa_out_i = flex_attention(
|
||||
q_sdpa_in,
|
||||
k_sdpa_in,
|
||||
v_sdpa_in,
|
||||
block_mask=block_mask,
|
||||
scale=scale,
|
||||
enable_gqa=True,
|
||||
)
|
||||
|
||||
all_sdpa_outputs.append(sdpa_out_i.transpose(1, 2).squeeze(0))
|
||||
|
||||
@@ -408,7 +401,8 @@ def _test_backend_correctness(
|
||||
sdpa_output = torch.cat(all_sdpa_outputs, dim=0)
|
||||
|
||||
common_attn_metadata = create_common_attn_metadata(
|
||||
batch_spec, vllm_config.cache_config.block_size, device)
|
||||
batch_spec, vllm_config.cache_config.block_size, device
|
||||
)
|
||||
|
||||
# 3. Simulate Paged KV Cache and a realistic slot_mapping
|
||||
kv_cache = create_and_prepopulate_kv_cache(
|
||||
@@ -421,7 +415,8 @@ def _test_backend_correctness(
|
||||
device=device,
|
||||
num_blocks=vllm_config.cache_config.num_gpu_blocks or 1000,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
randomize_blocks=True)
|
||||
randomize_blocks=True,
|
||||
)
|
||||
|
||||
# 4. Run vLLM backends and compare
|
||||
# Note: flex_attention has known Triton kernel compatibility issues
|
||||
@@ -437,8 +432,9 @@ def _test_backend_correctness(
|
||||
kv_cache_for_backend = kv_cache.transpose(0, 1)
|
||||
|
||||
# For FlashInfer default to HND layout and
|
||||
kv_cache_for_backend = kv_cache_for_backend.transpose(
|
||||
2, 3).contiguous().transpose(2, 3)
|
||||
kv_cache_for_backend = (
|
||||
kv_cache_for_backend.transpose(2, 3).contiguous().transpose(2, 3)
|
||||
)
|
||||
set_kv_cache_layout("HND")
|
||||
|
||||
backend_output = run_attention_backend(
|
||||
@@ -458,32 +454,45 @@ def _test_backend_correctness(
|
||||
# Check shape and dtype consistency
|
||||
assert backend_output.shape == sdpa_output.shape, (
|
||||
f"[{backend_name}] shape {backend_output.shape} != "
|
||||
f"SDPA shape {sdpa_output.shape}")
|
||||
f"SDPA shape {sdpa_output.shape}"
|
||||
)
|
||||
assert backend_output.dtype == sdpa_output.dtype, (
|
||||
f"[{backend_name}] dtype {backend_output.dtype} != "
|
||||
f"SDPA dtype {sdpa_output.dtype}")
|
||||
f"SDPA dtype {sdpa_output.dtype}"
|
||||
)
|
||||
|
||||
assert torch.isfinite(backend_output).all(), (
|
||||
f"[{backend_name}] produced non-finite values")
|
||||
f"[{backend_name}] produced non-finite values"
|
||||
)
|
||||
|
||||
# Check numerical similarity
|
||||
def error_msg(msg: str, backend_name: str):
|
||||
return (f"[{backend_name}] output differs from SDPA baseline. "
|
||||
f"{msg}")
|
||||
return f"[{backend_name}] output differs from SDPA baseline. {msg}"
|
||||
|
||||
torch.testing.assert_close(backend_output,
|
||||
sdpa_output,
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
msg=partial(error_msg,
|
||||
backend_name=backend_name))
|
||||
torch.testing.assert_close(
|
||||
backend_output,
|
||||
sdpa_output,
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
msg=partial(error_msg, backend_name=backend_name),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_spec_name", [
|
||||
"small_decode", "small_prefill", "mixed_small", "medium_decode",
|
||||
"medium_prefill", "mixed_medium", "large_decode", "large_prefill",
|
||||
"single_decode", "single_prefill"
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"batch_spec_name",
|
||||
[
|
||||
"small_decode",
|
||||
"small_prefill",
|
||||
"mixed_small",
|
||||
"medium_decode",
|
||||
"medium_prefill",
|
||||
"mixed_medium",
|
||||
"large_decode",
|
||||
"large_prefill",
|
||||
"single_decode",
|
||||
"single_prefill",
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"])
|
||||
def test_causal_backend_correctness(batch_spec_name: str, model: str):
|
||||
"""Test backend's correctness with causal attention."""
|
||||
@@ -499,33 +508,33 @@ def test_causal_backend_correctness(batch_spec_name: str, model: str):
|
||||
return (q_idx + context_len) >= kv_idx
|
||||
|
||||
batch_spec = BATCH_SPECS[batch_spec_name]
|
||||
LARGE_BLOCK_BACKENDS = ([_Backend.FLEX_ATTENTION]
|
||||
if is_torch_equal_or_newer("2.9.0.dev0") else [])
|
||||
LARGE_BLOCK_BACKENDS = (
|
||||
[_Backend.FLEX_ATTENTION] if is_torch_equal_or_newer("2.9.0.dev0") else []
|
||||
)
|
||||
SMALL_BLOCK_BACKENDS = [
|
||||
x for x in BACKENDS_TO_TEST if x not in LARGE_BLOCK_BACKENDS
|
||||
]
|
||||
_test_backend_correctness(batch_spec, model, SMALL_BLOCK_BACKENDS,
|
||||
causal_mask_mod)
|
||||
_test_backend_correctness(batch_spec, model, SMALL_BLOCK_BACKENDS, causal_mask_mod)
|
||||
|
||||
# Fast FlexAttention needs to run with block_size=128
|
||||
if LARGE_BLOCK_BACKENDS:
|
||||
_test_backend_correctness(batch_spec,
|
||||
model,
|
||||
LARGE_BLOCK_BACKENDS,
|
||||
causal_mask_mod,
|
||||
block_size=128)
|
||||
_test_backend_correctness(
|
||||
batch_spec, model, LARGE_BLOCK_BACKENDS, causal_mask_mod, block_size=128
|
||||
)
|
||||
|
||||
|
||||
SLIDING_WINDOW_BACKENDS_TO_TEST = [
|
||||
_Backend.FLASH_ATTN, _Backend.FLEX_ATTENTION, _Backend.TRITON_ATTN,
|
||||
"FLEX_ATTENTION_SLOW"
|
||||
_Backend.FLASH_ATTN,
|
||||
_Backend.FLEX_ATTENTION,
|
||||
_Backend.TRITON_ATTN,
|
||||
"FLEX_ATTENTION_SLOW",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_spec_name", [
|
||||
"small_decode", "small_prefill", "mixed_medium", "large_decode",
|
||||
"large_prefill"
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"batch_spec_name",
|
||||
["small_decode", "small_prefill", "mixed_medium", "large_decode", "large_prefill"],
|
||||
)
|
||||
@pytest.mark.parametrize("model", ["microsoft/Phi-tiny-MoE-instruct"])
|
||||
def test_sliding_window_backend_correctness(batch_spec_name: str, model: str):
|
||||
"""Test backend's correctness with sliding window attention."""
|
||||
@@ -544,25 +553,28 @@ def test_sliding_window_backend_correctness(batch_spec_name: str, model: str):
|
||||
return causal_mask & window_mask
|
||||
|
||||
batch_spec = BATCH_SPECS[batch_spec_name]
|
||||
model_config = ModelConfig(model=model,
|
||||
max_model_len=max(batch_spec.seq_lens))
|
||||
model_config = ModelConfig(model=model, max_model_len=max(batch_spec.seq_lens))
|
||||
sliding_window = model_config.get_sliding_window()
|
||||
sliding_window_mask_mod_fn = partial(sliding_window_mask_mod,
|
||||
sliding_window=sliding_window)
|
||||
sliding_window_mask_mod_fn = partial(
|
||||
sliding_window_mask_mod, sliding_window=sliding_window
|
||||
)
|
||||
|
||||
LARGE_BLOCK_BACKENDS = ([_Backend.FLEX_ATTENTION]
|
||||
if is_torch_equal_or_newer("2.9.0.dev0") else [])
|
||||
LARGE_BLOCK_BACKENDS = (
|
||||
[_Backend.FLEX_ATTENTION] if is_torch_equal_or_newer("2.9.0.dev0") else []
|
||||
)
|
||||
SMALL_BLOCK_BACKENDS = [
|
||||
x for x in SLIDING_WINDOW_BACKENDS_TO_TEST
|
||||
if x not in LARGE_BLOCK_BACKENDS
|
||||
x for x in SLIDING_WINDOW_BACKENDS_TO_TEST if x not in LARGE_BLOCK_BACKENDS
|
||||
]
|
||||
_test_backend_correctness(batch_spec, model, SMALL_BLOCK_BACKENDS,
|
||||
sliding_window_mask_mod_fn)
|
||||
_test_backend_correctness(
|
||||
batch_spec, model, SMALL_BLOCK_BACKENDS, sliding_window_mask_mod_fn
|
||||
)
|
||||
|
||||
# Fast FlexAttention needs to run with block_size=128
|
||||
if LARGE_BLOCK_BACKENDS:
|
||||
_test_backend_correctness(batch_spec,
|
||||
model,
|
||||
LARGE_BLOCK_BACKENDS,
|
||||
sliding_window_mask_mod_fn,
|
||||
block_size=128)
|
||||
_test_backend_correctness(
|
||||
batch_spec,
|
||||
model,
|
||||
LARGE_BLOCK_BACKENDS,
|
||||
sliding_window_mask_mod_fn,
|
||||
block_size=128,
|
||||
)
|
||||
|
||||
@@ -9,17 +9,16 @@ import pytest
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
||||
from vllm.model_executor.layers.mamba.short_conv import ShortConv
|
||||
from vllm.model_executor.models.minimax_text_01 import (
|
||||
MiniMaxText01LinearAttention)
|
||||
from vllm.model_executor.models.minimax_text_01 import MiniMaxText01LinearAttention
|
||||
from vllm.v1.attention.backends.linear_attn import LinearAttentionBackend
|
||||
from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionBackend
|
||||
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionBackend
|
||||
from vllm.v1.attention.backends.short_conv_attn import (
|
||||
ShortConvAttentionBackend)
|
||||
from vllm.v1.attention.backends.short_conv_attn import ShortConvAttentionBackend
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"layer_class, init_kwargs, expected_backend, expected_mamba_type", [
|
||||
"layer_class, init_kwargs, expected_backend, expected_mamba_type",
|
||||
[
|
||||
(
|
||||
MambaMixer,
|
||||
dict(
|
||||
@@ -77,9 +76,11 @@ from vllm.v1.attention.backends.short_conv_attn import (
|
||||
ShortConvAttentionBackend,
|
||||
"short_conv",
|
||||
),
|
||||
])
|
||||
def test_mamba_layers_get_attn_backend(dist_init, layer_class, init_kwargs,
|
||||
expected_backend, expected_mamba_type):
|
||||
],
|
||||
)
|
||||
def test_mamba_layers_get_attn_backend(
|
||||
dist_init, layer_class, init_kwargs, expected_backend, expected_mamba_type
|
||||
):
|
||||
"""Test that Mamba-like layers return the correct attention backend."""
|
||||
layer = layer_class(**init_kwargs)
|
||||
|
||||
@@ -88,17 +89,23 @@ def test_mamba_layers_get_attn_backend(dist_init, layer_class, init_kwargs,
|
||||
assert layer.mamba_type == expected_mamba_type
|
||||
|
||||
|
||||
@pytest.mark.parametrize("layer_class,expected_backend,expected_mamba_type", [
|
||||
(MambaMixer, Mamba1AttentionBackend, "mamba1"),
|
||||
(MambaMixer2, Mamba2AttentionBackend, "mamba2"),
|
||||
(MiniMaxText01LinearAttention, LinearAttentionBackend, "linear_attention"),
|
||||
(ShortConv, ShortConvAttentionBackend, "short_conv"),
|
||||
])
|
||||
def test_mamba_layers_have_unified_interface(layer_class, expected_backend,
|
||||
expected_mamba_type):
|
||||
"""Test that all Mamba layers have the unified get_attn_backend
|
||||
@pytest.mark.parametrize(
|
||||
"layer_class,expected_backend,expected_mamba_type",
|
||||
[
|
||||
(MambaMixer, Mamba1AttentionBackend, "mamba1"),
|
||||
(MambaMixer2, Mamba2AttentionBackend, "mamba2"),
|
||||
(MiniMaxText01LinearAttention, LinearAttentionBackend, "linear_attention"),
|
||||
(ShortConv, ShortConvAttentionBackend, "short_conv"),
|
||||
],
|
||||
)
|
||||
def test_mamba_layers_have_unified_interface(
|
||||
layer_class, expected_backend, expected_mamba_type
|
||||
):
|
||||
"""Test that all Mamba layers have the unified get_attn_backend
|
||||
interface."""
|
||||
assert hasattr(layer_class, 'get_attn_backend'), (
|
||||
f"{layer_class.__name__} should have get_attn_backend method")
|
||||
assert hasattr(layer_class, 'mamba_type'), (
|
||||
f"{layer_class.__name__} should have mamba_type property")
|
||||
assert hasattr(layer_class, "get_attn_backend"), (
|
||||
f"{layer_class.__name__} should have get_attn_backend method"
|
||||
)
|
||||
assert hasattr(layer_class, "mamba_type"), (
|
||||
f"{layer_class.__name__} should have mamba_type property"
|
||||
)
|
||||
|
||||
@@ -6,11 +6,13 @@ import torch
|
||||
|
||||
from tests.v1.attention.test_attention_backends import BATCH_SPECS
|
||||
from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
|
||||
from vllm.v1.attention.backends.utils import (UBatchSlice,
|
||||
_make_metadata_with_slice,
|
||||
slice_query_start_locs,
|
||||
split_attn_metadata,
|
||||
split_decodes_and_prefills)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
UBatchSlice,
|
||||
_make_metadata_with_slice,
|
||||
slice_query_start_locs,
|
||||
split_attn_metadata,
|
||||
split_decodes_and_prefills,
|
||||
)
|
||||
from vllm.v1.worker.ubatch_splitting import create_ubatch_slices
|
||||
|
||||
|
||||
@@ -79,9 +81,7 @@ def small_decode_metadata():
|
||||
"""Create metadata for small decode batch"""
|
||||
batch_spec = BATCH_SPECS["small_decode"]
|
||||
device = torch.device("cpu")
|
||||
return create_common_attn_metadata(batch_spec,
|
||||
block_size=16,
|
||||
device=device)
|
||||
return create_common_attn_metadata(batch_spec, block_size=16, device=device)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -89,9 +89,7 @@ def large_decode_metadata():
|
||||
"""Create metadata for small decode batch"""
|
||||
batch_spec = BATCH_SPECS["large_decode"]
|
||||
device = torch.device("cpu")
|
||||
return create_common_attn_metadata(batch_spec,
|
||||
block_size=16,
|
||||
device=device)
|
||||
return create_common_attn_metadata(batch_spec, block_size=16, device=device)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -99,9 +97,7 @@ def mixed_small_metadata():
|
||||
"""Create metadata for mixed small batch"""
|
||||
batch_spec = BATCH_SPECS["mixed_small"]
|
||||
device = torch.device("cpu")
|
||||
return create_common_attn_metadata(batch_spec,
|
||||
block_size=16,
|
||||
device=device)
|
||||
return create_common_attn_metadata(batch_spec, block_size=16, device=device)
|
||||
|
||||
|
||||
# Tests for _make_metadata_with_slice
|
||||
@@ -122,8 +118,7 @@ def test_make_metadata_with_slice_decode_batch(small_decode_metadata):
|
||||
|
||||
def test_make_metadata_with_slice_mixed_batch(mixed_small_metadata):
|
||||
"""Test slicing mixed batch metadata"""
|
||||
ubatch_slice = UBatchSlice(slice(1, 3),
|
||||
slice(1, 7)) # Requests 1-3, tokens 1-7
|
||||
ubatch_slice = UBatchSlice(slice(1, 3), slice(1, 7)) # Requests 1-3, tokens 1-7
|
||||
|
||||
result = _make_metadata_with_slice(ubatch_slice, mixed_small_metadata)
|
||||
|
||||
@@ -140,8 +135,7 @@ def test_split_attn_metadata_decode_batch(large_decode_metadata):
|
||||
mid_point = num_tokens // 2
|
||||
ubatch_slices = [
|
||||
UBatchSlice(slice(0, mid_point), slice(0, mid_point)),
|
||||
UBatchSlice(slice(mid_point, num_tokens), slice(mid_point,
|
||||
num_tokens)),
|
||||
UBatchSlice(slice(mid_point, num_tokens), slice(mid_point, num_tokens)),
|
||||
]
|
||||
|
||||
results = split_attn_metadata(ubatch_slices, large_decode_metadata)
|
||||
@@ -159,26 +153,30 @@ def test_split_attn_metadata_decode_batch(large_decode_metadata):
|
||||
assert torch.equal(results[1].seq_lens, torch.tensor([2048] * mid_point))
|
||||
|
||||
|
||||
def apply_split_decodes_and_prefills(query_lens: list[int],
|
||||
decode_threshold: int,
|
||||
require_uniform: bool):
|
||||
def apply_split_decodes_and_prefills(
|
||||
query_lens: list[int], decode_threshold: int, require_uniform: bool
|
||||
):
|
||||
"""Helper function to apply split_decodes_and_prefills and return
|
||||
the results."""
|
||||
device = torch.device("cpu")
|
||||
seq_lens = [10 * (i + 1) for i in range(len(query_lens))]
|
||||
common_metadata = create_common_attn_metadata(BatchSpec(
|
||||
seq_lens=seq_lens, query_lens=query_lens),
|
||||
block_size=16,
|
||||
device=device)
|
||||
return split_decodes_and_prefills(common_metadata,
|
||||
decode_threshold=decode_threshold,
|
||||
require_uniform=require_uniform)
|
||||
common_metadata = create_common_attn_metadata(
|
||||
BatchSpec(seq_lens=seq_lens, query_lens=query_lens),
|
||||
block_size=16,
|
||||
device=device,
|
||||
)
|
||||
return split_decodes_and_prefills(
|
||||
common_metadata,
|
||||
decode_threshold=decode_threshold,
|
||||
require_uniform=require_uniform,
|
||||
)
|
||||
|
||||
|
||||
def test_split_decodes_and_prefills_nonuniform_all_ones():
|
||||
query_lens = [1, 1, 1]
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
apply_split_decodes_and_prefills(query_lens, 1, False))
|
||||
apply_split_decodes_and_prefills(query_lens, 1, False)
|
||||
)
|
||||
assert num_decodes == 3
|
||||
assert num_prefills == 0
|
||||
assert num_decode_tokens == 3
|
||||
@@ -188,7 +186,8 @@ def test_split_decodes_and_prefills_nonuniform_all_ones():
|
||||
def test_split_decodes_and_prefills_nonuniform_all_short_decodes():
|
||||
query_lens = [1, 2, 1, 3, 2, 1, 2]
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
apply_split_decodes_and_prefills(query_lens, 3, False))
|
||||
apply_split_decodes_and_prefills(query_lens, 3, False)
|
||||
)
|
||||
assert num_decodes == 7
|
||||
assert num_prefills == 0
|
||||
assert num_decode_tokens == sum(query_lens)
|
||||
@@ -198,7 +197,8 @@ def test_split_decodes_and_prefills_nonuniform_all_short_decodes():
|
||||
def test_split_decodes_and_prefills_nonuniform_all_prefills():
|
||||
query_lens = [4, 5, 6, 7]
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
apply_split_decodes_and_prefills(query_lens, 3, False))
|
||||
apply_split_decodes_and_prefills(query_lens, 3, False)
|
||||
)
|
||||
assert num_decodes == 0
|
||||
assert num_prefills == 4
|
||||
assert num_decode_tokens == 0
|
||||
@@ -208,7 +208,8 @@ def test_split_decodes_and_prefills_nonuniform_all_prefills():
|
||||
def test_split_decodes_and_prefills_nonuniform_mixed_batch():
|
||||
query_lens = [2, 1, 3, 4, 5, 6, 7, 8]
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
apply_split_decodes_and_prefills(query_lens, 4, False))
|
||||
apply_split_decodes_and_prefills(query_lens, 4, False)
|
||||
)
|
||||
assert num_decodes == 4 # 2, 1, 3, 4 are all <= 4
|
||||
assert num_prefills == 4 # 5, 6, 7, 8 are all > 4
|
||||
assert num_decode_tokens == 10 # 2 + 1 + 3 + 4
|
||||
@@ -218,7 +219,8 @@ def test_split_decodes_and_prefills_nonuniform_mixed_batch():
|
||||
def test_split_decodes_and_prefills_uniform_all_ones():
|
||||
query_lens = [1, 1, 1]
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
apply_split_decodes_and_prefills(query_lens, 1, True))
|
||||
apply_split_decodes_and_prefills(query_lens, 1, True)
|
||||
)
|
||||
assert num_decodes == 3
|
||||
assert num_prefills == 0
|
||||
assert num_decode_tokens == 3
|
||||
@@ -228,7 +230,8 @@ def test_split_decodes_and_prefills_uniform_all_ones():
|
||||
def test_split_decodes_and_prefills_uniform_all_short_decodes():
|
||||
query_lens = [2, 2, 1, 3, 2, 1, 2]
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
apply_split_decodes_and_prefills(query_lens, 3, True))
|
||||
apply_split_decodes_and_prefills(query_lens, 3, True)
|
||||
)
|
||||
assert num_decodes == 2
|
||||
assert num_prefills == 5
|
||||
assert num_decode_tokens == 4
|
||||
@@ -238,7 +241,8 @@ def test_split_decodes_and_prefills_uniform_all_short_decodes():
|
||||
def test_split_decodes_and_prefills_uniform_all_prefills():
|
||||
query_lens = [4, 5, 6, 7]
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
apply_split_decodes_and_prefills(query_lens, 3, True))
|
||||
apply_split_decodes_and_prefills(query_lens, 3, True)
|
||||
)
|
||||
assert num_decodes == 0
|
||||
assert num_prefills == 4
|
||||
assert num_decode_tokens == 0
|
||||
@@ -248,7 +252,8 @@ def test_split_decodes_and_prefills_uniform_all_prefills():
|
||||
def test_split_decodes_and_prefills_uniform_mixed_batch_all_uniform_decodes():
|
||||
query_lens = [2, 2, 2, 4, 5, 6, 7, 8]
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
apply_split_decodes_and_prefills(query_lens, 4, True))
|
||||
apply_split_decodes_and_prefills(query_lens, 4, True)
|
||||
)
|
||||
assert num_decodes == 3 # 2, 2, 2 are all <= 4 and uniform
|
||||
assert num_prefills == 5 # 4, 5, 6, 7, 8 are all > 4
|
||||
assert num_decode_tokens == 6 # 2 + 2 + 2
|
||||
@@ -258,7 +263,8 @@ def test_split_decodes_and_prefills_uniform_mixed_batch_all_uniform_decodes():
|
||||
def test_split_decodes_and_prefills_uniform_mixed_batch_non_uniform_decodes():
|
||||
query_lens = [2, 1, 2, 4, 5, 6, 7, 8]
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
apply_split_decodes_and_prefills(query_lens, 4, True))
|
||||
apply_split_decodes_and_prefills(query_lens, 4, True)
|
||||
)
|
||||
assert num_decodes == 1 # only the first 2 is taken as decode
|
||||
assert num_prefills == 7 # 1, 2, 4, 5, 6, 7, 8 are all > 4 or non-uniform
|
||||
assert num_decode_tokens == 2 # only the first 2
|
||||
@@ -274,17 +280,15 @@ def test_split_decodes_and_prefills_uniform_mixed_batch_non_uniform_decodes():
|
||||
([32, 40], [8, 8], 4, 1, 2),
|
||||
],
|
||||
)
|
||||
def test_prefill_split_across_ubatches(seq_lens, query_lens, split_point,
|
||||
expected_first_reqs,
|
||||
expected_second_reqs):
|
||||
def test_prefill_split_across_ubatches(
|
||||
seq_lens, query_lens, split_point, expected_first_reqs, expected_second_reqs
|
||||
):
|
||||
"""Test splitting a prefill across ubatches"""
|
||||
import numpy as np
|
||||
|
||||
device = torch.device("cpu")
|
||||
batch_spec = BatchSpec(seq_lens=seq_lens, query_lens=query_lens)
|
||||
common = create_common_attn_metadata(batch_spec,
|
||||
block_size=16,
|
||||
device=device)
|
||||
common = create_common_attn_metadata(batch_spec, block_size=16, device=device)
|
||||
|
||||
num_scheduled_tokens = np.array(query_lens, dtype=np.int32)
|
||||
qsl_np = common.query_start_loc_cpu.numpy()
|
||||
@@ -307,19 +311,19 @@ def test_prefill_split_across_ubatches(seq_lens, query_lens, split_point,
|
||||
# Identify which request is split and how many tokens are in the first chunk
|
||||
split_req_idx = int(np.searchsorted(qsl_np, split_point, side="right") - 1)
|
||||
tokens_in_first_chunk = split_point - int(qsl_np[split_req_idx])
|
||||
orig_q_lens = (common.query_start_loc_cpu[1:] -
|
||||
common.query_start_loc_cpu[:-1])
|
||||
orig_q_lens = common.query_start_loc_cpu[1:] - common.query_start_loc_cpu[:-1]
|
||||
|
||||
# Check query length continuity: first-chunk + second-chunk == original qlen
|
||||
# First ubatch last request query length
|
||||
qlen_first_last = int(first_meta.query_start_loc_cpu[-1] -
|
||||
first_meta.query_start_loc_cpu[-2])
|
||||
qlen_first_last = int(
|
||||
first_meta.query_start_loc_cpu[-1] - first_meta.query_start_loc_cpu[-2]
|
||||
)
|
||||
# Second ubatch first request query length
|
||||
qlen_second_first = int(second_meta.query_start_loc_cpu[1] -
|
||||
second_meta.query_start_loc_cpu[0])
|
||||
qlen_second_first = int(
|
||||
second_meta.query_start_loc_cpu[1] - second_meta.query_start_loc_cpu[0]
|
||||
)
|
||||
assert qlen_first_last == tokens_in_first_chunk
|
||||
assert qlen_first_last + qlen_second_first == int(
|
||||
orig_q_lens[split_req_idx])
|
||||
assert qlen_first_last + qlen_second_first == int(orig_q_lens[split_req_idx])
|
||||
|
||||
# Check seq_lens adjustments
|
||||
# Context lengths per original request
|
||||
|
||||
@@ -7,8 +7,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
make_local_attention_virtual_batches)
|
||||
from vllm.v1.attention.backends.utils import make_local_attention_virtual_batches
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -46,21 +45,24 @@ test_data_list = [
|
||||
[17, 17], # local-batch 5, (batch 1, starting from k[16])
|
||||
[20, 21], # local-batch 6, (batch 2, starting from k[4])
|
||||
[22, 23], # local-batch 7, (batch 2, starting from k[8])
|
||||
]),
|
||||
],
|
||||
),
|
||||
# Case where block indices are not clipped to block table ncols-1
|
||||
# because tokens_in_last_block == attn_chunk_size
|
||||
LocalAttentionTestData(batch_spec=BatchSpec(
|
||||
query_lens=[8],
|
||||
seq_lens=[12],
|
||||
LocalAttentionTestData(
|
||||
batch_spec=BatchSpec(
|
||||
query_lens=[8],
|
||||
seq_lens=[12],
|
||||
),
|
||||
attn_chunk_size=4,
|
||||
block_size=2,
|
||||
expected_q_seqlens=[4, 4],
|
||||
expected_k_seqlens=[4, 4],
|
||||
expected_local_block_table=[
|
||||
[2, 3],
|
||||
[4, 5],
|
||||
],
|
||||
),
|
||||
attn_chunk_size=4,
|
||||
block_size=2,
|
||||
expected_q_seqlens=[4, 4],
|
||||
expected_k_seqlens=[4, 4],
|
||||
expected_local_block_table=[
|
||||
[2, 3],
|
||||
[4, 5],
|
||||
]),
|
||||
# Case where all kv_seq positions are involved in attn
|
||||
LocalAttentionTestData(
|
||||
batch_spec=BatchSpec(
|
||||
@@ -76,7 +78,8 @@ test_data_list = [
|
||||
[0, 1],
|
||||
[2, 3],
|
||||
[4, 4],
|
||||
]),
|
||||
],
|
||||
),
|
||||
# Case where attn_chunk_size > kv_seq_len
|
||||
# so no extra mini virtual batches are created
|
||||
LocalAttentionTestData(
|
||||
@@ -97,7 +100,8 @@ test_data_list = [
|
||||
# is calculated as (attn_chunk_size // block_size)
|
||||
expected_local_block_table=[
|
||||
[0, 1, 2, 2, 2],
|
||||
]),
|
||||
],
|
||||
),
|
||||
# Block size equal to chunk size
|
||||
# Expect single page per batch in local batch table
|
||||
LocalAttentionTestData(
|
||||
@@ -118,7 +122,8 @@ test_data_list = [
|
||||
[1], # local-batch 1, (batch 0, starting from k[4])
|
||||
[2], # local-batch 1, (batch 0, starting from k[0])
|
||||
[3], # local-batch 1, (batch 0, starting from k[4])
|
||||
]),
|
||||
],
|
||||
),
|
||||
# Case where query falls in the second attention chunk
|
||||
# k_toks > 0 1 2 3 4
|
||||
# q_toks v _____________
|
||||
@@ -128,17 +133,19 @@ test_data_list = [
|
||||
# 3 | 1 1 1 1
|
||||
# 4 | 1
|
||||
# where tokens 0,1,2,3 have been pre-computed
|
||||
LocalAttentionTestData(batch_spec=BatchSpec(
|
||||
query_lens=[1],
|
||||
seq_lens=[5],
|
||||
LocalAttentionTestData(
|
||||
batch_spec=BatchSpec(
|
||||
query_lens=[1],
|
||||
seq_lens=[5],
|
||||
),
|
||||
attn_chunk_size=4,
|
||||
block_size=2,
|
||||
expected_q_seqlens=[1],
|
||||
expected_k_seqlens=[1],
|
||||
expected_local_block_table=[
|
||||
[2, 2],
|
||||
],
|
||||
),
|
||||
attn_chunk_size=4,
|
||||
block_size=2,
|
||||
expected_q_seqlens=[1],
|
||||
expected_k_seqlens=[1],
|
||||
expected_local_block_table=[
|
||||
[2, 2],
|
||||
]),
|
||||
]
|
||||
|
||||
|
||||
@@ -165,9 +172,9 @@ def test_local_attention_virtual_batches(test_data: LocalAttentionTestData):
|
||||
)
|
||||
|
||||
# Call the function
|
||||
result = make_local_attention_virtual_batches(attn_chunk_size,
|
||||
common_attn_metadata,
|
||||
block_size)
|
||||
result = make_local_attention_virtual_batches(
|
||||
attn_chunk_size, common_attn_metadata, block_size
|
||||
)
|
||||
|
||||
# Convert to numpy for easier comparison
|
||||
actual_q_seqlens = np.diff(result.query_start_loc_cpu.numpy())
|
||||
@@ -184,13 +191,11 @@ def test_local_attention_virtual_batches(test_data: LocalAttentionTestData):
|
||||
np.testing.assert_array_equal(actual_q_seqlens, expected_q_seqlens)
|
||||
np.testing.assert_array_equal(actual_k_seqlens, expected_k_seqlens)
|
||||
|
||||
expected_block_table_tensor =\
|
||||
torch.tensor(expected_local_block_table,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
expected_block_table_tensor = torch.tensor(
|
||||
expected_local_block_table, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
print(f"Expected block table:\n{expected_block_table_tensor}")
|
||||
print(f"Actual block table:\n{result.block_table_tensor}")
|
||||
|
||||
torch.testing.assert_close(result.block_table_tensor,
|
||||
expected_block_table_tensor)
|
||||
torch.testing.assert_close(result.block_table_tensor, expected_block_table_tensor)
|
||||
|
||||
@@ -1,15 +1,19 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for v1 MLA backends without GPUModelRunner dependency."""
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.v1.attention.utils import (BatchSpec, create_common_attn_metadata,
|
||||
create_standard_kv_cache_spec,
|
||||
create_vllm_config,
|
||||
get_attention_backend)
|
||||
from tests.v1.attention.utils import (
|
||||
BatchSpec,
|
||||
create_common_attn_metadata,
|
||||
create_standard_kv_cache_spec,
|
||||
create_vllm_config,
|
||||
get_attention_backend,
|
||||
)
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
|
||||
@@ -17,13 +21,14 @@ from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec
|
||||
|
||||
BACKENDS_TO_TEST = [
|
||||
_Backend.CUTLASS_MLA, _Backend.FLASHMLA, _Backend.FLASH_ATTN_MLA,
|
||||
_Backend.TRITON_MLA
|
||||
_Backend.CUTLASS_MLA,
|
||||
_Backend.FLASHMLA,
|
||||
_Backend.FLASH_ATTN_MLA,
|
||||
_Backend.TRITON_MLA,
|
||||
]
|
||||
|
||||
# Remove CUTLASS_MLA from the list if not using sm100
|
||||
if not torch.cuda.is_available() or torch.cuda.get_device_properties(
|
||||
0).major < 10:
|
||||
if not torch.cuda.is_available() or torch.cuda.get_device_properties(0).major < 10:
|
||||
BACKENDS_TO_TEST.remove(_Backend.CUTLASS_MLA)
|
||||
|
||||
torch.manual_seed(42)
|
||||
@@ -46,45 +51,41 @@ def _convert_dtype_to_torch(dtype):
|
||||
|
||||
# Define common batch configurations
|
||||
BATCH_SPECS = {
|
||||
"small_decode":
|
||||
BatchSpec(seq_lens=[32, 40], query_lens=[1, 1]),
|
||||
"small_prefill":
|
||||
BatchSpec(seq_lens=[32, 40], query_lens=[8, 8]),
|
||||
"mixed_small":
|
||||
BatchSpec(seq_lens=[32, 40, 48, 56], query_lens=[1, 1, 5, 5]),
|
||||
"medium_decode":
|
||||
BatchSpec(seq_lens=[128, 256, 512, 1024, 128, 256, 512, 1024],
|
||||
query_lens=[1, 1, 1, 1, 1, 1, 1, 1]),
|
||||
"medium_prefill":
|
||||
BatchSpec(seq_lens=[256, 512, 1024, 2048], query_lens=[16, 16, 16, 16]),
|
||||
"mixed_medium":
|
||||
BatchSpec(seq_lens=[512, 1024, 2048, 512, 1024, 2048],
|
||||
query_lens=[1, 1, 1, 7, 7, 7]),
|
||||
"large_decode":
|
||||
BatchSpec(seq_lens=[2048] * 32, query_lens=[1] * 32),
|
||||
"large_prefill":
|
||||
BatchSpec(seq_lens=[4096] * 8, query_lens=[32] * 8),
|
||||
"single_decode":
|
||||
BatchSpec(seq_lens=[1024], query_lens=[1]),
|
||||
"single_prefill":
|
||||
BatchSpec(seq_lens=[1024], query_lens=[64]),
|
||||
"small_decode": BatchSpec(seq_lens=[32, 40], query_lens=[1, 1]),
|
||||
"small_prefill": BatchSpec(seq_lens=[32, 40], query_lens=[8, 8]),
|
||||
"mixed_small": BatchSpec(seq_lens=[32, 40, 48, 56], query_lens=[1, 1, 5, 5]),
|
||||
"medium_decode": BatchSpec(
|
||||
seq_lens=[128, 256, 512, 1024, 128, 256, 512, 1024],
|
||||
query_lens=[1, 1, 1, 1, 1, 1, 1, 1],
|
||||
),
|
||||
"medium_prefill": BatchSpec(
|
||||
seq_lens=[256, 512, 1024, 2048], query_lens=[16, 16, 16, 16]
|
||||
),
|
||||
"mixed_medium": BatchSpec(
|
||||
seq_lens=[512, 1024, 2048, 512, 1024, 2048], query_lens=[1, 1, 1, 7, 7, 7]
|
||||
),
|
||||
"large_decode": BatchSpec(seq_lens=[2048] * 32, query_lens=[1] * 32),
|
||||
"large_prefill": BatchSpec(seq_lens=[4096] * 8, query_lens=[32] * 8),
|
||||
"single_decode": BatchSpec(seq_lens=[1024], query_lens=[1]),
|
||||
"single_prefill": BatchSpec(seq_lens=[1024], query_lens=[64]),
|
||||
}
|
||||
|
||||
|
||||
def create_and_prepopulate_kv_cache(
|
||||
kv_c_contexts: list[torch.Tensor],
|
||||
k_pe_contexts: list[torch.Tensor],
|
||||
block_size: int,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
num_blocks: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
randomize_blocks: bool = True,
|
||||
kv_cache_dtype: Optional[str] = None,
|
||||
scale: Union[float, torch.Tensor] = 1.0) -> torch.Tensor:
|
||||
kv_c_contexts: list[torch.Tensor],
|
||||
k_pe_contexts: list[torch.Tensor],
|
||||
block_size: int,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
num_blocks: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
randomize_blocks: bool = True,
|
||||
kv_cache_dtype: Optional[str] = None,
|
||||
scale: Union[float, torch.Tensor] = 1.0,
|
||||
) -> torch.Tensor:
|
||||
"""Create and prepopulate an MLA KV cache with context data.
|
||||
|
||||
|
||||
Args:
|
||||
kv_c_contexts: List of latent KV context tensors for each sequence
|
||||
k_pe_contexts: List of key positional embedding context tensors
|
||||
@@ -95,21 +96,23 @@ def create_and_prepopulate_kv_cache(
|
||||
device: Device to create the cache on
|
||||
num_blocks: Total number of blocks in the cache
|
||||
common_attn_metadata: Common attention metadata
|
||||
randomize_blocks: Whether to randomly permute blocks
|
||||
randomize_blocks: Whether to randomly permute blocks
|
||||
or use sequential order
|
||||
kv_cache_dtype: Optional kv cache dtype string. When set to
|
||||
"fp8_ds_mla" the cache is populated using the
|
||||
fp8 DeepSeek MLA layout via concat_and_cache_mla.
|
||||
scale: Scaling factor forwarded to concat_and_cache_mla when the
|
||||
fp8 cache layout is requested.
|
||||
|
||||
|
||||
Returns:
|
||||
MLA KV cache tensor
|
||||
"""
|
||||
batch_size = len(kv_c_contexts)
|
||||
seq_lens = common_attn_metadata.seq_lens_cpu
|
||||
query_lens = common_attn_metadata.query_start_loc_cpu[
|
||||
1:] - common_attn_metadata.query_start_loc_cpu[:-1]
|
||||
query_lens = (
|
||||
common_attn_metadata.query_start_loc_cpu[1:]
|
||||
- common_attn_metadata.query_start_loc_cpu[:-1]
|
||||
)
|
||||
context_lens = common_attn_metadata.num_computed_tokens_cpu
|
||||
block_table = common_attn_metadata.block_table_tensor
|
||||
slot_mapping = common_attn_metadata.slot_mapping
|
||||
@@ -118,27 +121,26 @@ def create_and_prepopulate_kv_cache(
|
||||
|
||||
if use_fp8_ds_mla:
|
||||
if not kv_c_contexts:
|
||||
raise ValueError("kv_c_contexts cannot be empty when using"
|
||||
" fp8_ds_mla cache dtype")
|
||||
raise ValueError(
|
||||
"kv_c_contexts cannot be empty when using fp8_ds_mla cache dtype"
|
||||
)
|
||||
kv_lora_rank = kv_c_contexts[0].shape[-1]
|
||||
rope_dim = k_pe_contexts[0].shape[-1]
|
||||
entry_size = kv_lora_rank + 4 * 4 + 2 * rope_dim
|
||||
kv_cache = torch.zeros(num_blocks,
|
||||
block_size,
|
||||
entry_size,
|
||||
dtype=torch.uint8,
|
||||
device=device)
|
||||
scale_tensor = (scale
|
||||
if isinstance(scale, torch.Tensor) else torch.tensor(
|
||||
scale, dtype=torch.float32, device=device))
|
||||
kv_cache = torch.zeros(
|
||||
num_blocks, block_size, entry_size, dtype=torch.uint8, device=device
|
||||
)
|
||||
scale_tensor = (
|
||||
scale
|
||||
if isinstance(scale, torch.Tensor)
|
||||
else torch.tensor(scale, dtype=torch.float32, device=device)
|
||||
)
|
||||
scale_tensor = scale_tensor.to(device=device, dtype=torch.float32)
|
||||
else:
|
||||
# Create MLA KV cache: (num_blocks, block_size, head_size)
|
||||
kv_cache = torch.empty(num_blocks,
|
||||
block_size,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
kv_cache = torch.empty(
|
||||
num_blocks, block_size, head_size, dtype=dtype, device=device
|
||||
)
|
||||
kv_cache_flat = kv_cache.view(-1, head_size)
|
||||
|
||||
# Populate the cache with the context tokens
|
||||
@@ -154,8 +156,7 @@ def create_and_prepopulate_kv_cache(
|
||||
start = start_block_idx * block_size
|
||||
|
||||
if use_fp8_ds_mla:
|
||||
slots = torch.arange(context_len, device=device,
|
||||
dtype=torch.long) + start
|
||||
slots = torch.arange(context_len, device=device, dtype=torch.long) + start
|
||||
ops.concat_and_cache_mla(
|
||||
kv_c_context,
|
||||
k_pe_context.squeeze(1),
|
||||
@@ -165,8 +166,7 @@ def create_and_prepopulate_kv_cache(
|
||||
scale=scale_tensor,
|
||||
)
|
||||
else:
|
||||
kv_context = torch.cat(
|
||||
[kv_c_context, k_pe_context.squeeze(1)], dim=-1)
|
||||
kv_context = torch.cat([kv_c_context, k_pe_context.squeeze(1)], dim=-1)
|
||||
end = start + kv_context.shape[0]
|
||||
kv_cache_flat[start:end, ...] = kv_context
|
||||
|
||||
@@ -177,15 +177,14 @@ def create_and_prepopulate_kv_cache(
|
||||
|
||||
# Permute the context blocks (excluding block 0 which is null)
|
||||
if randomize_blocks:
|
||||
perm = torch.randperm(
|
||||
blocks_end - 1) + 1 # Random permutation starting from block 1
|
||||
perm = (
|
||||
torch.randperm(blocks_end - 1) + 1
|
||||
) # Random permutation starting from block 1
|
||||
else:
|
||||
perm = torch.arange(
|
||||
1, blocks_end) # Sequential order starting from block 1
|
||||
perm = torch.arange(1, blocks_end) # Sequential order starting from block 1
|
||||
|
||||
inv_perm = torch.zeros(blocks_end, dtype=torch.long, device=device)
|
||||
inv_perm[1:] = torch.argsort(
|
||||
perm) + 1 # Add 1 to account for starting from block 1
|
||||
inv_perm[1:] = torch.argsort(perm) + 1 # Add 1 to account for starting from block 1
|
||||
kv_cache[1:blocks_end, ...] = kv_cache[perm, ...]
|
||||
|
||||
# Construct the right block table
|
||||
@@ -206,8 +205,8 @@ def create_and_prepopulate_kv_cache(
|
||||
start = common_attn_metadata.query_start_loc_cpu[i]
|
||||
end = common_attn_metadata.query_start_loc_cpu[i + 1]
|
||||
slot_mapping[start:end] = block_table[
|
||||
i,
|
||||
block_indices] * block_size + token_inter_block_offsets.to(device)
|
||||
i, block_indices
|
||||
] * block_size + token_inter_block_offsets.to(device)
|
||||
|
||||
return kv_cache
|
||||
|
||||
@@ -221,15 +220,23 @@ class MockAttentionLayer:
|
||||
self._v_scale = torch.tensor(1.0, device=device)
|
||||
|
||||
|
||||
def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec,
|
||||
layer_names: list[str], vllm_config,
|
||||
device: torch.device,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
query: torch.Tensor, kv_c: torch.Tensor,
|
||||
k_pe: torch.Tensor, kv_cache: torch.Tensor,
|
||||
kv_lora_rank: int, qk_nope_head_dim: int,
|
||||
qk_rope_head_dim: int, v_head_dim: int,
|
||||
mock_kv_b_proj) -> torch.Tensor:
|
||||
def run_attention_backend(
|
||||
backend: _Backend,
|
||||
kv_cache_spec: FullAttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config,
|
||||
device: torch.device,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
query: torch.Tensor,
|
||||
kv_c: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
kv_lora_rank: int,
|
||||
qk_nope_head_dim: int,
|
||||
qk_rope_head_dim: int,
|
||||
v_head_dim: int,
|
||||
mock_kv_b_proj,
|
||||
) -> torch.Tensor:
|
||||
"""Run attention computation using the specified backend's AttentionImpl."""
|
||||
|
||||
builder_cls, impl_cls = get_attention_backend(backend)
|
||||
@@ -243,9 +250,11 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec,
|
||||
|
||||
# Instantiate MLA implementation
|
||||
num_heads = vllm_config.model_config.get_num_attention_heads(
|
||||
vllm_config.parallel_config)
|
||||
vllm_config.parallel_config
|
||||
)
|
||||
num_kv_heads = vllm_config.model_config.get_num_kv_heads(
|
||||
vllm_config.parallel_config)
|
||||
vllm_config.parallel_config
|
||||
)
|
||||
head_size = vllm_config.model_config.get_head_size()
|
||||
scale = 1.0 / (head_size**0.5)
|
||||
impl = impl_cls(
|
||||
@@ -275,30 +284,35 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec,
|
||||
# Create mock layer and output buffer
|
||||
mock_layer = MockAttentionLayer(device)
|
||||
num_tokens = query.shape[0]
|
||||
output = torch.empty(num_tokens,
|
||||
num_heads * v_head_dim,
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
output = torch.empty(
|
||||
num_tokens, num_heads * v_head_dim, dtype=query.dtype, device=query.device
|
||||
)
|
||||
|
||||
# Run forward pass
|
||||
# NOTE: The query, key, and value are already shaped correctly
|
||||
# in the calling test function.
|
||||
output = impl.forward(mock_layer,
|
||||
query,
|
||||
kv_c,
|
||||
k_pe,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
output=output)
|
||||
output = impl.forward(
|
||||
mock_layer, query, kv_c, k_pe, kv_cache, attn_metadata, output=output
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_spec_name", [
|
||||
"small_decode", "small_prefill", "mixed_small", "medium_decode",
|
||||
"medium_prefill", "mixed_medium", "large_decode", "large_prefill",
|
||||
"single_decode", "single_prefill"
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"batch_spec_name",
|
||||
[
|
||||
"small_decode",
|
||||
"small_prefill",
|
||||
"mixed_small",
|
||||
"medium_decode",
|
||||
"medium_prefill",
|
||||
"mixed_medium",
|
||||
"large_decode",
|
||||
"large_prefill",
|
||||
"single_decode",
|
||||
"single_prefill",
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("model", ["deepseek-ai/DeepSeek-V2-Lite-Chat"])
|
||||
def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
|
||||
"""
|
||||
@@ -317,9 +331,9 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
|
||||
5. Comparing the vLLM backend's output to the ground-truth SDPA output.
|
||||
"""
|
||||
batch_spec = BATCH_SPECS[batch_spec_name]
|
||||
vllm_config = create_vllm_config(model_name=model,
|
||||
max_model_len=max(batch_spec.seq_lens),
|
||||
num_gpu_blocks=2048)
|
||||
vllm_config = create_vllm_config(
|
||||
model_name=model, max_model_len=max(batch_spec.seq_lens), num_gpu_blocks=2048
|
||||
)
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
kv_cache_spec = create_standard_kv_cache_spec(vllm_config)
|
||||
@@ -329,7 +343,8 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
|
||||
seq_lens = batch_spec.seq_lens
|
||||
query_lens = batch_spec.query_lens
|
||||
num_q_heads = vllm_config.model_config.get_num_attention_heads(
|
||||
vllm_config.parallel_config)
|
||||
vllm_config.parallel_config
|
||||
)
|
||||
head_size = vllm_config.model_config.get_head_size()
|
||||
dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype)
|
||||
block_size = vllm_config.cache_config.block_size
|
||||
@@ -338,8 +353,9 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
|
||||
qk_nope_head_dim = 128
|
||||
v_head_dim = 128
|
||||
total_head_size = kv_lora_rank + qk_rope_head_dim
|
||||
assert kv_lora_rank + qk_rope_head_dim == head_size, \
|
||||
assert kv_lora_rank + qk_rope_head_dim == head_size, (
|
||||
f"MLA dimensions don't match: {total_head_size} != {head_size}"
|
||||
)
|
||||
scale = 1.0 / (total_head_size**0.5)
|
||||
|
||||
# 2. Generate data and compute SDPA reference output for MLA
|
||||
@@ -348,16 +364,12 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
|
||||
kv_c_contexts, k_pe_contexts = [], []
|
||||
|
||||
# Create shared MLA weight matrices for consistency across all sequences
|
||||
W_UK = torch.randn(kv_lora_rank,
|
||||
num_q_heads,
|
||||
qk_nope_head_dim,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
W_UV = torch.randn(kv_lora_rank,
|
||||
num_q_heads,
|
||||
v_head_dim,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
W_UK = torch.randn(
|
||||
kv_lora_rank, num_q_heads, qk_nope_head_dim, dtype=dtype, device=device
|
||||
)
|
||||
W_UV = torch.randn(
|
||||
kv_lora_rank, num_q_heads, v_head_dim, dtype=dtype, device=device
|
||||
)
|
||||
kv_b_proj_weight = torch.cat([W_UK, W_UV], dim=-1)
|
||||
|
||||
for i, backend in enumerate(BACKENDS_TO_TEST):
|
||||
@@ -371,24 +383,19 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
|
||||
# Generate MLA tensors
|
||||
# Q has both nope and rope components:
|
||||
# [q_len, num_heads, qk_nope_head_dim + qk_rope_head_dim]
|
||||
q_c = torch.randn(q_len,
|
||||
num_q_heads,
|
||||
qk_nope_head_dim + qk_rope_head_dim,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
q_c = torch.randn(
|
||||
q_len,
|
||||
num_q_heads,
|
||||
qk_nope_head_dim + qk_rope_head_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# KV_C (latent K/V): [s_len, kv_lora_rank]
|
||||
kv_c_full = torch.randn(s_len,
|
||||
kv_lora_rank,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
kv_c_full = torch.randn(s_len, kv_lora_rank, dtype=dtype, device=device)
|
||||
|
||||
# K_PE (rope component): [s_len, 1, qk_rope_head_dim]
|
||||
k_pe_full = torch.randn(s_len,
|
||||
1,
|
||||
qk_rope_head_dim,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
k_pe_full = torch.randn(s_len, 1, qk_rope_head_dim, dtype=dtype, device=device)
|
||||
|
||||
# Determine if this is decode or prefill
|
||||
is_decode = []
|
||||
@@ -404,8 +411,9 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
|
||||
# Transform q_nope to latent space: q_nope @ W_UK
|
||||
# q_nope: [1, num_heads, qk_nope_head_dim]
|
||||
# W_UK: [kv_lora_rank, num_heads, qk_nope_head_dim]
|
||||
ql_nope = torch.einsum("qnh,lnh->qnl", q_nope,
|
||||
W_UK) # [1, num_heads, kv_lora_rank]
|
||||
ql_nope = torch.einsum(
|
||||
"qnh,lnh->qnl", q_nope, W_UK
|
||||
) # [1, num_heads, kv_lora_rank]
|
||||
|
||||
# Build MQA attention inputs
|
||||
# Q: [1, num_heads, kv_lora_rank + qk_rope_head_dim]
|
||||
@@ -431,25 +439,24 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
|
||||
v_sdpa_in = v_mqa.unsqueeze(0).transpose(1, 2)
|
||||
|
||||
sdpa_out_i_decode = torch.nn.functional.scaled_dot_product_attention(
|
||||
q_sdpa_in, k_sdpa_in, v_sdpa_in, attn_mask=attn_mask, scale=scale)
|
||||
q_sdpa_in, k_sdpa_in, v_sdpa_in, attn_mask=attn_mask, scale=scale
|
||||
)
|
||||
sdpa_out_i_decode = sdpa_out_i_decode.transpose(1, 2).squeeze(
|
||||
0) # [1, num_heads, kv_lora_rank]
|
||||
0
|
||||
) # [1, num_heads, kv_lora_rank]
|
||||
|
||||
# Project back to output space: sdpa_out @ W_UV
|
||||
sdpa_out_i_decode = torch.einsum("qnl,lnv->qnv", sdpa_out_i_decode,
|
||||
W_UV)
|
||||
sdpa_out_i_decode = torch.einsum("qnl,lnv->qnv", sdpa_out_i_decode, W_UV)
|
||||
sdpa_out_i_decode = sdpa_out_i_decode.flatten(start_dim=-2)
|
||||
|
||||
#######################################################
|
||||
# Prefill path: MHA-style attention with full sequence
|
||||
# Apply kv_b_proj to the full kv_c tensor
|
||||
kv_nope_full = torch.einsum("sl,lnh->snh", kv_c_full, kv_b_proj_weight)
|
||||
k_nope_full, v_full = kv_nope_full.split(
|
||||
[qk_nope_head_dim, v_head_dim], dim=-1)
|
||||
k_nope_full, v_full = kv_nope_full.split([qk_nope_head_dim, v_head_dim], dim=-1)
|
||||
|
||||
# Build attention inputs for full sequence
|
||||
q_mha = torch.cat([q_nope, q_pe],
|
||||
dim=-1) # [q_len, num_heads, total_dim]
|
||||
q_mha = torch.cat([q_nope, q_pe], dim=-1) # [q_len, num_heads, total_dim]
|
||||
k_pe_full_expanded = k_pe_full.expand(-1, num_q_heads, -1)
|
||||
k_full = torch.cat([k_nope_full, k_pe_full_expanded], dim=-1)
|
||||
|
||||
@@ -468,7 +475,8 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
|
||||
|
||||
# Single attention call with custom mask
|
||||
sdpa_out_i_prefill = torch.nn.functional.scaled_dot_product_attention(
|
||||
q_sdpa_in, k_sdpa_in, v_sdpa_in, attn_mask=attn_mask, scale=scale)
|
||||
q_sdpa_in, k_sdpa_in, v_sdpa_in, attn_mask=attn_mask, scale=scale
|
||||
)
|
||||
sdpa_out_i_prefill = sdpa_out_i_prefill.transpose(1, 2).squeeze(0)
|
||||
sdpa_out_i_prefill = sdpa_out_i_prefill.flatten(start_dim=-2)
|
||||
|
||||
@@ -497,22 +505,25 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
|
||||
|
||||
# Create mock kv_b_proj using the same weights as reference implementation
|
||||
from vllm.model_executor.layers.linear import ColumnParallelLinear
|
||||
mock_kv_b_proj = ColumnParallelLinear(input_size=kv_lora_rank,
|
||||
output_size=num_q_heads *
|
||||
(qk_nope_head_dim + v_head_dim),
|
||||
bias=False).to(device=device,
|
||||
dtype=dtype)
|
||||
|
||||
mock_kv_b_proj = ColumnParallelLinear(
|
||||
input_size=kv_lora_rank,
|
||||
output_size=num_q_heads * (qk_nope_head_dim + v_head_dim),
|
||||
bias=False,
|
||||
).to(device=device, dtype=dtype)
|
||||
|
||||
# Set the mock weights to match our reference implementation
|
||||
# Reshape W_UK and W_UV to match the expected kv_b_proj format
|
||||
# [kv_lora_rank, num_heads, qk_nope_head_dim + v_head_dim]
|
||||
kv_b_proj_weight = kv_b_proj_weight.view(
|
||||
kv_lora_rank, num_q_heads * (qk_nope_head_dim + v_head_dim))
|
||||
kv_lora_rank, num_q_heads * (qk_nope_head_dim + v_head_dim)
|
||||
)
|
||||
mock_kv_b_proj.weight = torch.nn.Parameter(kv_b_proj_weight.T)
|
||||
|
||||
# Create metadata using original batch spec
|
||||
common_attn_metadata = create_common_attn_metadata(
|
||||
batch_spec, vllm_config.cache_config.block_size, device)
|
||||
batch_spec, vllm_config.cache_config.block_size, device
|
||||
)
|
||||
|
||||
# 3. Simulate Paged KV Cache and a realistic slot_mapping
|
||||
kv_cache = create_and_prepopulate_kv_cache(
|
||||
@@ -524,41 +535,56 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
|
||||
device=device,
|
||||
num_blocks=vllm_config.cache_config.num_gpu_blocks,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
randomize_blocks=True)
|
||||
randomize_blocks=True,
|
||||
)
|
||||
|
||||
# 4. Run vLLM backends and compare
|
||||
for i, backend_name in enumerate(BACKENDS_TO_TEST):
|
||||
backend_output = run_attention_backend(
|
||||
backend_name, kv_cache_spec, ["placeholder"], vllm_config, device,
|
||||
common_attn_metadata, query_vllm, kv_c_vllm, k_pe_vllm, kv_cache,
|
||||
kv_lora_rank, qk_nope_head_dim, qk_rope_head_dim, v_head_dim,
|
||||
mock_kv_b_proj)
|
||||
backend_name,
|
||||
kv_cache_spec,
|
||||
["placeholder"],
|
||||
vllm_config,
|
||||
device,
|
||||
common_attn_metadata,
|
||||
query_vllm,
|
||||
kv_c_vllm,
|
||||
k_pe_vllm,
|
||||
kv_cache,
|
||||
kv_lora_rank,
|
||||
qk_nope_head_dim,
|
||||
qk_rope_head_dim,
|
||||
v_head_dim,
|
||||
mock_kv_b_proj,
|
||||
)
|
||||
|
||||
# Check shape and dtype consistency
|
||||
assert backend_output.shape == sdpa_outputs[i].shape, (
|
||||
f"[{backend_name}] shape {backend_output.shape} != "
|
||||
f"SDPA shape {sdpa_outputs[i].shape}")
|
||||
f"SDPA shape {sdpa_outputs[i].shape}"
|
||||
)
|
||||
assert backend_output.dtype == sdpa_outputs[i].dtype, (
|
||||
f"[{backend_name}] dtype {backend_output.dtype} != "
|
||||
f"SDPA dtype {sdpa_outputs[i].dtype}")
|
||||
f"SDPA dtype {sdpa_outputs[i].dtype}"
|
||||
)
|
||||
|
||||
assert torch.isfinite(backend_output).all(), (
|
||||
f"[{backend_name}] produced non-finite values")
|
||||
f"[{backend_name}] produced non-finite values"
|
||||
)
|
||||
|
||||
# Check numerical similarity
|
||||
rtol = 1e-2
|
||||
atol = 5e-1
|
||||
|
||||
max_diff = torch.max(torch.abs(backend_output -
|
||||
sdpa_outputs[i])).item()
|
||||
max_diff = torch.max(torch.abs(backend_output - sdpa_outputs[i])).item()
|
||||
max_rel_diff = torch.max(
|
||||
torch.abs(backend_output - sdpa_outputs[i]) /
|
||||
torch.abs(sdpa_outputs[i])).item()
|
||||
all_close = torch.allclose(backend_output,
|
||||
sdpa_outputs[i],
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
torch.abs(backend_output - sdpa_outputs[i]) / torch.abs(sdpa_outputs[i])
|
||||
).item()
|
||||
all_close = torch.allclose(
|
||||
backend_output, sdpa_outputs[i], rtol=rtol, atol=atol
|
||||
)
|
||||
|
||||
assert all_close, (
|
||||
f"[{backend_name}] output differs from SDPA baseline. "
|
||||
f"Max diff: {max_diff:.6f}, max rel diff: {max_rel_diff:.6f})")
|
||||
f"Max diff: {max_diff:.6f}, max rel diff: {max_rel_diff:.6f})"
|
||||
)
|
||||
|
||||
@@ -10,18 +10,26 @@ import pytest
|
||||
import torch
|
||||
|
||||
from tests.v1.attention.test_mla_backends import (
|
||||
BATCH_SPECS, BatchSpec, MockAttentionLayer,
|
||||
create_and_prepopulate_kv_cache)
|
||||
from tests.v1.attention.utils import (create_common_attn_metadata,
|
||||
create_standard_kv_cache_spec,
|
||||
create_vllm_config)
|
||||
BATCH_SPECS,
|
||||
BatchSpec,
|
||||
MockAttentionLayer,
|
||||
create_and_prepopulate_kv_cache,
|
||||
)
|
||||
from tests.v1.attention.utils import (
|
||||
create_common_attn_metadata,
|
||||
create_standard_kv_cache_spec,
|
||||
create_vllm_config,
|
||||
)
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.ops import flashmla
|
||||
from vllm.model_executor.layers.linear import ColumnParallelLinear
|
||||
from vllm.utils import cdiv
|
||||
from vllm.v1.attention.backends.mla.flashmla_sparse import (
|
||||
FlashMLASparseBackend, FlashMLASparseDecodeAndContextMetadata,
|
||||
FlashMLASparseImpl, FlashMLASparseMetadata)
|
||||
FlashMLASparseBackend,
|
||||
FlashMLASparseDecodeAndContextMetadata,
|
||||
FlashMLASparseImpl,
|
||||
FlashMLASparseMetadata,
|
||||
)
|
||||
from vllm.v1.attention.backends.mla.indexer import split_prefill_chunks
|
||||
|
||||
SPARSE_BACKEND_BATCH_SPECS = {
|
||||
@@ -35,41 +43,42 @@ SPARSE_BACKEND_BATCH_SPECS = {
|
||||
]
|
||||
}
|
||||
|
||||
SPARSE_BACKEND_BATCH_SPECS["large_q_prefill"] = BatchSpec(seq_lens=[1024] * 2,
|
||||
query_lens=[256] * 2)
|
||||
SPARSE_BACKEND_BATCH_SPECS["large_q_prefill"] = BatchSpec(
|
||||
seq_lens=[1024] * 2, query_lens=[256] * 2
|
||||
)
|
||||
SPARSE_BACKEND_BATCH_SPECS["large_q_pure_prefill"] = BatchSpec(
|
||||
seq_lens=[256] * 2, query_lens=[256] * 2)
|
||||
seq_lens=[256] * 2, query_lens=[256] * 2
|
||||
)
|
||||
|
||||
|
||||
def _dequantize_fp8_ds_mla_entry(
|
||||
cache_slice: torch.Tensor, kv_lora_rank: int, rope_dim: int,
|
||||
dtype: torch.dtype) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
cache_slice: torch.Tensor, kv_lora_rank: int, rope_dim: int, dtype: torch.dtype
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Dequantize a single fp8_ds_mla cache entry back to latent + rope."""
|
||||
|
||||
# The first kv_lora_rank bytes store FP8 latent values with one scale per
|
||||
# 128 element tile written as float32 right after the latent payload.
|
||||
scales = cache_slice.view(torch.float32)[kv_lora_rank //
|
||||
4:kv_lora_rank // 4 + 4]
|
||||
latent = torch.empty(kv_lora_rank,
|
||||
dtype=torch.float16,
|
||||
device=cache_slice.device)
|
||||
scales = cache_slice.view(torch.float32)[kv_lora_rank // 4 : kv_lora_rank // 4 + 4]
|
||||
latent = torch.empty(kv_lora_rank, dtype=torch.float16, device=cache_slice.device)
|
||||
for tile_idx in range(4):
|
||||
tile_start = tile_idx * 128
|
||||
tile_end = tile_start + 128
|
||||
ops.convert_fp8(latent[tile_start:tile_end],
|
||||
cache_slice[tile_start:tile_end],
|
||||
float(scales[tile_idx].item()),
|
||||
kv_dtype="fp8")
|
||||
ops.convert_fp8(
|
||||
latent[tile_start:tile_end],
|
||||
cache_slice[tile_start:tile_end],
|
||||
float(scales[tile_idx].item()),
|
||||
kv_dtype="fp8",
|
||||
)
|
||||
latent = latent.to(dtype)
|
||||
|
||||
rope_offset = kv_lora_rank // 2 + 8
|
||||
rope_vals = cache_slice.view(dtype)[rope_offset:rope_offset + rope_dim]
|
||||
rope_vals = cache_slice.view(dtype)[rope_offset : rope_offset + rope_dim]
|
||||
return latent, rope_vals.clone()
|
||||
|
||||
|
||||
def _quantize_dequantize_fp8_ds_mla(
|
||||
kv_c: torch.Tensor, k_pe: torch.Tensor, block_size: int,
|
||||
scale: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
kv_c: torch.Tensor, k_pe: torch.Tensor, block_size: int, scale: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Round-trip kv_c/k_pe though the fp8_ds_mla cache layout."""
|
||||
|
||||
if kv_c.numel() == 0:
|
||||
@@ -81,21 +90,14 @@ def _quantize_dequantize_fp8_ds_mla(
|
||||
num_blocks = max(1, math.ceil(num_tokens / block_size))
|
||||
entry_size = kv_lora_rank + 4 * 4 + 2 * rope_dim
|
||||
|
||||
tmp_cache = torch.zeros(num_blocks,
|
||||
block_size,
|
||||
entry_size,
|
||||
dtype=torch.uint8,
|
||||
device=kv_c.device)
|
||||
slot_mapping = torch.arange(num_tokens,
|
||||
dtype=torch.long,
|
||||
device=kv_c.device)
|
||||
tmp_cache = torch.zeros(
|
||||
num_blocks, block_size, entry_size, dtype=torch.uint8, device=kv_c.device
|
||||
)
|
||||
slot_mapping = torch.arange(num_tokens, dtype=torch.long, device=kv_c.device)
|
||||
|
||||
ops.concat_and_cache_mla(kv_c,
|
||||
k_pe,
|
||||
tmp_cache,
|
||||
slot_mapping,
|
||||
kv_cache_dtype="fp8_ds_mla",
|
||||
scale=scale)
|
||||
ops.concat_and_cache_mla(
|
||||
kv_c, k_pe, tmp_cache, slot_mapping, kv_cache_dtype="fp8_ds_mla", scale=scale
|
||||
)
|
||||
|
||||
dequant_kv_c = torch.empty_like(kv_c)
|
||||
dequant_k_pe = torch.empty_like(k_pe)
|
||||
@@ -106,7 +108,8 @@ def _quantize_dequantize_fp8_ds_mla(
|
||||
block_offset = slot % block_size
|
||||
cache_slice = tmp_cache[block_idx, block_offset]
|
||||
latent, rope_vals = _dequantize_fp8_ds_mla_entry(
|
||||
cache_slice, kv_lora_rank, rope_dim, kv_c.dtype)
|
||||
cache_slice, kv_lora_rank, rope_dim, kv_c.dtype
|
||||
)
|
||||
dequant_kv_c[token_idx] = latent
|
||||
dequant_k_pe[token_idx] = rope_vals
|
||||
|
||||
@@ -123,10 +126,9 @@ def test_sparse_backend_metadata_registration():
|
||||
dtype_list = backend.get_supported_dtypes()
|
||||
assert torch.bfloat16 in dtype_list
|
||||
|
||||
shape = backend.get_kv_cache_shape(num_blocks=2,
|
||||
block_size=64,
|
||||
num_kv_heads=1,
|
||||
head_size=576)
|
||||
shape = backend.get_kv_cache_shape(
|
||||
num_blocks=2, block_size=64, num_kv_heads=1, head_size=576
|
||||
)
|
||||
assert shape == (2, 64, 576)
|
||||
|
||||
|
||||
@@ -141,13 +143,10 @@ def test_sparse_decode_metadata_filters_prefill_indices():
|
||||
|
||||
indices = torch.tensor([[0, 3, 5], [1, 2, 4]], dtype=torch.int32)
|
||||
|
||||
context_indices, new_token_indices = metadata.filter_prefill_indices(
|
||||
indices)
|
||||
context_indices, new_token_indices = metadata.filter_prefill_indices(indices)
|
||||
|
||||
expected_context = torch.tensor([[-1, -1, 5], [-1, -1, 4]],
|
||||
dtype=torch.int32)
|
||||
expected_new_tokens = torch.tensor([[-1, -1, 1], [-1, 0, 2]],
|
||||
dtype=torch.int32)
|
||||
expected_context = torch.tensor([[-1, -1, 5], [-1, -1, 4]], dtype=torch.int32)
|
||||
expected_new_tokens = torch.tensor([[-1, -1, 1], [-1, 0, 2]], dtype=torch.int32)
|
||||
|
||||
assert torch.equal(context_indices, expected_context)
|
||||
assert torch.equal(new_token_indices, expected_new_tokens)
|
||||
@@ -162,14 +161,9 @@ def test_sparse_impl_zero_fills_when_metadata_missing():
|
||||
kv_cache = torch.zeros((1, 1, 1))
|
||||
output = torch.ones((2, 4))
|
||||
|
||||
result = FlashMLASparseImpl.forward(impl,
|
||||
dummy_layer,
|
||||
q,
|
||||
k_c,
|
||||
k_pe,
|
||||
kv_cache,
|
||||
attn_metadata=None,
|
||||
output=output)
|
||||
result = FlashMLASparseImpl.forward(
|
||||
impl, dummy_layer, q, k_c, k_pe, kv_cache, attn_metadata=None, output=output
|
||||
)
|
||||
|
||||
assert result is output
|
||||
assert torch.all(result == 0)
|
||||
@@ -177,8 +171,7 @@ def test_sparse_impl_zero_fills_when_metadata_missing():
|
||||
|
||||
@pytest.mark.parametrize("batch_name", list(SPARSE_BACKEND_BATCH_SPECS.keys()))
|
||||
@pytest.mark.parametrize("kv_cache_dtype", ["fp8_ds_mla", "auto"])
|
||||
def test_sparse_backend_decode_correctness(dist_init, batch_name,
|
||||
kv_cache_dtype):
|
||||
def test_sparse_backend_decode_correctness(dist_init, batch_name, kv_cache_dtype):
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA is required for sparse MLA decode test")
|
||||
|
||||
@@ -203,14 +196,13 @@ def test_sparse_backend_decode_correctness(dist_init, batch_name,
|
||||
vllm_config = create_vllm_config(
|
||||
model_name="deepseek-ai/DeepSeek-V2-Lite-Chat",
|
||||
max_model_len=max_seqlen,
|
||||
num_gpu_blocks=max(2048,
|
||||
cdiv(total_cache_tokens, block_size) + 1),
|
||||
block_size=block_size)
|
||||
num_gpu_blocks=max(2048, cdiv(total_cache_tokens, block_size) + 1),
|
||||
block_size=block_size,
|
||||
)
|
||||
model_config = vllm_config.model_config
|
||||
model_config.hf_config = SimpleNamespace(
|
||||
attn_module_list_cfg=[{
|
||||
"topk_tokens": topk_tokens
|
||||
}])
|
||||
attn_module_list_cfg=[{"topk_tokens": topk_tokens}]
|
||||
)
|
||||
model_config.hf_text_config = SimpleNamespace(
|
||||
q_lora_rank=None,
|
||||
kv_lora_rank=kv_lora_rank,
|
||||
@@ -221,13 +213,13 @@ def test_sparse_backend_decode_correctness(dist_init, batch_name,
|
||||
)
|
||||
model_config.dtype = dtype
|
||||
model_config.get_num_attention_heads = MethodType(
|
||||
lambda self, parallel_config: num_heads, model_config)
|
||||
model_config.get_num_kv_heads = MethodType(lambda self, parallel_config: 1,
|
||||
model_config)
|
||||
model_config.get_head_size = MethodType(lambda self: head_size,
|
||||
model_config)
|
||||
model_config.get_sliding_window = MethodType(lambda self: None,
|
||||
model_config)
|
||||
lambda self, parallel_config: num_heads, model_config
|
||||
)
|
||||
model_config.get_num_kv_heads = MethodType(
|
||||
lambda self, parallel_config: 1, model_config
|
||||
)
|
||||
model_config.get_head_size = MethodType(lambda self: head_size, model_config)
|
||||
model_config.get_sliding_window = MethodType(lambda self: None, model_config)
|
||||
|
||||
kv_cache_spec = create_standard_kv_cache_spec(vllm_config)
|
||||
|
||||
@@ -236,16 +228,10 @@ def test_sparse_backend_decode_correctness(dist_init, batch_name,
|
||||
scale = 1.0 / math.sqrt(head_size)
|
||||
|
||||
# Shared MLA projection weights to keep reference and backend in sync
|
||||
W_UK = torch.randn(kv_lora_rank,
|
||||
num_heads,
|
||||
qk_nope_head_dim,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
W_UV = torch.randn(kv_lora_rank,
|
||||
num_heads,
|
||||
v_head_dim,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
W_UK = torch.randn(
|
||||
kv_lora_rank, num_heads, qk_nope_head_dim, dtype=dtype, device=device
|
||||
)
|
||||
W_UV = torch.randn(kv_lora_rank, num_heads, v_head_dim, dtype=dtype, device=device)
|
||||
|
||||
# Build synthetic decode-only workload
|
||||
seq_lens = batch_spec.seq_lens
|
||||
@@ -262,17 +248,15 @@ def test_sparse_backend_decode_correctness(dist_init, batch_name,
|
||||
q_len = query_lens[i]
|
||||
ctx_len = s_len - q_len
|
||||
|
||||
q_c = torch.rand(q_len,
|
||||
num_heads,
|
||||
qk_nope_head_dim + qk_rope_head_dim,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
q_c = torch.rand(
|
||||
q_len,
|
||||
num_heads,
|
||||
qk_nope_head_dim + qk_rope_head_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
kv_c_full = torch.rand(s_len, kv_lora_rank, dtype=dtype, device=device)
|
||||
k_pe_full = torch.rand(s_len,
|
||||
1,
|
||||
qk_rope_head_dim,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
k_pe_full = torch.rand(s_len, 1, qk_rope_head_dim, dtype=dtype, device=device)
|
||||
|
||||
kv_c_full, k_pe_full = _quantize_dequantize_fp8_ds_mla(
|
||||
kv_c_full,
|
||||
@@ -298,7 +282,8 @@ def test_sparse_backend_decode_correctness(dist_init, batch_name,
|
||||
v_sdpa_in = v_mqa.unsqueeze(0).transpose(1, 2)
|
||||
|
||||
sdpa_out = torch.nn.functional.scaled_dot_product_attention(
|
||||
q_sdpa_in, k_sdpa_in, v_sdpa_in, attn_mask=attn_mask, scale=scale)
|
||||
q_sdpa_in, k_sdpa_in, v_sdpa_in, attn_mask=attn_mask, scale=scale
|
||||
)
|
||||
sdpa_out = sdpa_out.transpose(1, 2).squeeze(0)
|
||||
|
||||
sdpa_out = torch.einsum("qnl,lnv->qnv", sdpa_out, W_UV)
|
||||
@@ -307,8 +292,8 @@ def test_sparse_backend_decode_correctness(dist_init, batch_name,
|
||||
all_q_vllm.append(q_c)
|
||||
all_kv_c_vllm.append(kv_c_full[ctx_len:])
|
||||
all_k_pe_vllm.append(k_pe_full[ctx_len:])
|
||||
kv_c_contexts.append(kv_c_full[:ctx_len + 1])
|
||||
k_pe_contexts.append(k_pe_full[:ctx_len + 1])
|
||||
kv_c_contexts.append(kv_c_full[: ctx_len + 1])
|
||||
k_pe_contexts.append(k_pe_full[: ctx_len + 1])
|
||||
|
||||
query_vllm = torch.cat(all_q_vllm, dim=0)
|
||||
kv_c_vllm = torch.cat(all_kv_c_vllm, dim=0)
|
||||
@@ -321,7 +306,8 @@ def test_sparse_backend_decode_correctness(dist_init, batch_name,
|
||||
batch_spec,
|
||||
vllm_config.cache_config.block_size,
|
||||
device,
|
||||
arange_block_indices=True)
|
||||
arange_block_indices=True,
|
||||
)
|
||||
|
||||
kv_cache = create_and_prepopulate_kv_cache(
|
||||
kv_c_contexts=kv_c_contexts,
|
||||
@@ -339,31 +325,31 @@ def test_sparse_backend_decode_correctness(dist_init, batch_name,
|
||||
|
||||
builder_cls = FlashMLASparseBackend.get_builder_cls()
|
||||
builder = builder_cls(kv_cache_spec, ["placeholder"], vllm_config, device)
|
||||
metadata = builder.build(common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata)
|
||||
metadata = builder.build(
|
||||
common_prefix_len=0, common_attn_metadata=common_attn_metadata
|
||||
)
|
||||
|
||||
starts = np.asarray(common_attn_metadata.query_start_loc_cpu,
|
||||
dtype=np.int32)
|
||||
starts = np.asarray(common_attn_metadata.query_start_loc_cpu, dtype=np.int32)
|
||||
seg_lengths = np.diff(starts)
|
||||
positions = np.arange(starts[-1], dtype=np.int32) - np.repeat(
|
||||
starts[:-1], seg_lengths)
|
||||
starts[:-1], seg_lengths
|
||||
)
|
||||
seq_lengths = np.asarray(common_attn_metadata.seq_lens_cpu, dtype=np.int32)
|
||||
prefix_lengths = seq_lengths - seg_lengths
|
||||
positions += np.repeat(prefix_lengths, seg_lengths)
|
||||
|
||||
pos_gpu = torch.as_tensor(positions, device=device, dtype=torch.int32)
|
||||
topk = metadata.topk_tokens
|
||||
debug_indices = torch.arange(topk, device=device,
|
||||
dtype=torch.int32).unsqueeze(0)
|
||||
debug_indices = torch.arange(topk, device=device, dtype=torch.int32).unsqueeze(0)
|
||||
token_positions = pos_gpu.unsqueeze(1)
|
||||
causal_mask = (debug_indices <= token_positions)
|
||||
debug_indices = torch.where(causal_mask, debug_indices,
|
||||
torch.full_like(debug_indices, -1))
|
||||
causal_mask = debug_indices <= token_positions
|
||||
debug_indices = torch.where(
|
||||
causal_mask, debug_indices, torch.full_like(debug_indices, -1)
|
||||
)
|
||||
|
||||
# FlashMLASparseImpl now reads top-k indices from the indexer-provided
|
||||
# buffer, so emulate that contract with a simple namespace mock.
|
||||
debug_indices = debug_indices.expand(metadata.num_actual_tokens,
|
||||
-1).clone()
|
||||
debug_indices = debug_indices.expand(metadata.num_actual_tokens, -1).clone()
|
||||
mock_indexer = SimpleNamespace(topk_indices_buffer=debug_indices)
|
||||
|
||||
ok, reason = flashmla.is_flashmla_supported()
|
||||
@@ -372,59 +358,54 @@ def test_sparse_backend_decode_correctness(dist_init, batch_name,
|
||||
|
||||
kv_b_proj_weight = torch.cat([W_UK, W_UV], dim=-1)
|
||||
kv_b_proj_weight = kv_b_proj_weight.view(
|
||||
kv_lora_rank, num_heads * (qk_nope_head_dim + v_head_dim))
|
||||
kv_lora_rank, num_heads * (qk_nope_head_dim + v_head_dim)
|
||||
)
|
||||
|
||||
mock_kv_b_proj = ColumnParallelLinear(input_size=kv_lora_rank,
|
||||
output_size=num_heads *
|
||||
(qk_nope_head_dim + v_head_dim),
|
||||
bias=False).to(device=device,
|
||||
dtype=dtype)
|
||||
mock_kv_b_proj = ColumnParallelLinear(
|
||||
input_size=kv_lora_rank,
|
||||
output_size=num_heads * (qk_nope_head_dim + v_head_dim),
|
||||
bias=False,
|
||||
).to(device=device, dtype=dtype)
|
||||
mock_kv_b_proj.weight = torch.nn.Parameter(kv_b_proj_weight.T.contiguous())
|
||||
|
||||
impl_cls = FlashMLASparseBackend.get_impl_cls()
|
||||
impl = impl_cls(num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
scale=scale,
|
||||
num_kv_heads=1,
|
||||
alibi_slopes=None,
|
||||
sliding_window=None,
|
||||
kv_cache_dtype=vllm_config.cache_config.cache_dtype,
|
||||
logits_soft_cap=None,
|
||||
attn_type="decoder",
|
||||
kv_sharing_target_layer_name=None,
|
||||
q_lora_rank=None,
|
||||
kv_lora_rank=kv_lora_rank,
|
||||
qk_nope_head_dim=qk_nope_head_dim,
|
||||
qk_rope_head_dim=qk_rope_head_dim,
|
||||
qk_head_dim=qk_nope_head_dim + qk_rope_head_dim,
|
||||
v_head_dim=v_head_dim,
|
||||
kv_b_proj=mock_kv_b_proj,
|
||||
indexer=mock_indexer)
|
||||
impl = impl_cls(
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
scale=scale,
|
||||
num_kv_heads=1,
|
||||
alibi_slopes=None,
|
||||
sliding_window=None,
|
||||
kv_cache_dtype=vllm_config.cache_config.cache_dtype,
|
||||
logits_soft_cap=None,
|
||||
attn_type="decoder",
|
||||
kv_sharing_target_layer_name=None,
|
||||
q_lora_rank=None,
|
||||
kv_lora_rank=kv_lora_rank,
|
||||
qk_nope_head_dim=qk_nope_head_dim,
|
||||
qk_rope_head_dim=qk_rope_head_dim,
|
||||
qk_head_dim=qk_nope_head_dim + qk_rope_head_dim,
|
||||
v_head_dim=v_head_dim,
|
||||
kv_b_proj=mock_kv_b_proj,
|
||||
indexer=mock_indexer,
|
||||
)
|
||||
|
||||
impl.process_weights_after_loading(dtype)
|
||||
|
||||
layer = MockAttentionLayer(device)
|
||||
out_buffer = torch.empty(metadata.num_actual_tokens,
|
||||
num_heads * v_head_dim,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
out_buffer = torch.empty(
|
||||
metadata.num_actual_tokens, num_heads * v_head_dim, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
backend_output = impl.forward(layer,
|
||||
query_vllm,
|
||||
kv_c_vllm,
|
||||
k_pe_vllm,
|
||||
kv_cache,
|
||||
metadata,
|
||||
output=out_buffer)
|
||||
backend_output = impl.forward(
|
||||
layer, query_vllm, kv_c_vllm, k_pe_vllm, kv_cache, metadata, output=out_buffer
|
||||
)
|
||||
|
||||
assert backend_output.shape == sdpa_reference.shape
|
||||
assert backend_output.dtype == sdpa_reference.dtype
|
||||
assert torch.isfinite(backend_output).all()
|
||||
|
||||
torch.testing.assert_close(backend_output,
|
||||
sdpa_reference,
|
||||
rtol=0.5,
|
||||
atol=0.5)
|
||||
torch.testing.assert_close(backend_output, sdpa_reference, rtol=0.5, atol=0.5)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
@@ -9,9 +9,17 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.config import (CacheConfig, CompilationConfig, DeviceConfig,
|
||||
LoadConfig, ModelConfig, ModelDType, ParallelConfig,
|
||||
SchedulerConfig, VllmConfig)
|
||||
from vllm.config import (
|
||||
CacheConfig,
|
||||
CompilationConfig,
|
||||
DeviceConfig,
|
||||
LoadConfig,
|
||||
ModelConfig,
|
||||
ModelDType,
|
||||
ParallelConfig,
|
||||
SchedulerConfig,
|
||||
VllmConfig,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import resolve_obj_by_qualname
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
@@ -21,6 +29,7 @@ from vllm.v1.kv_cache_interface import FullAttentionSpec
|
||||
@dataclass
|
||||
class BatchSpec:
|
||||
"""Specification for a batch configuration (workload shape only)."""
|
||||
|
||||
seq_lens: list[int]
|
||||
query_lens: list[int]
|
||||
|
||||
@@ -38,26 +47,25 @@ class BatchSpec:
|
||||
|
||||
|
||||
def create_common_attn_metadata(
|
||||
batch_spec: BatchSpec,
|
||||
block_size: int,
|
||||
device: torch.device,
|
||||
max_block_idx: int = 1000,
|
||||
arange_block_indices: bool = False) -> CommonAttentionMetadata:
|
||||
batch_spec: BatchSpec,
|
||||
block_size: int,
|
||||
device: torch.device,
|
||||
max_block_idx: int = 1000,
|
||||
arange_block_indices: bool = False,
|
||||
) -> CommonAttentionMetadata:
|
||||
"""Create CommonAttentionMetadata from a BatchSpec and ModelParams."""
|
||||
# Create query start locations
|
||||
query_start_loc = torch.zeros(batch_spec.batch_size + 1,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
query_start_loc[1:] = torch.tensor(batch_spec.query_lens,
|
||||
dtype=torch.int32,
|
||||
device=device).cumsum(0)
|
||||
query_start_loc = torch.zeros(
|
||||
batch_spec.batch_size + 1, dtype=torch.int32, device=device
|
||||
)
|
||||
query_start_loc[1:] = torch.tensor(
|
||||
batch_spec.query_lens, dtype=torch.int32, device=device
|
||||
).cumsum(0)
|
||||
query_start_loc_cpu = query_start_loc.cpu()
|
||||
num_tokens = batch_spec.compute_num_tokens()
|
||||
|
||||
# Create sequence lengths
|
||||
seq_lens = torch.tensor(batch_spec.seq_lens,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
seq_lens = torch.tensor(batch_spec.seq_lens, dtype=torch.int32, device=device)
|
||||
seq_lens_cpu = seq_lens.cpu()
|
||||
max_seq_len = int(seq_lens_cpu.max())
|
||||
|
||||
@@ -72,24 +80,23 @@ def create_common_attn_metadata(
|
||||
max_blocks = (max(batch_spec.seq_lens) + block_size - 1) // block_size
|
||||
if arange_block_indices:
|
||||
num_blocks = batch_spec.batch_size * max_blocks
|
||||
block_table_tensor = torch.arange(num_blocks,
|
||||
dtype=torch.int32,
|
||||
device=device).view(
|
||||
batch_spec.batch_size,
|
||||
max_blocks)
|
||||
slot_mapping = torch.arange(num_tokens,
|
||||
dtype=torch.int64,
|
||||
device=device).view(num_tokens)
|
||||
block_table_tensor = torch.arange(
|
||||
num_blocks, dtype=torch.int32, device=device
|
||||
).view(batch_spec.batch_size, max_blocks)
|
||||
slot_mapping = torch.arange(num_tokens, dtype=torch.int64, device=device).view(
|
||||
num_tokens
|
||||
)
|
||||
else:
|
||||
block_table_tensor = torch.randint(0,
|
||||
max_block_idx,
|
||||
(batch_spec.batch_size, max_blocks),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
slot_mapping = torch.randint(0,
|
||||
max_block_idx, (num_tokens, ),
|
||||
dtype=torch.int64,
|
||||
device=device)
|
||||
block_table_tensor = torch.randint(
|
||||
0,
|
||||
max_block_idx,
|
||||
(batch_spec.batch_size, max_blocks),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
slot_mapping = torch.randint(
|
||||
0, max_block_idx, (num_tokens,), dtype=torch.int64, device=device
|
||||
)
|
||||
|
||||
# Calculate max query length
|
||||
max_query_len = max(batch_spec.query_lens)
|
||||
@@ -121,31 +128,21 @@ def get_attention_backend(backend_name: _Backend):
|
||||
Tuple of (backend_builder_class, backend_impl_class)
|
||||
"""
|
||||
backend_map = {
|
||||
_Backend.FLASH_ATTN:
|
||||
("vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
|
||||
if current_platform.is_cuda() else
|
||||
"vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend"
|
||||
),
|
||||
_Backend.FLASHINFER:
|
||||
"vllm.v1.attention.backends.flashinfer.FlashInferBackend",
|
||||
_Backend.FLEX_ATTENTION:
|
||||
"vllm.v1.attention.backends.flex_attention.FlexAttentionBackend",
|
||||
_Backend.TRITON_ATTN:
|
||||
"vllm.v1.attention.backends.triton_attn.TritonAttentionBackend",
|
||||
_Backend.TREE_ATTN:
|
||||
"vllm.v1.attention.backends.tree_attn.TreeAttentionBackend",
|
||||
_Backend.XFORMERS:
|
||||
"vllm.v1.attention.backends.xformers.XFormersAttentionBackend",
|
||||
_Backend.CUTLASS_MLA:
|
||||
"vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend",
|
||||
_Backend.FLASHMLA:
|
||||
"vllm.v1.attention.backends.mla.flashmla.FlashMLABackend",
|
||||
_Backend.FLASH_ATTN_MLA:
|
||||
"vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend",
|
||||
_Backend.FLASHINFER_MLA:
|
||||
"vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend",
|
||||
_Backend.TRITON_MLA:
|
||||
"vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend",
|
||||
_Backend.FLASH_ATTN: (
|
||||
"vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
|
||||
if current_platform.is_cuda()
|
||||
else "vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend"
|
||||
),
|
||||
_Backend.FLASHINFER: "vllm.v1.attention.backends.flashinfer.FlashInferBackend",
|
||||
_Backend.FLEX_ATTENTION: "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend",
|
||||
_Backend.TRITON_ATTN: "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend",
|
||||
_Backend.TREE_ATTN: "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend",
|
||||
_Backend.XFORMERS: "vllm.v1.attention.backends.xformers.XFormersAttentionBackend",
|
||||
_Backend.CUTLASS_MLA: "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend",
|
||||
_Backend.FLASHMLA: "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend",
|
||||
_Backend.FLASH_ATTN_MLA: "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend",
|
||||
_Backend.FLASHINFER_MLA: "vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend",
|
||||
_Backend.TRITON_MLA: "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend",
|
||||
}
|
||||
|
||||
if backend_name not in backend_map:
|
||||
@@ -160,29 +157,31 @@ def get_attention_backend(backend_name: _Backend):
|
||||
pytest.skip(f"{backend_name} not available: {e}")
|
||||
|
||||
|
||||
def create_standard_kv_cache_spec(
|
||||
vllm_config: VllmConfig) -> FullAttentionSpec:
|
||||
def create_standard_kv_cache_spec(vllm_config: VllmConfig) -> FullAttentionSpec:
|
||||
"""Create a FullAttentionSpec from ModelParams only."""
|
||||
return FullAttentionSpec(
|
||||
block_size=vllm_config.cache_config.block_size,
|
||||
num_kv_heads=vllm_config.model_config.get_num_kv_heads(
|
||||
vllm_config.parallel_config),
|
||||
vllm_config.parallel_config
|
||||
),
|
||||
head_size=vllm_config.model_config.get_head_size(),
|
||||
dtype=vllm_config.model_config.dtype,
|
||||
sliding_window=vllm_config.model_config.get_sliding_window(),
|
||||
)
|
||||
|
||||
|
||||
def create_vllm_config(model_name: str = "meta-llama/Meta-Llama-3-8B",
|
||||
tensor_parallel_size: int = 1,
|
||||
max_model_len: int = 1024,
|
||||
dtype: Union[ModelDType, torch.dtype] = "auto",
|
||||
num_gpu_blocks: int = 1000,
|
||||
block_size: int = 16,
|
||||
max_num_seqs: int = 256,
|
||||
max_num_batched_tokens: int = 8192,
|
||||
enable_chunked_prefill: bool = True,
|
||||
add_mock_model_methods: bool = True) -> VllmConfig:
|
||||
def create_vllm_config(
|
||||
model_name: str = "meta-llama/Meta-Llama-3-8B",
|
||||
tensor_parallel_size: int = 1,
|
||||
max_model_len: int = 1024,
|
||||
dtype: Union[ModelDType, torch.dtype] = "auto",
|
||||
num_gpu_blocks: int = 1000,
|
||||
block_size: int = 16,
|
||||
max_num_seqs: int = 256,
|
||||
max_num_batched_tokens: int = 8192,
|
||||
enable_chunked_prefill: bool = True,
|
||||
add_mock_model_methods: bool = True,
|
||||
) -> VllmConfig:
|
||||
"""Create a VllmConfig for testing with reasonable defaults."""
|
||||
|
||||
model_config = ModelConfig(
|
||||
@@ -205,7 +204,8 @@ def create_vllm_config(model_name: str = "meta-llama/Meta-Llama-3-8B",
|
||||
cache_config.num_cpu_blocks = 0
|
||||
|
||||
parallel_config = ParallelConfig(
|
||||
tensor_parallel_size=tensor_parallel_size, )
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
)
|
||||
|
||||
scheduler_config = SchedulerConfig(
|
||||
max_num_seqs=max_num_seqs,
|
||||
@@ -223,15 +223,17 @@ def create_vllm_config(model_name: str = "meta-llama/Meta-Llama-3-8B",
|
||||
# but some backends expect to query the model for layer-specific
|
||||
# parameters
|
||||
import types
|
||||
model_config.get_num_layers = types.MethodType(lambda self: 1,
|
||||
model_config)
|
||||
|
||||
model_config.get_num_layers = types.MethodType(lambda self: 1, model_config)
|
||||
model_config.get_sliding_window_for_layer = types.MethodType(
|
||||
lambda self, i: None, model_config)
|
||||
lambda self, i: None, model_config
|
||||
)
|
||||
model_config.get_logits_soft_cap_for_layer = types.MethodType(
|
||||
lambda self, i: 0.0, model_config)
|
||||
lambda self, i: 0.0, model_config
|
||||
)
|
||||
model_config.get_sm_scale_for_layer = types.MethodType(
|
||||
lambda self, i: 1.0 / model_config.get_head_size()**0.5,
|
||||
model_config)
|
||||
lambda self, i: 1.0 / model_config.get_head_size() ** 0.5, model_config
|
||||
)
|
||||
|
||||
return VllmConfig(
|
||||
model_config=model_config,
|
||||
@@ -244,12 +246,14 @@ def create_vllm_config(model_name: str = "meta-llama/Meta-Llama-3-8B",
|
||||
)
|
||||
|
||||
|
||||
def create_dummy_kv_cache(block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
num_blocks: int = 100) -> torch.Tensor:
|
||||
def create_dummy_kv_cache(
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
num_blocks: int = 100,
|
||||
) -> torch.Tensor:
|
||||
"""Create a dummy KV cache tensor for testing."""
|
||||
kv_cache = torch.randn(
|
||||
num_blocks,
|
||||
@@ -258,7 +262,8 @@ def create_dummy_kv_cache(block_size: int,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
device=device,
|
||||
)
|
||||
return kv_cache
|
||||
|
||||
|
||||
@@ -273,75 +278,80 @@ class BackendConfig:
|
||||
# Define all backend configurations of full cudagraph to be tested
|
||||
full_cg_backend_configs = {
|
||||
# FA3 on Hopper
|
||||
"FA3":
|
||||
BackendConfig(name="FA3",
|
||||
env_vars={
|
||||
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN",
|
||||
"VLLM_FLASH_ATTN_VERSION": "3",
|
||||
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
|
||||
},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL",
|
||||
},
|
||||
specific_gpu_arch=(9, 0)),
|
||||
# FlashMLA on Hopper
|
||||
"FlashMLA":
|
||||
BackendConfig(name="FlashMLA",
|
||||
env_vars={
|
||||
"VLLM_ATTENTION_BACKEND": "FLASHMLA",
|
||||
},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||
},
|
||||
specific_gpu_arch=(9, 0)),
|
||||
# Cutlass MLA on Blackwell
|
||||
"CutlassMLA":
|
||||
BackendConfig(
|
||||
name="CutlassMLA",
|
||||
"FA3": BackendConfig(
|
||||
name="FA3",
|
||||
env_vars={
|
||||
"VLLM_USE_V1": "1",
|
||||
"VLLM_ATTENTION_BACKEND": "CUTLASS_MLA",
|
||||
"FORCE_NUM_KV_SPLITS":
|
||||
"1", # TODO: remove this when hang issue is fixed
|
||||
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN",
|
||||
"VLLM_FLASH_ATTN_VERSION": "3",
|
||||
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
|
||||
},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL",
|
||||
},
|
||||
specific_gpu_arch=(9, 0),
|
||||
),
|
||||
# FlashMLA on Hopper
|
||||
"FlashMLA": BackendConfig(
|
||||
name="FlashMLA",
|
||||
env_vars={
|
||||
"VLLM_ATTENTION_BACKEND": "FLASHMLA",
|
||||
},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||
},
|
||||
specific_gpu_arch=(10, 0)),
|
||||
specific_gpu_arch=(9, 0),
|
||||
),
|
||||
# Cutlass MLA on Blackwell
|
||||
"CutlassMLA": BackendConfig(
|
||||
name="CutlassMLA",
|
||||
env_vars={
|
||||
"VLLM_USE_V1": "1",
|
||||
"VLLM_ATTENTION_BACKEND": "CUTLASS_MLA",
|
||||
"FORCE_NUM_KV_SPLITS": "1", # TODO: remove this when hang issue is fixed
|
||||
},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||
},
|
||||
specific_gpu_arch=(10, 0),
|
||||
),
|
||||
# FlashAttention MLA on Hopper
|
||||
"FlashAttentionMLA":
|
||||
BackendConfig(name="FlashAttentionMLA",
|
||||
env_vars={
|
||||
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA",
|
||||
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
|
||||
},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_DECODE_ONLY",
|
||||
},
|
||||
specific_gpu_arch=(9, 0)),
|
||||
"FlashAttentionMLA": BackendConfig(
|
||||
name="FlashAttentionMLA",
|
||||
env_vars={
|
||||
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA",
|
||||
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
|
||||
},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_DECODE_ONLY",
|
||||
},
|
||||
specific_gpu_arch=(9, 0),
|
||||
),
|
||||
# FA2
|
||||
"FA2":
|
||||
BackendConfig(name="FA2",
|
||||
env_vars={
|
||||
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN",
|
||||
"VLLM_FLASH_ATTN_VERSION": "2",
|
||||
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
|
||||
},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||
}),
|
||||
"FA2": BackendConfig(
|
||||
name="FA2",
|
||||
env_vars={
|
||||
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN",
|
||||
"VLLM_FLASH_ATTN_VERSION": "2",
|
||||
"VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
|
||||
},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||
},
|
||||
),
|
||||
# Triton Attention
|
||||
"TritonAttn":
|
||||
BackendConfig(name="TritonAttn",
|
||||
env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN"},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||
}),
|
||||
"TritonAttn": BackendConfig(
|
||||
name="TritonAttn",
|
||||
env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN"},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||
},
|
||||
),
|
||||
# FlashInfer
|
||||
"FlashInfer":
|
||||
BackendConfig(name="FlashInfer",
|
||||
env_vars={"VLLM_ATTENTION_BACKEND": "FLASHINFER"},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||
}),
|
||||
"FlashInfer": BackendConfig(
|
||||
name="FlashInfer",
|
||||
env_vars={"VLLM_ATTENTION_BACKEND": "FLASHINFER"},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||
},
|
||||
),
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user