[Misc] rename torch_dtype to dtype (#26695)

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
wangxiyuan
2025-10-15 20:11:48 +08:00
committed by GitHub
parent f93e348010
commit 8f4b313c37
30 changed files with 52 additions and 55 deletions

View File

@@ -500,8 +500,8 @@ class CudaPlatformBase(Platform):
return supported
@classmethod
def check_if_supports_dtype(cls, torch_dtype: torch.dtype):
if torch_dtype == torch.bfloat16: # noqa: SIM102
def check_if_supports_dtype(cls, dtype: torch.dtype):
if dtype == torch.bfloat16: # noqa: SIM102
if not cls.has_device_capability(80):
capability = cls.get_device_capability()
gpu_name = cls.get_device_name()

View File

@@ -563,7 +563,7 @@ class Platform:
return False
@classmethod
def check_if_supports_dtype(cls, torch_dtype: torch.dtype):
def check_if_supports_dtype(cls, dtype: torch.dtype):
"""
Check if the dtype is supported by the current platform.
"""

View File

@@ -484,8 +484,8 @@ class RocmPlatform(Platform):
return True
@classmethod
def check_if_supports_dtype(cls, torch_dtype: torch.dtype):
if torch_dtype == torch.bfloat16: # noqa: SIM102
def check_if_supports_dtype(cls, dtype: torch.dtype):
if dtype == torch.bfloat16: # noqa: SIM102
if not cls.has_device_capability(80):
capability = cls.get_device_capability()
gpu_name = cls.get_device_name()

View File

@@ -236,8 +236,8 @@ class XPUPlatform(Platform):
return torch.xpu.device_count()
@classmethod
def check_if_supports_dtype(cls, torch_dtype: torch.dtype):
if torch_dtype == torch.bfloat16: # noqa: SIM102
def check_if_supports_dtype(cls, dtype: torch.dtype):
if dtype == torch.bfloat16: # noqa: SIM102
device_name = cls.get_device_name().lower()
# client gpu a770
if device_name.count("a770") > 0: