[Bugfix][ROCm] Include float8_e4m3fnuz in NCCL Dtype Dispatching (#33713)
Signed-off-by: Micah Williamson <micah.williamson@amd.com>
This commit is contained in:
@@ -93,7 +93,7 @@ class ncclDataTypeEnum:
|
|||||||
return cls.ncclFloat64
|
return cls.ncclFloat64
|
||||||
if dtype == torch.bfloat16:
|
if dtype == torch.bfloat16:
|
||||||
return cls.ncclBfloat16
|
return cls.ncclBfloat16
|
||||||
if dtype == torch.float8_e4m3fn:
|
if dtype == current_platform.fp8_dtype():
|
||||||
return cls.ncclFloat8e4m3
|
return cls.ncclFloat8e4m3
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported dtype {dtype}: should be one of "
|
f"Unsupported dtype {dtype}: should be one of "
|
||||||
|
|||||||
Reference in New Issue
Block a user