[Bugfix] Disable TRTLLM attention when KV transfer is enabled (#33192)
Signed-off-by: Zhanqiu Hu <zh338@cornell.edu>
This commit is contained in:
@@ -573,6 +573,20 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
# try to use fp8 q if kv cache is fp8, and will fall back to model dtype
|
||||
# if TRTLLM attention kernel is not used when building attn metadata
|
||||
can_use_trtllm = can_use_trtllm_attention(self.num_qo_heads, self.num_kv_heads)
|
||||
|
||||
# TRTLLM attention requires strictly contiguous KV cache tensors.
|
||||
# When KV transfer (P/D disaggregation) is enabled, the KV cache may be
|
||||
# permuted into non-contiguous views, which causes assertion failures.
|
||||
self._kv_transfer_enabled = vllm_config.kv_transfer_config is not None
|
||||
if can_use_trtllm and self._kv_transfer_enabled:
|
||||
logger.info_once(
|
||||
"TRTLLM attention is disabled because KV transfer "
|
||||
"(P/D disaggregation) is enabled. TRTLLM attention requires "
|
||||
"strictly contiguous KV cache tensors which may not be "
|
||||
"guaranteed with KV transfer."
|
||||
)
|
||||
can_use_trtllm = False
|
||||
|
||||
if (
|
||||
can_use_trtllm
|
||||
and not vllm_config.attention_config.disable_flashinfer_q_quantization
|
||||
@@ -822,6 +836,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
has_sinks=self.has_sinks,
|
||||
has_spec=uses_spec_reorder,
|
||||
)
|
||||
# KV transfer requires non-contiguous KV cache views, incompatible with TRTLLM
|
||||
if self._kv_transfer_enabled:
|
||||
prefill_use_trtllm = False
|
||||
decode_use_trtllm = (
|
||||
self.use_trtllm_decode_attention and self.dcp_world_size <= 1
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user