From ca1d3068906e61bb48618f2c37c70e026eb20e84 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Tue, 12 May 2026 12:23:43 +0000 Subject: [PATCH] fix: use torch.int8 for packed FP4 tensors (kPackedFP4=kInt8, not uint8) --- patches/deepseek_v4.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/patches/deepseek_v4.py b/patches/deepseek_v4.py index b829199..8102682 100644 --- a/patches/deepseek_v4.py +++ b/patches/deepseek_v4.py @@ -522,7 +522,7 @@ class DeepseekV4MegaMoEExperts(nn.Module): num_local_experts, 2 * intermediate_size, hidden_size // 2, - dtype=torch.uint8, + dtype=torch.int8, ), requires_grad=False, ) @@ -561,7 +561,7 @@ class DeepseekV4MegaMoEExperts(nn.Module): num_local_experts, hidden_size, intermediate_size // 2, - dtype=torch.uint8, + dtype=torch.int8, ), requires_grad=False, ) @@ -804,7 +804,7 @@ class DeepseekV4MegaMoEExperts(nn.Module): fp4_flat = fp4_val.reshape(E, M, K) even = fp4_flat[:, :, 0::2] odd = fp4_flat[:, :, 1::2] - w_packed = ((odd << 4) | even).to(torch.uint8) + w_packed = ((odd << 4) | even).to(torch.uint8).view(torch.int8) return w_packed, scale_exp @@ -1767,7 +1767,7 @@ class DeepseekV4Model(nn.Module): even = fp4_flat[:, 0::2] # lower nibble odd = fp4_flat[:, 1::2] # upper nibble packed = (odd << 4) | even - weight_packed = packed.to(torch.uint8) + weight_packed = packed.to(torch.uint8).view(torch.int8) # Reshape weight_scale to [out, n_blocks] weight_scale_2d = weight_scale.reshape(out_dim, n_blocks) @@ -1870,7 +1870,7 @@ class DeepseekV4Model(nn.Module): mod = getattr(attn, proj_name) if not hasattr(mod, "weight"): continue - if mod.weight.dtype == torch.uint8: + if mod.weight.dtype in (torch.uint8, torch.int8): # NVFP4 -> dequant to bf16 -> requant to FP8 if not diag_printed and layer_idx == 0: ws = getattr(mod, 'weight_scale', None) @@ -1892,7 +1892,7 @@ class DeepseekV4Model(nn.Module): if not hasattr(attn, proj_name): continue mod = getattr(attn, proj_name) - if not hasattr(mod, "weight") or mod.weight.dtype != torch.uint8: + if not hasattr(mod, "weight") or mod.weight.dtype not in (torch.uint8, torch.int8): continue if not diag_printed and layer_idx == 0: ws = getattr(mod, 'weight_scale', None) @@ -1930,7 +1930,7 @@ class DeepseekV4Model(nn.Module): if not hasattr(ffn.shared_experts, proj_name): continue mod = getattr(ffn.shared_experts, proj_name) - if not hasattr(mod, "weight") or mod.weight.dtype != torch.uint8: + if not hasattr(mod, "weight") or mod.weight.dtype not in (torch.uint8, torch.int8): continue self._dequant_nvfp4_to_bf16(mod, E2M1_LUT) bf16_converted += 1