[Bugfix] Relax TRTLLM KV cache contiguity assertion for cross-layer layout (#34158)

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Co-authored-by: Itay Etelis <itay.etelis@ibm.com>
This commit is contained in:
Itay Etelis
2026-03-16 17:20:51 +02:00
committed by GitHub
parent ce8cf9161d
commit 5ae685c1c8

View File

@@ -586,6 +586,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# try to use fp8 q if kv cache is fp8, and will fall back to model dtype
# if TRTLLM attention kernel is not used when building attn metadata
can_use_trtllm = can_use_trtllm_attention(self.num_qo_heads, self.num_kv_heads)
if (
can_use_trtllm
and not vllm_config.attention_config.disable_flashinfer_q_quantization
@@ -1436,7 +1437,6 @@ class FlashInferImpl(AttentionImpl):
# This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
assert get_kv_cache_layout() == "HND"
assert is_strictly_contiguous(prefill_query)
assert is_strictly_contiguous(kv_cache_permute)
assert is_strictly_contiguous(workspace_buffer)
assert is_strictly_contiguous(block_tables_prefill)
assert is_strictly_contiguous(seq_lens_prefill)
@@ -1461,6 +1461,20 @@ class FlashInferImpl(AttentionImpl):
# and fp8 kv cache. So to enable prefill attention
# with fp8 kv cache, we can construct a mock block
# and mock kv cache with BF16 KV involved in the prefill
#
# The inner (block_size, head_size) dims must be
# contiguous; outer dims may have non-canonical strides
# (e.g. cross-layer unified allocation).
# Degenerate strides on outer dims break TMA descriptors
# (see flashinfer-ai/flashinfer#2232).
kv_strides = kv_cache_permute.stride()
assert (
kv_strides[-1] == 1
and kv_strides[-2] == kv_cache_permute.shape[-1]
), (
"KV cache inner dims (block_size, head_size) must be "
f"contiguous, got strides {kv_strides}"
)
mock_kv_cache, mock_block_table = trtllm_prefill_attn_kvfp8_dequant(
kv_cache_permute,
block_tables_prefill,
@@ -1549,10 +1563,21 @@ class FlashInferImpl(AttentionImpl):
# This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
assert get_kv_cache_layout() == "HND"
assert is_strictly_contiguous(decode_query)
assert is_strictly_contiguous(kv_cache_permute)
assert is_strictly_contiguous(workspace_buffer)
assert is_strictly_contiguous(block_tables_decode)
assert is_strictly_contiguous(seq_lens_decode)
# kv_cache outer dims may be non-contiguous (e.g.
# cross-layer unified allocation), but inner dims
# (block_size, head_size) must be contiguous and
# strides must be canonical to avoid TMA descriptor
# failures (see flashinfer-ai/flashinfer#2232).
kv_strides = kv_cache_permute.stride()
assert (
kv_strides[-1] == 1 and kv_strides[-2] == kv_cache_permute.shape[-1]
), (
"KV cache inner dims (block_size, head_size) must be "
f"contiguous, got strides {kv_strides}"
)
if output.dtype == FP4_DTYPE:
assert self.o_sf_scale is not None