diff --git a/tests/kernels/attention/test_trtllm_kvfp8_dequant.py b/tests/kernels/attention/test_trtllm_kvfp8_dequant.py new file mode 100644 index 000000000..a2ea372c0 --- /dev/null +++ b/tests/kernels/attention/test_trtllm_kvfp8_dequant.py @@ -0,0 +1,434 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Standalone unit tests for trtllm_prefill_attn_kvfp8_dequant. + +Tests both contiguous and non-contiguous (cross-layer unified) KV cache +layouts against a pure-PyTorch reference implementation. +""" + +import pytest +import torch + +from vllm.platforms import current_platform + +FP8_DTYPE = current_platform.fp8_dtype() + +NUM_BLOCKS = 128 + + +def to_float8(x, dtype=None): + if dtype is None: + dtype = FP8_DTYPE + finfo = torch.finfo(dtype) + min_val, max_val = x.aminmax() + amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) + scale = finfo.max / amax * 0.1 + x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max) + return x_scl_sat.to(dtype), scale.float().reciprocal() + + +def make_contiguous_kv_cache(num_blocks, num_kv_heads, block_size, head_size): + """Create a standard contiguous fp8 KV cache (HND layout).""" + raw = torch.randn( + num_blocks, + 2, + num_kv_heads, + block_size, + head_size, + dtype=torch.bfloat16, + device="cuda", + ) + kv_cache, scale = to_float8(raw) + return kv_cache, scale + + +def make_cross_layer_kv_cache( + num_blocks, + num_kv_heads, + block_size, + head_size, + num_layers=4, +): + """ + Create a non-contiguous per-layer view mimicking cross-layer allocation. + + Physical layout: (num_blocks, 2, num_kv_heads, num_layers, block_size, head_size) + Returned view: (num_blocks, 2, num_kv_heads, block_size, head_size) + with non-contiguous strides on dims 0, 1, 2 (they skip over num_layers). + """ + raw = torch.randn( + num_blocks, + 2, + num_kv_heads, + num_layers, + block_size, + head_size, + dtype=torch.bfloat16, + device="cuda", + ) + fp8_full, scale = to_float8(raw) + layer_view = fp8_full[:, :, :, 0, :, :] + assert not layer_view.is_contiguous(), ( + f"Expected non-contiguous view, got strides {layer_view.stride()}" + ) + return layer_view, scale + + +def ref_dequant(kv_cache, block_tables, k_scale, v_scale, dequant_dtype): + """Pure PyTorch reference: gather pages and dequantize fp8 -> dequant_dtype.""" + batch_size, num_pages_per_seq = block_tables.shape + s = kv_cache.shape + out = torch.zeros( + batch_size * num_pages_per_seq + 1, + s[1], + s[2], + s[3], + s[4], + dtype=dequant_dtype, + device=kv_cache.device, + ) + for b in range(batch_size): + for p in range(num_pages_per_seq): + page_idx = block_tables[b, p].item() + if page_idx <= 0: + continue + mock_idx = b * num_pages_per_seq + p + 1 + out[mock_idx, 0] = (kv_cache[page_idx, 0].float() * k_scale.item()).to( + dequant_dtype + ) + out[mock_idx, 1] = (kv_cache[page_idx, 1].float() * v_scale.item()).to( + dequant_dtype + ) + return out + + +@pytest.mark.parametrize("num_kv_heads", [1, 8]) +@pytest.mark.parametrize("head_size", [64, 128]) +@pytest.mark.parametrize("block_size", [16, 32]) +@pytest.mark.parametrize("batch_size", [1, 4]) +@pytest.mark.parametrize("num_pages_per_seq", [3, 8]) +@pytest.mark.parametrize("contiguous", [True, False]) +@torch.inference_mode() +def test_trtllm_kvfp8_dequant( + num_kv_heads: int, + head_size: int, + block_size: int, + batch_size: int, + num_pages_per_seq: int, + contiguous: bool, +): + from vllm.v1.attention.backends.flashinfer import ( + trtllm_prefill_attn_kvfp8_dequant, + ) + + torch.set_default_device("cuda") + + if contiguous: + kv_cache, scale = make_contiguous_kv_cache( + NUM_BLOCKS, + num_kv_heads, + block_size, + head_size, + ) + else: + kv_cache, scale = make_cross_layer_kv_cache( + NUM_BLOCKS, + num_kv_heads, + block_size, + head_size, + ) + + k_scale = scale.clone() + v_scale = scale.clone() + + block_tables = torch.randint( + 1, + NUM_BLOCKS, + (batch_size, num_pages_per_seq), + dtype=torch.int32, + ) + + mock_kv_cache, mock_block_table = trtllm_prefill_attn_kvfp8_dequant( + kv_cache, + block_tables, + k_scale, + v_scale, + torch.bfloat16, + ) + + ref = ref_dequant(kv_cache, block_tables, k_scale, v_scale, torch.bfloat16) + + expected_bt = torch.arange( + 1, + batch_size * num_pages_per_seq + 1, + dtype=torch.int32, + device="cuda", + ).reshape(batch_size, num_pages_per_seq) + torch.testing.assert_close(mock_block_table, expected_bt) + + # Page 0 is padding (never written), compare only pages 1+ + torch.testing.assert_close(mock_kv_cache[1:], ref[1:], atol=1e-3, rtol=1e-3) + + +@torch.inference_mode() +def test_block_tables_with_zero_pages(): + """Pages with index <= 0 must be skipped (early return in kernel).""" + from vllm.v1.attention.backends.flashinfer import ( + trtllm_prefill_attn_kvfp8_dequant, + ) + + torch.set_default_device("cuda") + num_kv_heads, block_size, head_size = 8, 16, 64 + + kv_cache, scale = make_contiguous_kv_cache( + NUM_BLOCKS, + num_kv_heads, + block_size, + head_size, + ) + k_scale = v_scale = scale.clone() + + # Mix of valid pages and zeros (padding) + block_tables = torch.tensor( + [[5, 0, 10], [0, 0, 0], [3, 7, 0]], + dtype=torch.int32, + device="cuda", + ) + + mock_kv_cache, _ = trtllm_prefill_attn_kvfp8_dequant( + kv_cache, + block_tables, + k_scale, + v_scale, + torch.bfloat16, + ) + ref = ref_dequant(kv_cache, block_tables, k_scale, v_scale, torch.bfloat16) + + # Only compare pages that were actually written (non-zero page indices) + for b in range(block_tables.shape[0]): + for p in range(block_tables.shape[1]): + if block_tables[b, p].item() > 0: + idx = b * block_tables.shape[1] + p + 1 + torch.testing.assert_close( + mock_kv_cache[idx], + ref[idx], + atol=1e-3, + rtol=1e-3, + ) + + +@torch.inference_mode() +def test_all_zero_block_tables(): + """All-zero block_tables: kernel should write nothing.""" + from vllm.v1.attention.backends.flashinfer import ( + trtllm_prefill_attn_kvfp8_dequant, + ) + + torch.set_default_device("cuda") + num_kv_heads, block_size, head_size = 4, 16, 64 + + kv_cache, scale = make_contiguous_kv_cache( + NUM_BLOCKS, + num_kv_heads, + block_size, + head_size, + ) + k_scale = v_scale = scale.clone() + + block_tables = torch.zeros(2, 4, dtype=torch.int32, device="cuda") + + # Should not crash even though no pages are valid + mock_kv_cache, mock_block_table = trtllm_prefill_attn_kvfp8_dequant( + kv_cache, + block_tables, + k_scale, + v_scale, + torch.bfloat16, + ) + assert mock_kv_cache.shape[0] == 2 * 4 + 1 + assert mock_block_table.shape == (2, 4) + + +@torch.inference_mode() +def test_different_k_v_scales(): + """Verify K and V are dequantized with independent scales.""" + from vllm.v1.attention.backends.flashinfer import ( + trtllm_prefill_attn_kvfp8_dequant, + ) + + torch.set_default_device("cuda") + num_kv_heads, block_size, head_size = 8, 16, 64 + + kv_cache, _ = make_contiguous_kv_cache( + NUM_BLOCKS, + num_kv_heads, + block_size, + head_size, + ) + k_scale = torch.tensor([0.5], dtype=torch.float32, device="cuda") + v_scale = torch.tensor([2.0], dtype=torch.float32, device="cuda") + + block_tables = torch.tensor([[1, 2]], dtype=torch.int32, device="cuda") + + mock_kv_cache, _ = trtllm_prefill_attn_kvfp8_dequant( + kv_cache, + block_tables, + k_scale, + v_scale, + torch.bfloat16, + ) + ref = ref_dequant(kv_cache, block_tables, k_scale, v_scale, torch.bfloat16) + + torch.testing.assert_close(mock_kv_cache[1:], ref[1:], atol=1e-3, rtol=1e-3) + + +@torch.inference_mode() +def test_single_page_per_seq(): + """Minimum grid dim 1 = 1 page per sequence.""" + from vllm.v1.attention.backends.flashinfer import ( + trtllm_prefill_attn_kvfp8_dequant, + ) + + torch.set_default_device("cuda") + num_kv_heads, block_size, head_size = 8, 16, 128 + + kv_cache, scale = make_contiguous_kv_cache( + NUM_BLOCKS, + num_kv_heads, + block_size, + head_size, + ) + k_scale = v_scale = scale.clone() + + block_tables = torch.tensor([[5], [10], [20]], dtype=torch.int32, device="cuda") + + mock_kv_cache, _ = trtllm_prefill_attn_kvfp8_dequant( + kv_cache, + block_tables, + k_scale, + v_scale, + torch.bfloat16, + ) + ref = ref_dequant(kv_cache, block_tables, k_scale, v_scale, torch.bfloat16) + + torch.testing.assert_close(mock_kv_cache[1:], ref[1:], atol=1e-3, rtol=1e-3) + + +@torch.inference_mode() +def test_large_page_indices(): + """Page indices near the top of the buffer stress offset arithmetic.""" + from vllm.v1.attention.backends.flashinfer import ( + trtllm_prefill_attn_kvfp8_dequant, + ) + + torch.set_default_device("cuda") + num_kv_heads, block_size, head_size = 8, 16, 128 + large_num_blocks = 32768 + + kv_cache, scale = make_contiguous_kv_cache( + large_num_blocks, + num_kv_heads, + block_size, + head_size, + ) + k_scale = v_scale = scale.clone() + + # Use page indices near the top of the buffer + block_tables = torch.tensor( + [[large_num_blocks - 1, large_num_blocks - 2, 1]], + dtype=torch.int32, + device="cuda", + ) + + mock_kv_cache, _ = trtllm_prefill_attn_kvfp8_dequant( + kv_cache, + block_tables, + k_scale, + v_scale, + torch.bfloat16, + ) + ref = ref_dequant(kv_cache, block_tables, k_scale, v_scale, torch.bfloat16) + + torch.testing.assert_close(mock_kv_cache[1:], ref[1:], atol=1e-3, rtol=1e-3) + + +@torch.inference_mode() +def test_large_block_size(): + """block_size=64 -> HEAD_STRIDE=8192, large tl.arange per thread block.""" + from vllm.v1.attention.backends.flashinfer import ( + trtllm_prefill_attn_kvfp8_dequant, + ) + + torch.set_default_device("cuda") + num_kv_heads, block_size, head_size = 4, 64, 128 + + kv_cache, scale = make_contiguous_kv_cache( + NUM_BLOCKS, + num_kv_heads, + block_size, + head_size, + ) + k_scale = v_scale = scale.clone() + + block_tables = torch.randint( + 1, + NUM_BLOCKS, + (2, 4), + dtype=torch.int32, + device="cuda", + ) + + mock_kv_cache, _ = trtllm_prefill_attn_kvfp8_dequant( + kv_cache, + block_tables, + k_scale, + v_scale, + torch.bfloat16, + ) + ref = ref_dequant(kv_cache, block_tables, k_scale, v_scale, torch.bfloat16) + + torch.testing.assert_close(mock_kv_cache[1:], ref[1:], atol=1e-3, rtol=1e-3) + + +@torch.inference_mode() +def test_cross_layer_many_layers(): + """ + Non-contiguous with 36 layers -- matches real gpt-oss-120b. + Strides are far from contiguous (factor of 36 in the gaps). + """ + from vllm.v1.attention.backends.flashinfer import ( + trtllm_prefill_attn_kvfp8_dequant, + ) + + torch.set_default_device("cuda") + num_kv_heads, block_size, head_size = 8, 16, 64 + num_layers = 36 + + kv_cache, scale = make_cross_layer_kv_cache( + NUM_BLOCKS, + num_kv_heads, + block_size, + head_size, + num_layers=num_layers, + ) + k_scale = v_scale = scale.clone() + + block_tables = torch.randint( + 1, + NUM_BLOCKS, + (4, 6), + dtype=torch.int32, + device="cuda", + ) + + mock_kv_cache, _ = trtllm_prefill_attn_kvfp8_dequant( + kv_cache, + block_tables, + k_scale, + v_scale, + torch.bfloat16, + ) + ref = ref_dequant(kv_cache, block_tables, k_scale, v_scale, torch.bfloat16) + + torch.testing.assert_close(mock_kv_cache[1:], ref[1:], atol=1e-3, rtol=1e-3) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 595f4ffa5..411ec746c 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -96,8 +96,13 @@ def _trtllm_prefill_attn_kvfp8_dequant( mock_kv_cache_ptr, k_scale_ptr, v_scale_ptr, - K_CACHE_STRIDE: tl.constexpr, - KV_CACHE_STRIDE: tl.constexpr, + src_stride_page, + src_stride_kv, + src_stride_head, + DST_K_CACHE_STRIDE: tl.constexpr, + DST_KV_CACHE_STRIDE: tl.constexpr, + HEAD_STRIDE: tl.constexpr, + NUM_KV_HEADS: tl.constexpr, ): batch_idx = tl.program_id(0).to(tl.int64) mock_block_table_idx = tl.program_id(1).to(tl.int64) @@ -108,31 +113,42 @@ def _trtllm_prefill_attn_kvfp8_dequant( return dequant_dtype = mock_kv_cache_ptr.dtype.element_ty - # Dequantize K k_scale_val = tl.load(k_scale_ptr) - offset = orig_page_num * KV_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE) - fp8_vals = tl.load(kv_cache_ptr + offset) - dequantized_vals = fp8_vals.to(tl.float32) * k_scale_val - mock_cache_offset = ( - batch_idx * block_table_stride + mock_block_table_idx + 1 - ) * KV_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE) - dequantized_vals = dequantized_vals.to(dequant_dtype) - tl.store(mock_kv_cache_ptr + mock_cache_offset, dequantized_vals) - - # Dequantize V v_scale_val = tl.load(v_scale_ptr) - offset = ( - orig_page_num * KV_CACHE_STRIDE + K_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE) - ) - fp8_vals = tl.load(kv_cache_ptr + offset) - dequantized_vals = fp8_vals.to(tl.float32) * v_scale_val - mock_cache_offset = ( - (batch_idx * block_table_stride + mock_block_table_idx + 1) * KV_CACHE_STRIDE - + K_CACHE_STRIDE - + tl.arange(0, K_CACHE_STRIDE) - ) - dequantized_vals = dequantized_vals.to(dequant_dtype) - tl.store(mock_kv_cache_ptr + mock_cache_offset, dequantized_vals) + + mock_page_idx = batch_idx * block_table_stride + mock_block_table_idx + 1 + head_offsets = tl.arange(0, HEAD_STRIDE) + + for h in range(NUM_KV_HEADS): + h_off = tl.cast(h, tl.int64) + + # Read K from source (supports non-contiguous page/kv/head strides) + src_k = orig_page_num * src_stride_page + h_off * src_stride_head + head_offsets + fp8_k = tl.load(kv_cache_ptr + src_k) + dequant_k = (fp8_k.to(tl.float32) * k_scale_val).to(dequant_dtype) + + # Write K to contiguous mock cache + dst_k = mock_page_idx * DST_KV_CACHE_STRIDE + h * HEAD_STRIDE + head_offsets + tl.store(mock_kv_cache_ptr + dst_k, dequant_k) + + # Read V from source (offset by src_stride_kv for the V half) + src_v = ( + orig_page_num * src_stride_page + + src_stride_kv + + h_off * src_stride_head + + head_offsets + ) + fp8_v = tl.load(kv_cache_ptr + src_v) + dequant_v = (fp8_v.to(tl.float32) * v_scale_val).to(dequant_dtype) + + # Write V to contiguous mock cache + dst_v = ( + mock_page_idx * DST_KV_CACHE_STRIDE + + DST_K_CACHE_STRIDE + + h * HEAD_STRIDE + + head_offsets + ) + tl.store(mock_kv_cache_ptr + dst_v, dequant_v) def trtllm_prefill_attn_kvfp8_dequant( @@ -146,8 +162,18 @@ def trtllm_prefill_attn_kvfp8_dequant( s = kv_cache.shape assert s[1] == 2 assert dequant_dtype in (torch.bfloat16, torch.float16) - k_cache_stride = s[2] * s[3] * s[4] + + num_kv_heads, block_size, head_size = s[2], s[3], s[4] + head_stride = block_size * head_size + k_cache_stride = num_kv_heads * head_stride kv_cache_stride = k_cache_stride * s[1] + + strides = kv_cache.stride() + assert strides[3] == head_size and strides[4] == 1, ( + "For kv cache layouts, (block_size, head_size) " + f"dimensions must be contiguous, got strides {strides}" + ) + new_s = (batch_size * num_of_page_per_token + 1, s[1], s[2], s[3], s[4]) # mock kv cache contains just the pages needed by this prefill mock_kv_cache = torch.empty(new_s, dtype=dequant_dtype, device=kv_cache.device) @@ -166,8 +192,13 @@ def trtllm_prefill_attn_kvfp8_dequant( mock_kv_cache, k_scale, v_scale, + strides[0], + strides[1], + strides[2], k_cache_stride, kv_cache_stride, + head_stride, + num_kv_heads, ) return mock_kv_cache, mock_block_table