From 5cb4fcaef35ba928fd103e3eaef34cdb20e84f79 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Mon, 11 May 2026 06:36:32 +0000 Subject: [PATCH] fix: cast uint8 weights to int8 (kPackedFP4) for DeepGEMM compatibility --- deep_gemm/mega/__init__.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/deep_gemm/mega/__init__.py b/deep_gemm/mega/__init__.py index 4178dd7..ced23f8 100644 --- a/deep_gemm/mega/__init__.py +++ b/deep_gemm/mega/__init__.py @@ -175,9 +175,10 @@ def transform_nvfp4_weights_for_mega_moe( # L1: interleave gate/up, then pack + transpose SF for UTCCP l1_interleaved = _interleave_l1_weights((l1_weights[0], l1_sf)) - l1_out = (l1_interleaved[0], _pack_nvfp4_sf_for_utccp(l1_interleaved[1])) + # DeepGEMM expects int8 (kPackedFP4 = torch.kInt8), but NVFP4 weights are uint8 + l1_out = (l1_interleaved[0].view(torch.int8), _pack_nvfp4_sf_for_utccp(l1_interleaved[1])) # L2: only pack + transpose SF for UTCCP - l2_out = (l2_weights[0], _pack_nvfp4_sf_for_utccp(l2_sf)) + l2_out = (l2_weights[0].view(torch.int8), _pack_nvfp4_sf_for_utccp(l2_sf)) return l1_out, l2_out