diff --git a/deep_gemm/mega/__init__.py b/deep_gemm/mega/__init__.py index 4178dd7..ced23f8 100644 --- a/deep_gemm/mega/__init__.py +++ b/deep_gemm/mega/__init__.py @@ -175,9 +175,10 @@ def transform_nvfp4_weights_for_mega_moe( # L1: interleave gate/up, then pack + transpose SF for UTCCP l1_interleaved = _interleave_l1_weights((l1_weights[0], l1_sf)) - l1_out = (l1_interleaved[0], _pack_nvfp4_sf_for_utccp(l1_interleaved[1])) + # DeepGEMM expects int8 (kPackedFP4 = torch.kInt8), but NVFP4 weights are uint8 + l1_out = (l1_interleaved[0].view(torch.int8), _pack_nvfp4_sf_for_utccp(l1_interleaved[1])) # L2: only pack + transpose SF for UTCCP - l2_out = (l2_weights[0], _pack_nvfp4_sf_for_utccp(l2_sf)) + l2_out = (l2_weights[0].view(torch.int8), _pack_nvfp4_sf_for_utccp(l2_sf)) return l1_out, l2_out