diff --git a/deep_gemm/mega/__init__.py b/deep_gemm/mega/__init__.py index 824fcd9..ea2a103 100644 --- a/deep_gemm/mega/__init__.py +++ b/deep_gemm/mega/__init__.py @@ -196,6 +196,10 @@ def transform_nvfp4_weights_for_mega_moe( l1_sf_packed = pack_ue4m3_to_int32(l1_sf_32) l2_sf_packed = pack_ue4m3_to_int32(l2_sf_32) + print(f"[NVFP4-MoE] l1_sf_32: shape={l1_sf_32.shape}, l1_sf_packed: shape={l1_sf_packed.shape}") + print(f"[NVFP4-MoE] l2_sf_32: shape={l2_sf_32.shape}, l2_sf_packed: shape={l2_sf_packed.shape}") + print(f"[NVFP4-MoE] l1_n={l1_n} l1_k={l1_k} l2_n={l2_n} l2_k={l2_k}") + # Transpose to MN-major layout (stride(-2)=1) and make contiguous # transform_sf_into_required_layout expects MN-major input for TMA stride checks l1_sf_mn = l1_sf_packed.transpose(-2, -1).contiguous().transpose(-2, -1)