[ez] Remove checks for torch version <= 2.8 (#33209)
Signed-off-by: angelayi <yiangela7@gmail.com>
This commit is contained in:
@@ -8,7 +8,7 @@ from torch._ops import OpOverload
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
from vllm.v1.attention.ops.rocm_aiter_mla_sparse import (
|
||||
rocm_aiter_sparse_attn_indexer,
|
||||
rocm_aiter_sparse_attn_indexer_fake,
|
||||
@@ -1015,12 +1015,6 @@ class rocm_aiter_ops:
|
||||
def register_ops_once() -> None:
|
||||
global _OPS_REGISTERED
|
||||
if not _OPS_REGISTERED:
|
||||
tags = (
|
||||
tuple()
|
||||
if is_torch_equal_or_newer("2.7.0")
|
||||
else (torch.Tag.needs_fixed_stride_order,)
|
||||
)
|
||||
|
||||
# register all the custom ops here
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_asm_moe_tkw1",
|
||||
@@ -1075,7 +1069,6 @@ class rocm_aiter_ops:
|
||||
op_func=_rocm_aiter_mla_decode_fwd_impl,
|
||||
mutates_args=["o"],
|
||||
fake_impl=_rocm_aiter_mla_decode_fwd_fake,
|
||||
tags=tags,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
|
||||
@@ -33,7 +33,6 @@ from vllm.logger import init_logger
|
||||
from vllm.logging_utils import lazy
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
from .compiler_interface import (
|
||||
CompilerInterface,
|
||||
@@ -94,10 +93,8 @@ def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface:
|
||||
if compilation_config.backend == "inductor":
|
||||
# Use standalone compile only if requested, version is new enough,
|
||||
# and the symbol actually exists in this PyTorch build.
|
||||
if (
|
||||
envs.VLLM_USE_STANDALONE_COMPILE
|
||||
and is_torch_equal_or_newer("2.8.0.dev")
|
||||
and hasattr(torch._inductor, "standalone_compile")
|
||||
if envs.VLLM_USE_STANDALONE_COMPILE and hasattr(
|
||||
torch._inductor, "standalone_compile"
|
||||
):
|
||||
logger.debug("Using InductorStandaloneAdaptor")
|
||||
return InductorStandaloneAdaptor(
|
||||
|
||||
@@ -501,20 +501,19 @@ class InductorAdaptor(CompilerInterface):
|
||||
# get hit.
|
||||
# TODO(zou3519): we're going to replace this all with
|
||||
# standalone_compile sometime.
|
||||
if is_torch_equal_or_newer("2.6"):
|
||||
stack.enter_context(
|
||||
torch._inductor.config.patch(fx_graph_remote_cache=False)
|
||||
)
|
||||
# InductorAdaptor (unfortunately) requires AOTAutogradCache
|
||||
# to be turned off to run. It will fail to acquire the hash_str
|
||||
# and error if not.
|
||||
# StandaloneInductorAdaptor (PyTorch 2.8+) fixes this problem.
|
||||
stack.enter_context(
|
||||
torch._functorch.config.patch(enable_autograd_cache=False)
|
||||
)
|
||||
stack.enter_context(
|
||||
torch._functorch.config.patch(enable_remote_autograd_cache=False)
|
||||
)
|
||||
stack.enter_context(
|
||||
torch._inductor.config.patch(fx_graph_remote_cache=False)
|
||||
)
|
||||
# InductorAdaptor (unfortunately) requires AOTAutogradCache
|
||||
# to be turned off to run. It will fail to acquire the hash_str
|
||||
# and error if not.
|
||||
# StandaloneInductorAdaptor (PyTorch 2.8+) fixes this problem.
|
||||
stack.enter_context(
|
||||
torch._functorch.config.patch(enable_autograd_cache=False)
|
||||
)
|
||||
stack.enter_context(
|
||||
torch._functorch.config.patch(enable_remote_autograd_cache=False)
|
||||
)
|
||||
|
||||
compiled_graph = compile_fx(
|
||||
graph,
|
||||
|
||||
@@ -7,12 +7,11 @@ import inspect
|
||||
import os
|
||||
import sys
|
||||
from collections.abc import Callable, Generator
|
||||
from typing import TYPE_CHECKING, Any, Literal, TypeVar, overload
|
||||
from typing import TYPE_CHECKING, Any, TypeVar, overload
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from packaging import version
|
||||
from torch._dynamo.symbolic_convert import InliningInstructionTranslator
|
||||
|
||||
import vllm.envs as envs
|
||||
@@ -540,7 +539,6 @@ def _support_torch_compile(
|
||||
torch._dynamo.config.patch(**dynamo_config_patches),
|
||||
maybe_use_cudagraph_partition_wrapper(self.vllm_config),
|
||||
torch.fx.experimental._config.patch(**fx_config_patches),
|
||||
_torch27_patch_tensor_subclasses(),
|
||||
torch._inductor.config.patch(**inductor_config_patches),
|
||||
):
|
||||
use_aot_compile = envs.VLLM_USE_AOT_COMPILE
|
||||
@@ -647,42 +645,3 @@ def maybe_use_cudagraph_partition_wrapper(
|
||||
and compilation_config.use_inductor_graph_partition
|
||||
):
|
||||
torch._inductor.utils.set_customized_partition_wrappers(None)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _torch27_patch_tensor_subclasses() -> Generator[None, None, None]:
|
||||
"""
|
||||
Add support for using tensor subclasses (ie `BasevLLMParameter`, ect) when
|
||||
using torch 2.7.0. This enables using weight_loader_v2 and the use of
|
||||
`BasevLLMParameters` without having to replace them with regular tensors
|
||||
before `torch.compile`-time.
|
||||
"""
|
||||
from vllm.model_executor.parameter import (
|
||||
BasevLLMParameter,
|
||||
ModelWeightParameter,
|
||||
RowvLLMParameter,
|
||||
_ColumnvLLMParameter,
|
||||
)
|
||||
|
||||
def return_false(*args: Any, **kwargs: Any) -> Literal[False]:
|
||||
return False
|
||||
|
||||
if version.parse("2.7") <= version.parse(torch.__version__) < version.parse("2.8"):
|
||||
yield
|
||||
return
|
||||
|
||||
with (
|
||||
torch._dynamo.config.patch(
|
||||
"traceable_tensor_subclasses",
|
||||
[
|
||||
BasevLLMParameter,
|
||||
ModelWeightParameter,
|
||||
_ColumnvLLMParameter,
|
||||
RowvLLMParameter,
|
||||
],
|
||||
),
|
||||
patch(
|
||||
"torch._dynamo.variables.torch.can_dispatch_torch_function", return_false
|
||||
),
|
||||
):
|
||||
yield
|
||||
|
||||
@@ -16,18 +16,10 @@ import torch
|
||||
from torch import fx
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode, unset_fake_temporarily
|
||||
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config.utils import Range
|
||||
|
||||
if is_torch_equal_or_newer("2.6"):
|
||||
from torch._inductor.custom_graph_pass import CustomGraphPass
|
||||
else:
|
||||
# CustomGraphPass is not present in 2.5 or lower, import our version
|
||||
from .torch25_custom_graph_pass import (
|
||||
Torch25CustomGraphPass as CustomGraphPass,
|
||||
)
|
||||
from torch._inductor.custom_graph_pass import CustomGraphPass
|
||||
|
||||
_pass_context = None
|
||||
P = ParamSpec("P")
|
||||
|
||||
@@ -777,10 +777,9 @@ class CompilationConfig:
|
||||
# and it is not yet a priority. RFC here:
|
||||
# https://github.com/vllm-project/vllm/issues/14703
|
||||
|
||||
if is_torch_equal_or_newer("2.6"):
|
||||
KEY = "enable_auto_functionalized_v2"
|
||||
if KEY not in self.inductor_compile_config:
|
||||
self.inductor_compile_config[KEY] = False
|
||||
KEY = "enable_auto_functionalized_v2"
|
||||
if KEY not in self.inductor_compile_config:
|
||||
self.inductor_compile_config[KEY] = False
|
||||
|
||||
for k, v in self.inductor_passes.items():
|
||||
if not isinstance(v, str):
|
||||
|
||||
@@ -31,7 +31,6 @@ import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.network_utils import get_tcp_uri
|
||||
from vllm.utils.system_utils import suppress_stdout
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -429,20 +428,11 @@ def init_gloo_process_group(
|
||||
different torch versions.
|
||||
"""
|
||||
with suppress_stdout():
|
||||
if is_torch_equal_or_newer("2.6"):
|
||||
pg = ProcessGroup(
|
||||
prefix_store,
|
||||
group_rank,
|
||||
group_size,
|
||||
)
|
||||
else:
|
||||
options = ProcessGroup.Options(backend="gloo")
|
||||
pg = ProcessGroup(
|
||||
prefix_store,
|
||||
group_rank,
|
||||
group_size,
|
||||
options,
|
||||
)
|
||||
pg = ProcessGroup(
|
||||
prefix_store,
|
||||
group_rank,
|
||||
group_size,
|
||||
)
|
||||
from torch.distributed.distributed_c10d import ProcessGroupGloo
|
||||
|
||||
backend_class = ProcessGroupGloo(
|
||||
@@ -450,9 +440,7 @@ def init_gloo_process_group(
|
||||
)
|
||||
backend_type = ProcessGroup.BackendType.GLOO
|
||||
device = torch.device("cpu")
|
||||
if is_torch_equal_or_newer("2.6"):
|
||||
# _set_default_backend is supported in torch >= 2.6
|
||||
pg._set_default_backend(backend_type)
|
||||
pg._set_default_backend(backend_type)
|
||||
backend_class._set_sequence_number_for_group()
|
||||
|
||||
pg._register_backend(device, backend_type, backend_class)
|
||||
@@ -534,12 +522,5 @@ def stateless_destroy_torch_distributed_process_group(pg: ProcessGroup) -> None:
|
||||
Destroy ProcessGroup returned by
|
||||
stateless_init_torch_distributed_process_group().
|
||||
"""
|
||||
if is_torch_equal_or_newer("2.7"):
|
||||
pg.shutdown()
|
||||
else:
|
||||
# Lazy import for non-CUDA backends.
|
||||
from torch.distributed.distributed_c10d import _shutdown_backend
|
||||
|
||||
_shutdown_backend(pg)
|
||||
|
||||
pg.shutdown()
|
||||
_unregister_process_group(pg.group_name)
|
||||
|
||||
@@ -52,7 +52,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -1406,11 +1406,6 @@ direct_register_custom_op(
|
||||
op_func=inplace_fused_experts,
|
||||
mutates_args=["hidden_states"],
|
||||
fake_impl=inplace_fused_experts_fake,
|
||||
tags=(
|
||||
()
|
||||
if is_torch_equal_or_newer("2.7.0")
|
||||
else (torch.Tag.needs_fixed_stride_order,)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -1501,11 +1496,6 @@ direct_register_custom_op(
|
||||
op_name="outplace_fused_experts",
|
||||
op_func=outplace_fused_experts,
|
||||
fake_impl=outplace_fused_experts_fake,
|
||||
tags=(
|
||||
()
|
||||
if is_torch_equal_or_newer("2.7.0")
|
||||
else (torch.Tag.needs_fixed_stride_order,)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -56,7 +56,6 @@ from vllm.scalar_type import scalar_types
|
||||
from vllm.utils.flashinfer import has_flashinfer
|
||||
from vllm.utils.import_utils import has_triton_kernels
|
||||
from vllm.utils.math_utils import round_up
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -89,7 +88,6 @@ def get_mxfp4_backend_with_lora() -> Mxfp4Backend:
|
||||
# If FlashInfer is not available, try either Marlin or Triton
|
||||
triton_kernels_supported = (
|
||||
has_triton_kernels()
|
||||
and is_torch_equal_or_newer("2.8.0")
|
||||
# NOTE: triton_kernels are only confirmed to work on SM90 and SM100
|
||||
# SM110 fails with this error: https://github.com/vllm-project/vllm/issues/29317
|
||||
# SM120 needs this fix: https://github.com/triton-lang/triton/pull/8498
|
||||
@@ -151,7 +149,6 @@ def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend:
|
||||
# If FlashInfer is not available, try either Marlin or Triton
|
||||
triton_kernels_supported = (
|
||||
has_triton_kernels()
|
||||
and is_torch_equal_or_newer("2.8.0")
|
||||
# NOTE: triton_kernels are only confirmed to work on SM90 and SM100
|
||||
# SM110 fails with this error: https://github.com/vllm-project/vllm/issues/29317
|
||||
# SM120 needs this fix: https://github.com/triton-lang/triton/pull/8498
|
||||
|
||||
@@ -108,20 +108,6 @@ class TorchAOConfig(QuantizationConfig):
|
||||
skip_modules: list[str] | None = None,
|
||||
is_checkpoint_torchao_serialized: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
# TorchAO quantization relies on tensor subclasses. In order,
|
||||
# to enable proper caching this needs standalone compile
|
||||
if is_torch_equal_or_newer("2.8.0.dev"):
|
||||
os.environ["VLLM_TEST_STANDALONE_COMPILE"] = "1"
|
||||
logger.info(
|
||||
"Using TorchAO: Setting VLLM_TEST_STANDALONE_COMPILE=1")
|
||||
|
||||
# TODO: remove after the torch dependency is updated to 2.8
|
||||
if is_torch_equal_or_newer(
|
||||
"2.7.0") and not is_torch_equal_or_newer("2.8.0.dev"):
|
||||
os.environ["VLLM_DISABLE_COMPILE_CACHE"] = "1"
|
||||
logger.info("Using TorchAO: Setting VLLM_DISABLE_COMPILE_CACHE=1")
|
||||
"""
|
||||
super().__init__()
|
||||
self.torchao_config = torchao_config
|
||||
self.skip_modules = skip_modules or []
|
||||
|
||||
@@ -709,9 +709,7 @@ def is_torch_equal(target: str) -> bool:
|
||||
|
||||
# Supports xccl with PyTorch versions >= 2.8.0.dev for XPU platform
|
||||
def supports_xccl() -> bool:
|
||||
return (
|
||||
is_torch_equal_or_newer("2.8.0.dev") and torch.distributed.is_xccl_available()
|
||||
)
|
||||
return torch.distributed.is_xccl_available()
|
||||
|
||||
|
||||
# create a library to hold the custom op
|
||||
|
||||
Reference in New Issue
Block a user