Fix the torch version parsing logic (#15857)
This commit is contained in:
@@ -53,6 +53,7 @@ import torch.types
|
||||
import yaml
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
from packaging import version
|
||||
from packaging.version import Version
|
||||
from torch.library import Library
|
||||
from typing_extensions import Never, ParamSpec, TypeIs, assert_never
|
||||
@@ -2580,3 +2581,20 @@ def sha256(input) -> int:
|
||||
input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
return int.from_bytes(hashlib.sha256(input_bytes).digest(),
|
||||
byteorder="big")
|
||||
|
||||
|
||||
def is_torch_equal_or_newer(target: str) -> bool:
|
||||
"""Check if the installed torch version is >= the target version.
|
||||
|
||||
Args:
|
||||
target: a version string, like "2.6.0".
|
||||
|
||||
Returns:
|
||||
Whether the condition meets.
|
||||
"""
|
||||
try:
|
||||
torch_version = version.parse(str(torch.__version__))
|
||||
return torch_version >= version.parse(target)
|
||||
except Exception:
|
||||
# Fallback to PKG-INFO to load the package info, needed by the doc gen.
|
||||
return Version(importlib.metadata.version('torch')) >= Version(target)
|
||||
|
||||
Reference in New Issue
Block a user