Fix padded_cols calculation in scale assembly
This commit is contained in:
@@ -236,8 +236,11 @@ class CuTeDSLMoERunner:
|
||||
all_flat = torch.cat([p.view(torch.uint8) for p in swizzled_parts], dim=0)
|
||||
all_flat = all_flat.view(torch.float8_e4m3fn)
|
||||
# Reshape to 2D: (total_padded_rows, padded_cols)
|
||||
# padded_cols comes from the swizzle: ceil_div(K_sf, 4) * 4 * 4
|
||||
# (128 rows per row_block, 4 cols per col_block, 32 sub-rows * 16 sub-cols per block)
|
||||
# Simpler: total elements / total_padded_rows
|
||||
total_padded_rows = padded_expert_offsets[num_experts].item()
|
||||
padded_cols = swizzled_parts[0].shape[1] if swizzled_parts else padded_x_sf_buf.shape[1]
|
||||
padded_cols = all_flat.shape[0] // total_padded_rows if total_padded_rows > 0 else 0
|
||||
return all_flat.reshape(total_padded_rows, padded_cols)
|
||||
|
||||
def compute_activation_global_scales(self, hidden_states_sample, topk_weights, topk_ids):
|
||||
|
||||
Reference in New Issue
Block a user