[ez] Remove checks for torch version <= 2.8 (#33209)

Signed-off-by: angelayi <yiangela7@gmail.com>
This commit is contained in:
Angela Yi
2026-01-28 13:03:56 -08:00
committed by GitHub
parent 59bcc5b6f2
commit 4197168ea5
11 changed files with 30 additions and 139 deletions

View File

@@ -8,7 +8,7 @@ from torch._ops import OpOverload
import vllm.envs as envs
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.ops.rocm_aiter_mla_sparse import (
rocm_aiter_sparse_attn_indexer,
rocm_aiter_sparse_attn_indexer_fake,
@@ -1015,12 +1015,6 @@ class rocm_aiter_ops:
def register_ops_once() -> None:
global _OPS_REGISTERED
if not _OPS_REGISTERED:
tags = (
tuple()
if is_torch_equal_or_newer("2.7.0")
else (torch.Tag.needs_fixed_stride_order,)
)
# register all the custom ops here
direct_register_custom_op(
op_name="rocm_aiter_asm_moe_tkw1",
@@ -1075,7 +1069,6 @@ class rocm_aiter_ops:
op_func=_rocm_aiter_mla_decode_fwd_impl,
mutates_args=["o"],
fake_impl=_rocm_aiter_mla_decode_fwd_fake,
tags=tags,
)
direct_register_custom_op(

View File

@@ -33,7 +33,6 @@ from vllm.logger import init_logger
from vllm.logging_utils import lazy
from vllm.platforms import current_platform
from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.utils.torch_utils import is_torch_equal_or_newer
from .compiler_interface import (
CompilerInterface,
@@ -94,10 +93,8 @@ def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface:
if compilation_config.backend == "inductor":
# Use standalone compile only if requested, version is new enough,
# and the symbol actually exists in this PyTorch build.
if (
envs.VLLM_USE_STANDALONE_COMPILE
and is_torch_equal_or_newer("2.8.0.dev")
and hasattr(torch._inductor, "standalone_compile")
if envs.VLLM_USE_STANDALONE_COMPILE and hasattr(
torch._inductor, "standalone_compile"
):
logger.debug("Using InductorStandaloneAdaptor")
return InductorStandaloneAdaptor(

View File

@@ -501,20 +501,19 @@ class InductorAdaptor(CompilerInterface):
# get hit.
# TODO(zou3519): we're going to replace this all with
# standalone_compile sometime.
if is_torch_equal_or_newer("2.6"):
stack.enter_context(
torch._inductor.config.patch(fx_graph_remote_cache=False)
)
# InductorAdaptor (unfortunately) requires AOTAutogradCache
# to be turned off to run. It will fail to acquire the hash_str
# and error if not.
# StandaloneInductorAdaptor (PyTorch 2.8+) fixes this problem.
stack.enter_context(
torch._functorch.config.patch(enable_autograd_cache=False)
)
stack.enter_context(
torch._functorch.config.patch(enable_remote_autograd_cache=False)
)
stack.enter_context(
torch._inductor.config.patch(fx_graph_remote_cache=False)
)
# InductorAdaptor (unfortunately) requires AOTAutogradCache
# to be turned off to run. It will fail to acquire the hash_str
# and error if not.
# StandaloneInductorAdaptor (PyTorch 2.8+) fixes this problem.
stack.enter_context(
torch._functorch.config.patch(enable_autograd_cache=False)
)
stack.enter_context(
torch._functorch.config.patch(enable_remote_autograd_cache=False)
)
compiled_graph = compile_fx(
graph,

View File

@@ -7,12 +7,11 @@ import inspect
import os
import sys
from collections.abc import Callable, Generator
from typing import TYPE_CHECKING, Any, Literal, TypeVar, overload
from typing import TYPE_CHECKING, Any, TypeVar, overload
from unittest.mock import patch
import torch
import torch.nn as nn
from packaging import version
from torch._dynamo.symbolic_convert import InliningInstructionTranslator
import vllm.envs as envs
@@ -540,7 +539,6 @@ def _support_torch_compile(
torch._dynamo.config.patch(**dynamo_config_patches),
maybe_use_cudagraph_partition_wrapper(self.vllm_config),
torch.fx.experimental._config.patch(**fx_config_patches),
_torch27_patch_tensor_subclasses(),
torch._inductor.config.patch(**inductor_config_patches),
):
use_aot_compile = envs.VLLM_USE_AOT_COMPILE
@@ -647,42 +645,3 @@ def maybe_use_cudagraph_partition_wrapper(
and compilation_config.use_inductor_graph_partition
):
torch._inductor.utils.set_customized_partition_wrappers(None)
@contextlib.contextmanager
def _torch27_patch_tensor_subclasses() -> Generator[None, None, None]:
"""
Add support for using tensor subclasses (ie `BasevLLMParameter`, ect) when
using torch 2.7.0. This enables using weight_loader_v2 and the use of
`BasevLLMParameters` without having to replace them with regular tensors
before `torch.compile`-time.
"""
from vllm.model_executor.parameter import (
BasevLLMParameter,
ModelWeightParameter,
RowvLLMParameter,
_ColumnvLLMParameter,
)
def return_false(*args: Any, **kwargs: Any) -> Literal[False]:
return False
if version.parse("2.7") <= version.parse(torch.__version__) < version.parse("2.8"):
yield
return
with (
torch._dynamo.config.patch(
"traceable_tensor_subclasses",
[
BasevLLMParameter,
ModelWeightParameter,
_ColumnvLLMParameter,
RowvLLMParameter,
],
),
patch(
"torch._dynamo.variables.torch.can_dispatch_torch_function", return_false
),
):
yield

View File

@@ -16,18 +16,10 @@ import torch
from torch import fx
from torch._subclasses.fake_tensor import FakeTensorMode, unset_fake_temporarily
from vllm.utils.torch_utils import is_torch_equal_or_newer
if TYPE_CHECKING:
from vllm.config.utils import Range
if is_torch_equal_or_newer("2.6"):
from torch._inductor.custom_graph_pass import CustomGraphPass
else:
# CustomGraphPass is not present in 2.5 or lower, import our version
from .torch25_custom_graph_pass import (
Torch25CustomGraphPass as CustomGraphPass,
)
from torch._inductor.custom_graph_pass import CustomGraphPass
_pass_context = None
P = ParamSpec("P")

View File

@@ -777,10 +777,9 @@ class CompilationConfig:
# and it is not yet a priority. RFC here:
# https://github.com/vllm-project/vllm/issues/14703
if is_torch_equal_or_newer("2.6"):
KEY = "enable_auto_functionalized_v2"
if KEY not in self.inductor_compile_config:
self.inductor_compile_config[KEY] = False
KEY = "enable_auto_functionalized_v2"
if KEY not in self.inductor_compile_config:
self.inductor_compile_config[KEY] = False
for k, v in self.inductor_passes.items():
if not isinstance(v, str):

View File

@@ -31,7 +31,6 @@ import vllm.envs as envs
from vllm.logger import init_logger
from vllm.utils.network_utils import get_tcp_uri
from vllm.utils.system_utils import suppress_stdout
from vllm.utils.torch_utils import is_torch_equal_or_newer
logger = init_logger(__name__)
@@ -429,20 +428,11 @@ def init_gloo_process_group(
different torch versions.
"""
with suppress_stdout():
if is_torch_equal_or_newer("2.6"):
pg = ProcessGroup(
prefix_store,
group_rank,
group_size,
)
else:
options = ProcessGroup.Options(backend="gloo")
pg = ProcessGroup(
prefix_store,
group_rank,
group_size,
options,
)
pg = ProcessGroup(
prefix_store,
group_rank,
group_size,
)
from torch.distributed.distributed_c10d import ProcessGroupGloo
backend_class = ProcessGroupGloo(
@@ -450,9 +440,7 @@ def init_gloo_process_group(
)
backend_type = ProcessGroup.BackendType.GLOO
device = torch.device("cpu")
if is_torch_equal_or_newer("2.6"):
# _set_default_backend is supported in torch >= 2.6
pg._set_default_backend(backend_type)
pg._set_default_backend(backend_type)
backend_class._set_sequence_number_for_group()
pg._register_backend(device, backend_type, backend_class)
@@ -534,12 +522,5 @@ def stateless_destroy_torch_distributed_process_group(pg: ProcessGroup) -> None:
Destroy ProcessGroup returned by
stateless_init_torch_distributed_process_group().
"""
if is_torch_equal_or_newer("2.7"):
pg.shutdown()
else:
# Lazy import for non-CUDA backends.
from torch.distributed.distributed_c10d import _shutdown_backend
_shutdown_backend(pg)
pg.shutdown()
_unregister_process_group(pg.group_name)

View File

@@ -52,7 +52,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
)
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
from vllm.utils.torch_utils import direct_register_custom_op
logger = init_logger(__name__)
@@ -1406,11 +1406,6 @@ direct_register_custom_op(
op_func=inplace_fused_experts,
mutates_args=["hidden_states"],
fake_impl=inplace_fused_experts_fake,
tags=(
()
if is_torch_equal_or_newer("2.7.0")
else (torch.Tag.needs_fixed_stride_order,)
),
)
@@ -1501,11 +1496,6 @@ direct_register_custom_op(
op_name="outplace_fused_experts",
op_func=outplace_fused_experts,
fake_impl=outplace_fused_experts_fake,
tags=(
()
if is_torch_equal_or_newer("2.7.0")
else (torch.Tag.needs_fixed_stride_order,)
),
)

View File

@@ -56,7 +56,6 @@ from vllm.scalar_type import scalar_types
from vllm.utils.flashinfer import has_flashinfer
from vllm.utils.import_utils import has_triton_kernels
from vllm.utils.math_utils import round_up
from vllm.utils.torch_utils import is_torch_equal_or_newer
logger = init_logger(__name__)
@@ -89,7 +88,6 @@ def get_mxfp4_backend_with_lora() -> Mxfp4Backend:
# If FlashInfer is not available, try either Marlin or Triton
triton_kernels_supported = (
has_triton_kernels()
and is_torch_equal_or_newer("2.8.0")
# NOTE: triton_kernels are only confirmed to work on SM90 and SM100
# SM110 fails with this error: https://github.com/vllm-project/vllm/issues/29317
# SM120 needs this fix: https://github.com/triton-lang/triton/pull/8498
@@ -151,7 +149,6 @@ def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend:
# If FlashInfer is not available, try either Marlin or Triton
triton_kernels_supported = (
has_triton_kernels()
and is_torch_equal_or_newer("2.8.0")
# NOTE: triton_kernels are only confirmed to work on SM90 and SM100
# SM110 fails with this error: https://github.com/vllm-project/vllm/issues/29317
# SM120 needs this fix: https://github.com/triton-lang/triton/pull/8498

View File

@@ -108,20 +108,6 @@ class TorchAOConfig(QuantizationConfig):
skip_modules: list[str] | None = None,
is_checkpoint_torchao_serialized: bool = False,
) -> None:
"""
# TorchAO quantization relies on tensor subclasses. In order,
# to enable proper caching this needs standalone compile
if is_torch_equal_or_newer("2.8.0.dev"):
os.environ["VLLM_TEST_STANDALONE_COMPILE"] = "1"
logger.info(
"Using TorchAO: Setting VLLM_TEST_STANDALONE_COMPILE=1")
# TODO: remove after the torch dependency is updated to 2.8
if is_torch_equal_or_newer(
"2.7.0") and not is_torch_equal_or_newer("2.8.0.dev"):
os.environ["VLLM_DISABLE_COMPILE_CACHE"] = "1"
logger.info("Using TorchAO: Setting VLLM_DISABLE_COMPILE_CACHE=1")
"""
super().__init__()
self.torchao_config = torchao_config
self.skip_modules = skip_modules or []

View File

@@ -709,9 +709,7 @@ def is_torch_equal(target: str) -> bool:
# Supports xccl with PyTorch versions >= 2.8.0.dev for XPU platform
def supports_xccl() -> bool:
return (
is_torch_equal_or_newer("2.8.0.dev") and torch.distributed.is_xccl_available()
)
return torch.distributed.is_xccl_available()
# create a library to hold the custom op