Signed-off-by: jennyyyyzhen <yzhen@hmc.edu> Co-authored-by: yZhen <yZhen@fb.com>
This commit is contained in:
@@ -59,6 +59,12 @@ if current_platform.is_rocm():
|
||||
head_size,
|
||||
x,
|
||||
max_block_num,
|
||||
k_cache_stride0,
|
||||
k_cache_stride1,
|
||||
k_cache_stride2,
|
||||
v_cache_stride0,
|
||||
v_cache_stride1,
|
||||
v_cache_stride2,
|
||||
DEQUANT: tl.constexpr,
|
||||
PAGE_SIZE: tl.constexpr,
|
||||
CACHE_FORMAT: tl.constexpr,
|
||||
@@ -90,15 +96,15 @@ if current_platform.is_rocm():
|
||||
# V: [num_blocks, page_size, num_head, head_dim]
|
||||
key_cache_ptr_offset = (
|
||||
key_cache_ptr
|
||||
+ block_id * num_heads * head_size * PAGE_SIZE
|
||||
+ slot_id * num_heads * head_size
|
||||
+ head_id * head_size
|
||||
+ block_id * k_cache_stride0
|
||||
+ slot_id * k_cache_stride1
|
||||
+ head_id * k_cache_stride2
|
||||
)
|
||||
value_cache_ptr_offset = (
|
||||
value_cache_ptr
|
||||
+ block_id * num_heads * head_size * PAGE_SIZE
|
||||
+ slot_id * num_heads * head_size
|
||||
+ head_id * head_size
|
||||
+ block_id * v_cache_stride0
|
||||
+ slot_id * v_cache_stride1
|
||||
+ head_id * v_cache_stride2
|
||||
)
|
||||
k_reg = tl.load(key_cache_ptr_offset + col_offsets)
|
||||
v_reg = tl.load(value_cache_ptr_offset + col_offsets)
|
||||
@@ -171,6 +177,11 @@ if current_platform.is_rocm():
|
||||
page_size = key_cache.shape[1]
|
||||
num_heads = key_cache.shape[2]
|
||||
|
||||
# Pass actual tensor strides so the kernel works correctly
|
||||
# even when the cache is non-contiguous (e.g, for hybrid model)
|
||||
k_strides = key_cache.stride()
|
||||
v_strides = value_cache.stride()
|
||||
|
||||
grid = lambda meta: (total_tokens, num_heads)
|
||||
cp_mha_gather_cache_kernel[grid](
|
||||
key_cache,
|
||||
@@ -187,6 +198,12 @@ if current_platform.is_rocm():
|
||||
head_dim,
|
||||
x,
|
||||
block_tables.size(1),
|
||||
k_strides[0],
|
||||
k_strides[1],
|
||||
k_strides[2],
|
||||
v_strides[0],
|
||||
v_strides[1],
|
||||
v_strides[2],
|
||||
DEQUANT=dequant,
|
||||
PAGE_SIZE=page_size,
|
||||
CACHE_FORMAT=kv_cache_layout,
|
||||
@@ -1337,6 +1354,8 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
||||
assert k_scale is not None and v_scale is not None, (
|
||||
"k_scale and v_scale are required for shuffled update"
|
||||
)
|
||||
# TODO: Add correct KV cache handling for hybrid model. KV cache
|
||||
# may not be contiguous if mamba state exists.
|
||||
reshape_and_cache_shuffle_triton(
|
||||
key,
|
||||
value,
|
||||
|
||||
Reference in New Issue
Block a user