fix: transpose SF to MN-major layout before TMA stride checks
transform_sf_into_required_layout expects MN-major input (stride(-2)=1). Our packed int32 SF is K-major (stride(-1)=1). Transpose the last two dims, make contiguous, then transpose back so data is in MN-major order.
This commit is contained in:
@@ -182,18 +182,17 @@ def transform_nvfp4_weights_for_mega_moe(
|
||||
l1_sf_packed = pack_ue4m3_to_int32(l1_sf)
|
||||
l2_sf_packed = pack_ue4m3_to_int32(l2_sf)
|
||||
|
||||
# Reshape to 2D for transform_sf_into_required_layout
|
||||
# (experts, mn, K//64) → (experts * mn, K//64)
|
||||
# The C++ function expects 2D or properly-strided 3D tensors
|
||||
l1_sf_2d = l1_sf_packed.reshape(-1, l1_sf_packed.shape[-1])
|
||||
l2_sf_2d = l2_sf_packed.reshape(-1, l2_sf_packed.shape[-1])
|
||||
# Transpose to MN-major layout (stride(-2)=1) and make contiguous
|
||||
# transform_sf_into_required_layout expects MN-major input for TMA stride checks
|
||||
l1_sf_mn = l1_sf_packed.transpose(-2, -1).contiguous().transpose(-2, -1)
|
||||
l2_sf_mn = l2_sf_packed.transpose(-2, -1).contiguous().transpose(-2, -1)
|
||||
|
||||
# Transform SF into TMA-aligned UTCCP layout using DeepGEMM's C++ function
|
||||
# recipe (1, 16): gran_mn=1, gran_k=16
|
||||
l1_sf_transformed = transform_sf_into_required_layout(
|
||||
l1_sf_2d, l1_n, l1_k, (1, 16), num_experts)
|
||||
l1_sf_mn, l1_n, l1_k, (1, 16), num_experts)
|
||||
l2_sf_transformed = transform_sf_into_required_layout(
|
||||
l2_sf_2d, l2_n, l2_k, (1, 16), num_experts)
|
||||
l2_sf_mn, l2_n, l2_k, (1, 16), num_experts)
|
||||
|
||||
# L1: interleave gate/up
|
||||
l1_interleaved = _interleave_l1_weights((l1_weights[0], l1_sf_packed))
|
||||
|
||||
Reference in New Issue
Block a user