Fix the torch version parsing logic (#15857)

This commit is contained in:
Lu Fang
2025-04-10 07:37:47 -07:00
committed by GitHub
parent 8661c0241d
commit 7678fcd5b6
4 changed files with 26 additions and 11 deletions

View File

@@ -2,7 +2,6 @@
import contextlib
import copy
import hashlib
import importlib.metadata
import os
from contextlib import ExitStack
from typing import Any, Callable, Dict, List, Optional, Tuple
@@ -11,9 +10,9 @@ from unittest.mock import patch
import torch
import torch._inductor.compile_fx
import torch.fx as fx
from packaging.version import Version
from vllm.config import VllmConfig
from vllm.utils import is_torch_equal_or_newer
class CompilerInterface:
@@ -379,7 +378,7 @@ class InductorAdaptor(CompilerInterface):
manually setting up internal contexts. But we also rely on non-public
APIs which might not provide these guarantees.
"""
if Version(importlib.metadata.version('torch')) >= Version("2.6"):
if is_torch_equal_or_newer("2.6"):
import torch._dynamo.utils
return torch._dynamo.utils.get_metrics_context()
else: