Signed-off-by: jennyyyyzhen <yzhen@hmc.edu> Co-authored-by: yZhen <yZhen@fb.com>
This commit is contained in:
@@ -59,6 +59,12 @@ if current_platform.is_rocm():
|
|||||||
head_size,
|
head_size,
|
||||||
x,
|
x,
|
||||||
max_block_num,
|
max_block_num,
|
||||||
|
k_cache_stride0,
|
||||||
|
k_cache_stride1,
|
||||||
|
k_cache_stride2,
|
||||||
|
v_cache_stride0,
|
||||||
|
v_cache_stride1,
|
||||||
|
v_cache_stride2,
|
||||||
DEQUANT: tl.constexpr,
|
DEQUANT: tl.constexpr,
|
||||||
PAGE_SIZE: tl.constexpr,
|
PAGE_SIZE: tl.constexpr,
|
||||||
CACHE_FORMAT: tl.constexpr,
|
CACHE_FORMAT: tl.constexpr,
|
||||||
@@ -90,15 +96,15 @@ if current_platform.is_rocm():
|
|||||||
# V: [num_blocks, page_size, num_head, head_dim]
|
# V: [num_blocks, page_size, num_head, head_dim]
|
||||||
key_cache_ptr_offset = (
|
key_cache_ptr_offset = (
|
||||||
key_cache_ptr
|
key_cache_ptr
|
||||||
+ block_id * num_heads * head_size * PAGE_SIZE
|
+ block_id * k_cache_stride0
|
||||||
+ slot_id * num_heads * head_size
|
+ slot_id * k_cache_stride1
|
||||||
+ head_id * head_size
|
+ head_id * k_cache_stride2
|
||||||
)
|
)
|
||||||
value_cache_ptr_offset = (
|
value_cache_ptr_offset = (
|
||||||
value_cache_ptr
|
value_cache_ptr
|
||||||
+ block_id * num_heads * head_size * PAGE_SIZE
|
+ block_id * v_cache_stride0
|
||||||
+ slot_id * num_heads * head_size
|
+ slot_id * v_cache_stride1
|
||||||
+ head_id * head_size
|
+ head_id * v_cache_stride2
|
||||||
)
|
)
|
||||||
k_reg = tl.load(key_cache_ptr_offset + col_offsets)
|
k_reg = tl.load(key_cache_ptr_offset + col_offsets)
|
||||||
v_reg = tl.load(value_cache_ptr_offset + col_offsets)
|
v_reg = tl.load(value_cache_ptr_offset + col_offsets)
|
||||||
@@ -171,6 +177,11 @@ if current_platform.is_rocm():
|
|||||||
page_size = key_cache.shape[1]
|
page_size = key_cache.shape[1]
|
||||||
num_heads = key_cache.shape[2]
|
num_heads = key_cache.shape[2]
|
||||||
|
|
||||||
|
# Pass actual tensor strides so the kernel works correctly
|
||||||
|
# even when the cache is non-contiguous (e.g, for hybrid model)
|
||||||
|
k_strides = key_cache.stride()
|
||||||
|
v_strides = value_cache.stride()
|
||||||
|
|
||||||
grid = lambda meta: (total_tokens, num_heads)
|
grid = lambda meta: (total_tokens, num_heads)
|
||||||
cp_mha_gather_cache_kernel[grid](
|
cp_mha_gather_cache_kernel[grid](
|
||||||
key_cache,
|
key_cache,
|
||||||
@@ -187,6 +198,12 @@ if current_platform.is_rocm():
|
|||||||
head_dim,
|
head_dim,
|
||||||
x,
|
x,
|
||||||
block_tables.size(1),
|
block_tables.size(1),
|
||||||
|
k_strides[0],
|
||||||
|
k_strides[1],
|
||||||
|
k_strides[2],
|
||||||
|
v_strides[0],
|
||||||
|
v_strides[1],
|
||||||
|
v_strides[2],
|
||||||
DEQUANT=dequant,
|
DEQUANT=dequant,
|
||||||
PAGE_SIZE=page_size,
|
PAGE_SIZE=page_size,
|
||||||
CACHE_FORMAT=kv_cache_layout,
|
CACHE_FORMAT=kv_cache_layout,
|
||||||
@@ -1337,6 +1354,8 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
|||||||
assert k_scale is not None and v_scale is not None, (
|
assert k_scale is not None and v_scale is not None, (
|
||||||
"k_scale and v_scale are required for shuffled update"
|
"k_scale and v_scale are required for shuffled update"
|
||||||
)
|
)
|
||||||
|
# TODO: Add correct KV cache handling for hybrid model. KV cache
|
||||||
|
# may not be contiguous if mamba state exists.
|
||||||
reshape_and_cache_shuffle_triton(
|
reshape_and_cache_shuffle_triton(
|
||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
|
|||||||
Reference in New Issue
Block a user