fix: remove .item() sync in scale reshape — use padded_scales.shape[0] instead
This commit is contained in:
@@ -183,8 +183,9 @@ class CuTeDSLMoERunner:
|
||||
padded_scales[dst_rows, :K_sf] = x_sf
|
||||
|
||||
# Apply swizzle to the whole padded tensor, return 2D for 2D-side scale_a
|
||||
swizzled_flat = pad_and_swizzle_single(padded_scales)
|
||||
return swizzled_flat.reshape(total_padded_rows.item(), -1)
|
||||
# to_blocked preserves element count, so reshape to match padded shape
|
||||
swizzled = pad_and_swizzle_single(padded_scales)
|
||||
return swizzled.reshape(padded_scales.shape[0], -1)
|
||||
|
||||
def run(self, hidden_states, topk_weights, topk_ids, expert_indices=None):
|
||||
num_tokens, hidden_size = hidden_states.shape
|
||||
|
||||
Reference in New Issue
Block a user