[Bugfix] Relax TRTLLM KV cache contiguity assertion for cross-layer layout (#34158)
Signed-off-by: Itay Etelis <itay.etelis@ibm.com> Co-authored-by: Itay Etelis <itay.etelis@ibm.com>
This commit is contained in:
@@ -586,6 +586,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
# try to use fp8 q if kv cache is fp8, and will fall back to model dtype
|
||||
# if TRTLLM attention kernel is not used when building attn metadata
|
||||
can_use_trtllm = can_use_trtllm_attention(self.num_qo_heads, self.num_kv_heads)
|
||||
|
||||
if (
|
||||
can_use_trtllm
|
||||
and not vllm_config.attention_config.disable_flashinfer_q_quantization
|
||||
@@ -1436,7 +1437,6 @@ class FlashInferImpl(AttentionImpl):
|
||||
# This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
|
||||
assert get_kv_cache_layout() == "HND"
|
||||
assert is_strictly_contiguous(prefill_query)
|
||||
assert is_strictly_contiguous(kv_cache_permute)
|
||||
assert is_strictly_contiguous(workspace_buffer)
|
||||
assert is_strictly_contiguous(block_tables_prefill)
|
||||
assert is_strictly_contiguous(seq_lens_prefill)
|
||||
@@ -1461,6 +1461,20 @@ class FlashInferImpl(AttentionImpl):
|
||||
# and fp8 kv cache. So to enable prefill attention
|
||||
# with fp8 kv cache, we can construct a mock block
|
||||
# and mock kv cache with BF16 KV involved in the prefill
|
||||
#
|
||||
# The inner (block_size, head_size) dims must be
|
||||
# contiguous; outer dims may have non-canonical strides
|
||||
# (e.g. cross-layer unified allocation).
|
||||
# Degenerate strides on outer dims break TMA descriptors
|
||||
# (see flashinfer-ai/flashinfer#2232).
|
||||
kv_strides = kv_cache_permute.stride()
|
||||
assert (
|
||||
kv_strides[-1] == 1
|
||||
and kv_strides[-2] == kv_cache_permute.shape[-1]
|
||||
), (
|
||||
"KV cache inner dims (block_size, head_size) must be "
|
||||
f"contiguous, got strides {kv_strides}"
|
||||
)
|
||||
mock_kv_cache, mock_block_table = trtllm_prefill_attn_kvfp8_dequant(
|
||||
kv_cache_permute,
|
||||
block_tables_prefill,
|
||||
@@ -1549,10 +1563,21 @@ class FlashInferImpl(AttentionImpl):
|
||||
# This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
|
||||
assert get_kv_cache_layout() == "HND"
|
||||
assert is_strictly_contiguous(decode_query)
|
||||
assert is_strictly_contiguous(kv_cache_permute)
|
||||
assert is_strictly_contiguous(workspace_buffer)
|
||||
assert is_strictly_contiguous(block_tables_decode)
|
||||
assert is_strictly_contiguous(seq_lens_decode)
|
||||
# kv_cache outer dims may be non-contiguous (e.g.
|
||||
# cross-layer unified allocation), but inner dims
|
||||
# (block_size, head_size) must be contiguous and
|
||||
# strides must be canonical to avoid TMA descriptor
|
||||
# failures (see flashinfer-ai/flashinfer#2232).
|
||||
kv_strides = kv_cache_permute.stride()
|
||||
assert (
|
||||
kv_strides[-1] == 1 and kv_strides[-2] == kv_cache_permute.shape[-1]
|
||||
), (
|
||||
"KV cache inner dims (block_size, head_size) must be "
|
||||
f"contiguous, got strides {kv_strides}"
|
||||
)
|
||||
|
||||
if output.dtype == FP4_DTYPE:
|
||||
assert self.o_sf_scale is not None
|
||||
|
||||
Reference in New Issue
Block a user