diff --git a/single_shot_inference.py b/single_shot_inference.py index 869ff861..a1ce5512 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -789,13 +789,21 @@ def main(): embed_w = all_w.get("model.embed_tokens.weight") embed = torch.nn.Embedding.from_pretrained(embed_w.bfloat16().to('cuda:0')) # lm_head: quantize to NVFP4 for tensor-core acceleration + # Weight is (vocab_size, hidden_size) = (N, K) in BF16 + # quantize_weight_to_nvfp4 expects (K, N), so transpose first + # But Nvfp4Linear expects (N_packed, K_packed) from checkpoint layout + # quantize_weight_to_nvfp4 returns (K//2, N) which IS (K_packed, N) + # So we need to transpose the weight, quantize as (K, N), + # then the result (K//2, N) needs to be transposed to (N, K//2) for Nvfp4Linear. lm_w_raw = all_w.get("lm_head.weight", embed_w).bfloat16().to('cuda:0') from dsv4.layers.linear import Nvfp4Linear lm_head_lin = Nvfp4Linear(lm_w_raw.shape[1], lm_w_raw.shape[0], max_num_tokens=8192, device='cuda:0') - from dsv4.ops.quantize import quantize_to_nvfp4 - lm_fp4, lm_sf, lm_gs = quantize_to_nvfp4(lm_w_raw.T.contiguous()) # (K, N) for quantize - lm_head_lin.fp4 = [lm_fp4] - lm_head_lin.sf = [lm_sf] + from dsv4.ops.quantize import quantize_weight_to_nvfp4 + # quantize_weight_to_nvfp4 takes (K, N) → returns (K//2, N), (K//16, N), gs + lm_fp4, lm_sf, lm_gs = quantize_weight_to_nvfp4(lm_w_raw.T.contiguous()) # (K//2, N) = (3584, 128K) + # Nvfp4Linear expects fp4 in (N_packed, K_packed) layout, so transpose + lm_head_lin.fp4 = [lm_fp4.permute(1, 0).contiguous()] # (N, K_packed) = (128K, 3584) + lm_head_lin.sf = [lm_sf.permute(1, 0).contiguous()] # (N, K_sf) = (128K, 448) lm_head_lin.gs = [lm_gs] # global scale from weight quantization lm_head_lin.ws2 = [None] # no separate weight_scale_2 lm_head_lin._activation_global_scale = 1.0 / (6.0 * 448.0) # placeholder