Compare commits
1 Commits
v0.17.2rc0
...
v0.15.0rc0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cf1167e50b |
@@ -72,7 +72,8 @@ class ncclDataTypeEnum:
|
||||
ncclFloat64 = 8
|
||||
ncclDouble = 8
|
||||
ncclBfloat16 = 9
|
||||
ncclNumTypes = 10
|
||||
ncclFloat8e4m3 = 10
|
||||
ncclNumTypes = 11
|
||||
|
||||
@classmethod
|
||||
def from_torch(cls, dtype: torch.dtype) -> int:
|
||||
@@ -92,9 +93,12 @@ class ncclDataTypeEnum:
|
||||
return cls.ncclFloat64
|
||||
if dtype == torch.bfloat16:
|
||||
return cls.ncclBfloat16
|
||||
if dtype == torch.float8_e4m3fn:
|
||||
return cls.ncclFloat8e4m3
|
||||
raise ValueError(
|
||||
f"Unsupported dtype {dtype}: should be one of "
|
||||
f"int8, uint8, int32, int64, float16, float32, float64, bfloat16."
|
||||
f"int8, uint8, int32, int64, float16, float32, float64, bfloat16,"
|
||||
" float8e4m3."
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -224,6 +224,8 @@ class FlashInferAllGatherMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFin
|
||||
a1q_scale = None
|
||||
|
||||
if is_nvfp4 and a1q_scale is not None:
|
||||
if a1q_scale.element_size() == 1:
|
||||
a1q_scale = a1q_scale.view(torch.uint8)
|
||||
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
|
||||
|
||||
return a1q, a1q_scale, None, topk_ids, topk_weights
|
||||
|
||||
Reference in New Issue
Block a user