fix: use torch.int8 for packed FP4 tensors (kPackedFP4=kInt8, not uint8)

This commit is contained in:
2026-05-12 12:23:43 +00:00
parent b8f95ffad3
commit ca1d306890

View File

@@ -522,7 +522,7 @@ class DeepseekV4MegaMoEExperts(nn.Module):
num_local_experts,
2 * intermediate_size,
hidden_size // 2,
dtype=torch.uint8,
dtype=torch.int8,
),
requires_grad=False,
)
@@ -561,7 +561,7 @@ class DeepseekV4MegaMoEExperts(nn.Module):
num_local_experts,
hidden_size,
intermediate_size // 2,
dtype=torch.uint8,
dtype=torch.int8,
),
requires_grad=False,
)
@@ -804,7 +804,7 @@ class DeepseekV4MegaMoEExperts(nn.Module):
fp4_flat = fp4_val.reshape(E, M, K)
even = fp4_flat[:, :, 0::2]
odd = fp4_flat[:, :, 1::2]
w_packed = ((odd << 4) | even).to(torch.uint8)
w_packed = ((odd << 4) | even).to(torch.uint8).view(torch.int8)
return w_packed, scale_exp
@@ -1767,7 +1767,7 @@ class DeepseekV4Model(nn.Module):
even = fp4_flat[:, 0::2] # lower nibble
odd = fp4_flat[:, 1::2] # upper nibble
packed = (odd << 4) | even
weight_packed = packed.to(torch.uint8)
weight_packed = packed.to(torch.uint8).view(torch.int8)
# Reshape weight_scale to [out, n_blocks]
weight_scale_2d = weight_scale.reshape(out_dim, n_blocks)
@@ -1870,7 +1870,7 @@ class DeepseekV4Model(nn.Module):
mod = getattr(attn, proj_name)
if not hasattr(mod, "weight"):
continue
if mod.weight.dtype == torch.uint8:
if mod.weight.dtype in (torch.uint8, torch.int8):
# NVFP4 -> dequant to bf16 -> requant to FP8
if not diag_printed and layer_idx == 0:
ws = getattr(mod, 'weight_scale', None)
@@ -1892,7 +1892,7 @@ class DeepseekV4Model(nn.Module):
if not hasattr(attn, proj_name):
continue
mod = getattr(attn, proj_name)
if not hasattr(mod, "weight") or mod.weight.dtype != torch.uint8:
if not hasattr(mod, "weight") or mod.weight.dtype not in (torch.uint8, torch.int8):
continue
if not diag_printed and layer_idx == 0:
ws = getattr(mod, 'weight_scale', None)
@@ -1930,7 +1930,7 @@ class DeepseekV4Model(nn.Module):
if not hasattr(ffn.shared_experts, proj_name):
continue
mod = getattr(ffn.shared_experts, proj_name)
if not hasattr(mod, "weight") or mod.weight.dtype != torch.uint8:
if not hasattr(mod, "weight") or mod.weight.dtype not in (torch.uint8, torch.int8):
continue
self._dequant_nvfp4_to_bf16(mod, E2M1_LUT)
bf16_converted += 1