[Perf] Optimize batch invariant BMM, 18.1% Throughput improvement, 10.7% TTFT improvement (#29345)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Wentao Ye
2025-11-26 11:38:52 -05:00
committed by GitHub
parent 70d5953f82
commit 0b0aa874e8
3 changed files with 217 additions and 16 deletions

View File

@@ -8,6 +8,7 @@ import torch
from vllm.attention.utils.fa_utils import flash_attn_supports_mla
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer
skip_unsupported = pytest.mark.skipif(
not (current_platform.is_cuda() and current_platform.has_device_capability(90)),
@@ -16,9 +17,11 @@ skip_unsupported = pytest.mark.skipif(
BACKENDS: list[str] = [
"FLASH_ATTN",
"FLASHINFER",
]
if has_flashinfer():
BACKENDS.append("FLASHINFER")
if flash_attn_supports_mla():
BACKENDS.append("FLASH_ATTN_MLA")