[XPU] Delay BF16 check to worker init for spawn compatibility (#22979)

Signed-off-by: chzhang <chaojun.zhang@intel.com>
This commit is contained in:
Chaojun Zhang
2025-08-26 04:09:26 +08:00
committed by GitHub
parent 9188ae7cb5
commit 8a044754bd
6 changed files with 60 additions and 47 deletions

View File

@@ -462,3 +462,23 @@ class RocmPlatform(Platform):
def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str,
model_config: "ModelConfig") -> bool:
return True
@classmethod
def check_if_supports_dtype(cls, torch_dtype: torch.dtype):
if torch_dtype == torch.bfloat16: # noqa: SIM102
if not cls.has_device_capability(80):
capability = cls.get_device_capability()
gpu_name = cls.get_device_name()
if capability is None:
compute_str = "does not have a compute capability"
else:
version_str = capability.as_version_str()
compute_str = f"has compute capability {version_str}"
raise ValueError(
"Bfloat16 is only supported on GPUs "
"with compute capability of at least 8.0. "
f"Your {gpu_name} GPU {compute_str}. "
"You can use float16 instead by explicitly setting the "
"`dtype` flag in CLI, for example: --dtype=half.")