[ROCM][Bugfix] Use correct stride in cp_mha_gather_cache_kernel for hybrid model (#37228) (#37228)

Signed-off-by: jennyyyyzhen <yzhen@hmc.edu>
Co-authored-by: yZhen <yZhen@fb.com>
This commit is contained in:
jennyyyyzhen
2026-03-26 10:33:39 -07:00
committed by GitHub
parent 9c3ae04bfe
commit a4cf9b22ba

View File

@@ -59,6 +59,12 @@ if current_platform.is_rocm():
head_size, head_size,
x, x,
max_block_num, max_block_num,
k_cache_stride0,
k_cache_stride1,
k_cache_stride2,
v_cache_stride0,
v_cache_stride1,
v_cache_stride2,
DEQUANT: tl.constexpr, DEQUANT: tl.constexpr,
PAGE_SIZE: tl.constexpr, PAGE_SIZE: tl.constexpr,
CACHE_FORMAT: tl.constexpr, CACHE_FORMAT: tl.constexpr,
@@ -90,15 +96,15 @@ if current_platform.is_rocm():
# V: [num_blocks, page_size, num_head, head_dim] # V: [num_blocks, page_size, num_head, head_dim]
key_cache_ptr_offset = ( key_cache_ptr_offset = (
key_cache_ptr key_cache_ptr
+ block_id * num_heads * head_size * PAGE_SIZE + block_id * k_cache_stride0
+ slot_id * num_heads * head_size + slot_id * k_cache_stride1
+ head_id * head_size + head_id * k_cache_stride2
) )
value_cache_ptr_offset = ( value_cache_ptr_offset = (
value_cache_ptr value_cache_ptr
+ block_id * num_heads * head_size * PAGE_SIZE + block_id * v_cache_stride0
+ slot_id * num_heads * head_size + slot_id * v_cache_stride1
+ head_id * head_size + head_id * v_cache_stride2
) )
k_reg = tl.load(key_cache_ptr_offset + col_offsets) k_reg = tl.load(key_cache_ptr_offset + col_offsets)
v_reg = tl.load(value_cache_ptr_offset + col_offsets) v_reg = tl.load(value_cache_ptr_offset + col_offsets)
@@ -171,6 +177,11 @@ if current_platform.is_rocm():
page_size = key_cache.shape[1] page_size = key_cache.shape[1]
num_heads = key_cache.shape[2] num_heads = key_cache.shape[2]
# Pass actual tensor strides so the kernel works correctly
# even when the cache is non-contiguous (e.g, for hybrid model)
k_strides = key_cache.stride()
v_strides = value_cache.stride()
grid = lambda meta: (total_tokens, num_heads) grid = lambda meta: (total_tokens, num_heads)
cp_mha_gather_cache_kernel[grid]( cp_mha_gather_cache_kernel[grid](
key_cache, key_cache,
@@ -187,6 +198,12 @@ if current_platform.is_rocm():
head_dim, head_dim,
x, x,
block_tables.size(1), block_tables.size(1),
k_strides[0],
k_strides[1],
k_strides[2],
v_strides[0],
v_strides[1],
v_strides[2],
DEQUANT=dequant, DEQUANT=dequant,
PAGE_SIZE=page_size, PAGE_SIZE=page_size,
CACHE_FORMAT=kv_cache_layout, CACHE_FORMAT=kv_cache_layout,
@@ -1337,6 +1354,8 @@ class AiterFlashAttentionImpl(AttentionImpl):
assert k_scale is not None and v_scale is not None, ( assert k_scale is not None and v_scale is not None, (
"k_scale and v_scale are required for shuffled update" "k_scale and v_scale are required for shuffled update"
) )
# TODO: Add correct KV cache handling for hybrid model. KV cache
# may not be contiguous if mamba state exists.
reshape_and_cache_shuffle_triton( reshape_and_cache_shuffle_triton(
key, key,
value, value,