Fix: return 2D scale tensor for GEMM (shape[1] access)

This commit is contained in:
2026-05-17 09:59:57 +00:00
parent 3cd910193c
commit 4445882ba7

View File

@@ -227,11 +227,11 @@ class CuTeDSLMoERunner:
blocks = padded_x_sf.view(row_blocks, 128, col_blocks, 4).permute(0, 2, 1, 3)
rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16)
swizzled = rearranged.flatten().view(torch.float8_e4m3fn)
# The GEMM only reads total_padded_rows worth of scale data.
# Return the full swizzled buffer; the GEMM uses expert_offsets to
# determine how many rows each expert gets.
return swizzled
swizzled_flat = rearranged.flatten()
# Return as 2D (total_rows, 32*16=512 cols in swizzled layout)
# The GEMM reads scale_a.shape[1] * sf_vec_size as hidden_padded
total_rows = row_blocks * col_blocks
return swizzled_flat.view(torch.float8_e4m3fn).reshape(total_rows, -1)
def compute_activation_global_scales(self, hidden_states_sample, topk_weights, topk_ids):
"""Compute activation global scales from a warmup forward pass.