[Attention][MLA] Make FLASHINFER_MLA the default MLA backend on Blackwell, and TRTLLM the default prefill (#32339)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
Matthew Bonanni
2026-01-15 09:49:57 -05:00
committed by GitHub
parent b89275d018
commit 8ebfacaa75
3 changed files with 16 additions and 15 deletions

View File

@@ -450,7 +450,6 @@ def use_flashinfer_prefill() -> bool:
not vllm_config.attention_config.disable_flashinfer_prefill
and flashinfer_available
and not vllm_config.attention_config.use_cudnn_prefill
and not vllm_config.attention_config.use_trtllm_ragged_deepseek_prefill
and current_platform.is_device_capability_family(100)
)
@@ -1294,25 +1293,27 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
if use_flashinfer_prefill():
logger.debug_once("Using FlashInfer prefill for MLA")
self._run_prefill_context_chunk = self._run_prefill_context_chunk_fi
self._run_prefill_new_tokens = self._run_prefill_new_tokens_fi
self._pad_v = False
elif use_trtllm_ragged_deepseek_prefill():
logger.debug_once("Using TRT-LLM ragged DeepSeek prefill for MLA")
if use_trtllm_ragged_deepseek_prefill():
logger.info_once(
"Using TRT-LLM ragged DeepSeek prefill for MLA", scope="local"
)
self._run_prefill_context_chunk = (
self._run_prefill_context_chunk_trtllm_ragged
)
self._run_prefill_new_tokens = self._run_prefill_new_tokens_trtllm_ragged
self._pad_v = False
elif use_flashinfer_prefill():
logger.info_once("Using FlashInfer prefill for MLA")
self._run_prefill_context_chunk = self._run_prefill_context_chunk_fi
self._run_prefill_new_tokens = self._run_prefill_new_tokens_fi
self._pad_v = False
elif use_cudnn_prefill():
logger.debug_once("Using CUDNN prefill for MLA")
logger.info_once("Using CUDNN prefill for MLA", scope="local")
self._run_prefill_context_chunk = self._run_prefill_context_chunk_cudnn
self._run_prefill_new_tokens = self._run_prefill_new_tokens_cudnn
self._pad_v = False
else: # Use FlashAttention
logger.debug_once("Using FlashAttention prefill for MLA")
logger.info_once("Using FlashAttention prefill for MLA", scope="local")
self._run_prefill_context_chunk = self._run_prefill_context_chunk_fa
self._run_prefill_new_tokens = self._run_prefill_new_tokens_fa