diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index a79a7480b..7e272ab25 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -586,6 +586,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): # try to use fp8 q if kv cache is fp8, and will fall back to model dtype # if TRTLLM attention kernel is not used when building attn metadata can_use_trtllm = can_use_trtllm_attention(self.num_qo_heads, self.num_kv_heads) + if ( can_use_trtllm and not vllm_config.attention_config.disable_flashinfer_q_quantization @@ -1436,7 +1437,6 @@ class FlashInferImpl(AttentionImpl): # This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND assert get_kv_cache_layout() == "HND" assert is_strictly_contiguous(prefill_query) - assert is_strictly_contiguous(kv_cache_permute) assert is_strictly_contiguous(workspace_buffer) assert is_strictly_contiguous(block_tables_prefill) assert is_strictly_contiguous(seq_lens_prefill) @@ -1461,6 +1461,20 @@ class FlashInferImpl(AttentionImpl): # and fp8 kv cache. So to enable prefill attention # with fp8 kv cache, we can construct a mock block # and mock kv cache with BF16 KV involved in the prefill + # + # The inner (block_size, head_size) dims must be + # contiguous; outer dims may have non-canonical strides + # (e.g. cross-layer unified allocation). + # Degenerate strides on outer dims break TMA descriptors + # (see flashinfer-ai/flashinfer#2232). + kv_strides = kv_cache_permute.stride() + assert ( + kv_strides[-1] == 1 + and kv_strides[-2] == kv_cache_permute.shape[-1] + ), ( + "KV cache inner dims (block_size, head_size) must be " + f"contiguous, got strides {kv_strides}" + ) mock_kv_cache, mock_block_table = trtllm_prefill_attn_kvfp8_dequant( kv_cache_permute, block_tables_prefill, @@ -1549,10 +1563,21 @@ class FlashInferImpl(AttentionImpl): # This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND assert get_kv_cache_layout() == "HND" assert is_strictly_contiguous(decode_query) - assert is_strictly_contiguous(kv_cache_permute) assert is_strictly_contiguous(workspace_buffer) assert is_strictly_contiguous(block_tables_decode) assert is_strictly_contiguous(seq_lens_decode) + # kv_cache outer dims may be non-contiguous (e.g. + # cross-layer unified allocation), but inner dims + # (block_size, head_size) must be contiguous and + # strides must be canonical to avoid TMA descriptor + # failures (see flashinfer-ai/flashinfer#2232). + kv_strides = kv_cache_permute.stride() + assert ( + kv_strides[-1] == 1 and kv_strides[-2] == kv_cache_permute.shape[-1] + ), ( + "KV cache inner dims (block_size, head_size) must be " + f"contiguous, got strides {kv_strides}" + ) if output.dtype == FP4_DTYPE: assert self.o_sf_scale is not None