Fix: return 2D scale tensor for GEMM (shape[1] access)
This commit is contained in:
@@ -227,11 +227,11 @@ class CuTeDSLMoERunner:
|
||||
|
||||
blocks = padded_x_sf.view(row_blocks, 128, col_blocks, 4).permute(0, 2, 1, 3)
|
||||
rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16)
|
||||
swizzled = rearranged.flatten().view(torch.float8_e4m3fn)
|
||||
# The GEMM only reads total_padded_rows worth of scale data.
|
||||
# Return the full swizzled buffer; the GEMM uses expert_offsets to
|
||||
# determine how many rows each expert gets.
|
||||
return swizzled
|
||||
swizzled_flat = rearranged.flatten()
|
||||
# Return as 2D (total_rows, 32*16=512 cols in swizzled layout)
|
||||
# The GEMM reads scale_a.shape[1] * sf_vec_size as hidden_padded
|
||||
total_rows = row_blocks * col_blocks
|
||||
return swizzled_flat.view(torch.float8_e4m3fn).reshape(total_rows, -1)
|
||||
|
||||
def compute_activation_global_scales(self, hidden_states_sample, topk_weights, topk_ids):
|
||||
"""Compute activation global scales from a warmup forward pass.
|
||||
|
||||
Reference in New Issue
Block a user