[unrevert] Add batch invariant kernel override for FlashInfer backend [2/n] (#26373)

Signed-off-by: Bram Wasti <bwasti@meta.com>
Signed-off-by: Bram Wasti <bwasti@fb.com>
This commit is contained in:
Bram Wasti
2025-10-13 07:24:53 -07:00
committed by GitHub
parent 8e67b2557a
commit 3263799056
4 changed files with 81 additions and 35 deletions

View File

@@ -8,8 +8,12 @@ from typing import Any
import torch
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.triton_utils import tl, triton
logger = init_logger(__name__)
def _matmul_launch_metadata(
grid: Callable[..., Any], kernel: Any, args: dict[str, Any]
@@ -562,5 +566,14 @@ def vllm_kernel_override_batch_invariant():
def init_batch_invariance():
# this will hit all the csrc overrides as well
if vllm_kernel_override_batch_invariant():
os.environ["VLLM_ATTENTION_BACKEND"] = "FLEX_ATTENTION"
curr_attn_backend = envs.VLLM_ATTENTION_BACKEND
supported_backends = ["FLEX_ATTENTION", "FLASHINFER"]
if curr_attn_backend not in supported_backends:
warning = (
"Forcibly updating attention backend to"
f" {supported_backends[0]} for batch_invariant. "
f" Supported backends: {supported_backends}."
)
logger.warning_once(warning)
os.environ["VLLM_ATTENTION_BACKEND"] = supported_backends[0]
enable_batch_invariant_mode()