feat: Nvfp4GroupedLinear for o_a_proj (replaces BF16 grouped BMM)
The attention output projection first half (wo_a) was using BF16 grouped BMM (torch.bmm). Now uses production Nvfp4GroupedLinear which performs the same grouped GEMM with NVFP4 tensor-core acceleration on Blackwell. The weight is loaded from NVFP4 checkpoint if available, otherwise quantized from BF16 via set_bf16_weight(). Also includes: - NVFP4 gate projection for router (from previous commit) - Compressor position_bias in CUDA kernel (from earlier fix)
This commit is contained in:
@@ -402,22 +402,23 @@ def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin,
|
||||
# 7. Inverse RoPE
|
||||
attn_out = _apply_rope(attn_out, positions, rope_cos, rope_sin, rd, inverse=True)
|
||||
|
||||
# 8. Output: wo_a (BF16 grouped BMM) + wo_b (NVFP4 GEMM)
|
||||
hpg = n_h // o_groups; gid = hpg * hd
|
||||
oa_w = w.get(f"{pfx}.o_a_proj.weight")
|
||||
if oa_w is not None:
|
||||
oa_bf = oa_w.bfloat16().to(dev); a_flat = attn_out.reshape(T, n_h * hd)
|
||||
a_grp = a_flat.reshape(T, o_groups, gid); oa_3d = oa_bf.reshape(o_groups, o_rank, gid)
|
||||
g_out = torch.bmm(a_grp.permute(1, 0, 2), oa_3d.transpose(1, 2))
|
||||
g_flat = g_out.permute(1, 0, 2).reshape(T, o_groups * o_rank)
|
||||
if VERBOSE >= 2 and li < 3:
|
||||
print(f" L{li} wo_a: |g_flat|={g_flat.abs().max().item():.6f} shape={g_flat.shape}", flush=True)
|
||||
# 8. Output: wo_a (NVFP4 grouped GEMM) + wo_b (NVFP4 GEMM)
|
||||
wo_a_lin = prod_lin.get('o_a')
|
||||
if wo_a_lin is not None:
|
||||
# Nvfp4GroupedLinear: (T, n_h, hd) → (T, n_groups, o_rank) → flatten for o_b
|
||||
g_3d = wo_a_lin.run(attn_out) # (T, n_groups, o_rank) BF16
|
||||
g_flat = g_3d.reshape(T, -1) # (T, n_groups * o_rank) BF16
|
||||
F_attn = prod_lin['o_b'](g_flat)
|
||||
else:
|
||||
# o_a_proj as full-rank BF16 linear (no low-rank)
|
||||
# BF16 grouped BMM fallback (should not happen in production)
|
||||
hpg_fb = n_h // o_groups; gid_fb = hpg_fb * hd
|
||||
oa_full = w.get(f"{pfx}.o_a_proj.weight")
|
||||
if oa_full is not None:
|
||||
F_attn = F.linear(attn_out.reshape(T, n_h * hd), oa_full.bfloat16().to(dev))
|
||||
oa_bf = oa_full.bfloat16().to(dev); a_flat = attn_out.reshape(T, n_h * hd)
|
||||
a_grp = a_flat.reshape(T, o_groups, gid_fb); oa_3d = oa_bf.reshape(o_groups, o_rank, gid_fb)
|
||||
g_out = torch.bmm(a_grp.permute(1, 0, 2), oa_3d.transpose(1, 2))
|
||||
g_flat = g_out.permute(1, 0, 2).reshape(T, o_groups * o_rank)
|
||||
F_attn = prod_lin['o_b'](g_flat)
|
||||
else:
|
||||
log.warning(f"L{li}: No o_a_proj weight, zero attention output")
|
||||
F_attn = torch.zeros(T, cfg["hidden_size"], dtype=torch.bfloat16, device=dev)
|
||||
@@ -638,7 +639,9 @@ def main():
|
||||
# q_a_proj: (1536, 3584) uint8 -> in=7168, out=1536
|
||||
# q_b_proj: (65536, 768) uint8 -> in=1536, out=65536
|
||||
# kv_proj: (512, 3584) uint8 -> in=7168, out=512
|
||||
# o_a_proj: (16384, 4096) BF16 -> Nvfp4GroupedLinear (16 groups, 1024×4096 each)
|
||||
# o_b_proj: (7168, 8192) uint8 -> in=16384, out=7168
|
||||
from dsv4.layers.grouped_linear import Nvfp4GroupedLinear
|
||||
for li in range(n_layers):
|
||||
dev = f"cuda:{li % NUM_GPUS}"; pfx = f"model.layers.{li}.self_attn"
|
||||
torch.cuda.set_device(li % NUM_GPUS)
|
||||
@@ -646,10 +649,35 @@ def main():
|
||||
pl['q_a'] = make_nvfp4_linear(7168, 1536, dev, all_w, pfx, 'q_a_proj')
|
||||
pl['q_b'] = make_nvfp4_linear(1536, 65536, dev, all_w, pfx, 'q_b_proj')
|
||||
pl['kv'] = make_nvfp4_linear(7168, 512, dev, all_w, pfx, 'kv_proj')
|
||||
# o_a_proj: Nvfp4GroupedLinear (NVFP4 grouped GEMM)
|
||||
n_local_groups = cfg.get('o_groups', 16)
|
||||
heads_per_group = n_h // n_local_groups
|
||||
o_rank_val = cfg.get('o_lora_rank', 1024)
|
||||
wo_a = Nvfp4GroupedLinear(
|
||||
n_local_groups=n_local_groups,
|
||||
heads_per_group=heads_per_group,
|
||||
head_dim=hd,
|
||||
o_lora_rank=o_rank_val,
|
||||
max_num_tokens=8192,
|
||||
device=dev,
|
||||
)
|
||||
oa_w_nvfp4, oa_ws, oa_ws2, oa_isc = get_nvfp4_weight(all_w, pfx, 'o_a_proj')
|
||||
if oa_w_nvfp4 is not None and oa_ws is not None:
|
||||
# Checkpoint has NVFP4 weights — load directly
|
||||
# TODO: Nvfp4GroupedLinear needs a load_nvfp4_weight method
|
||||
# For now, dequant and re-quantize via set_bf16_weight
|
||||
oa_bf16 = dequant_nvfp4(oa_w_nvfp4, oa_ws, oa_ws2, oa_isc).to(dev)
|
||||
wo_a.set_bf16_weight(oa_bf16)
|
||||
else:
|
||||
# BF16 checkpoint weight
|
||||
oa_bf = all_w.get(f"{pfx}.o_a_proj.weight")
|
||||
if oa_bf is not None:
|
||||
wo_a.set_bf16_weight(oa_bf.bfloat16().to(dev))
|
||||
pl['o_a'] = wo_a
|
||||
pl['o_b'] = make_nvfp4_linear(16384, 7168, dev, all_w, pfx, 'o_b_proj')
|
||||
prod_lins[li] = pl
|
||||
if (li+1) % 10 == 0: print(f" {li+1}/{n_layers} layers")
|
||||
print(" All attention projections: production NVFP4 GEMM")
|
||||
print(" All attention projections: production NVFP4 GEMM (o_a now NVFP4 grouped)")
|
||||
|
||||
# Routers, MoE, shared experts
|
||||
routers, moe_runners, se_runners = {}, {}, {}
|
||||
|
||||
Reference in New Issue
Block a user