From 52df3bc26c9fbff5d4ffb2e7b4643331455b27e0 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 30 May 2026 22:48:31 +0000 Subject: [PATCH] Fix: wo_a as batched matmul (grouped linear for output projection) --- single_shot_inference.py | 79 ++++++++++++++++++---------------------- 1 file changed, 35 insertions(+), 44 deletions(-) diff --git a/single_shot_inference.py b/single_shot_inference.py index 974ec96e..b629f51b 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -245,61 +245,45 @@ def forward_layer(x, w, li, cfg, rope_cos, rope_sin): attn_out = dsv4_attention(q_heads, k, v) # (n_h, T, hd) attn_out = attn_out.permute(1, 0, 2).reshape(T, n_h * hd) # (T, n_h*hd) - # ---- Output projection: o_a (BF16 grouped) → o_b (NVFP4) ---- - # o_a_proj: grouped linear — input is (T, n_h*hd) reshaped as (T*o_groups, heads_per_group*hd) - # Each group: (heads_per_group * hd) → o_lora_rank - # Then concatenate: (T*o_groups * o_rank) → o_b → (T, H) - oa_w = w[f"{pre}.o_a_proj.weight"] # BF16 + # ---- Output projection: wo_a (BF16 batched matmul) → wo_b (NVFP4) ---- + # wo_a: grouped linear — input per group: (heads_per_group * hd) → o_lora_rank + # Implemented as batched matmul: (n_groups, heads_per_group*hd) × (n_groups, heads_per_group*hd, o_rank) + oa_w = w[f"{pre}.o_a_proj.weight"] # BF16, stored as (n_groups*o_rank, heads_per_group*hd) or similar ob_w = w[f"{pre}.o_b_proj.weight"] ob_s = w[f"{pre}.o_b_proj.weight_scale"] ob_s2 = w[f"{pre}.o_b_proj.weight_scale_2"] - heads_per_group = n_h // o_groups # 128/16 = 8 - group_input_dim = heads_per_group * hd # 8 * 512 = 4096 + heads_per_group = n_h // o_groups # 8 + group_input_dim = heads_per_group * hd # 4096 # Reshape attention output for grouped projection # attn_out: (T, n_h * hd) → (T, o_groups, heads_per_group * hd) → (T*o_groups, group_input_dim) - attn_grouped = attn_out.reshape(T, o_groups, group_input_dim).reshape(T * o_groups, group_input_dim) + attn_grouped = attn_out.reshape(T, o_groups, heads_per_group, hd) + attn_grouped = attn_grouped.reshape(T, o_groups, group_input_dim) # (1, 16, 4096) - # o_a: (o_rank, group_input_dim) per group → total (o_groups * o_rank, o_groups * group_input_dim) - # But in the checkpoint it's stored as (o_groups * o_rank, n_h * hd) or similar - # The actual shape tells us: (4096, 16384) BF16 - # 4096 = o_rank * o_groups? No, o_rank=1024, o_groups=16 → 16384 - # 16384 = n_h * hd / 4? No. - # Let's just check: 4096 output, 16384 input - # 16384 = 4 * 4096 = ??? That doesn't match n_h*hd=65536 + # wo_a weight: (n_groups * o_rank, heads_per_group * hd) = (16384, 4096) BF16 + # Reshape to (n_groups, o_rank, heads_per_group * hd) for batched matmul + oa_w_bf16 = oa_w.cuda().bfloat16() + oa_shape = oa_w_bf16.shape - # Actually: o_a_proj weight is (4096, 16384). Since it's BF16, no FP4 packing. - # So actual out_features=4096, in_features=16384. - # 16384 = ??? Let me compute: maybe it's the input to the OUTPUT of the grouped linear, - # which is heads_per_group * hd per group but with some other factor. - # Actually 16384 = 32 * 512 = 32 heads * hd? That would be 1/4 of the heads. - # Or: 16384 = o_groups * (heads_per_group * hd) / something + # The weight might be stored transposed or in grouped format + # Try: reshape to (o_groups, o_rank, group_input_dim) for BMM + if oa_shape[0] == o_groups * o_rank and oa_shape[1] == group_input_dim: + # (o_groups * o_rank, group_input_dim) → (o_groups, o_rank, group_input_dim) + oa_3d = oa_w_bf16.reshape(o_groups, o_rank, group_input_dim) + elif oa_shape[1] == o_groups * o_rank and oa_shape[0] == group_input_dim: + # Transposed: (group_input_dim, o_groups * o_rank) → (o_groups, group_input_dim, o_rank) → (o_groups, o_rank, group_input_dim) + oa_3d = oa_w_bf16.reshape(group_input_dim, o_groups, o_rank).permute(1, 2, 0) + else: + # Fallback: just try dense linear + oa_3d = oa_w_bf16.reshape(o_groups, -1, group_input_dim) if oa_w_bf16.shape[-1] == group_input_dim else oa_w_bf16.T.reshape(o_groups, -1, group_input_dim) - # The simplest approach: just try both possible reshapes - try: - # Try 1: treat as standard linear with the full attention output - if oa_w.shape[1] * 1 == n_h * hd: # BF16, no packing - grouped = bf16_linear(attn_out, oa_w.cuda()) - elif oa_w.shape[1] * 2 == n_h * hd: # possible FP4 (but it's BF16) - # Maybe the weight is stored differently - # Just try reshaping the attention output to match - attn_reshaped = attn_out.reshape(T, -1)[:, :oa_w.shape[1]] - grouped = bf16_linear(attn_reshaped, oa_w.cuda()) - else: - # Reshape for grouped: split into groups - # (T, n_h * hd) → (T, n_h, hd) → permute → reshape → linear per group - # For now, just try direct linear - grouped = bf16_linear(attn_out[:, :oa_w.shape[1]], oa_w.cuda()) - except RuntimeError: - # Fallback: pad or truncate - if attn_out.shape[-1] < oa_w.shape[1]: - padded = torch.nn.functional.pad(attn_out, (0, oa_w.shape[1] - attn_out.shape[-1])) - grouped = bf16_linear(padded, oa_w.cuda()) - else: - grouped = bf16_linear(attn_out[:, :oa_w.shape[1]], oa_w.cuda()) + # Batched matmul: (1, 16, 4096) × (16, 1024, 4096)ᵀ → (1, 16, 1024) + # torch.bmm needs (B, M, K) × (B, K, N) → (B, M, N) + grouped_out = torch.bmm(attn_grouped, oa_3d.transpose(1, 2)) # (1, 16, o_rank) + grouped_flat = grouped_out.reshape(T, o_groups * o_rank) # (1, 16384) - attn_proj = nvfp4_linear(grouped, ob_w, ob_s, ob_s2) # (1, H) + attn_proj = nvfp4_linear(grouped_flat, ob_w, ob_s, ob_s2) # (1, H) # ---- Residual ---- x = x + attn_proj @@ -411,6 +395,13 @@ def main(): # lm_head logits = torch.nn.functional.linear(x, lm_head_w) + # lm_head + logits = torch.nn.functional.linear(x, lm_head_w) + + # Debug: check logits + if step == 0: + print(f" logits: shape={logits.shape}, min={logits.float().min().item():.4f}, max={logits.float().max().item():.4f}, has_nan={torch.isnan(logits.float()).any().item()}") + next_id = torch.argmax(logits, dim=-1).item() generated.append(next_id)