From 2796bd81e8e594a932b9fb7466fa2d2ee0af4136 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 17 May 2026 21:28:04 +0000 Subject: [PATCH] Fix: scatter FP4 as uint8 (float4 doesn't support index_put) --- tests/test_pipeline_real_weights.py | 7 ++++--- vllm/nvfp4_cutedsl.py | 9 +++++---- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/tests/test_pipeline_real_weights.py b/tests/test_pipeline_real_weights.py index f1760c4a..aaba97d7 100644 --- a/tests/test_pipeline_real_weights.py +++ b/tests/test_pipeline_real_weights.py @@ -166,9 +166,10 @@ def main(): slot_x_fp4, slot_x_sf, l1_gs = quantize_to_nvfp4(slot_hidden) print(f" L1 gs (dynamic): {l1_gs:.6f}") - # Scatter x_fp4 into padded layout - padded_x_fp4 = torch.zeros(total_padded, HIDDEN_SIZE // 2, dtype=torch.uint8, device=DEVICE).view(torch.float4_e2m1fn_x2) - padded_x_fp4[padded_dst] = slot_x_fp4 + # Scatter x_fp4 into padded layout (use uint8 for scatter, then view as float4) + padded_x_fp4_uint8 = torch.zeros(total_padded, HIDDEN_SIZE // 2, dtype=torch.uint8, device=DEVICE) + padded_x_fp4_uint8[padded_dst] = slot_x_fp4.view(torch.uint8) + padded_x_fp4 = padded_x_fp4_uint8.view(torch.float4_e2m1fn_x2) # For scale_a, we need to use the runner's assembly approach. # Use the same _assemble_scales_cudagraph_safe function diff --git a/vllm/nvfp4_cutedsl.py b/vllm/nvfp4_cutedsl.py index 8594208c..c2b9fd98 100644 --- a/vllm/nvfp4_cutedsl.py +++ b/vllm/nvfp4_cutedsl.py @@ -433,9 +433,10 @@ class CuTeDSLMoERunner: slot_hidden, self._l1_activation_global_scale ) # Scatter x_fp4 into padded layout for the GEMM + # Must scatter as uint8 (float4_e2m1fn_x2 doesn't support index_put) padded_x_fp4 = self._shared_bufs['hidden_fp4'] - padded_x_fp4.zero_() - padded_x_fp4[padded_dst] = slot_x_fp4 + padded_x_fp4.view(torch.uint8).zero_() + padded_x_fp4.view(torch.uint8)[padded_dst] = slot_x_fp4.view(torch.uint8) l1_scale_a = self._assemble_scales_cudagraph_safe( slot_x_sf, expert_offsets[:self.num_experts + 1], @@ -470,8 +471,8 @@ class CuTeDSLMoERunner: activated, self._l2_activation_global_scale ) padded_activated_fp4 = self._shared_bufs['activated_fp4'] - padded_activated_fp4.zero_() - padded_activated_fp4[padded_dst] = slot_l2_x_fp4 + padded_activated_fp4.view(torch.uint8).zero_() + padded_activated_fp4.view(torch.uint8)[padded_dst] = slot_l2_x_fp4.view(torch.uint8) l2_scale_a = self._assemble_scales_cudagraph_safe( slot_l2_x_sf, expert_offsets[:self.num_experts + 1],