Fix the torch version parsing logic (#15857)
This commit is contained in:
@@ -2,7 +2,6 @@
|
||||
import contextlib
|
||||
import copy
|
||||
import hashlib
|
||||
import importlib.metadata
|
||||
import os
|
||||
from contextlib import ExitStack
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
@@ -11,9 +10,9 @@ from unittest.mock import patch
|
||||
import torch
|
||||
import torch._inductor.compile_fx
|
||||
import torch.fx as fx
|
||||
from packaging.version import Version
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.utils import is_torch_equal_or_newer
|
||||
|
||||
|
||||
class CompilerInterface:
|
||||
@@ -379,7 +378,7 @@ class InductorAdaptor(CompilerInterface):
|
||||
manually setting up internal contexts. But we also rely on non-public
|
||||
APIs which might not provide these guarantees.
|
||||
"""
|
||||
if Version(importlib.metadata.version('torch')) >= Version("2.6"):
|
||||
if is_torch_equal_or_newer("2.6"):
|
||||
import torch._dynamo.utils
|
||||
return torch._dynamo.utils.get_metrics_context()
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user