[Attention][MLA] Make FLASHINFER_MLA the default MLA backend on Blackwell, and TRTLLM the default prefill (#32615)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
@@ -455,7 +455,7 @@ steps:
|
|||||||
- vllm/v1/attention
|
- vllm/v1/attention
|
||||||
- tests/v1/attention
|
- tests/v1/attention
|
||||||
commands:
|
commands:
|
||||||
- VLLM_DISABLE_FLASHINFER_PREFILL=1 pytest -v -s v1/attention # TODO: FI prefill is bugged and causes incorrectness, fix this
|
- pytest -v -s v1/attention
|
||||||
|
|
||||||
- label: V1 Test others (CPU) # 5 mins
|
- label: V1 Test others (CPU) # 5 mins
|
||||||
mirror_hardwares: [amdexperimental, amdproduction, amdtentative]
|
mirror_hardwares: [amdexperimental, amdproduction, amdtentative]
|
||||||
|
|||||||
@@ -399,7 +399,7 @@ steps:
|
|||||||
- vllm/v1/attention
|
- vllm/v1/attention
|
||||||
- tests/v1/attention
|
- tests/v1/attention
|
||||||
commands:
|
commands:
|
||||||
- VLLM_DISABLE_FLASHINFER_PREFILL=1 pytest -v -s v1/attention # TODO: FI prefill is bugged and causes incorrectness, fix this
|
- pytest -v -s v1/attention
|
||||||
|
|
||||||
- label: V1 Test others (CPU) # 5 mins
|
- label: V1 Test others (CPU) # 5 mins
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
|
|||||||
@@ -18,4 +18,4 @@ steps:
|
|||||||
- vllm/v1/attention
|
- vllm/v1/attention
|
||||||
- tests/v1/attention
|
- tests/v1/attention
|
||||||
commands:
|
commands:
|
||||||
- VLLM_DISABLE_FLASHINFER_PREFILL=1 pytest -v -s v1/attention # TODO: FI prefill is bugged and causes incorrectness, fix this
|
- pytest -v -s v1/attention
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ class AttentionConfig:
|
|||||||
use_cudnn_prefill: bool = False
|
use_cudnn_prefill: bool = False
|
||||||
"""Whether to use cudnn prefill."""
|
"""Whether to use cudnn prefill."""
|
||||||
|
|
||||||
use_trtllm_ragged_deepseek_prefill: bool = False
|
use_trtllm_ragged_deepseek_prefill: bool = True
|
||||||
"""Whether to use TRTLLM ragged deepseek prefill."""
|
"""Whether to use TRTLLM ragged deepseek prefill."""
|
||||||
|
|
||||||
use_trtllm_attention: bool | None = None
|
use_trtllm_attention: bool | None = None
|
||||||
|
|||||||
@@ -450,7 +450,6 @@ def use_flashinfer_prefill() -> bool:
|
|||||||
not vllm_config.attention_config.disable_flashinfer_prefill
|
not vllm_config.attention_config.disable_flashinfer_prefill
|
||||||
and flashinfer_available
|
and flashinfer_available
|
||||||
and not vllm_config.attention_config.use_cudnn_prefill
|
and not vllm_config.attention_config.use_cudnn_prefill
|
||||||
and not vllm_config.attention_config.use_trtllm_ragged_deepseek_prefill
|
|
||||||
and current_platform.is_device_capability_family(100)
|
and current_platform.is_device_capability_family(100)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1324,25 +1323,27 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
|||||||
def __init__(self, *args, **kwargs) -> None:
|
def __init__(self, *args, **kwargs) -> None:
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
if use_flashinfer_prefill():
|
if use_trtllm_ragged_deepseek_prefill():
|
||||||
logger.debug_once("Using FlashInfer prefill for MLA")
|
logger.info_once(
|
||||||
self._run_prefill_context_chunk = self._run_prefill_context_chunk_fi
|
"Using TRT-LLM ragged DeepSeek prefill for MLA", scope="local"
|
||||||
self._run_prefill_new_tokens = self._run_prefill_new_tokens_fi
|
)
|
||||||
self._pad_v = False
|
|
||||||
elif use_trtllm_ragged_deepseek_prefill():
|
|
||||||
logger.debug_once("Using TRT-LLM ragged DeepSeek prefill for MLA")
|
|
||||||
self._run_prefill_context_chunk = (
|
self._run_prefill_context_chunk = (
|
||||||
self._run_prefill_context_chunk_trtllm_ragged
|
self._run_prefill_context_chunk_trtllm_ragged
|
||||||
)
|
)
|
||||||
self._run_prefill_new_tokens = self._run_prefill_new_tokens_trtllm_ragged
|
self._run_prefill_new_tokens = self._run_prefill_new_tokens_trtllm_ragged
|
||||||
self._pad_v = False
|
self._pad_v = False
|
||||||
|
elif use_flashinfer_prefill():
|
||||||
|
logger.info_once("Using FlashInfer prefill for MLA", scope="local")
|
||||||
|
self._run_prefill_context_chunk = self._run_prefill_context_chunk_fi
|
||||||
|
self._run_prefill_new_tokens = self._run_prefill_new_tokens_fi
|
||||||
|
self._pad_v = False
|
||||||
elif use_cudnn_prefill():
|
elif use_cudnn_prefill():
|
||||||
logger.debug_once("Using CUDNN prefill for MLA")
|
logger.info_once("Using CUDNN prefill for MLA", scope="local")
|
||||||
self._run_prefill_context_chunk = self._run_prefill_context_chunk_cudnn
|
self._run_prefill_context_chunk = self._run_prefill_context_chunk_cudnn
|
||||||
self._run_prefill_new_tokens = self._run_prefill_new_tokens_cudnn
|
self._run_prefill_new_tokens = self._run_prefill_new_tokens_cudnn
|
||||||
self._pad_v = False
|
self._pad_v = False
|
||||||
else: # Use FlashAttention
|
else: # Use FlashAttention
|
||||||
logger.debug_once("Using FlashAttention prefill for MLA")
|
logger.info_once("Using FlashAttention prefill for MLA", scope="local")
|
||||||
self._run_prefill_context_chunk = self._run_prefill_context_chunk_fa
|
self._run_prefill_context_chunk = self._run_prefill_context_chunk_fa
|
||||||
self._run_prefill_new_tokens = self._run_prefill_new_tokens_fa
|
self._run_prefill_new_tokens = self._run_prefill_new_tokens_fa
|
||||||
|
|
||||||
|
|||||||
@@ -50,8 +50,8 @@ def _get_backend_priorities(
|
|||||||
if use_mla:
|
if use_mla:
|
||||||
if device_capability.major == 10:
|
if device_capability.major == 10:
|
||||||
return [
|
return [
|
||||||
AttentionBackendEnum.CUTLASS_MLA,
|
|
||||||
AttentionBackendEnum.FLASHINFER_MLA,
|
AttentionBackendEnum.FLASHINFER_MLA,
|
||||||
|
AttentionBackendEnum.CUTLASS_MLA,
|
||||||
AttentionBackendEnum.FLASH_ATTN_MLA,
|
AttentionBackendEnum.FLASH_ATTN_MLA,
|
||||||
AttentionBackendEnum.FLASHMLA,
|
AttentionBackendEnum.FLASHMLA,
|
||||||
AttentionBackendEnum.TRITON_MLA,
|
AttentionBackendEnum.TRITON_MLA,
|
||||||
@@ -183,12 +183,12 @@ class CudaPlatformBase(Platform):
|
|||||||
if vllm_config.attention_config.backend is None:
|
if vllm_config.attention_config.backend is None:
|
||||||
# Default case
|
# Default case
|
||||||
if cls.is_device_capability_family(100) and not use_sparse:
|
if cls.is_device_capability_family(100) and not use_sparse:
|
||||||
# Blackwell => Force CutlassMLA (unless sparse, i.e. DSv3.2).
|
# Blackwell => Force FlashInfer MLA (unless sparse, i.e. DSv3.2).
|
||||||
use_cutlass_mla = True
|
use_flashinfer_mla = True
|
||||||
# Set the backend in AttentionConfig so it's used during
|
# Set the backend in AttentionConfig so it's used during
|
||||||
# backend selection
|
# backend selection
|
||||||
vllm_config.attention_config.backend = (
|
vllm_config.attention_config.backend = (
|
||||||
AttentionBackendEnum.CUTLASS_MLA
|
AttentionBackendEnum.FLASHINFER_MLA
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Not Blackwell
|
# Not Blackwell
|
||||||
|
|||||||
Reference in New Issue
Block a user