fix: use torch.int8 for packed FP4 tensors (kPackedFP4=kInt8, not uint8)
This commit is contained in:
@@ -522,7 +522,7 @@ class DeepseekV4MegaMoEExperts(nn.Module):
|
||||
num_local_experts,
|
||||
2 * intermediate_size,
|
||||
hidden_size // 2,
|
||||
dtype=torch.uint8,
|
||||
dtype=torch.int8,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
@@ -561,7 +561,7 @@ class DeepseekV4MegaMoEExperts(nn.Module):
|
||||
num_local_experts,
|
||||
hidden_size,
|
||||
intermediate_size // 2,
|
||||
dtype=torch.uint8,
|
||||
dtype=torch.int8,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
@@ -804,7 +804,7 @@ class DeepseekV4MegaMoEExperts(nn.Module):
|
||||
fp4_flat = fp4_val.reshape(E, M, K)
|
||||
even = fp4_flat[:, :, 0::2]
|
||||
odd = fp4_flat[:, :, 1::2]
|
||||
w_packed = ((odd << 4) | even).to(torch.uint8)
|
||||
w_packed = ((odd << 4) | even).to(torch.uint8).view(torch.int8)
|
||||
|
||||
return w_packed, scale_exp
|
||||
|
||||
@@ -1767,7 +1767,7 @@ class DeepseekV4Model(nn.Module):
|
||||
even = fp4_flat[:, 0::2] # lower nibble
|
||||
odd = fp4_flat[:, 1::2] # upper nibble
|
||||
packed = (odd << 4) | even
|
||||
weight_packed = packed.to(torch.uint8)
|
||||
weight_packed = packed.to(torch.uint8).view(torch.int8)
|
||||
|
||||
# Reshape weight_scale to [out, n_blocks]
|
||||
weight_scale_2d = weight_scale.reshape(out_dim, n_blocks)
|
||||
@@ -1870,7 +1870,7 @@ class DeepseekV4Model(nn.Module):
|
||||
mod = getattr(attn, proj_name)
|
||||
if not hasattr(mod, "weight"):
|
||||
continue
|
||||
if mod.weight.dtype == torch.uint8:
|
||||
if mod.weight.dtype in (torch.uint8, torch.int8):
|
||||
# NVFP4 -> dequant to bf16 -> requant to FP8
|
||||
if not diag_printed and layer_idx == 0:
|
||||
ws = getattr(mod, 'weight_scale', None)
|
||||
@@ -1892,7 +1892,7 @@ class DeepseekV4Model(nn.Module):
|
||||
if not hasattr(attn, proj_name):
|
||||
continue
|
||||
mod = getattr(attn, proj_name)
|
||||
if not hasattr(mod, "weight") or mod.weight.dtype != torch.uint8:
|
||||
if not hasattr(mod, "weight") or mod.weight.dtype not in (torch.uint8, torch.int8):
|
||||
continue
|
||||
if not diag_printed and layer_idx == 0:
|
||||
ws = getattr(mod, 'weight_scale', None)
|
||||
@@ -1930,7 +1930,7 @@ class DeepseekV4Model(nn.Module):
|
||||
if not hasattr(ffn.shared_experts, proj_name):
|
||||
continue
|
||||
mod = getattr(ffn.shared_experts, proj_name)
|
||||
if not hasattr(mod, "weight") or mod.weight.dtype != torch.uint8:
|
||||
if not hasattr(mod, "weight") or mod.weight.dtype not in (torch.uint8, torch.int8):
|
||||
continue
|
||||
self._dequant_nvfp4_to_bf16(mod, E2M1_LUT)
|
||||
bf16_converted += 1
|
||||
|
||||
Reference in New Issue
Block a user