[Bugfix][ROCm] Include float8_e4m3fnuz in NCCL Dtype Dispatching (#33713)
Signed-off-by: Micah Williamson <micah.williamson@amd.com>
This commit is contained in:
@@ -93,7 +93,7 @@ class ncclDataTypeEnum:
|
||||
return cls.ncclFloat64
|
||||
if dtype == torch.bfloat16:
|
||||
return cls.ncclBfloat16
|
||||
if dtype == torch.float8_e4m3fn:
|
||||
if dtype == current_platform.fp8_dtype():
|
||||
return cls.ncclFloat8e4m3
|
||||
raise ValueError(
|
||||
f"Unsupported dtype {dtype}: should be one of "
|
||||
|
||||
Reference in New Issue
Block a user