[Bugfix] Register VLLM_BATCH_INVARIANT in envs.py to fix spurious unknown env var warning (#35007)

Signed-off-by: Ranran <1012869439@qq.com>
Signed-off-by: Ranran <hzz5361@psu.edu>
Signed-off-by: ran <hzz5361@psu.edu>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
Ranran
2026-03-23 17:31:14 -05:00
committed by GitHub
parent e85f8f0932
commit dc6908ac6a
30 changed files with 70 additions and 130 deletions

View File

@@ -3,8 +3,8 @@
from typing import Any
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
from vllm.platforms import current_platform
logger = init_logger(__name__)
@@ -114,7 +114,7 @@ def get_flash_attn_version(
# FA4 currently uses batch-shape-dependent scheduling
# heuristics on SM100+, which breaks batch invariance.
if vllm_is_batch_invariant() and fa_version == 4:
if envs.VLLM_BATCH_INVARIANT and fa_version == 4:
logger.warning_once(
"Cannot use FA version 4 with batch invariance, "
"defaulting to FA version 2.",

View File

@@ -33,6 +33,7 @@ if is_flash_attn_varlen_func_available():
get_scheduler_metadata,
reshape_and_cache_flash,
)
import vllm.envs as envs
from vllm.config import (
VllmConfig,
get_current_vllm_config,
@@ -42,9 +43,6 @@ from vllm.config import (
from vllm.config.cache import CacheDType
from vllm.distributed.parallel_state import get_dcp_group
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.platforms.interface import DeviceCapability
from vllm.utils.math_utils import cdiv, round_up
from vllm.v1.attention.backend import (
@@ -402,7 +400,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
# we only set num_splits when using cuda graphs.
max_num_splits = self.max_num_splits
if vllm_is_batch_invariant():
if envs.VLLM_BATCH_INVARIANT:
max_num_splits = 1
def schedule(
@@ -601,7 +599,7 @@ class FlashAttentionImpl(AttentionImpl):
scope="local",
)
# Cache the batch invariant result for use in forward passes
self.batch_invariant_enabled = vllm_is_batch_invariant()
self.batch_invariant_enabled = envs.VLLM_BATCH_INVARIANT
if is_quantized_kv_cache(self.kv_cache_dtype) and not flash_attn_supports_fp8():
raise NotImplementedError(
@@ -1124,7 +1122,7 @@ def cascade_attention(
# s_aux is incorporated into prefix_lse inside the GPU kernel,
# enabling its effect during the final attention merge.
s_aux=s_aux,
num_splits=1 if vllm_is_batch_invariant() else max_num_splits,
num_splits=1 if envs.VLLM_BATCH_INVARIANT else max_num_splits,
)
descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2])
@@ -1149,7 +1147,7 @@ def cascade_attention(
q_descale=q_descale.expand(descale_shape) if q_descale is not None else None,
k_descale=k_descale.expand(descale_shape) if k_descale is not None else None,
v_descale=v_descale.expand(descale_shape) if v_descale is not None else None,
num_splits=1 if vllm_is_batch_invariant() else max_num_splits,
num_splits=1 if envs.VLLM_BATCH_INVARIANT else max_num_splits,
)
# Merge prefix and suffix outputs, and store the result in output.

View File

@@ -28,9 +28,6 @@ from vllm.config import (
from vllm.config.cache import CacheDType
from vllm.distributed.parallel_state import get_dcp_group
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8StaticTensorSym,
@@ -544,7 +541,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
) = None # Wrapper for prefill/append
self._decode_wrapper = None # Wrapper for decode (general shape)
if vllm_is_batch_invariant():
if envs.VLLM_BATCH_INVARIANT:
self.decode_fixed_split_size = 2048
self.prefill_fixed_split_size = 4096
self.disable_split_kv = True
@@ -719,7 +716,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
def _get_workspace_buffer(self):
if self._workspace_buffer is None:
buffer_size = envs.VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE
if vllm_is_batch_invariant():
if envs.VLLM_BATCH_INVARIANT:
buffer_size = FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT
self._workspace_buffer = torch.zeros(
buffer_size, dtype=torch.uint8, device=self.device

View File

@@ -20,12 +20,10 @@ from torch.nn.attention.flex_attention import (
or_masks,
)
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import is_torch_equal_or_newer
@@ -995,7 +993,7 @@ def get_kernel_options(
return block_size
return candidate
if vllm_is_batch_invariant():
if envs.VLLM_BATCH_INVARIANT:
kernel_options["BLOCK_M"] = 16
kernel_options["BLOCK_N"] = 16
kernel_options["IS_DIVISIBLE"] = False

View File

@@ -6,6 +6,7 @@ from typing import ClassVar
import torch
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
@@ -17,9 +18,6 @@ from vllm.model_executor.layers.attention.mla_attention import (
MLACommonMetadataBuilder,
QueryLenSupport,
)
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.platforms.interface import DeviceCapability
from vllm.utils.math_utils import round_up
from vllm.v1.attention.backend import (
@@ -152,7 +150,7 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
vllm_config.attention_config.flash_attn_max_num_splits_for_cuda_graph
)
if vllm_is_batch_invariant():
if envs.VLLM_BATCH_INVARIANT:
self.max_num_splits = 1
def _schedule_decode(
@@ -209,7 +207,7 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
# we only set num_splits when using cuda graphs.
max_num_splits = self.max_num_splits
if vllm_is_batch_invariant():
if envs.VLLM_BATCH_INVARIANT:
max_num_splits = 1
scheduler_metadata = self._schedule_decode(

View File

@@ -6,6 +6,7 @@ from typing import ClassVar
import torch
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
@@ -17,9 +18,6 @@ from vllm.model_executor.layers.attention.mla_attention import (
MLACommonMetadataBuilder,
QueryLenSupport,
)
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.platforms.interface import DeviceCapability
from vllm.utils.platform_utils import num_compute_units
from vllm.v1.attention.backend import (
@@ -256,7 +254,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
q = reshape_query_for_spec_decode(q, num_decodes)
scheduler_metadata = attn_metadata.decode.scheduler_metadata
if vllm_is_batch_invariant() and not self.kv_cache_dtype.startswith("fp8"):
if envs.VLLM_BATCH_INVARIANT and not self.kv_cache_dtype.startswith("fp8"):
device = q.device
dtype = torch.int32

View File

@@ -5,6 +5,7 @@ from typing import ClassVar
import torch
import vllm.envs as envs
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.model_executor.layers.attention.mla_attention import (
@@ -12,9 +13,6 @@ from vllm.model_executor.layers.attention.mla_attention import (
MLACommonImpl,
MLACommonMetadata,
)
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backend import (
AttentionLayer,
@@ -151,7 +149,7 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
lse = torch.zeros(B, q_num_heads, dtype=q.dtype, device=q.device)
# For batch invariance, use only 1 split to ensure deterministic reduction
num_kv_splits = 1 if vllm_is_batch_invariant() else 4
num_kv_splits = 1 if envs.VLLM_BATCH_INVARIANT else 4
# TODO(lucas) Allocate ahead of time
attn_logits = torch.empty(

View File

@@ -9,13 +9,13 @@
import torch
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
logger = init_logger(__name__)
is_batch_invariant = vllm_is_batch_invariant()
is_batch_invariant = envs.VLLM_BATCH_INVARIANT
float8_info = torch.finfo(current_platform.fp8_dtype())