Fix the torch version parsing logic (#15857)

This commit is contained in:
Lu Fang
2025-04-10 07:37:47 -07:00
committed by GitHub
parent 8661c0241d
commit 7678fcd5b6
4 changed files with 26 additions and 11 deletions

View File

@@ -53,6 +53,7 @@ import torch.types
import yaml
import zmq
import zmq.asyncio
from packaging import version
from packaging.version import Version
from torch.library import Library
from typing_extensions import Never, ParamSpec, TypeIs, assert_never
@@ -2580,3 +2581,20 @@ def sha256(input) -> int:
input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL)
return int.from_bytes(hashlib.sha256(input_bytes).digest(),
byteorder="big")
def is_torch_equal_or_newer(target: str) -> bool:
"""Check if the installed torch version is >= the target version.
Args:
target: a version string, like "2.6.0".
Returns:
Whether the condition meets.
"""
try:
torch_version = version.parse(str(torch.__version__))
return torch_version >= version.parse(target)
except Exception:
# Fallback to PKG-INFO to load the package info, needed by the doc gen.
return Version(importlib.metadata.version('torch')) >= Version(target)