[Bugfix] Fix the tensor non-contiguous issue for Flashinfer TRT-LLM backend attention kernel (#21133)

This commit is contained in:
elvischenv
2025-07-18 08:35:58 +08:00
committed by GitHub
parent 8a8fc94639
commit 8dfb45ca33

View File

@@ -353,8 +353,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
attn_metadata.decode_wrapper = self._get_decode_wrapper() attn_metadata.decode_wrapper = self._get_decode_wrapper()
if not FlashInferBackend.use_trtllm_decode_attention( if not FlashInferBackend.use_trtllm_decode_attention(
num_decodes, attn_metadata.max_seq_len, num_decodes, attn_metadata.max_seq_len,
attn_metadata.kv_data_type, attn_metadata.num_qo_heads, self.cache_config.cache_dtype,
attn_metadata.num_kv_heads, attn_metadata.head_dim): attn_metadata.num_qo_heads, attn_metadata.num_kv_heads,
attn_metadata.head_dim):
attn_metadata.decode_wrapper.plan( attn_metadata.decode_wrapper.plan(
attn_metadata.paged_kv_indptr[:num_decodes + 1], attn_metadata.paged_kv_indptr[:num_decodes + 1],
attn_metadata.paged_kv_indices, attn_metadata.paged_kv_indices,
@@ -539,10 +540,10 @@ class FlashInferImpl(AttentionImpl):
query: shape = [num_tokens, num_heads, head_size] query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size] key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size] value: shape = [num_tokens, num_kv_heads, head_size]
kv_cache: shape - kv_cache: shape -
# NHD: [num_blocks, 2, block_size, num_kv_heads, head_size] # NHD: [num_blocks, 2, block_size, num_kv_heads, head_size]
# HND: [num_blocks, 2, num_kv_heads, block_size, head_size] # HND: [num_blocks, 2, num_kv_heads, block_size, head_size]
attn_metadata: Metadata for attention. attn_metadata: Metadata for attention.
Returns: Returns:
@@ -614,6 +615,7 @@ class FlashInferImpl(AttentionImpl):
num_prefill_tokens = attn_metadata.num_prefill_tokens num_prefill_tokens = attn_metadata.num_prefill_tokens
stride_order = FlashInferBackend.get_kv_cache_stride_order() stride_order = FlashInferBackend.get_kv_cache_stride_order()
kv_cache_permute = kv_cache.permute(*stride_order)
# Regular attention (common case). # Regular attention (common case).
# Decodes are at the front and prefills are at the back, # Decodes are at the front and prefills are at the back,
# according to reorder_batch() # according to reorder_batch()
@@ -628,7 +630,7 @@ class FlashInferImpl(AttentionImpl):
assert prefill_wrapper._sm_scale == self.scale assert prefill_wrapper._sm_scale == self.scale
prefill_wrapper.run( prefill_wrapper.run(
prefill_query, prefill_query,
kv_cache.permute(*stride_order), kv_cache_permute,
k_scale=layer._k_scale_float, k_scale=layer._k_scale_float,
v_scale=layer._v_scale_float, v_scale=layer._v_scale_float,
out=output[num_decode_tokens:], out=output[num_decode_tokens:],
@@ -647,7 +649,7 @@ class FlashInferImpl(AttentionImpl):
assert decode_wrapper._sm_scale == self.scale assert decode_wrapper._sm_scale == self.scale
decode_wrapper.run( decode_wrapper.run(
decode_query, decode_query,
kv_cache.permute(*stride_order), kv_cache_permute,
k_scale=layer._k_scale_float, k_scale=layer._k_scale_float,
v_scale=layer._v_scale_float, v_scale=layer._v_scale_float,
out=output[:num_decode_tokens], out=output[:num_decode_tokens],
@@ -655,19 +657,29 @@ class FlashInferImpl(AttentionImpl):
else: else:
# This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND # This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
if num_decode_tokens > 0: if num_decode_tokens > 0:
# decode_query may be non-contiguous
decode_query = decode_query.contiguous()
block_tables_decode = attn_metadata.block_table_tensor[:
num_decode_tokens]
seq_lens_decode = attn_metadata.seq_lens[:
num_decode_tokens]
assert get_kv_cache_layout() == "HND" assert get_kv_cache_layout() == "HND"
assert decode_query.is_contiguous()
assert kv_cache_permute.is_contiguous()
assert block_tables_decode.is_contiguous()
assert seq_lens_decode.is_contiguous()
output[:num_decode_tokens] = ( output[:num_decode_tokens] = (
trtllm_batch_decode_with_kv_cache( trtllm_batch_decode_with_kv_cache(
query=decode_query, query=decode_query,
kv_cache=kv_cache.permute(*stride_order), kv_cache=kv_cache_permute,
workspace_buffer=attn_metadata.workspace_buffer, workspace_buffer=attn_metadata.workspace_buffer,
num_heads=self.num_heads, num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
scale=self.scale, scale=self.scale,
block_tables=attn_metadata. block_tables=block_tables_decode,
block_table_tensor[:num_decode_tokens], seq_lens=seq_lens_decode,
seq_lens=attn_metadata.
seq_lens[:num_decode_tokens],
block_size=attn_metadata.page_size, block_size=attn_metadata.page_size,
max_seq_len=attn_metadata.max_seq_len, max_seq_len=attn_metadata.max_seq_len,
kv_cache_dtype=self.kv_cache_dtype, kv_cache_dtype=self.kv_cache_dtype,