fix: cast uint8 weights to int8 (kPackedFP4) for DeepGEMM compatibility

This commit is contained in:
2026-05-11 06:36:32 +00:00
parent aa9e53d5b2
commit 5cb4fcaef3

View File

@@ -175,9 +175,10 @@ def transform_nvfp4_weights_for_mega_moe(
# L1: interleave gate/up, then pack + transpose SF for UTCCP
l1_interleaved = _interleave_l1_weights((l1_weights[0], l1_sf))
l1_out = (l1_interleaved[0], _pack_nvfp4_sf_for_utccp(l1_interleaved[1]))
# DeepGEMM expects int8 (kPackedFP4 = torch.kInt8), but NVFP4 weights are uint8
l1_out = (l1_interleaved[0].view(torch.int8), _pack_nvfp4_sf_for_utccp(l1_interleaved[1]))
# L2: only pack + transpose SF for UTCCP
l2_out = (l2_weights[0], _pack_nvfp4_sf_for_utccp(l2_sf))
l2_out = (l2_weights[0].view(torch.int8), _pack_nvfp4_sf_for_utccp(l2_sf))
return l1_out, l2_out