diff --git a/tests/kernels/attention/test_use_trtllm_attention.py b/tests/kernels/attention/test_use_trtllm_attention.py index e24ad1018..fba18fe46 100644 --- a/tests/kernels/attention/test_use_trtllm_attention.py +++ b/tests/kernels/attention/test_use_trtllm_attention.py @@ -55,37 +55,37 @@ def _clear_supports_cache(): # supports_trtllm_attention -@patch("vllm.utils.flashinfer.vllm_is_batch_invariant", return_value=True) -def test_supports_batch_invariant_disables(_mock): +@patch("vllm.envs.VLLM_BATCH_INVARIANT", True) +def test_supports_batch_invariant_disables(): assert supports_trtllm_attention() is False -@patch("vllm.utils.flashinfer.vllm_is_batch_invariant", return_value=False) +@patch("vllm.envs.VLLM_BATCH_INVARIANT", False) @patch( "vllm.utils.flashinfer.current_platform.is_device_capability_family", return_value=True, ) @patch("vllm.utils.flashinfer.has_nvidia_artifactory", return_value=True) -def test_supports_sm100_with_artifactory(_art, _cap, _bi): +def test_supports_sm100_with_artifactory(_art, _cap): assert supports_trtllm_attention() is True -@patch("vllm.utils.flashinfer.vllm_is_batch_invariant", return_value=False) +@patch("vllm.envs.VLLM_BATCH_INVARIANT", False) @patch( "vllm.utils.flashinfer.current_platform.is_device_capability_family", return_value=False, ) -def test_supports_non_sm100_platform(_cap, _bi): +def test_supports_non_sm100_platform(_cap): assert supports_trtllm_attention() is False -@patch("vllm.utils.flashinfer.vllm_is_batch_invariant", return_value=False) +@patch("vllm.envs.VLLM_BATCH_INVARIANT", False) @patch( "vllm.utils.flashinfer.current_platform.is_device_capability_family", return_value=True, ) @patch("vllm.utils.flashinfer.has_nvidia_artifactory", return_value=False) -def test_supports_sm100_without_artifactory(_art, _cap, _bi): +def test_supports_sm100_without_artifactory(_art, _cap): assert supports_trtllm_attention() is False diff --git a/tests/kernels/moe/test_grouped_topk.py b/tests/kernels/moe/test_grouped_topk.py index 70c7285ac..c58c8474b 100644 --- a/tests/kernels/moe/test_grouped_topk.py +++ b/tests/kernels/moe/test_grouped_topk.py @@ -8,7 +8,7 @@ Run `pytest tests/kernels/moe/test_grouped_topk.py`. import pytest import torch -import vllm.model_executor.layers.batch_invariant as batch_invariant +import vllm.envs as envs from vllm.config import ( CompilationConfig, VllmConfig, @@ -69,7 +69,7 @@ def test_grouped_topk( with set_current_vllm_config(vllm_config), monkeypatch.context() as m: m.setenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "0") - m.setattr(batch_invariant, "VLLM_BATCH_INVARIANT", True) + m.setattr(envs, "VLLM_BATCH_INVARIANT", True) grouped_topk = GroupedTopk( topk=topk, renormalize=renormalize, diff --git a/tests/v1/determinism/conftest.py b/tests/v1/determinism/conftest.py index bde02bbd0..b682377cf 100644 --- a/tests/v1/determinism/conftest.py +++ b/tests/v1/determinism/conftest.py @@ -2,11 +2,11 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pytest -import vllm.model_executor.layers.batch_invariant as batch_invariant +import vllm.envs as envs @pytest.fixture(autouse=True) def enable_batch_invariant_mode(monkeypatch: pytest.MonkeyPatch): """Automatically enable batch invariant kernel overrides for all tests.""" - monkeypatch.setattr(batch_invariant, "VLLM_BATCH_INVARIANT", True) + monkeypatch.setattr(envs, "VLLM_BATCH_INVARIANT", True) monkeypatch.setenv("VLLM_BATCH_INVARIANT", "1") diff --git a/tests/v1/determinism/test_batch_invariance.py b/tests/v1/determinism/test_batch_invariance.py index 11550c190..6465985f0 100644 --- a/tests/v1/determinism/test_batch_invariance.py +++ b/tests/v1/determinism/test_batch_invariance.py @@ -15,7 +15,7 @@ from utils import ( skip_unsupported, ) -import vllm.model_executor.layers.batch_invariant as batch_invariant +import vllm.envs as envs from vllm import LLM, SamplingParams IS_DEVICE_CAPABILITY_BELOW_90 = is_device_capability_below_90() @@ -173,11 +173,9 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN( # For batch invariance, disable custom all-reduce to ensure deterministic # all-reduce operations (custom all-reduce may not be deterministic) - from vllm.model_executor.layers.batch_invariant import ( - vllm_is_batch_invariant, - ) + import vllm.envs as envs - disable_custom_ar = vllm_is_batch_invariant() + disable_custom_ar = envs.VLLM_BATCH_INVARIANT if disable_custom_ar: print(f"\n{'=' * 80}") @@ -454,7 +452,7 @@ def test_logprobs_without_batch_invariance_should_fail( """ # CRITICAL: Disable batch invariance for this test monkeypatch.setenv("VLLM_BATCH_INVARIANT", "0") - monkeypatch.setattr(batch_invariant, "VLLM_BATCH_INVARIANT", False) + monkeypatch.setattr(envs, "VLLM_BATCH_INVARIANT", False) seed = int(os.getenv("VLLM_TEST_SEED", "12345")) random.seed(seed) tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1")) @@ -674,11 +672,9 @@ def test_decode_logprobs_match_prefill_logprobs( random.seed(seed) tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1")) - from vllm.model_executor.layers.batch_invariant import ( - vllm_is_batch_invariant, - ) + import vllm.envs as envs - disable_custom_ar = vllm_is_batch_invariant() + disable_custom_ar = envs.VLLM_BATCH_INVARIANT if disable_custom_ar: print(f"\n{'=' * 80}") diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index dd0d7b9cc..8332b0ec7 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -14,9 +14,6 @@ from typing_extensions import Self import vllm.envs as envs from vllm.config.utils import config from vllm.logger import init_logger -from vllm.model_executor.layers.batch_invariant import ( - vllm_is_batch_invariant, -) from vllm.platforms import current_platform from vllm.utils.network_utils import get_open_ports_list from vllm.utils.torch_utils import cuda_device_count_stateless @@ -786,7 +783,7 @@ class ParallelConfig: from vllm.v1.executor import Executor # Enable batch invariance settings if requested - if vllm_is_batch_invariant(): + if envs.VLLM_BATCH_INVARIANT: self.disable_custom_all_reduce = True if ( diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 014400fa9..a8088de63 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -1112,11 +1112,9 @@ class VllmConfig: # type: ignore[misc] "when cudagraph_mode piecewise cudagraphs is used, " f"cudagraph_mode={self.compilation_config.cudagraph_mode}" ) - from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant - if ( self.model_config - and vllm_is_batch_invariant() + and envs.VLLM_BATCH_INVARIANT and not self.model_config.disable_cascade_attn ): self.model_config.disable_cascade_attn = True diff --git a/vllm/distributed/device_communicators/all_reduce_utils.py b/vllm/distributed/device_communicators/all_reduce_utils.py index 9777be5aa..108afa195 100644 --- a/vllm/distributed/device_communicators/all_reduce_utils.py +++ b/vllm/distributed/device_communicators/all_reduce_utils.py @@ -19,9 +19,6 @@ import torch.multiprocessing as mp import vllm.envs as envs from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary from vllm.logger import init_logger -from vllm.model_executor.layers.batch_invariant import ( - vllm_is_batch_invariant, -) from vllm.utils.system_utils import update_environment_variables from vllm.utils.torch_utils import cuda_device_count_stateless @@ -115,7 +112,7 @@ def should_nccl_symm_mem_allreduce(world_size: int, input_tensor: torch.Tensor) is_symmetric_memory_enabled, ) - if vllm_is_batch_invariant(): + if envs.VLLM_BATCH_INVARIANT: return False if not is_symmetric_memory_enabled(): diff --git a/vllm/distributed/device_communicators/symm_mem.py b/vllm/distributed/device_communicators/symm_mem.py index c25ff8cf1..3d964c640 100644 --- a/vllm/distributed/device_communicators/symm_mem.py +++ b/vllm/distributed/device_communicators/symm_mem.py @@ -5,13 +5,11 @@ import torch import torch.distributed as dist from torch.distributed import ProcessGroup +import vllm.envs as envs from vllm.distributed.device_communicators.all_reduce_utils import ( SYMM_MEM_ALL_REDUCE_MAX_SIZES, ) from vllm.logger import init_logger -from vllm.model_executor.layers.batch_invariant import ( - vllm_is_batch_invariant, -) from vllm.platforms import current_platform try: @@ -112,7 +110,7 @@ class SymmMemCommunicator: return self.force_multimem = force_multimem self.disabled = False - if vllm_is_batch_invariant(): + if envs.VLLM_BATCH_INVARIANT: self.disabled = True def should_use_symm_mem(self, inp: torch.Tensor): diff --git a/vllm/envs.py b/vllm/envs.py index 2f93b2cb3..5c2a01482 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -74,6 +74,7 @@ if TYPE_CHECKING: VLLM_TARGET_DEVICE: str = "cuda" VLLM_MAIN_CUDA_VERSION: str = "12.9" VLLM_FLOAT32_MATMUL_PRECISION: Literal["highest", "high", "medium"] = "highest" + VLLM_BATCH_INVARIANT: bool = False MAX_JOBS: str | None = None NVCC_THREADS: str | None = None VLLM_USE_PRECOMPILED: bool = False @@ -280,9 +281,6 @@ def disable_compile_cache() -> bool: def use_aot_compile() -> bool: - from vllm.model_executor.layers.batch_invariant import ( - vllm_is_batch_invariant, - ) from vllm.utils.torch_utils import is_torch_equal_or_newer default_value = ( @@ -292,7 +290,7 @@ def use_aot_compile() -> bool: ) return ( - not vllm_is_batch_invariant() + not bool(int(os.getenv("VLLM_BATCH_INVARIANT", "0"))) and os.environ.get("VLLM_USE_AOT_COMPILE", default_value) == "1" ) @@ -498,6 +496,9 @@ environment_variables: dict[str, Callable[[], Any]] = { ["highest", "high", "medium"], case_sensitive=False, ), + # Enable batch-invariant mode: deterministic results regardless of + # batch composition. Requires NVIDIA GPU with compute capability >= 9.0. + "VLLM_BATCH_INVARIANT": lambda: bool(int(os.getenv("VLLM_BATCH_INVARIANT", "0"))), # Maximum number of compilation jobs to run in parallel. # By default this is the number of CPUs "MAX_JOBS": lambda: os.getenv("MAX_JOBS", None), diff --git a/vllm/lora/ops/triton_ops/utils.py b/vllm/lora/ops/triton_ops/utils.py index ac32dd471..0ab52e698 100644 --- a/vllm/lora/ops/triton_ops/utils.py +++ b/vllm/lora/ops/triton_ops/utils.py @@ -11,12 +11,11 @@ import torch from vllm import envs from vllm.logger import init_logger -from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant from vllm.platforms import current_platform from vllm.utils.math_utils import next_power_of_2 logger = init_logger(__name__) -is_batch_invariant = vllm_is_batch_invariant() +is_batch_invariant = envs.VLLM_BATCH_INVARIANT _LORA_A_PTR_DICT: dict[tuple[int, ...], tuple[torch.tensor, ...]] = {} _LORA_B_PTR_DICT: dict[tuple[int, ...], tuple[torch.tensor, ...]] = {} diff --git a/vllm/model_executor/kernels/linear/scaled_mm/marlin.py b/vllm/model_executor/kernels/linear/scaled_mm/marlin.py index e79809037..ef3b4e463 100644 --- a/vllm/model_executor/kernels/linear/scaled_mm/marlin.py +++ b/vllm/model_executor/kernels/linear/scaled_mm/marlin.py @@ -6,7 +6,6 @@ from collections.abc import Sequence import torch import vllm.envs as envs -from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant from vllm.model_executor.layers.quantization.utils.fp8_utils import ( process_fp8_weight_block_strategy, ) @@ -42,7 +41,7 @@ class MarlinFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel): # Check if platform supports FP8 Marlin if not is_fp8_marlin_supported(): return False, "FP8 Marlin requires compute capability 7.5 or higher" - if vllm_is_batch_invariant(): + if envs.VLLM_BATCH_INVARIANT: return False, "FP8 Marlin not supported for batch invariant execution." if ( compute_capability is not None diff --git a/vllm/model_executor/layers/attention/attention.py b/vllm/model_executor/layers/attention/attention.py index 5516cd329..9b5842594 100644 --- a/vllm/model_executor/layers/attention/attention.py +++ b/vllm/model_executor/layers/attention/attention.py @@ -15,7 +15,6 @@ from vllm.model_executor.layers.attention.kv_transfer_utils import ( maybe_transfer_kv_layer, ) from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase -from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant from vllm.model_executor.layers.linear import ( UnquantizedLinearMethod, ) @@ -296,7 +295,7 @@ class Attention(nn.Module, AttentionLayerBase): if ( cache_config is not None and cache_config.enable_prefix_caching - and vllm_is_batch_invariant() + and envs.VLLM_BATCH_INVARIANT and ( self.attn_backend.get_name() == "FLASHINFER" or self.attn_backend.get_name() == "TRITON_MLA" diff --git a/vllm/model_executor/layers/attention/mla_attention.py b/vllm/model_executor/layers/attention/mla_attention.py index 9d2fa287d..dcad30a8b 100644 --- a/vllm/model_executor/layers/attention/mla_attention.py +++ b/vllm/model_executor/layers/attention/mla_attention.py @@ -227,7 +227,6 @@ from vllm.model_executor.layers.attention.kv_transfer_utils import ( maybe_transfer_kv_layer, ) from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase -from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant from vllm.model_executor.layers.linear import ( ColumnParallelLinear, ) @@ -372,7 +371,7 @@ class MLAAttention(nn.Module, AttentionLayerBase): if ( cache_config is not None and cache_config.enable_prefix_caching - and vllm_is_batch_invariant() + and envs.VLLM_BATCH_INVARIANT and ( self.attn_backend.get_name() == "TRITON_MLA" or self.attn_backend.get_name() == "FLASHINFER" @@ -2188,7 +2187,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): # ROCm leverages the upstream flash_attn, which takes a parameter # called "return_attn_probs" instead of return_softmax_lse kwargs["return_attn_probs"] = return_softmax_lse - if vllm_is_batch_invariant(): + if envs.VLLM_BATCH_INVARIANT: kwargs["num_splits"] = 1 attn_out = self.flash_attn_varlen_func( diff --git a/vllm/model_executor/layers/batch_invariant.py b/vllm/model_executor/layers/batch_invariant.py index 9f8b1955e..2f9450244 100644 --- a/vllm/model_executor/layers/batch_invariant.py +++ b/vllm/model_executor/layers/batch_invariant.py @@ -6,6 +6,7 @@ from typing import Any import torch +import vllm.envs as envs from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.triton_utils import tl, triton @@ -986,21 +987,6 @@ def enable_batch_invariant_mode(): torch.backends.cuda.preferred_blas_library(backend="cublaslt") -def _read_vllm_batch_invariant() -> bool: - val = os.getenv("VLLM_BATCH_INVARIANT", "0") - try: - return int(val) != 0 - except ValueError: - return False - - -VLLM_BATCH_INVARIANT: bool = _read_vllm_batch_invariant() - - -def vllm_is_batch_invariant() -> bool: - return VLLM_BATCH_INVARIANT - - def override_envs_for_invariance( attention_backend: AttentionBackendEnum | None, ): @@ -1059,7 +1045,7 @@ def init_batch_invariance( attention_backend: AttentionBackendEnum | None, ): # this will hit all the csrc overrides as well - if vllm_is_batch_invariant(): + if envs.VLLM_BATCH_INVARIANT: override_envs_for_invariance(attention_backend) enable_batch_invariant_mode() diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index d5b8feb3c..dccdc52bc 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -14,9 +14,6 @@ import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops from vllm.logger import init_logger -from vllm.model_executor.layers.batch_invariant import ( - vllm_is_batch_invariant, -) from vllm.model_executor.layers.fused_moe.activation import ( MoEActivation, apply_moe_activation, @@ -1051,7 +1048,7 @@ def get_moe_configs( """ # Avoid optimizing for the batch invariant case. Use default config - if vllm_is_batch_invariant(): + if envs.VLLM_BATCH_INVARIANT: return None # First look up if an optimized configuration is available in the configs @@ -1232,7 +1229,7 @@ def get_default_config( dtype: str | None, block_shape: list[int] | None = None, ) -> dict[str, int]: - if vllm_is_batch_invariant(): + if envs.VLLM_BATCH_INVARIANT: return { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, diff --git a/vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py b/vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py index 5beb782d7..bcabb1f36 100644 --- a/vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py +++ b/vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py @@ -6,11 +6,9 @@ from collections.abc import Callable import torch import vllm._custom_ops as ops +import vllm.envs as envs from vllm._aiter_ops import rocm_aiter_ops from vllm.distributed.eplb.eplb_state import EplbLayerState -from vllm.model_executor.layers.batch_invariant import ( - vllm_is_batch_invariant, -) from vllm.model_executor.layers.fused_moe.config import ( RoutingMethodType, get_routing_method_type, @@ -160,7 +158,7 @@ def fused_topk_bias( ) + e_score_correction_bias.unsqueeze(0) # For batch invariance, use sorted=True to ensure deterministic expert selection - use_sorted = vllm_is_batch_invariant() + use_sorted = envs.VLLM_BATCH_INVARIANT topk_indices = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=use_sorted)[1] topk_weights = scores.gather(1, topk_indices) if renormalize: diff --git a/vllm/model_executor/layers/fused_moe/router/grouped_topk_router.py b/vllm/model_executor/layers/fused_moe/router/grouped_topk_router.py index 5af2e31b2..1bf141d81 100644 --- a/vllm/model_executor/layers/fused_moe/router/grouped_topk_router.py +++ b/vllm/model_executor/layers/fused_moe/router/grouped_topk_router.py @@ -10,9 +10,6 @@ from vllm import envs as envs from vllm._aiter_ops import rocm_aiter_ops from vllm.distributed.eplb.eplb_state import EplbLayerState from vllm.model_executor.custom_op import CustomOp -from vllm.model_executor.layers.batch_invariant import ( - vllm_is_batch_invariant, -) from vllm.model_executor.layers.fused_moe.config import ( RoutingMethodType, get_routing_method_type, @@ -135,7 +132,7 @@ def grouped_topk( ) # [n, n_group] # For batch invariance, use sorted=True to ensure deterministic expert selection - use_sorted = vllm_is_batch_invariant() + use_sorted = envs.VLLM_BATCH_INVARIANT group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=use_sorted)[ 1 ] # [n, top_k_group] diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index ecc36556c..7fa804587 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -12,7 +12,6 @@ from vllm.logger import init_logger from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.batch_invariant import ( rms_norm_batch_invariant, - vllm_is_batch_invariant, ) from vllm.platforms import current_platform @@ -57,7 +56,7 @@ def rms_norm( ) -> torch.Tensor: from vllm import _custom_ops as ops - if vllm_is_batch_invariant(): + if envs.VLLM_BATCH_INVARIANT: return rms_norm_batch_invariant(x, weight, variance_epsilon) out = torch.empty_like(x) ops.rms_norm( @@ -77,7 +76,7 @@ def fused_add_rms_norm( ) -> tuple[torch.Tensor, torch.Tensor]: from vllm import _custom_ops as ops - if vllm_is_batch_invariant(): + if envs.VLLM_BATCH_INVARIANT: return rms_norm_batch_invariant( x + residual, weight, variance_epsilon ), x + residual @@ -300,7 +299,7 @@ class RMSNorm(CustomOp): and x.is_cuda and x.dim() >= 2 and self.has_weight - and not vllm_is_batch_invariant() + and not envs.VLLM_BATCH_INVARIANT and self.weight.data.dtype == x.dtype and self.weight.data.is_contiguous() ): @@ -328,7 +327,7 @@ class RMSNorm(CustomOp): and x.dtype == residual.dtype and x.dim() >= 2 and self.has_weight - and not vllm_is_batch_invariant() + and not envs.VLLM_BATCH_INVARIANT and self.weight.data.dtype == x.dtype and self.weight.data.is_contiguous() ): diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 3d0430c31..9f81f8fa7 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -7,6 +7,7 @@ from abc import abstractmethod import torch from torch.nn.parameter import Parameter, UninitializedParameter +import vllm.envs as envs from vllm.distributed import ( divide, get_tensor_model_parallel_rank, @@ -19,7 +20,6 @@ from vllm.logger import init_logger from vllm.model_executor.custom_op import PluggableLayer from vllm.model_executor.layers.batch_invariant import ( linear_batch_invariant, - vllm_is_batch_invariant, ) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, @@ -223,7 +223,7 @@ class UnquantizedLinearMethod(LinearMethodBase): x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: - if vllm_is_batch_invariant() and current_platform.is_cuda_alike(): + if envs.VLLM_BATCH_INVARIANT and current_platform.is_cuda_alike(): return linear_batch_invariant(x, layer.weight, bias) return dispatch_unquantized_gemm()(layer, x, layer.weight, bias) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index d2a23bcf2..d758edd9c 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -7,6 +7,7 @@ import torch from torch.nn import Module from torch.utils._python_dispatch import TorchDispatchMode +import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops from vllm._aiter_ops import rocm_aiter_ops @@ -17,9 +18,6 @@ from vllm.model_executor.kernels.linear import ( ) from vllm.model_executor.kernels.linear.scaled_mm import MarlinFP8ScaledMMLinearKernel from vllm.model_executor.layers.attention import Attention -from vllm.model_executor.layers.batch_invariant import ( - vllm_is_batch_invariant, -) from vllm.model_executor.layers.fused_moe import ( FusedMoE, FusedMoEMethodBase, @@ -441,7 +439,7 @@ class Fp8LinearMethod(LinearMethodBase): ) -> torch.Tensor: # if batch invariant mode is enabled, prefer DeepGEMM FP8 path # we will use BF16 dequant when DeepGEMM is not supported. - if vllm_is_batch_invariant(): + if envs.VLLM_BATCH_INVARIANT: if self.block_quant: assert self.weight_block_size is not None return self.w8a8_block_fp8_linear.apply( diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index a974e2c57..3036c71ad 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -305,9 +305,7 @@ def _flashinfer_fp8_blockscale_gemm_impl( ) return output - from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant - - if vllm_is_batch_invariant(): + if envs.VLLM_BATCH_INVARIANT: return run_deepgemm(input, weight, weight_scale) condition = input.shape[0] < 32 diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 0db05851b..065a9ca89 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -19,9 +19,6 @@ import torch import vllm.envs as envs from vllm.logger import init_logger -from vllm.model_executor.layers.batch_invariant import ( - vllm_is_batch_invariant, -) from vllm.platforms import current_platform logger = init_logger(__name__) @@ -289,7 +286,7 @@ def supports_trtllm_attention() -> bool: NVIDIA artifactory is accessible, and batch-invariant mode is not enabled. """ # Batch-invariant mode disables TRTLLM attention - if vllm_is_batch_invariant(): + if envs.VLLM_BATCH_INVARIANT: return False # Requires SM100 and NVIDIA artifactory to be accessible to download cubins diff --git a/vllm/v1/attention/backends/fa_utils.py b/vllm/v1/attention/backends/fa_utils.py index cd8c46d03..a4423b301 100644 --- a/vllm/v1/attention/backends/fa_utils.py +++ b/vllm/v1/attention/backends/fa_utils.py @@ -3,8 +3,8 @@ from typing import Any +import vllm.envs as envs from vllm.logger import init_logger -from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant from vllm.platforms import current_platform logger = init_logger(__name__) @@ -114,7 +114,7 @@ def get_flash_attn_version( # FA4 currently uses batch-shape-dependent scheduling # heuristics on SM100+, which breaks batch invariance. - if vllm_is_batch_invariant() and fa_version == 4: + if envs.VLLM_BATCH_INVARIANT and fa_version == 4: logger.warning_once( "Cannot use FA version 4 with batch invariance, " "defaulting to FA version 2.", diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index f3f19f60c..245995be2 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -33,6 +33,7 @@ if is_flash_attn_varlen_func_available(): get_scheduler_metadata, reshape_and_cache_flash, ) +import vllm.envs as envs from vllm.config import ( VllmConfig, get_current_vllm_config, @@ -42,9 +43,6 @@ from vllm.config import ( from vllm.config.cache import CacheDType from vllm.distributed.parallel_state import get_dcp_group from vllm.logger import init_logger -from vllm.model_executor.layers.batch_invariant import ( - vllm_is_batch_invariant, -) from vllm.platforms.interface import DeviceCapability from vllm.utils.math_utils import cdiv, round_up from vllm.v1.attention.backend import ( @@ -402,7 +400,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad # we only set num_splits when using cuda graphs. max_num_splits = self.max_num_splits - if vllm_is_batch_invariant(): + if envs.VLLM_BATCH_INVARIANT: max_num_splits = 1 def schedule( @@ -601,7 +599,7 @@ class FlashAttentionImpl(AttentionImpl): scope="local", ) # Cache the batch invariant result for use in forward passes - self.batch_invariant_enabled = vllm_is_batch_invariant() + self.batch_invariant_enabled = envs.VLLM_BATCH_INVARIANT if is_quantized_kv_cache(self.kv_cache_dtype) and not flash_attn_supports_fp8(): raise NotImplementedError( @@ -1124,7 +1122,7 @@ def cascade_attention( # s_aux is incorporated into prefix_lse inside the GPU kernel, # enabling its effect during the final attention merge. s_aux=s_aux, - num_splits=1 if vllm_is_batch_invariant() else max_num_splits, + num_splits=1 if envs.VLLM_BATCH_INVARIANT else max_num_splits, ) descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2]) @@ -1149,7 +1147,7 @@ def cascade_attention( q_descale=q_descale.expand(descale_shape) if q_descale is not None else None, k_descale=k_descale.expand(descale_shape) if k_descale is not None else None, v_descale=v_descale.expand(descale_shape) if v_descale is not None else None, - num_splits=1 if vllm_is_batch_invariant() else max_num_splits, + num_splits=1 if envs.VLLM_BATCH_INVARIANT else max_num_splits, ) # Merge prefix and suffix outputs, and store the result in output. diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index da97f612a..5b6c198e7 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -28,9 +28,6 @@ from vllm.config import ( from vllm.config.cache import CacheDType from vllm.distributed.parallel_state import get_dcp_group from vllm.logger import init_logger -from vllm.model_executor.layers.batch_invariant import ( - vllm_is_batch_invariant, -) from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, kFp8StaticTensorSym, @@ -544,7 +541,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ) = None # Wrapper for prefill/append self._decode_wrapper = None # Wrapper for decode (general shape) - if vllm_is_batch_invariant(): + if envs.VLLM_BATCH_INVARIANT: self.decode_fixed_split_size = 2048 self.prefill_fixed_split_size = 4096 self.disable_split_kv = True @@ -719,7 +716,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): def _get_workspace_buffer(self): if self._workspace_buffer is None: buffer_size = envs.VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE - if vllm_is_batch_invariant(): + if envs.VLLM_BATCH_INVARIANT: buffer_size = FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT self._workspace_buffer = torch.zeros( buffer_size, dtype=torch.uint8, device=self.device diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index d76d7c94e..25bb31ffc 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -20,12 +20,10 @@ from torch.nn.attention.flex_attention import ( or_masks, ) +import vllm.envs as envs from vllm.config import VllmConfig from vllm.config.cache import CacheDType from vllm.logger import init_logger -from vllm.model_executor.layers.batch_invariant import ( - vllm_is_batch_invariant, -) from vllm.platforms import current_platform from vllm.utils.math_utils import cdiv from vllm.utils.torch_utils import is_torch_equal_or_newer @@ -995,7 +993,7 @@ def get_kernel_options( return block_size return candidate - if vllm_is_batch_invariant(): + if envs.VLLM_BATCH_INVARIANT: kernel_options["BLOCK_M"] = 16 kernel_options["BLOCK_N"] = 16 kernel_options["IS_DIVISIBLE"] = False diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index fc74a16a1..82d463dcd 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -6,6 +6,7 @@ from typing import ClassVar import torch +import vllm.envs as envs from vllm.config import VllmConfig from vllm.config.cache import CacheDType from vllm.logger import init_logger @@ -17,9 +18,6 @@ from vllm.model_executor.layers.attention.mla_attention import ( MLACommonMetadataBuilder, QueryLenSupport, ) -from vllm.model_executor.layers.batch_invariant import ( - vllm_is_batch_invariant, -) from vllm.platforms.interface import DeviceCapability from vllm.utils.math_utils import round_up from vllm.v1.attention.backend import ( @@ -152,7 +150,7 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata] vllm_config.attention_config.flash_attn_max_num_splits_for_cuda_graph ) - if vllm_is_batch_invariant(): + if envs.VLLM_BATCH_INVARIANT: self.max_num_splits = 1 def _schedule_decode( @@ -209,7 +207,7 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata] # we only set num_splits when using cuda graphs. max_num_splits = self.max_num_splits - if vllm_is_batch_invariant(): + if envs.VLLM_BATCH_INVARIANT: max_num_splits = 1 scheduler_metadata = self._schedule_decode( diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index f5440d149..76f32a54f 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -6,6 +6,7 @@ from typing import ClassVar import torch +import vllm.envs as envs from vllm.config import VllmConfig from vllm.config.cache import CacheDType from vllm.logger import init_logger @@ -17,9 +18,6 @@ from vllm.model_executor.layers.attention.mla_attention import ( MLACommonMetadataBuilder, QueryLenSupport, ) -from vllm.model_executor.layers.batch_invariant import ( - vllm_is_batch_invariant, -) from vllm.platforms.interface import DeviceCapability from vllm.utils.platform_utils import num_compute_units from vllm.v1.attention.backend import ( @@ -256,7 +254,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): q = reshape_query_for_spec_decode(q, num_decodes) scheduler_metadata = attn_metadata.decode.scheduler_metadata - if vllm_is_batch_invariant() and not self.kv_cache_dtype.startswith("fp8"): + if envs.VLLM_BATCH_INVARIANT and not self.kv_cache_dtype.startswith("fp8"): device = q.device dtype = torch.int32 diff --git a/vllm/v1/attention/backends/mla/triton_mla.py b/vllm/v1/attention/backends/mla/triton_mla.py index b205066d6..3de5be31d 100644 --- a/vllm/v1/attention/backends/mla/triton_mla.py +++ b/vllm/v1/attention/backends/mla/triton_mla.py @@ -5,6 +5,7 @@ from typing import ClassVar import torch +import vllm.envs as envs from vllm.config.cache import CacheDType from vllm.logger import init_logger from vllm.model_executor.layers.attention.mla_attention import ( @@ -12,9 +13,6 @@ from vllm.model_executor.layers.attention.mla_attention import ( MLACommonImpl, MLACommonMetadata, ) -from vllm.model_executor.layers.batch_invariant import ( - vllm_is_batch_invariant, -) from vllm.platforms.interface import DeviceCapability from vllm.v1.attention.backend import ( AttentionLayer, @@ -151,7 +149,7 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): lse = torch.zeros(B, q_num_heads, dtype=q.dtype, device=q.device) # For batch invariance, use only 1 split to ensure deterministic reduction - num_kv_splits = 1 if vllm_is_batch_invariant() else 4 + num_kv_splits = 1 if envs.VLLM_BATCH_INVARIANT else 4 # TODO(lucas) Allocate ahead of time attn_logits = torch.empty( diff --git a/vllm/v1/attention/ops/triton_unified_attention.py b/vllm/v1/attention/ops/triton_unified_attention.py index 4ddd47c6d..ca5d0e336 100644 --- a/vllm/v1/attention/ops/triton_unified_attention.py +++ b/vllm/v1/attention/ops/triton_unified_attention.py @@ -9,13 +9,13 @@ import torch +import vllm.envs as envs from vllm.logger import init_logger -from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant from vllm.platforms import current_platform from vllm.triton_utils import tl, triton logger = init_logger(__name__) -is_batch_invariant = vllm_is_batch_invariant() +is_batch_invariant = envs.VLLM_BATCH_INVARIANT float8_info = torch.finfo(current_platform.fp8_dtype())