[Bugfix][ROCm] Include float8_e4m3fnuz in NCCL Dtype Dispatching (#33713)

Signed-off-by: Micah Williamson <micah.williamson@amd.com>
This commit is contained in:
Micah Williamson
2026-02-04 07:36:29 -06:00
committed by GitHub
parent 32a02c7ca2
commit 1d367a738e

View File

@@ -93,7 +93,7 @@ class ncclDataTypeEnum:
return cls.ncclFloat64
if dtype == torch.bfloat16:
return cls.ncclBfloat16
if dtype == torch.float8_e4m3fn:
if dtype == current_platform.fp8_dtype():
return cls.ncclFloat8e4m3
raise ValueError(
f"Unsupported dtype {dtype}: should be one of "