[Bugfix] Register VLLM_BATCH_INVARIANT in envs.py to fix spurious unknown env var warning (#35007)
Signed-off-by: Ranran <1012869439@qq.com> Signed-off-by: Ranran <hzz5361@psu.edu> Signed-off-by: ran <hzz5361@psu.edu> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
@@ -3,8 +3,8 @@
|
||||
|
||||
from typing import Any
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -114,7 +114,7 @@ def get_flash_attn_version(
|
||||
|
||||
# FA4 currently uses batch-shape-dependent scheduling
|
||||
# heuristics on SM100+, which breaks batch invariance.
|
||||
if vllm_is_batch_invariant() and fa_version == 4:
|
||||
if envs.VLLM_BATCH_INVARIANT and fa_version == 4:
|
||||
logger.warning_once(
|
||||
"Cannot use FA version 4 with batch invariance, "
|
||||
"defaulting to FA version 2.",
|
||||
|
||||
@@ -33,6 +33,7 @@ if is_flash_attn_varlen_func_available():
|
||||
get_scheduler_metadata,
|
||||
reshape_and_cache_flash,
|
||||
)
|
||||
import vllm.envs as envs
|
||||
from vllm.config import (
|
||||
VllmConfig,
|
||||
get_current_vllm_config,
|
||||
@@ -42,9 +43,6 @@ from vllm.config import (
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.distributed.parallel_state import get_dcp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.utils.math_utils import cdiv, round_up
|
||||
from vllm.v1.attention.backend import (
|
||||
@@ -402,7 +400,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
||||
# we only set num_splits when using cuda graphs.
|
||||
max_num_splits = self.max_num_splits
|
||||
|
||||
if vllm_is_batch_invariant():
|
||||
if envs.VLLM_BATCH_INVARIANT:
|
||||
max_num_splits = 1
|
||||
|
||||
def schedule(
|
||||
@@ -601,7 +599,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
scope="local",
|
||||
)
|
||||
# Cache the batch invariant result for use in forward passes
|
||||
self.batch_invariant_enabled = vllm_is_batch_invariant()
|
||||
self.batch_invariant_enabled = envs.VLLM_BATCH_INVARIANT
|
||||
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype) and not flash_attn_supports_fp8():
|
||||
raise NotImplementedError(
|
||||
@@ -1124,7 +1122,7 @@ def cascade_attention(
|
||||
# s_aux is incorporated into prefix_lse inside the GPU kernel,
|
||||
# enabling its effect during the final attention merge.
|
||||
s_aux=s_aux,
|
||||
num_splits=1 if vllm_is_batch_invariant() else max_num_splits,
|
||||
num_splits=1 if envs.VLLM_BATCH_INVARIANT else max_num_splits,
|
||||
)
|
||||
|
||||
descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2])
|
||||
@@ -1149,7 +1147,7 @@ def cascade_attention(
|
||||
q_descale=q_descale.expand(descale_shape) if q_descale is not None else None,
|
||||
k_descale=k_descale.expand(descale_shape) if k_descale is not None else None,
|
||||
v_descale=v_descale.expand(descale_shape) if v_descale is not None else None,
|
||||
num_splits=1 if vllm_is_batch_invariant() else max_num_splits,
|
||||
num_splits=1 if envs.VLLM_BATCH_INVARIANT else max_num_splits,
|
||||
)
|
||||
|
||||
# Merge prefix and suffix outputs, and store the result in output.
|
||||
|
||||
@@ -28,9 +28,6 @@ from vllm.config import (
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.distributed.parallel_state import get_dcp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
kFp8StaticTensorSym,
|
||||
@@ -544,7 +541,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
) = None # Wrapper for prefill/append
|
||||
self._decode_wrapper = None # Wrapper for decode (general shape)
|
||||
|
||||
if vllm_is_batch_invariant():
|
||||
if envs.VLLM_BATCH_INVARIANT:
|
||||
self.decode_fixed_split_size = 2048
|
||||
self.prefill_fixed_split_size = 4096
|
||||
self.disable_split_kv = True
|
||||
@@ -719,7 +716,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
def _get_workspace_buffer(self):
|
||||
if self._workspace_buffer is None:
|
||||
buffer_size = envs.VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE
|
||||
if vllm_is_batch_invariant():
|
||||
if envs.VLLM_BATCH_INVARIANT:
|
||||
buffer_size = FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT
|
||||
self._workspace_buffer = torch.zeros(
|
||||
buffer_size, dtype=torch.uint8, device=self.device
|
||||
|
||||
@@ -20,12 +20,10 @@ from torch.nn.attention.flex_attention import (
|
||||
or_masks,
|
||||
)
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
@@ -995,7 +993,7 @@ def get_kernel_options(
|
||||
return block_size
|
||||
return candidate
|
||||
|
||||
if vllm_is_batch_invariant():
|
||||
if envs.VLLM_BATCH_INVARIANT:
|
||||
kernel_options["BLOCK_M"] = 16
|
||||
kernel_options["BLOCK_N"] = 16
|
||||
kernel_options["IS_DIVISIBLE"] = False
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import ClassVar
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
@@ -17,9 +18,6 @@ from vllm.model_executor.layers.attention.mla_attention import (
|
||||
MLACommonMetadataBuilder,
|
||||
QueryLenSupport,
|
||||
)
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.utils.math_utils import round_up
|
||||
from vllm.v1.attention.backend import (
|
||||
@@ -152,7 +150,7 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
|
||||
vllm_config.attention_config.flash_attn_max_num_splits_for_cuda_graph
|
||||
)
|
||||
|
||||
if vllm_is_batch_invariant():
|
||||
if envs.VLLM_BATCH_INVARIANT:
|
||||
self.max_num_splits = 1
|
||||
|
||||
def _schedule_decode(
|
||||
@@ -209,7 +207,7 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
|
||||
# we only set num_splits when using cuda graphs.
|
||||
max_num_splits = self.max_num_splits
|
||||
|
||||
if vllm_is_batch_invariant():
|
||||
if envs.VLLM_BATCH_INVARIANT:
|
||||
max_num_splits = 1
|
||||
|
||||
scheduler_metadata = self._schedule_decode(
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import ClassVar
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
@@ -17,9 +18,6 @@ from vllm.model_executor.layers.attention.mla_attention import (
|
||||
MLACommonMetadataBuilder,
|
||||
QueryLenSupport,
|
||||
)
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.utils.platform_utils import num_compute_units
|
||||
from vllm.v1.attention.backend import (
|
||||
@@ -256,7 +254,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
||||
q = reshape_query_for_spec_decode(q, num_decodes)
|
||||
|
||||
scheduler_metadata = attn_metadata.decode.scheduler_metadata
|
||||
if vllm_is_batch_invariant() and not self.kv_cache_dtype.startswith("fp8"):
|
||||
if envs.VLLM_BATCH_INVARIANT and not self.kv_cache_dtype.startswith("fp8"):
|
||||
device = q.device
|
||||
dtype = torch.int32
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import ClassVar
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention.mla_attention import (
|
||||
@@ -12,9 +13,6 @@ from vllm.model_executor.layers.attention.mla_attention import (
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata,
|
||||
)
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionLayer,
|
||||
@@ -151,7 +149,7 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
lse = torch.zeros(B, q_num_heads, dtype=q.dtype, device=q.device)
|
||||
|
||||
# For batch invariance, use only 1 split to ensure deterministic reduction
|
||||
num_kv_splits = 1 if vllm_is_batch_invariant() else 4
|
||||
num_kv_splits = 1 if envs.VLLM_BATCH_INVARIANT else 4
|
||||
|
||||
# TODO(lucas) Allocate ahead of time
|
||||
attn_logits = torch.empty(
|
||||
|
||||
@@ -9,13 +9,13 @@
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
logger = init_logger(__name__)
|
||||
is_batch_invariant = vllm_is_batch_invariant()
|
||||
is_batch_invariant = envs.VLLM_BATCH_INVARIANT
|
||||
float8_info = torch.finfo(current_platform.fp8_dtype())
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user