Fix: wo_a as batched matmul (grouped linear for output projection)

This commit is contained in:
2026-05-30 22:48:31 +00:00
parent 19240608d7
commit 52df3bc26c

View File

@@ -245,61 +245,45 @@ def forward_layer(x, w, li, cfg, rope_cos, rope_sin):
attn_out = dsv4_attention(q_heads, k, v) # (n_h, T, hd)
attn_out = attn_out.permute(1, 0, 2).reshape(T, n_h * hd) # (T, n_h*hd)
# ---- Output projection: o_a (BF16 grouped) → o_b (NVFP4) ----
# o_a_proj: grouped linear — input is (T, n_h*hd) reshaped as (T*o_groups, heads_per_group*hd)
# Each group: (heads_per_group * hd) → o_lora_rank
# Then concatenate: (T*o_groups * o_rank) → o_b → (T, H)
oa_w = w[f"{pre}.o_a_proj.weight"] # BF16
# ---- Output projection: wo_a (BF16 batched matmul) → wo_b (NVFP4) ----
# wo_a: grouped linear — input per group: (heads_per_group * hd) → o_lora_rank
# Implemented as batched matmul: (n_groups, heads_per_group*hd) × (n_groups, heads_per_group*hd, o_rank)
oa_w = w[f"{pre}.o_a_proj.weight"] # BF16, stored as (n_groups*o_rank, heads_per_group*hd) or similar
ob_w = w[f"{pre}.o_b_proj.weight"]
ob_s = w[f"{pre}.o_b_proj.weight_scale"]
ob_s2 = w[f"{pre}.o_b_proj.weight_scale_2"]
heads_per_group = n_h // o_groups # 128/16 = 8
group_input_dim = heads_per_group * hd # 8 * 512 = 4096
heads_per_group = n_h // o_groups # 8
group_input_dim = heads_per_group * hd # 4096
# Reshape attention output for grouped projection
# attn_out: (T, n_h * hd) → (T, o_groups, heads_per_group * hd) → (T*o_groups, group_input_dim)
attn_grouped = attn_out.reshape(T, o_groups, group_input_dim).reshape(T * o_groups, group_input_dim)
attn_grouped = attn_out.reshape(T, o_groups, heads_per_group, hd)
attn_grouped = attn_grouped.reshape(T, o_groups, group_input_dim) # (1, 16, 4096)
# o_a: (o_rank, group_input_dim) per group → total (o_groups * o_rank, o_groups * group_input_dim)
# But in the checkpoint it's stored as (o_groups * o_rank, n_h * hd) or similar
# The actual shape tells us: (4096, 16384) BF16
# 4096 = o_rank * o_groups? No, o_rank=1024, o_groups=16 → 16384
# 16384 = n_h * hd / 4? No.
# Let's just check: 4096 output, 16384 input
# 16384 = 4 * 4096 = ??? That doesn't match n_h*hd=65536
# wo_a weight: (n_groups * o_rank, heads_per_group * hd) = (16384, 4096) BF16
# Reshape to (n_groups, o_rank, heads_per_group * hd) for batched matmul
oa_w_bf16 = oa_w.cuda().bfloat16()
oa_shape = oa_w_bf16.shape
# Actually: o_a_proj weight is (4096, 16384). Since it's BF16, no FP4 packing.
# So actual out_features=4096, in_features=16384.
# 16384 = ??? Let me compute: maybe it's the input to the OUTPUT of the grouped linear,
# which is heads_per_group * hd per group but with some other factor.
# Actually 16384 = 32 * 512 = 32 heads * hd? That would be 1/4 of the heads.
# Or: 16384 = o_groups * (heads_per_group * hd) / something
# The weight might be stored transposed or in grouped format
# Try: reshape to (o_groups, o_rank, group_input_dim) for BMM
if oa_shape[0] == o_groups * o_rank and oa_shape[1] == group_input_dim:
# (o_groups * o_rank, group_input_dim) → (o_groups, o_rank, group_input_dim)
oa_3d = oa_w_bf16.reshape(o_groups, o_rank, group_input_dim)
elif oa_shape[1] == o_groups * o_rank and oa_shape[0] == group_input_dim:
# Transposed: (group_input_dim, o_groups * o_rank) → (o_groups, group_input_dim, o_rank) → (o_groups, o_rank, group_input_dim)
oa_3d = oa_w_bf16.reshape(group_input_dim, o_groups, o_rank).permute(1, 2, 0)
else:
# Fallback: just try dense linear
oa_3d = oa_w_bf16.reshape(o_groups, -1, group_input_dim) if oa_w_bf16.shape[-1] == group_input_dim else oa_w_bf16.T.reshape(o_groups, -1, group_input_dim)
# The simplest approach: just try both possible reshapes
try:
# Try 1: treat as standard linear with the full attention output
if oa_w.shape[1] * 1 == n_h * hd: # BF16, no packing
grouped = bf16_linear(attn_out, oa_w.cuda())
elif oa_w.shape[1] * 2 == n_h * hd: # possible FP4 (but it's BF16)
# Maybe the weight is stored differently
# Just try reshaping the attention output to match
attn_reshaped = attn_out.reshape(T, -1)[:, :oa_w.shape[1]]
grouped = bf16_linear(attn_reshaped, oa_w.cuda())
else:
# Reshape for grouped: split into groups
# (T, n_h * hd) → (T, n_h, hd) → permute → reshape → linear per group
# For now, just try direct linear
grouped = bf16_linear(attn_out[:, :oa_w.shape[1]], oa_w.cuda())
except RuntimeError:
# Fallback: pad or truncate
if attn_out.shape[-1] < oa_w.shape[1]:
padded = torch.nn.functional.pad(attn_out, (0, oa_w.shape[1] - attn_out.shape[-1]))
grouped = bf16_linear(padded, oa_w.cuda())
else:
grouped = bf16_linear(attn_out[:, :oa_w.shape[1]], oa_w.cuda())
# Batched matmul: (1, 16, 4096) × (16, 1024, 4096)ᵀ → (1, 16, 1024)
# torch.bmm needs (B, M, K) × (B, K, N) → (B, M, N)
grouped_out = torch.bmm(attn_grouped, oa_3d.transpose(1, 2)) # (1, 16, o_rank)
grouped_flat = grouped_out.reshape(T, o_groups * o_rank) # (1, 16384)
attn_proj = nvfp4_linear(grouped, ob_w, ob_s, ob_s2) # (1, H)
attn_proj = nvfp4_linear(grouped_flat, ob_w, ob_s, ob_s2) # (1, H)
# ---- Residual ----
x = x + attn_proj
@@ -411,6 +395,13 @@ def main():
# lm_head
logits = torch.nn.functional.linear(x, lm_head_w)
# lm_head
logits = torch.nn.functional.linear(x, lm_head_w)
# Debug: check logits
if step == 0:
print(f" logits: shape={logits.shape}, min={logits.float().min().item():.4f}, max={logits.float().max().item():.4f}, has_nan={torch.isnan(logits.float()).any().item()}")
next_id = torch.argmax(logits, dim=-1).item()
generated.append(next_id)