Fix: scatter FP4 as uint8 (float4 doesn't support index_put)

This commit is contained in:
2026-05-17 21:28:04 +00:00
parent 364f8372bb
commit 2796bd81e8
2 changed files with 9 additions and 7 deletions

View File

@@ -166,9 +166,10 @@ def main():
slot_x_fp4, slot_x_sf, l1_gs = quantize_to_nvfp4(slot_hidden)
print(f" L1 gs (dynamic): {l1_gs:.6f}")
# Scatter x_fp4 into padded layout
padded_x_fp4 = torch.zeros(total_padded, HIDDEN_SIZE // 2, dtype=torch.uint8, device=DEVICE).view(torch.float4_e2m1fn_x2)
padded_x_fp4[padded_dst] = slot_x_fp4
# Scatter x_fp4 into padded layout (use uint8 for scatter, then view as float4)
padded_x_fp4_uint8 = torch.zeros(total_padded, HIDDEN_SIZE // 2, dtype=torch.uint8, device=DEVICE)
padded_x_fp4_uint8[padded_dst] = slot_x_fp4.view(torch.uint8)
padded_x_fp4 = padded_x_fp4_uint8.view(torch.float4_e2m1fn_x2)
# For scale_a, we need to use the runner's assembly approach.
# Use the same _assemble_scales_cudagraph_safe function

View File

@@ -433,9 +433,10 @@ class CuTeDSLMoERunner:
slot_hidden, self._l1_activation_global_scale
)
# Scatter x_fp4 into padded layout for the GEMM
# Must scatter as uint8 (float4_e2m1fn_x2 doesn't support index_put)
padded_x_fp4 = self._shared_bufs['hidden_fp4']
padded_x_fp4.zero_()
padded_x_fp4[padded_dst] = slot_x_fp4
padded_x_fp4.view(torch.uint8).zero_()
padded_x_fp4.view(torch.uint8)[padded_dst] = slot_x_fp4.view(torch.uint8)
l1_scale_a = self._assemble_scales_cudagraph_safe(
slot_x_sf, expert_offsets[:self.num_experts + 1],
@@ -470,8 +471,8 @@ class CuTeDSLMoERunner:
activated, self._l2_activation_global_scale
)
padded_activated_fp4 = self._shared_bufs['activated_fp4']
padded_activated_fp4.zero_()
padded_activated_fp4[padded_dst] = slot_l2_x_fp4
padded_activated_fp4.view(torch.uint8).zero_()
padded_activated_fp4.view(torch.uint8)[padded_dst] = slot_l2_x_fp4.view(torch.uint8)
l2_scale_a = self._assemble_scales_cudagraph_safe(
slot_l2_x_sf, expert_offsets[:self.num_experts + 1],