Revert "[Bugfix] Disable TRTLLM attention with KV transfer enabled (#33192)" (#34832)

Signed-off-by: Zhanqiu Hu <zh338@cornell.edu>
This commit is contained in:
zhanqiuhu
2026-03-01 17:32:37 -05:00
committed by GitHub
parent e82fbeec7b
commit 57a96e26c9

View File

@@ -575,20 +575,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# try to use fp8 q if kv cache is fp8, and will fall back to model dtype # try to use fp8 q if kv cache is fp8, and will fall back to model dtype
# if TRTLLM attention kernel is not used when building attn metadata # if TRTLLM attention kernel is not used when building attn metadata
can_use_trtllm = can_use_trtllm_attention(self.num_qo_heads, self.num_kv_heads) can_use_trtllm = can_use_trtllm_attention(self.num_qo_heads, self.num_kv_heads)
# TRTLLM attention requires strictly contiguous KV cache tensors.
# When KV transfer (P/D disaggregation) is enabled, the KV cache may be
# permuted into non-contiguous views, which causes assertion failures.
self._kv_transfer_enabled = vllm_config.kv_transfer_config is not None
if can_use_trtllm and self._kv_transfer_enabled:
logger.info_once(
"TRTLLM attention is disabled because KV transfer "
"(P/D disaggregation) is enabled. TRTLLM attention requires "
"strictly contiguous KV cache tensors which may not be "
"guaranteed with KV transfer."
)
can_use_trtllm = False
if ( if (
can_use_trtllm can_use_trtllm
and not vllm_config.attention_config.disable_flashinfer_q_quantization and not vllm_config.attention_config.disable_flashinfer_q_quantization
@@ -865,9 +851,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
has_sinks=self.has_sinks, has_sinks=self.has_sinks,
has_spec=uses_spec_reorder, has_spec=uses_spec_reorder,
) )
# KV transfer requires non-contiguous KV cache views, incompatible with TRTLLM
if self._kv_transfer_enabled:
prefill_use_trtllm = False
decode_use_trtllm = ( decode_use_trtllm = (
self.use_trtllm_decode_attention and self.dcp_world_size <= 1 self.use_trtllm_decode_attention and self.dcp_world_size <= 1
) )