fix: cast uint8 weights to int8 (kPackedFP4) for DeepGEMM compatibility
This commit is contained in:
@@ -175,9 +175,10 @@ def transform_nvfp4_weights_for_mega_moe(
|
||||
|
||||
# L1: interleave gate/up, then pack + transpose SF for UTCCP
|
||||
l1_interleaved = _interleave_l1_weights((l1_weights[0], l1_sf))
|
||||
l1_out = (l1_interleaved[0], _pack_nvfp4_sf_for_utccp(l1_interleaved[1]))
|
||||
# DeepGEMM expects int8 (kPackedFP4 = torch.kInt8), but NVFP4 weights are uint8
|
||||
l1_out = (l1_interleaved[0].view(torch.int8), _pack_nvfp4_sf_for_utccp(l1_interleaved[1]))
|
||||
# L2: only pack + transpose SF for UTCCP
|
||||
l2_out = (l2_weights[0], _pack_nvfp4_sf_for_utccp(l2_sf))
|
||||
l2_out = (l2_weights[0].view(torch.int8), _pack_nvfp4_sf_for_utccp(l2_sf))
|
||||
return l1_out, l2_out
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user