364 lines
13 KiB
Python
364 lines
13 KiB
Python
|
|
# SPDX-License-Identifier: Apache-2.0
|
|||
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|||
|
|
import math
|
|||
|
|
|
|||
|
|
import pytest
|
|||
|
|
import torch
|
|||
|
|
|
|||
|
|
from vllm import _custom_ops as ops
|
|||
|
|
|
|||
|
|
# DeepSeek V3 MLA dimensions
|
|||
|
|
NOPE_DIM = 512 # NoPE latent dimension (FP8 quantized in cache)
|
|||
|
|
ROPE_DIM = 64 # RoPE dimension (stored as BF16 in cache)
|
|||
|
|
NUM_TILES = 4 # NOPE_DIM / GROUP_SIZE = 512 / 128
|
|||
|
|
GROUP_SIZE = 128 # FP8 quantization group size (one scale per group)
|
|||
|
|
ENTRY_BYTES = 656 # 512 (FP8) + 16 (4×float32 scales) + 128 (64×BF16 RoPE)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _build_test_case(seq_lens, block_size, seed=42):
|
|||
|
|
"""Build a synthetic FP8 cache and compute the expected BF16 output.
|
|||
|
|
|
|||
|
|
This simulates what concat_and_cache_ds_mla_kernel writes into the
|
|||
|
|
KV cache, then computes what cp_gather_and_upconvert should produce.
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
seq_lens: List of sequence lengths, one per request.
|
|||
|
|
block_size: Number of tokens per physical cache block.
|
|||
|
|
seed: Random seed for reproducibility.
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
Tuple of (cache, block_table, seq_lens_t, workspace_starts_t,
|
|||
|
|
num_reqs, total_tokens, expected_output).
|
|||
|
|
"""
|
|||
|
|
torch.manual_seed(seed)
|
|||
|
|
|
|||
|
|
num_reqs = len(seq_lens)
|
|||
|
|
total_tokens = sum(seq_lens)
|
|||
|
|
|
|||
|
|
# workspace_starts[r] = sum of seq_lens[0..r-1]
|
|||
|
|
# This tells the kernel where in the output buffer each request's
|
|||
|
|
# gathered tokens should be written.
|
|||
|
|
workspace_starts = []
|
|||
|
|
s = 0
|
|||
|
|
for sl in seq_lens:
|
|||
|
|
workspace_starts.append(s)
|
|||
|
|
s += sl
|
|||
|
|
|
|||
|
|
# How many physical cache blocks each request needs
|
|||
|
|
blocks_per_req = [math.ceil(s / block_size) for s in seq_lens]
|
|||
|
|
total_blocks = sum(blocks_per_req)
|
|||
|
|
max_blocks = max(blocks_per_req)
|
|||
|
|
|
|||
|
|
# Block table maps (request, logical_block_idx) -> physical_block_id.
|
|||
|
|
# Here we assign blocks contiguously: request 0 gets blocks [0, 1, ...],
|
|||
|
|
# request 1 gets the next set, etc.
|
|||
|
|
block_table = torch.zeros(num_reqs, max_blocks, dtype=torch.int32, device="cuda")
|
|||
|
|
block_idx = 0
|
|||
|
|
for r in range(num_reqs):
|
|||
|
|
for b in range(blocks_per_req[r]):
|
|||
|
|
block_table[r, b] = block_idx
|
|||
|
|
block_idx += 1
|
|||
|
|
|
|||
|
|
# The raw paged cache: [num_blocks, block_size, 656] as uint8
|
|||
|
|
cache = torch.zeros(
|
|||
|
|
total_blocks, block_size, ENTRY_BYTES, dtype=torch.uint8, device="cuda"
|
|||
|
|
)
|
|||
|
|
# Expected kernel output: [total_tokens, 576] as BF16
|
|||
|
|
expected = torch.zeros(
|
|||
|
|
total_tokens, NOPE_DIM + ROPE_DIM, dtype=torch.bfloat16, device="cuda"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Fill each token's cache entry and compute expected output
|
|||
|
|
for r in range(num_reqs):
|
|||
|
|
for t in range(seq_lens[r]):
|
|||
|
|
out_idx = workspace_starts[r] + t
|
|||
|
|
# Map token position -> (physical_block, offset_within_block)
|
|||
|
|
phys = block_table[r, t // block_size].item()
|
|||
|
|
off = t % block_size
|
|||
|
|
|
|||
|
|
# --- NoPE section: 4 tiles of 128 FP8 values, each with a scale ---
|
|||
|
|
for tile in range(NUM_TILES):
|
|||
|
|
start = tile * GROUP_SIZE
|
|||
|
|
|
|||
|
|
# Generate random data and quantize to FP8 e4m3
|
|||
|
|
fp8_vals = torch.randn(GROUP_SIZE, device="cuda").to(
|
|||
|
|
torch.float8_e4m3fn
|
|||
|
|
)
|
|||
|
|
# Pack FP8 bytes into cache at bytes [start : start+128]
|
|||
|
|
cache[phys, off, start : start + GROUP_SIZE] = fp8_vals.view(
|
|||
|
|
torch.uint8
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Random positive scale in [0.1, 2.1]
|
|||
|
|
scale = (torch.rand(1, device="cuda") * 2.0 + 0.1).item()
|
|||
|
|
scale_t = torch.tensor([scale], dtype=torch.float32, device="cuda")
|
|||
|
|
# Pack scale as 4 raw bytes at bytes [512 + tile*4 : ...]
|
|||
|
|
cache[phys, off, NOPE_DIM + tile * 4 : NOPE_DIM + (tile + 1) * 4] = (
|
|||
|
|
scale_t.view(torch.uint8)
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Reference dequant: fp8 -> float32, multiply scale, -> bf16.
|
|||
|
|
# This matches the CUDA path: fp8 -> half -> float * scale -> bf16.
|
|||
|
|
# (fp8 -> half is exact, half -> float is exact, so fp8 -> float
|
|||
|
|
# gives the same result regardless of intermediate type.)
|
|||
|
|
expected[out_idx, start : start + GROUP_SIZE] = (
|
|||
|
|
fp8_vals.float() * scale
|
|||
|
|
).bfloat16()
|
|||
|
|
|
|||
|
|
# --- RoPE section: 64 BF16 values, direct copy (no dequant) ---
|
|||
|
|
rope = torch.randn(ROPE_DIM, dtype=torch.bfloat16, device="cuda")
|
|||
|
|
# Pack RoPE bytes into cache at bytes [528 : 656]
|
|||
|
|
cache[phys, off, NOPE_DIM + 16 :] = rope.view(torch.uint8)
|
|||
|
|
# Expected output: exact copy
|
|||
|
|
expected[out_idx, NOPE_DIM:] = rope
|
|||
|
|
|
|||
|
|
seq_lens_t = torch.tensor(seq_lens, dtype=torch.int32, device="cuda")
|
|||
|
|
workspace_starts_t = torch.tensor(
|
|||
|
|
workspace_starts, dtype=torch.int32, device="cuda"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return (
|
|||
|
|
cache,
|
|||
|
|
block_table,
|
|||
|
|
seq_lens_t,
|
|||
|
|
workspace_starts_t,
|
|||
|
|
num_reqs,
|
|||
|
|
total_tokens,
|
|||
|
|
expected,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _build_test_case_fast(seq_lens, block_size, seed=42):
|
|||
|
|
"""Vectorized test-case builder for large sequence lengths.
|
|||
|
|
|
|||
|
|
Same logic as _build_test_case but uses tensor operations instead of
|
|||
|
|
per-token Python loops, making it practical for seq_lens up to 128K+.
|
|||
|
|
"""
|
|||
|
|
torch.manual_seed(seed)
|
|||
|
|
|
|||
|
|
num_reqs = len(seq_lens)
|
|||
|
|
total_tokens = sum(seq_lens)
|
|||
|
|
|
|||
|
|
workspace_starts = []
|
|||
|
|
s = 0
|
|||
|
|
for sl in seq_lens:
|
|||
|
|
workspace_starts.append(s)
|
|||
|
|
s += sl
|
|||
|
|
|
|||
|
|
blocks_per_req = [math.ceil(sl / block_size) for sl in seq_lens]
|
|||
|
|
total_blocks = sum(blocks_per_req)
|
|||
|
|
max_blocks = max(blocks_per_req)
|
|||
|
|
|
|||
|
|
# Contiguous block allocation
|
|||
|
|
block_table = torch.zeros(num_reqs, max_blocks, dtype=torch.int32, device="cuda")
|
|||
|
|
block_idx = 0
|
|||
|
|
for r in range(num_reqs):
|
|||
|
|
for b in range(blocks_per_req[r]):
|
|||
|
|
block_table[r, b] = block_idx
|
|||
|
|
block_idx += 1
|
|||
|
|
|
|||
|
|
cache = torch.zeros(
|
|||
|
|
total_blocks, block_size, ENTRY_BYTES, dtype=torch.uint8, device="cuda"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Generate all data vectorized
|
|||
|
|
nope_fp8 = torch.randn(total_tokens, NOPE_DIM, device="cuda").to(
|
|||
|
|
torch.float8_e4m3fn
|
|||
|
|
)
|
|||
|
|
scales = (torch.rand(total_tokens, NUM_TILES, device="cuda") * 2.0 + 0.1).float()
|
|||
|
|
rope = torch.randn(total_tokens, ROPE_DIM, dtype=torch.bfloat16, device="cuda")
|
|||
|
|
|
|||
|
|
# Compute expected output vectorized (same dequant logic as kernel)
|
|||
|
|
expected = torch.zeros(
|
|||
|
|
total_tokens, NOPE_DIM + ROPE_DIM, dtype=torch.bfloat16, device="cuda"
|
|||
|
|
)
|
|||
|
|
for tile in range(NUM_TILES):
|
|||
|
|
start = tile * GROUP_SIZE
|
|||
|
|
expected[:, start : start + GROUP_SIZE] = (
|
|||
|
|
nope_fp8[:, start : start + GROUP_SIZE].float() * scales[:, tile : tile + 1]
|
|||
|
|
).bfloat16()
|
|||
|
|
expected[:, NOPE_DIM:] = rope
|
|||
|
|
|
|||
|
|
# Build per-token cache entries as [total_tokens, 656] uint8
|
|||
|
|
token_data = torch.zeros(
|
|||
|
|
total_tokens, ENTRY_BYTES, dtype=torch.uint8, device="cuda"
|
|||
|
|
)
|
|||
|
|
token_data[:, :NOPE_DIM] = nope_fp8.view(torch.uint8)
|
|||
|
|
token_data[:, NOPE_DIM : NOPE_DIM + 16] = scales.view(torch.uint8)
|
|||
|
|
token_data[:, NOPE_DIM + 16 :] = rope.view(torch.uint8)
|
|||
|
|
|
|||
|
|
# Scatter into paged cache (loop over requests, not tokens)
|
|||
|
|
block_start = 0
|
|||
|
|
for r in range(num_reqs):
|
|||
|
|
sl = seq_lens[r]
|
|||
|
|
nb = blocks_per_req[r]
|
|||
|
|
ws = workspace_starts[r]
|
|||
|
|
flat_cache = cache[block_start : block_start + nb].reshape(-1, ENTRY_BYTES)
|
|||
|
|
flat_cache[:sl] = token_data[ws : ws + sl]
|
|||
|
|
block_start += nb
|
|||
|
|
|
|||
|
|
seq_lens_t = torch.tensor(seq_lens, dtype=torch.int32, device="cuda")
|
|||
|
|
workspace_starts_t = torch.tensor(
|
|||
|
|
workspace_starts, dtype=torch.int32, device="cuda"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return (
|
|||
|
|
cache,
|
|||
|
|
block_table,
|
|||
|
|
seq_lens_t,
|
|||
|
|
workspace_starts_t,
|
|||
|
|
num_reqs,
|
|||
|
|
total_tokens,
|
|||
|
|
expected,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@pytest.mark.parametrize(
|
|||
|
|
"seq_lens,block_size",
|
|||
|
|
[
|
|||
|
|
# Production block_size=64 (only supported value for FlashMLA sparse).
|
|||
|
|
# Realistic prefill scenarios with varying request counts.
|
|||
|
|
([1], 64), # single token edge case
|
|||
|
|
([64], 64), # 1 req, exactly one block
|
|||
|
|
([128], 64), # 1 req, crosses block boundary
|
|||
|
|
([512], 64), # 1 req, longer prefill
|
|||
|
|
([256, 128, 384], 64), # 3 reqs, varying lengths
|
|||
|
|
([128] * 4, 64), # 4 reqs, equal lengths
|
|||
|
|
([64] * 16, 64), # 16 reqs, shorter prefills
|
|||
|
|
],
|
|||
|
|
)
|
|||
|
|
def test_cp_gather_and_upconvert_fp8_kv_cache(seq_lens, block_size):
|
|||
|
|
"""Core correctness test: build cache, run kernel, compare output."""
|
|||
|
|
(
|
|||
|
|
cache,
|
|||
|
|
block_table,
|
|||
|
|
seq_lens_t,
|
|||
|
|
workspace_starts_t,
|
|||
|
|
num_reqs,
|
|||
|
|
total_tokens,
|
|||
|
|
expected,
|
|||
|
|
) = _build_test_case(seq_lens, block_size)
|
|||
|
|
|
|||
|
|
dst = torch.zeros(
|
|||
|
|
total_tokens, NOPE_DIM + ROPE_DIM, dtype=torch.bfloat16, device="cuda"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
ops.cp_gather_and_upconvert_fp8_kv_cache(
|
|||
|
|
cache, dst, block_table, seq_lens_t, workspace_starts_t, num_reqs
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# NoPE: fp8 dequant has rounding error, so we allow small tolerance.
|
|||
|
|
# The fp8 -> float -> bf16 path can differ by up to ~1 ULP of bf16.
|
|||
|
|
torch.testing.assert_close(
|
|||
|
|
dst[:, :NOPE_DIM], expected[:, :NOPE_DIM], atol=1e-3, rtol=1e-2
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# RoPE: pure bf16 copy, must be bit-exact
|
|||
|
|
assert torch.equal(dst[:, NOPE_DIM:], expected[:, NOPE_DIM:])
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_cp_gather_fp8_shuffled_blocks():
|
|||
|
|
"""Test that the kernel correctly follows the block table when
|
|||
|
|
physical blocks are non-contiguous and out of order.
|
|||
|
|
|
|||
|
|
Here we allocate 4 physical blocks but map the request's 2 logical
|
|||
|
|
blocks to physical blocks [3, 1] (reversed, with gaps).
|
|||
|
|
"""
|
|||
|
|
torch.manual_seed(123)
|
|||
|
|
block_size = 4
|
|||
|
|
seq_lens = [8] # needs 2 blocks (tokens 0-3 in block 0, 4-7 in block 1)
|
|||
|
|
total_tokens = 8
|
|||
|
|
|
|||
|
|
# 4 physical blocks, but only blocks 3 and 1 are used (in that order).
|
|||
|
|
# Tokens 0-3 -> physical block 3, tokens 4-7 -> physical block 1.
|
|||
|
|
num_phys_blocks = 4
|
|||
|
|
cache = torch.zeros(
|
|||
|
|
num_phys_blocks, block_size, ENTRY_BYTES, dtype=torch.uint8, device="cuda"
|
|||
|
|
)
|
|||
|
|
block_table = torch.tensor([[3, 1]], dtype=torch.int32, device="cuda")
|
|||
|
|
workspace_starts = torch.tensor([0], dtype=torch.int32, device="cuda")
|
|||
|
|
seq_lens_t = torch.tensor(seq_lens, dtype=torch.int32, device="cuda")
|
|||
|
|
|
|||
|
|
expected = torch.zeros(
|
|||
|
|
total_tokens, NOPE_DIM + ROPE_DIM, dtype=torch.bfloat16, device="cuda"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Fill cache at the shuffled physical locations
|
|||
|
|
for t in range(total_tokens):
|
|||
|
|
# Follow the same block_table lookup the kernel will use
|
|||
|
|
phys = block_table[0, t // block_size].item()
|
|||
|
|
off = t % block_size
|
|||
|
|
|
|||
|
|
for tile in range(NUM_TILES):
|
|||
|
|
start = tile * GROUP_SIZE
|
|||
|
|
fp8_vals = torch.randn(GROUP_SIZE, device="cuda").to(torch.float8_e4m3fn)
|
|||
|
|
cache[phys, off, start : start + GROUP_SIZE] = fp8_vals.view(torch.uint8)
|
|||
|
|
|
|||
|
|
# Use a fixed scale to keep this test simple
|
|||
|
|
scale = 1.5
|
|||
|
|
scale_t = torch.tensor([scale], dtype=torch.float32, device="cuda")
|
|||
|
|
cache[phys, off, NOPE_DIM + tile * 4 : NOPE_DIM + (tile + 1) * 4] = (
|
|||
|
|
scale_t.view(torch.uint8)
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
expected[t, start : start + GROUP_SIZE] = (
|
|||
|
|
fp8_vals.float() * scale
|
|||
|
|
).bfloat16()
|
|||
|
|
|
|||
|
|
rope = torch.randn(ROPE_DIM, dtype=torch.bfloat16, device="cuda")
|
|||
|
|
cache[phys, off, NOPE_DIM + 16 :] = rope.view(torch.uint8)
|
|||
|
|
expected[t, NOPE_DIM:] = rope
|
|||
|
|
|
|||
|
|
dst = torch.zeros(
|
|||
|
|
total_tokens, NOPE_DIM + ROPE_DIM, dtype=torch.bfloat16, device="cuda"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
ops.cp_gather_and_upconvert_fp8_kv_cache(
|
|||
|
|
cache, dst, block_table, seq_lens_t, workspace_starts, len(seq_lens)
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
torch.testing.assert_close(
|
|||
|
|
dst[:, :NOPE_DIM], expected[:, :NOPE_DIM], atol=1e-3, rtol=1e-2
|
|||
|
|
)
|
|||
|
|
assert torch.equal(dst[:, NOPE_DIM:], expected[:, NOPE_DIM:])
|
|||
|
|
|
|||
|
|
|
|||
|
|
@pytest.mark.parametrize(
|
|||
|
|
"seq_lens,block_size",
|
|||
|
|
[
|
|||
|
|
# Large sequence lengths matching end-to-end benchmark scenarios.
|
|||
|
|
# Uses vectorized builder since per-token Python loops would be too slow.
|
|||
|
|
([8000], 64),
|
|||
|
|
([16000], 64),
|
|||
|
|
([32000], 64),
|
|||
|
|
([64000], 64),
|
|||
|
|
([96000], 64),
|
|||
|
|
([128000], 64),
|
|||
|
|
],
|
|||
|
|
)
|
|||
|
|
def test_cp_gather_fp8_large_seqlens(seq_lens, block_size):
|
|||
|
|
"""Correctness test with large sequence lengths matching benchmark
|
|||
|
|
scenarios (8K-128K prefill)."""
|
|||
|
|
(
|
|||
|
|
cache,
|
|||
|
|
block_table,
|
|||
|
|
seq_lens_t,
|
|||
|
|
workspace_starts_t,
|
|||
|
|
num_reqs,
|
|||
|
|
total_tokens,
|
|||
|
|
expected,
|
|||
|
|
) = _build_test_case_fast(seq_lens, block_size)
|
|||
|
|
|
|||
|
|
dst = torch.zeros(
|
|||
|
|
total_tokens, NOPE_DIM + ROPE_DIM, dtype=torch.bfloat16, device="cuda"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
ops.cp_gather_and_upconvert_fp8_kv_cache(
|
|||
|
|
cache, dst, block_table, seq_lens_t, workspace_starts_t, num_reqs
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
torch.testing.assert_close(
|
|||
|
|
dst[:, :NOPE_DIM], expected[:, :NOPE_DIM], atol=1e-3, rtol=1e-2
|
|||
|
|
)
|
|||
|
|
assert torch.equal(dst[:, NOPE_DIM:], expected[:, NOPE_DIM:])
|