[Misc] rename torch_dtype to dtype (#26695)
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -500,8 +500,8 @@ class CudaPlatformBase(Platform):
|
||||
return supported
|
||||
|
||||
@classmethod
|
||||
def check_if_supports_dtype(cls, torch_dtype: torch.dtype):
|
||||
if torch_dtype == torch.bfloat16: # noqa: SIM102
|
||||
def check_if_supports_dtype(cls, dtype: torch.dtype):
|
||||
if dtype == torch.bfloat16: # noqa: SIM102
|
||||
if not cls.has_device_capability(80):
|
||||
capability = cls.get_device_capability()
|
||||
gpu_name = cls.get_device_name()
|
||||
|
||||
@@ -563,7 +563,7 @@ class Platform:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def check_if_supports_dtype(cls, torch_dtype: torch.dtype):
|
||||
def check_if_supports_dtype(cls, dtype: torch.dtype):
|
||||
"""
|
||||
Check if the dtype is supported by the current platform.
|
||||
"""
|
||||
|
||||
@@ -484,8 +484,8 @@ class RocmPlatform(Platform):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def check_if_supports_dtype(cls, torch_dtype: torch.dtype):
|
||||
if torch_dtype == torch.bfloat16: # noqa: SIM102
|
||||
def check_if_supports_dtype(cls, dtype: torch.dtype):
|
||||
if dtype == torch.bfloat16: # noqa: SIM102
|
||||
if not cls.has_device_capability(80):
|
||||
capability = cls.get_device_capability()
|
||||
gpu_name = cls.get_device_name()
|
||||
|
||||
@@ -236,8 +236,8 @@ class XPUPlatform(Platform):
|
||||
return torch.xpu.device_count()
|
||||
|
||||
@classmethod
|
||||
def check_if_supports_dtype(cls, torch_dtype: torch.dtype):
|
||||
if torch_dtype == torch.bfloat16: # noqa: SIM102
|
||||
def check_if_supports_dtype(cls, dtype: torch.dtype):
|
||||
if dtype == torch.bfloat16: # noqa: SIM102
|
||||
device_name = cls.get_device_name().lower()
|
||||
# client gpu a770
|
||||
if device_name.count("a770") > 0:
|
||||
|
||||
Reference in New Issue
Block a user