Fix: wo_a as batched matmul (grouped linear for output projection)
This commit is contained in:
@@ -245,61 +245,45 @@ def forward_layer(x, w, li, cfg, rope_cos, rope_sin):
|
||||
attn_out = dsv4_attention(q_heads, k, v) # (n_h, T, hd)
|
||||
attn_out = attn_out.permute(1, 0, 2).reshape(T, n_h * hd) # (T, n_h*hd)
|
||||
|
||||
# ---- Output projection: o_a (BF16 grouped) → o_b (NVFP4) ----
|
||||
# o_a_proj: grouped linear — input is (T, n_h*hd) reshaped as (T*o_groups, heads_per_group*hd)
|
||||
# Each group: (heads_per_group * hd) → o_lora_rank
|
||||
# Then concatenate: (T*o_groups * o_rank) → o_b → (T, H)
|
||||
oa_w = w[f"{pre}.o_a_proj.weight"] # BF16
|
||||
# ---- Output projection: wo_a (BF16 batched matmul) → wo_b (NVFP4) ----
|
||||
# wo_a: grouped linear — input per group: (heads_per_group * hd) → o_lora_rank
|
||||
# Implemented as batched matmul: (n_groups, heads_per_group*hd) × (n_groups, heads_per_group*hd, o_rank)
|
||||
oa_w = w[f"{pre}.o_a_proj.weight"] # BF16, stored as (n_groups*o_rank, heads_per_group*hd) or similar
|
||||
ob_w = w[f"{pre}.o_b_proj.weight"]
|
||||
ob_s = w[f"{pre}.o_b_proj.weight_scale"]
|
||||
ob_s2 = w[f"{pre}.o_b_proj.weight_scale_2"]
|
||||
|
||||
heads_per_group = n_h // o_groups # 128/16 = 8
|
||||
group_input_dim = heads_per_group * hd # 8 * 512 = 4096
|
||||
heads_per_group = n_h // o_groups # 8
|
||||
group_input_dim = heads_per_group * hd # 4096
|
||||
|
||||
# Reshape attention output for grouped projection
|
||||
# attn_out: (T, n_h * hd) → (T, o_groups, heads_per_group * hd) → (T*o_groups, group_input_dim)
|
||||
attn_grouped = attn_out.reshape(T, o_groups, group_input_dim).reshape(T * o_groups, group_input_dim)
|
||||
attn_grouped = attn_out.reshape(T, o_groups, heads_per_group, hd)
|
||||
attn_grouped = attn_grouped.reshape(T, o_groups, group_input_dim) # (1, 16, 4096)
|
||||
|
||||
# o_a: (o_rank, group_input_dim) per group → total (o_groups * o_rank, o_groups * group_input_dim)
|
||||
# But in the checkpoint it's stored as (o_groups * o_rank, n_h * hd) or similar
|
||||
# The actual shape tells us: (4096, 16384) BF16
|
||||
# 4096 = o_rank * o_groups? No, o_rank=1024, o_groups=16 → 16384
|
||||
# 16384 = n_h * hd / 4? No.
|
||||
# Let's just check: 4096 output, 16384 input
|
||||
# 16384 = 4 * 4096 = ??? That doesn't match n_h*hd=65536
|
||||
# wo_a weight: (n_groups * o_rank, heads_per_group * hd) = (16384, 4096) BF16
|
||||
# Reshape to (n_groups, o_rank, heads_per_group * hd) for batched matmul
|
||||
oa_w_bf16 = oa_w.cuda().bfloat16()
|
||||
oa_shape = oa_w_bf16.shape
|
||||
|
||||
# Actually: o_a_proj weight is (4096, 16384). Since it's BF16, no FP4 packing.
|
||||
# So actual out_features=4096, in_features=16384.
|
||||
# 16384 = ??? Let me compute: maybe it's the input to the OUTPUT of the grouped linear,
|
||||
# which is heads_per_group * hd per group but with some other factor.
|
||||
# Actually 16384 = 32 * 512 = 32 heads * hd? That would be 1/4 of the heads.
|
||||
# Or: 16384 = o_groups * (heads_per_group * hd) / something
|
||||
# The weight might be stored transposed or in grouped format
|
||||
# Try: reshape to (o_groups, o_rank, group_input_dim) for BMM
|
||||
if oa_shape[0] == o_groups * o_rank and oa_shape[1] == group_input_dim:
|
||||
# (o_groups * o_rank, group_input_dim) → (o_groups, o_rank, group_input_dim)
|
||||
oa_3d = oa_w_bf16.reshape(o_groups, o_rank, group_input_dim)
|
||||
elif oa_shape[1] == o_groups * o_rank and oa_shape[0] == group_input_dim:
|
||||
# Transposed: (group_input_dim, o_groups * o_rank) → (o_groups, group_input_dim, o_rank) → (o_groups, o_rank, group_input_dim)
|
||||
oa_3d = oa_w_bf16.reshape(group_input_dim, o_groups, o_rank).permute(1, 2, 0)
|
||||
else:
|
||||
# Fallback: just try dense linear
|
||||
oa_3d = oa_w_bf16.reshape(o_groups, -1, group_input_dim) if oa_w_bf16.shape[-1] == group_input_dim else oa_w_bf16.T.reshape(o_groups, -1, group_input_dim)
|
||||
|
||||
# The simplest approach: just try both possible reshapes
|
||||
try:
|
||||
# Try 1: treat as standard linear with the full attention output
|
||||
if oa_w.shape[1] * 1 == n_h * hd: # BF16, no packing
|
||||
grouped = bf16_linear(attn_out, oa_w.cuda())
|
||||
elif oa_w.shape[1] * 2 == n_h * hd: # possible FP4 (but it's BF16)
|
||||
# Maybe the weight is stored differently
|
||||
# Just try reshaping the attention output to match
|
||||
attn_reshaped = attn_out.reshape(T, -1)[:, :oa_w.shape[1]]
|
||||
grouped = bf16_linear(attn_reshaped, oa_w.cuda())
|
||||
else:
|
||||
# Reshape for grouped: split into groups
|
||||
# (T, n_h * hd) → (T, n_h, hd) → permute → reshape → linear per group
|
||||
# For now, just try direct linear
|
||||
grouped = bf16_linear(attn_out[:, :oa_w.shape[1]], oa_w.cuda())
|
||||
except RuntimeError:
|
||||
# Fallback: pad or truncate
|
||||
if attn_out.shape[-1] < oa_w.shape[1]:
|
||||
padded = torch.nn.functional.pad(attn_out, (0, oa_w.shape[1] - attn_out.shape[-1]))
|
||||
grouped = bf16_linear(padded, oa_w.cuda())
|
||||
else:
|
||||
grouped = bf16_linear(attn_out[:, :oa_w.shape[1]], oa_w.cuda())
|
||||
# Batched matmul: (1, 16, 4096) × (16, 1024, 4096)ᵀ → (1, 16, 1024)
|
||||
# torch.bmm needs (B, M, K) × (B, K, N) → (B, M, N)
|
||||
grouped_out = torch.bmm(attn_grouped, oa_3d.transpose(1, 2)) # (1, 16, o_rank)
|
||||
grouped_flat = grouped_out.reshape(T, o_groups * o_rank) # (1, 16384)
|
||||
|
||||
attn_proj = nvfp4_linear(grouped, ob_w, ob_s, ob_s2) # (1, H)
|
||||
attn_proj = nvfp4_linear(grouped_flat, ob_w, ob_s, ob_s2) # (1, H)
|
||||
|
||||
# ---- Residual ----
|
||||
x = x + attn_proj
|
||||
@@ -411,6 +395,13 @@ def main():
|
||||
|
||||
# lm_head
|
||||
logits = torch.nn.functional.linear(x, lm_head_w)
|
||||
# lm_head
|
||||
logits = torch.nn.functional.linear(x, lm_head_w)
|
||||
|
||||
# Debug: check logits
|
||||
if step == 0:
|
||||
print(f" logits: shape={logits.shape}, min={logits.float().min().item():.4f}, max={logits.float().max().item():.4f}, has_nan={torch.isnan(logits.float()).any().item()}")
|
||||
|
||||
next_id = torch.argmax(logits, dim=-1).item()
|
||||
generated.append(next_id)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user