Add batch invariant kernel override for FlashInfer backend [2/n] (#25769)

Signed-off-by: Bram Wasti <bwasti@meta.com>
Signed-off-by: Bram Wasti <bwasti@fb.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
Bram Wasti
2025-10-03 21:49:30 -05:00
committed by GitHub
parent ea25a76c05
commit 2f7dbc9b42
3 changed files with 84 additions and 29 deletions

View File

@@ -8,8 +8,12 @@ from typing import Any, Union
import torch
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.triton_utils import tl, triton
logger = init_logger(__name__)
def _matmul_launch_metadata(grid: Callable[..., Any], kernel: Any,
args: dict[str, Any]) -> dict[str, Any]:
@@ -557,5 +561,12 @@ def vllm_kernel_override_batch_invariant():
def init_batch_invariance():
# this will hit all the csrc overrides as well
if vllm_kernel_override_batch_invariant():
os.environ["VLLM_ATTENTION_BACKEND"] = "FLEX_ATTENTION"
curr_attn_backend = envs.VLLM_ATTENTION_BACKEND
supported_backends = ["FLEX_ATTENTION", "FLASHINFER"]
if curr_attn_backend not in supported_backends:
warning = "Forcibly updating attention backend to" \
f" {supported_backends[0]} for batch_invariant. " \
f" Supported backends: {supported_backends}."
logger.warning_once(warning)
os.environ["VLLM_ATTENTION_BACKEND"] = supported_backends[0]
enable_batch_invariant_mode()