Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -3,8 +3,7 @@
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.utils import (create_kv_caches_with_random,
|
||||
create_kv_caches_with_random_flash)
|
||||
from vllm.utils import create_kv_caches_with_random, create_kv_caches_with_random_flash
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
|
||||
@@ -39,7 +39,7 @@ def ref_paged_attn(
|
||||
for i in range(num_seqs):
|
||||
query_len = query_lens[i]
|
||||
kv_len = kv_lens[i]
|
||||
q = query[start_idx:start_idx + query_len]
|
||||
q = query[start_idx : start_idx + query_len]
|
||||
q *= scale
|
||||
|
||||
num_kv_blocks = (kv_len + block_size - 1) // block_size
|
||||
@@ -57,10 +57,13 @@ def ref_paged_attn(
|
||||
empty_mask = torch.ones(query_len, kv_len)
|
||||
mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool()
|
||||
if sliding_window is not None:
|
||||
sliding_window_mask = torch.triu(empty_mask,
|
||||
diagonal=kv_len -
|
||||
(query_len + sliding_window) +
|
||||
1).bool().logical_not()
|
||||
sliding_window_mask = (
|
||||
torch.triu(
|
||||
empty_mask, diagonal=kv_len - (query_len + sliding_window) + 1
|
||||
)
|
||||
.bool()
|
||||
.logical_not()
|
||||
)
|
||||
mask |= sliding_window_mask
|
||||
if soft_cap is not None:
|
||||
attn = soft_cap * torch.tanh(attn / soft_cap)
|
||||
@@ -74,11 +77,10 @@ def ref_paged_attn(
|
||||
return torch.cat(outputs, dim=0)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_rocm(),
|
||||
reason="Only ROCm is supported")
|
||||
@pytest.mark.parametrize("seq_lens",
|
||||
[[(10, 1328), (5, 18),
|
||||
(129, 463)], [(8, 523), (24, 37), (3, 2011)]])
|
||||
@pytest.mark.skipif(not current_platform.is_rocm(), reason="Only ROCm is supported")
|
||||
@pytest.mark.parametrize(
|
||||
"seq_lens", [[(10, 1328), (5, 18), (129, 463)], [(8, 523), (24, 37), (3, 2011)]]
|
||||
)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@@ -109,34 +111,27 @@ def test_varlen_with_paged_kv(
|
||||
assert num_query_heads % num_kv_heads == 0
|
||||
max_query_len = max(query_lens)
|
||||
max_kv_len = max(kv_lens)
|
||||
window_size = ((sliding_window - 1, 0) if sliding_window is not None else
|
||||
(-1, -1))
|
||||
window_size = (sliding_window - 1, 0) if sliding_window is not None else (-1, -1)
|
||||
scale = head_size**-0.5
|
||||
|
||||
query = torch.randn(sum(query_lens),
|
||||
num_query_heads,
|
||||
head_size,
|
||||
dtype=dtype)
|
||||
key_cache = torch.randn(num_blocks,
|
||||
block_size,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
dtype=dtype)
|
||||
query = torch.randn(sum(query_lens), num_query_heads, head_size, dtype=dtype)
|
||||
key_cache = torch.randn(
|
||||
num_blocks, block_size, num_kv_heads, head_size, dtype=dtype
|
||||
)
|
||||
value_cache = torch.randn_like(key_cache)
|
||||
cu_query_lens = torch.tensor([0] + query_lens,
|
||||
dtype=torch.int32).cumsum(dim=0,
|
||||
dtype=torch.int32)
|
||||
cu_query_lens = torch.tensor([0] + query_lens, dtype=torch.int32).cumsum(
|
||||
dim=0, dtype=torch.int32
|
||||
)
|
||||
|
||||
cu_seq_lens = torch.tensor([0] + kv_lens,
|
||||
dtype=torch.int32).cumsum(dim=0,
|
||||
dtype=torch.int32)
|
||||
cu_seq_lens = torch.tensor([0] + kv_lens, dtype=torch.int32).cumsum(
|
||||
dim=0, dtype=torch.int32
|
||||
)
|
||||
kv_lens = torch.tensor(kv_lens, dtype=torch.int32)
|
||||
|
||||
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
|
||||
block_tables = torch.randint(0,
|
||||
num_blocks,
|
||||
(num_seqs, max_num_blocks_per_seq),
|
||||
dtype=torch.int32)
|
||||
block_tables = torch.randint(
|
||||
0, num_blocks, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32
|
||||
)
|
||||
|
||||
output = torch.empty_like(query)
|
||||
|
||||
@@ -187,5 +182,7 @@ def test_varlen_with_paged_kv(
|
||||
atol, rtol = 2e-2, 2e-2
|
||||
if q_dtype is not None:
|
||||
atol, rtol = 1.5e-1, 1.5e-1
|
||||
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), \
|
||||
f"{torch.max(torch.abs(output - ref_output))}"
|
||||
(
|
||||
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol),
|
||||
f"{torch.max(torch.abs(output - ref_output))}",
|
||||
)
|
||||
|
||||
@@ -42,9 +42,7 @@ BLOCK_SIZES = [16, 32]
|
||||
USE_ALIBI = [False, True]
|
||||
KV_CACHE_DTYPE = ["auto", "fp8"]
|
||||
SEEDS = [0]
|
||||
CUDA_DEVICES = [
|
||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||
]
|
||||
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
||||
|
||||
|
||||
def ref_masked_attention(
|
||||
@@ -110,8 +108,7 @@ def ref_single_query_cached_kv_attention(
|
||||
# Create the ALiBi bias used in the paged attention kernel.
|
||||
position_ids = torch.arange(seq_len).int()
|
||||
alibi_bias = (position_ids - seq_len + 1).float()
|
||||
alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(
|
||||
1, 1, -1)
|
||||
alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(1, 1, -1)
|
||||
|
||||
out = ref_masked_attention(q, keys, values, scale, alibi_bias)
|
||||
out = out.view(num_query_heads, head_size)
|
||||
@@ -119,8 +116,8 @@ def ref_single_query_cached_kv_attention(
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"version",
|
||||
["v1", "v2"] if not current_platform.is_rocm() else ["v1", "v2", "rocm"])
|
||||
"version", ["v1", "v2"] if not current_platform.is_rocm() else ["v1", "v2", "rocm"]
|
||||
)
|
||||
@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@@ -143,13 +140,18 @@ def test_paged_attention(
|
||||
seed: int,
|
||||
device: str,
|
||||
) -> None:
|
||||
if ((kv_cache_dtype == "fp8" and head_size % 16)
|
||||
or (version == "rocm" and head_size not in (64, 128))):
|
||||
if (kv_cache_dtype == "fp8" and head_size % 16) or (
|
||||
version == "rocm" and head_size not in (64, 128)
|
||||
):
|
||||
pytest.skip()
|
||||
|
||||
if (version == "rocm" and current_platform.is_navi()
|
||||
and (kv_cache_dtype == "fp8" or head_size != 128
|
||||
or block_size != 16 or use_alibi)):
|
||||
if (
|
||||
version == "rocm"
|
||||
and current_platform.is_navi()
|
||||
and (
|
||||
kv_cache_dtype == "fp8" or head_size != 128 or block_size != 16 or use_alibi
|
||||
)
|
||||
):
|
||||
pytest.skip()
|
||||
|
||||
global PARTITION_SIZE
|
||||
@@ -177,18 +179,24 @@ def test_paged_attention(
|
||||
block_tables_lst: list[list[int]] = []
|
||||
for _ in range(num_seqs):
|
||||
block_table = [
|
||||
random.randint(0, NUM_BLOCKS - 1)
|
||||
for _ in range(max_num_blocks_per_seq)
|
||||
random.randint(0, NUM_BLOCKS - 1) for _ in range(max_num_blocks_per_seq)
|
||||
]
|
||||
block_tables_lst.append(block_table)
|
||||
|
||||
block_tables = torch.tensor(block_tables_lst, dtype=torch.int)
|
||||
|
||||
# Create the KV caches.
|
||||
key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
|
||||
num_kv_heads, head_size,
|
||||
kv_cache_dtype, dtype, seed,
|
||||
device)
|
||||
key_caches, value_caches = kv_cache_factory(
|
||||
NUM_BLOCKS,
|
||||
block_size,
|
||||
1,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
kv_cache_dtype,
|
||||
dtype,
|
||||
seed,
|
||||
device,
|
||||
)
|
||||
key_cache, value_cache = key_caches[0], value_caches[0]
|
||||
|
||||
# Using default kv_scale
|
||||
@@ -214,18 +222,37 @@ def test_paged_attention(
|
||||
v_scale,
|
||||
)
|
||||
|
||||
opcheck(torch.ops._C.paged_attention_v1,
|
||||
(output, query, key_cache, value_cache, num_kv_heads, scale,
|
||||
block_tables, seq_lens, block_size, max_seq_len, alibi_slopes,
|
||||
kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
|
||||
cond=(head_size == HEAD_SIZES[0]
|
||||
and block_size == BLOCK_SIZES[0]))
|
||||
opcheck(
|
||||
torch.ops._C.paged_attention_v1,
|
||||
(
|
||||
output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
seq_lens,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
64,
|
||||
0,
|
||||
),
|
||||
cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0]),
|
||||
)
|
||||
|
||||
elif version in ("v2", "rocm"):
|
||||
if current_platform.is_rocm() and version == "rocm":
|
||||
PARTITION_SIZE = PARTITION_SIZE_ROCM
|
||||
|
||||
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
|
||||
num_partitions = (max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE
|
||||
assert PARTITION_SIZE % block_size == 0
|
||||
num_seqs, num_heads, head_size = output.shape
|
||||
tmp_output = torch.empty(
|
||||
@@ -258,13 +285,34 @@ def test_paged_attention(
|
||||
v_scale,
|
||||
)
|
||||
|
||||
opcheck(torch.ops._C.paged_attention_v2,
|
||||
(output, exp_sums, max_logits, tmp_output, query,
|
||||
key_cache, value_cache, num_kv_heads, scale, block_tables,
|
||||
seq_lens, block_size, max_seq_len, alibi_slopes,
|
||||
kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
|
||||
cond=(head_size == HEAD_SIZES[0]
|
||||
and block_size == BLOCK_SIZES[0]))
|
||||
opcheck(
|
||||
torch.ops._C.paged_attention_v2,
|
||||
(
|
||||
output,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
seq_lens,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
64,
|
||||
0,
|
||||
),
|
||||
cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0]),
|
||||
)
|
||||
|
||||
else:
|
||||
ops.paged_attention_rocm(
|
||||
@@ -288,13 +336,30 @@ def test_paged_attention(
|
||||
v_scale,
|
||||
)
|
||||
|
||||
opcheck(torch.ops._rocm_C.paged_attention,
|
||||
(output, exp_sums, max_logits, tmp_output, query,
|
||||
key_cache, value_cache, num_kv_heads, scale, block_tables,
|
||||
seq_lens, None, block_size, max_seq_len, alibi_slopes,
|
||||
kv_cache_dtype, k_scale, v_scale),
|
||||
cond=(head_size == HEAD_SIZES[0]
|
||||
and block_size == BLOCK_SIZES[0]))
|
||||
opcheck(
|
||||
torch.ops._rocm_C.paged_attention,
|
||||
(
|
||||
output,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
seq_lens,
|
||||
None,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
),
|
||||
cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0]),
|
||||
)
|
||||
|
||||
else:
|
||||
raise AssertionError(f"Unknown version: {version}")
|
||||
@@ -303,18 +368,17 @@ def test_paged_attention(
|
||||
if kv_cache_dtype == "fp8":
|
||||
# Convert cache data back to dtype.
|
||||
x = 16 // torch.tensor([], dtype=dtype).element_size()
|
||||
key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x,
|
||||
block_size, x)
|
||||
dequantized_key_cache = torch.empty(size=key_cache_shape,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, block_size, x)
|
||||
dequantized_key_cache = torch.empty(
|
||||
size=key_cache_shape, dtype=dtype, device=device
|
||||
)
|
||||
ops.convert_fp8(dequantized_key_cache, key_cache)
|
||||
key_cache = dequantized_key_cache
|
||||
|
||||
value_cache_shape = value_cache.shape
|
||||
dequantized_value_cache = torch.empty(size=value_cache_shape,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
dequantized_value_cache = torch.empty(
|
||||
size=value_cache_shape, dtype=dtype, device=device
|
||||
)
|
||||
ops.convert_fp8(dequantized_value_cache, value_cache)
|
||||
value_cache = dequantized_value_cache
|
||||
|
||||
@@ -367,8 +431,9 @@ def ref_multi_query_kv_attention(
|
||||
if alibi_bias:
|
||||
attn_mask = alibi_bias[i]
|
||||
else:
|
||||
attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
|
||||
diagonal=1)
|
||||
attn_mask = torch.triu(
|
||||
torch.ones(seq_len, seq_len, dtype=dtype), diagonal=1
|
||||
)
|
||||
attn_mask = attn_mask * torch.finfo(dtype).min
|
||||
attn_mask = attn_mask.to(dtype=dtype)
|
||||
|
||||
@@ -390,8 +455,9 @@ def ref_multi_query_kv_attention(
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.skipif(current_platform.is_rocm(),
|
||||
reason="Xformers backend is not supported on ROCm.")
|
||||
@pytest.mark.skipif(
|
||||
current_platform.is_rocm(), reason="Xformers backend is not supported on ROCm."
|
||||
)
|
||||
@torch.inference_mode()
|
||||
def test_multi_query_kv_attention(
|
||||
num_seqs: int,
|
||||
@@ -413,13 +479,11 @@ def test_multi_query_kv_attention(
|
||||
|
||||
scale = float(1.0 / (head_size**0.5))
|
||||
num_query_heads, num_kv_heads = num_heads
|
||||
qkv = torch.empty(num_tokens,
|
||||
num_query_heads + 2 * num_kv_heads,
|
||||
head_size,
|
||||
dtype=dtype)
|
||||
qkv = torch.empty(
|
||||
num_tokens, num_query_heads + 2 * num_kv_heads, head_size, dtype=dtype
|
||||
)
|
||||
qkv.uniform_(-scale, scale)
|
||||
query, key, value = qkv.split(
|
||||
[num_query_heads, num_kv_heads, num_kv_heads], dim=1)
|
||||
query, key, value = qkv.split([num_query_heads, num_kv_heads, num_kv_heads], dim=1)
|
||||
|
||||
num_queries_per_kv = num_query_heads // num_kv_heads
|
||||
if num_queries_per_kv > 1:
|
||||
@@ -429,8 +493,7 @@ def test_multi_query_kv_attention(
|
||||
alibi_bias = None
|
||||
if use_alibi:
|
||||
alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)
|
||||
attn_bias = make_alibi_bias(alibi_slopes, num_kv_heads, dtype,
|
||||
seq_lens)
|
||||
attn_bias = make_alibi_bias(alibi_slopes, num_kv_heads, dtype, seq_lens)
|
||||
output = torch.empty_like(query)
|
||||
start = 0
|
||||
# Dynamic sequence length not supported with custom attn_bias.
|
||||
@@ -442,7 +505,8 @@ def test_multi_query_kv_attention(
|
||||
value[None, start:end],
|
||||
attn_bias=attn_bias[i],
|
||||
p=0.0,
|
||||
scale=scale)
|
||||
scale=scale,
|
||||
)
|
||||
output[start:end].copy_(out.view_as(query[start:end]))
|
||||
start += seq_len
|
||||
# xformers.AttentionBias to Tensor for use in reference impl.
|
||||
@@ -485,8 +549,9 @@ def test_multi_query_kv_attention(
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.skipif(current_platform.is_rocm(),
|
||||
reason="Xformers backend is not supported on ROCm.")
|
||||
@pytest.mark.skipif(
|
||||
current_platform.is_rocm(), reason="Xformers backend is not supported on ROCm."
|
||||
)
|
||||
@torch.inference_mode()
|
||||
def test_multi_query_kv_attention_with_alibi(
|
||||
num_seqs: int,
|
||||
|
||||
@@ -15,16 +15,18 @@ from vllm.utils import STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL, STR_INVALID_VAL
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_cache():
|
||||
"""Clear lru cache to ensure each test case runs without caching.
|
||||
"""
|
||||
"""Clear lru cache to ensure each test case runs without caching."""
|
||||
_cached_get_attn_backend.cache_clear()
|
||||
|
||||
|
||||
# Define MLA and non-MLA backends separately
|
||||
DEVICE_MLA_BACKENDS = {
|
||||
"cuda": [
|
||||
"TRITON_MLA", "FLASHMLA", "FLASHINFER_MLA", "FLASH_ATTN_MLA",
|
||||
"CUTLASS_MLA"
|
||||
"TRITON_MLA",
|
||||
"FLASHMLA",
|
||||
"FLASHINFER_MLA",
|
||||
"FLASH_ATTN_MLA",
|
||||
"CUTLASS_MLA",
|
||||
],
|
||||
"hip": ["TRITON_MLA", "ROCM_AITER_MLA"],
|
||||
"cpu": [],
|
||||
@@ -40,7 +42,7 @@ DEVICE_MLA_BLOCK_SIZES = {
|
||||
"cuda": [16, 64], # CUDA supports both standard and extended block sizes
|
||||
"hip": [16, 1], # HIP requires special handling for block_size=1
|
||||
# "cpu": [16] # CPU uses fixed block size from test cases
|
||||
"cpu": [] # FIXME(woosuk): Temporarily disable CPU tests
|
||||
"cpu": [], # FIXME(woosuk): Temporarily disable CPU tests
|
||||
}
|
||||
|
||||
|
||||
@@ -48,12 +50,13 @@ def generate_params():
|
||||
params = []
|
||||
for use_mla in [True, False]:
|
||||
for device in ["cuda", "hip", "cpu"]:
|
||||
backends = DEVICE_MLA_BACKENDS[
|
||||
device] if use_mla else DEVICE_REGULAR_ATTN_BACKENDS[device]
|
||||
backends = (
|
||||
DEVICE_MLA_BACKENDS[device]
|
||||
if use_mla
|
||||
else DEVICE_REGULAR_ATTN_BACKENDS[device]
|
||||
)
|
||||
for name in backends:
|
||||
block_sizes = DEVICE_MLA_BLOCK_SIZES[device] if use_mla else [
|
||||
16
|
||||
]
|
||||
block_sizes = DEVICE_MLA_BLOCK_SIZES[device] if use_mla else [16]
|
||||
for block_size in block_sizes:
|
||||
params.append(
|
||||
pytest.param(
|
||||
@@ -61,14 +64,13 @@ def generate_params():
|
||||
name,
|
||||
use_mla,
|
||||
block_size,
|
||||
id=
|
||||
f"{device}_{name}_mla_{str(use_mla)[0]}_blks{block_size}"
|
||||
))
|
||||
id=f"{device}_{name}_mla_{str(use_mla)[0]}_blks{block_size}",
|
||||
)
|
||||
)
|
||||
return params
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device, name, use_mla, block_size",
|
||||
generate_params())
|
||||
@pytest.mark.parametrize("device, name, use_mla, block_size", generate_params())
|
||||
def test_env(
|
||||
device: str,
|
||||
name: str,
|
||||
@@ -83,14 +85,12 @@ def test_env(
|
||||
m.setenv("VLLM_MLA_DISABLE", "1" if use_mla else "0")
|
||||
|
||||
if device == "cpu":
|
||||
with patch("vllm.attention.selector.current_platform",
|
||||
CpuPlatform()):
|
||||
with patch("vllm.attention.selector.current_platform", CpuPlatform()):
|
||||
backend = get_attn_backend(16, torch.float16, None, block_size)
|
||||
assert backend.get_name() == "TORCH_SDPA"
|
||||
|
||||
elif device == "hip":
|
||||
with patch("vllm.attention.selector.current_platform",
|
||||
RocmPlatform()):
|
||||
with patch("vllm.attention.selector.current_platform", RocmPlatform()):
|
||||
if use_mla:
|
||||
# ROCm MLA backend logic:
|
||||
# - TRITON_MLA: supported when block_size != 1
|
||||
@@ -101,44 +101,33 @@ def test_env(
|
||||
if name == "TRITON_MLA" and block_size == 1:
|
||||
# TRITON_MLA doesn't support block_size == 1
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
get_attn_backend(16,
|
||||
torch.float16,
|
||||
None,
|
||||
block_size,
|
||||
use_mla=use_mla)
|
||||
assert f"The selected backend, {name}" in str(
|
||||
exc_info.value)
|
||||
get_attn_backend(
|
||||
16, torch.float16, None, block_size, use_mla=use_mla
|
||||
)
|
||||
assert f"The selected backend, {name}" in str(exc_info.value)
|
||||
elif name == "ROCM_AITER_MLA" and block_size != 1:
|
||||
# ROCM_AITER_MLA only supports block_size == 1
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
get_attn_backend(16,
|
||||
torch.float16,
|
||||
None,
|
||||
block_size,
|
||||
use_mla=use_mla)
|
||||
assert f"The selected backend, {name}" in str(
|
||||
exc_info.value)
|
||||
get_attn_backend(
|
||||
16, torch.float16, None, block_size, use_mla=use_mla
|
||||
)
|
||||
assert f"The selected backend, {name}" in str(exc_info.value)
|
||||
else:
|
||||
# Valid backend-block_size combination
|
||||
backend = get_attn_backend(16,
|
||||
torch.float16,
|
||||
None,
|
||||
block_size,
|
||||
use_mla=use_mla)
|
||||
backend = get_attn_backend(
|
||||
16, torch.float16, None, block_size, use_mla=use_mla
|
||||
)
|
||||
expected = name
|
||||
assert backend.get_name() == expected
|
||||
else:
|
||||
backend = get_attn_backend(16,
|
||||
torch.float16,
|
||||
None,
|
||||
block_size,
|
||||
use_mla=use_mla)
|
||||
backend = get_attn_backend(
|
||||
16, torch.float16, None, block_size, use_mla=use_mla
|
||||
)
|
||||
expected = "TRITON_ATTN"
|
||||
assert backend.get_name() == expected
|
||||
|
||||
elif device == "cuda":
|
||||
with patch("vllm.attention.selector.current_platform",
|
||||
CudaPlatform()):
|
||||
with patch("vllm.attention.selector.current_platform", CudaPlatform()):
|
||||
if use_mla:
|
||||
# CUDA MLA backend logic:
|
||||
# - CUTLASS_MLA: only supported with block_size == 128
|
||||
@@ -152,28 +141,23 @@ def test_env(
|
||||
if name == "CUTLASS_MLA":
|
||||
if block_size != 128:
|
||||
# CUTLASS_MLA only supports block_size == 128
|
||||
pytest.skip(
|
||||
"CUTLASS_MLA only supports block_size 128")
|
||||
pytest.skip("CUTLASS_MLA only supports block_size 128")
|
||||
else:
|
||||
backend = get_attn_backend(16,
|
||||
torch.float16,
|
||||
None,
|
||||
block_size,
|
||||
use_mla=use_mla)
|
||||
backend = get_attn_backend(
|
||||
16, torch.float16, None, block_size, use_mla=use_mla
|
||||
)
|
||||
expected = "CUTLASS_MLA"
|
||||
assert backend.get_name() == expected
|
||||
elif name == "FLASHINFER_MLA":
|
||||
if block_size not in [32, 64]:
|
||||
# FlashInfer MLA only supports block_size 32 or 64
|
||||
pytest.skip(
|
||||
"FlashInfer MLA only supports block_size 32 "
|
||||
"or 64")
|
||||
"FlashInfer MLA only supports block_size 32 or 64"
|
||||
)
|
||||
else:
|
||||
backend = get_attn_backend(16,
|
||||
torch.float16,
|
||||
None,
|
||||
block_size,
|
||||
use_mla=use_mla)
|
||||
backend = get_attn_backend(
|
||||
16, torch.float16, None, block_size, use_mla=use_mla
|
||||
)
|
||||
expected = "FLASHINFER_MLA"
|
||||
assert backend.get_name() == expected
|
||||
elif name == "FLASHMLA":
|
||||
@@ -182,58 +166,47 @@ def test_env(
|
||||
pytest.skip("FlashMLA only supports block_size 64")
|
||||
else:
|
||||
from vllm.v1.attention.backends.mla.flashmla import ( # noqa: E501
|
||||
is_flashmla_supported)
|
||||
is_flashmla_supported,
|
||||
)
|
||||
|
||||
is_supported, _ = is_flashmla_supported()
|
||||
if not is_supported:
|
||||
pytest.skip(
|
||||
"FlashMLA not supported on this platform")
|
||||
pytest.skip("FlashMLA not supported on this platform")
|
||||
else:
|
||||
backend = get_attn_backend(16,
|
||||
torch.float16,
|
||||
None,
|
||||
block_size,
|
||||
use_mla=use_mla)
|
||||
backend = get_attn_backend(
|
||||
16, torch.float16, None, block_size, use_mla=use_mla
|
||||
)
|
||||
expected = name
|
||||
assert backend.get_name() == expected
|
||||
elif name == "FLASH_ATTN_MLA":
|
||||
backend = get_attn_backend(16,
|
||||
torch.float16,
|
||||
None,
|
||||
block_size,
|
||||
use_mla=use_mla)
|
||||
backend = get_attn_backend(
|
||||
16, torch.float16, None, block_size, use_mla=use_mla
|
||||
)
|
||||
expected = "FLASH_ATTN_MLA"
|
||||
assert backend.get_name() == expected
|
||||
else:
|
||||
# TRITON_MLA or other fallback
|
||||
backend = get_attn_backend(16,
|
||||
torch.float16,
|
||||
None,
|
||||
block_size,
|
||||
use_mla=use_mla)
|
||||
backend = get_attn_backend(
|
||||
16, torch.float16, None, block_size, use_mla=use_mla
|
||||
)
|
||||
expected = "TRITON_MLA"
|
||||
assert backend.get_name() == expected
|
||||
elif name == "FLASHINFER":
|
||||
backend = get_attn_backend(16,
|
||||
torch.float16,
|
||||
None,
|
||||
block_size,
|
||||
use_mla=use_mla)
|
||||
backend = get_attn_backend(
|
||||
16, torch.float16, None, block_size, use_mla=use_mla
|
||||
)
|
||||
expected = "FLASHINFER"
|
||||
assert backend.get_name() == expected
|
||||
elif name == "XFORMERS":
|
||||
backend = get_attn_backend(32,
|
||||
torch.float16,
|
||||
None,
|
||||
block_size,
|
||||
use_mla=use_mla)
|
||||
backend = get_attn_backend(
|
||||
32, torch.float16, None, block_size, use_mla=use_mla
|
||||
)
|
||||
expected = "XFORMERS"
|
||||
assert backend.get_name() == expected
|
||||
elif name == "FLASH_ATTN":
|
||||
backend = get_attn_backend(32,
|
||||
torch.float16,
|
||||
None,
|
||||
block_size,
|
||||
use_mla=use_mla)
|
||||
backend = get_attn_backend(
|
||||
32, torch.float16, None, block_size, use_mla=use_mla
|
||||
)
|
||||
expected = "FLASH_ATTN"
|
||||
assert backend.get_name() == expected
|
||||
|
||||
@@ -248,14 +221,12 @@ def test_fp32_fallback(
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
if device == "cpu":
|
||||
with patch("vllm.attention.selector.current_platform",
|
||||
CpuPlatform()):
|
||||
with patch("vllm.attention.selector.current_platform", CpuPlatform()):
|
||||
backend = get_attn_backend(16, torch.float32, None, 16)
|
||||
assert backend.get_name() == "TORCH_SDPA"
|
||||
|
||||
elif device == "cuda":
|
||||
with patch("vllm.attention.selector.current_platform",
|
||||
CudaPlatform()):
|
||||
with patch("vllm.attention.selector.current_platform", CudaPlatform()):
|
||||
backend = get_attn_backend(16, torch.float32, None, 16)
|
||||
assert backend.get_name() == "FLEX_ATTENTION"
|
||||
|
||||
@@ -265,16 +236,16 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
|
||||
# TODO: When testing for v1, pipe in `use_v1` as an argument to
|
||||
# get_attn_backend
|
||||
|
||||
pytest.skip("Skipping as current backend selector does not " \
|
||||
"handle fallbacks when a backend is set via env var.")
|
||||
pytest.skip(
|
||||
"Skipping as current backend selector does not "
|
||||
"handle fallbacks when a backend is set via env var."
|
||||
)
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv(STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL)
|
||||
|
||||
# Unsupported CUDA arch
|
||||
monkeypatch.setattr(torch.cuda,
|
||||
"get_device_capability",
|
||||
lambda _=None: (7, 5))
|
||||
monkeypatch.setattr(torch.cuda, "get_device_capability", lambda _=None: (7, 5))
|
||||
backend = get_attn_backend(16, torch.float16, None, 16)
|
||||
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
||||
|
||||
@@ -295,17 +266,17 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
|
||||
|
||||
# flash-attn is not installed
|
||||
import sys
|
||||
original_module = sys.modules.get('vllm_flash_attn')
|
||||
monkeypatch.setitem(sys.modules, 'vllm_flash_attn', None)
|
||||
|
||||
original_module = sys.modules.get("vllm_flash_attn")
|
||||
monkeypatch.setitem(sys.modules, "vllm_flash_attn", None)
|
||||
backend = get_attn_backend(16, torch.float16, None, 16)
|
||||
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
||||
|
||||
# Restore the original module if it existed
|
||||
if original_module is not None:
|
||||
monkeypatch.setitem(sys.modules, 'vllm_flash_attn',
|
||||
original_module)
|
||||
monkeypatch.setitem(sys.modules, "vllm_flash_attn", original_module)
|
||||
else:
|
||||
monkeypatch.delitem(sys.modules, 'vllm_flash_attn', raising=False)
|
||||
monkeypatch.delitem(sys.modules, "vllm_flash_attn", raising=False)
|
||||
|
||||
# Unsupported head size
|
||||
backend = get_attn_backend(17, torch.float16, None, 16)
|
||||
@@ -314,8 +285,10 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
|
||||
|
||||
def test_invalid_env(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Test that invalid attention backend names raise ValueError."""
|
||||
with monkeypatch.context() as m, patch(
|
||||
"vllm.attention.selector.current_platform", CudaPlatform()):
|
||||
with (
|
||||
monkeypatch.context() as m,
|
||||
patch("vllm.attention.selector.current_platform", CudaPlatform()),
|
||||
):
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
m.setenv(STR_BACKEND_ENV_VAR, STR_INVALID_VAL)
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')]
|
||||
COPYING_DIRECTION = [("cuda", "cpu"), ("cuda", "cuda"), ("cpu", "cuda")]
|
||||
DTYPES = [torch.bfloat16, torch.float]
|
||||
NUM_TOKENS = [42] # Arbitrary values for testing
|
||||
NUM_LAYERS = [1] # Arbitrary values for testing
|
||||
@@ -32,9 +32,7 @@ NUM_BLOCKS = [1024, 10000]
|
||||
|
||||
NUM_MAPPINGS = [256] # Arbitrary values for testing
|
||||
SEEDS = [0]
|
||||
CUDA_DEVICES = [
|
||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||
]
|
||||
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
||||
|
||||
# We assume fp8 is always enabled for testing.
|
||||
KV_CACHE_DTYPE = ["auto", "fp8"]
|
||||
@@ -85,24 +83,33 @@ def test_copy_blocks(
|
||||
block_mapping.append((src, dst2))
|
||||
|
||||
# Create the KV caches.
|
||||
key_caches, value_caches = kv_cache_factory(num_blocks, block_size,
|
||||
num_layers, num_heads,
|
||||
head_size, kv_cache_dtype,
|
||||
dtype, seed, device)
|
||||
key_caches, value_caches = kv_cache_factory(
|
||||
num_blocks,
|
||||
block_size,
|
||||
num_layers,
|
||||
num_heads,
|
||||
head_size,
|
||||
kv_cache_dtype,
|
||||
dtype,
|
||||
seed,
|
||||
device,
|
||||
)
|
||||
|
||||
# Clone the KV caches.
|
||||
cloned_key_caches = [key_cache.clone() for key_cache in key_caches]
|
||||
cloned_value_caches = [value_cache.clone() for value_cache in value_caches]
|
||||
|
||||
# Call the copy blocks kernel.
|
||||
block_mapping_tensor = torch.tensor(block_mapping,
|
||||
dtype=torch.int64,
|
||||
device=device).view(-1, 2)
|
||||
block_mapping_tensor = torch.tensor(
|
||||
block_mapping, dtype=torch.int64, device=device
|
||||
).view(-1, 2)
|
||||
|
||||
opcheck(torch.ops._C_cache_ops.copy_blocks,
|
||||
(key_caches, value_caches, block_mapping_tensor),
|
||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
|
||||
cond=(head_size == HEAD_SIZES[0]))
|
||||
opcheck(
|
||||
torch.ops._C_cache_ops.copy_blocks,
|
||||
(key_caches, value_caches, block_mapping_tensor),
|
||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
|
||||
cond=(head_size == HEAD_SIZES[0]),
|
||||
)
|
||||
ops.copy_blocks(key_caches, value_caches, block_mapping_tensor)
|
||||
|
||||
# Run the reference implementation.
|
||||
@@ -115,8 +122,7 @@ def test_copy_blocks(
|
||||
# Compare the results.
|
||||
for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
|
||||
torch.testing.assert_close(key_cache, cloned_key_cache)
|
||||
for value_cache, cloned_value_cache in zip(value_caches,
|
||||
cloned_value_caches):
|
||||
for value_cache, cloned_value_cache in zip(value_caches, cloned_value_caches):
|
||||
torch.testing.assert_close(value_cache, cloned_value_cache)
|
||||
|
||||
|
||||
@@ -155,10 +161,17 @@ def test_reshape_and_cache(
|
||||
_, key, value = qkv.unbind(dim=1)
|
||||
|
||||
# Create the KV caches.
|
||||
key_caches, value_caches = kv_cache_factory(num_blocks, block_size, 1,
|
||||
num_heads, head_size,
|
||||
kv_cache_dtype, dtype, seed,
|
||||
device)
|
||||
key_caches, value_caches = kv_cache_factory(
|
||||
num_blocks,
|
||||
block_size,
|
||||
1,
|
||||
num_heads,
|
||||
head_size,
|
||||
kv_cache_dtype,
|
||||
dtype,
|
||||
seed,
|
||||
device,
|
||||
)
|
||||
key_cache, value_cache = key_caches[0], value_caches[0]
|
||||
|
||||
# Using default kv_scale
|
||||
@@ -176,12 +189,30 @@ def test_reshape_and_cache(
|
||||
cloned_value_cache = value_cache.clone()
|
||||
|
||||
# Call the reshape_and_cache kernel.
|
||||
opcheck(torch.ops._C_cache_ops.reshape_and_cache,
|
||||
(key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype,
|
||||
k_scale, v_scale),
|
||||
cond=(head_size == HEAD_SIZES[0]))
|
||||
ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping,
|
||||
kv_cache_dtype, k_scale, v_scale)
|
||||
opcheck(
|
||||
torch.ops._C_cache_ops.reshape_and_cache,
|
||||
(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping,
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
),
|
||||
cond=(head_size == HEAD_SIZES[0]),
|
||||
)
|
||||
ops.reshape_and_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping,
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
|
||||
if kv_cache_dtype == "fp8":
|
||||
result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
|
||||
@@ -202,14 +233,12 @@ def test_reshape_and_cache(
|
||||
cloned_value_cache[block_idx, :, :, block_offset] = value[i]
|
||||
|
||||
if kv_cache_dtype == "fp8":
|
||||
torch.testing.assert_close(result_key_cache,
|
||||
cloned_key_cache,
|
||||
atol=0.001,
|
||||
rtol=0.1)
|
||||
torch.testing.assert_close(result_value_cache,
|
||||
cloned_value_cache,
|
||||
atol=0.001,
|
||||
rtol=0.1)
|
||||
torch.testing.assert_close(
|
||||
result_key_cache, cloned_key_cache, atol=0.001, rtol=0.1
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
result_value_cache, cloned_value_cache, atol=0.001, rtol=0.1
|
||||
)
|
||||
else:
|
||||
torch.testing.assert_close(key_cache, cloned_key_cache)
|
||||
torch.testing.assert_close(value_cache, cloned_value_cache)
|
||||
@@ -254,15 +283,8 @@ def test_reshape_and_cache_flash(
|
||||
# Create a random slot mapping.
|
||||
num_slots = block_size * num_blocks
|
||||
slot_mapping_lst = random.sample(range(num_slots), num_tokens)
|
||||
slot_mapping = torch.tensor(slot_mapping_lst,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
qkv = torch.randn(num_tokens,
|
||||
3,
|
||||
num_heads,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device)
|
||||
qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype, device=device)
|
||||
_, key, value = qkv.unbind(dim=1)
|
||||
|
||||
# Create the KV caches.
|
||||
@@ -293,48 +315,73 @@ def test_reshape_and_cache_flash(
|
||||
|
||||
# Clone the KV caches.
|
||||
if kv_cache_dtype == "fp8":
|
||||
cloned_key_cache = torch.empty_like(key_cache_compact,
|
||||
dtype=torch.float16)
|
||||
ops.convert_fp8(cloned_key_cache, key_cache_compact, k_scale.item(),
|
||||
kv_cache_dtype)
|
||||
cloned_value_cache = torch.empty_like(value_cache_compact,
|
||||
dtype=torch.float16)
|
||||
ops.convert_fp8(cloned_value_cache, value_cache_compact,
|
||||
v_scale.item(), kv_cache_dtype)
|
||||
cloned_key_cache = torch.empty_like(key_cache_compact, dtype=torch.float16)
|
||||
ops.convert_fp8(
|
||||
cloned_key_cache, key_cache_compact, k_scale.item(), kv_cache_dtype
|
||||
)
|
||||
cloned_value_cache = torch.empty_like(value_cache_compact, dtype=torch.float16)
|
||||
ops.convert_fp8(
|
||||
cloned_value_cache, value_cache_compact, v_scale.item(), kv_cache_dtype
|
||||
)
|
||||
else:
|
||||
cloned_key_cache = key_cache_compact.clone()
|
||||
cloned_value_cache = value_cache_compact.clone()
|
||||
# Call the reshape_and_cache kernel.
|
||||
if implementation == "cuda":
|
||||
opcheck(torch.ops._C_cache_ops.reshape_and_cache_flash,
|
||||
(key, value, key_cache, value_cache, slot_mapping,
|
||||
kv_cache_dtype, k_scale, v_scale),
|
||||
cond=(head_size == HEAD_SIZES[0]))
|
||||
ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
|
||||
slot_mapping, kv_cache_dtype, k_scale,
|
||||
v_scale)
|
||||
opcheck(
|
||||
torch.ops._C_cache_ops.reshape_and_cache_flash,
|
||||
(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping,
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
),
|
||||
cond=(head_size == HEAD_SIZES[0]),
|
||||
)
|
||||
ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping,
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
elif implementation == "triton":
|
||||
from vllm.attention.ops.triton_reshape_and_cache_flash import (
|
||||
triton_reshape_and_cache_flash)
|
||||
triton_reshape_and_cache_flash(key, value, key_cache, value_cache,
|
||||
slot_mapping, kv_cache_dtype, k_scale,
|
||||
v_scale)
|
||||
triton_reshape_and_cache_flash,
|
||||
)
|
||||
|
||||
triton_reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping,
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
key_cache_compact = permute_and_compact(key_cache)
|
||||
value_cache_compact = permute_and_compact(value_cache)
|
||||
|
||||
if kv_cache_dtype == "fp8":
|
||||
result_key_cache = torch.empty_like(key_cache_compact,
|
||||
dtype=torch.float16)
|
||||
ops.convert_fp8(result_key_cache,
|
||||
key_cache_compact,
|
||||
k_scale.item(),
|
||||
kv_dtype=kv_cache_dtype)
|
||||
result_value_cache = torch.empty_like(value_cache_compact,
|
||||
dtype=torch.float16)
|
||||
ops.convert_fp8(result_value_cache,
|
||||
value_cache_compact,
|
||||
v_scale.item(),
|
||||
kv_dtype=kv_cache_dtype)
|
||||
result_key_cache = torch.empty_like(key_cache_compact, dtype=torch.float16)
|
||||
ops.convert_fp8(
|
||||
result_key_cache, key_cache_compact, k_scale.item(), kv_dtype=kv_cache_dtype
|
||||
)
|
||||
result_value_cache = torch.empty_like(value_cache_compact, dtype=torch.float16)
|
||||
ops.convert_fp8(
|
||||
result_value_cache,
|
||||
value_cache_compact,
|
||||
v_scale.item(),
|
||||
kv_dtype=kv_cache_dtype,
|
||||
)
|
||||
|
||||
# Run the reference implementation.
|
||||
block_indices = torch.div(slot_mapping, block_size, rounding_mode="floor")
|
||||
@@ -352,14 +399,12 @@ def test_reshape_and_cache_flash(
|
||||
cloned_value_cache[block_idx, :, block_offset, :] = value[i]
|
||||
|
||||
if kv_cache_dtype == "fp8":
|
||||
torch.testing.assert_close(result_key_cache,
|
||||
cloned_key_cache,
|
||||
atol=0.001,
|
||||
rtol=0.1)
|
||||
torch.testing.assert_close(result_value_cache,
|
||||
cloned_value_cache,
|
||||
atol=0.001,
|
||||
rtol=0.1)
|
||||
torch.testing.assert_close(
|
||||
result_key_cache, cloned_key_cache, atol=0.001, rtol=0.1
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
result_value_cache, cloned_value_cache, atol=0.001, rtol=0.1
|
||||
)
|
||||
else:
|
||||
torch.testing.assert_close(key_cache_compact, cloned_key_cache)
|
||||
torch.testing.assert_close(value_cache_compact, cloned_value_cache)
|
||||
@@ -396,8 +441,8 @@ def test_swap_blocks(
|
||||
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
src_device = device if direction[0] == "cuda" else 'cpu'
|
||||
dst_device = device if direction[1] == "cuda" else 'cpu'
|
||||
src_device = device if direction[0] == "cuda" else "cpu"
|
||||
dst_device = device if direction[1] == "cuda" else "cpu"
|
||||
|
||||
src_blocks = random.sample(range(num_blocks), num_mappings)
|
||||
# For the same device, mapping must not overlap
|
||||
@@ -408,42 +453,62 @@ def test_swap_blocks(
|
||||
dst_blocks = random.sample(range(num_blocks), num_mappings)
|
||||
|
||||
block_mapping = list(zip(src_blocks, dst_blocks))
|
||||
block_mapping_tensor = torch.tensor(block_mapping,
|
||||
dtype=torch.int64,
|
||||
device="cpu").view(-1, 2)
|
||||
block_mapping_tensor = torch.tensor(
|
||||
block_mapping, dtype=torch.int64, device="cpu"
|
||||
).view(-1, 2)
|
||||
|
||||
# Create the KV caches on the first device.
|
||||
src_key_caches, src_value_caches = kv_cache_factory(
|
||||
num_blocks, block_size, 1, num_heads, head_size, kv_cache_dtype, dtype,
|
||||
seed, src_device)
|
||||
num_blocks,
|
||||
block_size,
|
||||
1,
|
||||
num_heads,
|
||||
head_size,
|
||||
kv_cache_dtype,
|
||||
dtype,
|
||||
seed,
|
||||
src_device,
|
||||
)
|
||||
|
||||
# Create the KV caches on the second device.
|
||||
dist_key_caches, dist_value_caches = kv_cache_factory(
|
||||
num_blocks, block_size, 1, num_heads, head_size, kv_cache_dtype, dtype,
|
||||
seed, dst_device)
|
||||
num_blocks,
|
||||
block_size,
|
||||
1,
|
||||
num_heads,
|
||||
head_size,
|
||||
kv_cache_dtype,
|
||||
dtype,
|
||||
seed,
|
||||
dst_device,
|
||||
)
|
||||
|
||||
src_key_caches_clone = src_key_caches[0].clone()
|
||||
src_value_caches_clone = src_value_caches[0].clone()
|
||||
|
||||
# Call the swap_blocks kernel.
|
||||
do_opcheck = (head_size == HEAD_SIZES[0])
|
||||
opcheck(torch.ops._C_cache_ops.swap_blocks,
|
||||
(src_key_caches[0], dist_key_caches[0], block_mapping_tensor),
|
||||
cond=do_opcheck)
|
||||
opcheck(torch.ops._C_cache_ops.swap_blocks,
|
||||
(src_value_caches[0], dist_value_caches[0], block_mapping_tensor),
|
||||
cond=do_opcheck)
|
||||
do_opcheck = head_size == HEAD_SIZES[0]
|
||||
opcheck(
|
||||
torch.ops._C_cache_ops.swap_blocks,
|
||||
(src_key_caches[0], dist_key_caches[0], block_mapping_tensor),
|
||||
cond=do_opcheck,
|
||||
)
|
||||
opcheck(
|
||||
torch.ops._C_cache_ops.swap_blocks,
|
||||
(src_value_caches[0], dist_value_caches[0], block_mapping_tensor),
|
||||
cond=do_opcheck,
|
||||
)
|
||||
|
||||
ops.swap_blocks(src_key_caches[0], dist_key_caches[0],
|
||||
block_mapping_tensor)
|
||||
ops.swap_blocks(src_value_caches[0], dist_value_caches[0],
|
||||
block_mapping_tensor)
|
||||
ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping_tensor)
|
||||
ops.swap_blocks(src_value_caches[0], dist_value_caches[0], block_mapping_tensor)
|
||||
|
||||
for src, dst in block_mapping:
|
||||
torch.testing.assert_close(src_key_caches_clone[src].cpu(),
|
||||
dist_key_caches[0][dst].cpu())
|
||||
torch.testing.assert_close(src_value_caches_clone[src].cpu(),
|
||||
dist_value_caches[0][dst].cpu())
|
||||
torch.testing.assert_close(
|
||||
src_key_caches_clone[src].cpu(), dist_key_caches[0][dst].cpu()
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
src_value_caches_clone[src].cpu(), dist_value_caches[0][dst].cpu()
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@@ -489,11 +554,9 @@ def _create_mla_cache(
|
||||
device: str,
|
||||
) -> torch.Tensor:
|
||||
cache_dtype = torch.uint8 if kv_cache_dtype == "fp8" else dtype
|
||||
return torch.zeros(num_blocks,
|
||||
block_size,
|
||||
entry_size,
|
||||
dtype=cache_dtype,
|
||||
device=device)
|
||||
return torch.zeros(
|
||||
num_blocks, block_size, entry_size, dtype=cache_dtype, device=device
|
||||
)
|
||||
|
||||
|
||||
def _fill_mla_cache(cache: torch.Tensor, kv_cache_dtype: str):
|
||||
@@ -533,20 +596,16 @@ def test_concat_and_cache_mla(
|
||||
|
||||
total_slots = num_blocks * block_size
|
||||
slot_mapping_lst = random.sample(range(total_slots), num_tokens)
|
||||
slot_mapping = torch.tensor(slot_mapping_lst,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device)
|
||||
|
||||
kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device)
|
||||
k_pe = torch.randn(num_tokens,
|
||||
qk_rope_head_dim,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
k_pe = torch.randn(num_tokens, qk_rope_head_dim, dtype=dtype, device=device)
|
||||
entry_size = kv_lora_rank + qk_rope_head_dim
|
||||
|
||||
scale = torch.tensor(0.1, dtype=torch.float32, device=device)
|
||||
kv_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
|
||||
kv_cache_dtype, device)
|
||||
kv_cache = _create_mla_cache(
|
||||
num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device
|
||||
)
|
||||
ref_temp = torch.zeros(*kv_cache.shape, dtype=dtype, device=device)
|
||||
|
||||
for i in range(num_tokens):
|
||||
@@ -558,10 +617,7 @@ def test_concat_and_cache_mla(
|
||||
|
||||
if kv_cache_dtype == "fp8":
|
||||
ref_kv_cache = torch.empty_like(ref_temp, dtype=kv_cache.dtype)
|
||||
ops.convert_fp8(ref_kv_cache,
|
||||
ref_temp,
|
||||
scale.item(),
|
||||
kv_dtype=kv_cache_dtype)
|
||||
ops.convert_fp8(ref_kv_cache, ref_temp, scale.item(), kv_dtype=kv_cache_dtype)
|
||||
else:
|
||||
ref_kv_cache = ref_temp
|
||||
|
||||
@@ -571,24 +627,18 @@ def test_concat_and_cache_mla(
|
||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
|
||||
)
|
||||
|
||||
ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping,
|
||||
kv_cache_dtype, scale)
|
||||
ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale)
|
||||
|
||||
if kv_cache_dtype == "fp8":
|
||||
result_temp = torch.empty_like(kv_cache, dtype=torch.float16)
|
||||
ops.convert_fp8(result_temp,
|
||||
kv_cache.contiguous(),
|
||||
scale.item(),
|
||||
kv_dtype=kv_cache_dtype)
|
||||
ops.convert_fp8(
|
||||
result_temp, kv_cache.contiguous(), scale.item(), kv_dtype=kv_cache_dtype
|
||||
)
|
||||
expected_temp = torch.empty_like(ref_kv_cache, dtype=torch.float16)
|
||||
ops.convert_fp8(expected_temp,
|
||||
ref_kv_cache,
|
||||
scale.item(),
|
||||
kv_dtype=kv_cache_dtype)
|
||||
torch.testing.assert_close(result_temp,
|
||||
expected_temp,
|
||||
atol=0.001,
|
||||
rtol=0.1)
|
||||
ops.convert_fp8(
|
||||
expected_temp, ref_kv_cache, scale.item(), kv_dtype=kv_cache_dtype
|
||||
)
|
||||
torch.testing.assert_close(result_temp, expected_temp, atol=0.001, rtol=0.1)
|
||||
else:
|
||||
torch.testing.assert_close(kv_cache, ref_kv_cache)
|
||||
|
||||
@@ -620,24 +670,21 @@ def test_concat_and_cache_ds_mla(
|
||||
|
||||
total_slots = num_blocks * block_size
|
||||
slot_mapping_lst = random.sample(range(total_slots), num_tokens)
|
||||
slot_mapping = torch.tensor(slot_mapping_lst,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device)
|
||||
|
||||
kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device)
|
||||
k_pe = torch.randn(num_tokens,
|
||||
qk_rope_head_dim,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
k_pe = torch.randn(num_tokens, qk_rope_head_dim, dtype=dtype, device=device)
|
||||
entry_size = kv_lora_rank + (4 * 4) + (2 * qk_rope_head_dim)
|
||||
|
||||
scale = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||
kv_cache = _create_mla_cache(num_blocks,
|
||||
block_size,
|
||||
entry_size,
|
||||
dtype=torch.uint8,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
device=device)
|
||||
kv_cache = _create_mla_cache(
|
||||
num_blocks,
|
||||
block_size,
|
||||
entry_size,
|
||||
dtype=torch.uint8,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
ref_cache = torch.zeros_like(kv_cache, dtype=kv_cache.dtype)
|
||||
tile_data = torch.zeros(128, dtype=dtype, device=device)
|
||||
@@ -664,14 +711,16 @@ def test_concat_and_cache_ds_mla(
|
||||
manual_max = abs(tile_data_float[0])
|
||||
for j in range(1, 128):
|
||||
manual_max = max(manual_max, abs(tile_data_float[j]))
|
||||
tile_scale = manual_max / 448.
|
||||
tile_scale = manual_max / 448.0
|
||||
|
||||
ref_cache_32bit[kv_lora_rank // 4 + tile_idx] = tile_scale
|
||||
|
||||
ops.convert_fp8(ref_cache_slice[tile_start:tile_end],
|
||||
tile_data,
|
||||
tile_scale.item(),
|
||||
kv_dtype="fp8")
|
||||
ops.convert_fp8(
|
||||
ref_cache_slice[tile_start:tile_end],
|
||||
tile_data,
|
||||
tile_scale.item(),
|
||||
kv_dtype="fp8",
|
||||
)
|
||||
|
||||
for j in range(qk_rope_head_dim):
|
||||
ref_cache_16bit[kv_lora_rank // 2 + 8 + j] = k_pe[i, j]
|
||||
@@ -682,8 +731,7 @@ def test_concat_and_cache_ds_mla(
|
||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
|
||||
)
|
||||
|
||||
ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping,
|
||||
kv_cache_dtype, scale)
|
||||
ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale)
|
||||
|
||||
for i in range(num_tokens):
|
||||
slot = slot_mapping[i].item()
|
||||
@@ -694,12 +742,14 @@ def test_concat_and_cache_ds_mla(
|
||||
|
||||
kv_nope = kv_cache_slice[:kv_lora_rank]
|
||||
ref_nope = ref_cache_slice[:kv_lora_rank]
|
||||
kv_scales = kv_cache_slice.view(torch.float32)[kv_lora_rank //
|
||||
4:kv_lora_rank // 4 + 4]
|
||||
ref_scales = ref_cache_slice.view(
|
||||
torch.float32)[kv_lora_rank // 4:kv_lora_rank // 4 + 4]
|
||||
kv_rope = kv_cache_slice.view(dtype)[kv_lora_rank // 2 + 8:]
|
||||
ref_rope = ref_cache_slice.view(dtype)[kv_lora_rank // 2 + 8:]
|
||||
kv_scales = kv_cache_slice.view(torch.float32)[
|
||||
kv_lora_rank // 4 : kv_lora_rank // 4 + 4
|
||||
]
|
||||
ref_scales = ref_cache_slice.view(torch.float32)[
|
||||
kv_lora_rank // 4 : kv_lora_rank // 4 + 4
|
||||
]
|
||||
kv_rope = kv_cache_slice.view(dtype)[kv_lora_rank // 2 + 8 :]
|
||||
ref_rope = ref_cache_slice.view(dtype)[kv_lora_rank // 2 + 8 :]
|
||||
|
||||
torch.testing.assert_close(kv_nope, ref_nope, atol=0.001, rtol=0.1)
|
||||
torch.testing.assert_close(kv_scales, ref_scales, atol=0.001, rtol=0.1)
|
||||
@@ -734,8 +784,9 @@ def test_copy_blocks_mla(
|
||||
|
||||
kv_caches = []
|
||||
for _ in range(num_layers):
|
||||
kv_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
|
||||
kv_cache_dtype, device)
|
||||
kv_cache = _create_mla_cache(
|
||||
num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device
|
||||
)
|
||||
_fill_mla_cache(kv_cache, kv_cache_dtype=kv_cache_dtype)
|
||||
kv_caches.append(kv_cache)
|
||||
|
||||
@@ -752,9 +803,9 @@ def test_copy_blocks_mla(
|
||||
dst2 = dst_blocks[2 * i + 1]
|
||||
block_mapping.append((src, dst1))
|
||||
block_mapping.append((src, dst2))
|
||||
block_mapping_tensor = torch.tensor(block_mapping,
|
||||
dtype=torch.int64,
|
||||
device=device).view(-1, 2)
|
||||
block_mapping_tensor = torch.tensor(
|
||||
block_mapping, dtype=torch.int64, device=device
|
||||
).view(-1, 2)
|
||||
|
||||
for src, dst in block_mapping:
|
||||
for ref_cache in ref_caches:
|
||||
@@ -795,10 +846,12 @@ def test_swap_blocks_mla(
|
||||
|
||||
entry_size = kv_lora_rank + qk_rope_head_dim
|
||||
|
||||
src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
|
||||
kv_cache_dtype, device)
|
||||
dst_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
|
||||
kv_cache_dtype, device)
|
||||
src_cache = _create_mla_cache(
|
||||
num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device
|
||||
)
|
||||
dst_cache = _create_mla_cache(
|
||||
num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device
|
||||
)
|
||||
|
||||
_fill_mla_cache(src_cache, kv_cache_dtype)
|
||||
_fill_mla_cache(dst_cache, kv_cache_dtype)
|
||||
@@ -810,9 +863,9 @@ def test_swap_blocks_mla(
|
||||
remaining_blocks = list(set(range(num_blocks)) - set(src_blocks))
|
||||
dst_blocks = random.sample(remaining_blocks, num_mappings)
|
||||
block_mapping = list(zip(src_blocks, dst_blocks))
|
||||
block_mapping_tensor = torch.tensor(block_mapping,
|
||||
dtype=torch.int64,
|
||||
device="cpu").view(-1, 2)
|
||||
block_mapping_tensor = torch.tensor(
|
||||
block_mapping, dtype=torch.int64, device="cpu"
|
||||
).view(-1, 2)
|
||||
|
||||
opcheck(
|
||||
torch.ops._C_cache_ops.swap_blocks,
|
||||
@@ -827,7 +880,8 @@ def test_swap_blocks_mla(
|
||||
src_cache_clone[src].cpu(),
|
||||
dst_cache[dst].cpu(),
|
||||
msg=f"Block {src} from src should have been swapped to block "
|
||||
f"{dst} in dst_cache.")
|
||||
f"{dst} in dst_cache.",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kv_lora_rank", [512])
|
||||
@@ -840,32 +894,36 @@ def test_swap_blocks_mla(
|
||||
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_gather_and_maybe_dequant_cache_mla(kv_lora_rank, qk_rope_head_dim,
|
||||
block_size, num_blocks,
|
||||
max_seq_len, batch_size, dtype,
|
||||
kv_cache_dtype, device):
|
||||
def test_gather_and_maybe_dequant_cache_mla(
|
||||
kv_lora_rank,
|
||||
qk_rope_head_dim,
|
||||
block_size,
|
||||
num_blocks,
|
||||
max_seq_len,
|
||||
batch_size,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
device,
|
||||
):
|
||||
entry_size = kv_lora_rank + qk_rope_head_dim
|
||||
scale = torch.tensor(0.1, dtype=torch.float32, device=device)
|
||||
src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
|
||||
kv_cache_dtype, device)
|
||||
src_cache = _create_mla_cache(
|
||||
num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device
|
||||
)
|
||||
_fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype)
|
||||
|
||||
seq_len_tensor = torch.randint(0,
|
||||
max_seq_len + 1, (batch_size, ),
|
||||
device=device)
|
||||
seq_len_tensor = torch.randint(0, max_seq_len + 1, (batch_size,), device=device)
|
||||
|
||||
total_tokens = seq_len_tensor.sum()
|
||||
cu_seq_lens = torch.empty((batch_size + 1),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
cu_seq_lens = torch.empty((batch_size + 1), dtype=torch.int32, device=device)
|
||||
cu_seq_lens[0] = 0
|
||||
cu_seq_lens[1:] = seq_len_tensor.cumsum(dim=0).to(dtype=torch.int32)
|
||||
print("seq_len_tensor", seq_len_tensor)
|
||||
|
||||
tot_blocks_tensor = (seq_len_tensor + block_size - 1) // block_size
|
||||
block_table = torch.empty((batch_size, num_blocks),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
block_table = torch.empty(
|
||||
(batch_size, num_blocks), dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
for b in range(batch_size):
|
||||
perm = torch.randperm(num_blocks, device=device)
|
||||
@@ -893,10 +951,8 @@ def test_gather_and_maybe_dequant_cache_mla(kv_lora_rank, qk_rope_head_dim,
|
||||
remaining = s - (tot - 1) * block_size
|
||||
last_block_data = src_cache[blocks[-1], :remaining, :]
|
||||
if kv_cache_dtype == "fp8":
|
||||
dequantized_last_block = torch.empty_like(last_block_data,
|
||||
dtype=dtype)
|
||||
ops.convert_fp8(dequantized_last_block, last_block_data,
|
||||
scale.item())
|
||||
dequantized_last_block = torch.empty_like(last_block_data, dtype=dtype)
|
||||
ops.convert_fp8(dequantized_last_block, last_block_data, scale.item())
|
||||
gathered_rows.append(dequantized_last_block)
|
||||
else:
|
||||
gathered_rows.append(last_block_data)
|
||||
@@ -907,14 +963,29 @@ def test_gather_and_maybe_dequant_cache_mla(kv_lora_rank, qk_rope_head_dim,
|
||||
|
||||
opcheck(
|
||||
torch.ops._C_cache_ops.gather_and_maybe_dequant_cache,
|
||||
(src_cache, dst, block_table, cu_seq_lens, batch_size, kv_cache_dtype,
|
||||
scale, None),
|
||||
(
|
||||
src_cache,
|
||||
dst,
|
||||
block_table,
|
||||
cu_seq_lens,
|
||||
batch_size,
|
||||
kv_cache_dtype,
|
||||
scale,
|
||||
None,
|
||||
),
|
||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
|
||||
)
|
||||
|
||||
ops.gather_and_maybe_dequant_cache(src_cache, dst, block_table,
|
||||
cu_seq_lens, batch_size, kv_cache_dtype,
|
||||
scale, None)
|
||||
ops.gather_and_maybe_dequant_cache(
|
||||
src_cache,
|
||||
dst,
|
||||
block_table,
|
||||
cu_seq_lens,
|
||||
batch_size,
|
||||
kv_cache_dtype,
|
||||
scale,
|
||||
None,
|
||||
)
|
||||
torch.testing.assert_close(dst, expected)
|
||||
|
||||
|
||||
@@ -925,42 +996,46 @@ def test_gather_and_maybe_dequant_cache_mla(kv_lora_rank, qk_rope_head_dim,
|
||||
@pytest.mark.parametrize("max_seq_len", [512])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize("dtype", [torch.float32])
|
||||
@pytest.mark.parametrize("kv_cache_dtype",
|
||||
["auto"]) # You can also test "fp8" if needed.
|
||||
@pytest.mark.parametrize(
|
||||
"kv_cache_dtype", ["auto"]
|
||||
) # You can also test "fp8" if needed.
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_cp_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size,
|
||||
num_blocks, max_seq_len, batch_size, dtype,
|
||||
kv_cache_dtype, device):
|
||||
def test_cp_gather_cache_mla(
|
||||
kv_lora_rank,
|
||||
qk_rope_head_dim,
|
||||
block_size,
|
||||
num_blocks,
|
||||
max_seq_len,
|
||||
batch_size,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
device,
|
||||
):
|
||||
entry_size = kv_lora_rank + qk_rope_head_dim
|
||||
src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
|
||||
kv_cache_dtype, device)
|
||||
src_cache = _create_mla_cache(
|
||||
num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device
|
||||
)
|
||||
_fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype)
|
||||
|
||||
seq_len_tensor = torch.randint(0,
|
||||
max_seq_len + 1, (batch_size, ),
|
||||
device=device)
|
||||
seq_len_tensor = torch.randint(0, max_seq_len + 1, (batch_size,), device=device)
|
||||
|
||||
total_tokens = seq_len_tensor.sum()
|
||||
cu_seq_lens = torch.empty((batch_size + 1),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
cu_seq_lens = torch.empty((batch_size + 1), dtype=torch.int32, device=device)
|
||||
cu_seq_lens[0] = 0
|
||||
cu_seq_lens[1:] = seq_len_tensor.cumsum(dim=0).to(dtype=torch.int32)
|
||||
print("seq_len_tensor", seq_len_tensor)
|
||||
|
||||
tot_blocks_tensor = (seq_len_tensor + block_size - 1) // block_size
|
||||
block_table = torch.empty((batch_size, num_blocks),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
block_table = torch.empty(
|
||||
(batch_size, num_blocks), dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
for b in range(batch_size):
|
||||
perm = torch.randperm(num_blocks, device=device)
|
||||
block_table[b, :] = perm
|
||||
|
||||
dst = torch.zeros((total_tokens, entry_size),
|
||||
dtype=src_cache.dtype,
|
||||
device=device)
|
||||
dst = torch.zeros((total_tokens, entry_size), dtype=src_cache.dtype, device=device)
|
||||
|
||||
expected_batches = []
|
||||
for b in range(batch_size):
|
||||
@@ -1016,20 +1091,16 @@ def test_concat_and_cache_mla_cpu(
|
||||
|
||||
total_slots = num_blocks * block_size
|
||||
slot_mapping_lst = random.sample(range(total_slots), num_tokens)
|
||||
slot_mapping = torch.tensor(slot_mapping_lst,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device)
|
||||
|
||||
kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device)
|
||||
k_pe = torch.randn(num_tokens,
|
||||
qk_rope_head_dim,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
k_pe = torch.randn(num_tokens, qk_rope_head_dim, dtype=dtype, device=device)
|
||||
entry_size = kv_lora_rank + qk_rope_head_dim
|
||||
|
||||
scale = torch.tensor(0.1, dtype=torch.float32, device=device)
|
||||
kv_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
|
||||
kv_cache_dtype, device)
|
||||
kv_cache = _create_mla_cache(
|
||||
num_blocks, block_size, entry_size, dtype, kv_cache_dtype, device
|
||||
)
|
||||
ref_temp = torch.zeros(*kv_cache.shape, dtype=dtype, device=device)
|
||||
|
||||
for i in range(num_tokens):
|
||||
@@ -1041,10 +1112,7 @@ def test_concat_and_cache_mla_cpu(
|
||||
|
||||
if kv_cache_dtype == "fp8":
|
||||
ref_kv_cache = torch.empty_like(ref_temp, dtype=kv_cache.dtype)
|
||||
ops.convert_fp8(ref_kv_cache,
|
||||
ref_temp,
|
||||
scale.item(),
|
||||
kv_dtype=kv_cache_dtype)
|
||||
ops.convert_fp8(ref_kv_cache, ref_temp, scale.item(), kv_dtype=kv_cache_dtype)
|
||||
else:
|
||||
ref_kv_cache = ref_temp
|
||||
|
||||
@@ -1054,6 +1122,5 @@ def test_concat_and_cache_mla_cpu(
|
||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
|
||||
)
|
||||
|
||||
ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping,
|
||||
kv_cache_dtype, scale)
|
||||
ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale)
|
||||
torch.testing.assert_close(kv_cache, ref_kv_cache)
|
||||
|
||||
@@ -7,11 +7,12 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.attention.backends.flash_attn import (cascade_attention,
|
||||
merge_attn_states)
|
||||
from vllm.vllm_flash_attn import (fa_version_unsupported_reason,
|
||||
flash_attn_varlen_func,
|
||||
is_fa_version_supported)
|
||||
from vllm.v1.attention.backends.flash_attn import cascade_attention, merge_attn_states
|
||||
from vllm.vllm_flash_attn import (
|
||||
fa_version_unsupported_reason,
|
||||
flash_attn_varlen_func,
|
||||
is_fa_version_supported,
|
||||
)
|
||||
|
||||
NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
|
||||
HEAD_SIZES = [128, 192, 256]
|
||||
@@ -37,21 +38,14 @@ def test_merge_kernel(
|
||||
assert num_query_heads % num_kv_heads == 0
|
||||
|
||||
# Prepare inputs.
|
||||
prefix_output = torch.randn(num_tokens,
|
||||
num_query_heads,
|
||||
head_size,
|
||||
dtype=dtype)
|
||||
suffix_output = torch.randn(num_tokens,
|
||||
num_query_heads,
|
||||
head_size,
|
||||
dtype=dtype)
|
||||
prefix_output = torch.randn(num_tokens, num_query_heads, head_size, dtype=dtype)
|
||||
suffix_output = torch.randn(num_tokens, num_query_heads, head_size, dtype=dtype)
|
||||
prefix_lse = torch.randn(num_query_heads, num_tokens, dtype=torch.float32)
|
||||
suffix_lse = torch.randn(num_query_heads, num_tokens, dtype=torch.float32)
|
||||
|
||||
# Run the kernel.
|
||||
output = torch.empty(num_tokens, num_query_heads, head_size, dtype=dtype)
|
||||
merge_attn_states(output, prefix_output, prefix_lse, suffix_output,
|
||||
suffix_lse)
|
||||
merge_attn_states(output, prefix_output, prefix_lse, suffix_output, suffix_lse)
|
||||
|
||||
# Reference implementation.
|
||||
max_lse = torch.maximum(prefix_lse, suffix_lse)
|
||||
@@ -97,8 +91,10 @@ def test_cascade(
|
||||
) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
if not is_fa_version_supported(fa_version):
|
||||
pytest.skip(f"Flash attention version {fa_version} not supported due "
|
||||
f"to: \"{fa_version_unsupported_reason(fa_version)}\"")
|
||||
pytest.skip(
|
||||
f"Flash attention version {fa_version} not supported due "
|
||||
f'to: "{fa_version_unsupported_reason(fa_version)}"'
|
||||
)
|
||||
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
@@ -107,11 +103,9 @@ def test_cascade(
|
||||
num_query_heads = num_heads[0]
|
||||
num_kv_heads = num_heads[1]
|
||||
assert num_query_heads % num_kv_heads == 0
|
||||
key_cache = torch.randn(num_blocks,
|
||||
block_size,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
dtype=dtype)
|
||||
key_cache = torch.randn(
|
||||
num_blocks, block_size, num_kv_heads, head_size, dtype=dtype
|
||||
)
|
||||
value_cache = torch.randn_like(key_cache)
|
||||
|
||||
seq_lens, common_prefix_len = seq_lens_and_common_prefix
|
||||
@@ -122,26 +116,21 @@ def test_cascade(
|
||||
max_kv_len = max(kv_lens)
|
||||
|
||||
total_num_query_tokens = sum(query_lens)
|
||||
query = torch.randn(total_num_query_tokens,
|
||||
num_query_heads,
|
||||
head_size,
|
||||
dtype=dtype)
|
||||
cu_query_lens = torch.tensor([0] + query_lens,
|
||||
dtype=torch.int32).cumsum(dim=0,
|
||||
dtype=torch.int32)
|
||||
query = torch.randn(total_num_query_tokens, num_query_heads, head_size, dtype=dtype)
|
||||
cu_query_lens = torch.tensor([0] + query_lens, dtype=torch.int32).cumsum(
|
||||
dim=0, dtype=torch.int32
|
||||
)
|
||||
kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32)
|
||||
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
|
||||
block_tables = torch.randint(0,
|
||||
num_blocks,
|
||||
(num_seqs, max_num_blocks_per_seq),
|
||||
dtype=torch.int32)
|
||||
block_tables = torch.randint(
|
||||
0, num_blocks, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32
|
||||
)
|
||||
|
||||
assert common_prefix_len > 0
|
||||
assert common_prefix_len % block_size == 0
|
||||
num_common_kv_blocks = common_prefix_len // block_size
|
||||
# Make sure the first `num_common_kv_blocks` blocks are the same.
|
||||
block_tables[:, :num_common_kv_blocks] = \
|
||||
block_tables[0, :num_common_kv_blocks]
|
||||
block_tables[:, :num_common_kv_blocks] = block_tables[0, :num_common_kv_blocks]
|
||||
|
||||
# Run the regular attention.
|
||||
ref_output = flash_attn_varlen_func(
|
||||
@@ -161,8 +150,7 @@ def test_cascade(
|
||||
|
||||
# Run cascade attention.
|
||||
assert all(common_prefix_len < kv_len for kv_len in kv_lens)
|
||||
cu_prefix_query_lens = torch.tensor([0, total_num_query_tokens],
|
||||
dtype=torch.int32)
|
||||
cu_prefix_query_lens = torch.tensor([0, total_num_query_tokens], dtype=torch.int32)
|
||||
prefix_kv_lens = torch.tensor([common_prefix_len], dtype=torch.int32)
|
||||
suffix_kv_lens = kv_lens_tensor - common_prefix_len
|
||||
output = torch.empty_like(query)
|
||||
|
||||
@@ -12,33 +12,37 @@ from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import triton
|
||||
|
||||
|
||||
def cal_diff(x: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
name: str,
|
||||
use_fp8: bool = False,
|
||||
diff_threshold: Optional[float] = None) -> None:
|
||||
def cal_diff(
|
||||
x: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
name: str,
|
||||
use_fp8: bool = False,
|
||||
diff_threshold: Optional[float] = None,
|
||||
) -> None:
|
||||
x, y = x.double(), y.double()
|
||||
cos_diff = 1 - 2 * (x * y).sum().item() / max(
|
||||
(x * x + y * y).sum().item(), 1e-12)
|
||||
cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12)
|
||||
if diff_threshold is not None:
|
||||
# directly compare the cos_diff with the threshold
|
||||
assert cos_diff < diff_threshold
|
||||
else:
|
||||
# use the default threshold
|
||||
if (use_fp8):
|
||||
if use_fp8:
|
||||
assert cos_diff < 1e-4
|
||||
else:
|
||||
assert cos_diff < 1e-5
|
||||
|
||||
|
||||
CUTLASS_MLA_UNSUPPORTED_REASON = \
|
||||
"Cutlass MLA Requires compute capability of 10 or above." \
|
||||
if not current_platform.is_device_capability(100) \
|
||||
CUTLASS_MLA_UNSUPPORTED_REASON = (
|
||||
"Cutlass MLA Requires compute capability of 10 or above."
|
||||
if not current_platform.is_device_capability(100)
|
||||
else "Cutlass MLA is supported"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.has_device_capability(100),
|
||||
reason=CUTLASS_MLA_UNSUPPORTED_REASON)
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(100),
|
||||
reason=CUTLASS_MLA_UNSUPPORTED_REASON,
|
||||
)
|
||||
@pytest.mark.parametrize("b", [128])
|
||||
@pytest.mark.parametrize("s_q", [1])
|
||||
@pytest.mark.parametrize("mean_sk", [4096, 8192, 16384])
|
||||
@@ -54,11 +58,13 @@ CUTLASS_MLA_UNSUPPORTED_REASON = \
|
||||
[
|
||||
torch.bfloat16,
|
||||
# fp8 can have occasional precision-related failures.
|
||||
pytest.param(torch.float8_e4m3fn, marks=pytest.mark.flaky(reruns=2))
|
||||
])
|
||||
pytest.param(torch.float8_e4m3fn, marks=pytest.mark.flaky(reruns=2)),
|
||||
],
|
||||
)
|
||||
@torch.inference_mode()
|
||||
def test_cutlass_mla_decode(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size,
|
||||
causal, varlen, torch_dtype):
|
||||
def test_cutlass_mla_decode(
|
||||
b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, varlen, torch_dtype
|
||||
):
|
||||
device = torch.device("cuda:0")
|
||||
if torch_dtype == torch.float8_e4m3fn:
|
||||
init_dtype = torch.bfloat16
|
||||
@@ -70,24 +76,25 @@ def test_cutlass_mla_decode(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size,
|
||||
torch.manual_seed(42)
|
||||
random.seed(42)
|
||||
|
||||
print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, "
|
||||
f"{d=}, {dv=}, {causal=}, {varlen=}, {torch_dtype=}")
|
||||
print(
|
||||
f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, "
|
||||
f"{d=}, {dv=}, {causal=}, {varlen=}, {torch_dtype=}"
|
||||
)
|
||||
|
||||
use_fp8 = torch_dtype == torch.float8_e4m3fn
|
||||
scale = math.sqrt(d)**(-1)
|
||||
cache_seqlens = torch.full((b, ), mean_sk, dtype=torch.int32)
|
||||
scale = math.sqrt(d) ** (-1)
|
||||
cache_seqlens = torch.full((b,), mean_sk, dtype=torch.int32)
|
||||
if varlen:
|
||||
for i in range(b):
|
||||
cache_seqlens[i] = max(random.normalvariate(mean_sk, mean_sk / 2),
|
||||
s_q)
|
||||
cache_seqlens[i] = max(random.normalvariate(mean_sk, mean_sk / 2), s_q)
|
||||
total_seqlens = cache_seqlens.sum().item()
|
||||
max_seqlen = cache_seqlens.max().item()
|
||||
max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256
|
||||
|
||||
q = torch.randn(b, s_q, h_q, d)
|
||||
block_table = torch.arange(b * max_seqlen_pad // block_size,
|
||||
dtype=torch.int32).view(
|
||||
b, max_seqlen_pad // block_size)
|
||||
block_table = torch.arange(
|
||||
b * max_seqlen_pad // block_size, dtype=torch.int32
|
||||
).view(b, max_seqlen_pad // block_size)
|
||||
blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d)
|
||||
blocked_v = blocked_k[..., :dv]
|
||||
|
||||
@@ -121,22 +128,29 @@ def test_cutlass_mla_decode(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size,
|
||||
q_pe = q_pe_padded
|
||||
|
||||
kv_cache_flat = blocked_k.squeeze(2)
|
||||
device_properties = torch.cuda.get_device_properties(
|
||||
torch.device("cuda:0"))
|
||||
device_properties = torch.cuda.get_device_properties(torch.device("cuda:0"))
|
||||
sm_count = device_properties.multi_processor_count
|
||||
workspace_size = ops.sm100_cutlass_mla_get_workspace_size(
|
||||
max_seqlen * block_size, b, sm_count, num_kv_splits=1)
|
||||
workspace = torch.empty(workspace_size,
|
||||
device="cuda",
|
||||
dtype=torch.uint8)
|
||||
max_seqlen * block_size, b, sm_count, num_kv_splits=1
|
||||
)
|
||||
workspace = torch.empty(workspace_size, device="cuda", dtype=torch.uint8)
|
||||
|
||||
out_ans = torch.empty(b, MAX_HEADS, dv, dtype=init_dtype)
|
||||
output_lse = torch.empty((b, MAX_HEADS),
|
||||
dtype=torch.float32,
|
||||
device=q_nope.device)
|
||||
ops.sm100_cutlass_mla_decode(out_ans, output_lse, q_nope, q_pe,
|
||||
kv_cache_flat, cache_seqlens, block_table,
|
||||
workspace, scale, 1)
|
||||
output_lse = torch.empty(
|
||||
(b, MAX_HEADS), dtype=torch.float32, device=q_nope.device
|
||||
)
|
||||
ops.sm100_cutlass_mla_decode(
|
||||
out_ans,
|
||||
output_lse,
|
||||
q_nope,
|
||||
q_pe,
|
||||
kv_cache_flat,
|
||||
cache_seqlens,
|
||||
block_table,
|
||||
workspace,
|
||||
scale,
|
||||
1,
|
||||
)
|
||||
return out_ans[:, :h_q].contiguous(), output_lse[:, :h_q].contiguous()
|
||||
|
||||
def scaled_dot_product_attention(query, key, value, is_causal=False):
|
||||
@@ -150,8 +164,7 @@ def test_cutlass_mla_decode(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size,
|
||||
s_q = query.shape[-2]
|
||||
s_k = key.shape[-2]
|
||||
attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype)
|
||||
temp_mask = torch.ones(s_q, s_k,
|
||||
dtype=torch.bool).tril(diagonal=s_k - s_q)
|
||||
temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q)
|
||||
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
|
||||
attn_bias.to(query.dtype)
|
||||
attn_weight += attn_bias
|
||||
@@ -161,10 +174,16 @@ def test_cutlass_mla_decode(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size,
|
||||
|
||||
def ref_mla():
|
||||
q_ = (q.to(torch.float) * descale_q).to(init_dtype) if use_fp8 else q
|
||||
blocked_k_ = (blocked_k.to(torch.float) *
|
||||
descale_k).to(init_dtype) if use_fp8 else blocked_k
|
||||
blocked_v_ = (blocked_v.to(torch.float) *
|
||||
descale_k).to(init_dtype) if use_fp8 else blocked_v
|
||||
blocked_k_ = (
|
||||
(blocked_k.to(torch.float) * descale_k).to(init_dtype)
|
||||
if use_fp8
|
||||
else blocked_k
|
||||
)
|
||||
blocked_v_ = (
|
||||
(blocked_v.to(torch.float) * descale_k).to(init_dtype)
|
||||
if use_fp8
|
||||
else blocked_v
|
||||
)
|
||||
out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32)
|
||||
lse = torch.empty(b, h_q, s_q, dtype=torch.float32)
|
||||
for i in range(b):
|
||||
@@ -191,8 +210,9 @@ def test_cutlass_mla_decode(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size,
|
||||
|
||||
t = triton.testing.do_bench(cutlass_mla)
|
||||
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
|
||||
bytes = (total_seqlens * h_kv * d +
|
||||
b * s_q * h_q * d) * (torch.finfo(torch_dtype).bits // 8) + (
|
||||
b * s_q * h_q * dv) * (torch.finfo(init_dtype).bits // 8)
|
||||
print(f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS,",
|
||||
f"{bytes / 10 ** 6 / t:.0f} GB/s")
|
||||
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d) * (
|
||||
torch.finfo(torch_dtype).bits // 8
|
||||
) + (b * s_q * h_q * dv) * (torch.finfo(init_dtype).bits // 8)
|
||||
print(
|
||||
f"{t:.3f} ms, {FLOPS / 10**9 / t:.0f} TFLOPS,", f"{bytes / 10**6 / t:.0f} GB/s"
|
||||
)
|
||||
|
||||
@@ -7,9 +7,14 @@ import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import cdiv, has_deep_gemm
|
||||
from vllm.utils.deep_gemm import (_ceil_to_ue8m0, calc_diff, fp8_mqa_logits,
|
||||
fp8_paged_mqa_logits, get_num_sms,
|
||||
get_paged_mqa_logits_metadata)
|
||||
from vllm.utils.deep_gemm import (
|
||||
_ceil_to_ue8m0,
|
||||
calc_diff,
|
||||
fp8_mqa_logits,
|
||||
fp8_paged_mqa_logits,
|
||||
get_num_sms,
|
||||
get_paged_mqa_logits_metadata,
|
||||
)
|
||||
|
||||
|
||||
def kv_cache_cast_to_fp8(x: torch.Tensor) -> torch.Tensor:
|
||||
@@ -24,17 +29,18 @@ def kv_cache_cast_to_fp8(x: torch.Tensor) -> torch.Tensor:
|
||||
device=x.device,
|
||||
dtype=torch.uint8,
|
||||
)
|
||||
x_fp8[:, :block_size * head_dim] = x_scaled.view(
|
||||
num_blocks, block_size * head_dim).view(dtype=torch.uint8)
|
||||
x_fp8[:,
|
||||
block_size * head_dim:] = sf.view(num_blocks,
|
||||
block_size).view(dtype=torch.uint8)
|
||||
x_fp8[:, : block_size * head_dim] = x_scaled.view(
|
||||
num_blocks, block_size * head_dim
|
||||
).view(dtype=torch.uint8)
|
||||
x_fp8[:, block_size * head_dim :] = sf.view(num_blocks, block_size).view(
|
||||
dtype=torch.uint8
|
||||
)
|
||||
return x_fp8.view(num_blocks, block_size, num_heads, head_dim + 4)
|
||||
|
||||
|
||||
def per_custom_dims_cast_to_fp8(
|
||||
x: torch.Tensor, dims: tuple,
|
||||
use_ue8m0: bool) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
x: torch.Tensor, dims: tuple, use_ue8m0: bool
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
excluded_dims = tuple([i for i in range(x.dim()) if i not in set(dims)])
|
||||
x_amax = x.abs().float().amax(dim=excluded_dims, keepdim=True).clamp(1e-4)
|
||||
sf = x_amax / 448.0
|
||||
@@ -69,10 +75,12 @@ def _ref_fp8_mqa_logits(
|
||||
q = q.float()
|
||||
k = k.float()
|
||||
|
||||
mask_lo = (torch.arange(0, seq_len_kv, device="cuda")[None, :]
|
||||
>= cu_seqlen_ks[:, None])
|
||||
mask_hi = (torch.arange(0, seq_len_kv, device="cuda")[None, :]
|
||||
< cu_seqlen_ke[:, None])
|
||||
mask_lo = (
|
||||
torch.arange(0, seq_len_kv, device="cuda")[None, :] >= cu_seqlen_ks[:, None]
|
||||
)
|
||||
mask_hi = (
|
||||
torch.arange(0, seq_len_kv, device="cuda")[None, :] < cu_seqlen_ke[:, None]
|
||||
)
|
||||
mask = mask_lo & mask_hi
|
||||
|
||||
score = torch.einsum("mhd,and->hmn", q, k)
|
||||
@@ -84,14 +92,15 @@ def _ref_fp8_mqa_logits(
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA only")
|
||||
@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available")
|
||||
@pytest.mark.skipif(not current_platform.has_device_capability(90),
|
||||
reason="SM90 and SM100 only")
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(90), reason="SM90 and SM100 only"
|
||||
)
|
||||
def test_deepgemm_fp8_mqa_logits():
|
||||
torch.manual_seed(0)
|
||||
random.seed(0)
|
||||
num_heads, head_dim = 32, 128
|
||||
for seq_len in (512, ):
|
||||
for seq_len_kv in (1024, ):
|
||||
for seq_len in (512,):
|
||||
for seq_len_kv in (1024,):
|
||||
for disable_cp in (False, True):
|
||||
q = torch.randn(
|
||||
seq_len,
|
||||
@@ -100,24 +109,23 @@ def test_deepgemm_fp8_mqa_logits():
|
||||
device="cuda",
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
kv = torch.randn(seq_len_kv,
|
||||
head_dim,
|
||||
device="cuda",
|
||||
dtype=torch.bfloat16)
|
||||
weights = torch.randn(seq_len,
|
||||
num_heads,
|
||||
device="cuda",
|
||||
dtype=torch.float32)
|
||||
kv = torch.randn(
|
||||
seq_len_kv, head_dim, device="cuda", dtype=torch.bfloat16
|
||||
)
|
||||
weights = torch.randn(
|
||||
seq_len, num_heads, device="cuda", dtype=torch.float32
|
||||
)
|
||||
|
||||
if disable_cp:
|
||||
ks = torch.zeros(seq_len, dtype=torch.int, device="cuda")
|
||||
ke = torch.arange(seq_len, dtype=torch.int,
|
||||
device="cuda") + (seq_len_kv - seq_len)
|
||||
ke = torch.arange(seq_len, dtype=torch.int, device="cuda") + (
|
||||
seq_len_kv - seq_len
|
||||
)
|
||||
else:
|
||||
ks, ke = _generate_cp_test_data(seq_len, seq_len_kv)
|
||||
|
||||
q_fp8 = q.to(torch.float8_e4m3fn)
|
||||
kv_fp8 = per_custom_dims_cast_to_fp8(kv, (0, ), False)
|
||||
kv_fp8 = per_custom_dims_cast_to_fp8(kv, (0,), False)
|
||||
logits = fp8_mqa_logits(q_fp8, kv_fp8, weights, ks, ke)
|
||||
|
||||
ref_logits = _ref_fp8_mqa_logits(
|
||||
@@ -157,11 +165,10 @@ def _ref_fp8_paged_mqa_logits(
|
||||
context_lens_list = context_lens.tolist()
|
||||
for i in range(batch_size):
|
||||
context_len = context_lens_list[i]
|
||||
q_offsets = torch.arange(context_len - next_n,
|
||||
context_len,
|
||||
device="cuda")
|
||||
weight_slice = (weights[i * next_n:(i + 1) * next_n, :].transpose(
|
||||
0, 1).contiguous())
|
||||
q_offsets = torch.arange(context_len - next_n, context_len, device="cuda")
|
||||
weight_slice = (
|
||||
weights[i * next_n : (i + 1) * next_n, :].transpose(0, 1).contiguous()
|
||||
)
|
||||
for block_rk in range(cdiv(context_len, block_size)):
|
||||
block_idx = block_tables[i][block_rk]
|
||||
qx, kx = q[i], kv_cache[block_idx]
|
||||
@@ -170,28 +177,30 @@ def _ref_fp8_paged_mqa_logits(
|
||||
(block_rk + 1) * block_size,
|
||||
device="cuda",
|
||||
)
|
||||
mask = (k_offsets[None, :] < context_len) & (k_offsets[None, :]
|
||||
<= q_offsets[:, None])
|
||||
mask = (k_offsets[None, :] < context_len) & (
|
||||
k_offsets[None, :] <= q_offsets[:, None]
|
||||
)
|
||||
s = torch.where(
|
||||
mask[None, :, :],
|
||||
(qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to(
|
||||
logits.dtype),
|
||||
logits.dtype
|
||||
),
|
||||
float("-inf"),
|
||||
)
|
||||
s = torch.relu(s) * weight_slice[..., None]
|
||||
s = s.sum(dim=0)
|
||||
logits[
|
||||
i * next_n:(i + 1) * next_n,
|
||||
block_rk * block_size:(block_rk + 1) * block_size,
|
||||
] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s,
|
||||
float("-inf"))
|
||||
i * next_n : (i + 1) * next_n,
|
||||
block_rk * block_size : (block_rk + 1) * block_size,
|
||||
] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s, float("-inf"))
|
||||
return logits
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA only")
|
||||
@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available")
|
||||
@pytest.mark.skipif(not current_platform.has_device_capability(90),
|
||||
reason="SM90 and SM100 only")
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(90), reason="SM90 and SM100 only"
|
||||
)
|
||||
def test_deepgemm_fp8_paged_mqa_logits():
|
||||
torch.manual_seed(0)
|
||||
random.seed(0)
|
||||
@@ -199,7 +208,7 @@ def test_deepgemm_fp8_paged_mqa_logits():
|
||||
max_model_len = 4096
|
||||
for batch_size, next_n in [(4, 1), (2, 2)]:
|
||||
for heads, index_dim in [(32, 128)]:
|
||||
for avg_kv in (2048, ):
|
||||
for avg_kv in (2048,):
|
||||
num_blocks, blocksize = max_model_len * 2, 64
|
||||
|
||||
q = torch.randn(
|
||||
@@ -218,12 +227,14 @@ def test_deepgemm_fp8_paged_mqa_logits():
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
context_lens = (torch.randint(int(0.8 * avg_kv),
|
||||
int(1.2 * avg_kv),
|
||||
(batch_size, )).cuda().to(
|
||||
torch.int32))
|
||||
max_block_len = ((context_lens.max().item() + blocksize - 1) //
|
||||
blocksize * blocksize)
|
||||
context_lens = (
|
||||
torch.randint(int(0.8 * avg_kv), int(1.2 * avg_kv), (batch_size,))
|
||||
.cuda()
|
||||
.to(torch.int32)
|
||||
)
|
||||
max_block_len = (
|
||||
(context_lens.max().item() + blocksize - 1) // blocksize * blocksize
|
||||
)
|
||||
block_tables = torch.zeros(
|
||||
(batch_size, max_block_len),
|
||||
device="cuda",
|
||||
@@ -243,7 +254,8 @@ def test_deepgemm_fp8_paged_mqa_logits():
|
||||
kv_cache_fp8 = kv_cache_cast_to_fp8(kv_cache)
|
||||
|
||||
schedule_metadata = get_paged_mqa_logits_metadata(
|
||||
context_lens, blocksize, get_num_sms())
|
||||
context_lens, blocksize, get_num_sms()
|
||||
)
|
||||
logits = fp8_paged_mqa_logits(
|
||||
q_fp8,
|
||||
kv_cache_fp8,
|
||||
@@ -263,15 +275,18 @@ def test_deepgemm_fp8_paged_mqa_logits():
|
||||
max_model_len,
|
||||
)
|
||||
|
||||
positions = (torch.arange(max_model_len,
|
||||
device="cuda").unsqueeze(0).expand(
|
||||
batch_size * next_n, -1))
|
||||
row_indices = (
|
||||
torch.arange(batch_size * next_n, device="cuda") // next_n)
|
||||
positions = (
|
||||
torch.arange(max_model_len, device="cuda")
|
||||
.unsqueeze(0)
|
||||
.expand(batch_size * next_n, -1)
|
||||
)
|
||||
row_indices = torch.arange(batch_size * next_n, device="cuda") // next_n
|
||||
next_n_offset = (
|
||||
torch.arange(batch_size * next_n, device="cuda") % next_n)
|
||||
mask = positions <= (context_lens[row_indices] - next_n +
|
||||
next_n_offset).unsqueeze(1)
|
||||
torch.arange(batch_size * next_n, device="cuda") % next_n
|
||||
)
|
||||
mask = positions <= (
|
||||
context_lens[row_indices] - next_n + next_n_offset
|
||||
).unsqueeze(1)
|
||||
|
||||
logits = logits.masked_fill(~mask, 0)
|
||||
ref_logits = ref_logits.masked_fill(~mask, 0)
|
||||
|
||||
@@ -7,10 +7,12 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.vllm_flash_attn import (fa_version_unsupported_reason,
|
||||
flash_attn_varlen_func,
|
||||
flash_attn_with_kvcache,
|
||||
is_fa_version_supported)
|
||||
from vllm.vllm_flash_attn import (
|
||||
fa_version_unsupported_reason,
|
||||
flash_attn_varlen_func,
|
||||
flash_attn_with_kvcache,
|
||||
is_fa_version_supported,
|
||||
)
|
||||
|
||||
NUM_HEADS = [(4, 4), (8, 2)]
|
||||
HEAD_SIZES = [128, 256]
|
||||
@@ -44,7 +46,7 @@ def ref_paged_attn(
|
||||
for i in range(num_seqs):
|
||||
query_len = query_lens[i]
|
||||
kv_len = kv_lens[i]
|
||||
q = query[start_idx:start_idx + query_len]
|
||||
q = query[start_idx : start_idx + query_len]
|
||||
q *= scale
|
||||
|
||||
num_kv_blocks = (kv_len + block_size - 1) // block_size
|
||||
@@ -62,10 +64,13 @@ def ref_paged_attn(
|
||||
empty_mask = torch.ones(query_len, kv_len)
|
||||
mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool()
|
||||
if sliding_window is not None:
|
||||
sliding_window_mask = torch.triu(empty_mask,
|
||||
diagonal=kv_len -
|
||||
(query_len + sliding_window) +
|
||||
1).bool().logical_not()
|
||||
sliding_window_mask = (
|
||||
torch.triu(
|
||||
empty_mask, diagonal=kv_len - (query_len + sliding_window) + 1
|
||||
)
|
||||
.bool()
|
||||
.logical_not()
|
||||
)
|
||||
mask |= sliding_window_mask
|
||||
if soft_cap is not None:
|
||||
attn = soft_cap * torch.tanh(attn / soft_cap)
|
||||
@@ -106,11 +111,15 @@ def test_flash_attn_with_paged_kv(
|
||||
) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
if not is_fa_version_supported(fa_version):
|
||||
pytest.skip(f"Flash attention version {fa_version} not supported due "
|
||||
f"to: \"{fa_version_unsupported_reason(fa_version)}\"")
|
||||
pytest.skip(
|
||||
f"Flash attention version {fa_version} not supported due "
|
||||
f'to: "{fa_version_unsupported_reason(fa_version)}"'
|
||||
)
|
||||
if q_dtype is not None and (dtype != torch.bfloat16 or fa_version == 2):
|
||||
pytest.skip("Flash attention with quantized inputs is only "
|
||||
"supported on version 3 with bfloat16 base type")
|
||||
pytest.skip(
|
||||
"Flash attention with quantized inputs is only "
|
||||
"supported on version 3 with bfloat16 base type"
|
||||
)
|
||||
|
||||
current_platform.seed_everything(0)
|
||||
num_seqs = len(kv_lens)
|
||||
@@ -119,23 +128,19 @@ def test_flash_attn_with_paged_kv(
|
||||
assert num_query_heads % num_kv_heads == 0
|
||||
max_kv_len = max(kv_lens)
|
||||
scale = head_size**-0.5
|
||||
window_size = ((sliding_window - 1, 0) if sliding_window is not None else
|
||||
(-1, -1))
|
||||
window_size = (sliding_window - 1, 0) if sliding_window is not None else (-1, -1)
|
||||
|
||||
query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
|
||||
key_cache = torch.randn(num_blocks,
|
||||
block_size,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
dtype=dtype)
|
||||
key_cache = torch.randn(
|
||||
num_blocks, block_size, num_kv_heads, head_size, dtype=dtype
|
||||
)
|
||||
value_cache = torch.randn_like(key_cache)
|
||||
kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32)
|
||||
|
||||
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
|
||||
block_tables = torch.randint(0,
|
||||
num_blocks,
|
||||
(num_seqs, max_num_blocks_per_seq),
|
||||
dtype=torch.int32)
|
||||
block_tables = torch.randint(
|
||||
0, num_blocks, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32
|
||||
)
|
||||
|
||||
q = query.unsqueeze(1)
|
||||
out = torch.empty_like(q) if use_out else None
|
||||
@@ -180,23 +185,27 @@ def test_flash_attn_with_paged_kv(
|
||||
if q_dtype is not None:
|
||||
atol, rtol = 1.5e-1, 1.5e-1
|
||||
|
||||
ref_output = ref_paged_attn(query=query,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
query_lens=[1] * num_seqs,
|
||||
kv_lens=kv_lens,
|
||||
block_tables=block_tables,
|
||||
scale=scale,
|
||||
soft_cap=soft_cap,
|
||||
sliding_window=sliding_window)
|
||||
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), \
|
||||
f"{torch.max(torch.abs(output - ref_output))}"
|
||||
ref_output = ref_paged_attn(
|
||||
query=query,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
query_lens=[1] * num_seqs,
|
||||
kv_lens=kv_lens,
|
||||
block_tables=block_tables,
|
||||
scale=scale,
|
||||
soft_cap=soft_cap,
|
||||
sliding_window=sliding_window,
|
||||
)
|
||||
(
|
||||
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol),
|
||||
f"{torch.max(torch.abs(output - ref_output))}",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_out", [True, False])
|
||||
@pytest.mark.parametrize("seq_lens",
|
||||
[[(1, 1328), (5, 18),
|
||||
(129, 463)], [(1, 523), (1, 37), (1, 2011)]])
|
||||
@pytest.mark.parametrize(
|
||||
"seq_lens", [[(1, 1328), (5, 18), (129, 463)], [(1, 523), (1, 37), (1, 2011)]]
|
||||
)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@@ -222,11 +231,15 @@ def test_varlen_with_paged_kv(
|
||||
) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
if not is_fa_version_supported(fa_version):
|
||||
pytest.skip(f"Flash attention version {fa_version} not supported due "
|
||||
f"to: \"{fa_version_unsupported_reason(fa_version)}\"")
|
||||
pytest.skip(
|
||||
f"Flash attention version {fa_version} not supported due "
|
||||
f'to: "{fa_version_unsupported_reason(fa_version)}"'
|
||||
)
|
||||
if q_dtype is not None and (dtype != torch.bfloat16 or fa_version == 2):
|
||||
pytest.skip("Flash attention with quantized inputs is only "
|
||||
"supported on version 3 with bfloat16 base type")
|
||||
pytest.skip(
|
||||
"Flash attention with quantized inputs is only "
|
||||
"supported on version 3 with bfloat16 base type"
|
||||
)
|
||||
current_platform.seed_everything(0)
|
||||
num_seqs = len(seq_lens)
|
||||
query_lens = [x[0] for x in seq_lens]
|
||||
@@ -236,30 +249,23 @@ def test_varlen_with_paged_kv(
|
||||
assert num_query_heads % num_kv_heads == 0
|
||||
max_query_len = max(query_lens)
|
||||
max_kv_len = max(kv_lens)
|
||||
window_size = ((sliding_window - 1, 0) if sliding_window is not None else
|
||||
(-1, -1))
|
||||
window_size = (sliding_window - 1, 0) if sliding_window is not None else (-1, -1)
|
||||
scale = head_size**-0.5
|
||||
|
||||
query = torch.randn(sum(query_lens),
|
||||
num_query_heads,
|
||||
head_size,
|
||||
dtype=dtype)
|
||||
key_cache = torch.randn(num_blocks,
|
||||
block_size,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
dtype=dtype)
|
||||
query = torch.randn(sum(query_lens), num_query_heads, head_size, dtype=dtype)
|
||||
key_cache = torch.randn(
|
||||
num_blocks, block_size, num_kv_heads, head_size, dtype=dtype
|
||||
)
|
||||
value_cache = torch.randn_like(key_cache)
|
||||
cu_query_lens = torch.tensor([0] + query_lens,
|
||||
dtype=torch.int32).cumsum(dim=0,
|
||||
dtype=torch.int32)
|
||||
cu_query_lens = torch.tensor([0] + query_lens, dtype=torch.int32).cumsum(
|
||||
dim=0, dtype=torch.int32
|
||||
)
|
||||
kv_lens = torch.tensor(kv_lens, dtype=torch.int32)
|
||||
|
||||
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
|
||||
block_tables = torch.randint(0,
|
||||
num_blocks,
|
||||
(num_seqs, max_num_blocks_per_seq),
|
||||
dtype=torch.int32)
|
||||
block_tables = torch.randint(
|
||||
0, num_blocks, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32
|
||||
)
|
||||
|
||||
out = torch.empty_like(query) if use_out else None
|
||||
|
||||
@@ -315,5 +321,7 @@ def test_varlen_with_paged_kv(
|
||||
atol, rtol = 1.5e-2, 1e-2
|
||||
if q_dtype is not None:
|
||||
atol, rtol = 1.5e-1, 1.5e-1
|
||||
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), \
|
||||
f"{torch.max(torch.abs(output - ref_output))}"
|
||||
(
|
||||
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol),
|
||||
f"{torch.max(torch.abs(output - ref_output))}",
|
||||
)
|
||||
|
||||
@@ -38,7 +38,7 @@ def ref_paged_attn(
|
||||
for i in range(num_seqs):
|
||||
query_len = query_lens[i]
|
||||
kv_len = kv_lens[i]
|
||||
q = query[start_idx:start_idx + query_len]
|
||||
q = query[start_idx : start_idx + query_len]
|
||||
q *= scale
|
||||
|
||||
num_kv_blocks = (kv_len + block_size - 1) // block_size
|
||||
@@ -56,10 +56,13 @@ def ref_paged_attn(
|
||||
empty_mask = torch.ones(query_len, kv_len)
|
||||
mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool()
|
||||
if sliding_window is not None:
|
||||
sliding_window_mask = torch.triu(empty_mask,
|
||||
diagonal=kv_len -
|
||||
(query_len + sliding_window) +
|
||||
1).bool().logical_not()
|
||||
sliding_window_mask = (
|
||||
torch.triu(
|
||||
empty_mask, diagonal=kv_len - (query_len + sliding_window) + 1
|
||||
)
|
||||
.bool()
|
||||
.logical_not()
|
||||
)
|
||||
mask |= sliding_window_mask
|
||||
if soft_cap is not None:
|
||||
attn = soft_cap * torch.tanh(attn / soft_cap)
|
||||
@@ -101,20 +104,16 @@ def test_flashinfer_decode_with_paged_kv(
|
||||
|
||||
query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
|
||||
|
||||
key_value_cache = torch.randn(NUM_BLOCKS,
|
||||
2,
|
||||
block_size,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
dtype=dtype)
|
||||
key_value_cache = torch.randn(
|
||||
NUM_BLOCKS, 2, block_size, num_kv_heads, head_size, dtype=dtype
|
||||
)
|
||||
key_cache = key_value_cache[:, 0, :, :, :].squeeze(1)
|
||||
value_cache = key_value_cache[:, 1, :, :, :].squeeze(1)
|
||||
|
||||
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
|
||||
block_tables = torch.randint(0,
|
||||
NUM_BLOCKS,
|
||||
(num_seqs, max_num_blocks_per_seq),
|
||||
dtype=torch.int32)
|
||||
block_tables = torch.randint(
|
||||
0, NUM_BLOCKS, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32
|
||||
)
|
||||
|
||||
kv_indptr = [0]
|
||||
kv_indices = []
|
||||
@@ -135,9 +134,9 @@ def test_flashinfer_decode_with_paged_kv(
|
||||
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
|
||||
|
||||
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
|
||||
wrapper = flashinfer.\
|
||||
BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD",
|
||||
use_tensor_cores=True)
|
||||
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
|
||||
workspace_buffer, "NHD", use_tensor_cores=True
|
||||
)
|
||||
wrapper.plan(
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
@@ -155,17 +154,21 @@ def test_flashinfer_decode_with_paged_kv(
|
||||
|
||||
output = wrapper.run(query, key_value_cache)
|
||||
|
||||
ref_output = ref_paged_attn(query=query,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
query_lens=[1] * num_seqs,
|
||||
kv_lens=kv_lens,
|
||||
block_tables=block_tables,
|
||||
scale=scale,
|
||||
soft_cap=soft_cap,
|
||||
sliding_window=sliding_window)
|
||||
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2), \
|
||||
f"{torch.max(torch.abs(output - ref_output))}"
|
||||
ref_output = ref_paged_attn(
|
||||
query=query,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
query_lens=[1] * num_seqs,
|
||||
kv_lens=kv_lens,
|
||||
block_tables=block_tables,
|
||||
scale=scale,
|
||||
soft_cap=soft_cap,
|
||||
sliding_window=sliding_window,
|
||||
)
|
||||
(
|
||||
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2),
|
||||
f"{torch.max(torch.abs(output - ref_output))}",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seq_lens", [[(1, 1328), (5, 18), (129, 463)]])
|
||||
@@ -196,16 +199,10 @@ def test_flashinfer_prefill_with_paged_kv(
|
||||
max_kv_len = max(kv_lens)
|
||||
scale = head_size**-0.5
|
||||
|
||||
query = torch.randn(sum(query_lens),
|
||||
num_query_heads,
|
||||
head_size,
|
||||
dtype=dtype)
|
||||
key_value_cache = torch.randn(NUM_BLOCKS,
|
||||
2,
|
||||
block_size,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
dtype=dtype)
|
||||
query = torch.randn(sum(query_lens), num_query_heads, head_size, dtype=dtype)
|
||||
key_value_cache = torch.randn(
|
||||
NUM_BLOCKS, 2, block_size, num_kv_heads, head_size, dtype=dtype
|
||||
)
|
||||
key_cache = key_value_cache[:, 0, :, :, :].squeeze(1)
|
||||
value_cache = key_value_cache[:, 1, :, :, :].squeeze(1)
|
||||
|
||||
@@ -215,10 +212,9 @@ def test_flashinfer_prefill_with_paged_kv(
|
||||
value_cache /= head_size**0.5
|
||||
|
||||
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
|
||||
block_tables = torch.randint(0,
|
||||
NUM_BLOCKS,
|
||||
(num_seqs, max_num_blocks_per_seq),
|
||||
dtype=torch.int32)
|
||||
block_tables = torch.randint(
|
||||
0, NUM_BLOCKS, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32
|
||||
)
|
||||
|
||||
qo_indptr = [0]
|
||||
kv_indptr = [0]
|
||||
@@ -242,8 +238,7 @@ def test_flashinfer_prefill_with_paged_kv(
|
||||
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
|
||||
|
||||
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
|
||||
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
|
||||
workspace_buffer, "NHD")
|
||||
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(workspace_buffer, "NHD")
|
||||
wrapper.plan(
|
||||
qo_indptr,
|
||||
kv_indptr,
|
||||
@@ -264,17 +259,21 @@ def test_flashinfer_prefill_with_paged_kv(
|
||||
key_value_cache,
|
||||
)
|
||||
|
||||
ref_output = ref_paged_attn(query=query,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
query_lens=query_lens,
|
||||
kv_lens=kv_lens,
|
||||
block_tables=block_tables,
|
||||
scale=scale,
|
||||
soft_cap=soft_cap,
|
||||
sliding_window=sliding_window)
|
||||
torch.testing.assert_close(output, ref_output, atol=5e-2, rtol=1e-2), \
|
||||
f"{torch.max(torch.abs(output - ref_output))}"
|
||||
ref_output = ref_paged_attn(
|
||||
query=query,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
query_lens=query_lens,
|
||||
kv_lens=kv_lens,
|
||||
block_tables=block_tables,
|
||||
scale=scale,
|
||||
soft_cap=soft_cap,
|
||||
sliding_window=sliding_window,
|
||||
)
|
||||
(
|
||||
torch.testing.assert_close(output, ref_output, atol=5e-2, rtol=1e-2),
|
||||
f"{torch.max(torch.abs(output - ref_output))}",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seq_lens", [[(1, 132), (5, 18)]])
|
||||
@@ -284,9 +283,13 @@ def test_flashinfer_prefill_with_paged_kv(
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("soft_cap", SOFT_CAPS)
|
||||
def test_flashinfer_prefill_with_paged_fp8_kv(
|
||||
seq_lens: list[tuple[int, int]], num_heads: tuple[int, int],
|
||||
head_size: int, dtype: torch.dtype, block_size: int,
|
||||
soft_cap: Optional[float]) -> None:
|
||||
seq_lens: list[tuple[int, int]],
|
||||
num_heads: tuple[int, int],
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
block_size: int,
|
||||
soft_cap: Optional[float],
|
||||
) -> None:
|
||||
pytest.skip("TODO: fix the accuracy issue")
|
||||
torch.set_default_device("cuda")
|
||||
current_platform.seed_everything(0)
|
||||
@@ -301,17 +304,11 @@ def test_flashinfer_prefill_with_paged_fp8_kv(
|
||||
|
||||
kv_cache_dtype = torch.float8_e4m3fn
|
||||
|
||||
query = torch.randn(sum(query_lens),
|
||||
num_query_heads,
|
||||
head_size,
|
||||
dtype=dtype)
|
||||
query = torch.randn(sum(query_lens), num_query_heads, head_size, dtype=dtype)
|
||||
NUM_BLOCKS_FP8 = 2048
|
||||
key_value_cache = torch.randn(NUM_BLOCKS_FP8,
|
||||
2,
|
||||
block_size,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
dtype=dtype)
|
||||
key_value_cache = torch.randn(
|
||||
NUM_BLOCKS_FP8, 2, block_size, num_kv_heads, head_size, dtype=dtype
|
||||
)
|
||||
key_cache, value_cache = torch.chunk(key_value_cache, 2, dim=1)
|
||||
key_cache /= head_size**0.5
|
||||
value_cache /= head_size**0.5
|
||||
@@ -319,15 +316,15 @@ def test_flashinfer_prefill_with_paged_fp8_kv(
|
||||
k_scale = key_cache.amax().item() / 448.0
|
||||
v_scale = value_cache.amax().item() / 448.0
|
||||
|
||||
kv_cache_fp8 = torch.cat([key_cache / k_scale, value_cache / v_scale],
|
||||
dim=1).to(kv_cache_dtype)
|
||||
kv_cache_fp8 = torch.cat([key_cache / k_scale, value_cache / v_scale], dim=1).to(
|
||||
kv_cache_dtype
|
||||
)
|
||||
|
||||
assert (kv_cache_fp8.shape == key_value_cache.shape)
|
||||
assert kv_cache_fp8.shape == key_value_cache.shape
|
||||
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
|
||||
block_tables = torch.randint(0,
|
||||
NUM_BLOCKS_FP8,
|
||||
(num_seqs, max_num_blocks_per_seq),
|
||||
dtype=torch.int32)
|
||||
block_tables = torch.randint(
|
||||
0, NUM_BLOCKS_FP8, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32
|
||||
)
|
||||
|
||||
qo_indptr = [0]
|
||||
kv_indptr = [0]
|
||||
@@ -351,8 +348,7 @@ def test_flashinfer_prefill_with_paged_fp8_kv(
|
||||
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
|
||||
|
||||
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
|
||||
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
|
||||
workspace_buffer, "NHD")
|
||||
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(workspace_buffer, "NHD")
|
||||
wrapper.plan(
|
||||
qo_indptr,
|
||||
kv_indptr,
|
||||
@@ -369,19 +365,23 @@ def test_flashinfer_prefill_with_paged_fp8_kv(
|
||||
|
||||
output = wrapper.run(query, kv_cache_fp8, k_scale=k_scale, v_scale=v_scale)
|
||||
|
||||
ref_output = ref_paged_attn(query=query,
|
||||
key_cache=key_cache.squeeze(1),
|
||||
value_cache=value_cache.squeeze(1),
|
||||
query_lens=query_lens,
|
||||
kv_lens=kv_lens,
|
||||
block_tables=block_tables,
|
||||
scale=scale,
|
||||
soft_cap=soft_cap)
|
||||
ref_output = ref_paged_attn(
|
||||
query=query,
|
||||
key_cache=key_cache.squeeze(1),
|
||||
value_cache=value_cache.squeeze(1),
|
||||
query_lens=query_lens,
|
||||
kv_lens=kv_lens,
|
||||
block_tables=block_tables,
|
||||
scale=scale,
|
||||
soft_cap=soft_cap,
|
||||
)
|
||||
del query
|
||||
del block_tables
|
||||
# verify prefill fp8
|
||||
torch.testing.assert_close(output, ref_output, atol=5e-2, rtol=1e-2), \
|
||||
f"{torch.max(torch.abs(output - ref_output))}"
|
||||
(
|
||||
torch.testing.assert_close(output, ref_output, atol=5e-2, rtol=1e-2),
|
||||
f"{torch.max(torch.abs(output - ref_output))}",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]])
|
||||
@@ -414,12 +414,9 @@ def test_flashinfer_decode_with_paged_fp8_kv(
|
||||
|
||||
query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
|
||||
NUM_BLOCKS_FP8 = 2048
|
||||
key_value_cache = torch.randn(NUM_BLOCKS_FP8,
|
||||
2,
|
||||
block_size,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
dtype=dtype)
|
||||
key_value_cache = torch.randn(
|
||||
NUM_BLOCKS_FP8, 2, block_size, num_kv_heads, head_size, dtype=dtype
|
||||
)
|
||||
key_cache, value_cache = torch.chunk(key_value_cache, 2, dim=1)
|
||||
key_cache /= head_size**0.5
|
||||
value_cache /= head_size**0.5
|
||||
@@ -429,14 +426,13 @@ def test_flashinfer_decode_with_paged_fp8_kv(
|
||||
|
||||
key_cache_fp8 = (key_cache / k_scale).to(kv_cache_dtype)
|
||||
value_cache_fp8 = (value_cache / v_scale).to(kv_cache_dtype)
|
||||
assert (key_cache_fp8.shape[1] == 1 and value_cache_fp8.shape[1] == 1)
|
||||
assert key_cache_fp8.shape[1] == 1 and value_cache_fp8.shape[1] == 1
|
||||
kv_cache_fp8 = torch.cat([key_cache_fp8, value_cache_fp8], dim=1)
|
||||
|
||||
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
|
||||
block_tables = torch.randint(0,
|
||||
NUM_BLOCKS_FP8,
|
||||
(num_seqs, max_num_blocks_per_seq),
|
||||
dtype=torch.int32)
|
||||
block_tables = torch.randint(
|
||||
0, NUM_BLOCKS_FP8, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32
|
||||
)
|
||||
|
||||
kv_indptr = [0]
|
||||
kv_indices = []
|
||||
@@ -457,32 +453,38 @@ def test_flashinfer_decode_with_paged_fp8_kv(
|
||||
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
|
||||
|
||||
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
|
||||
wrapper = flashinfer.\
|
||||
BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD",
|
||||
use_tensor_cores=use_tensor_cores)
|
||||
wrapper.plan(kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_lens,
|
||||
num_query_heads,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
block_size,
|
||||
"NONE",
|
||||
q_data_type=dtype,
|
||||
kv_data_type=kv_cache_dtype,
|
||||
logits_soft_cap=soft_cap)
|
||||
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
|
||||
workspace_buffer, "NHD", use_tensor_cores=use_tensor_cores
|
||||
)
|
||||
wrapper.plan(
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_lens,
|
||||
num_query_heads,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
block_size,
|
||||
"NONE",
|
||||
q_data_type=dtype,
|
||||
kv_data_type=kv_cache_dtype,
|
||||
logits_soft_cap=soft_cap,
|
||||
)
|
||||
output = wrapper.run(query, kv_cache_fp8, k_scale=k_scale, v_scale=v_scale)
|
||||
key_cache = key_value_cache[:, 0, :, :, :].squeeze(1)
|
||||
value_cache = key_value_cache[:, 1, :, :, :].squeeze(1)
|
||||
|
||||
ref_output = ref_paged_attn(query=query,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
query_lens=[1] * num_seqs,
|
||||
kv_lens=kv_lens,
|
||||
block_tables=block_tables,
|
||||
scale=scale,
|
||||
soft_cap=soft_cap)
|
||||
ref_output = ref_paged_attn(
|
||||
query=query,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
query_lens=[1] * num_seqs,
|
||||
kv_lens=kv_lens,
|
||||
block_tables=block_tables,
|
||||
scale=scale,
|
||||
soft_cap=soft_cap,
|
||||
)
|
||||
# Temporary fix: Increasing the tolerance. Seems like a flashinfer issue
|
||||
torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \
|
||||
f"{torch.max(torch.abs(output - ref_output))}"
|
||||
(
|
||||
torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2),
|
||||
f"{torch.max(torch.abs(output - ref_output))}",
|
||||
)
|
||||
|
||||
@@ -13,34 +13,29 @@ FLASHINFER_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024
|
||||
if not current_platform.has_device_capability(100):
|
||||
pytest.skip(
|
||||
reason="FlashInfer MLA Requires compute capability of 10 or above.",
|
||||
allow_module_level=True)
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
|
||||
def ref_mla(
|
||||
out: Tensor, # (bs, num_heads, v_head_dim)
|
||||
query: Tensor, # (bs, num_heads, head_dim)
|
||||
kv_cache: Tensor, # (num_blocks, block_size, head_dim)
|
||||
scale: float,
|
||||
block_tables: Tensor, # (bs, max_num_blocks)
|
||||
seq_lens: Tensor, # (bs,)
|
||||
out: Tensor, # (bs, num_heads, v_head_dim)
|
||||
query: Tensor, # (bs, num_heads, head_dim)
|
||||
kv_cache: Tensor, # (num_blocks, block_size, head_dim)
|
||||
scale: float,
|
||||
block_tables: Tensor, # (bs, max_num_blocks)
|
||||
seq_lens: Tensor, # (bs,)
|
||||
):
|
||||
bs, num_heads, v_head_dim = out.shape
|
||||
head_dim = query.shape[2]
|
||||
|
||||
for i in range(bs):
|
||||
# gather and flatten KV-cache
|
||||
kv = kv_cache[
|
||||
block_tables[i]] # (max_num_blocks, block_size, head_dim)
|
||||
kv = kv.view(1, -1,
|
||||
head_dim)[:, :seq_lens[i]] # (1, seq_len, head_dim)
|
||||
kv = kv_cache[block_tables[i]] # (max_num_blocks, block_size, head_dim)
|
||||
kv = kv.view(1, -1, head_dim)[:, : seq_lens[i]] # (1, seq_len, head_dim)
|
||||
v = kv[:, :, :v_head_dim]
|
||||
|
||||
q = query[i].view(num_heads, 1, head_dim)
|
||||
o = F.scaled_dot_product_attention(q,
|
||||
kv,
|
||||
v,
|
||||
scale=scale,
|
||||
enable_gqa=True)
|
||||
o = F.scaled_dot_product_attention(q, kv, v, scale=scale, enable_gqa=True)
|
||||
out[i] = o.view(num_heads, v_head_dim)
|
||||
|
||||
return out
|
||||
@@ -50,7 +45,7 @@ def ref_mla(
|
||||
@pytest.mark.parametrize("bs", [1, 2, 4, 16])
|
||||
@pytest.mark.parametrize("block_size", [32, 64])
|
||||
def test_flashinfer_mla_decode(dtype: torch.dtype, bs: int, block_size: int):
|
||||
torch.set_default_device('cuda')
|
||||
torch.set_default_device("cuda")
|
||||
torch.manual_seed(42)
|
||||
|
||||
# Deepseek R1 config
|
||||
@@ -59,11 +54,11 @@ def test_flashinfer_mla_decode(dtype: torch.dtype, bs: int, block_size: int):
|
||||
qk_nope_head_dim = 128
|
||||
qk_rope_head_dim = 64
|
||||
qk_head_dim = kv_lora_rank + qk_rope_head_dim
|
||||
scale = (qk_nope_head_dim + qk_rope_head_dim)**-0.5
|
||||
scale = (qk_nope_head_dim + qk_rope_head_dim) ** -0.5
|
||||
|
||||
MAX_SEQ_LEN = 1024
|
||||
|
||||
seq_lens = [torch.randint(2, MAX_SEQ_LEN, (1, )).item() for _ in range(bs)]
|
||||
seq_lens = [torch.randint(2, MAX_SEQ_LEN, (1,)).item() for _ in range(bs)]
|
||||
seq_lens[-1] = MAX_SEQ_LEN
|
||||
max_seq_len = max(seq_lens)
|
||||
seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int32)
|
||||
@@ -86,12 +81,12 @@ def test_flashinfer_mla_decode(dtype: torch.dtype, bs: int, block_size: int):
|
||||
block_id = 0
|
||||
for i in range(bs):
|
||||
num_blocks_needed = blocks_per_seq[i]
|
||||
block_tables[i, :num_blocks_needed] = all_block_ids[block_id:block_id +
|
||||
num_blocks_needed]
|
||||
block_tables[i, :num_blocks_needed] = all_block_ids[
|
||||
block_id : block_id + num_blocks_needed
|
||||
]
|
||||
block_id += num_blocks_needed
|
||||
|
||||
kv_cache = torch.randn(block_tables.numel(), block_size,
|
||||
qk_head_dim).to(dtype)
|
||||
kv_cache = torch.randn(block_tables.numel(), block_size, qk_head_dim).to(dtype)
|
||||
q = torch.randn(bs, num_heads, qk_head_dim).to(dtype)
|
||||
|
||||
out_ref = q.new_zeros(bs, num_heads, kv_lora_rank)
|
||||
|
||||
@@ -6,15 +6,18 @@ import flashinfer
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
|
||||
FLOAT8_E4M3_MAX,
|
||||
dequantize_nvfp4_to_dtype)
|
||||
from tests.kernels.quantization.nvfp4_utils import (
|
||||
FLOAT4_E2M1_MAX,
|
||||
FLOAT8_E4M3_MAX,
|
||||
dequantize_nvfp4_to_dtype,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import round_up
|
||||
|
||||
if not current_platform.is_device_capability(100):
|
||||
pytest.skip("This TRTLLM kernel requires NVIDIA Blackwell.",
|
||||
allow_module_level=True)
|
||||
pytest.skip(
|
||||
"This TRTLLM kernel requires NVIDIA Blackwell.", allow_module_level=True
|
||||
)
|
||||
|
||||
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
@@ -64,8 +67,9 @@ NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
|
||||
@torch.inference_mode
|
||||
def test_flashinfer_trtllm_decode_with_baseline(
|
||||
dtype: torch.dtype,
|
||||
quant_dtypes: tuple[Optional[torch.dtype], Optional[torch.dtype],
|
||||
Optional[torch.dtype]],
|
||||
quant_dtypes: tuple[
|
||||
Optional[torch.dtype], Optional[torch.dtype], Optional[torch.dtype]
|
||||
],
|
||||
batch_size: int,
|
||||
max_seq_lens: tuple[int, int],
|
||||
num_heads: tuple[int, int],
|
||||
@@ -106,7 +110,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
|
||||
q_scale = 1.0
|
||||
ref_query = query
|
||||
|
||||
kv_lens = torch.randint(1, max_kv_len, (batch_size, ), dtype=torch.int32)
|
||||
kv_lens = torch.randint(1, max_kv_len, (batch_size,), dtype=torch.int32)
|
||||
kv_lens[-1] = max_kv_len
|
||||
|
||||
seq_lens = kv_lens
|
||||
@@ -122,10 +126,9 @@ def test_flashinfer_trtllm_decode_with_baseline(
|
||||
k_scale = v_scale = kv_scale
|
||||
|
||||
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
|
||||
block_tables = torch.randint(0,
|
||||
NUM_BLOCKS,
|
||||
(batch_size, max_num_blocks_per_seq),
|
||||
dtype=torch.int32)
|
||||
block_tables = torch.randint(
|
||||
0, NUM_BLOCKS, (batch_size, max_num_blocks_per_seq), dtype=torch.int32
|
||||
)
|
||||
kv_indptr = [0]
|
||||
kv_indices = []
|
||||
kv_last_page_lens = []
|
||||
@@ -147,20 +150,23 @@ def test_flashinfer_trtllm_decode_with_baseline(
|
||||
|
||||
# Baseline Decode
|
||||
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
|
||||
workspace_buffer, kv_layout, use_tensor_cores=True)
|
||||
wrapper.plan(kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_lens,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
block_size,
|
||||
"NONE",
|
||||
sm_scale=sm_scale,
|
||||
q_data_type=dtype,
|
||||
kv_data_type=dtype,
|
||||
window_left=window_left,
|
||||
logits_soft_cap=soft_cap)
|
||||
workspace_buffer, kv_layout, use_tensor_cores=True
|
||||
)
|
||||
wrapper.plan(
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_lens,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
block_size,
|
||||
"NONE",
|
||||
sm_scale=sm_scale,
|
||||
q_data_type=dtype,
|
||||
kv_data_type=dtype,
|
||||
window_left=window_left,
|
||||
logits_soft_cap=soft_cap,
|
||||
)
|
||||
|
||||
output = torch.empty(ref_query.shape, dtype=dtype)
|
||||
wrapper.run(ref_query, ref_kv_cache, out=output)
|
||||
@@ -169,17 +175,21 @@ def test_flashinfer_trtllm_decode_with_baseline(
|
||||
if o_quant_dtype == FP8_DTYPE:
|
||||
_, o_scale = to_float8(output)
|
||||
elif o_quant_dtype == FP4_DTYPE:
|
||||
o_sf_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
|
||||
torch.amax(output.flatten(), dim=-1)).to(torch.float32)
|
||||
o_sf_scale = (
|
||||
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(output.flatten(), dim=-1)
|
||||
).to(torch.float32)
|
||||
|
||||
# TRTLLM Decode
|
||||
if o_quant_dtype == FP4_DTYPE:
|
||||
output_trtllm = flashinfer.utils.FP4Tensor(
|
||||
torch.empty(query.shape[:-1] + (query.shape[-1] // 2, ),
|
||||
dtype=torch.uint8),
|
||||
torch.empty((round_up(query.shape[0], 128),
|
||||
round_up(query.shape[1] * query.shape[2] // 16, 4)),
|
||||
dtype=torch.float8_e4m3fn),
|
||||
torch.empty(query.shape[:-1] + (query.shape[-1] // 2,), dtype=torch.uint8),
|
||||
torch.empty(
|
||||
(
|
||||
round_up(query.shape[0], 128),
|
||||
round_up(query.shape[1] * query.shape[2] // 16, 4),
|
||||
),
|
||||
dtype=torch.float8_e4m3fn,
|
||||
),
|
||||
)
|
||||
else:
|
||||
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
|
||||
@@ -201,13 +211,12 @@ def test_flashinfer_trtllm_decode_with_baseline(
|
||||
output_trtllm = output_trtllm.to(dtype) * o_scale
|
||||
elif o_quant_dtype == FP4_DTYPE:
|
||||
output_trtllm.data = output_trtllm.data.reshape(
|
||||
-1, query.shape[1] * query.shape[2] // 2)
|
||||
output_trtllm = dequantize_nvfp4_to_dtype(output_trtllm.data,
|
||||
output_trtllm.scale,
|
||||
o_sf_scale, dtype,
|
||||
query.device)
|
||||
output_trtllm = output_trtllm.reshape(-1, query.shape[1],
|
||||
query.shape[2])
|
||||
-1, query.shape[1] * query.shape[2] // 2
|
||||
)
|
||||
output_trtllm = dequantize_nvfp4_to_dtype(
|
||||
output_trtllm.data, output_trtllm.scale, o_sf_scale, dtype, query.device
|
||||
)
|
||||
output_trtllm = output_trtllm.reshape(-1, query.shape[1], query.shape[2])
|
||||
|
||||
if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE:
|
||||
rtol, atol = 3e-1, 1e0
|
||||
@@ -216,8 +225,10 @@ def test_flashinfer_trtllm_decode_with_baseline(
|
||||
else:
|
||||
rtol, atol = 1e-2, 2e-2
|
||||
|
||||
torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol), \
|
||||
f"{torch.max(torch.abs(output - output_trtllm))}"
|
||||
(
|
||||
torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol),
|
||||
f"{torch.max(torch.abs(output - output_trtllm))}",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", DTYPE)
|
||||
@@ -233,8 +244,9 @@ def test_flashinfer_trtllm_decode_with_baseline(
|
||||
@torch.inference_mode
|
||||
def test_flashinfer_trtllm_prefill_with_baseline(
|
||||
dtype: torch.dtype,
|
||||
quant_dtypes: tuple[Optional[torch.dtype], Optional[torch.dtype],
|
||||
Optional[torch.dtype]],
|
||||
quant_dtypes: tuple[
|
||||
Optional[torch.dtype], Optional[torch.dtype], Optional[torch.dtype]
|
||||
],
|
||||
batch_size: int,
|
||||
max_seq_lens: tuple[int, int],
|
||||
num_heads: tuple[int, int],
|
||||
@@ -270,17 +282,16 @@ def test_flashinfer_trtllm_prefill_with_baseline(
|
||||
else:
|
||||
raise ValueError(f"Invalid kv_layout: {kv_layout}")
|
||||
|
||||
q_lens = torch.randint(1, max_q_len, (batch_size, ), dtype=torch.int32)
|
||||
q_lens = torch.randint(1, max_q_len, (batch_size,), dtype=torch.int32)
|
||||
q_lens[-1] = max_q_len
|
||||
q_indptr = torch.cat([
|
||||
torch.tensor([0], dtype=torch.int32),
|
||||
torch.cumsum(q_lens, dim=0, dtype=torch.int32),
|
||||
])
|
||||
q_indptr = torch.cat(
|
||||
[
|
||||
torch.tensor([0], dtype=torch.int32),
|
||||
torch.cumsum(q_lens, dim=0, dtype=torch.int32),
|
||||
]
|
||||
)
|
||||
|
||||
query = torch.randn(torch.sum(q_lens).item(),
|
||||
num_qo_heads,
|
||||
head_size,
|
||||
dtype=dtype)
|
||||
query = torch.randn(torch.sum(q_lens).item(), num_qo_heads, head_size, dtype=dtype)
|
||||
if q_quant_dtype == FP8_DTYPE:
|
||||
query, q_scale = to_float8(query)
|
||||
ref_query = query.to(dtype) * q_scale
|
||||
@@ -288,7 +299,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
|
||||
q_scale = 1.0
|
||||
ref_query = query
|
||||
|
||||
kv_lens = torch.randint(0, max_kv_len, (batch_size, ), dtype=torch.int32)
|
||||
kv_lens = torch.randint(0, max_kv_len, (batch_size,), dtype=torch.int32)
|
||||
kv_lens[-1] = max_kv_len
|
||||
|
||||
seq_lens = kv_lens + q_lens
|
||||
@@ -304,10 +315,9 @@ def test_flashinfer_trtllm_prefill_with_baseline(
|
||||
k_scale = v_scale = kv_scale
|
||||
|
||||
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
|
||||
block_tables = torch.randint(0,
|
||||
NUM_BLOCKS,
|
||||
(batch_size, max_num_blocks_per_seq),
|
||||
dtype=torch.int32)
|
||||
block_tables = torch.randint(
|
||||
0, NUM_BLOCKS, (batch_size, max_num_blocks_per_seq), dtype=torch.int32
|
||||
)
|
||||
kv_indptr = [0]
|
||||
kv_indices = []
|
||||
kv_last_page_lens = []
|
||||
@@ -329,21 +339,24 @@ def test_flashinfer_trtllm_prefill_with_baseline(
|
||||
|
||||
# Baseline Prefill
|
||||
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
|
||||
workspace_buffer, kv_layout)
|
||||
wrapper.plan(q_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_lens,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
block_size,
|
||||
causal=True,
|
||||
sm_scale=sm_scale,
|
||||
q_data_type=dtype,
|
||||
kv_data_type=dtype,
|
||||
window_left=window_left,
|
||||
logits_soft_cap=soft_cap)
|
||||
workspace_buffer, kv_layout
|
||||
)
|
||||
wrapper.plan(
|
||||
q_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_lens,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
block_size,
|
||||
causal=True,
|
||||
sm_scale=sm_scale,
|
||||
q_data_type=dtype,
|
||||
kv_data_type=dtype,
|
||||
window_left=window_left,
|
||||
logits_soft_cap=soft_cap,
|
||||
)
|
||||
|
||||
output = torch.empty(ref_query.shape, dtype=dtype)
|
||||
wrapper.run(ref_query, ref_kv_cache, out=output)
|
||||
@@ -352,17 +365,21 @@ def test_flashinfer_trtllm_prefill_with_baseline(
|
||||
if o_quant_dtype == FP8_DTYPE:
|
||||
_, o_scale = to_float8(output)
|
||||
elif o_quant_dtype == FP4_DTYPE:
|
||||
o_sf_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
|
||||
torch.amax(output.flatten(), dim=-1)).to(torch.float32)
|
||||
o_sf_scale = (
|
||||
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(output.flatten(), dim=-1)
|
||||
).to(torch.float32)
|
||||
|
||||
# TRTLLM Prefill
|
||||
if o_quant_dtype == FP4_DTYPE:
|
||||
output_trtllm = flashinfer.utils.FP4Tensor(
|
||||
torch.empty(query.shape[:-1] + (query.shape[-1] // 2, ),
|
||||
dtype=torch.uint8),
|
||||
torch.empty((round_up(query.shape[0], 128),
|
||||
round_up(query.shape[1] * query.shape[2] // 16, 4)),
|
||||
dtype=torch.float8_e4m3fn),
|
||||
torch.empty(query.shape[:-1] + (query.shape[-1] // 2,), dtype=torch.uint8),
|
||||
torch.empty(
|
||||
(
|
||||
round_up(query.shape[0], 128),
|
||||
round_up(query.shape[1] * query.shape[2] // 16, 4),
|
||||
),
|
||||
dtype=torch.float8_e4m3fn,
|
||||
),
|
||||
)
|
||||
else:
|
||||
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
|
||||
@@ -388,13 +405,12 @@ def test_flashinfer_trtllm_prefill_with_baseline(
|
||||
output_trtllm = output_trtllm.to(dtype) * o_scale
|
||||
elif o_quant_dtype == FP4_DTYPE:
|
||||
output_trtllm.data = output_trtllm.data.reshape(
|
||||
-1, query.shape[1] * query.shape[2] // 2)
|
||||
output_trtllm = dequantize_nvfp4_to_dtype(output_trtllm.data,
|
||||
output_trtllm.scale,
|
||||
o_sf_scale, dtype,
|
||||
query.device)
|
||||
output_trtllm = output_trtllm.reshape(-1, query.shape[1],
|
||||
query.shape[2])
|
||||
-1, query.shape[1] * query.shape[2] // 2
|
||||
)
|
||||
output_trtllm = dequantize_nvfp4_to_dtype(
|
||||
output_trtllm.data, output_trtllm.scale, o_sf_scale, dtype, query.device
|
||||
)
|
||||
output_trtllm = output_trtllm.reshape(-1, query.shape[1], query.shape[2])
|
||||
|
||||
if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE:
|
||||
rtol, atol = 4e-1, 1e0
|
||||
@@ -405,5 +421,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
|
||||
else:
|
||||
rtol, atol = 1e-2, 1e-2
|
||||
|
||||
torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol), \
|
||||
f"{torch.max(torch.abs(output - output_trtllm))}"
|
||||
(
|
||||
torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol),
|
||||
f"{torch.max(torch.abs(output - output_trtllm))}",
|
||||
)
|
||||
|
||||
@@ -7,30 +7,33 @@ import random
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
|
||||
get_mla_metadata,
|
||||
is_flashmla_supported)
|
||||
from vllm.attention.ops.flashmla import (
|
||||
flash_mla_with_kvcache,
|
||||
get_mla_metadata,
|
||||
is_flashmla_supported,
|
||||
)
|
||||
from vllm.triton_utils import triton
|
||||
|
||||
|
||||
def cal_diff(x: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
name: str,
|
||||
use_fp8: bool = False) -> None:
|
||||
def cal_diff(
|
||||
x: torch.Tensor, y: torch.Tensor, name: str, use_fp8: bool = False
|
||||
) -> None:
|
||||
x, y = x.double(), y.double()
|
||||
cos_diff = 1 - 2 * (x * y).sum().item() / max(
|
||||
(x * x + y * y).sum().item(), 1e-12)
|
||||
if (use_fp8):
|
||||
cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12)
|
||||
if use_fp8:
|
||||
assert cos_diff < 1e-4
|
||||
else:
|
||||
assert cos_diff < 1e-5
|
||||
|
||||
FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \
|
||||
if not is_flashmla_supported()[0] else "FlashMLA is supported"
|
||||
|
||||
FLASH_MLA_UNSUPPORTED_REASON = (
|
||||
is_flashmla_supported()[1]
|
||||
if not is_flashmla_supported()[0]
|
||||
else "FlashMLA is supported"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_flashmla_supported()[0],
|
||||
reason=FLASH_MLA_UNSUPPORTED_REASON)
|
||||
@pytest.mark.skipif(not is_flashmla_supported()[0], reason=FLASH_MLA_UNSUPPORTED_REASON)
|
||||
@pytest.mark.parametrize("b", [128])
|
||||
@pytest.mark.parametrize("s_q", [1, 2])
|
||||
@pytest.mark.parametrize("mean_sk", [4096, 8192, 16384])
|
||||
@@ -41,11 +44,13 @@ FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \
|
||||
@pytest.mark.parametrize("block_size", [64])
|
||||
@pytest.mark.parametrize("causal", [True])
|
||||
@pytest.mark.parametrize("varlen", [False, True])
|
||||
@pytest.mark.parametrize("torch_dtype",
|
||||
[torch.bfloat16, torch.float16, torch.float8_e4m3fn])
|
||||
@pytest.mark.parametrize(
|
||||
"torch_dtype", [torch.bfloat16, torch.float16, torch.float8_e4m3fn]
|
||||
)
|
||||
@torch.inference_mode()
|
||||
def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
|
||||
varlen, torch_dtype):
|
||||
def test_flash_mla(
|
||||
b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, varlen, torch_dtype
|
||||
):
|
||||
device = torch.device("cuda:0")
|
||||
if torch_dtype == torch.float8_e4m3fn:
|
||||
init_dtype = torch.bfloat16
|
||||
@@ -57,31 +62,34 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
|
||||
torch.manual_seed(0)
|
||||
random.seed(0)
|
||||
|
||||
print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, "
|
||||
f"{d=}, {dv=}, {causal=}, {varlen=}, {torch_dtype=}")
|
||||
print(
|
||||
f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, "
|
||||
f"{d=}, {dv=}, {causal=}, {varlen=}, {torch_dtype=}"
|
||||
)
|
||||
|
||||
use_fp8 = torch_dtype == torch.float8_e4m3fn
|
||||
cache_seqlens = torch.full((b, ), mean_sk, dtype=torch.int32)
|
||||
cache_seqlens = torch.full((b,), mean_sk, dtype=torch.int32)
|
||||
if varlen:
|
||||
for i in range(b):
|
||||
cache_seqlens[i] = max(random.normalvariate(mean_sk, mean_sk / 2),
|
||||
s_q)
|
||||
cache_seqlens[i] = max(random.normalvariate(mean_sk, mean_sk / 2), s_q)
|
||||
total_seqlens = cache_seqlens.sum().item()
|
||||
max_seqlen = cache_seqlens.max().item()
|
||||
max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256
|
||||
|
||||
q = torch.randn(b, s_q, h_q, d)
|
||||
block_table = torch.arange(b * max_seqlen_pad // block_size,
|
||||
dtype=torch.int32).view(
|
||||
b, max_seqlen_pad // block_size)
|
||||
block_table = torch.arange(
|
||||
b * max_seqlen_pad // block_size, dtype=torch.int32
|
||||
).view(b, max_seqlen_pad // block_size)
|
||||
blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d)
|
||||
for i in range(b):
|
||||
blocked_k.view(b, max_seqlen_pad, h_kv,
|
||||
d)[i, cache_seqlens[i].item():] = float("nan")
|
||||
blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item() :] = (
|
||||
float("nan")
|
||||
)
|
||||
blocked_v = blocked_k[..., :dv]
|
||||
|
||||
tile_scheduler_metadata, num_splits = get_mla_metadata(
|
||||
cache_seqlens, s_q * h_q // h_kv, h_kv)
|
||||
cache_seqlens, s_q * h_q // h_kv, h_kv
|
||||
)
|
||||
|
||||
init_dtype = q.dtype
|
||||
if use_fp8:
|
||||
@@ -97,16 +105,18 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
|
||||
descale_k = None
|
||||
|
||||
def flash_mla():
|
||||
return flash_mla_with_kvcache(q,
|
||||
blocked_k,
|
||||
block_table,
|
||||
cache_seqlens,
|
||||
dv,
|
||||
tile_scheduler_metadata,
|
||||
num_splits,
|
||||
causal=causal,
|
||||
descale_q=descale_q,
|
||||
descale_k=descale_k)
|
||||
return flash_mla_with_kvcache(
|
||||
q,
|
||||
blocked_k,
|
||||
block_table,
|
||||
cache_seqlens,
|
||||
dv,
|
||||
tile_scheduler_metadata,
|
||||
num_splits,
|
||||
causal=causal,
|
||||
descale_q=descale_q,
|
||||
descale_k=descale_k,
|
||||
)
|
||||
|
||||
def scaled_dot_product_attention(query, key, value, is_causal=False):
|
||||
query = query.float()
|
||||
@@ -119,8 +129,7 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
|
||||
s_q = query.shape[-2]
|
||||
s_k = key.shape[-2]
|
||||
attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype)
|
||||
temp_mask = torch.ones(s_q, s_k,
|
||||
dtype=torch.bool).tril(diagonal=s_k - s_q)
|
||||
temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q)
|
||||
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
|
||||
attn_bias.to(query.dtype)
|
||||
attn_weight += attn_bias
|
||||
@@ -130,10 +139,16 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
|
||||
|
||||
def ref_mla():
|
||||
q_ = (q.to(torch.float) * descale_q).to(init_dtype) if use_fp8 else q
|
||||
blocked_k_ = (blocked_k.to(torch.float) *
|
||||
descale_k).to(init_dtype) if use_fp8 else blocked_k
|
||||
blocked_v_ = (blocked_v.to(torch.float) *
|
||||
descale_k).to(init_dtype) if use_fp8 else blocked_v
|
||||
blocked_k_ = (
|
||||
(blocked_k.to(torch.float) * descale_k).to(init_dtype)
|
||||
if use_fp8
|
||||
else blocked_k
|
||||
)
|
||||
blocked_v_ = (
|
||||
(blocked_v.to(torch.float) * descale_k).to(init_dtype)
|
||||
if use_fp8
|
||||
else blocked_v
|
||||
)
|
||||
out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32)
|
||||
lse = torch.empty(b, h_q, s_q, dtype=torch.float32)
|
||||
for i in range(b):
|
||||
@@ -156,8 +171,9 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
|
||||
|
||||
t = triton.testing.do_bench(flash_mla)
|
||||
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
|
||||
bytes = (total_seqlens * h_kv * d +
|
||||
b * s_q * h_q * d) * (torch.finfo(torch_dtype).bits // 8) + (
|
||||
b * s_q * h_q * dv) * (torch.finfo(init_dtype).bits // 8)
|
||||
print(f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS,",
|
||||
f"{bytes / 10 ** 6 / t:.0f} GB/s")
|
||||
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d) * (
|
||||
torch.finfo(torch_dtype).bits // 8
|
||||
) + (b * s_q * h_q * dv) * (torch.finfo(init_dtype).bits // 8)
|
||||
print(
|
||||
f"{t:.3f} ms, {FLOPS / 10**9 / t:.0f} TFLOPS,", f"{bytes / 10**6 / t:.0f} GB/s"
|
||||
)
|
||||
|
||||
@@ -13,6 +13,7 @@ def _cuda_sm90_available() -> bool:
|
||||
|
||||
def test_sparse_flashmla_metadata_smoke():
|
||||
import vllm.attention.ops.flashmla as fm
|
||||
|
||||
ok, reason = fm.is_flashmla_supported()
|
||||
if not ok or not _cuda_sm90_available():
|
||||
pytest.skip(reason or "SM90 not available")
|
||||
@@ -27,18 +28,21 @@ def test_sparse_flashmla_metadata_smoke():
|
||||
|
||||
cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device)
|
||||
|
||||
tile_md, num_splits = fm.get_mla_metadata(cache_seqlens,
|
||||
q_seq_per_hk,
|
||||
num_heads_k,
|
||||
num_heads_q=num_heads_q,
|
||||
topk=topk,
|
||||
is_fp8_kvcache=True)
|
||||
tile_md, num_splits = fm.get_mla_metadata(
|
||||
cache_seqlens,
|
||||
q_seq_per_hk,
|
||||
num_heads_k,
|
||||
num_heads_q=num_heads_q,
|
||||
topk=topk,
|
||||
is_fp8_kvcache=True,
|
||||
)
|
||||
assert tile_md.dtype == torch.int32
|
||||
assert num_splits.dtype == torch.int32
|
||||
|
||||
|
||||
def test_sparse_flashmla_decode_smoke():
|
||||
import vllm.attention.ops.flashmla as fm
|
||||
|
||||
ok, reason = fm.is_flashmla_supported()
|
||||
if not ok or not _cuda_sm90_available():
|
||||
pytest.skip(reason or "SM90 not available")
|
||||
@@ -58,36 +62,42 @@ def test_sparse_flashmla_decode_smoke():
|
||||
q_seq_per_hk = seqlen_q * num_heads_q // num_heads_k
|
||||
# q_heads_per_hk = num_heads_q // num_heads_k
|
||||
cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device)
|
||||
tile_md, num_splits = fm.get_mla_metadata(cache_seqlens,
|
||||
q_seq_per_hk,
|
||||
num_heads_k,
|
||||
num_heads_q=num_heads_q,
|
||||
topk=topk,
|
||||
is_fp8_kvcache=True)
|
||||
tile_md, num_splits = fm.get_mla_metadata(
|
||||
cache_seqlens,
|
||||
q_seq_per_hk,
|
||||
num_heads_k,
|
||||
num_heads_q=num_heads_q,
|
||||
topk=topk,
|
||||
is_fp8_kvcache=True,
|
||||
)
|
||||
|
||||
# Inputs
|
||||
q = torch.zeros((batch_size, seqlen_q, num_heads_q, head_dim_k),
|
||||
dtype=torch.bfloat16,
|
||||
device=device)
|
||||
k_cache = torch.zeros((1, page_block_size, num_heads_k, bytes_per_token),
|
||||
dtype=torch.uint8,
|
||||
device=device)
|
||||
indices = torch.zeros((batch_size, seqlen_q, topk),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
q = torch.zeros(
|
||||
(batch_size, seqlen_q, num_heads_q, head_dim_k),
|
||||
dtype=torch.bfloat16,
|
||||
device=device,
|
||||
)
|
||||
k_cache = torch.zeros(
|
||||
(1, page_block_size, num_heads_k, bytes_per_token),
|
||||
dtype=torch.uint8,
|
||||
device=device,
|
||||
)
|
||||
indices = torch.zeros(
|
||||
(batch_size, seqlen_q, topk), dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
block_table = torch.zeros((batch_size, 128),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
out, lse = fm.flash_mla_with_kvcache(q,
|
||||
k_cache,
|
||||
block_table,
|
||||
cache_seqlens,
|
||||
head_dim_v,
|
||||
tile_md,
|
||||
num_splits,
|
||||
indices=indices,
|
||||
is_fp8_kvcache=True)
|
||||
block_table = torch.zeros((batch_size, 128), dtype=torch.int32, device=device)
|
||||
out, lse = fm.flash_mla_with_kvcache(
|
||||
q,
|
||||
k_cache,
|
||||
block_table,
|
||||
cache_seqlens,
|
||||
head_dim_v,
|
||||
tile_md,
|
||||
num_splits,
|
||||
indices=indices,
|
||||
is_fp8_kvcache=True,
|
||||
)
|
||||
assert out.shape[0] == batch_size
|
||||
assert out.shape[-1] == head_dim_v
|
||||
assert lse.shape[0] == batch_size
|
||||
@@ -95,6 +105,7 @@ def test_sparse_flashmla_decode_smoke():
|
||||
|
||||
def test_sparse_flashmla_prefill_smoke():
|
||||
import vllm.attention.ops.flashmla as fm
|
||||
|
||||
ok, reason = fm.is_flashmla_supported()
|
||||
if not ok or not _cuda_sm90_available():
|
||||
pytest.skip(reason or "SM90 not available")
|
||||
@@ -112,8 +123,7 @@ def test_sparse_flashmla_prefill_smoke():
|
||||
kv = torch.zeros((s_kv, h_kv, d_qk), dtype=torch.bfloat16, device=device)
|
||||
indices = torch.zeros((s_q, h_kv, topk), dtype=torch.int32, device=device)
|
||||
|
||||
out, max_logits, lse = fm.flash_mla_sparse_prefill(q, kv, indices, 1.0,
|
||||
d_v)
|
||||
out, max_logits, lse = fm.flash_mla_sparse_prefill(q, kv, indices, 1.0, d_v)
|
||||
assert out.shape == (s_q, h_q, d_v)
|
||||
assert max_logits.shape == (s_q, h_q)
|
||||
assert lse.shape == (s_q, h_q)
|
||||
|
||||
@@ -4,8 +4,7 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.lightning_attn import (
|
||||
linear_decode_forward_triton)
|
||||
from vllm.model_executor.layers.lightning_attn import linear_decode_forward_triton
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
NUM_HEADS = [4, 8]
|
||||
@@ -17,8 +16,8 @@ DTYPES = [torch.float32]
|
||||
|
||||
def reference_lightning_attention(q, k, v, ed, block_size, kv_history):
|
||||
"""Reference implementation of lightning attention core algorithm
|
||||
|
||||
The difference from the main implementation is that this processes
|
||||
|
||||
The difference from the main implementation is that this processes
|
||||
each step sequentially, instead of using parallelized triton kernels
|
||||
"""
|
||||
B, H, S, D = q.shape
|
||||
@@ -62,8 +61,7 @@ def reference_lightning_attention(q, k, v, ed, block_size, kv_history):
|
||||
# The actual implementation returns a tensor of shape [B, H, 2, D, E]
|
||||
# where dimension 2 contains both KV and KV history
|
||||
kv_reshaped = kv_cache.unsqueeze(2) # [B, H, 1, D, E]
|
||||
final_kv_cache = torch.cat([kv_reshaped, kv_reshaped],
|
||||
dim=2) # [B, H, 2, D, E]
|
||||
final_kv_cache = torch.cat([kv_reshaped, kv_reshaped], dim=2) # [B, H, 2, D, E]
|
||||
|
||||
return output, final_kv_cache
|
||||
|
||||
@@ -109,7 +107,7 @@ def reference_linear_decode(q, k, v, kv_caches, slope_rate, slot_idx):
|
||||
out_h = torch.matmul(q_bh, kv_new)
|
||||
|
||||
# Update output and cache
|
||||
output[b, h * D:(h + 1) * D] = out_h
|
||||
output[b, h * D : (h + 1) * D] = out_h
|
||||
kv_caches[b, h] = kv_new
|
||||
|
||||
return output
|
||||
@@ -135,12 +133,9 @@ def test_linear_decode_forward_triton(
|
||||
k = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
|
||||
v = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
|
||||
|
||||
kv_caches = base * torch.randn(batch_size,
|
||||
num_heads,
|
||||
head_size,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device="cuda")
|
||||
kv_caches = base * torch.randn(
|
||||
batch_size, num_heads, head_size, head_size, dtype=dtype, device="cuda"
|
||||
)
|
||||
|
||||
kv_caches_copy = kv_caches.clone()
|
||||
|
||||
@@ -150,15 +145,14 @@ def test_linear_decode_forward_triton(
|
||||
|
||||
slot_idx = torch.arange(batch_size, device="cuda")
|
||||
|
||||
triton_output = linear_decode_forward_triton(q, k, v, kv_caches,
|
||||
slope_rate, slot_idx)
|
||||
triton_output = linear_decode_forward_triton(
|
||||
q, k, v, kv_caches, slope_rate, slot_idx
|
||||
)
|
||||
|
||||
reference_output = reference_linear_decode(q, k, v, kv_caches_copy,
|
||||
slope_rate, slot_idx)
|
||||
torch.testing.assert_close(triton_output,
|
||||
reference_output,
|
||||
rtol=1e-1,
|
||||
atol=1e-1)
|
||||
reference_output = reference_linear_decode(
|
||||
q, k, v, kv_caches_copy, slope_rate, slot_idx
|
||||
)
|
||||
torch.testing.assert_close(triton_output, reference_output, rtol=1e-1, atol=1e-1)
|
||||
torch.testing.assert_close(kv_caches, kv_caches_copy, rtol=1e-1, atol=1e-1)
|
||||
|
||||
assert triton_output.shape == (batch_size, num_heads * head_size)
|
||||
@@ -184,12 +178,9 @@ def test_linear_decode_forward_triton_with_padding(
|
||||
k = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
|
||||
v = base * torch.randn(batch_size, num_heads, 1, head_size, dtype=dtype)
|
||||
|
||||
kv_caches = base * torch.randn(batch_size,
|
||||
num_heads,
|
||||
head_size,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device="cuda")
|
||||
kv_caches = base * torch.randn(
|
||||
batch_size, num_heads, head_size, head_size, dtype=dtype, device="cuda"
|
||||
)
|
||||
|
||||
kv_caches_copy = kv_caches.clone()
|
||||
|
||||
@@ -199,14 +190,15 @@ def test_linear_decode_forward_triton_with_padding(
|
||||
|
||||
slot_idx = torch.tensor([0, 1, -1, 2], device="cuda")
|
||||
|
||||
triton_output = linear_decode_forward_triton(q, k, v, kv_caches,
|
||||
slope_rate, slot_idx)
|
||||
triton_output = linear_decode_forward_triton(
|
||||
q, k, v, kv_caches, slope_rate, slot_idx
|
||||
)
|
||||
|
||||
reference_output = reference_linear_decode(q, k, v, kv_caches_copy,
|
||||
slope_rate, slot_idx)
|
||||
reference_output = reference_linear_decode(
|
||||
q, k, v, kv_caches_copy, slope_rate, slot_idx
|
||||
)
|
||||
|
||||
padding_mask = (slot_idx
|
||||
!= -1).unsqueeze(1).expand(-1, num_heads * head_size)
|
||||
padding_mask = (slot_idx != -1).unsqueeze(1).expand(-1, num_heads * head_size)
|
||||
|
||||
triton_masked = triton_output[padding_mask]
|
||||
reference_masked = reference_output[padding_mask]
|
||||
@@ -217,15 +209,11 @@ def test_linear_decode_forward_triton_with_padding(
|
||||
|
||||
for i in range(batch_size):
|
||||
if valid_indices[i] > 0:
|
||||
torch.testing.assert_close(kv_caches[i],
|
||||
kv_caches_copy[i],
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
torch.testing.assert_close(
|
||||
kv_caches[i], kv_caches_copy[i], rtol=rtol, atol=atol
|
||||
)
|
||||
|
||||
torch.testing.assert_close(triton_masked,
|
||||
reference_masked,
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
torch.testing.assert_close(triton_masked, reference_masked, rtol=rtol, atol=atol)
|
||||
|
||||
assert triton_output.shape == (batch_size, num_heads * head_size)
|
||||
|
||||
@@ -249,39 +237,33 @@ def test_lightning_attention_reference(
|
||||
current_platform.seed_everything(42)
|
||||
|
||||
base = 0.01
|
||||
q = base * torch.randn(
|
||||
batch_size, num_heads, seq_len, head_size, dtype=dtype)
|
||||
k = base * torch.randn(
|
||||
batch_size, num_heads, seq_len, head_size, dtype=dtype)
|
||||
v = base * torch.randn(
|
||||
batch_size, num_heads, seq_len, head_size, dtype=dtype)
|
||||
q = base * torch.randn(batch_size, num_heads, seq_len, head_size, dtype=dtype)
|
||||
k = base * torch.randn(batch_size, num_heads, seq_len, head_size, dtype=dtype)
|
||||
v = base * torch.randn(batch_size, num_heads, seq_len, head_size, dtype=dtype)
|
||||
|
||||
ed = torch.zeros(num_heads, device="cuda")
|
||||
for h in range(num_heads):
|
||||
ed[h] = 0.1 * (h + 1)
|
||||
|
||||
kv_history = base * torch.randn(batch_size,
|
||||
num_heads,
|
||||
head_size,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device="cuda")
|
||||
kv_history = base * torch.randn(
|
||||
batch_size, num_heads, head_size, head_size, dtype=dtype, device="cuda"
|
||||
)
|
||||
|
||||
kv_history_clone = kv_history.clone()
|
||||
|
||||
ref_output, ref_kv_cache = reference_lightning_attention(
|
||||
q, k, v, ed, 256, kv_history)
|
||||
q, k, v, ed, 256, kv_history
|
||||
)
|
||||
|
||||
from vllm.model_executor.layers.lightning_attn import lightning_attention
|
||||
|
||||
actual_output, actual_kv_cache = lightning_attention(
|
||||
q, k, v, ed, 256, kv_history_clone)
|
||||
q, k, v, ed, 256, kv_history_clone
|
||||
)
|
||||
|
||||
atol, rtol = 1.5e-1, 1.5e-1
|
||||
torch.testing.assert_close(ref_output, actual_output, rtol=rtol, atol=atol)
|
||||
torch.testing.assert_close(ref_kv_cache,
|
||||
actual_kv_cache,
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
torch.testing.assert_close(ref_kv_cache, actual_kv_cache, rtol=rtol, atol=atol)
|
||||
|
||||
assert ref_output.shape == (batch_size, num_heads, seq_len, head_size)
|
||||
assert ref_kv_cache.shape == actual_kv_cache.shape
|
||||
|
||||
@@ -7,19 +7,20 @@ import torch
|
||||
|
||||
from vllm._custom_ops import merge_attn_states as merge_attn_states_cuda
|
||||
from vllm.attention.ops.triton_merge_attn_states import (
|
||||
merge_attn_states as merge_attn_states_triton)
|
||||
merge_attn_states as merge_attn_states_triton,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
# Naive PyTorch Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
|
||||
# can be used to combine partial attention results (in the split-KV case)
|
||||
def merge_attn_states_torch(
|
||||
output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
||||
prefix_output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
||||
prefix_lse: torch.Tensor, # [NUM_HEADS, NUM_TOKENS]
|
||||
suffix_output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
||||
suffix_lse: torch.Tensor, # [NUM_HEADS, NUM_TOKENS]
|
||||
output_lse: Optional[torch.Tensor] = None, # [NUM_HEADS, NUM_TOKENS]
|
||||
output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
||||
prefix_output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
||||
prefix_lse: torch.Tensor, # [NUM_HEADS, NUM_TOKENS]
|
||||
suffix_output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
||||
suffix_lse: torch.Tensor, # [NUM_HEADS, NUM_TOKENS]
|
||||
output_lse: Optional[torch.Tensor] = None, # [NUM_HEADS, NUM_TOKENS]
|
||||
):
|
||||
p_lse = prefix_lse
|
||||
s_lse = suffix_lse
|
||||
@@ -32,15 +33,13 @@ def merge_attn_states_torch(
|
||||
s_lse = s_lse - max_lse
|
||||
p_lse_exp = torch.exp(p_lse)
|
||||
s_lse_exp = torch.exp(s_lse)
|
||||
out_se = (p_lse_exp + s_lse_exp)
|
||||
out_se = p_lse_exp + s_lse_exp
|
||||
if output_lse is not None:
|
||||
output_lse = torch.log(out_se) + max_lse
|
||||
p_scale = p_lse_exp / out_se # [NUM_HEADS, NUM_TOKENS]
|
||||
s_scale = s_lse_exp / out_se # [NUM_HEADS, NUM_TOKENS]
|
||||
p_scale = torch.transpose(p_scale, 0,
|
||||
1).unsqueeze(2) # [NUM_TOKENS, NUM_HEADS, 1]
|
||||
s_scale = torch.transpose(s_scale, 0,
|
||||
1).unsqueeze(2) # [NUM_TOKENS, NUM_HEADS, 1]
|
||||
p_scale = torch.transpose(p_scale, 0, 1).unsqueeze(2) # [NUM_TOKENS, NUM_HEADS, 1]
|
||||
s_scale = torch.transpose(s_scale, 0, 1).unsqueeze(2) # [NUM_TOKENS, NUM_HEADS, 1]
|
||||
output = prefix_output * p_scale + suffix_output * s_scale
|
||||
return output, output_lse
|
||||
|
||||
@@ -55,8 +54,10 @@ all_case_info: list[tuple] = []
|
||||
|
||||
def generate_markdown_table():
|
||||
global all_case_info
|
||||
table_header = ("| tokens | heads | headsize | dtype "
|
||||
"| device | torch | triton | cuda | speedup |")
|
||||
table_header = (
|
||||
"| tokens | heads | headsize | dtype "
|
||||
"| device | torch | triton | cuda | speedup |"
|
||||
)
|
||||
table_separator = "| --- | --- | --- | --- | --- | --- | --- | --- | --- |"
|
||||
|
||||
def shortly_dtype(dtype: torch.dtype) -> str:
|
||||
@@ -68,16 +69,26 @@ def generate_markdown_table():
|
||||
print(table_header)
|
||||
print(table_separator)
|
||||
for info in all_case_info:
|
||||
(num_tokens, num_heads, head_size, dtype, device,
|
||||
avg_time_torch_kernel, avg_time_triton_kernel, avg_time_cuda_kernel,
|
||||
performance_improved) = info
|
||||
(
|
||||
num_tokens,
|
||||
num_heads,
|
||||
head_size,
|
||||
dtype,
|
||||
device,
|
||||
avg_time_torch_kernel,
|
||||
avg_time_triton_kernel,
|
||||
avg_time_cuda_kernel,
|
||||
performance_improved,
|
||||
) = info
|
||||
dtype = shortly_dtype(dtype)
|
||||
device = shortly_device(device)
|
||||
print(f"| {num_tokens} | {num_heads} | {head_size} "
|
||||
f"| {dtype} | {device} | {avg_time_torch_kernel:.5f}ms "
|
||||
f"| {avg_time_triton_kernel:.5f}ms "
|
||||
f"| {avg_time_cuda_kernel:.5f}ms "
|
||||
f"| {performance_improved:.4f}x |")
|
||||
print(
|
||||
f"| {num_tokens} | {num_heads} | {head_size} "
|
||||
f"| {dtype} | {device} | {avg_time_torch_kernel:.5f}ms "
|
||||
f"| {avg_time_triton_kernel:.5f}ms "
|
||||
f"| {avg_time_cuda_kernel:.5f}ms "
|
||||
f"| {performance_improved:.4f}x |"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_BATCH_TOKENS)
|
||||
@@ -85,29 +96,28 @@ def generate_markdown_table():
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("output_dtype", DTYPES)
|
||||
@torch.inference_mode()
|
||||
def test_merge_attn_states(num_tokens: int, num_query_heads: int,
|
||||
head_size: int, output_dtype: torch.dtype):
|
||||
def test_merge_attn_states(
|
||||
num_tokens: int, num_query_heads: int, head_size: int, output_dtype: torch.dtype
|
||||
):
|
||||
if not current_platform.is_cuda():
|
||||
pytest.skip('Currently only support compare triton merge_attn_states '
|
||||
'with custom cuda merge_attn_states kernel')
|
||||
pytest.skip(
|
||||
"Currently only support compare triton merge_attn_states "
|
||||
"with custom cuda merge_attn_states kernel"
|
||||
)
|
||||
|
||||
NUM_TOKENS = num_tokens
|
||||
NUM_HEADS = num_query_heads
|
||||
HEAD_SIZE = head_size
|
||||
|
||||
print(f"\nNUM_TOKENS:{NUM_TOKENS}, NUM_HEADS:{NUM_HEADS}, "
|
||||
f"HEAD_SIZE:{HEAD_SIZE}, DTYPE: {output_dtype}, "
|
||||
f"Device: {current_platform.get_device_name()}")
|
||||
print(
|
||||
f"\nNUM_TOKENS:{NUM_TOKENS}, NUM_HEADS:{NUM_HEADS}, "
|
||||
f"HEAD_SIZE:{HEAD_SIZE}, DTYPE: {output_dtype}, "
|
||||
f"Device: {current_platform.get_device_name()}"
|
||||
)
|
||||
|
||||
# prefix_lse and suffix_lse contain inf and normal values
|
||||
prefix_lse = torch.randn(NUM_HEADS,
|
||||
NUM_TOKENS,
|
||||
dtype=torch.float32,
|
||||
device="cuda")
|
||||
suffix_lse = torch.randn(NUM_HEADS,
|
||||
NUM_TOKENS,
|
||||
dtype=torch.float32,
|
||||
device="cuda")
|
||||
prefix_lse = torch.randn(NUM_HEADS, NUM_TOKENS, dtype=torch.float32, device="cuda")
|
||||
suffix_lse = torch.randn(NUM_HEADS, NUM_TOKENS, dtype=torch.float32, device="cuda")
|
||||
|
||||
# Generate boolean masks
|
||||
mask_prefix = torch.rand(NUM_HEADS, NUM_TOKENS) < 0.1
|
||||
@@ -117,23 +127,23 @@ def test_merge_attn_states(num_tokens: int, num_query_heads: int,
|
||||
mask_prefix = torch.logical_and(mask_prefix, ~combined_mask)
|
||||
mask_suffix = torch.logical_and(mask_suffix, ~combined_mask)
|
||||
|
||||
prefix_lse[mask_prefix] = float('inf')
|
||||
suffix_lse[mask_suffix] = float('inf')
|
||||
prefix_lse[mask_prefix] = float("inf")
|
||||
suffix_lse[mask_suffix] = float("inf")
|
||||
|
||||
# Other input tensors (need to be initialized but
|
||||
# no actual calculation needed)
|
||||
output = torch.zeros((NUM_TOKENS, NUM_HEADS, HEAD_SIZE),
|
||||
dtype=output_dtype,
|
||||
device="cuda")
|
||||
output_lse = torch.zeros((NUM_HEADS, NUM_TOKENS),
|
||||
dtype=torch.float32,
|
||||
device="cuda")
|
||||
prefix_output = torch.randn((NUM_TOKENS, NUM_HEADS, HEAD_SIZE),
|
||||
dtype=output_dtype,
|
||||
device="cuda")
|
||||
suffix_output = torch.randn((NUM_TOKENS, NUM_HEADS, HEAD_SIZE),
|
||||
dtype=output_dtype,
|
||||
device="cuda")
|
||||
output = torch.zeros(
|
||||
(NUM_TOKENS, NUM_HEADS, HEAD_SIZE), dtype=output_dtype, device="cuda"
|
||||
)
|
||||
output_lse = torch.zeros(
|
||||
(NUM_HEADS, NUM_TOKENS), dtype=torch.float32, device="cuda"
|
||||
)
|
||||
prefix_output = torch.randn(
|
||||
(NUM_TOKENS, NUM_HEADS, HEAD_SIZE), dtype=output_dtype, device="cuda"
|
||||
)
|
||||
suffix_output = torch.randn(
|
||||
(NUM_TOKENS, NUM_HEADS, HEAD_SIZE), dtype=output_dtype, device="cuda"
|
||||
)
|
||||
|
||||
warmup_times = 2
|
||||
repeat_times = 20
|
||||
@@ -149,15 +159,25 @@ def test_merge_attn_states(num_tokens: int, num_query_heads: int,
|
||||
suffix_lse_torch = suffix_lse.clone()
|
||||
for _ in range(warmup_times):
|
||||
output_torch, output_lse_torch = merge_attn_states_torch(
|
||||
output_torch, prefix_output, prefix_lse_torch, suffix_output,
|
||||
suffix_lse_torch, output_lse_torch)
|
||||
output_torch,
|
||||
prefix_output,
|
||||
prefix_lse_torch,
|
||||
suffix_output,
|
||||
suffix_lse_torch,
|
||||
output_lse_torch,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
for _ in range(repeat_times):
|
||||
start.record()
|
||||
output_torch, output_lse_torch = merge_attn_states_torch(
|
||||
output_torch, prefix_output, prefix_lse_torch, suffix_output,
|
||||
suffix_lse_torch, output_lse_torch)
|
||||
output_torch,
|
||||
prefix_output,
|
||||
prefix_lse_torch,
|
||||
suffix_output,
|
||||
suffix_lse_torch,
|
||||
output_lse_torch,
|
||||
)
|
||||
end.record()
|
||||
torch.cuda.synchronize()
|
||||
total_time_torch_kernel += start.elapsed_time(end)
|
||||
@@ -173,16 +193,26 @@ def test_merge_attn_states(num_tokens: int, num_query_heads: int,
|
||||
end = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
for _ in range(warmup_times):
|
||||
merge_attn_states_triton(output_ref_triton, prefix_output, prefix_lse,
|
||||
suffix_output, suffix_lse,
|
||||
output_lse_ref_triton)
|
||||
merge_attn_states_triton(
|
||||
output_ref_triton,
|
||||
prefix_output,
|
||||
prefix_lse,
|
||||
suffix_output,
|
||||
suffix_lse,
|
||||
output_lse_ref_triton,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
for _ in range(repeat_times):
|
||||
start.record()
|
||||
merge_attn_states_triton(output_ref_triton, prefix_output, prefix_lse,
|
||||
suffix_output, suffix_lse,
|
||||
output_lse_ref_triton)
|
||||
merge_attn_states_triton(
|
||||
output_ref_triton,
|
||||
prefix_output,
|
||||
prefix_lse,
|
||||
suffix_output,
|
||||
suffix_lse,
|
||||
output_lse_ref_triton,
|
||||
)
|
||||
end.record()
|
||||
torch.cuda.synchronize()
|
||||
total_time_triton_kernel += start.elapsed_time(end)
|
||||
@@ -195,14 +225,26 @@ def test_merge_attn_states(num_tokens: int, num_query_heads: int,
|
||||
output_lse_cuda = output_lse.clone()
|
||||
|
||||
for _ in range(warmup_times):
|
||||
merge_attn_states_cuda(output_cuda, prefix_output, prefix_lse,
|
||||
suffix_output, suffix_lse, output_lse_cuda)
|
||||
merge_attn_states_cuda(
|
||||
output_cuda,
|
||||
prefix_output,
|
||||
prefix_lse,
|
||||
suffix_output,
|
||||
suffix_lse,
|
||||
output_lse_cuda,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
for _ in range(repeat_times):
|
||||
start.record()
|
||||
merge_attn_states_cuda(output_cuda, prefix_output, prefix_lse,
|
||||
suffix_output, suffix_lse, output_lse_cuda)
|
||||
merge_attn_states_cuda(
|
||||
output_cuda,
|
||||
prefix_output,
|
||||
prefix_lse,
|
||||
suffix_output,
|
||||
suffix_lse,
|
||||
output_lse_cuda,
|
||||
)
|
||||
end.record()
|
||||
torch.cuda.synchronize()
|
||||
total_time_cuda_kernel += start.elapsed_time(end)
|
||||
@@ -213,8 +255,10 @@ def test_merge_attn_states(num_tokens: int, num_query_heads: int,
|
||||
performance_improved = avg_time_triton_kernel / avg_time_cuda_kernel
|
||||
print(f" Torch time: {avg_time_torch_kernel:.6f}ms")
|
||||
print(f"Triton time: {avg_time_triton_kernel:.6f}ms")
|
||||
print(f" CUDA time: {avg_time_cuda_kernel:.6f}ms, "
|
||||
f"Performance: {performance_improved:.5f}x")
|
||||
print(
|
||||
f" CUDA time: {avg_time_cuda_kernel:.6f}ms, "
|
||||
f"Performance: {performance_improved:.5f}x"
|
||||
)
|
||||
print("-" * 100)
|
||||
|
||||
# 4. Correctness compare
|
||||
@@ -232,35 +276,45 @@ def test_merge_attn_states(num_tokens: int, num_query_heads: int,
|
||||
# states operation.
|
||||
output_ref = output_ref_triton
|
||||
output_lse_ref = output_lse_ref_triton
|
||||
torch.testing.assert_close(output_cuda.float(),
|
||||
output_ref.float(),
|
||||
atol=1e-3,
|
||||
rtol=rtol)
|
||||
torch.testing.assert_close(
|
||||
output_cuda.float(), output_ref.float(), atol=1e-3, rtol=rtol
|
||||
)
|
||||
print("Output all match, max abs diff:")
|
||||
print(f"(Triton vs Torch) : {diff(output_torch, output_ref)}")
|
||||
print(f" (CUDA vs Torch) : {diff(output_torch, output_cuda)}")
|
||||
print(f" (CUDA vs Triton): {diff(output_ref, output_cuda)}")
|
||||
print("-" * 100)
|
||||
|
||||
torch.testing.assert_close(output_lse_cuda.float(),
|
||||
output_lse_ref.float(),
|
||||
atol=1e-3,
|
||||
rtol=rtol)
|
||||
torch.testing.assert_close(
|
||||
output_lse_cuda.float(), output_lse_ref.float(), atol=1e-3, rtol=rtol
|
||||
)
|
||||
print("Output LSE all match, max abs diff:")
|
||||
print(f"(Triton vs Torch) : {diff(output_lse_torch, output_lse_ref)}")
|
||||
print(f" (CUDA vs Torch) : {diff(output_lse_torch, output_lse_cuda)}")
|
||||
print(f" (CUDA vs Triton): {diff(output_lse_ref, output_lse_cuda)}")
|
||||
print("-" * 100)
|
||||
|
||||
print("All output values test passed! All inf values "
|
||||
"are correctly replaced with -inf.")
|
||||
print(
|
||||
"All output values test passed! All inf values "
|
||||
"are correctly replaced with -inf."
|
||||
)
|
||||
print("-" * 100)
|
||||
|
||||
device = current_platform.get_device_name()
|
||||
all_case_info.append(
|
||||
(NUM_TOKENS, NUM_HEADS, HEAD_SIZE, output_dtype, device,
|
||||
avg_time_torch_kernel, avg_time_triton_kernel, avg_time_cuda_kernel,
|
||||
performance_improved))
|
||||
if len(all_case_info) == (len(NUM_BATCH_TOKENS) * len(HEAD_SIZES) *
|
||||
len(NUM_QUERY_HEADS) * len(DTYPES)):
|
||||
(
|
||||
NUM_TOKENS,
|
||||
NUM_HEADS,
|
||||
HEAD_SIZE,
|
||||
output_dtype,
|
||||
device,
|
||||
avg_time_torch_kernel,
|
||||
avg_time_triton_kernel,
|
||||
avg_time_cuda_kernel,
|
||||
performance_improved,
|
||||
)
|
||||
)
|
||||
if len(all_case_info) == (
|
||||
len(NUM_BATCH_TOKENS) * len(HEAD_SIZES) * len(NUM_QUERY_HEADS) * len(DTYPES)
|
||||
):
|
||||
generate_markdown_table()
|
||||
|
||||
@@ -5,6 +5,7 @@ Test:
|
||||
|
||||
* Tests for MultiHeadAttention layer
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
@@ -21,11 +22,11 @@ from vllm.platforms.rocm import RocmPlatform
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_cache():
|
||||
"""Clear lru cache to ensure each test case runs without caching.
|
||||
"""
|
||||
"""Clear lru cache to ensure each test case runs without caching."""
|
||||
_cached_get_attn_backend.cache_clear()
|
||||
# Clear xformers availability cache
|
||||
import vllm.attention.layer as layer_module
|
||||
|
||||
layer_module.USE_XFORMERS_OPS = None
|
||||
|
||||
|
||||
@@ -37,49 +38,63 @@ def test_mha_attn_platform(device: str):
|
||||
torch.set_default_dtype(torch.float16)
|
||||
|
||||
if device == "cpu":
|
||||
with patch("vllm.attention.layer.current_platform", CpuPlatform()), \
|
||||
patch("vllm.model_executor.models.vision.current_platform",
|
||||
CpuPlatform()):
|
||||
with (
|
||||
patch("vllm.attention.layer.current_platform", CpuPlatform()),
|
||||
patch("vllm.model_executor.models.vision.current_platform", CpuPlatform()),
|
||||
):
|
||||
attn = MultiHeadAttention(16, 64, scale=1)
|
||||
assert attn.attn_backend == _Backend.TORCH_SDPA
|
||||
elif device == "hip":
|
||||
with patch("vllm.attention.layer.current_platform", RocmPlatform()), \
|
||||
patch("vllm.model_executor.models.vision.current_platform",
|
||||
RocmPlatform()):
|
||||
with (
|
||||
patch("vllm.attention.layer.current_platform", RocmPlatform()),
|
||||
patch("vllm.model_executor.models.vision.current_platform", RocmPlatform()),
|
||||
):
|
||||
attn = MultiHeadAttention(16, 64, scale=1)
|
||||
assert attn.attn_backend == _Backend.TORCH_SDPA
|
||||
else:
|
||||
# Test CUDA with head_size=64 (divisible by 32)
|
||||
# - should use vLLM's FlashAttention
|
||||
with patch("vllm.attention.layer.current_platform", CudaPlatform()), \
|
||||
patch("vllm.model_executor.models.vision.current_platform",
|
||||
CudaPlatform()):
|
||||
with (
|
||||
patch("vllm.attention.layer.current_platform", CudaPlatform()),
|
||||
patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()),
|
||||
):
|
||||
attn = MultiHeadAttention(16, 64, scale=1)
|
||||
assert attn.attn_backend == _Backend.FLASH_ATTN
|
||||
|
||||
# Test CUDA with head_size=72 (not divisible by 32)
|
||||
# - with upstream FA not available
|
||||
# - should use xformers
|
||||
with patch("vllm.attention.layer.current_platform", CudaPlatform()), \
|
||||
patch("vllm.model_executor.models.vision.current_platform",
|
||||
CudaPlatform()), \
|
||||
patch("vllm.attention.layer.check_upstream_fa_availability",
|
||||
return_value=False):
|
||||
with (
|
||||
patch("vllm.attention.layer.current_platform", CudaPlatform()),
|
||||
patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()),
|
||||
patch(
|
||||
"vllm.attention.layer.check_upstream_fa_availability",
|
||||
return_value=False,
|
||||
),
|
||||
):
|
||||
attn = MultiHeadAttention(16, 72, scale=1)
|
||||
assert attn.attn_backend == _Backend.XFORMERS
|
||||
|
||||
# Test CUDA with head_size=72 (not divisible by 32)
|
||||
# - with upstream FA available
|
||||
# - should use upstream FA
|
||||
with patch("vllm.attention.layer.current_platform", CudaPlatform()), \
|
||||
patch("vllm.model_executor.models.vision.current_platform",
|
||||
CudaPlatform()), \
|
||||
patch("vllm.attention.layer.check_upstream_fa_availability",
|
||||
return_value=True), \
|
||||
patch.dict('sys.modules', {'flash_attn': type('MockFlashAttn', (),
|
||||
{
|
||||
'flash_attn_varlen_func': lambda *args, **kwargs: None
|
||||
})()}):
|
||||
with (
|
||||
patch("vllm.attention.layer.current_platform", CudaPlatform()),
|
||||
patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()),
|
||||
patch(
|
||||
"vllm.attention.layer.check_upstream_fa_availability", return_value=True
|
||||
),
|
||||
patch.dict(
|
||||
"sys.modules",
|
||||
{
|
||||
"flash_attn": type(
|
||||
"MockFlashAttn",
|
||||
(),
|
||||
{"flash_attn_varlen_func": lambda *args, **kwargs: None},
|
||||
)()
|
||||
},
|
||||
),
|
||||
):
|
||||
attn = MultiHeadAttention(16, 72, scale=1)
|
||||
assert attn.attn_backend == _Backend.FLASH_ATTN
|
||||
|
||||
@@ -108,9 +123,11 @@ NUM_HEADS = [1, 16]
|
||||
NUM_KV_HEADS = [1]
|
||||
HEAD_SIZES = [64, 80]
|
||||
# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
|
||||
DTYPES = [
|
||||
torch.half, torch.bfloat16, torch.float
|
||||
] if not current_platform.is_rocm() else [torch.half, torch.bfloat16]
|
||||
DTYPES = (
|
||||
[torch.half, torch.bfloat16, torch.float]
|
||||
if not current_platform.is_rocm()
|
||||
else [torch.half, torch.bfloat16]
|
||||
)
|
||||
CUDA_DEVICES = ["cuda"]
|
||||
|
||||
|
||||
@@ -138,10 +155,9 @@ def test_mha_attn_forward(
|
||||
k = torch.randn(batch_size, seq_len, num_kv_heads * head_size)
|
||||
v = torch.randn(batch_size, seq_len, num_kv_heads * head_size)
|
||||
scale = 1.0 / head_size**0.5
|
||||
attn = MultiHeadAttention(num_heads,
|
||||
head_size,
|
||||
scale=scale,
|
||||
num_kv_heads=num_kv_heads)
|
||||
attn = MultiHeadAttention(
|
||||
num_heads, head_size, scale=scale, num_kv_heads=num_kv_heads
|
||||
)
|
||||
output = attn(q, k, v)
|
||||
|
||||
assert num_heads % num_kv_heads == 0
|
||||
|
||||
@@ -11,30 +11,24 @@ from vllm.utils import cdiv
|
||||
|
||||
|
||||
def ref_mla(
|
||||
out: Tensor, # (bs, num_heads, v_head_dim)
|
||||
query: Tensor, # (bs, num_heads, head_dim)
|
||||
kv_cache: Tensor, # (num_blocks, block_size, head_dim)
|
||||
scale: float,
|
||||
block_tables: Tensor, # (bs, max_num_blocks)
|
||||
seq_lens: Tensor, # (bs,)
|
||||
out: Tensor, # (bs, num_heads, v_head_dim)
|
||||
query: Tensor, # (bs, num_heads, head_dim)
|
||||
kv_cache: Tensor, # (num_blocks, block_size, head_dim)
|
||||
scale: float,
|
||||
block_tables: Tensor, # (bs, max_num_blocks)
|
||||
seq_lens: Tensor, # (bs,)
|
||||
):
|
||||
bs, num_heads, v_head_dim = out.shape
|
||||
head_dim = query.shape[2]
|
||||
|
||||
for i in range(bs):
|
||||
# gather and flatten KV-cache
|
||||
kv = kv_cache[
|
||||
block_tables[i]] # (max_num_blocks, block_size, head_dim)
|
||||
kv = kv.view(1, -1,
|
||||
head_dim)[:, :seq_lens[i]] # (1, seq_len, head_dim)
|
||||
kv = kv_cache[block_tables[i]] # (max_num_blocks, block_size, head_dim)
|
||||
kv = kv.view(1, -1, head_dim)[:, : seq_lens[i]] # (1, seq_len, head_dim)
|
||||
v = kv[:, :, :v_head_dim]
|
||||
|
||||
q = query[i].view(num_heads, 1, head_dim)
|
||||
o = F.scaled_dot_product_attention(q,
|
||||
kv,
|
||||
v,
|
||||
scale=scale,
|
||||
enable_gqa=True)
|
||||
o = F.scaled_dot_product_attention(q, kv, v, scale=scale, enable_gqa=True)
|
||||
out[i] = o.view(num_heads, v_head_dim)
|
||||
|
||||
return out
|
||||
@@ -63,18 +57,17 @@ def test_mla_decode_cpu(
|
||||
torch.set_default_dtype(dtype)
|
||||
torch.manual_seed(0)
|
||||
|
||||
scale = d**(-0.5)
|
||||
scale = d ** (-0.5)
|
||||
if varlen:
|
||||
seq_lens = torch.empty(bs).normal_(mean_seq_len, mean_seq_len / 2)
|
||||
seq_lens = seq_lens.clip(2).to(torch.int32)
|
||||
else:
|
||||
seq_lens = torch.full((bs, ), mean_seq_len, dtype=torch.int32)
|
||||
seq_lens = torch.full((bs,), mean_seq_len, dtype=torch.int32)
|
||||
max_seq_len = seq_lens.max().item()
|
||||
seqlen_pad = cdiv(max_seq_len, 256) * 256 # is this necessary?
|
||||
|
||||
q = torch.randn(bs, h_q, d)
|
||||
block_table = torch.arange(bs * seqlen_pad // block_size,
|
||||
dtype=torch.int32)
|
||||
block_table = torch.arange(bs * seqlen_pad // block_size, dtype=torch.int32)
|
||||
block_table = block_table.view(bs, seqlen_pad // block_size)
|
||||
|
||||
kv_cache = torch.randn(block_table.numel(), block_size, d)
|
||||
@@ -82,8 +75,7 @@ def test_mla_decode_cpu(
|
||||
kv_cache.view(bs, seqlen_pad, d)[i, seq_len:] = float("nan")
|
||||
|
||||
out_mla = q.new_zeros(bs, h_q, dv)
|
||||
ops.mla_decode_kvcache_cpu(out_mla, q, kv_cache, scale, block_table,
|
||||
seq_lens)
|
||||
ops.mla_decode_kvcache_cpu(out_mla, q, kv_cache, scale, block_table, seq_lens)
|
||||
|
||||
out_ref = q.new_zeros(bs, h_q, dv)
|
||||
ref_mla(out_ref, q, kv_cache, scale, block_table, seq_lens)
|
||||
|
||||
@@ -39,7 +39,7 @@ def test_pack_seq_basic_fp8():
|
||||
start_idx = sum(lengths_list[:b])
|
||||
seq_len = lengths_list[b]
|
||||
|
||||
expected_data = x[start_idx:start_idx + seq_len].to(torch.float32)
|
||||
expected_data = x[start_idx : start_idx + seq_len].to(torch.float32)
|
||||
actual_data = packed[b, :seq_len].to(torch.float32)
|
||||
|
||||
assert_close(actual_data, expected_data, rtol=1e-1, atol=1e-2)
|
||||
@@ -62,7 +62,7 @@ def test_pack_seq_custom_padding_fp8():
|
||||
# Check valid data
|
||||
for b in range(B):
|
||||
start_idx = b * 10
|
||||
expected_data = x[start_idx:start_idx + 10].to(torch.float32)
|
||||
expected_data = x[start_idx : start_idx + 10].to(torch.float32)
|
||||
actual_data = result[b, :10].to(torch.float32)
|
||||
assert_close(actual_data, expected_data, rtol=1e-1, atol=1e-2)
|
||||
|
||||
@@ -73,9 +73,7 @@ def test_pack_seq_custom_padding_fp8():
|
||||
elif pad_value > 0:
|
||||
assert torch.all(padded_data > 50) # Large positive values
|
||||
else:
|
||||
assert torch.allclose(padded_data,
|
||||
torch.zeros_like(padded_data),
|
||||
atol=1e-2)
|
||||
assert torch.allclose(padded_data, torch.zeros_like(padded_data), atol=1e-2)
|
||||
|
||||
|
||||
def test_pack_seq_default_negative_inf_padding_fp8():
|
||||
@@ -93,7 +91,8 @@ def test_pack_seq_default_negative_inf_padding_fp8():
|
||||
# Check that padding is large negative values (fp8 representation of -inf)
|
||||
padded_data = result[:, 10:].to(torch.float32)
|
||||
assert torch.all(
|
||||
padded_data < -100) # fp8 -inf is represented as large negative number
|
||||
padded_data < -100
|
||||
) # fp8 -inf is represented as large negative number
|
||||
|
||||
|
||||
def test_pack_seq_edge_cases_fp8():
|
||||
@@ -142,7 +141,7 @@ def test_pack_seq_different_block_sizes_fp8():
|
||||
# Check that valid data is preserved (within fp8 precision)
|
||||
for b in range(B):
|
||||
start_idx = b * 25
|
||||
expected_data = x[start_idx:start_idx + 25].to(torch.float32)
|
||||
expected_data = x[start_idx : start_idx + 25].to(torch.float32)
|
||||
actual_data = result[b, :25].to(torch.float32)
|
||||
assert_close(actual_data, expected_data, rtol=1e-1, atol=1e-2)
|
||||
|
||||
@@ -198,10 +197,7 @@ def test_pack_unpack_roundtrip_fp8():
|
||||
|
||||
# Unpack without explicit start locations (computed in kernel)
|
||||
unpacked_with_loc = unpack_seq_triton(packed, lengths)
|
||||
assert_close(x_f32,
|
||||
unpacked_with_loc.to(torch.float32),
|
||||
rtol=1e-3,
|
||||
atol=1e-2)
|
||||
assert_close(x_f32, unpacked_with_loc.to(torch.float32), rtol=1e-3, atol=1e-2)
|
||||
|
||||
|
||||
def test_unpack_seq_triton_edge_cases_fp8():
|
||||
@@ -216,10 +212,7 @@ def test_unpack_seq_triton_edge_cases_fp8():
|
||||
packed = pack_seq_triton(x, lengths)
|
||||
unpacked = unpack_seq_triton(packed, lengths)
|
||||
assert unpacked.shape == x.shape
|
||||
assert_close(x.to(torch.float32),
|
||||
unpacked.to(torch.float32),
|
||||
rtol=1e-1,
|
||||
atol=1e-2)
|
||||
assert_close(x.to(torch.float32), unpacked.to(torch.float32), rtol=1e-1, atol=1e-2)
|
||||
|
||||
# Test with very short sequences
|
||||
x = torch.randn(20, 4, 8, dtype=torch.float32, device=device) * 0.1
|
||||
@@ -228,10 +221,9 @@ def test_unpack_seq_triton_edge_cases_fp8():
|
||||
packed = pack_seq_triton(x, lengths)
|
||||
unpacked = unpack_seq_triton(packed, lengths)
|
||||
# Only compare the first 3 elements that were actually packed
|
||||
assert_close(x[:3].to(torch.float32),
|
||||
unpacked.to(torch.float32),
|
||||
rtol=1e-1,
|
||||
atol=1e-2)
|
||||
assert_close(
|
||||
x[:3].to(torch.float32), unpacked.to(torch.float32), rtol=1e-1, atol=1e-2
|
||||
)
|
||||
|
||||
x = torch.randn(15, 8, 16, dtype=torch.float32, device=device) * 0.1
|
||||
x = x.to(dtype=dtype)
|
||||
@@ -239,7 +231,4 @@ def test_unpack_seq_triton_edge_cases_fp8():
|
||||
packed = pack_seq_triton(x, lengths)
|
||||
unpacked = unpack_seq_triton(packed, lengths)
|
||||
assert unpacked.shape == x.shape
|
||||
assert_close(x.to(torch.float32),
|
||||
unpacked.to(torch.float32),
|
||||
rtol=1e-1,
|
||||
atol=1e-2)
|
||||
assert_close(x.to(torch.float32), unpacked.to(torch.float32), rtol=1e-1, atol=1e-2)
|
||||
|
||||
@@ -12,8 +12,7 @@ from xformers import ops as xops
|
||||
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask
|
||||
|
||||
from tests.kernels.utils import make_alibi_bias
|
||||
from vllm.attention.ops.chunked_prefill_paged_decode import (
|
||||
chunked_prefill_paged_decode)
|
||||
from vllm.attention.ops.chunked_prefill_paged_decode import chunked_prefill_paged_decode
|
||||
from vllm.attention.ops.prefix_prefill import context_attention_fwd
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
@@ -22,9 +21,7 @@ NUM_HEADS = [64]
|
||||
NUM_QUERIES_PER_KV = [1, 64]
|
||||
HEAD_SIZES = [24, 128]
|
||||
DTYPES = [torch.float16]
|
||||
CUDA_DEVICES = [
|
||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||
]
|
||||
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
||||
SLIDING_WINDOW = [0, 16, 2048]
|
||||
KV_CACHE_DTYPES = ["auto", "fp8", "fp8_e5m2"]
|
||||
|
||||
@@ -50,12 +47,10 @@ def test_contexted_kv_attention(
|
||||
device: str,
|
||||
op: Callable,
|
||||
) -> None:
|
||||
|
||||
if 'fp8' in kv_cache_dtype and not current_platform.has_device_capability(
|
||||
89):
|
||||
if "fp8" in kv_cache_dtype and not current_platform.has_device_capability(89):
|
||||
pytest.skip(
|
||||
'Triton limitation: fp8e4nv data type is not supported on CUDA'
|
||||
' arch < 89')
|
||||
"Triton limitation: fp8e4nv data type is not supported on CUDA arch < 89"
|
||||
)
|
||||
|
||||
current_platform.seed_everything(0)
|
||||
torch.set_default_device(device)
|
||||
@@ -93,38 +88,29 @@ def test_contexted_kv_attention(
|
||||
cache_dtype = dtype
|
||||
else:
|
||||
cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype]
|
||||
k_cache = torch.zeros(cache_size,
|
||||
block_size,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
dtype=cache_dtype)
|
||||
v_cache = torch.zeros(cache_size,
|
||||
block_size,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
dtype=cache_dtype)
|
||||
k_cache = torch.zeros(
|
||||
cache_size, block_size, num_kv_heads, head_size, dtype=cache_dtype
|
||||
)
|
||||
v_cache = torch.zeros(
|
||||
cache_size, block_size, num_kv_heads, head_size, dtype=cache_dtype
|
||||
)
|
||||
k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
|
||||
v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
|
||||
values = torch.arange(0, cache_size, dtype=torch.long)
|
||||
values = values[torch.randperm(cache_size)]
|
||||
block_table = values[:BS * max_block_per_request].view(
|
||||
BS, max_block_per_request)
|
||||
block_table = values[: BS * max_block_per_request].view(BS, max_block_per_request)
|
||||
b_seq_len = torch.tensor(seq_lens, dtype=torch.long)
|
||||
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long)
|
||||
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens,
|
||||
dtype=torch.long),
|
||||
dim=0)
|
||||
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.long), dim=0)
|
||||
max_input_len = MAX_SEQ_LEN
|
||||
# copy kv to cache
|
||||
b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1],
|
||||
dtype=torch.long),
|
||||
dim=0)
|
||||
b_seq_start_loc = torch.cumsum(
|
||||
torch.tensor([0] + seq_lens[:-1], dtype=torch.long), dim=0
|
||||
)
|
||||
for i in range(BS):
|
||||
for j in range(query_lens[i]):
|
||||
k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] +
|
||||
j])
|
||||
v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] +
|
||||
b_ctx_len[i] + j])
|
||||
k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] + j])
|
||||
v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] + b_ctx_len[i] + j])
|
||||
cur_ctx = 0
|
||||
block_id = 0
|
||||
while cur_ctx < b_ctx_len[i]:
|
||||
@@ -135,61 +121,71 @@ def test_contexted_kv_attention(
|
||||
end_loc = start_loc + block_size
|
||||
start_slot = block_table[i, block_id] * block_size
|
||||
end_slot = start_slot + end_loc - start_loc
|
||||
k_cache.view(-1, num_kv_heads,
|
||||
head_size)[start_slot:end_slot].copy_(
|
||||
key[start_loc:end_loc])
|
||||
v_cache.view(-1, num_kv_heads,
|
||||
head_size)[start_slot:end_slot].copy_(
|
||||
value[start_loc:end_loc])
|
||||
k_cache.view(-1, num_kv_heads, head_size)[start_slot:end_slot].copy_(
|
||||
key[start_loc:end_loc]
|
||||
)
|
||||
v_cache.view(-1, num_kv_heads, head_size)[start_slot:end_slot].copy_(
|
||||
value[start_loc:end_loc]
|
||||
)
|
||||
cur_ctx += block_size
|
||||
block_id += 1
|
||||
# transpose K_cache[num_blocks, block_size, num_kv_heads, head_size]
|
||||
# to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8]
|
||||
k_cache = k_cache.view(-1, block_size, num_kv_heads, head_size // 8,
|
||||
8).permute(0, 2, 3, 1, 4).contiguous()
|
||||
k_cache = (
|
||||
k_cache.view(-1, block_size, num_kv_heads, head_size // 8, 8)
|
||||
.permute(0, 2, 3, 1, 4)
|
||||
.contiguous()
|
||||
)
|
||||
# transpose V_cache[num_blocks, block_size, num_kv_heads, head_size]
|
||||
# to V_cache[num_blocks, num_kv_heads, head_size, block_size]
|
||||
v_cache = v_cache.view(-1, block_size, num_kv_heads,
|
||||
head_size).permute(0, 2, 3, 1).contiguous()
|
||||
v_cache = (
|
||||
v_cache.view(-1, block_size, num_kv_heads, head_size)
|
||||
.permute(0, 2, 3, 1)
|
||||
.contiguous()
|
||||
)
|
||||
k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||
|
||||
# Warm up the Triton kernel by calling it once before actually measuring
|
||||
# generation time
|
||||
op(query,
|
||||
k,
|
||||
v,
|
||||
output,
|
||||
kv_cache_dtype,
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_table,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
MAX_CTX_LEN,
|
||||
max_input_len,
|
||||
k_scale,
|
||||
v_scale,
|
||||
sliding_window=sliding_window)
|
||||
op(
|
||||
query,
|
||||
k,
|
||||
v,
|
||||
output,
|
||||
kv_cache_dtype,
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_table,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
MAX_CTX_LEN,
|
||||
max_input_len,
|
||||
k_scale,
|
||||
v_scale,
|
||||
sliding_window=sliding_window,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.time()
|
||||
op(query,
|
||||
k,
|
||||
v,
|
||||
output,
|
||||
kv_cache_dtype,
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_table,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
MAX_CTX_LEN,
|
||||
max_input_len,
|
||||
k_scale,
|
||||
v_scale,
|
||||
sliding_window=sliding_window)
|
||||
op(
|
||||
query,
|
||||
k,
|
||||
v,
|
||||
output,
|
||||
kv_cache_dtype,
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_table,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
MAX_CTX_LEN,
|
||||
max_input_len,
|
||||
k_scale,
|
||||
v_scale,
|
||||
sliding_window=sliding_window,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
end_time = time.time()
|
||||
print(f"triton Time: {(end_time - start_time)*1000:.2f} ms")
|
||||
print(f"triton Time: {(end_time - start_time) * 1000:.2f} ms")
|
||||
|
||||
scale = float(1.0 / (head_size**0.5))
|
||||
|
||||
@@ -201,22 +197,24 @@ def test_contexted_kv_attention(
|
||||
# heads.
|
||||
#
|
||||
# see also: vllm/model_executor/layers/attention.py
|
||||
query = query.view(query.shape[0], num_kv_heads, num_queries_per_kv,
|
||||
query.shape[-1])
|
||||
key = key[:, :, None, :].expand(key.shape[0], num_kv_heads,
|
||||
num_queries_per_kv, key.shape[-1])
|
||||
value = value[:, :,
|
||||
None, :].expand(value.shape[0], num_kv_heads,
|
||||
num_queries_per_kv, value.shape[-1])
|
||||
query = query.view(
|
||||
query.shape[0], num_kv_heads, num_queries_per_kv, query.shape[-1]
|
||||
)
|
||||
key = key[:, :, None, :].expand(
|
||||
key.shape[0], num_kv_heads, num_queries_per_kv, key.shape[-1]
|
||||
)
|
||||
value = value[:, :, None, :].expand(
|
||||
value.shape[0], num_kv_heads, num_queries_per_kv, value.shape[-1]
|
||||
)
|
||||
query = query.unsqueeze(0)
|
||||
key = key.unsqueeze(0)
|
||||
value = value.unsqueeze(0)
|
||||
|
||||
attn_bias = BlockDiagonalCausalFromBottomRightMask.from_seqlens(
|
||||
query_lens, seq_lens)
|
||||
query_lens, seq_lens
|
||||
)
|
||||
if sliding_window > 0:
|
||||
attn_bias = attn_bias.make_local_attention_from_bottomright(
|
||||
sliding_window)
|
||||
attn_bias = attn_bias.make_local_attention_from_bottomright(sliding_window)
|
||||
output_ref = xops.memory_efficient_attention_forward(
|
||||
query,
|
||||
key,
|
||||
@@ -239,7 +237,7 @@ def test_contexted_kv_attention(
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
end_time = time.time()
|
||||
print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms")
|
||||
print(f"xformers Time: {(end_time - start_time) * 1000:.2f} ms")
|
||||
output_ref = output_ref.reshape(output.shape)
|
||||
atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-4
|
||||
torch.testing.assert_close(output, output_ref, atol=atol, rtol=0)
|
||||
@@ -262,12 +260,10 @@ def test_contexted_kv_attention_alibi(
|
||||
device: str,
|
||||
op: Callable,
|
||||
) -> None:
|
||||
|
||||
if 'fp8' in kv_cache_dtype and not current_platform.has_device_capability(
|
||||
89):
|
||||
if "fp8" in kv_cache_dtype and not current_platform.has_device_capability(89):
|
||||
pytest.skip(
|
||||
'Triton limitation: fp8e4nv data type is not supported on CUDA'
|
||||
' arch < 89')
|
||||
"Triton limitation: fp8e4nv data type is not supported on CUDA arch < 89"
|
||||
)
|
||||
|
||||
current_platform.seed_everything(0)
|
||||
torch.set_default_device(device)
|
||||
@@ -280,9 +276,9 @@ def test_contexted_kv_attention_alibi(
|
||||
|
||||
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
|
||||
# Fork from: vllm/vllm/model_executor/models/bloom.py#L44
|
||||
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
|
||||
closest_power_of_2 = 2 ** math.floor(math.log2(total_num_heads))
|
||||
base = torch.tensor(
|
||||
2**(-(2**-(math.log2(closest_power_of_2) - 3))),
|
||||
2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
|
||||
@@ -290,17 +286,16 @@ def test_contexted_kv_attention_alibi(
|
||||
|
||||
if closest_power_of_2 != total_num_heads:
|
||||
extra_base = torch.tensor(
|
||||
2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
|
||||
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
num_remaining_heads = min(closest_power_of_2,
|
||||
total_num_heads - closest_power_of_2)
|
||||
extra_powers = torch.arange(start=1,
|
||||
end=1 + 2 * num_remaining_heads,
|
||||
step=2,
|
||||
dtype=torch.int32)
|
||||
slopes = torch.cat(
|
||||
[slopes, torch.pow(extra_base, extra_powers)], dim=0)
|
||||
num_remaining_heads = min(
|
||||
closest_power_of_2, total_num_heads - closest_power_of_2
|
||||
)
|
||||
extra_powers = torch.arange(
|
||||
start=1, end=1 + 2 * num_remaining_heads, step=2, dtype=torch.int32
|
||||
)
|
||||
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
|
||||
return slopes
|
||||
|
||||
alibi_slopes = _get_alibi_slopes(num_heads).to(device)
|
||||
@@ -328,38 +323,29 @@ def test_contexted_kv_attention_alibi(
|
||||
cache_dtype = dtype
|
||||
else:
|
||||
cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype]
|
||||
k_cache = torch.zeros(cache_size,
|
||||
block_size,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
dtype=cache_dtype)
|
||||
v_cache = torch.zeros(cache_size,
|
||||
block_size,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
dtype=cache_dtype)
|
||||
k_cache = torch.zeros(
|
||||
cache_size, block_size, num_kv_heads, head_size, dtype=cache_dtype
|
||||
)
|
||||
v_cache = torch.zeros(
|
||||
cache_size, block_size, num_kv_heads, head_size, dtype=cache_dtype
|
||||
)
|
||||
k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
|
||||
v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype)
|
||||
values = torch.arange(0, cache_size, dtype=torch.long)
|
||||
values = values[torch.randperm(cache_size)]
|
||||
block_table = values[:BS * max_block_per_request].view(
|
||||
BS, max_block_per_request)
|
||||
block_table = values[: BS * max_block_per_request].view(BS, max_block_per_request)
|
||||
b_seq_len = torch.tensor(seq_lens, dtype=torch.long)
|
||||
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long)
|
||||
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens,
|
||||
dtype=torch.long),
|
||||
dim=0)
|
||||
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.long), dim=0)
|
||||
max_input_len = MAX_SEQ_LEN
|
||||
# copy kv to cache
|
||||
b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1],
|
||||
dtype=torch.long),
|
||||
dim=0)
|
||||
b_seq_start_loc = torch.cumsum(
|
||||
torch.tensor([0] + seq_lens[:-1], dtype=torch.long), dim=0
|
||||
)
|
||||
for i in range(BS):
|
||||
for j in range(query_lens[i]):
|
||||
k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] +
|
||||
j])
|
||||
v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] +
|
||||
b_ctx_len[i] + j])
|
||||
k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] + j])
|
||||
v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] + b_ctx_len[i] + j])
|
||||
cur_ctx = 0
|
||||
block_id = 0
|
||||
while cur_ctx < b_ctx_len[i]:
|
||||
@@ -370,82 +356,90 @@ def test_contexted_kv_attention_alibi(
|
||||
end_loc = start_loc + block_size
|
||||
start_slot = block_table[i, block_id] * block_size
|
||||
end_slot = start_slot + end_loc - start_loc
|
||||
k_cache.view(-1, num_kv_heads,
|
||||
head_size)[start_slot:end_slot].copy_(
|
||||
key[start_loc:end_loc])
|
||||
v_cache.view(-1, num_kv_heads,
|
||||
head_size)[start_slot:end_slot].copy_(
|
||||
value[start_loc:end_loc])
|
||||
k_cache.view(-1, num_kv_heads, head_size)[start_slot:end_slot].copy_(
|
||||
key[start_loc:end_loc]
|
||||
)
|
||||
v_cache.view(-1, num_kv_heads, head_size)[start_slot:end_slot].copy_(
|
||||
value[start_loc:end_loc]
|
||||
)
|
||||
cur_ctx += block_size
|
||||
block_id += 1
|
||||
# transpose K_cache[num_blocks, block_size, num_kv_heads, head_size]
|
||||
# to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8]
|
||||
k_cache = k_cache.view(-1, block_size, num_kv_heads, head_size // 8,
|
||||
8).permute(0, 2, 3, 1, 4).contiguous()
|
||||
k_cache = (
|
||||
k_cache.view(-1, block_size, num_kv_heads, head_size // 8, 8)
|
||||
.permute(0, 2, 3, 1, 4)
|
||||
.contiguous()
|
||||
)
|
||||
# transpose V_cache[num_blocks, block_size, num_kv_heads, head_size]
|
||||
# to V_cache[num_blocks, num_kv_heads, head_size, block_size]
|
||||
v_cache = v_cache.view(-1, block_size, num_kv_heads,
|
||||
head_size).permute(0, 2, 3, 1).contiguous()
|
||||
v_cache = (
|
||||
v_cache.view(-1, block_size, num_kv_heads, head_size)
|
||||
.permute(0, 2, 3, 1)
|
||||
.contiguous()
|
||||
)
|
||||
k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||
|
||||
# Warm up the Triton kernel by calling it once before actually measuring
|
||||
# generation time
|
||||
op(query,
|
||||
k,
|
||||
v,
|
||||
output,
|
||||
kv_cache_dtype,
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_table,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
MAX_CTX_LEN,
|
||||
max_input_len,
|
||||
k_scale,
|
||||
v_scale,
|
||||
alibi_slopes=alibi_slopes)
|
||||
op(
|
||||
query,
|
||||
k,
|
||||
v,
|
||||
output,
|
||||
kv_cache_dtype,
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_table,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
MAX_CTX_LEN,
|
||||
max_input_len,
|
||||
k_scale,
|
||||
v_scale,
|
||||
alibi_slopes=alibi_slopes,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.time()
|
||||
op(query,
|
||||
k,
|
||||
v,
|
||||
output,
|
||||
kv_cache_dtype,
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_table,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
MAX_CTX_LEN,
|
||||
max_input_len,
|
||||
k_scale,
|
||||
v_scale,
|
||||
alibi_slopes=alibi_slopes)
|
||||
op(
|
||||
query,
|
||||
k,
|
||||
v,
|
||||
output,
|
||||
kv_cache_dtype,
|
||||
k_cache,
|
||||
v_cache,
|
||||
block_table,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
MAX_CTX_LEN,
|
||||
max_input_len,
|
||||
k_scale,
|
||||
v_scale,
|
||||
alibi_slopes=alibi_slopes,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
end_time = time.time()
|
||||
print(f"triton Time: {(end_time - start_time)*1000:.2f} ms")
|
||||
print(f"triton Time: {(end_time - start_time) * 1000:.2f} ms")
|
||||
scale = float(1.0 / (head_size**0.5))
|
||||
|
||||
# NOTE(DefTruth): In order to reuse _make_alibi_bias function,
|
||||
# we have to pad query tensor before MQA/GQA expanding.
|
||||
if query.shape[0] != key.shape[0]:
|
||||
query_pad = torch.empty(sum(seq_lens),
|
||||
num_heads,
|
||||
head_size,
|
||||
dtype=dtype)
|
||||
query_pad = torch.empty(sum(seq_lens), num_heads, head_size, dtype=dtype)
|
||||
query_pad.uniform_(-1e-3, 1e-3)
|
||||
seq_start = 0
|
||||
query_start = 0
|
||||
for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)):
|
||||
seq_end = seq_start + seq_len
|
||||
query_end = query_start + query_len
|
||||
query_pad[seq_start:seq_end, ...] = torch.cat([
|
||||
torch.zeros(
|
||||
seq_len - query_len, num_heads, head_size, dtype=dtype),
|
||||
query[query_start:query_end, ...]
|
||||
],
|
||||
dim=0)
|
||||
query_pad[seq_start:seq_end, ...] = torch.cat(
|
||||
[
|
||||
torch.zeros(seq_len - query_len, num_heads, head_size, dtype=dtype),
|
||||
query[query_start:query_end, ...],
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
seq_start += seq_len
|
||||
query_start += query_len
|
||||
query = query_pad
|
||||
@@ -456,11 +450,12 @@ def test_contexted_kv_attention_alibi(
|
||||
# heads.
|
||||
#
|
||||
# see also: vllm/model_executor/layers/attention.py
|
||||
key = key[:, :, None, :].expand(key.shape[0], num_kv_heads,
|
||||
num_queries_per_kv, key.shape[-1])
|
||||
value = value[:, :,
|
||||
None, :].expand(value.shape[0], num_kv_heads,
|
||||
num_queries_per_kv, value.shape[-1])
|
||||
key = key[:, :, None, :].expand(
|
||||
key.shape[0], num_kv_heads, num_queries_per_kv, key.shape[-1]
|
||||
)
|
||||
value = value[:, :, None, :].expand(
|
||||
value.shape[0], num_kv_heads, num_queries_per_kv, value.shape[-1]
|
||||
)
|
||||
# [seq, num_kv_heads, num_queries_per_kv, dk]=>
|
||||
# [seq, num_kv_heads*num_queries_per_kv, dk] to comply with rest of the
|
||||
# codebase. We save some time reshaping alibi matrix at runtime.
|
||||
@@ -483,24 +478,23 @@ def test_contexted_kv_attention_alibi(
|
||||
for i, (query_len, seq_len) in enumerate(zip(query_lens, seq_lens)):
|
||||
seq_end = seq_start + seq_len
|
||||
query_end = query_start + query_len
|
||||
out = xops.memory_efficient_attention_forward(query[:,
|
||||
seq_start:seq_end],
|
||||
key[:,
|
||||
seq_start:seq_end],
|
||||
value[:,
|
||||
seq_start:seq_end],
|
||||
attn_bias=attn_bias[i],
|
||||
p=0.0,
|
||||
scale=scale)
|
||||
out = xops.memory_efficient_attention_forward(
|
||||
query[:, seq_start:seq_end],
|
||||
key[:, seq_start:seq_end],
|
||||
value[:, seq_start:seq_end],
|
||||
attn_bias=attn_bias[i],
|
||||
p=0.0,
|
||||
scale=scale,
|
||||
)
|
||||
out = out.view_as(query[:, seq_start:seq_end]).view(
|
||||
seq_len, num_heads, head_size)
|
||||
output_ref[query_start:query_end, ...].copy_(out[seq_len - query_len:,
|
||||
...])
|
||||
seq_len, num_heads, head_size
|
||||
)
|
||||
output_ref[query_start:query_end, ...].copy_(out[seq_len - query_len :, ...])
|
||||
seq_start += seq_len
|
||||
query_start += query_len
|
||||
torch.cuda.synchronize()
|
||||
end_time = time.time()
|
||||
print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms")
|
||||
print(f"xformers Time: {(end_time - start_time) * 1000:.2f} ms")
|
||||
atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-6
|
||||
torch.testing.assert_close(output, output_ref, atol=atol, rtol=0)
|
||||
|
||||
@@ -532,9 +526,16 @@ def test_contexted_kv_attention_f32(
|
||||
device: str,
|
||||
op: Callable,
|
||||
) -> None:
|
||||
test_contexted_kv_attention(num_heads, num_queries_per_kv, head_size,
|
||||
sliding_window, dtype, kv_cache_dtype, device,
|
||||
op)
|
||||
test_contexted_kv_attention(
|
||||
num_heads,
|
||||
num_queries_per_kv,
|
||||
head_size,
|
||||
sliding_window,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
device,
|
||||
op,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.optional
|
||||
@@ -555,5 +556,6 @@ def test_contexted_kv_attention_alibi_f32(
|
||||
device: str,
|
||||
op: Callable,
|
||||
) -> None:
|
||||
test_contexted_kv_attention_alibi(num_heads, num_queries_per_kv, head_size,
|
||||
dtype, kv_cache_dtype, device, op)
|
||||
test_contexted_kv_attention_alibi(
|
||||
num_heads, num_queries_per_kv, head_size, dtype, kv_cache_dtype, device, op
|
||||
)
|
||||
|
||||
@@ -11,8 +11,7 @@ from vllm.utils import STR_BACKEND_ENV_VAR
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_cache():
|
||||
"""Clear lru cache to ensure each test case runs without caching.
|
||||
"""
|
||||
"""Clear lru cache to ensure each test case runs without caching."""
|
||||
_cached_get_attn_backend.cache_clear()
|
||||
|
||||
|
||||
@@ -22,46 +21,29 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
|
||||
m.setenv(STR_BACKEND_ENV_VAR, "ROCM_FLASH")
|
||||
|
||||
# Set the current platform to ROCm using monkeypatch
|
||||
monkeypatch.setattr("vllm.attention.selector.current_platform",
|
||||
RocmPlatform())
|
||||
monkeypatch.setattr("vllm.attention.selector.current_platform", RocmPlatform())
|
||||
|
||||
# Test standard ROCm attention
|
||||
backend = get_attn_backend(16, torch.float16, torch.float16, 16, False)
|
||||
assert (backend.get_name() == "ROCM_FLASH"
|
||||
or backend.get_name() == "TRITON_ATTN")
|
||||
assert backend.get_name() == "ROCM_FLASH" or backend.get_name() == "TRITON_ATTN"
|
||||
|
||||
# MLA test for deepseek related
|
||||
|
||||
# change the attention backend to triton MLA
|
||||
m.setenv(STR_BACKEND_ENV_VAR, "TRITON_MLA")
|
||||
backend = get_attn_backend(576,
|
||||
torch.bfloat16,
|
||||
"auto",
|
||||
16,
|
||||
False,
|
||||
use_mla=True)
|
||||
backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, use_mla=True)
|
||||
assert backend.get_name() == "TRITON_MLA"
|
||||
|
||||
# If attention backend is None
|
||||
# If use_mla is true
|
||||
# The selected backend is triton MLA
|
||||
m.setenv(STR_BACKEND_ENV_VAR, None)
|
||||
backend = get_attn_backend(576,
|
||||
torch.bfloat16,
|
||||
"auto",
|
||||
16,
|
||||
False,
|
||||
use_mla=True)
|
||||
backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, use_mla=True)
|
||||
assert backend.get_name() == "TRITON_MLA"
|
||||
|
||||
# change the attention backend to AITER MLA
|
||||
m.setenv(STR_BACKEND_ENV_VAR, "ROCM_AITER_MLA")
|
||||
backend = get_attn_backend(576,
|
||||
torch.bfloat16,
|
||||
"auto",
|
||||
1,
|
||||
False,
|
||||
use_mla=True)
|
||||
backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, use_mla=True)
|
||||
assert backend.get_name() == "ROCM_AITER_MLA"
|
||||
|
||||
# If attention backend is None
|
||||
@@ -70,10 +52,5 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
|
||||
# The selected backend is ROCM_AITER_MLA
|
||||
m.setenv(STR_BACKEND_ENV_VAR, None)
|
||||
m.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
backend = get_attn_backend(576,
|
||||
torch.bfloat16,
|
||||
"auto",
|
||||
1,
|
||||
False,
|
||||
use_mla=True)
|
||||
backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, use_mla=True)
|
||||
assert backend.get_name() == "ROCM_AITER_MLA"
|
||||
|
||||
@@ -24,14 +24,12 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
|
||||
num_kv_splits = 8
|
||||
|
||||
num_pages_per_batch = cdiv(seq_len, PAGE_SIZE)
|
||||
req_to_page = torch.randint(0,
|
||||
CACHE_SIZE // PAGE_SIZE,
|
||||
(B, num_pages_per_batch, 1),
|
||||
device="cuda")
|
||||
req_to_page = torch.randint(
|
||||
0, CACHE_SIZE // PAGE_SIZE, (B, num_pages_per_batch, 1), device="cuda"
|
||||
)
|
||||
req_to_token = req_to_page * PAGE_SIZE
|
||||
req_to_token = req_to_token.expand(B, num_pages_per_batch, PAGE_SIZE)
|
||||
req_to_token = req_to_token + torch.arange(PAGE_SIZE, device="cuda").view(
|
||||
1, 1, -1)
|
||||
req_to_token = req_to_token + torch.arange(PAGE_SIZE, device="cuda").view(1, 1, -1)
|
||||
req_to_token = req_to_token.view(B, -1)
|
||||
req_to_token = req_to_token[:, :seq_len].contiguous()
|
||||
|
||||
@@ -48,7 +46,7 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
|
||||
|
||||
lse = torch.zeros(B, H_Q, dtype=dtype, device="cuda")
|
||||
|
||||
b_seq_len = torch.full((B, ), seq_len, device="cuda")
|
||||
b_seq_len = torch.full((B,), seq_len, device="cuda")
|
||||
|
||||
attn_logits = torch.empty(
|
||||
(B, H_Q, num_kv_splits, D_V + 1),
|
||||
|
||||
@@ -14,9 +14,11 @@ HEAD_SIZES = [128, 256]
|
||||
BLOCK_SIZES = [16]
|
||||
|
||||
DTYPES = [torch.bfloat16]
|
||||
QDTYPES = [None, torch.float8_e4m3fn] if not current_platform.is_rocm() else [
|
||||
None, torch.float8_e4m3fnuz
|
||||
]
|
||||
QDTYPES = (
|
||||
[None, torch.float8_e4m3fn]
|
||||
if not current_platform.is_rocm()
|
||||
else [None, torch.float8_e4m3fnuz]
|
||||
)
|
||||
# one value large enough to test overflow in index calculation.
|
||||
# one value small enough to test the schema op check
|
||||
NUM_BLOCKS = [32768, 2048]
|
||||
@@ -42,7 +44,7 @@ def ref_paged_attn(
|
||||
for i in range(num_seqs):
|
||||
query_len = query_lens[i]
|
||||
kv_len = kv_lens[i]
|
||||
q = query[start_idx:start_idx + query_len]
|
||||
q = query[start_idx : start_idx + query_len]
|
||||
q *= scale
|
||||
|
||||
num_kv_blocks = (kv_len + block_size - 1) // block_size
|
||||
@@ -60,10 +62,13 @@ def ref_paged_attn(
|
||||
empty_mask = torch.ones(query_len, kv_len)
|
||||
mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool()
|
||||
if sliding_window is not None:
|
||||
sliding_window_mask = torch.triu(empty_mask,
|
||||
diagonal=kv_len -
|
||||
(query_len + sliding_window) +
|
||||
1).bool().logical_not()
|
||||
sliding_window_mask = (
|
||||
torch.triu(
|
||||
empty_mask, diagonal=kv_len - (query_len + sliding_window) + 1
|
||||
)
|
||||
.bool()
|
||||
.logical_not()
|
||||
)
|
||||
mask |= sliding_window_mask
|
||||
if soft_cap is not None and soft_cap > 0:
|
||||
attn = soft_cap * torch.tanh(attn / soft_cap)
|
||||
@@ -77,9 +82,9 @@ def ref_paged_attn(
|
||||
return torch.cat(outputs, dim=0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seq_lens",
|
||||
[[(1, 1328), (5, 18),
|
||||
(129, 463)], [(1, 523), (1, 37), (1, 2011)]])
|
||||
@pytest.mark.parametrize(
|
||||
"seq_lens", [[(1, 1328), (5, 18), (129, 463)], [(1, 523), (1, 37), (1, 2011)]]
|
||||
)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@@ -111,30 +116,23 @@ def test_triton_unified_attn(
|
||||
assert num_query_heads % num_kv_heads == 0
|
||||
max_query_len = max(query_lens)
|
||||
max_kv_len = max(kv_lens)
|
||||
window_size = ((sliding_window - 1, 0) if sliding_window is not None else
|
||||
(-1, -1))
|
||||
window_size = (sliding_window - 1, 0) if sliding_window is not None else (-1, -1)
|
||||
scale = head_size**-0.5
|
||||
|
||||
query = torch.randn(sum(query_lens),
|
||||
num_query_heads,
|
||||
head_size,
|
||||
dtype=dtype)
|
||||
key_cache = torch.randn(num_blocks,
|
||||
block_size,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
dtype=dtype)
|
||||
query = torch.randn(sum(query_lens), num_query_heads, head_size, dtype=dtype)
|
||||
key_cache = torch.randn(
|
||||
num_blocks, block_size, num_kv_heads, head_size, dtype=dtype
|
||||
)
|
||||
value_cache = torch.randn_like(key_cache)
|
||||
cu_query_lens = torch.tensor([0] + query_lens,
|
||||
dtype=torch.int32).cumsum(dim=0,
|
||||
dtype=torch.int32)
|
||||
cu_query_lens = torch.tensor([0] + query_lens, dtype=torch.int32).cumsum(
|
||||
dim=0, dtype=torch.int32
|
||||
)
|
||||
kv_lens = torch.tensor(kv_lens, dtype=torch.int32)
|
||||
|
||||
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
|
||||
block_tables = torch.randint(0,
|
||||
num_blocks,
|
||||
(num_seqs, max_num_blocks_per_seq),
|
||||
dtype=torch.int32)
|
||||
block_tables = torch.randint(
|
||||
0, num_blocks, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32
|
||||
)
|
||||
|
||||
output = torch.empty_like(query)
|
||||
|
||||
@@ -188,5 +186,7 @@ def test_triton_unified_attn(
|
||||
atol, rtol = 1.5e-2, 1e-2
|
||||
if q_dtype is not None:
|
||||
atol, rtol = 1.5e-1, 1.5e-1
|
||||
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), \
|
||||
f"{torch.max(torch.abs(output - ref_output))}"
|
||||
(
|
||||
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol),
|
||||
f"{torch.max(torch.abs(output - ref_output))}",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user