Revert "Add batch invariant kernel override for FlashInfer backend [2/n]" (#26220)
This commit is contained in:
@@ -8,12 +8,8 @@ from typing import Any, Union
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _matmul_launch_metadata(grid: Callable[..., Any], kernel: Any,
|
||||
args: dict[str, Any]) -> dict[str, Any]:
|
||||
@@ -561,12 +557,5 @@ def vllm_kernel_override_batch_invariant():
|
||||
def init_batch_invariance():
|
||||
# this will hit all the csrc overrides as well
|
||||
if vllm_kernel_override_batch_invariant():
|
||||
curr_attn_backend = envs.VLLM_ATTENTION_BACKEND
|
||||
supported_backends = ["FLEX_ATTENTION", "FLASHINFER"]
|
||||
if curr_attn_backend not in supported_backends:
|
||||
warning = "Forcibly updating attention backend to" \
|
||||
f" {supported_backends[0]} for batch_invariant. " \
|
||||
f" Supported backends: {supported_backends}."
|
||||
logger.warning_once(warning)
|
||||
os.environ["VLLM_ATTENTION_BACKEND"] = supported_backends[0]
|
||||
os.environ["VLLM_ATTENTION_BACKEND"] = "FLEX_ATTENTION"
|
||||
enable_batch_invariant_mode()
|
||||
|
||||
Reference in New Issue
Block a user