Support non-contiguous KV cache in TRTLLM fp8 dequant kernel (#36867)

Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Signed-off-by: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com>
Co-authored-by: Pavani Majety <pavanimajety@gmail.com>
This commit is contained in:
Vadim Gimpelson
2026-03-17 04:48:42 +04:00
committed by GitHub
parent 45f526d652
commit 6c1cfbad32
2 changed files with 491 additions and 26 deletions

View File

@@ -96,8 +96,13 @@ def _trtllm_prefill_attn_kvfp8_dequant(
mock_kv_cache_ptr,
k_scale_ptr,
v_scale_ptr,
K_CACHE_STRIDE: tl.constexpr,
KV_CACHE_STRIDE: tl.constexpr,
src_stride_page,
src_stride_kv,
src_stride_head,
DST_K_CACHE_STRIDE: tl.constexpr,
DST_KV_CACHE_STRIDE: tl.constexpr,
HEAD_STRIDE: tl.constexpr,
NUM_KV_HEADS: tl.constexpr,
):
batch_idx = tl.program_id(0).to(tl.int64)
mock_block_table_idx = tl.program_id(1).to(tl.int64)
@@ -108,31 +113,42 @@ def _trtllm_prefill_attn_kvfp8_dequant(
return
dequant_dtype = mock_kv_cache_ptr.dtype.element_ty
# Dequantize K
k_scale_val = tl.load(k_scale_ptr)
offset = orig_page_num * KV_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE)
fp8_vals = tl.load(kv_cache_ptr + offset)
dequantized_vals = fp8_vals.to(tl.float32) * k_scale_val
mock_cache_offset = (
batch_idx * block_table_stride + mock_block_table_idx + 1
) * KV_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE)
dequantized_vals = dequantized_vals.to(dequant_dtype)
tl.store(mock_kv_cache_ptr + mock_cache_offset, dequantized_vals)
# Dequantize V
v_scale_val = tl.load(v_scale_ptr)
offset = (
orig_page_num * KV_CACHE_STRIDE + K_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE)
)
fp8_vals = tl.load(kv_cache_ptr + offset)
dequantized_vals = fp8_vals.to(tl.float32) * v_scale_val
mock_cache_offset = (
(batch_idx * block_table_stride + mock_block_table_idx + 1) * KV_CACHE_STRIDE
+ K_CACHE_STRIDE
+ tl.arange(0, K_CACHE_STRIDE)
)
dequantized_vals = dequantized_vals.to(dequant_dtype)
tl.store(mock_kv_cache_ptr + mock_cache_offset, dequantized_vals)
mock_page_idx = batch_idx * block_table_stride + mock_block_table_idx + 1
head_offsets = tl.arange(0, HEAD_STRIDE)
for h in range(NUM_KV_HEADS):
h_off = tl.cast(h, tl.int64)
# Read K from source (supports non-contiguous page/kv/head strides)
src_k = orig_page_num * src_stride_page + h_off * src_stride_head + head_offsets
fp8_k = tl.load(kv_cache_ptr + src_k)
dequant_k = (fp8_k.to(tl.float32) * k_scale_val).to(dequant_dtype)
# Write K to contiguous mock cache
dst_k = mock_page_idx * DST_KV_CACHE_STRIDE + h * HEAD_STRIDE + head_offsets
tl.store(mock_kv_cache_ptr + dst_k, dequant_k)
# Read V from source (offset by src_stride_kv for the V half)
src_v = (
orig_page_num * src_stride_page
+ src_stride_kv
+ h_off * src_stride_head
+ head_offsets
)
fp8_v = tl.load(kv_cache_ptr + src_v)
dequant_v = (fp8_v.to(tl.float32) * v_scale_val).to(dequant_dtype)
# Write V to contiguous mock cache
dst_v = (
mock_page_idx * DST_KV_CACHE_STRIDE
+ DST_K_CACHE_STRIDE
+ h * HEAD_STRIDE
+ head_offsets
)
tl.store(mock_kv_cache_ptr + dst_v, dequant_v)
def trtllm_prefill_attn_kvfp8_dequant(
@@ -146,8 +162,18 @@ def trtllm_prefill_attn_kvfp8_dequant(
s = kv_cache.shape
assert s[1] == 2
assert dequant_dtype in (torch.bfloat16, torch.float16)
k_cache_stride = s[2] * s[3] * s[4]
num_kv_heads, block_size, head_size = s[2], s[3], s[4]
head_stride = block_size * head_size
k_cache_stride = num_kv_heads * head_stride
kv_cache_stride = k_cache_stride * s[1]
strides = kv_cache.stride()
assert strides[3] == head_size and strides[4] == 1, (
"For kv cache layouts, (block_size, head_size) "
f"dimensions must be contiguous, got strides {strides}"
)
new_s = (batch_size * num_of_page_per_token + 1, s[1], s[2], s[3], s[4])
# mock kv cache contains just the pages needed by this prefill
mock_kv_cache = torch.empty(new_s, dtype=dequant_dtype, device=kv_cache.device)
@@ -166,8 +192,13 @@ def trtllm_prefill_attn_kvfp8_dequant(
mock_kv_cache,
k_scale,
v_scale,
strides[0],
strides[1],
strides[2],
k_cache_stride,
kv_cache_stride,
head_stride,
num_kv_heads,
)
return mock_kv_cache, mock_block_table