[Refactor] Remove Unused Func in Batch Invariant (#28881)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye
2025-11-17 23:29:29 -05:00
committed by GitHub
parent d0a73620cc
commit 3ddcf46011

View File

@@ -1,8 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import os import os
from collections import namedtuple
from collections.abc import Callable from collections.abc import Callable
from functools import cache from functools import cache
from typing import Any from typing import Any
@@ -725,10 +723,6 @@ _original_cublas_workspace_cfg = None
_original_cublaslt_workspace_size = None _original_cublaslt_workspace_size = None
def is_batch_invariant_mode_enabled():
return _batch_invariant_MODE
def enable_batch_invariant_mode(): def enable_batch_invariant_mode():
global _batch_invariant_MODE, _batch_invariant_LIB, _original_torch_bmm global _batch_invariant_MODE, _batch_invariant_LIB, _original_torch_bmm
global _original_fp16_reduction_precision, _original_bf16_reduction_precision global _original_fp16_reduction_precision, _original_bf16_reduction_precision
@@ -791,73 +785,6 @@ def enable_batch_invariant_mode():
torch.backends.cuda.preferred_blas_library(backend="cublaslt") torch.backends.cuda.preferred_blas_library(backend="cublaslt")
def disable_batch_invariant_mode():
global _batch_invariant_MODE, _batch_invariant_LIB, _original_torch_bmm
global _original_fp16_reduction_precision, _original_bf16_reduction_precision
global _original_cublas_workspace_cfg, _original_cublaslt_workspace_size
if not _batch_invariant_MODE:
return
if _batch_invariant_LIB is not None:
_batch_invariant_LIB._destroy()
if _original_torch_bmm is not None:
torch.bmm = _original_torch_bmm
_original_torch_bmm = None
if _original_bf16_reduction_precision is not None:
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = (
_original_bf16_reduction_precision
)
_original_bf16_reduction_precision = None
if _original_fp16_reduction_precision is not None:
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = (
_original_fp16_reduction_precision
)
_original_fp16_reduction_precision = None
torch.backends.cuda.preferred_blas_library(backend="default")
if not is_torch_equal_or_newer("2.10.0.dev"):
# Set cublas env vars to previous results. If previous results are None,
# that means the env vars were not set, so we should remove them.
if _original_cublas_workspace_cfg:
os.environ["CUBLAS_WORKSPACE_CONFIG"] = _original_cublas_workspace_cfg
elif "CUBLAS_WORKSPACE_CONFIG" in os.environ:
del os.environ["CUBLAS_WORKSPACE_CONFIG"]
if _original_cublaslt_workspace_size:
os.environ["CUBLASLT_WORKSPACE_SIZE"] = _original_cublaslt_workspace_size
elif "CUBLASLT_WORKSPACE_SIZE" in os.environ:
del os.environ["CUBLASLT_WORKSPACE_SIZE"]
_original_cublas_workspace_cfg = None
_original_cublaslt_workspace_size = None
_batch_invariant_MODE = False
_batch_invariant_LIB = None
@contextlib.contextmanager
def set_batch_invariant_mode(enabled: bool = True):
global _batch_invariant_MODE, _batch_invariant_LIB
old_data = (_batch_invariant_MODE, _batch_invariant_LIB)
if enabled:
enable_batch_invariant_mode()
else:
disable_batch_invariant_mode()
yield
if _batch_invariant_LIB is not None:
_batch_invariant_LIB._destroy()
_batch_invariant_MODE, _batch_invariant_LIB = old_data
AttentionBlockSize = namedtuple("AttentionBlockSize", ["block_m", "block_n"])
def get_batch_invariant_attention_block_size() -> AttentionBlockSize:
return AttentionBlockSize(block_m=16, block_n=16)
@cache @cache
def vllm_is_batch_invariant(): def vllm_is_batch_invariant():
env_key = "VLLM_BATCH_INVARIANT" env_key = "VLLM_BATCH_INVARIANT"