[unrevert] Add batch invariant kernel override for FlashInfer backend [2/n] (#26373)
Signed-off-by: Bram Wasti <bwasti@meta.com> Signed-off-by: Bram Wasti <bwasti@fb.com>
This commit is contained in:
@@ -8,8 +8,12 @@ from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _matmul_launch_metadata(
|
||||
grid: Callable[..., Any], kernel: Any, args: dict[str, Any]
|
||||
@@ -562,5 +566,14 @@ def vllm_kernel_override_batch_invariant():
|
||||
def init_batch_invariance():
|
||||
# this will hit all the csrc overrides as well
|
||||
if vllm_kernel_override_batch_invariant():
|
||||
os.environ["VLLM_ATTENTION_BACKEND"] = "FLEX_ATTENTION"
|
||||
curr_attn_backend = envs.VLLM_ATTENTION_BACKEND
|
||||
supported_backends = ["FLEX_ATTENTION", "FLASHINFER"]
|
||||
if curr_attn_backend not in supported_backends:
|
||||
warning = (
|
||||
"Forcibly updating attention backend to"
|
||||
f" {supported_backends[0]} for batch_invariant. "
|
||||
f" Supported backends: {supported_backends}."
|
||||
)
|
||||
logger.warning_once(warning)
|
||||
os.environ["VLLM_ATTENTION_BACKEND"] = supported_backends[0]
|
||||
enable_batch_invariant_mode()
|
||||
|
||||
Reference in New Issue
Block a user