fix: remove .item() sync in scale reshape — use padded_scales.shape[0] instead

This commit is contained in:
2026-05-16 18:29:12 +00:00
parent 5a79065b2b
commit 4300775bfe

View File

@@ -183,8 +183,9 @@ class CuTeDSLMoERunner:
padded_scales[dst_rows, :K_sf] = x_sf
# Apply swizzle to the whole padded tensor, return 2D for 2D-side scale_a
swizzled_flat = pad_and_swizzle_single(padded_scales)
return swizzled_flat.reshape(total_padded_rows.item(), -1)
# to_blocked preserves element count, so reshape to match padded shape
swizzled = pad_and_swizzle_single(padded_scales)
return swizzled.reshape(padded_scales.shape[0], -1)
def run(self, hidden_states, topk_weights, topk_ids, expert_indices=None):
num_tokens, hidden_size = hidden_states.shape