Add cache to cuda get_device_capability (#19436)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin
2025-06-11 05:37:05 -04:00
committed by GitHub
parent a2142f0196
commit 7484e1fce2

View File

@@ -6,7 +6,7 @@ pynvml. However, it should not initialize cuda context.
import os
from datetime import timedelta
from functools import wraps
from functools import cache, wraps
from typing import TYPE_CHECKING, Callable, Optional, TypeVar, Union
import torch
@@ -389,6 +389,7 @@ class CudaPlatformBase(Platform):
class NvmlCudaPlatform(CudaPlatformBase):
@classmethod
@cache
@with_nvml_context
def get_device_capability(cls,
device_id: int = 0
@@ -486,6 +487,7 @@ class NvmlCudaPlatform(CudaPlatformBase):
class NonNvmlCudaPlatform(CudaPlatformBase):
@classmethod
@cache
def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
major, minor = torch.cuda.get_device_capability(device_id)
return DeviceCapability(major=major, minor=minor)