diff --git a/benchmarks/kernels/bench_cp_gather_fp8.py b/benchmarks/kernels/bench_cp_gather_fp8.py new file mode 100644 index 000000000..19fc84c4d --- /dev/null +++ b/benchmarks/kernels/bench_cp_gather_fp8.py @@ -0,0 +1,153 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import argparse +import math + +import torch + +from vllm import _custom_ops as ops +from vllm.triton_utils import triton + +# DeepSeek V3 MLA dimensions +NOPE_DIM = 512 +ROPE_DIM = 64 +HEAD_DIM = NOPE_DIM + ROPE_DIM # 576 BF16 output elements per token +ENTRY_BYTES = 656 # 512 FP8 + 16 scales + 128 BF16 RoPE +BLOCK_SIZE = 64 # tokens per physical cache block - get_supported_kernel_block_sizes + +# Realistic prefill scenarios: +# - 1 long prefill: single request, 16K-96K tokens +# - 4 medium prefills: 4 requests, 4K-24K tokens each +# - 16 shorter prefills: 16 requests, 1K-6K tokens each +SCENARIOS = [ + # (label, num_reqs, total_tokens_list) + ("1-req", 1, [8192, 16384, 32768, 65536, 98304]), + ("4-reqs", 4, [8192, 16384, 32768, 65536, 98304]), + ("16-reqs", 16, [8192, 16384, 32768, 65536, 98304]), +] + + +def make_inputs(total_tokens, num_reqs, block_size): + """Create synthetic FP8 cache, block table, and output buffer. + + Fills the cache with random bytes (we only measure throughput, + not correctness). Block table maps each request to contiguous + physical blocks. + """ + # Divide tokens evenly across requests + base_len = total_tokens // num_reqs + remainder = total_tokens % num_reqs + seq_lens = [base_len + (1 if r < remainder else 0) for r in range(num_reqs)] + + # workspace_starts: cumulative sum of seq_lens + workspace_starts = [0] * num_reqs + for r in range(1, num_reqs): + workspace_starts[r] = workspace_starts[r - 1] + seq_lens[r - 1] + + # Physical blocks needed per request + blocks_per_req = [math.ceil(s / block_size) for s in seq_lens] + total_blocks = sum(blocks_per_req) + max_blocks = max(blocks_per_req) + + # Allocate cache with random data (content doesn't matter for perf) + cache = torch.randint( + 0, + 256, + (total_blocks, block_size, ENTRY_BYTES), + dtype=torch.uint8, + device="cuda", + ) + + # Block table: contiguous block assignments + block_table = torch.zeros(num_reqs, max_blocks, dtype=torch.int32, device="cuda") + block_idx = 0 + for r in range(num_reqs): + for b in range(blocks_per_req[r]): + block_table[r, b] = block_idx + block_idx += 1 + + # Output workspace + dst = torch.zeros(total_tokens, HEAD_DIM, dtype=torch.bfloat16, device="cuda") + + seq_lens_t = torch.tensor(seq_lens, dtype=torch.int32, device="cuda") + workspace_starts_t = torch.tensor( + workspace_starts, dtype=torch.int32, device="cuda" + ) + + return cache, dst, block_table, seq_lens_t, workspace_starts_t + + +def bench_scenario(label, num_reqs, total_tokens_list, save_path): + """Run benchmark for a specific (num_reqs, total_tokens) scenario.""" + + @triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["total_tokens"], + x_vals=total_tokens_list, + line_arg="provider", + line_vals=["cuda_kernel"], + line_names=["cp_gather_fp8 (CUDA)"], + styles=[("green", "-")], + ylabel="Latency (us)", + plot_name=f"cp_gather_fp8-{label}-bs{BLOCK_SIZE}", + args={"num_reqs": num_reqs}, + ) + ) + def bench_fn(total_tokens, provider, num_reqs): + cache, dst, block_table, seq_lens_t, ws_starts = make_inputs( + total_tokens, num_reqs, BLOCK_SIZE + ) + + quantiles = [0.5, 0.2, 0.8] + + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + lambda: ops.cp_gather_and_upconvert_fp8_kv_cache( + cache, dst, block_table, seq_lens_t, ws_starts, num_reqs + ), + quantiles=quantiles, + rep=500, + ) + + return ms * 1000, max_ms * 1000, min_ms * 1000 # us + + seq_len_per_req = total_tokens_list[0] // num_reqs + seq_len_per_req_max = total_tokens_list[-1] // num_reqs + print( + f"\n--- {label}: {num_reqs} request(s), " + f"~{seq_len_per_req}-{seq_len_per_req_max} tokens/req ---" + ) + bench_fn.run(print_data=True, save_path=save_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Benchmark cp_gather_and_upconvert_fp8_kv_cache" + ) + parser.add_argument( + "--save-path", + type=str, + default=None, + help="Path to save benchmark results as CSV", + ) + args = parser.parse_args() + + # Print data volume info for bandwidth analysis + read_per_token = ENTRY_BYTES # 656 bytes from cache + write_per_token = HEAD_DIM * 2 # 576 * 2 = 1152 bytes to workspace + total_per_token = read_per_token + write_per_token # 1808 bytes + + print("\n" + "=" * 70) + print("CP_GATHER_AND_UPCONVERT_FP8_KV_CACHE BENCHMARKS") + print("=" * 70) + print(f"Cache entry: {ENTRY_BYTES} bytes (512 FP8 + 16 scales + 128 RoPE)") + print(f"Output row: {HEAD_DIM} BF16 = {HEAD_DIM * 2} bytes") + print(f"Per token: {total_per_token} bytes (read + write)") + print(f"Block size: {BLOCK_SIZE} tokens/block") + print("=" * 70) + + for label, num_reqs, total_tokens_list in SCENARIOS: + bench_scenario(label, num_reqs, total_tokens_list, args.save_path) + + print("\n" + "=" * 70) + print("Benchmarking complete!") + print("=" * 70) diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 3e8ffe15b..364686ef7 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -995,75 +995,67 @@ namespace vllm { // Similar to cp_gather_cache but specifically for FP8->BF16 conversion __global__ void cp_gather_and_upconvert_fp8_kv_cache( const uint8_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE, 656] - __nv_bfloat16* __restrict__ dst, // [TOT_TOKENS, 576] - const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES] - const int32_t* __restrict__ seq_lens, // [BATCH] - const int32_t* __restrict__ workspace_starts, // [BATCH] - const int32_t block_size, const int32_t head_dim, - const int64_t block_table_stride, const int64_t cache_block_stride, - const int64_t cache_entry_stride, const int64_t dst_entry_stride) { - const int64_t bid = blockIdx.x; // Batch ID - const int32_t num_splits = gridDim.y; - const int32_t split = blockIdx.y; - const int32_t seq_start = workspace_starts[bid]; - const int32_t seq_len = seq_lens[bid]; - const int32_t tot_slots = seq_len; - const int32_t split_slots = cuda_utils::ceil_div(tot_slots, num_splits); + __nv_bfloat16* __restrict__ dst, // [total_tokens, 576] + const int32_t* __restrict__ block_table, // [num_reqs, BLOCK_INDICES] + const int32_t* __restrict__ workspace_starts, // [num_reqs] + const int32_t num_reqs, const int32_t block_size, + const int32_t total_tokens, const int64_t block_table_stride, + const int64_t cache_block_stride, const int64_t cache_entry_stride, + const int64_t dst_entry_stride) { + const int flat_warp_id = (blockIdx.x * blockDim.x + threadIdx.x) >> 5; + if (flat_warp_id >= total_tokens) return; + const int lane_id = threadIdx.x & 31; - const int32_t split_start = split * split_slots; - const int32_t split_end = min((split + 1) * split_slots, tot_slots); - - const bool is_active_split = (split_start < tot_slots); - - if (!is_active_split) return; - - // Adjust the pointer for the block_table for this batch - const int32_t batch_offset = bid * block_table_stride; - int32_t offset = split_start; - int32_t offset_div = offset / block_size; - offset = offset % block_size; - const int32_t* batch_block_table = block_table + batch_offset; - - // Adjust dst pointer based on the cumulative sequence lengths - dst += seq_start * dst_entry_stride; - - const int tid = threadIdx.x; - - // Process each token in this split - for (int pid = split_start; pid < split_end; ++pid) { - auto block_id = batch_block_table[offset_div]; - const uint8_t* token_ptr = - src_cache + block_id * cache_block_stride + offset * cache_entry_stride; - __nv_bfloat16* dst_ptr = dst + pid * dst_entry_stride; - - // FP8 format: 512 bytes fp8 + 16 bytes scales + 128 bytes rope (64 bf16) - const uint8_t* no_pe_ptr = token_ptr; - const float* scales_ptr = reinterpret_cast(token_ptr + 512); - const __nv_bfloat16* rope_ptr = - reinterpret_cast(token_ptr + 512 + 16); - - // Parallelize fp8 dequant (512 elements) and rope copy (64 elements) - if (tid < 512) { - // FP8 dequantization - const int tile = tid >> 7; // each tile is 128 elements - const float scale = scales_ptr[tile]; - const uint8_t val = no_pe_ptr[tid]; - dst_ptr[tid] = - fp8::scaled_convert<__nv_bfloat16, uint8_t, - vllm::Fp8KVCacheDataType::kFp8E4M3>(val, scale); - } else if (tid < 576) { - // Rope copy (64 bf16 elements) - const int rope_idx = tid - 512; - dst_ptr[512 + rope_idx] = rope_ptr[rope_idx]; - } - - // Move to next token - offset += 1; - if (offset == block_size) { - offset_div += 1; - offset = 0; - } + // Binary search to find which request owns this output token + int lo = 0, hi = num_reqs - 1; + while (lo < hi) { + int mid = (lo + hi + 1) >> 1; + if (workspace_starts[mid] <= flat_warp_id) + lo = mid; + else + hi = mid - 1; } + const int req_id = lo; + + // Compute physical token address via block table + const int out_token_id = flat_warp_id; + const int token_offset = out_token_id - workspace_starts[req_id]; + const int cache_block_idx = token_offset / block_size; + const int offset_in_block = token_offset % block_size; + const int physical_block = + block_table[req_id * block_table_stride + cache_block_idx]; + + const uint8_t* token_ptr = src_cache + physical_block * cache_block_stride + + offset_in_block * cache_entry_stride; + + const int4* nope_src = reinterpret_cast(token_ptr); + const int4 fp8_data = nope_src[lane_id]; + + const float* scales_ptr = reinterpret_cast(token_ptr + 512); + const float scale = scales_ptr[lane_id >> 3]; + + const uint2 fp8_lo = make_uint2(fp8_data.x, fp8_data.y); + const uint2 fp8_hi = make_uint2(fp8_data.z, fp8_data.w); +#ifdef USE_ROCM + const bf16_8_t bf16_lo = + fp8::scaled_vec_conversion(fp8_lo, scale); + const bf16_8_t bf16_hi = + fp8::scaled_vec_conversion(fp8_hi, scale); +#else + const bf16_8_t bf16_lo = + fp8::scaled_vec_conversion(fp8_lo, scale, __NV_E4M3); + const bf16_8_t bf16_hi = + fp8::scaled_vec_conversion(fp8_hi, scale, __NV_E4M3); +#endif + + __nv_bfloat16* dst_ptr = dst + out_token_id * dst_entry_stride; + int4* nope_dst = reinterpret_cast(dst_ptr) + lane_id * 2; + nope_dst[0] = *reinterpret_cast(&bf16_lo); + nope_dst[1] = *reinterpret_cast(&bf16_hi); + + const int* rope_src = reinterpret_cast(token_ptr + 528); + int* rope_dst = reinterpret_cast(dst_ptr + 512); + rope_dst[lane_id] = rope_src[lane_id]; } template @@ -1257,15 +1249,16 @@ void cp_gather_and_upconvert_fp8_kv_cache( src_ptr = reinterpret_cast(src_cache.data_ptr()); } - // Decide on the number of splits based on the batch size - int num_splits = batch_size > 128 ? 2 : batch_size > 64 ? 4 : 16; - dim3 grid(batch_size, num_splits); - dim3 block(576); + const int total_tokens = dst.size(0); + constexpr int warps_per_block = 8; + const int grid_size = (total_tokens + warps_per_block - 1) / warps_per_block; + const int block_size_threads = warps_per_block * 32; // 256 threads - vllm::cp_gather_and_upconvert_fp8_kv_cache<<>>( + vllm::cp_gather_and_upconvert_fp8_kv_cache<<>>( src_ptr, reinterpret_cast<__nv_bfloat16*>(dst.data_ptr()), - block_table.data_ptr(), seq_lens.data_ptr(), - workspace_starts.data_ptr(), block_size, head_dim, + block_table.data_ptr(), workspace_starts.data_ptr(), + static_cast(batch_size), block_size, total_tokens, block_table_stride, cache_block_stride, cache_entry_stride, dst_entry_stride); } diff --git a/tests/kernels/test_cp_gather_fp8.py b/tests/kernels/test_cp_gather_fp8.py new file mode 100644 index 000000000..d9ee8defd --- /dev/null +++ b/tests/kernels/test_cp_gather_fp8.py @@ -0,0 +1,363 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import math + +import pytest +import torch + +from vllm import _custom_ops as ops + +# DeepSeek V3 MLA dimensions +NOPE_DIM = 512 # NoPE latent dimension (FP8 quantized in cache) +ROPE_DIM = 64 # RoPE dimension (stored as BF16 in cache) +NUM_TILES = 4 # NOPE_DIM / GROUP_SIZE = 512 / 128 +GROUP_SIZE = 128 # FP8 quantization group size (one scale per group) +ENTRY_BYTES = 656 # 512 (FP8) + 16 (4×float32 scales) + 128 (64×BF16 RoPE) + + +def _build_test_case(seq_lens, block_size, seed=42): + """Build a synthetic FP8 cache and compute the expected BF16 output. + + This simulates what concat_and_cache_ds_mla_kernel writes into the + KV cache, then computes what cp_gather_and_upconvert should produce. + + Args: + seq_lens: List of sequence lengths, one per request. + block_size: Number of tokens per physical cache block. + seed: Random seed for reproducibility. + + Returns: + Tuple of (cache, block_table, seq_lens_t, workspace_starts_t, + num_reqs, total_tokens, expected_output). + """ + torch.manual_seed(seed) + + num_reqs = len(seq_lens) + total_tokens = sum(seq_lens) + + # workspace_starts[r] = sum of seq_lens[0..r-1] + # This tells the kernel where in the output buffer each request's + # gathered tokens should be written. + workspace_starts = [] + s = 0 + for sl in seq_lens: + workspace_starts.append(s) + s += sl + + # How many physical cache blocks each request needs + blocks_per_req = [math.ceil(s / block_size) for s in seq_lens] + total_blocks = sum(blocks_per_req) + max_blocks = max(blocks_per_req) + + # Block table maps (request, logical_block_idx) -> physical_block_id. + # Here we assign blocks contiguously: request 0 gets blocks [0, 1, ...], + # request 1 gets the next set, etc. + block_table = torch.zeros(num_reqs, max_blocks, dtype=torch.int32, device="cuda") + block_idx = 0 + for r in range(num_reqs): + for b in range(blocks_per_req[r]): + block_table[r, b] = block_idx + block_idx += 1 + + # The raw paged cache: [num_blocks, block_size, 656] as uint8 + cache = torch.zeros( + total_blocks, block_size, ENTRY_BYTES, dtype=torch.uint8, device="cuda" + ) + # Expected kernel output: [total_tokens, 576] as BF16 + expected = torch.zeros( + total_tokens, NOPE_DIM + ROPE_DIM, dtype=torch.bfloat16, device="cuda" + ) + + # Fill each token's cache entry and compute expected output + for r in range(num_reqs): + for t in range(seq_lens[r]): + out_idx = workspace_starts[r] + t + # Map token position -> (physical_block, offset_within_block) + phys = block_table[r, t // block_size].item() + off = t % block_size + + # --- NoPE section: 4 tiles of 128 FP8 values, each with a scale --- + for tile in range(NUM_TILES): + start = tile * GROUP_SIZE + + # Generate random data and quantize to FP8 e4m3 + fp8_vals = torch.randn(GROUP_SIZE, device="cuda").to( + torch.float8_e4m3fn + ) + # Pack FP8 bytes into cache at bytes [start : start+128] + cache[phys, off, start : start + GROUP_SIZE] = fp8_vals.view( + torch.uint8 + ) + + # Random positive scale in [0.1, 2.1] + scale = (torch.rand(1, device="cuda") * 2.0 + 0.1).item() + scale_t = torch.tensor([scale], dtype=torch.float32, device="cuda") + # Pack scale as 4 raw bytes at bytes [512 + tile*4 : ...] + cache[phys, off, NOPE_DIM + tile * 4 : NOPE_DIM + (tile + 1) * 4] = ( + scale_t.view(torch.uint8) + ) + + # Reference dequant: fp8 -> float32, multiply scale, -> bf16. + # This matches the CUDA path: fp8 -> half -> float * scale -> bf16. + # (fp8 -> half is exact, half -> float is exact, so fp8 -> float + # gives the same result regardless of intermediate type.) + expected[out_idx, start : start + GROUP_SIZE] = ( + fp8_vals.float() * scale + ).bfloat16() + + # --- RoPE section: 64 BF16 values, direct copy (no dequant) --- + rope = torch.randn(ROPE_DIM, dtype=torch.bfloat16, device="cuda") + # Pack RoPE bytes into cache at bytes [528 : 656] + cache[phys, off, NOPE_DIM + 16 :] = rope.view(torch.uint8) + # Expected output: exact copy + expected[out_idx, NOPE_DIM:] = rope + + seq_lens_t = torch.tensor(seq_lens, dtype=torch.int32, device="cuda") + workspace_starts_t = torch.tensor( + workspace_starts, dtype=torch.int32, device="cuda" + ) + + return ( + cache, + block_table, + seq_lens_t, + workspace_starts_t, + num_reqs, + total_tokens, + expected, + ) + + +def _build_test_case_fast(seq_lens, block_size, seed=42): + """Vectorized test-case builder for large sequence lengths. + + Same logic as _build_test_case but uses tensor operations instead of + per-token Python loops, making it practical for seq_lens up to 128K+. + """ + torch.manual_seed(seed) + + num_reqs = len(seq_lens) + total_tokens = sum(seq_lens) + + workspace_starts = [] + s = 0 + for sl in seq_lens: + workspace_starts.append(s) + s += sl + + blocks_per_req = [math.ceil(sl / block_size) for sl in seq_lens] + total_blocks = sum(blocks_per_req) + max_blocks = max(blocks_per_req) + + # Contiguous block allocation + block_table = torch.zeros(num_reqs, max_blocks, dtype=torch.int32, device="cuda") + block_idx = 0 + for r in range(num_reqs): + for b in range(blocks_per_req[r]): + block_table[r, b] = block_idx + block_idx += 1 + + cache = torch.zeros( + total_blocks, block_size, ENTRY_BYTES, dtype=torch.uint8, device="cuda" + ) + + # Generate all data vectorized + nope_fp8 = torch.randn(total_tokens, NOPE_DIM, device="cuda").to( + torch.float8_e4m3fn + ) + scales = (torch.rand(total_tokens, NUM_TILES, device="cuda") * 2.0 + 0.1).float() + rope = torch.randn(total_tokens, ROPE_DIM, dtype=torch.bfloat16, device="cuda") + + # Compute expected output vectorized (same dequant logic as kernel) + expected = torch.zeros( + total_tokens, NOPE_DIM + ROPE_DIM, dtype=torch.bfloat16, device="cuda" + ) + for tile in range(NUM_TILES): + start = tile * GROUP_SIZE + expected[:, start : start + GROUP_SIZE] = ( + nope_fp8[:, start : start + GROUP_SIZE].float() * scales[:, tile : tile + 1] + ).bfloat16() + expected[:, NOPE_DIM:] = rope + + # Build per-token cache entries as [total_tokens, 656] uint8 + token_data = torch.zeros( + total_tokens, ENTRY_BYTES, dtype=torch.uint8, device="cuda" + ) + token_data[:, :NOPE_DIM] = nope_fp8.view(torch.uint8) + token_data[:, NOPE_DIM : NOPE_DIM + 16] = scales.view(torch.uint8) + token_data[:, NOPE_DIM + 16 :] = rope.view(torch.uint8) + + # Scatter into paged cache (loop over requests, not tokens) + block_start = 0 + for r in range(num_reqs): + sl = seq_lens[r] + nb = blocks_per_req[r] + ws = workspace_starts[r] + flat_cache = cache[block_start : block_start + nb].reshape(-1, ENTRY_BYTES) + flat_cache[:sl] = token_data[ws : ws + sl] + block_start += nb + + seq_lens_t = torch.tensor(seq_lens, dtype=torch.int32, device="cuda") + workspace_starts_t = torch.tensor( + workspace_starts, dtype=torch.int32, device="cuda" + ) + + return ( + cache, + block_table, + seq_lens_t, + workspace_starts_t, + num_reqs, + total_tokens, + expected, + ) + + +@pytest.mark.parametrize( + "seq_lens,block_size", + [ + # Production block_size=64 (only supported value for FlashMLA sparse). + # Realistic prefill scenarios with varying request counts. + ([1], 64), # single token edge case + ([64], 64), # 1 req, exactly one block + ([128], 64), # 1 req, crosses block boundary + ([512], 64), # 1 req, longer prefill + ([256, 128, 384], 64), # 3 reqs, varying lengths + ([128] * 4, 64), # 4 reqs, equal lengths + ([64] * 16, 64), # 16 reqs, shorter prefills + ], +) +def test_cp_gather_and_upconvert_fp8_kv_cache(seq_lens, block_size): + """Core correctness test: build cache, run kernel, compare output.""" + ( + cache, + block_table, + seq_lens_t, + workspace_starts_t, + num_reqs, + total_tokens, + expected, + ) = _build_test_case(seq_lens, block_size) + + dst = torch.zeros( + total_tokens, NOPE_DIM + ROPE_DIM, dtype=torch.bfloat16, device="cuda" + ) + + ops.cp_gather_and_upconvert_fp8_kv_cache( + cache, dst, block_table, seq_lens_t, workspace_starts_t, num_reqs + ) + + # NoPE: fp8 dequant has rounding error, so we allow small tolerance. + # The fp8 -> float -> bf16 path can differ by up to ~1 ULP of bf16. + torch.testing.assert_close( + dst[:, :NOPE_DIM], expected[:, :NOPE_DIM], atol=1e-3, rtol=1e-2 + ) + + # RoPE: pure bf16 copy, must be bit-exact + assert torch.equal(dst[:, NOPE_DIM:], expected[:, NOPE_DIM:]) + + +def test_cp_gather_fp8_shuffled_blocks(): + """Test that the kernel correctly follows the block table when + physical blocks are non-contiguous and out of order. + + Here we allocate 4 physical blocks but map the request's 2 logical + blocks to physical blocks [3, 1] (reversed, with gaps). + """ + torch.manual_seed(123) + block_size = 4 + seq_lens = [8] # needs 2 blocks (tokens 0-3 in block 0, 4-7 in block 1) + total_tokens = 8 + + # 4 physical blocks, but only blocks 3 and 1 are used (in that order). + # Tokens 0-3 -> physical block 3, tokens 4-7 -> physical block 1. + num_phys_blocks = 4 + cache = torch.zeros( + num_phys_blocks, block_size, ENTRY_BYTES, dtype=torch.uint8, device="cuda" + ) + block_table = torch.tensor([[3, 1]], dtype=torch.int32, device="cuda") + workspace_starts = torch.tensor([0], dtype=torch.int32, device="cuda") + seq_lens_t = torch.tensor(seq_lens, dtype=torch.int32, device="cuda") + + expected = torch.zeros( + total_tokens, NOPE_DIM + ROPE_DIM, dtype=torch.bfloat16, device="cuda" + ) + + # Fill cache at the shuffled physical locations + for t in range(total_tokens): + # Follow the same block_table lookup the kernel will use + phys = block_table[0, t // block_size].item() + off = t % block_size + + for tile in range(NUM_TILES): + start = tile * GROUP_SIZE + fp8_vals = torch.randn(GROUP_SIZE, device="cuda").to(torch.float8_e4m3fn) + cache[phys, off, start : start + GROUP_SIZE] = fp8_vals.view(torch.uint8) + + # Use a fixed scale to keep this test simple + scale = 1.5 + scale_t = torch.tensor([scale], dtype=torch.float32, device="cuda") + cache[phys, off, NOPE_DIM + tile * 4 : NOPE_DIM + (tile + 1) * 4] = ( + scale_t.view(torch.uint8) + ) + + expected[t, start : start + GROUP_SIZE] = ( + fp8_vals.float() * scale + ).bfloat16() + + rope = torch.randn(ROPE_DIM, dtype=torch.bfloat16, device="cuda") + cache[phys, off, NOPE_DIM + 16 :] = rope.view(torch.uint8) + expected[t, NOPE_DIM:] = rope + + dst = torch.zeros( + total_tokens, NOPE_DIM + ROPE_DIM, dtype=torch.bfloat16, device="cuda" + ) + + ops.cp_gather_and_upconvert_fp8_kv_cache( + cache, dst, block_table, seq_lens_t, workspace_starts, len(seq_lens) + ) + + torch.testing.assert_close( + dst[:, :NOPE_DIM], expected[:, :NOPE_DIM], atol=1e-3, rtol=1e-2 + ) + assert torch.equal(dst[:, NOPE_DIM:], expected[:, NOPE_DIM:]) + + +@pytest.mark.parametrize( + "seq_lens,block_size", + [ + # Large sequence lengths matching end-to-end benchmark scenarios. + # Uses vectorized builder since per-token Python loops would be too slow. + ([8000], 64), + ([16000], 64), + ([32000], 64), + ([64000], 64), + ([96000], 64), + ([128000], 64), + ], +) +def test_cp_gather_fp8_large_seqlens(seq_lens, block_size): + """Correctness test with large sequence lengths matching benchmark + scenarios (8K-128K prefill).""" + ( + cache, + block_table, + seq_lens_t, + workspace_starts_t, + num_reqs, + total_tokens, + expected, + ) = _build_test_case_fast(seq_lens, block_size) + + dst = torch.zeros( + total_tokens, NOPE_DIM + ROPE_DIM, dtype=torch.bfloat16, device="cuda" + ) + + ops.cp_gather_and_upconvert_fp8_kv_cache( + cache, dst, block_table, seq_lens_t, workspace_starts_t, num_reqs + ) + + torch.testing.assert_close( + dst[:, :NOPE_DIM], expected[:, :NOPE_DIM], atol=1e-3, rtol=1e-2 + ) + assert torch.equal(dst[:, NOPE_DIM:], expected[:, NOPE_DIM:])