Support non-contiguous KV cache in TRTLLM fp8 dequant kernel (#36867)
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com> Signed-off-by: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com> Co-authored-by: Pavani Majety <pavanimajety@gmail.com>
This commit is contained in:
@@ -96,8 +96,13 @@ def _trtllm_prefill_attn_kvfp8_dequant(
|
||||
mock_kv_cache_ptr,
|
||||
k_scale_ptr,
|
||||
v_scale_ptr,
|
||||
K_CACHE_STRIDE: tl.constexpr,
|
||||
KV_CACHE_STRIDE: tl.constexpr,
|
||||
src_stride_page,
|
||||
src_stride_kv,
|
||||
src_stride_head,
|
||||
DST_K_CACHE_STRIDE: tl.constexpr,
|
||||
DST_KV_CACHE_STRIDE: tl.constexpr,
|
||||
HEAD_STRIDE: tl.constexpr,
|
||||
NUM_KV_HEADS: tl.constexpr,
|
||||
):
|
||||
batch_idx = tl.program_id(0).to(tl.int64)
|
||||
mock_block_table_idx = tl.program_id(1).to(tl.int64)
|
||||
@@ -108,31 +113,42 @@ def _trtllm_prefill_attn_kvfp8_dequant(
|
||||
return
|
||||
dequant_dtype = mock_kv_cache_ptr.dtype.element_ty
|
||||
|
||||
# Dequantize K
|
||||
k_scale_val = tl.load(k_scale_ptr)
|
||||
offset = orig_page_num * KV_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE)
|
||||
fp8_vals = tl.load(kv_cache_ptr + offset)
|
||||
dequantized_vals = fp8_vals.to(tl.float32) * k_scale_val
|
||||
mock_cache_offset = (
|
||||
batch_idx * block_table_stride + mock_block_table_idx + 1
|
||||
) * KV_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE)
|
||||
dequantized_vals = dequantized_vals.to(dequant_dtype)
|
||||
tl.store(mock_kv_cache_ptr + mock_cache_offset, dequantized_vals)
|
||||
|
||||
# Dequantize V
|
||||
v_scale_val = tl.load(v_scale_ptr)
|
||||
offset = (
|
||||
orig_page_num * KV_CACHE_STRIDE + K_CACHE_STRIDE + tl.arange(0, K_CACHE_STRIDE)
|
||||
)
|
||||
fp8_vals = tl.load(kv_cache_ptr + offset)
|
||||
dequantized_vals = fp8_vals.to(tl.float32) * v_scale_val
|
||||
mock_cache_offset = (
|
||||
(batch_idx * block_table_stride + mock_block_table_idx + 1) * KV_CACHE_STRIDE
|
||||
+ K_CACHE_STRIDE
|
||||
+ tl.arange(0, K_CACHE_STRIDE)
|
||||
)
|
||||
dequantized_vals = dequantized_vals.to(dequant_dtype)
|
||||
tl.store(mock_kv_cache_ptr + mock_cache_offset, dequantized_vals)
|
||||
|
||||
mock_page_idx = batch_idx * block_table_stride + mock_block_table_idx + 1
|
||||
head_offsets = tl.arange(0, HEAD_STRIDE)
|
||||
|
||||
for h in range(NUM_KV_HEADS):
|
||||
h_off = tl.cast(h, tl.int64)
|
||||
|
||||
# Read K from source (supports non-contiguous page/kv/head strides)
|
||||
src_k = orig_page_num * src_stride_page + h_off * src_stride_head + head_offsets
|
||||
fp8_k = tl.load(kv_cache_ptr + src_k)
|
||||
dequant_k = (fp8_k.to(tl.float32) * k_scale_val).to(dequant_dtype)
|
||||
|
||||
# Write K to contiguous mock cache
|
||||
dst_k = mock_page_idx * DST_KV_CACHE_STRIDE + h * HEAD_STRIDE + head_offsets
|
||||
tl.store(mock_kv_cache_ptr + dst_k, dequant_k)
|
||||
|
||||
# Read V from source (offset by src_stride_kv for the V half)
|
||||
src_v = (
|
||||
orig_page_num * src_stride_page
|
||||
+ src_stride_kv
|
||||
+ h_off * src_stride_head
|
||||
+ head_offsets
|
||||
)
|
||||
fp8_v = tl.load(kv_cache_ptr + src_v)
|
||||
dequant_v = (fp8_v.to(tl.float32) * v_scale_val).to(dequant_dtype)
|
||||
|
||||
# Write V to contiguous mock cache
|
||||
dst_v = (
|
||||
mock_page_idx * DST_KV_CACHE_STRIDE
|
||||
+ DST_K_CACHE_STRIDE
|
||||
+ h * HEAD_STRIDE
|
||||
+ head_offsets
|
||||
)
|
||||
tl.store(mock_kv_cache_ptr + dst_v, dequant_v)
|
||||
|
||||
|
||||
def trtllm_prefill_attn_kvfp8_dequant(
|
||||
@@ -146,8 +162,18 @@ def trtllm_prefill_attn_kvfp8_dequant(
|
||||
s = kv_cache.shape
|
||||
assert s[1] == 2
|
||||
assert dequant_dtype in (torch.bfloat16, torch.float16)
|
||||
k_cache_stride = s[2] * s[3] * s[4]
|
||||
|
||||
num_kv_heads, block_size, head_size = s[2], s[3], s[4]
|
||||
head_stride = block_size * head_size
|
||||
k_cache_stride = num_kv_heads * head_stride
|
||||
kv_cache_stride = k_cache_stride * s[1]
|
||||
|
||||
strides = kv_cache.stride()
|
||||
assert strides[3] == head_size and strides[4] == 1, (
|
||||
"For kv cache layouts, (block_size, head_size) "
|
||||
f"dimensions must be contiguous, got strides {strides}"
|
||||
)
|
||||
|
||||
new_s = (batch_size * num_of_page_per_token + 1, s[1], s[2], s[3], s[4])
|
||||
# mock kv cache contains just the pages needed by this prefill
|
||||
mock_kv_cache = torch.empty(new_s, dtype=dequant_dtype, device=kv_cache.device)
|
||||
@@ -166,8 +192,13 @@ def trtllm_prefill_attn_kvfp8_dequant(
|
||||
mock_kv_cache,
|
||||
k_scale,
|
||||
v_scale,
|
||||
strides[0],
|
||||
strides[1],
|
||||
strides[2],
|
||||
k_cache_stride,
|
||||
kv_cache_stride,
|
||||
head_stride,
|
||||
num_kv_heads,
|
||||
)
|
||||
return mock_kv_cache, mock_block_table
|
||||
|
||||
|
||||
Reference in New Issue
Block a user