diff --git a/vllm/config/attention.py b/vllm/config/attention.py index 354ef056c..293045787 100644 --- a/vllm/config/attention.py +++ b/vllm/config/attention.py @@ -35,7 +35,7 @@ class AttentionConfig: use_cudnn_prefill: bool = False """Whether to use cudnn prefill.""" - use_trtllm_ragged_deepseek_prefill: bool = True + use_trtllm_ragged_deepseek_prefill: bool = False """Whether to use TRTLLM ragged deepseek prefill.""" use_trtllm_attention: bool | None = None diff --git a/vllm/model_executor/layers/attention/mla_attention.py b/vllm/model_executor/layers/attention/mla_attention.py index e9cfa4a08..fc2a4fad9 100755 --- a/vllm/model_executor/layers/attention/mla_attention.py +++ b/vllm/model_executor/layers/attention/mla_attention.py @@ -450,6 +450,7 @@ def use_flashinfer_prefill() -> bool: not vllm_config.attention_config.disable_flashinfer_prefill and flashinfer_available and not vllm_config.attention_config.use_cudnn_prefill + and not vllm_config.attention_config.use_trtllm_ragged_deepseek_prefill and current_platform.is_device_capability_family(100) ) @@ -1323,27 +1324,25 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) - if use_trtllm_ragged_deepseek_prefill(): - logger.info_once( - "Using TRT-LLM ragged DeepSeek prefill for MLA", scope="local" - ) + if use_flashinfer_prefill(): + logger.debug_once("Using FlashInfer prefill for MLA") + self._run_prefill_context_chunk = self._run_prefill_context_chunk_fi + self._run_prefill_new_tokens = self._run_prefill_new_tokens_fi + self._pad_v = False + elif use_trtllm_ragged_deepseek_prefill(): + logger.debug_once("Using TRT-LLM ragged DeepSeek prefill for MLA") self._run_prefill_context_chunk = ( self._run_prefill_context_chunk_trtllm_ragged ) self._run_prefill_new_tokens = self._run_prefill_new_tokens_trtllm_ragged self._pad_v = False - elif use_flashinfer_prefill(): - logger.info_once("Using FlashInfer prefill for MLA") - self._run_prefill_context_chunk = self._run_prefill_context_chunk_fi - self._run_prefill_new_tokens = self._run_prefill_new_tokens_fi - self._pad_v = False elif use_cudnn_prefill(): - logger.info_once("Using CUDNN prefill for MLA", scope="local") + logger.debug_once("Using CUDNN prefill for MLA") self._run_prefill_context_chunk = self._run_prefill_context_chunk_cudnn self._run_prefill_new_tokens = self._run_prefill_new_tokens_cudnn self._pad_v = False else: # Use FlashAttention - logger.info_once("Using FlashAttention prefill for MLA", scope="local") + logger.debug_once("Using FlashAttention prefill for MLA") self._run_prefill_context_chunk = self._run_prefill_context_chunk_fa self._run_prefill_new_tokens = self._run_prefill_new_tokens_fa diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 770fbde69..47d634416 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -50,8 +50,8 @@ def _get_backend_priorities( if use_mla: if device_capability.major == 10: return [ - AttentionBackendEnum.FLASHINFER_MLA, AttentionBackendEnum.CUTLASS_MLA, + AttentionBackendEnum.FLASHINFER_MLA, AttentionBackendEnum.FLASH_ATTN_MLA, AttentionBackendEnum.FLASHMLA, AttentionBackendEnum.TRITON_MLA, @@ -183,12 +183,12 @@ class CudaPlatformBase(Platform): if vllm_config.attention_config.backend is None: # Default case if cls.is_device_capability_family(100) and not use_sparse: - # Blackwell => Force FlashInferMLA (unless sparse, i.e. DSv3.2). - use_flashinfer_mla = True + # Blackwell => Force CutlassMLA (unless sparse, i.e. DSv3.2). + use_cutlass_mla = True # Set the backend in AttentionConfig so it's used during # backend selection vllm_config.attention_config.backend = ( - AttentionBackendEnum.FLASHINFER_MLA + AttentionBackendEnum.CUTLASS_MLA ) else: # Not Blackwell