From 4445882ba7ff3d85c5234804d1f313c3ff27f70b Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 17 May 2026 09:59:57 +0000 Subject: [PATCH] Fix: return 2D scale tensor for GEMM (shape[1] access) --- vllm/nvfp4_cutedsl.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/nvfp4_cutedsl.py b/vllm/nvfp4_cutedsl.py index 9845a4b0..37d09605 100644 --- a/vllm/nvfp4_cutedsl.py +++ b/vllm/nvfp4_cutedsl.py @@ -227,11 +227,11 @@ class CuTeDSLMoERunner: blocks = padded_x_sf.view(row_blocks, 128, col_blocks, 4).permute(0, 2, 1, 3) rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16) - swizzled = rearranged.flatten().view(torch.float8_e4m3fn) - # The GEMM only reads total_padded_rows worth of scale data. - # Return the full swizzled buffer; the GEMM uses expert_offsets to - # determine how many rows each expert gets. - return swizzled + swizzled_flat = rearranged.flatten() + # Return as 2D (total_rows, 32*16=512 cols in swizzled layout) + # The GEMM reads scale_a.shape[1] * sf_vec_size as hidden_padded + total_rows = row_blocks * col_blocks + return swizzled_flat.view(torch.float8_e4m3fn).reshape(total_rows, -1) def compute_activation_global_scales(self, hidden_states_sample, topk_weights, topk_ids): """Compute activation global scales from a warmup forward pass.