Fix scale assembly output shape: reshape to 2D for GEMM

This commit is contained in:
2026-05-17 09:57:27 +00:00
parent d9bae6d770
commit 918aa8aede

View File

@@ -235,7 +235,10 @@ class CuTeDSLMoERunner:
all_flat = torch.cat([p.view(torch.uint8) for p in swizzled_parts], dim=0)
all_flat = all_flat.view(torch.float8_e4m3fn)
return all_flat
# Reshape to 2D: (total_padded_rows, padded_cols)
total_padded_rows = padded_expert_offsets[num_experts].item()
padded_cols = swizzled_parts[0].shape[1] if swizzled_parts else padded_x_sf_buf.shape[1]
return all_flat.reshape(total_padded_rows, padded_cols)
def compute_activation_global_scales(self, hidden_states_sample, topk_weights, topk_ids):
"""Compute activation global scales from a warmup forward pass.