From 4f6217acb96f8319ea366733c93028a1cbde49c2 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 17 May 2026 09:58:09 +0000 Subject: [PATCH] Fix padded_cols calculation in scale assembly --- vllm/nvfp4_cutedsl.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/nvfp4_cutedsl.py b/vllm/nvfp4_cutedsl.py index 5a6e7086..ab6dc76a 100644 --- a/vllm/nvfp4_cutedsl.py +++ b/vllm/nvfp4_cutedsl.py @@ -236,8 +236,11 @@ class CuTeDSLMoERunner: all_flat = torch.cat([p.view(torch.uint8) for p in swizzled_parts], dim=0) all_flat = all_flat.view(torch.float8_e4m3fn) # Reshape to 2D: (total_padded_rows, padded_cols) + # padded_cols comes from the swizzle: ceil_div(K_sf, 4) * 4 * 4 + # (128 rows per row_block, 4 cols per col_block, 32 sub-rows * 16 sub-cols per block) + # Simpler: total elements / total_padded_rows total_padded_rows = padded_expert_offsets[num_experts].item() - padded_cols = swizzled_parts[0].shape[1] if swizzled_parts else padded_x_sf_buf.shape[1] + padded_cols = all_flat.shape[0] // total_padded_rows if total_padded_rows > 0 else 0 return all_flat.reshape(total_padded_rows, padded_cols) def compute_activation_global_scales(self, hidden_states_sample, topk_weights, topk_ids):