[small][batch invariance] Rename the env and internal flags to simplify usage (#26855)

Signed-off-by: Bram Wasti <bwasti@meta.com>
This commit is contained in:
Bram Wasti
2025-10-16 14:40:25 -07:00
committed by GitHub
parent 23583ee28c
commit b2f78cbad4
20 changed files with 61 additions and 61 deletions

View File

@@ -741,8 +741,8 @@ def get_batch_invariant_attention_block_size() -> AttentionBlockSize:
return AttentionBlockSize(block_m=16, block_n=16)
def vllm_kernel_override_batch_invariant():
env_key = "VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"
def vllm_is_batch_invariant():
env_key = "VLLM_BATCH_INVARIANT"
is_overridden = False
val = os.getenv(env_key, "0")
try:
@@ -797,7 +797,7 @@ def override_envs_for_invariance():
def init_batch_invariance():
# this will hit all the csrc overrides as well
if vllm_kernel_override_batch_invariant():
if vllm_is_batch_invariant():
override_envs_for_invariance()
enable_batch_invariant_mode()

View File

@@ -16,7 +16,7 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_kernel_override_batch_invariant,
vllm_is_batch_invariant,
)
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG,
@@ -841,7 +841,7 @@ def get_moe_configs(
"""
# Avoid optimizing for the batch invariant case. Use default config
if vllm_kernel_override_batch_invariant():
if vllm_is_batch_invariant():
return None
# First look up if an optimized configuration is available in the configs
@@ -976,7 +976,7 @@ def get_default_config(
dtype: str | None,
block_shape: list[int] | None = None,
) -> dict[str, int]:
if vllm_kernel_override_batch_invariant():
if vllm_is_batch_invariant():
config = {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
@@ -1136,7 +1136,7 @@ def fused_topk_bias(
) + e_score_correction_bias.unsqueeze(0)
# For batch invariance, use sorted=True to ensure deterministic expert selection
use_sorted = vllm_kernel_override_batch_invariant()
use_sorted = vllm_is_batch_invariant()
topk_indices = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=use_sorted)[1]
topk_weights = scores.gather(1, topk_indices)
if renormalize:
@@ -1200,7 +1200,7 @@ def grouped_topk(
) # [n, n_group]
# For batch invariance, use sorted=True to ensure deterministic expert selection
use_sorted = vllm_kernel_override_batch_invariant()
use_sorted = vllm_is_batch_invariant()
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=use_sorted)[
1
] # [n, top_k_group]

View File

@@ -10,7 +10,7 @@ import vllm.envs as envs
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.batch_invariant import (
rms_norm_batch_invariant,
vllm_kernel_override_batch_invariant,
vllm_is_batch_invariant,
)
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
@@ -25,7 +25,7 @@ def rms_norm(
) -> torch.Tensor:
from vllm import _custom_ops as ops
if vllm_kernel_override_batch_invariant():
if vllm_is_batch_invariant():
return rms_norm_batch_invariant(x, weight, variance_epsilon)
out = torch.empty_like(x)
ops.rms_norm(
@@ -45,7 +45,7 @@ def fused_add_rms_norm(
) -> tuple[torch.Tensor, torch.Tensor]:
from vllm import _custom_ops as ops
if vllm_kernel_override_batch_invariant():
if vllm_is_batch_invariant():
return rms_norm_batch_invariant(
x + residual, weight, variance_epsilon
), x + residual

View File

@@ -15,7 +15,7 @@ from vllm import _custom_ops as ops
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_kernel_override_batch_invariant,
vllm_is_batch_invariant,
)
from vllm.model_executor.layers.fused_moe import (
FusedMoE,
@@ -356,7 +356,7 @@ class Fp8LinearMethod(LinearMethodBase):
# Disable marlin for rocm
if current_platform.is_rocm():
self.use_marlin = False
if vllm_kernel_override_batch_invariant():
if vllm_is_batch_invariant():
self.use_marlin = False
self.use_aiter_and_is_supported = check_aiter_fp8_linear_support()
@@ -540,7 +540,7 @@ class Fp8LinearMethod(LinearMethodBase):
bias: torch.Tensor | None = None,
) -> torch.Tensor:
# If batch invariant mode is enabled, dequantize and use BF16 compute
if vllm_kernel_override_batch_invariant():
if vllm_is_batch_invariant():
# Dequantize FP8 weights to BF16
weight_fp8 = layer.weight.to(torch.bfloat16)
weight_scale = layer.weight_scale.to(torch.bfloat16)