[Bugfix] Register VLLM_BATCH_INVARIANT in envs.py to fix spurious unknown env var warning (#35007)

Signed-off-by: Ranran <1012869439@qq.com>
Signed-off-by: Ranran <hzz5361@psu.edu>
Signed-off-by: ran <hzz5361@psu.edu>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
Ranran
2026-03-23 17:31:14 -05:00
committed by GitHub
parent e85f8f0932
commit dc6908ac6a
30 changed files with 70 additions and 130 deletions

View File

@@ -15,7 +15,6 @@ from vllm.model_executor.layers.attention.kv_transfer_utils import (
maybe_transfer_kv_layer,
)
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
from vllm.model_executor.layers.linear import (
UnquantizedLinearMethod,
)
@@ -296,7 +295,7 @@ class Attention(nn.Module, AttentionLayerBase):
if (
cache_config is not None
and cache_config.enable_prefix_caching
and vllm_is_batch_invariant()
and envs.VLLM_BATCH_INVARIANT
and (
self.attn_backend.get_name() == "FLASHINFER"
or self.attn_backend.get_name() == "TRITON_MLA"

View File

@@ -227,7 +227,6 @@ from vllm.model_executor.layers.attention.kv_transfer_utils import (
maybe_transfer_kv_layer,
)
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
)
@@ -372,7 +371,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
if (
cache_config is not None
and cache_config.enable_prefix_caching
and vllm_is_batch_invariant()
and envs.VLLM_BATCH_INVARIANT
and (
self.attn_backend.get_name() == "TRITON_MLA"
or self.attn_backend.get_name() == "FLASHINFER"
@@ -2188,7 +2187,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
# ROCm leverages the upstream flash_attn, which takes a parameter
# called "return_attn_probs" instead of return_softmax_lse
kwargs["return_attn_probs"] = return_softmax_lse
if vllm_is_batch_invariant():
if envs.VLLM_BATCH_INVARIANT:
kwargs["num_splits"] = 1
attn_out = self.flash_attn_varlen_func(

View File

@@ -6,6 +6,7 @@ from typing import Any
import torch
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
@@ -986,21 +987,6 @@ def enable_batch_invariant_mode():
torch.backends.cuda.preferred_blas_library(backend="cublaslt")
def _read_vllm_batch_invariant() -> bool:
val = os.getenv("VLLM_BATCH_INVARIANT", "0")
try:
return int(val) != 0
except ValueError:
return False
VLLM_BATCH_INVARIANT: bool = _read_vllm_batch_invariant()
def vllm_is_batch_invariant() -> bool:
return VLLM_BATCH_INVARIANT
def override_envs_for_invariance(
attention_backend: AttentionBackendEnum | None,
):
@@ -1059,7 +1045,7 @@ def init_batch_invariance(
attention_backend: AttentionBackendEnum | None,
):
# this will hit all the csrc overrides as well
if vllm_is_batch_invariant():
if envs.VLLM_BATCH_INVARIANT:
override_envs_for_invariance(attention_backend)
enable_batch_invariant_mode()

View File

@@ -14,9 +14,6 @@ import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.model_executor.layers.fused_moe.activation import (
MoEActivation,
apply_moe_activation,
@@ -1051,7 +1048,7 @@ def get_moe_configs(
"""
# Avoid optimizing for the batch invariant case. Use default config
if vllm_is_batch_invariant():
if envs.VLLM_BATCH_INVARIANT:
return None
# First look up if an optimized configuration is available in the configs
@@ -1232,7 +1229,7 @@ def get_default_config(
dtype: str | None,
block_shape: list[int] | None = None,
) -> dict[str, int]:
if vllm_is_batch_invariant():
if envs.VLLM_BATCH_INVARIANT:
return {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,

View File

@@ -6,11 +6,9 @@ from collections.abc import Callable
import torch
import vllm._custom_ops as ops
import vllm.envs as envs
from vllm._aiter_ops import rocm_aiter_ops
from vllm.distributed.eplb.eplb_state import EplbLayerState
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.model_executor.layers.fused_moe.config import (
RoutingMethodType,
get_routing_method_type,
@@ -160,7 +158,7 @@ def fused_topk_bias(
) + e_score_correction_bias.unsqueeze(0)
# For batch invariance, use sorted=True to ensure deterministic expert selection
use_sorted = vllm_is_batch_invariant()
use_sorted = envs.VLLM_BATCH_INVARIANT
topk_indices = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=use_sorted)[1]
topk_weights = scores.gather(1, topk_indices)
if renormalize:

View File

@@ -10,9 +10,6 @@ from vllm import envs as envs
from vllm._aiter_ops import rocm_aiter_ops
from vllm.distributed.eplb.eplb_state import EplbLayerState
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.model_executor.layers.fused_moe.config import (
RoutingMethodType,
get_routing_method_type,
@@ -135,7 +132,7 @@ def grouped_topk(
) # [n, n_group]
# For batch invariance, use sorted=True to ensure deterministic expert selection
use_sorted = vllm_is_batch_invariant()
use_sorted = envs.VLLM_BATCH_INVARIANT
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=use_sorted)[
1
] # [n, top_k_group]

View File

@@ -12,7 +12,6 @@ from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.batch_invariant import (
rms_norm_batch_invariant,
vllm_is_batch_invariant,
)
from vllm.platforms import current_platform
@@ -57,7 +56,7 @@ def rms_norm(
) -> torch.Tensor:
from vllm import _custom_ops as ops
if vllm_is_batch_invariant():
if envs.VLLM_BATCH_INVARIANT:
return rms_norm_batch_invariant(x, weight, variance_epsilon)
out = torch.empty_like(x)
ops.rms_norm(
@@ -77,7 +76,7 @@ def fused_add_rms_norm(
) -> tuple[torch.Tensor, torch.Tensor]:
from vllm import _custom_ops as ops
if vllm_is_batch_invariant():
if envs.VLLM_BATCH_INVARIANT:
return rms_norm_batch_invariant(
x + residual, weight, variance_epsilon
), x + residual
@@ -300,7 +299,7 @@ class RMSNorm(CustomOp):
and x.is_cuda
and x.dim() >= 2
and self.has_weight
and not vllm_is_batch_invariant()
and not envs.VLLM_BATCH_INVARIANT
and self.weight.data.dtype == x.dtype
and self.weight.data.is_contiguous()
):
@@ -328,7 +327,7 @@ class RMSNorm(CustomOp):
and x.dtype == residual.dtype
and x.dim() >= 2
and self.has_weight
and not vllm_is_batch_invariant()
and not envs.VLLM_BATCH_INVARIANT
and self.weight.data.dtype == x.dtype
and self.weight.data.is_contiguous()
):

View File

@@ -7,6 +7,7 @@ from abc import abstractmethod
import torch
from torch.nn.parameter import Parameter, UninitializedParameter
import vllm.envs as envs
from vllm.distributed import (
divide,
get_tensor_model_parallel_rank,
@@ -19,7 +20,6 @@ from vllm.logger import init_logger
from vllm.model_executor.custom_op import PluggableLayer
from vllm.model_executor.layers.batch_invariant import (
linear_batch_invariant,
vllm_is_batch_invariant,
)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig,
@@ -223,7 +223,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
if vllm_is_batch_invariant() and current_platform.is_cuda_alike():
if envs.VLLM_BATCH_INVARIANT and current_platform.is_cuda_alike():
return linear_batch_invariant(x, layer.weight, bias)
return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)

View File

@@ -7,6 +7,7 @@ import torch
from torch.nn import Module
from torch.utils._python_dispatch import TorchDispatchMode
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
@@ -17,9 +18,6 @@ from vllm.model_executor.kernels.linear import (
)
from vllm.model_executor.kernels.linear.scaled_mm import MarlinFP8ScaledMMLinearKernel
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.model_executor.layers.fused_moe import (
FusedMoE,
FusedMoEMethodBase,
@@ -441,7 +439,7 @@ class Fp8LinearMethod(LinearMethodBase):
) -> torch.Tensor:
# if batch invariant mode is enabled, prefer DeepGEMM FP8 path
# we will use BF16 dequant when DeepGEMM is not supported.
if vllm_is_batch_invariant():
if envs.VLLM_BATCH_INVARIANT:
if self.block_quant:
assert self.weight_block_size is not None
return self.w8a8_block_fp8_linear.apply(

View File

@@ -305,9 +305,7 @@ def _flashinfer_fp8_blockscale_gemm_impl(
)
return output
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
if vllm_is_batch_invariant():
if envs.VLLM_BATCH_INVARIANT:
return run_deepgemm(input, weight, weight_scale)
condition = input.shape[0] < 32