Compare commits
1 Commits
v0.17.2rc0
...
v0.15.0rc0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cf1167e50b |
@@ -72,7 +72,8 @@ class ncclDataTypeEnum:
|
|||||||
ncclFloat64 = 8
|
ncclFloat64 = 8
|
||||||
ncclDouble = 8
|
ncclDouble = 8
|
||||||
ncclBfloat16 = 9
|
ncclBfloat16 = 9
|
||||||
ncclNumTypes = 10
|
ncclFloat8e4m3 = 10
|
||||||
|
ncclNumTypes = 11
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_torch(cls, dtype: torch.dtype) -> int:
|
def from_torch(cls, dtype: torch.dtype) -> int:
|
||||||
@@ -92,9 +93,12 @@ class ncclDataTypeEnum:
|
|||||||
return cls.ncclFloat64
|
return cls.ncclFloat64
|
||||||
if dtype == torch.bfloat16:
|
if dtype == torch.bfloat16:
|
||||||
return cls.ncclBfloat16
|
return cls.ncclBfloat16
|
||||||
|
if dtype == torch.float8_e4m3fn:
|
||||||
|
return cls.ncclFloat8e4m3
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported dtype {dtype}: should be one of "
|
f"Unsupported dtype {dtype}: should be one of "
|
||||||
f"int8, uint8, int32, int64, float16, float32, float64, bfloat16."
|
f"int8, uint8, int32, int64, float16, float32, float64, bfloat16,"
|
||||||
|
" float8e4m3."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -224,6 +224,8 @@ class FlashInferAllGatherMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFin
|
|||||||
a1q_scale = None
|
a1q_scale = None
|
||||||
|
|
||||||
if is_nvfp4 and a1q_scale is not None:
|
if is_nvfp4 and a1q_scale is not None:
|
||||||
|
if a1q_scale.element_size() == 1:
|
||||||
|
a1q_scale = a1q_scale.view(torch.uint8)
|
||||||
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
|
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
|
||||||
|
|
||||||
return a1q, a1q_scale, None, topk_ids, topk_weights
|
return a1q, a1q_scale, None, topk_ids, topk_weights
|
||||||
|
|||||||
Reference in New Issue
Block a user