[Kernel][CPU] CPU MLA (#14744)

Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
This commit is contained in:
Thien Tran
2025-03-25 17:34:59 +08:00
committed by GitHub
parent 4157f563b4
commit 4f044b1d67
15 changed files with 1010 additions and 17 deletions

View File

@@ -749,3 +749,72 @@ def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size,
ops.gather_cache(src_cache, dst, block_table, cu_seq_lens, batch_size)
torch.testing.assert_close(dst, expected)
@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS)
@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS_MLA)
@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA)
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.cpu_model
@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only")
@torch.inference_mode()
def test_concat_and_cache_mla_cpu(
kv_lora_rank: int,
qk_rope_head_dim: int,
num_tokens: int,
block_size: int,
num_blocks: int,
dtype: torch.dtype,
seed: int,
) -> None:
device = "cpu"
kv_cache_dtype = "auto"
current_platform.seed_everything(seed)
torch.set_default_device(device)
total_slots = num_blocks * block_size
slot_mapping_lst = random.sample(range(total_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping_lst,
dtype=torch.long,
device=device)
kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device)
k_pe = torch.randn(num_tokens,
qk_rope_head_dim,
dtype=dtype,
device=device)
entry_size = kv_lora_rank + qk_rope_head_dim
scale = torch.tensor(0.1, dtype=torch.float32, device=device)
kv_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
kv_cache_dtype, device)
ref_temp = torch.zeros(*kv_cache.shape, dtype=dtype, device=device)
for i in range(num_tokens):
slot = slot_mapping[i].item()
block_idx = slot // block_size
block_offset = slot % block_size
ref_temp[block_idx, block_offset, :kv_lora_rank] = kv_c[i]
ref_temp[block_idx, block_offset, kv_lora_rank:] = k_pe[i]
if kv_cache_dtype == "fp8":
ref_kv_cache = torch.empty_like(ref_temp, dtype=kv_cache.dtype)
ops.convert_fp8(ref_kv_cache,
ref_temp,
scale.item(),
kv_dtype=kv_cache_dtype)
else:
ref_kv_cache = ref_temp
opcheck(
torch.ops._C_cache_ops.concat_and_cache_mla,
(kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale),
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
)
ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping,
kv_cache_dtype, scale)
torch.testing.assert_close(kv_cache, ref_kv_cache)

View File

@@ -0,0 +1,94 @@
# SPDX-License-Identifier: Apache-2.0
import pytest
import torch
import torch.nn.functional as F
from torch import Tensor
import vllm._custom_ops as ops
from vllm.platforms import current_platform
def cdiv(a, b):
return (a + b - 1) // b
def ref_mla(
out: Tensor, # (bs, num_heads, v_head_dim)
query: Tensor, # (bs, num_heads, head_dim)
kv_cache: Tensor, # (num_blocks, block_size, head_dim)
scale: float,
block_tables: Tensor, # (bs, max_num_blocks)
seq_lens: Tensor, # (bs,)
):
bs, num_heads, v_head_dim = out.shape
head_dim = query.shape[2]
for i in range(bs):
# gather and flatten KV-cache
kv = kv_cache[
block_tables[i]] # (max_num_blocks, block_size, head_dim)
kv = kv.view(1, -1,
head_dim)[:, :seq_lens[i]] # (1, seq_len, head_dim)
v = kv[:, :, :v_head_dim]
q = query[i].view(num_heads, 1, head_dim)
o = F.scaled_dot_product_attention(q,
kv,
v,
scale=scale,
enable_gqa=True)
out[i] = o.view(num_heads, v_head_dim)
return out
@pytest.mark.parametrize("bs", [4])
@pytest.mark.parametrize("mean_seq_len", [256])
@pytest.mark.parametrize("h_q", [16])
@pytest.mark.parametrize("d", [576])
@pytest.mark.parametrize("dv", [512])
@pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("dtype", [torch.float, torch.half, torch.bfloat16])
@pytest.mark.parametrize("varlen", [False, True])
@pytest.mark.cpu_model
@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only")
def test_mla_decode_cpu(
bs: int,
mean_seq_len: int,
h_q: int,
d: int,
dv: int,
block_size: int,
dtype: torch.dtype,
varlen: bool,
):
torch.set_default_dtype(dtype)
torch.manual_seed(0)
scale = d**(-0.5)
if varlen:
seq_lens = torch.empty(bs).normal_(mean_seq_len, mean_seq_len / 2)
seq_lens = seq_lens.clip(2).to(torch.int32)
else:
seq_lens = torch.full((bs, ), mean_seq_len, dtype=torch.int32)
max_seq_len = seq_lens.max().item()
seqlen_pad = cdiv(max_seq_len, 256) * 256 # is this necessary?
q = torch.randn(bs, h_q, d)
block_table = torch.arange(bs * seqlen_pad // block_size,
dtype=torch.int32)
block_table = block_table.view(bs, seqlen_pad // block_size)
kv_cache = torch.randn(block_table.numel(), block_size, d)
for i, seq_len in enumerate(seq_lens.tolist()):
kv_cache.view(bs, seqlen_pad, d)[i, seq_len:] = float("nan")
out_mla = q.new_zeros(bs, h_q, dv)
ops.mla_decode_kvcache_cpu(out_mla, q, kv_cache, scale, block_table,
seq_lens)
out_ref = q.new_zeros(bs, h_q, dv)
ref_mla(out_ref, q, kv_cache, scale, block_table, seq_lens)
assert not out_mla.isnan().any(), "Likely read out of bounds"
torch.testing.assert_close(out_mla, out_ref)