feat: Nvfp4GroupedLinear for o_a_proj (replaces BF16 grouped BMM)

The attention output projection first half (wo_a) was using BF16
grouped BMM (torch.bmm). Now uses production Nvfp4GroupedLinear
which performs the same grouped GEMM with NVFP4 tensor-core
acceleration on Blackwell.

The weight is loaded from NVFP4 checkpoint if available, otherwise
quantized from BF16 via set_bf16_weight().

Also includes:
- NVFP4 gate projection for router (from previous commit)
- Compressor position_bias in CUDA kernel (from earlier fix)
This commit is contained in:
2026-06-01 06:00:36 +00:00
parent cf2b7ab7ec
commit 6c28c57b6a

View File

@@ -402,22 +402,23 @@ def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin,
# 7. Inverse RoPE
attn_out = _apply_rope(attn_out, positions, rope_cos, rope_sin, rd, inverse=True)
# 8. Output: wo_a (BF16 grouped BMM) + wo_b (NVFP4 GEMM)
hpg = n_h // o_groups; gid = hpg * hd
oa_w = w.get(f"{pfx}.o_a_proj.weight")
if oa_w is not None:
oa_bf = oa_w.bfloat16().to(dev); a_flat = attn_out.reshape(T, n_h * hd)
a_grp = a_flat.reshape(T, o_groups, gid); oa_3d = oa_bf.reshape(o_groups, o_rank, gid)
g_out = torch.bmm(a_grp.permute(1, 0, 2), oa_3d.transpose(1, 2))
g_flat = g_out.permute(1, 0, 2).reshape(T, o_groups * o_rank)
if VERBOSE >= 2 and li < 3:
print(f" L{li} wo_a: |g_flat|={g_flat.abs().max().item():.6f} shape={g_flat.shape}", flush=True)
# 8. Output: wo_a (NVFP4 grouped GEMM) + wo_b (NVFP4 GEMM)
wo_a_lin = prod_lin.get('o_a')
if wo_a_lin is not None:
# Nvfp4GroupedLinear: (T, n_h, hd) → (T, n_groups, o_rank) → flatten for o_b
g_3d = wo_a_lin.run(attn_out) # (T, n_groups, o_rank) BF16
g_flat = g_3d.reshape(T, -1) # (T, n_groups * o_rank) BF16
F_attn = prod_lin['o_b'](g_flat)
else:
# o_a_proj as full-rank BF16 linear (no low-rank)
# BF16 grouped BMM fallback (should not happen in production)
hpg_fb = n_h // o_groups; gid_fb = hpg_fb * hd
oa_full = w.get(f"{pfx}.o_a_proj.weight")
if oa_full is not None:
F_attn = F.linear(attn_out.reshape(T, n_h * hd), oa_full.bfloat16().to(dev))
oa_bf = oa_full.bfloat16().to(dev); a_flat = attn_out.reshape(T, n_h * hd)
a_grp = a_flat.reshape(T, o_groups, gid_fb); oa_3d = oa_bf.reshape(o_groups, o_rank, gid_fb)
g_out = torch.bmm(a_grp.permute(1, 0, 2), oa_3d.transpose(1, 2))
g_flat = g_out.permute(1, 0, 2).reshape(T, o_groups * o_rank)
F_attn = prod_lin['o_b'](g_flat)
else:
log.warning(f"L{li}: No o_a_proj weight, zero attention output")
F_attn = torch.zeros(T, cfg["hidden_size"], dtype=torch.bfloat16, device=dev)
@@ -638,7 +639,9 @@ def main():
# q_a_proj: (1536, 3584) uint8 -> in=7168, out=1536
# q_b_proj: (65536, 768) uint8 -> in=1536, out=65536
# kv_proj: (512, 3584) uint8 -> in=7168, out=512
# o_a_proj: (16384, 4096) BF16 -> Nvfp4GroupedLinear (16 groups, 1024×4096 each)
# o_b_proj: (7168, 8192) uint8 -> in=16384, out=7168
from dsv4.layers.grouped_linear import Nvfp4GroupedLinear
for li in range(n_layers):
dev = f"cuda:{li % NUM_GPUS}"; pfx = f"model.layers.{li}.self_attn"
torch.cuda.set_device(li % NUM_GPUS)
@@ -646,10 +649,35 @@ def main():
pl['q_a'] = make_nvfp4_linear(7168, 1536, dev, all_w, pfx, 'q_a_proj')
pl['q_b'] = make_nvfp4_linear(1536, 65536, dev, all_w, pfx, 'q_b_proj')
pl['kv'] = make_nvfp4_linear(7168, 512, dev, all_w, pfx, 'kv_proj')
# o_a_proj: Nvfp4GroupedLinear (NVFP4 grouped GEMM)
n_local_groups = cfg.get('o_groups', 16)
heads_per_group = n_h // n_local_groups
o_rank_val = cfg.get('o_lora_rank', 1024)
wo_a = Nvfp4GroupedLinear(
n_local_groups=n_local_groups,
heads_per_group=heads_per_group,
head_dim=hd,
o_lora_rank=o_rank_val,
max_num_tokens=8192,
device=dev,
)
oa_w_nvfp4, oa_ws, oa_ws2, oa_isc = get_nvfp4_weight(all_w, pfx, 'o_a_proj')
if oa_w_nvfp4 is not None and oa_ws is not None:
# Checkpoint has NVFP4 weights — load directly
# TODO: Nvfp4GroupedLinear needs a load_nvfp4_weight method
# For now, dequant and re-quantize via set_bf16_weight
oa_bf16 = dequant_nvfp4(oa_w_nvfp4, oa_ws, oa_ws2, oa_isc).to(dev)
wo_a.set_bf16_weight(oa_bf16)
else:
# BF16 checkpoint weight
oa_bf = all_w.get(f"{pfx}.o_a_proj.weight")
if oa_bf is not None:
wo_a.set_bf16_weight(oa_bf.bfloat16().to(dev))
pl['o_a'] = wo_a
pl['o_b'] = make_nvfp4_linear(16384, 7168, dev, all_w, pfx, 'o_b_proj')
prod_lins[li] = pl
if (li+1) % 10 == 0: print(f" {li+1}/{n_layers} layers")
print(" All attention projections: production NVFP4 GEMM")
print(" All attention projections: production NVFP4 GEMM (o_a now NVFP4 grouped)")
# Routers, MoE, shared experts
routers, moe_runners, se_runners = {}, {}, {}