Revert "Add batch invariant kernel override for FlashInfer backend [2/n]" (#26220)

This commit is contained in:
Cyrus Leung
2025-10-04 17:45:08 +08:00
committed by GitHub
parent 7d6b03381e
commit 1838cd4860
3 changed files with 29 additions and 84 deletions

View File

@@ -8,12 +8,8 @@ from typing import Any, Union
import torch
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.triton_utils import tl, triton
logger = init_logger(__name__)
def _matmul_launch_metadata(grid: Callable[..., Any], kernel: Any,
args: dict[str, Any]) -> dict[str, Any]:
@@ -561,12 +557,5 @@ def vllm_kernel_override_batch_invariant():
def init_batch_invariance():
# this will hit all the csrc overrides as well
if vllm_kernel_override_batch_invariant():
curr_attn_backend = envs.VLLM_ATTENTION_BACKEND
supported_backends = ["FLEX_ATTENTION", "FLASHINFER"]
if curr_attn_backend not in supported_backends:
warning = "Forcibly updating attention backend to" \
f" {supported_backends[0]} for batch_invariant. " \
f" Supported backends: {supported_backends}."
logger.warning_once(warning)
os.environ["VLLM_ATTENTION_BACKEND"] = supported_backends[0]
os.environ["VLLM_ATTENTION_BACKEND"] = "FLEX_ATTENTION"
enable_batch_invariant_mode()