diff --git a/vllm/nvfp4_cutedsl.py b/vllm/nvfp4_cutedsl.py index 00c1260d..5a6e7086 100644 --- a/vllm/nvfp4_cutedsl.py +++ b/vllm/nvfp4_cutedsl.py @@ -235,7 +235,10 @@ class CuTeDSLMoERunner: all_flat = torch.cat([p.view(torch.uint8) for p in swizzled_parts], dim=0) all_flat = all_flat.view(torch.float8_e4m3fn) - return all_flat + # Reshape to 2D: (total_padded_rows, padded_cols) + total_padded_rows = padded_expert_offsets[num_experts].item() + padded_cols = swizzled_parts[0].shape[1] if swizzled_parts else padded_x_sf_buf.shape[1] + return all_flat.reshape(total_padded_rows, padded_cols) def compute_activation_global_scales(self, hidden_states_sample, topk_weights, topk_ids): """Compute activation global scales from a warmup forward pass.