diff --git a/single_shot_inference.py b/single_shot_inference.py index 4d4d8a24..fdddb012 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -402,22 +402,23 @@ def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin, # 7. Inverse RoPE attn_out = _apply_rope(attn_out, positions, rope_cos, rope_sin, rd, inverse=True) - # 8. Output: wo_a (BF16 grouped BMM) + wo_b (NVFP4 GEMM) - hpg = n_h // o_groups; gid = hpg * hd - oa_w = w.get(f"{pfx}.o_a_proj.weight") - if oa_w is not None: - oa_bf = oa_w.bfloat16().to(dev); a_flat = attn_out.reshape(T, n_h * hd) - a_grp = a_flat.reshape(T, o_groups, gid); oa_3d = oa_bf.reshape(o_groups, o_rank, gid) - g_out = torch.bmm(a_grp.permute(1, 0, 2), oa_3d.transpose(1, 2)) - g_flat = g_out.permute(1, 0, 2).reshape(T, o_groups * o_rank) - if VERBOSE >= 2 and li < 3: - print(f" L{li} wo_a: |g_flat|={g_flat.abs().max().item():.6f} shape={g_flat.shape}", flush=True) + # 8. Output: wo_a (NVFP4 grouped GEMM) + wo_b (NVFP4 GEMM) + wo_a_lin = prod_lin.get('o_a') + if wo_a_lin is not None: + # Nvfp4GroupedLinear: (T, n_h, hd) → (T, n_groups, o_rank) → flatten for o_b + g_3d = wo_a_lin.run(attn_out) # (T, n_groups, o_rank) BF16 + g_flat = g_3d.reshape(T, -1) # (T, n_groups * o_rank) BF16 F_attn = prod_lin['o_b'](g_flat) else: - # o_a_proj as full-rank BF16 linear (no low-rank) + # BF16 grouped BMM fallback (should not happen in production) + hpg_fb = n_h // o_groups; gid_fb = hpg_fb * hd oa_full = w.get(f"{pfx}.o_a_proj.weight") if oa_full is not None: - F_attn = F.linear(attn_out.reshape(T, n_h * hd), oa_full.bfloat16().to(dev)) + oa_bf = oa_full.bfloat16().to(dev); a_flat = attn_out.reshape(T, n_h * hd) + a_grp = a_flat.reshape(T, o_groups, gid_fb); oa_3d = oa_bf.reshape(o_groups, o_rank, gid_fb) + g_out = torch.bmm(a_grp.permute(1, 0, 2), oa_3d.transpose(1, 2)) + g_flat = g_out.permute(1, 0, 2).reshape(T, o_groups * o_rank) + F_attn = prod_lin['o_b'](g_flat) else: log.warning(f"L{li}: No o_a_proj weight, zero attention output") F_attn = torch.zeros(T, cfg["hidden_size"], dtype=torch.bfloat16, device=dev) @@ -638,7 +639,9 @@ def main(): # q_a_proj: (1536, 3584) uint8 -> in=7168, out=1536 # q_b_proj: (65536, 768) uint8 -> in=1536, out=65536 # kv_proj: (512, 3584) uint8 -> in=7168, out=512 + # o_a_proj: (16384, 4096) BF16 -> Nvfp4GroupedLinear (16 groups, 1024×4096 each) # o_b_proj: (7168, 8192) uint8 -> in=16384, out=7168 + from dsv4.layers.grouped_linear import Nvfp4GroupedLinear for li in range(n_layers): dev = f"cuda:{li % NUM_GPUS}"; pfx = f"model.layers.{li}.self_attn" torch.cuda.set_device(li % NUM_GPUS) @@ -646,10 +649,35 @@ def main(): pl['q_a'] = make_nvfp4_linear(7168, 1536, dev, all_w, pfx, 'q_a_proj') pl['q_b'] = make_nvfp4_linear(1536, 65536, dev, all_w, pfx, 'q_b_proj') pl['kv'] = make_nvfp4_linear(7168, 512, dev, all_w, pfx, 'kv_proj') + # o_a_proj: Nvfp4GroupedLinear (NVFP4 grouped GEMM) + n_local_groups = cfg.get('o_groups', 16) + heads_per_group = n_h // n_local_groups + o_rank_val = cfg.get('o_lora_rank', 1024) + wo_a = Nvfp4GroupedLinear( + n_local_groups=n_local_groups, + heads_per_group=heads_per_group, + head_dim=hd, + o_lora_rank=o_rank_val, + max_num_tokens=8192, + device=dev, + ) + oa_w_nvfp4, oa_ws, oa_ws2, oa_isc = get_nvfp4_weight(all_w, pfx, 'o_a_proj') + if oa_w_nvfp4 is not None and oa_ws is not None: + # Checkpoint has NVFP4 weights — load directly + # TODO: Nvfp4GroupedLinear needs a load_nvfp4_weight method + # For now, dequant and re-quantize via set_bf16_weight + oa_bf16 = dequant_nvfp4(oa_w_nvfp4, oa_ws, oa_ws2, oa_isc).to(dev) + wo_a.set_bf16_weight(oa_bf16) + else: + # BF16 checkpoint weight + oa_bf = all_w.get(f"{pfx}.o_a_proj.weight") + if oa_bf is not None: + wo_a.set_bf16_weight(oa_bf.bfloat16().to(dev)) + pl['o_a'] = wo_a pl['o_b'] = make_nvfp4_linear(16384, 7168, dev, all_w, pfx, 'o_b_proj') prod_lins[li] = pl if (li+1) % 10 == 0: print(f" {li+1}/{n_layers} layers") - print(" All attention projections: production NVFP4 GEMM") + print(" All attention projections: production NVFP4 GEMM (o_a now NVFP4 grouped)") # Routers, MoE, shared experts routers, moe_runners, se_runners = {}, {}, {}