Fix lm_head NVFP4: transpose weight and scales to match Nvfp4Linear checkpoint layout
quantize_weight_to_nvfp4 returns (K_packed, N) but Nvfp4Linear expects (N, K_packed) from the checkpoint format. Transpose both fp4 and sf.
This commit is contained in:
@@ -789,13 +789,21 @@ def main():
|
||||
embed_w = all_w.get("model.embed_tokens.weight")
|
||||
embed = torch.nn.Embedding.from_pretrained(embed_w.bfloat16().to('cuda:0'))
|
||||
# lm_head: quantize to NVFP4 for tensor-core acceleration
|
||||
# Weight is (vocab_size, hidden_size) = (N, K) in BF16
|
||||
# quantize_weight_to_nvfp4 expects (K, N), so transpose first
|
||||
# But Nvfp4Linear expects (N_packed, K_packed) from checkpoint layout
|
||||
# quantize_weight_to_nvfp4 returns (K//2, N) which IS (K_packed, N)
|
||||
# So we need to transpose the weight, quantize as (K, N),
|
||||
# then the result (K//2, N) needs to be transposed to (N, K//2) for Nvfp4Linear.
|
||||
lm_w_raw = all_w.get("lm_head.weight", embed_w).bfloat16().to('cuda:0')
|
||||
from dsv4.layers.linear import Nvfp4Linear
|
||||
lm_head_lin = Nvfp4Linear(lm_w_raw.shape[1], lm_w_raw.shape[0], max_num_tokens=8192, device='cuda:0')
|
||||
from dsv4.ops.quantize import quantize_to_nvfp4
|
||||
lm_fp4, lm_sf, lm_gs = quantize_to_nvfp4(lm_w_raw.T.contiguous()) # (K, N) for quantize
|
||||
lm_head_lin.fp4 = [lm_fp4]
|
||||
lm_head_lin.sf = [lm_sf]
|
||||
from dsv4.ops.quantize import quantize_weight_to_nvfp4
|
||||
# quantize_weight_to_nvfp4 takes (K, N) → returns (K//2, N), (K//16, N), gs
|
||||
lm_fp4, lm_sf, lm_gs = quantize_weight_to_nvfp4(lm_w_raw.T.contiguous()) # (K//2, N) = (3584, 128K)
|
||||
# Nvfp4Linear expects fp4 in (N_packed, K_packed) layout, so transpose
|
||||
lm_head_lin.fp4 = [lm_fp4.permute(1, 0).contiguous()] # (N, K_packed) = (128K, 3584)
|
||||
lm_head_lin.sf = [lm_sf.permute(1, 0).contiguous()] # (N, K_sf) = (128K, 448)
|
||||
lm_head_lin.gs = [lm_gs] # global scale from weight quantization
|
||||
lm_head_lin.ws2 = [None] # no separate weight_scale_2
|
||||
lm_head_lin._activation_global_scale = 1.0 / (6.0 * 448.0) # placeholder
|
||||
|
||||
Reference in New Issue
Block a user