[Attention][Perf] Optimize cp_gather_and_upconvert_fp8_kv_cache - DeepSeek-v3.2 (#35290)

Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
Roberto L. Castro
2026-03-09 17:46:57 +01:00
committed by GitHub
parent 70485a11bd
commit 2b28b9b269
3 changed files with 583 additions and 74 deletions

View File

@@ -0,0 +1,153 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import math
import torch
from vllm import _custom_ops as ops
from vllm.triton_utils import triton
# DeepSeek V3 MLA dimensions
NOPE_DIM = 512
ROPE_DIM = 64
HEAD_DIM = NOPE_DIM + ROPE_DIM # 576 BF16 output elements per token
ENTRY_BYTES = 656 # 512 FP8 + 16 scales + 128 BF16 RoPE
BLOCK_SIZE = 64 # tokens per physical cache block - get_supported_kernel_block_sizes
# Realistic prefill scenarios:
# - 1 long prefill: single request, 16K-96K tokens
# - 4 medium prefills: 4 requests, 4K-24K tokens each
# - 16 shorter prefills: 16 requests, 1K-6K tokens each
SCENARIOS = [
# (label, num_reqs, total_tokens_list)
("1-req", 1, [8192, 16384, 32768, 65536, 98304]),
("4-reqs", 4, [8192, 16384, 32768, 65536, 98304]),
("16-reqs", 16, [8192, 16384, 32768, 65536, 98304]),
]
def make_inputs(total_tokens, num_reqs, block_size):
"""Create synthetic FP8 cache, block table, and output buffer.
Fills the cache with random bytes (we only measure throughput,
not correctness). Block table maps each request to contiguous
physical blocks.
"""
# Divide tokens evenly across requests
base_len = total_tokens // num_reqs
remainder = total_tokens % num_reqs
seq_lens = [base_len + (1 if r < remainder else 0) for r in range(num_reqs)]
# workspace_starts: cumulative sum of seq_lens
workspace_starts = [0] * num_reqs
for r in range(1, num_reqs):
workspace_starts[r] = workspace_starts[r - 1] + seq_lens[r - 1]
# Physical blocks needed per request
blocks_per_req = [math.ceil(s / block_size) for s in seq_lens]
total_blocks = sum(blocks_per_req)
max_blocks = max(blocks_per_req)
# Allocate cache with random data (content doesn't matter for perf)
cache = torch.randint(
0,
256,
(total_blocks, block_size, ENTRY_BYTES),
dtype=torch.uint8,
device="cuda",
)
# Block table: contiguous block assignments
block_table = torch.zeros(num_reqs, max_blocks, dtype=torch.int32, device="cuda")
block_idx = 0
for r in range(num_reqs):
for b in range(blocks_per_req[r]):
block_table[r, b] = block_idx
block_idx += 1
# Output workspace
dst = torch.zeros(total_tokens, HEAD_DIM, dtype=torch.bfloat16, device="cuda")
seq_lens_t = torch.tensor(seq_lens, dtype=torch.int32, device="cuda")
workspace_starts_t = torch.tensor(
workspace_starts, dtype=torch.int32, device="cuda"
)
return cache, dst, block_table, seq_lens_t, workspace_starts_t
def bench_scenario(label, num_reqs, total_tokens_list, save_path):
"""Run benchmark for a specific (num_reqs, total_tokens) scenario."""
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["total_tokens"],
x_vals=total_tokens_list,
line_arg="provider",
line_vals=["cuda_kernel"],
line_names=["cp_gather_fp8 (CUDA)"],
styles=[("green", "-")],
ylabel="Latency (us)",
plot_name=f"cp_gather_fp8-{label}-bs{BLOCK_SIZE}",
args={"num_reqs": num_reqs},
)
)
def bench_fn(total_tokens, provider, num_reqs):
cache, dst, block_table, seq_lens_t, ws_starts = make_inputs(
total_tokens, num_reqs, BLOCK_SIZE
)
quantiles = [0.5, 0.2, 0.8]
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: ops.cp_gather_and_upconvert_fp8_kv_cache(
cache, dst, block_table, seq_lens_t, ws_starts, num_reqs
),
quantiles=quantiles,
rep=500,
)
return ms * 1000, max_ms * 1000, min_ms * 1000 # us
seq_len_per_req = total_tokens_list[0] // num_reqs
seq_len_per_req_max = total_tokens_list[-1] // num_reqs
print(
f"\n--- {label}: {num_reqs} request(s), "
f"~{seq_len_per_req}-{seq_len_per_req_max} tokens/req ---"
)
bench_fn.run(print_data=True, save_path=save_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Benchmark cp_gather_and_upconvert_fp8_kv_cache"
)
parser.add_argument(
"--save-path",
type=str,
default=None,
help="Path to save benchmark results as CSV",
)
args = parser.parse_args()
# Print data volume info for bandwidth analysis
read_per_token = ENTRY_BYTES # 656 bytes from cache
write_per_token = HEAD_DIM * 2 # 576 * 2 = 1152 bytes to workspace
total_per_token = read_per_token + write_per_token # 1808 bytes
print("\n" + "=" * 70)
print("CP_GATHER_AND_UPCONVERT_FP8_KV_CACHE BENCHMARKS")
print("=" * 70)
print(f"Cache entry: {ENTRY_BYTES} bytes (512 FP8 + 16 scales + 128 RoPE)")
print(f"Output row: {HEAD_DIM} BF16 = {HEAD_DIM * 2} bytes")
print(f"Per token: {total_per_token} bytes (read + write)")
print(f"Block size: {BLOCK_SIZE} tokens/block")
print("=" * 70)
for label, num_reqs, total_tokens_list in SCENARIOS:
bench_scenario(label, num_reqs, total_tokens_list, args.save_path)
print("\n" + "=" * 70)
print("Benchmarking complete!")
print("=" * 70)

View File

@@ -995,75 +995,67 @@ namespace vllm {
// Similar to cp_gather_cache but specifically for FP8->BF16 conversion
__global__ void cp_gather_and_upconvert_fp8_kv_cache(
const uint8_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE, 656]
__nv_bfloat16* __restrict__ dst, // [TOT_TOKENS, 576]
const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES]
const int32_t* __restrict__ seq_lens, // [BATCH]
const int32_t* __restrict__ workspace_starts, // [BATCH]
const int32_t block_size, const int32_t head_dim,
const int64_t block_table_stride, const int64_t cache_block_stride,
const int64_t cache_entry_stride, const int64_t dst_entry_stride) {
const int64_t bid = blockIdx.x; // Batch ID
const int32_t num_splits = gridDim.y;
const int32_t split = blockIdx.y;
const int32_t seq_start = workspace_starts[bid];
const int32_t seq_len = seq_lens[bid];
const int32_t tot_slots = seq_len;
const int32_t split_slots = cuda_utils::ceil_div(tot_slots, num_splits);
__nv_bfloat16* __restrict__ dst, // [total_tokens, 576]
const int32_t* __restrict__ block_table, // [num_reqs, BLOCK_INDICES]
const int32_t* __restrict__ workspace_starts, // [num_reqs]
const int32_t num_reqs, const int32_t block_size,
const int32_t total_tokens, const int64_t block_table_stride,
const int64_t cache_block_stride, const int64_t cache_entry_stride,
const int64_t dst_entry_stride) {
const int flat_warp_id = (blockIdx.x * blockDim.x + threadIdx.x) >> 5;
if (flat_warp_id >= total_tokens) return;
const int lane_id = threadIdx.x & 31;
const int32_t split_start = split * split_slots;
const int32_t split_end = min((split + 1) * split_slots, tot_slots);
const bool is_active_split = (split_start < tot_slots);
if (!is_active_split) return;
// Adjust the pointer for the block_table for this batch
const int32_t batch_offset = bid * block_table_stride;
int32_t offset = split_start;
int32_t offset_div = offset / block_size;
offset = offset % block_size;
const int32_t* batch_block_table = block_table + batch_offset;
// Adjust dst pointer based on the cumulative sequence lengths
dst += seq_start * dst_entry_stride;
const int tid = threadIdx.x;
// Process each token in this split
for (int pid = split_start; pid < split_end; ++pid) {
auto block_id = batch_block_table[offset_div];
const uint8_t* token_ptr =
src_cache + block_id * cache_block_stride + offset * cache_entry_stride;
__nv_bfloat16* dst_ptr = dst + pid * dst_entry_stride;
// FP8 format: 512 bytes fp8 + 16 bytes scales + 128 bytes rope (64 bf16)
const uint8_t* no_pe_ptr = token_ptr;
const float* scales_ptr = reinterpret_cast<const float*>(token_ptr + 512);
const __nv_bfloat16* rope_ptr =
reinterpret_cast<const __nv_bfloat16*>(token_ptr + 512 + 16);
// Parallelize fp8 dequant (512 elements) and rope copy (64 elements)
if (tid < 512) {
// FP8 dequantization
const int tile = tid >> 7; // each tile is 128 elements
const float scale = scales_ptr[tile];
const uint8_t val = no_pe_ptr[tid];
dst_ptr[tid] =
fp8::scaled_convert<__nv_bfloat16, uint8_t,
vllm::Fp8KVCacheDataType::kFp8E4M3>(val, scale);
} else if (tid < 576) {
// Rope copy (64 bf16 elements)
const int rope_idx = tid - 512;
dst_ptr[512 + rope_idx] = rope_ptr[rope_idx];
}
// Move to next token
offset += 1;
if (offset == block_size) {
offset_div += 1;
offset = 0;
}
// Binary search to find which request owns this output token
int lo = 0, hi = num_reqs - 1;
while (lo < hi) {
int mid = (lo + hi + 1) >> 1;
if (workspace_starts[mid] <= flat_warp_id)
lo = mid;
else
hi = mid - 1;
}
const int req_id = lo;
// Compute physical token address via block table
const int out_token_id = flat_warp_id;
const int token_offset = out_token_id - workspace_starts[req_id];
const int cache_block_idx = token_offset / block_size;
const int offset_in_block = token_offset % block_size;
const int physical_block =
block_table[req_id * block_table_stride + cache_block_idx];
const uint8_t* token_ptr = src_cache + physical_block * cache_block_stride +
offset_in_block * cache_entry_stride;
const int4* nope_src = reinterpret_cast<const int4*>(token_ptr);
const int4 fp8_data = nope_src[lane_id];
const float* scales_ptr = reinterpret_cast<const float*>(token_ptr + 512);
const float scale = scales_ptr[lane_id >> 3];
const uint2 fp8_lo = make_uint2(fp8_data.x, fp8_data.y);
const uint2 fp8_hi = make_uint2(fp8_data.z, fp8_data.w);
#ifdef USE_ROCM
const bf16_8_t bf16_lo =
fp8::scaled_vec_conversion<bf16_8_t, uint2>(fp8_lo, scale);
const bf16_8_t bf16_hi =
fp8::scaled_vec_conversion<bf16_8_t, uint2>(fp8_hi, scale);
#else
const bf16_8_t bf16_lo =
fp8::scaled_vec_conversion<bf16_8_t, uint2>(fp8_lo, scale, __NV_E4M3);
const bf16_8_t bf16_hi =
fp8::scaled_vec_conversion<bf16_8_t, uint2>(fp8_hi, scale, __NV_E4M3);
#endif
__nv_bfloat16* dst_ptr = dst + out_token_id * dst_entry_stride;
int4* nope_dst = reinterpret_cast<int4*>(dst_ptr) + lane_id * 2;
nope_dst[0] = *reinterpret_cast<const int4*>(&bf16_lo);
nope_dst[1] = *reinterpret_cast<const int4*>(&bf16_hi);
const int* rope_src = reinterpret_cast<const int*>(token_ptr + 528);
int* rope_dst = reinterpret_cast<int*>(dst_ptr + 512);
rope_dst[lane_id] = rope_src[lane_id];
}
template <typename scalar_t>
@@ -1257,15 +1249,16 @@ void cp_gather_and_upconvert_fp8_kv_cache(
src_ptr = reinterpret_cast<const uint8_t*>(src_cache.data_ptr());
}
// Decide on the number of splits based on the batch size
int num_splits = batch_size > 128 ? 2 : batch_size > 64 ? 4 : 16;
dim3 grid(batch_size, num_splits);
dim3 block(576);
const int total_tokens = dst.size(0);
constexpr int warps_per_block = 8;
const int grid_size = (total_tokens + warps_per_block - 1) / warps_per_block;
const int block_size_threads = warps_per_block * 32; // 256 threads
vllm::cp_gather_and_upconvert_fp8_kv_cache<<<grid, block, 0, stream>>>(
vllm::cp_gather_and_upconvert_fp8_kv_cache<<<grid_size, block_size_threads, 0,
stream>>>(
src_ptr, reinterpret_cast<__nv_bfloat16*>(dst.data_ptr()),
block_table.data_ptr<int32_t>(), seq_lens.data_ptr<int32_t>(),
workspace_starts.data_ptr<int32_t>(), block_size, head_dim,
block_table.data_ptr<int32_t>(), workspace_starts.data_ptr<int32_t>(),
static_cast<int32_t>(batch_size), block_size, total_tokens,
block_table_stride, cache_block_stride, cache_entry_stride,
dst_entry_stride);
}

View File

@@ -0,0 +1,363 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
import pytest
import torch
from vllm import _custom_ops as ops
# DeepSeek V3 MLA dimensions
NOPE_DIM = 512 # NoPE latent dimension (FP8 quantized in cache)
ROPE_DIM = 64 # RoPE dimension (stored as BF16 in cache)
NUM_TILES = 4 # NOPE_DIM / GROUP_SIZE = 512 / 128
GROUP_SIZE = 128 # FP8 quantization group size (one scale per group)
ENTRY_BYTES = 656 # 512 (FP8) + 16 (4×float32 scales) + 128 (64×BF16 RoPE)
def _build_test_case(seq_lens, block_size, seed=42):
"""Build a synthetic FP8 cache and compute the expected BF16 output.
This simulates what concat_and_cache_ds_mla_kernel writes into the
KV cache, then computes what cp_gather_and_upconvert should produce.
Args:
seq_lens: List of sequence lengths, one per request.
block_size: Number of tokens per physical cache block.
seed: Random seed for reproducibility.
Returns:
Tuple of (cache, block_table, seq_lens_t, workspace_starts_t,
num_reqs, total_tokens, expected_output).
"""
torch.manual_seed(seed)
num_reqs = len(seq_lens)
total_tokens = sum(seq_lens)
# workspace_starts[r] = sum of seq_lens[0..r-1]
# This tells the kernel where in the output buffer each request's
# gathered tokens should be written.
workspace_starts = []
s = 0
for sl in seq_lens:
workspace_starts.append(s)
s += sl
# How many physical cache blocks each request needs
blocks_per_req = [math.ceil(s / block_size) for s in seq_lens]
total_blocks = sum(blocks_per_req)
max_blocks = max(blocks_per_req)
# Block table maps (request, logical_block_idx) -> physical_block_id.
# Here we assign blocks contiguously: request 0 gets blocks [0, 1, ...],
# request 1 gets the next set, etc.
block_table = torch.zeros(num_reqs, max_blocks, dtype=torch.int32, device="cuda")
block_idx = 0
for r in range(num_reqs):
for b in range(blocks_per_req[r]):
block_table[r, b] = block_idx
block_idx += 1
# The raw paged cache: [num_blocks, block_size, 656] as uint8
cache = torch.zeros(
total_blocks, block_size, ENTRY_BYTES, dtype=torch.uint8, device="cuda"
)
# Expected kernel output: [total_tokens, 576] as BF16
expected = torch.zeros(
total_tokens, NOPE_DIM + ROPE_DIM, dtype=torch.bfloat16, device="cuda"
)
# Fill each token's cache entry and compute expected output
for r in range(num_reqs):
for t in range(seq_lens[r]):
out_idx = workspace_starts[r] + t
# Map token position -> (physical_block, offset_within_block)
phys = block_table[r, t // block_size].item()
off = t % block_size
# --- NoPE section: 4 tiles of 128 FP8 values, each with a scale ---
for tile in range(NUM_TILES):
start = tile * GROUP_SIZE
# Generate random data and quantize to FP8 e4m3
fp8_vals = torch.randn(GROUP_SIZE, device="cuda").to(
torch.float8_e4m3fn
)
# Pack FP8 bytes into cache at bytes [start : start+128]
cache[phys, off, start : start + GROUP_SIZE] = fp8_vals.view(
torch.uint8
)
# Random positive scale in [0.1, 2.1]
scale = (torch.rand(1, device="cuda") * 2.0 + 0.1).item()
scale_t = torch.tensor([scale], dtype=torch.float32, device="cuda")
# Pack scale as 4 raw bytes at bytes [512 + tile*4 : ...]
cache[phys, off, NOPE_DIM + tile * 4 : NOPE_DIM + (tile + 1) * 4] = (
scale_t.view(torch.uint8)
)
# Reference dequant: fp8 -> float32, multiply scale, -> bf16.
# This matches the CUDA path: fp8 -> half -> float * scale -> bf16.
# (fp8 -> half is exact, half -> float is exact, so fp8 -> float
# gives the same result regardless of intermediate type.)
expected[out_idx, start : start + GROUP_SIZE] = (
fp8_vals.float() * scale
).bfloat16()
# --- RoPE section: 64 BF16 values, direct copy (no dequant) ---
rope = torch.randn(ROPE_DIM, dtype=torch.bfloat16, device="cuda")
# Pack RoPE bytes into cache at bytes [528 : 656]
cache[phys, off, NOPE_DIM + 16 :] = rope.view(torch.uint8)
# Expected output: exact copy
expected[out_idx, NOPE_DIM:] = rope
seq_lens_t = torch.tensor(seq_lens, dtype=torch.int32, device="cuda")
workspace_starts_t = torch.tensor(
workspace_starts, dtype=torch.int32, device="cuda"
)
return (
cache,
block_table,
seq_lens_t,
workspace_starts_t,
num_reqs,
total_tokens,
expected,
)
def _build_test_case_fast(seq_lens, block_size, seed=42):
"""Vectorized test-case builder for large sequence lengths.
Same logic as _build_test_case but uses tensor operations instead of
per-token Python loops, making it practical for seq_lens up to 128K+.
"""
torch.manual_seed(seed)
num_reqs = len(seq_lens)
total_tokens = sum(seq_lens)
workspace_starts = []
s = 0
for sl in seq_lens:
workspace_starts.append(s)
s += sl
blocks_per_req = [math.ceil(sl / block_size) for sl in seq_lens]
total_blocks = sum(blocks_per_req)
max_blocks = max(blocks_per_req)
# Contiguous block allocation
block_table = torch.zeros(num_reqs, max_blocks, dtype=torch.int32, device="cuda")
block_idx = 0
for r in range(num_reqs):
for b in range(blocks_per_req[r]):
block_table[r, b] = block_idx
block_idx += 1
cache = torch.zeros(
total_blocks, block_size, ENTRY_BYTES, dtype=torch.uint8, device="cuda"
)
# Generate all data vectorized
nope_fp8 = torch.randn(total_tokens, NOPE_DIM, device="cuda").to(
torch.float8_e4m3fn
)
scales = (torch.rand(total_tokens, NUM_TILES, device="cuda") * 2.0 + 0.1).float()
rope = torch.randn(total_tokens, ROPE_DIM, dtype=torch.bfloat16, device="cuda")
# Compute expected output vectorized (same dequant logic as kernel)
expected = torch.zeros(
total_tokens, NOPE_DIM + ROPE_DIM, dtype=torch.bfloat16, device="cuda"
)
for tile in range(NUM_TILES):
start = tile * GROUP_SIZE
expected[:, start : start + GROUP_SIZE] = (
nope_fp8[:, start : start + GROUP_SIZE].float() * scales[:, tile : tile + 1]
).bfloat16()
expected[:, NOPE_DIM:] = rope
# Build per-token cache entries as [total_tokens, 656] uint8
token_data = torch.zeros(
total_tokens, ENTRY_BYTES, dtype=torch.uint8, device="cuda"
)
token_data[:, :NOPE_DIM] = nope_fp8.view(torch.uint8)
token_data[:, NOPE_DIM : NOPE_DIM + 16] = scales.view(torch.uint8)
token_data[:, NOPE_DIM + 16 :] = rope.view(torch.uint8)
# Scatter into paged cache (loop over requests, not tokens)
block_start = 0
for r in range(num_reqs):
sl = seq_lens[r]
nb = blocks_per_req[r]
ws = workspace_starts[r]
flat_cache = cache[block_start : block_start + nb].reshape(-1, ENTRY_BYTES)
flat_cache[:sl] = token_data[ws : ws + sl]
block_start += nb
seq_lens_t = torch.tensor(seq_lens, dtype=torch.int32, device="cuda")
workspace_starts_t = torch.tensor(
workspace_starts, dtype=torch.int32, device="cuda"
)
return (
cache,
block_table,
seq_lens_t,
workspace_starts_t,
num_reqs,
total_tokens,
expected,
)
@pytest.mark.parametrize(
"seq_lens,block_size",
[
# Production block_size=64 (only supported value for FlashMLA sparse).
# Realistic prefill scenarios with varying request counts.
([1], 64), # single token edge case
([64], 64), # 1 req, exactly one block
([128], 64), # 1 req, crosses block boundary
([512], 64), # 1 req, longer prefill
([256, 128, 384], 64), # 3 reqs, varying lengths
([128] * 4, 64), # 4 reqs, equal lengths
([64] * 16, 64), # 16 reqs, shorter prefills
],
)
def test_cp_gather_and_upconvert_fp8_kv_cache(seq_lens, block_size):
"""Core correctness test: build cache, run kernel, compare output."""
(
cache,
block_table,
seq_lens_t,
workspace_starts_t,
num_reqs,
total_tokens,
expected,
) = _build_test_case(seq_lens, block_size)
dst = torch.zeros(
total_tokens, NOPE_DIM + ROPE_DIM, dtype=torch.bfloat16, device="cuda"
)
ops.cp_gather_and_upconvert_fp8_kv_cache(
cache, dst, block_table, seq_lens_t, workspace_starts_t, num_reqs
)
# NoPE: fp8 dequant has rounding error, so we allow small tolerance.
# The fp8 -> float -> bf16 path can differ by up to ~1 ULP of bf16.
torch.testing.assert_close(
dst[:, :NOPE_DIM], expected[:, :NOPE_DIM], atol=1e-3, rtol=1e-2
)
# RoPE: pure bf16 copy, must be bit-exact
assert torch.equal(dst[:, NOPE_DIM:], expected[:, NOPE_DIM:])
def test_cp_gather_fp8_shuffled_blocks():
"""Test that the kernel correctly follows the block table when
physical blocks are non-contiguous and out of order.
Here we allocate 4 physical blocks but map the request's 2 logical
blocks to physical blocks [3, 1] (reversed, with gaps).
"""
torch.manual_seed(123)
block_size = 4
seq_lens = [8] # needs 2 blocks (tokens 0-3 in block 0, 4-7 in block 1)
total_tokens = 8
# 4 physical blocks, but only blocks 3 and 1 are used (in that order).
# Tokens 0-3 -> physical block 3, tokens 4-7 -> physical block 1.
num_phys_blocks = 4
cache = torch.zeros(
num_phys_blocks, block_size, ENTRY_BYTES, dtype=torch.uint8, device="cuda"
)
block_table = torch.tensor([[3, 1]], dtype=torch.int32, device="cuda")
workspace_starts = torch.tensor([0], dtype=torch.int32, device="cuda")
seq_lens_t = torch.tensor(seq_lens, dtype=torch.int32, device="cuda")
expected = torch.zeros(
total_tokens, NOPE_DIM + ROPE_DIM, dtype=torch.bfloat16, device="cuda"
)
# Fill cache at the shuffled physical locations
for t in range(total_tokens):
# Follow the same block_table lookup the kernel will use
phys = block_table[0, t // block_size].item()
off = t % block_size
for tile in range(NUM_TILES):
start = tile * GROUP_SIZE
fp8_vals = torch.randn(GROUP_SIZE, device="cuda").to(torch.float8_e4m3fn)
cache[phys, off, start : start + GROUP_SIZE] = fp8_vals.view(torch.uint8)
# Use a fixed scale to keep this test simple
scale = 1.5
scale_t = torch.tensor([scale], dtype=torch.float32, device="cuda")
cache[phys, off, NOPE_DIM + tile * 4 : NOPE_DIM + (tile + 1) * 4] = (
scale_t.view(torch.uint8)
)
expected[t, start : start + GROUP_SIZE] = (
fp8_vals.float() * scale
).bfloat16()
rope = torch.randn(ROPE_DIM, dtype=torch.bfloat16, device="cuda")
cache[phys, off, NOPE_DIM + 16 :] = rope.view(torch.uint8)
expected[t, NOPE_DIM:] = rope
dst = torch.zeros(
total_tokens, NOPE_DIM + ROPE_DIM, dtype=torch.bfloat16, device="cuda"
)
ops.cp_gather_and_upconvert_fp8_kv_cache(
cache, dst, block_table, seq_lens_t, workspace_starts, len(seq_lens)
)
torch.testing.assert_close(
dst[:, :NOPE_DIM], expected[:, :NOPE_DIM], atol=1e-3, rtol=1e-2
)
assert torch.equal(dst[:, NOPE_DIM:], expected[:, NOPE_DIM:])
@pytest.mark.parametrize(
"seq_lens,block_size",
[
# Large sequence lengths matching end-to-end benchmark scenarios.
# Uses vectorized builder since per-token Python loops would be too slow.
([8000], 64),
([16000], 64),
([32000], 64),
([64000], 64),
([96000], 64),
([128000], 64),
],
)
def test_cp_gather_fp8_large_seqlens(seq_lens, block_size):
"""Correctness test with large sequence lengths matching benchmark
scenarios (8K-128K prefill)."""
(
cache,
block_table,
seq_lens_t,
workspace_starts_t,
num_reqs,
total_tokens,
expected,
) = _build_test_case_fast(seq_lens, block_size)
dst = torch.zeros(
total_tokens, NOPE_DIM + ROPE_DIM, dtype=torch.bfloat16, device="cuda"
)
ops.cp_gather_and_upconvert_fp8_kv_cache(
cache, dst, block_table, seq_lens_t, workspace_starts_t, num_reqs
)
torch.testing.assert_close(
dst[:, :NOPE_DIM], expected[:, :NOPE_DIM], atol=1e-3, rtol=1e-2
)
assert torch.equal(dst[:, NOPE_DIM:], expected[:, NOPE_DIM:])