From a4cf9b22ba8f08df8d7977c2b0753fe20231cffc Mon Sep 17 00:00:00 2001 From: jennyyyyzhen <47012288+jennyyyyzhen@users.noreply.github.com> Date: Thu, 26 Mar 2026 10:33:39 -0700 Subject: [PATCH] [ROCM][Bugfix] Use correct stride in cp_mha_gather_cache_kernel for hybrid model (#37228) (#37228) Signed-off-by: jennyyyyzhen Co-authored-by: yZhen --- vllm/v1/attention/backends/rocm_aiter_fa.py | 31 +++++++++++++++++---- 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index 6834918b8..294414796 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -59,6 +59,12 @@ if current_platform.is_rocm(): head_size, x, max_block_num, + k_cache_stride0, + k_cache_stride1, + k_cache_stride2, + v_cache_stride0, + v_cache_stride1, + v_cache_stride2, DEQUANT: tl.constexpr, PAGE_SIZE: tl.constexpr, CACHE_FORMAT: tl.constexpr, @@ -90,15 +96,15 @@ if current_platform.is_rocm(): # V: [num_blocks, page_size, num_head, head_dim] key_cache_ptr_offset = ( key_cache_ptr - + block_id * num_heads * head_size * PAGE_SIZE - + slot_id * num_heads * head_size - + head_id * head_size + + block_id * k_cache_stride0 + + slot_id * k_cache_stride1 + + head_id * k_cache_stride2 ) value_cache_ptr_offset = ( value_cache_ptr - + block_id * num_heads * head_size * PAGE_SIZE - + slot_id * num_heads * head_size - + head_id * head_size + + block_id * v_cache_stride0 + + slot_id * v_cache_stride1 + + head_id * v_cache_stride2 ) k_reg = tl.load(key_cache_ptr_offset + col_offsets) v_reg = tl.load(value_cache_ptr_offset + col_offsets) @@ -171,6 +177,11 @@ if current_platform.is_rocm(): page_size = key_cache.shape[1] num_heads = key_cache.shape[2] + # Pass actual tensor strides so the kernel works correctly + # even when the cache is non-contiguous (e.g, for hybrid model) + k_strides = key_cache.stride() + v_strides = value_cache.stride() + grid = lambda meta: (total_tokens, num_heads) cp_mha_gather_cache_kernel[grid]( key_cache, @@ -187,6 +198,12 @@ if current_platform.is_rocm(): head_dim, x, block_tables.size(1), + k_strides[0], + k_strides[1], + k_strides[2], + v_strides[0], + v_strides[1], + v_strides[2], DEQUANT=dequant, PAGE_SIZE=page_size, CACHE_FORMAT=kv_cache_layout, @@ -1337,6 +1354,8 @@ class AiterFlashAttentionImpl(AttentionImpl): assert k_scale is not None and v_scale is not None, ( "k_scale and v_scale are required for shuffled update" ) + # TODO: Add correct KV cache handling for hybrid model. KV cache + # may not be contiguous if mamba state exists. reshape_and_cache_shuffle_triton( key, value,