Fix the torch version parsing logic (#15857)
This commit is contained in:
@@ -4,7 +4,6 @@ import ast
|
||||
import copy
|
||||
import enum
|
||||
import hashlib
|
||||
import importlib.metadata
|
||||
import json
|
||||
import sys
|
||||
import warnings
|
||||
@@ -18,7 +17,6 @@ from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal,
|
||||
Optional, Protocol, Union)
|
||||
|
||||
import torch
|
||||
from packaging.version import Version
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
from torch.distributed import ProcessGroup, ReduceOp
|
||||
from transformers import PretrainedConfig
|
||||
@@ -40,8 +38,8 @@ from vllm.transformers_utils.config import (
|
||||
from vllm.transformers_utils.s3_utils import S3Model
|
||||
from vllm.transformers_utils.utils import is_s3, maybe_model_redirect
|
||||
from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless,
|
||||
get_cpu_memory, get_open_port, random_uuid,
|
||||
resolve_obj_by_qualname)
|
||||
get_cpu_memory, get_open_port, is_torch_equal_or_newer,
|
||||
random_uuid, resolve_obj_by_qualname)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.util.placement_group import PlacementGroup
|
||||
@@ -3285,7 +3283,7 @@ class CompilationConfig(BaseModel):
|
||||
# and it is not yet a priority. RFC here:
|
||||
# https://github.com/vllm-project/vllm/issues/14703
|
||||
|
||||
if Version(importlib.metadata.version('torch')) >= Version("2.6"):
|
||||
if is_torch_equal_or_newer("2.6"):
|
||||
KEY = 'enable_auto_functionalized_v2'
|
||||
if KEY not in self.inductor_compile_config:
|
||||
self.inductor_compile_config[KEY] = False
|
||||
|
||||
Reference in New Issue
Block a user