From 43a013c3a29194f7b88b1b611b3b0067592b8c67 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Date: Mon, 26 Jan 2026 15:09:32 -0500 Subject: [PATCH] [Bugfix] Fix Dtypes for Pynccl Wrapper (#33030) Signed-off-by: Robert Shaw Co-authored-by: Robert Shaw --- vllm/distributed/device_communicators/pynccl_wrapper.py | 8 ++++++-- .../fused_moe/flashinfer_cutlass_prepare_finalize.py | 2 ++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/vllm/distributed/device_communicators/pynccl_wrapper.py b/vllm/distributed/device_communicators/pynccl_wrapper.py index 67b6cbe53..3b11595b4 100644 --- a/vllm/distributed/device_communicators/pynccl_wrapper.py +++ b/vllm/distributed/device_communicators/pynccl_wrapper.py @@ -72,7 +72,8 @@ class ncclDataTypeEnum: ncclFloat64 = 8 ncclDouble = 8 ncclBfloat16 = 9 - ncclNumTypes = 10 + ncclFloat8e4m3 = 10 + ncclNumTypes = 11 @classmethod def from_torch(cls, dtype: torch.dtype) -> int: @@ -92,9 +93,12 @@ class ncclDataTypeEnum: return cls.ncclFloat64 if dtype == torch.bfloat16: return cls.ncclBfloat16 + if dtype == torch.float8_e4m3fn: + return cls.ncclFloat8e4m3 raise ValueError( f"Unsupported dtype {dtype}: should be one of " - f"int8, uint8, int32, int64, float16, float32, float64, bfloat16." + f"int8, uint8, int32, int64, float16, float32, float64, bfloat16," + " float8e4m3." ) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py index dfff86075..a56d68566 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py @@ -224,6 +224,8 @@ class FlashInferAllGatherMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFin a1q_scale = None if is_nvfp4 and a1q_scale is not None: + if a1q_scale.element_size() == 1: + a1q_scale = a1q_scale.view(torch.uint8) a1q_scale = nvfp4_block_scale_interleave(a1q_scale) return a1q, a1q_scale, None, topk_ids, topk_weights