diff --git a/vllm/distributed/device_communicators/pynccl_wrapper.py b/vllm/distributed/device_communicators/pynccl_wrapper.py index 67b6cbe53..3b11595b4 100644 --- a/vllm/distributed/device_communicators/pynccl_wrapper.py +++ b/vllm/distributed/device_communicators/pynccl_wrapper.py @@ -72,7 +72,8 @@ class ncclDataTypeEnum: ncclFloat64 = 8 ncclDouble = 8 ncclBfloat16 = 9 - ncclNumTypes = 10 + ncclFloat8e4m3 = 10 + ncclNumTypes = 11 @classmethod def from_torch(cls, dtype: torch.dtype) -> int: @@ -92,9 +93,12 @@ class ncclDataTypeEnum: return cls.ncclFloat64 if dtype == torch.bfloat16: return cls.ncclBfloat16 + if dtype == torch.float8_e4m3fn: + return cls.ncclFloat8e4m3 raise ValueError( f"Unsupported dtype {dtype}: should be one of " - f"int8, uint8, int32, int64, float16, float32, float64, bfloat16." + f"int8, uint8, int32, int64, float16, float32, float64, bfloat16," + " float8e4m3." ) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py index dfff86075..a56d68566 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py @@ -224,6 +224,8 @@ class FlashInferAllGatherMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFin a1q_scale = None if is_nvfp4 and a1q_scale is not None: + if a1q_scale.element_size() == 1: + a1q_scale = a1q_scale.view(torch.uint8) a1q_scale = nvfp4_block_scale_interleave(a1q_scale) return a1q, a1q_scale, None, topk_ids, topk_weights