[Bugfix] Fix the tensor non-contiguous issue for Flashinfer TRT-LLM backend attention kernel (#21133)
This commit is contained in:
@@ -353,8 +353,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
attn_metadata.decode_wrapper = self._get_decode_wrapper()
|
attn_metadata.decode_wrapper = self._get_decode_wrapper()
|
||||||
if not FlashInferBackend.use_trtllm_decode_attention(
|
if not FlashInferBackend.use_trtllm_decode_attention(
|
||||||
num_decodes, attn_metadata.max_seq_len,
|
num_decodes, attn_metadata.max_seq_len,
|
||||||
attn_metadata.kv_data_type, attn_metadata.num_qo_heads,
|
self.cache_config.cache_dtype,
|
||||||
attn_metadata.num_kv_heads, attn_metadata.head_dim):
|
attn_metadata.num_qo_heads, attn_metadata.num_kv_heads,
|
||||||
|
attn_metadata.head_dim):
|
||||||
attn_metadata.decode_wrapper.plan(
|
attn_metadata.decode_wrapper.plan(
|
||||||
attn_metadata.paged_kv_indptr[:num_decodes + 1],
|
attn_metadata.paged_kv_indptr[:num_decodes + 1],
|
||||||
attn_metadata.paged_kv_indices,
|
attn_metadata.paged_kv_indices,
|
||||||
@@ -539,10 +540,10 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
query: shape = [num_tokens, num_heads, head_size]
|
query: shape = [num_tokens, num_heads, head_size]
|
||||||
key: shape = [num_tokens, num_kv_heads, head_size]
|
key: shape = [num_tokens, num_kv_heads, head_size]
|
||||||
value: shape = [num_tokens, num_kv_heads, head_size]
|
value: shape = [num_tokens, num_kv_heads, head_size]
|
||||||
kv_cache: shape -
|
kv_cache: shape -
|
||||||
# NHD: [num_blocks, 2, block_size, num_kv_heads, head_size]
|
# NHD: [num_blocks, 2, block_size, num_kv_heads, head_size]
|
||||||
# HND: [num_blocks, 2, num_kv_heads, block_size, head_size]
|
# HND: [num_blocks, 2, num_kv_heads, block_size, head_size]
|
||||||
|
|
||||||
|
|
||||||
attn_metadata: Metadata for attention.
|
attn_metadata: Metadata for attention.
|
||||||
Returns:
|
Returns:
|
||||||
@@ -614,6 +615,7 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||||
|
|
||||||
stride_order = FlashInferBackend.get_kv_cache_stride_order()
|
stride_order = FlashInferBackend.get_kv_cache_stride_order()
|
||||||
|
kv_cache_permute = kv_cache.permute(*stride_order)
|
||||||
# Regular attention (common case).
|
# Regular attention (common case).
|
||||||
# Decodes are at the front and prefills are at the back,
|
# Decodes are at the front and prefills are at the back,
|
||||||
# according to reorder_batch()
|
# according to reorder_batch()
|
||||||
@@ -628,7 +630,7 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
assert prefill_wrapper._sm_scale == self.scale
|
assert prefill_wrapper._sm_scale == self.scale
|
||||||
prefill_wrapper.run(
|
prefill_wrapper.run(
|
||||||
prefill_query,
|
prefill_query,
|
||||||
kv_cache.permute(*stride_order),
|
kv_cache_permute,
|
||||||
k_scale=layer._k_scale_float,
|
k_scale=layer._k_scale_float,
|
||||||
v_scale=layer._v_scale_float,
|
v_scale=layer._v_scale_float,
|
||||||
out=output[num_decode_tokens:],
|
out=output[num_decode_tokens:],
|
||||||
@@ -647,7 +649,7 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
assert decode_wrapper._sm_scale == self.scale
|
assert decode_wrapper._sm_scale == self.scale
|
||||||
decode_wrapper.run(
|
decode_wrapper.run(
|
||||||
decode_query,
|
decode_query,
|
||||||
kv_cache.permute(*stride_order),
|
kv_cache_permute,
|
||||||
k_scale=layer._k_scale_float,
|
k_scale=layer._k_scale_float,
|
||||||
v_scale=layer._v_scale_float,
|
v_scale=layer._v_scale_float,
|
||||||
out=output[:num_decode_tokens],
|
out=output[:num_decode_tokens],
|
||||||
@@ -655,19 +657,29 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
else:
|
else:
|
||||||
# This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
|
# This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
|
||||||
if num_decode_tokens > 0:
|
if num_decode_tokens > 0:
|
||||||
|
# decode_query may be non-contiguous
|
||||||
|
decode_query = decode_query.contiguous()
|
||||||
|
block_tables_decode = attn_metadata.block_table_tensor[:
|
||||||
|
num_decode_tokens]
|
||||||
|
seq_lens_decode = attn_metadata.seq_lens[:
|
||||||
|
num_decode_tokens]
|
||||||
|
|
||||||
assert get_kv_cache_layout() == "HND"
|
assert get_kv_cache_layout() == "HND"
|
||||||
|
assert decode_query.is_contiguous()
|
||||||
|
assert kv_cache_permute.is_contiguous()
|
||||||
|
assert block_tables_decode.is_contiguous()
|
||||||
|
assert seq_lens_decode.is_contiguous()
|
||||||
|
|
||||||
output[:num_decode_tokens] = (
|
output[:num_decode_tokens] = (
|
||||||
trtllm_batch_decode_with_kv_cache(
|
trtllm_batch_decode_with_kv_cache(
|
||||||
query=decode_query,
|
query=decode_query,
|
||||||
kv_cache=kv_cache.permute(*stride_order),
|
kv_cache=kv_cache_permute,
|
||||||
workspace_buffer=attn_metadata.workspace_buffer,
|
workspace_buffer=attn_metadata.workspace_buffer,
|
||||||
num_heads=self.num_heads,
|
num_heads=self.num_heads,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
scale=self.scale,
|
scale=self.scale,
|
||||||
block_tables=attn_metadata.
|
block_tables=block_tables_decode,
|
||||||
block_table_tensor[:num_decode_tokens],
|
seq_lens=seq_lens_decode,
|
||||||
seq_lens=attn_metadata.
|
|
||||||
seq_lens[:num_decode_tokens],
|
|
||||||
block_size=attn_metadata.page_size,
|
block_size=attn_metadata.page_size,
|
||||||
max_seq_len=attn_metadata.max_seq_len,
|
max_seq_len=attn_metadata.max_seq_len,
|
||||||
kv_cache_dtype=self.kv_cache_dtype,
|
kv_cache_dtype=self.kv_cache_dtype,
|
||||||
|
|||||||
Reference in New Issue
Block a user