[Attention][Perf] Optimize cp_gather_and_upconvert_fp8_kv_cache - DeepSeek-v3.2 (#35290)
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com> Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
committed by
GitHub
parent
70485a11bd
commit
2b28b9b269
153
benchmarks/kernels/bench_cp_gather_fp8.py
Normal file
153
benchmarks/kernels/bench_cp_gather_fp8.py
Normal file
@@ -0,0 +1,153 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.triton_utils import triton
|
||||
|
||||
# DeepSeek V3 MLA dimensions
|
||||
NOPE_DIM = 512
|
||||
ROPE_DIM = 64
|
||||
HEAD_DIM = NOPE_DIM + ROPE_DIM # 576 BF16 output elements per token
|
||||
ENTRY_BYTES = 656 # 512 FP8 + 16 scales + 128 BF16 RoPE
|
||||
BLOCK_SIZE = 64 # tokens per physical cache block - get_supported_kernel_block_sizes
|
||||
|
||||
# Realistic prefill scenarios:
|
||||
# - 1 long prefill: single request, 16K-96K tokens
|
||||
# - 4 medium prefills: 4 requests, 4K-24K tokens each
|
||||
# - 16 shorter prefills: 16 requests, 1K-6K tokens each
|
||||
SCENARIOS = [
|
||||
# (label, num_reqs, total_tokens_list)
|
||||
("1-req", 1, [8192, 16384, 32768, 65536, 98304]),
|
||||
("4-reqs", 4, [8192, 16384, 32768, 65536, 98304]),
|
||||
("16-reqs", 16, [8192, 16384, 32768, 65536, 98304]),
|
||||
]
|
||||
|
||||
|
||||
def make_inputs(total_tokens, num_reqs, block_size):
|
||||
"""Create synthetic FP8 cache, block table, and output buffer.
|
||||
|
||||
Fills the cache with random bytes (we only measure throughput,
|
||||
not correctness). Block table maps each request to contiguous
|
||||
physical blocks.
|
||||
"""
|
||||
# Divide tokens evenly across requests
|
||||
base_len = total_tokens // num_reqs
|
||||
remainder = total_tokens % num_reqs
|
||||
seq_lens = [base_len + (1 if r < remainder else 0) for r in range(num_reqs)]
|
||||
|
||||
# workspace_starts: cumulative sum of seq_lens
|
||||
workspace_starts = [0] * num_reqs
|
||||
for r in range(1, num_reqs):
|
||||
workspace_starts[r] = workspace_starts[r - 1] + seq_lens[r - 1]
|
||||
|
||||
# Physical blocks needed per request
|
||||
blocks_per_req = [math.ceil(s / block_size) for s in seq_lens]
|
||||
total_blocks = sum(blocks_per_req)
|
||||
max_blocks = max(blocks_per_req)
|
||||
|
||||
# Allocate cache with random data (content doesn't matter for perf)
|
||||
cache = torch.randint(
|
||||
0,
|
||||
256,
|
||||
(total_blocks, block_size, ENTRY_BYTES),
|
||||
dtype=torch.uint8,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
# Block table: contiguous block assignments
|
||||
block_table = torch.zeros(num_reqs, max_blocks, dtype=torch.int32, device="cuda")
|
||||
block_idx = 0
|
||||
for r in range(num_reqs):
|
||||
for b in range(blocks_per_req[r]):
|
||||
block_table[r, b] = block_idx
|
||||
block_idx += 1
|
||||
|
||||
# Output workspace
|
||||
dst = torch.zeros(total_tokens, HEAD_DIM, dtype=torch.bfloat16, device="cuda")
|
||||
|
||||
seq_lens_t = torch.tensor(seq_lens, dtype=torch.int32, device="cuda")
|
||||
workspace_starts_t = torch.tensor(
|
||||
workspace_starts, dtype=torch.int32, device="cuda"
|
||||
)
|
||||
|
||||
return cache, dst, block_table, seq_lens_t, workspace_starts_t
|
||||
|
||||
|
||||
def bench_scenario(label, num_reqs, total_tokens_list, save_path):
|
||||
"""Run benchmark for a specific (num_reqs, total_tokens) scenario."""
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["total_tokens"],
|
||||
x_vals=total_tokens_list,
|
||||
line_arg="provider",
|
||||
line_vals=["cuda_kernel"],
|
||||
line_names=["cp_gather_fp8 (CUDA)"],
|
||||
styles=[("green", "-")],
|
||||
ylabel="Latency (us)",
|
||||
plot_name=f"cp_gather_fp8-{label}-bs{BLOCK_SIZE}",
|
||||
args={"num_reqs": num_reqs},
|
||||
)
|
||||
)
|
||||
def bench_fn(total_tokens, provider, num_reqs):
|
||||
cache, dst, block_table, seq_lens_t, ws_starts = make_inputs(
|
||||
total_tokens, num_reqs, BLOCK_SIZE
|
||||
)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: ops.cp_gather_and_upconvert_fp8_kv_cache(
|
||||
cache, dst, block_table, seq_lens_t, ws_starts, num_reqs
|
||||
),
|
||||
quantiles=quantiles,
|
||||
rep=500,
|
||||
)
|
||||
|
||||
return ms * 1000, max_ms * 1000, min_ms * 1000 # us
|
||||
|
||||
seq_len_per_req = total_tokens_list[0] // num_reqs
|
||||
seq_len_per_req_max = total_tokens_list[-1] // num_reqs
|
||||
print(
|
||||
f"\n--- {label}: {num_reqs} request(s), "
|
||||
f"~{seq_len_per_req}-{seq_len_per_req_max} tokens/req ---"
|
||||
)
|
||||
bench_fn.run(print_data=True, save_path=save_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Benchmark cp_gather_and_upconvert_fp8_kv_cache"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to save benchmark results as CSV",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Print data volume info for bandwidth analysis
|
||||
read_per_token = ENTRY_BYTES # 656 bytes from cache
|
||||
write_per_token = HEAD_DIM * 2 # 576 * 2 = 1152 bytes to workspace
|
||||
total_per_token = read_per_token + write_per_token # 1808 bytes
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print("CP_GATHER_AND_UPCONVERT_FP8_KV_CACHE BENCHMARKS")
|
||||
print("=" * 70)
|
||||
print(f"Cache entry: {ENTRY_BYTES} bytes (512 FP8 + 16 scales + 128 RoPE)")
|
||||
print(f"Output row: {HEAD_DIM} BF16 = {HEAD_DIM * 2} bytes")
|
||||
print(f"Per token: {total_per_token} bytes (read + write)")
|
||||
print(f"Block size: {BLOCK_SIZE} tokens/block")
|
||||
print("=" * 70)
|
||||
|
||||
for label, num_reqs, total_tokens_list in SCENARIOS:
|
||||
bench_scenario(label, num_reqs, total_tokens_list, args.save_path)
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print("Benchmarking complete!")
|
||||
print("=" * 70)
|
||||
@@ -995,75 +995,67 @@ namespace vllm {
|
||||
// Similar to cp_gather_cache but specifically for FP8->BF16 conversion
|
||||
__global__ void cp_gather_and_upconvert_fp8_kv_cache(
|
||||
const uint8_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE, 656]
|
||||
__nv_bfloat16* __restrict__ dst, // [TOT_TOKENS, 576]
|
||||
const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES]
|
||||
const int32_t* __restrict__ seq_lens, // [BATCH]
|
||||
const int32_t* __restrict__ workspace_starts, // [BATCH]
|
||||
const int32_t block_size, const int32_t head_dim,
|
||||
const int64_t block_table_stride, const int64_t cache_block_stride,
|
||||
const int64_t cache_entry_stride, const int64_t dst_entry_stride) {
|
||||
const int64_t bid = blockIdx.x; // Batch ID
|
||||
const int32_t num_splits = gridDim.y;
|
||||
const int32_t split = blockIdx.y;
|
||||
const int32_t seq_start = workspace_starts[bid];
|
||||
const int32_t seq_len = seq_lens[bid];
|
||||
const int32_t tot_slots = seq_len;
|
||||
const int32_t split_slots = cuda_utils::ceil_div(tot_slots, num_splits);
|
||||
__nv_bfloat16* __restrict__ dst, // [total_tokens, 576]
|
||||
const int32_t* __restrict__ block_table, // [num_reqs, BLOCK_INDICES]
|
||||
const int32_t* __restrict__ workspace_starts, // [num_reqs]
|
||||
const int32_t num_reqs, const int32_t block_size,
|
||||
const int32_t total_tokens, const int64_t block_table_stride,
|
||||
const int64_t cache_block_stride, const int64_t cache_entry_stride,
|
||||
const int64_t dst_entry_stride) {
|
||||
const int flat_warp_id = (blockIdx.x * blockDim.x + threadIdx.x) >> 5;
|
||||
if (flat_warp_id >= total_tokens) return;
|
||||
const int lane_id = threadIdx.x & 31;
|
||||
|
||||
const int32_t split_start = split * split_slots;
|
||||
const int32_t split_end = min((split + 1) * split_slots, tot_slots);
|
||||
|
||||
const bool is_active_split = (split_start < tot_slots);
|
||||
|
||||
if (!is_active_split) return;
|
||||
|
||||
// Adjust the pointer for the block_table for this batch
|
||||
const int32_t batch_offset = bid * block_table_stride;
|
||||
int32_t offset = split_start;
|
||||
int32_t offset_div = offset / block_size;
|
||||
offset = offset % block_size;
|
||||
const int32_t* batch_block_table = block_table + batch_offset;
|
||||
|
||||
// Adjust dst pointer based on the cumulative sequence lengths
|
||||
dst += seq_start * dst_entry_stride;
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
// Process each token in this split
|
||||
for (int pid = split_start; pid < split_end; ++pid) {
|
||||
auto block_id = batch_block_table[offset_div];
|
||||
const uint8_t* token_ptr =
|
||||
src_cache + block_id * cache_block_stride + offset * cache_entry_stride;
|
||||
__nv_bfloat16* dst_ptr = dst + pid * dst_entry_stride;
|
||||
|
||||
// FP8 format: 512 bytes fp8 + 16 bytes scales + 128 bytes rope (64 bf16)
|
||||
const uint8_t* no_pe_ptr = token_ptr;
|
||||
const float* scales_ptr = reinterpret_cast<const float*>(token_ptr + 512);
|
||||
const __nv_bfloat16* rope_ptr =
|
||||
reinterpret_cast<const __nv_bfloat16*>(token_ptr + 512 + 16);
|
||||
|
||||
// Parallelize fp8 dequant (512 elements) and rope copy (64 elements)
|
||||
if (tid < 512) {
|
||||
// FP8 dequantization
|
||||
const int tile = tid >> 7; // each tile is 128 elements
|
||||
const float scale = scales_ptr[tile];
|
||||
const uint8_t val = no_pe_ptr[tid];
|
||||
dst_ptr[tid] =
|
||||
fp8::scaled_convert<__nv_bfloat16, uint8_t,
|
||||
vllm::Fp8KVCacheDataType::kFp8E4M3>(val, scale);
|
||||
} else if (tid < 576) {
|
||||
// Rope copy (64 bf16 elements)
|
||||
const int rope_idx = tid - 512;
|
||||
dst_ptr[512 + rope_idx] = rope_ptr[rope_idx];
|
||||
}
|
||||
|
||||
// Move to next token
|
||||
offset += 1;
|
||||
if (offset == block_size) {
|
||||
offset_div += 1;
|
||||
offset = 0;
|
||||
}
|
||||
// Binary search to find which request owns this output token
|
||||
int lo = 0, hi = num_reqs - 1;
|
||||
while (lo < hi) {
|
||||
int mid = (lo + hi + 1) >> 1;
|
||||
if (workspace_starts[mid] <= flat_warp_id)
|
||||
lo = mid;
|
||||
else
|
||||
hi = mid - 1;
|
||||
}
|
||||
const int req_id = lo;
|
||||
|
||||
// Compute physical token address via block table
|
||||
const int out_token_id = flat_warp_id;
|
||||
const int token_offset = out_token_id - workspace_starts[req_id];
|
||||
const int cache_block_idx = token_offset / block_size;
|
||||
const int offset_in_block = token_offset % block_size;
|
||||
const int physical_block =
|
||||
block_table[req_id * block_table_stride + cache_block_idx];
|
||||
|
||||
const uint8_t* token_ptr = src_cache + physical_block * cache_block_stride +
|
||||
offset_in_block * cache_entry_stride;
|
||||
|
||||
const int4* nope_src = reinterpret_cast<const int4*>(token_ptr);
|
||||
const int4 fp8_data = nope_src[lane_id];
|
||||
|
||||
const float* scales_ptr = reinterpret_cast<const float*>(token_ptr + 512);
|
||||
const float scale = scales_ptr[lane_id >> 3];
|
||||
|
||||
const uint2 fp8_lo = make_uint2(fp8_data.x, fp8_data.y);
|
||||
const uint2 fp8_hi = make_uint2(fp8_data.z, fp8_data.w);
|
||||
#ifdef USE_ROCM
|
||||
const bf16_8_t bf16_lo =
|
||||
fp8::scaled_vec_conversion<bf16_8_t, uint2>(fp8_lo, scale);
|
||||
const bf16_8_t bf16_hi =
|
||||
fp8::scaled_vec_conversion<bf16_8_t, uint2>(fp8_hi, scale);
|
||||
#else
|
||||
const bf16_8_t bf16_lo =
|
||||
fp8::scaled_vec_conversion<bf16_8_t, uint2>(fp8_lo, scale, __NV_E4M3);
|
||||
const bf16_8_t bf16_hi =
|
||||
fp8::scaled_vec_conversion<bf16_8_t, uint2>(fp8_hi, scale, __NV_E4M3);
|
||||
#endif
|
||||
|
||||
__nv_bfloat16* dst_ptr = dst + out_token_id * dst_entry_stride;
|
||||
int4* nope_dst = reinterpret_cast<int4*>(dst_ptr) + lane_id * 2;
|
||||
nope_dst[0] = *reinterpret_cast<const int4*>(&bf16_lo);
|
||||
nope_dst[1] = *reinterpret_cast<const int4*>(&bf16_hi);
|
||||
|
||||
const int* rope_src = reinterpret_cast<const int*>(token_ptr + 528);
|
||||
int* rope_dst = reinterpret_cast<int*>(dst_ptr + 512);
|
||||
rope_dst[lane_id] = rope_src[lane_id];
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
@@ -1257,15 +1249,16 @@ void cp_gather_and_upconvert_fp8_kv_cache(
|
||||
src_ptr = reinterpret_cast<const uint8_t*>(src_cache.data_ptr());
|
||||
}
|
||||
|
||||
// Decide on the number of splits based on the batch size
|
||||
int num_splits = batch_size > 128 ? 2 : batch_size > 64 ? 4 : 16;
|
||||
dim3 grid(batch_size, num_splits);
|
||||
dim3 block(576);
|
||||
const int total_tokens = dst.size(0);
|
||||
constexpr int warps_per_block = 8;
|
||||
const int grid_size = (total_tokens + warps_per_block - 1) / warps_per_block;
|
||||
const int block_size_threads = warps_per_block * 32; // 256 threads
|
||||
|
||||
vllm::cp_gather_and_upconvert_fp8_kv_cache<<<grid, block, 0, stream>>>(
|
||||
vllm::cp_gather_and_upconvert_fp8_kv_cache<<<grid_size, block_size_threads, 0,
|
||||
stream>>>(
|
||||
src_ptr, reinterpret_cast<__nv_bfloat16*>(dst.data_ptr()),
|
||||
block_table.data_ptr<int32_t>(), seq_lens.data_ptr<int32_t>(),
|
||||
workspace_starts.data_ptr<int32_t>(), block_size, head_dim,
|
||||
block_table.data_ptr<int32_t>(), workspace_starts.data_ptr<int32_t>(),
|
||||
static_cast<int32_t>(batch_size), block_size, total_tokens,
|
||||
block_table_stride, cache_block_stride, cache_entry_stride,
|
||||
dst_entry_stride);
|
||||
}
|
||||
|
||||
363
tests/kernels/test_cp_gather_fp8.py
Normal file
363
tests/kernels/test_cp_gather_fp8.py
Normal file
@@ -0,0 +1,363 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import math
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
# DeepSeek V3 MLA dimensions
|
||||
NOPE_DIM = 512 # NoPE latent dimension (FP8 quantized in cache)
|
||||
ROPE_DIM = 64 # RoPE dimension (stored as BF16 in cache)
|
||||
NUM_TILES = 4 # NOPE_DIM / GROUP_SIZE = 512 / 128
|
||||
GROUP_SIZE = 128 # FP8 quantization group size (one scale per group)
|
||||
ENTRY_BYTES = 656 # 512 (FP8) + 16 (4×float32 scales) + 128 (64×BF16 RoPE)
|
||||
|
||||
|
||||
def _build_test_case(seq_lens, block_size, seed=42):
|
||||
"""Build a synthetic FP8 cache and compute the expected BF16 output.
|
||||
|
||||
This simulates what concat_and_cache_ds_mla_kernel writes into the
|
||||
KV cache, then computes what cp_gather_and_upconvert should produce.
|
||||
|
||||
Args:
|
||||
seq_lens: List of sequence lengths, one per request.
|
||||
block_size: Number of tokens per physical cache block.
|
||||
seed: Random seed for reproducibility.
|
||||
|
||||
Returns:
|
||||
Tuple of (cache, block_table, seq_lens_t, workspace_starts_t,
|
||||
num_reqs, total_tokens, expected_output).
|
||||
"""
|
||||
torch.manual_seed(seed)
|
||||
|
||||
num_reqs = len(seq_lens)
|
||||
total_tokens = sum(seq_lens)
|
||||
|
||||
# workspace_starts[r] = sum of seq_lens[0..r-1]
|
||||
# This tells the kernel where in the output buffer each request's
|
||||
# gathered tokens should be written.
|
||||
workspace_starts = []
|
||||
s = 0
|
||||
for sl in seq_lens:
|
||||
workspace_starts.append(s)
|
||||
s += sl
|
||||
|
||||
# How many physical cache blocks each request needs
|
||||
blocks_per_req = [math.ceil(s / block_size) for s in seq_lens]
|
||||
total_blocks = sum(blocks_per_req)
|
||||
max_blocks = max(blocks_per_req)
|
||||
|
||||
# Block table maps (request, logical_block_idx) -> physical_block_id.
|
||||
# Here we assign blocks contiguously: request 0 gets blocks [0, 1, ...],
|
||||
# request 1 gets the next set, etc.
|
||||
block_table = torch.zeros(num_reqs, max_blocks, dtype=torch.int32, device="cuda")
|
||||
block_idx = 0
|
||||
for r in range(num_reqs):
|
||||
for b in range(blocks_per_req[r]):
|
||||
block_table[r, b] = block_idx
|
||||
block_idx += 1
|
||||
|
||||
# The raw paged cache: [num_blocks, block_size, 656] as uint8
|
||||
cache = torch.zeros(
|
||||
total_blocks, block_size, ENTRY_BYTES, dtype=torch.uint8, device="cuda"
|
||||
)
|
||||
# Expected kernel output: [total_tokens, 576] as BF16
|
||||
expected = torch.zeros(
|
||||
total_tokens, NOPE_DIM + ROPE_DIM, dtype=torch.bfloat16, device="cuda"
|
||||
)
|
||||
|
||||
# Fill each token's cache entry and compute expected output
|
||||
for r in range(num_reqs):
|
||||
for t in range(seq_lens[r]):
|
||||
out_idx = workspace_starts[r] + t
|
||||
# Map token position -> (physical_block, offset_within_block)
|
||||
phys = block_table[r, t // block_size].item()
|
||||
off = t % block_size
|
||||
|
||||
# --- NoPE section: 4 tiles of 128 FP8 values, each with a scale ---
|
||||
for tile in range(NUM_TILES):
|
||||
start = tile * GROUP_SIZE
|
||||
|
||||
# Generate random data and quantize to FP8 e4m3
|
||||
fp8_vals = torch.randn(GROUP_SIZE, device="cuda").to(
|
||||
torch.float8_e4m3fn
|
||||
)
|
||||
# Pack FP8 bytes into cache at bytes [start : start+128]
|
||||
cache[phys, off, start : start + GROUP_SIZE] = fp8_vals.view(
|
||||
torch.uint8
|
||||
)
|
||||
|
||||
# Random positive scale in [0.1, 2.1]
|
||||
scale = (torch.rand(1, device="cuda") * 2.0 + 0.1).item()
|
||||
scale_t = torch.tensor([scale], dtype=torch.float32, device="cuda")
|
||||
# Pack scale as 4 raw bytes at bytes [512 + tile*4 : ...]
|
||||
cache[phys, off, NOPE_DIM + tile * 4 : NOPE_DIM + (tile + 1) * 4] = (
|
||||
scale_t.view(torch.uint8)
|
||||
)
|
||||
|
||||
# Reference dequant: fp8 -> float32, multiply scale, -> bf16.
|
||||
# This matches the CUDA path: fp8 -> half -> float * scale -> bf16.
|
||||
# (fp8 -> half is exact, half -> float is exact, so fp8 -> float
|
||||
# gives the same result regardless of intermediate type.)
|
||||
expected[out_idx, start : start + GROUP_SIZE] = (
|
||||
fp8_vals.float() * scale
|
||||
).bfloat16()
|
||||
|
||||
# --- RoPE section: 64 BF16 values, direct copy (no dequant) ---
|
||||
rope = torch.randn(ROPE_DIM, dtype=torch.bfloat16, device="cuda")
|
||||
# Pack RoPE bytes into cache at bytes [528 : 656]
|
||||
cache[phys, off, NOPE_DIM + 16 :] = rope.view(torch.uint8)
|
||||
# Expected output: exact copy
|
||||
expected[out_idx, NOPE_DIM:] = rope
|
||||
|
||||
seq_lens_t = torch.tensor(seq_lens, dtype=torch.int32, device="cuda")
|
||||
workspace_starts_t = torch.tensor(
|
||||
workspace_starts, dtype=torch.int32, device="cuda"
|
||||
)
|
||||
|
||||
return (
|
||||
cache,
|
||||
block_table,
|
||||
seq_lens_t,
|
||||
workspace_starts_t,
|
||||
num_reqs,
|
||||
total_tokens,
|
||||
expected,
|
||||
)
|
||||
|
||||
|
||||
def _build_test_case_fast(seq_lens, block_size, seed=42):
|
||||
"""Vectorized test-case builder for large sequence lengths.
|
||||
|
||||
Same logic as _build_test_case but uses tensor operations instead of
|
||||
per-token Python loops, making it practical for seq_lens up to 128K+.
|
||||
"""
|
||||
torch.manual_seed(seed)
|
||||
|
||||
num_reqs = len(seq_lens)
|
||||
total_tokens = sum(seq_lens)
|
||||
|
||||
workspace_starts = []
|
||||
s = 0
|
||||
for sl in seq_lens:
|
||||
workspace_starts.append(s)
|
||||
s += sl
|
||||
|
||||
blocks_per_req = [math.ceil(sl / block_size) for sl in seq_lens]
|
||||
total_blocks = sum(blocks_per_req)
|
||||
max_blocks = max(blocks_per_req)
|
||||
|
||||
# Contiguous block allocation
|
||||
block_table = torch.zeros(num_reqs, max_blocks, dtype=torch.int32, device="cuda")
|
||||
block_idx = 0
|
||||
for r in range(num_reqs):
|
||||
for b in range(blocks_per_req[r]):
|
||||
block_table[r, b] = block_idx
|
||||
block_idx += 1
|
||||
|
||||
cache = torch.zeros(
|
||||
total_blocks, block_size, ENTRY_BYTES, dtype=torch.uint8, device="cuda"
|
||||
)
|
||||
|
||||
# Generate all data vectorized
|
||||
nope_fp8 = torch.randn(total_tokens, NOPE_DIM, device="cuda").to(
|
||||
torch.float8_e4m3fn
|
||||
)
|
||||
scales = (torch.rand(total_tokens, NUM_TILES, device="cuda") * 2.0 + 0.1).float()
|
||||
rope = torch.randn(total_tokens, ROPE_DIM, dtype=torch.bfloat16, device="cuda")
|
||||
|
||||
# Compute expected output vectorized (same dequant logic as kernel)
|
||||
expected = torch.zeros(
|
||||
total_tokens, NOPE_DIM + ROPE_DIM, dtype=torch.bfloat16, device="cuda"
|
||||
)
|
||||
for tile in range(NUM_TILES):
|
||||
start = tile * GROUP_SIZE
|
||||
expected[:, start : start + GROUP_SIZE] = (
|
||||
nope_fp8[:, start : start + GROUP_SIZE].float() * scales[:, tile : tile + 1]
|
||||
).bfloat16()
|
||||
expected[:, NOPE_DIM:] = rope
|
||||
|
||||
# Build per-token cache entries as [total_tokens, 656] uint8
|
||||
token_data = torch.zeros(
|
||||
total_tokens, ENTRY_BYTES, dtype=torch.uint8, device="cuda"
|
||||
)
|
||||
token_data[:, :NOPE_DIM] = nope_fp8.view(torch.uint8)
|
||||
token_data[:, NOPE_DIM : NOPE_DIM + 16] = scales.view(torch.uint8)
|
||||
token_data[:, NOPE_DIM + 16 :] = rope.view(torch.uint8)
|
||||
|
||||
# Scatter into paged cache (loop over requests, not tokens)
|
||||
block_start = 0
|
||||
for r in range(num_reqs):
|
||||
sl = seq_lens[r]
|
||||
nb = blocks_per_req[r]
|
||||
ws = workspace_starts[r]
|
||||
flat_cache = cache[block_start : block_start + nb].reshape(-1, ENTRY_BYTES)
|
||||
flat_cache[:sl] = token_data[ws : ws + sl]
|
||||
block_start += nb
|
||||
|
||||
seq_lens_t = torch.tensor(seq_lens, dtype=torch.int32, device="cuda")
|
||||
workspace_starts_t = torch.tensor(
|
||||
workspace_starts, dtype=torch.int32, device="cuda"
|
||||
)
|
||||
|
||||
return (
|
||||
cache,
|
||||
block_table,
|
||||
seq_lens_t,
|
||||
workspace_starts_t,
|
||||
num_reqs,
|
||||
total_tokens,
|
||||
expected,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"seq_lens,block_size",
|
||||
[
|
||||
# Production block_size=64 (only supported value for FlashMLA sparse).
|
||||
# Realistic prefill scenarios with varying request counts.
|
||||
([1], 64), # single token edge case
|
||||
([64], 64), # 1 req, exactly one block
|
||||
([128], 64), # 1 req, crosses block boundary
|
||||
([512], 64), # 1 req, longer prefill
|
||||
([256, 128, 384], 64), # 3 reqs, varying lengths
|
||||
([128] * 4, 64), # 4 reqs, equal lengths
|
||||
([64] * 16, 64), # 16 reqs, shorter prefills
|
||||
],
|
||||
)
|
||||
def test_cp_gather_and_upconvert_fp8_kv_cache(seq_lens, block_size):
|
||||
"""Core correctness test: build cache, run kernel, compare output."""
|
||||
(
|
||||
cache,
|
||||
block_table,
|
||||
seq_lens_t,
|
||||
workspace_starts_t,
|
||||
num_reqs,
|
||||
total_tokens,
|
||||
expected,
|
||||
) = _build_test_case(seq_lens, block_size)
|
||||
|
||||
dst = torch.zeros(
|
||||
total_tokens, NOPE_DIM + ROPE_DIM, dtype=torch.bfloat16, device="cuda"
|
||||
)
|
||||
|
||||
ops.cp_gather_and_upconvert_fp8_kv_cache(
|
||||
cache, dst, block_table, seq_lens_t, workspace_starts_t, num_reqs
|
||||
)
|
||||
|
||||
# NoPE: fp8 dequant has rounding error, so we allow small tolerance.
|
||||
# The fp8 -> float -> bf16 path can differ by up to ~1 ULP of bf16.
|
||||
torch.testing.assert_close(
|
||||
dst[:, :NOPE_DIM], expected[:, :NOPE_DIM], atol=1e-3, rtol=1e-2
|
||||
)
|
||||
|
||||
# RoPE: pure bf16 copy, must be bit-exact
|
||||
assert torch.equal(dst[:, NOPE_DIM:], expected[:, NOPE_DIM:])
|
||||
|
||||
|
||||
def test_cp_gather_fp8_shuffled_blocks():
|
||||
"""Test that the kernel correctly follows the block table when
|
||||
physical blocks are non-contiguous and out of order.
|
||||
|
||||
Here we allocate 4 physical blocks but map the request's 2 logical
|
||||
blocks to physical blocks [3, 1] (reversed, with gaps).
|
||||
"""
|
||||
torch.manual_seed(123)
|
||||
block_size = 4
|
||||
seq_lens = [8] # needs 2 blocks (tokens 0-3 in block 0, 4-7 in block 1)
|
||||
total_tokens = 8
|
||||
|
||||
# 4 physical blocks, but only blocks 3 and 1 are used (in that order).
|
||||
# Tokens 0-3 -> physical block 3, tokens 4-7 -> physical block 1.
|
||||
num_phys_blocks = 4
|
||||
cache = torch.zeros(
|
||||
num_phys_blocks, block_size, ENTRY_BYTES, dtype=torch.uint8, device="cuda"
|
||||
)
|
||||
block_table = torch.tensor([[3, 1]], dtype=torch.int32, device="cuda")
|
||||
workspace_starts = torch.tensor([0], dtype=torch.int32, device="cuda")
|
||||
seq_lens_t = torch.tensor(seq_lens, dtype=torch.int32, device="cuda")
|
||||
|
||||
expected = torch.zeros(
|
||||
total_tokens, NOPE_DIM + ROPE_DIM, dtype=torch.bfloat16, device="cuda"
|
||||
)
|
||||
|
||||
# Fill cache at the shuffled physical locations
|
||||
for t in range(total_tokens):
|
||||
# Follow the same block_table lookup the kernel will use
|
||||
phys = block_table[0, t // block_size].item()
|
||||
off = t % block_size
|
||||
|
||||
for tile in range(NUM_TILES):
|
||||
start = tile * GROUP_SIZE
|
||||
fp8_vals = torch.randn(GROUP_SIZE, device="cuda").to(torch.float8_e4m3fn)
|
||||
cache[phys, off, start : start + GROUP_SIZE] = fp8_vals.view(torch.uint8)
|
||||
|
||||
# Use a fixed scale to keep this test simple
|
||||
scale = 1.5
|
||||
scale_t = torch.tensor([scale], dtype=torch.float32, device="cuda")
|
||||
cache[phys, off, NOPE_DIM + tile * 4 : NOPE_DIM + (tile + 1) * 4] = (
|
||||
scale_t.view(torch.uint8)
|
||||
)
|
||||
|
||||
expected[t, start : start + GROUP_SIZE] = (
|
||||
fp8_vals.float() * scale
|
||||
).bfloat16()
|
||||
|
||||
rope = torch.randn(ROPE_DIM, dtype=torch.bfloat16, device="cuda")
|
||||
cache[phys, off, NOPE_DIM + 16 :] = rope.view(torch.uint8)
|
||||
expected[t, NOPE_DIM:] = rope
|
||||
|
||||
dst = torch.zeros(
|
||||
total_tokens, NOPE_DIM + ROPE_DIM, dtype=torch.bfloat16, device="cuda"
|
||||
)
|
||||
|
||||
ops.cp_gather_and_upconvert_fp8_kv_cache(
|
||||
cache, dst, block_table, seq_lens_t, workspace_starts, len(seq_lens)
|
||||
)
|
||||
|
||||
torch.testing.assert_close(
|
||||
dst[:, :NOPE_DIM], expected[:, :NOPE_DIM], atol=1e-3, rtol=1e-2
|
||||
)
|
||||
assert torch.equal(dst[:, NOPE_DIM:], expected[:, NOPE_DIM:])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"seq_lens,block_size",
|
||||
[
|
||||
# Large sequence lengths matching end-to-end benchmark scenarios.
|
||||
# Uses vectorized builder since per-token Python loops would be too slow.
|
||||
([8000], 64),
|
||||
([16000], 64),
|
||||
([32000], 64),
|
||||
([64000], 64),
|
||||
([96000], 64),
|
||||
([128000], 64),
|
||||
],
|
||||
)
|
||||
def test_cp_gather_fp8_large_seqlens(seq_lens, block_size):
|
||||
"""Correctness test with large sequence lengths matching benchmark
|
||||
scenarios (8K-128K prefill)."""
|
||||
(
|
||||
cache,
|
||||
block_table,
|
||||
seq_lens_t,
|
||||
workspace_starts_t,
|
||||
num_reqs,
|
||||
total_tokens,
|
||||
expected,
|
||||
) = _build_test_case_fast(seq_lens, block_size)
|
||||
|
||||
dst = torch.zeros(
|
||||
total_tokens, NOPE_DIM + ROPE_DIM, dtype=torch.bfloat16, device="cuda"
|
||||
)
|
||||
|
||||
ops.cp_gather_and_upconvert_fp8_kv_cache(
|
||||
cache, dst, block_table, seq_lens_t, workspace_starts_t, num_reqs
|
||||
)
|
||||
|
||||
torch.testing.assert_close(
|
||||
dst[:, :NOPE_DIM], expected[:, :NOPE_DIM], atol=1e-3, rtol=1e-2
|
||||
)
|
||||
assert torch.equal(dst[:, NOPE_DIM:], expected[:, NOPE_DIM:])
|
||||
Reference in New Issue
Block a user