The BF16 wo_a path was calling self.wo_a(o_inv.reshape(num_tokens, -1)) which flattens across groups: (num_tokens, n_local_heads*head_dim)=(tokens, 8192). But wo_a is a BMM with in_features=n_heads*head_dim/n_groups=4096. The FP8 path handles this via einsum 'bhr,hdr->bhd' with per-group shapes. The BF16 path now does the same: reshape o_inv to per-group format, do torch.bmm, then reshape output and handle TP all-gather manually.
80 lines
3.2 KiB
Python
80 lines
3.2 KiB
Python
"""Unit test: wo_a BF16 BMM reshape logic (CPU only).
|
|
|
|
Verifies that the per-group BMM reshape in the BF16 wo_a path
|
|
produces the same result as the flat linear (when no TP sharding).
|
|
|
|
Usage: python3 tests/test_wo_a_bmm.py
|
|
"""
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
|
|
def test_bmm_vs_flat():
|
|
"""Compare per-group BMM vs flat linear for wo_a."""
|
|
# Simulate: n_local_groups=2, heads_per_group=8, head_dim=512, o_lora_rank=1024
|
|
n_local_groups = 2
|
|
heads_per_group = 8
|
|
head_dim = 512
|
|
o_lora_rank = 1024
|
|
num_tokens = 4
|
|
in_features = heads_per_group * head_dim # 4096
|
|
out_features = n_local_groups * o_lora_rank # 2048
|
|
|
|
torch.manual_seed(42)
|
|
|
|
# Random attention output after inverse RoPE
|
|
# Shape: (num_tokens, n_local_heads, head_dim) where n_local_heads = n_local_groups * heads_per_group
|
|
o_inv = torch.randn(num_tokens, n_local_groups * heads_per_group, head_dim, dtype=torch.bfloat16)
|
|
|
|
# Random wo_a weight (ColumnParallelLinear, no TP sharding for this test)
|
|
# Weight shape: (out_features, in_features) = (2048, 4096)
|
|
wo_a_weight = torch.randn(out_features, in_features, dtype=torch.bfloat16) * 0.02
|
|
|
|
# Flat linear (the OLD broken way - would give wrong result if in_features != n_local_heads * head_dim)
|
|
# This test just verifies the BMM matches the flat case when dimensions align
|
|
|
|
# BMM approach (NEW way):
|
|
# Reshape o_inv: (num_tokens, n_local_groups, heads_per_group * head_dim)
|
|
# -> permute: (n_local_groups, num_tokens, in_features)
|
|
o_grouped = o_inv.view(num_tokens, n_local_groups, heads_per_group * head_dim).permute(1, 0, 2)
|
|
|
|
# Reshape weight: (out_features, in_features) -> (n_local_groups, o_lora_rank, in_features)
|
|
wo_a_w = wo_a_weight.view(n_local_groups, o_lora_rank, in_features)
|
|
|
|
# BMM: (n_local_groups, num_tokens, in) @ (n_local_groups, in, o_lora_rank)
|
|
z_bmm = torch.bmm(o_grouped, wo_a_w.transpose(1, 2))
|
|
# -> permute: (num_tokens, n_local_groups, o_lora_rank)
|
|
z_bmm = z_bmm.permute(1, 0, 2).reshape(num_tokens, n_local_groups * o_lora_rank)
|
|
|
|
# Reference: per-group matmul (the ground truth)
|
|
z_ref = torch.zeros(num_tokens, n_local_groups, o_lora_rank, dtype=torch.bfloat16)
|
|
for g in range(n_local_groups):
|
|
# (num_tokens, in_features) @ (in_features, o_lora_rank)
|
|
z_ref[:, g, :] = o_grouped[g] @ wo_a_w[g].T
|
|
z_ref = z_ref.reshape(num_tokens, n_local_groups * o_lora_rank)
|
|
|
|
cos = F.cosine_similarity(z_bmm.flatten().unsqueeze(0).float(),
|
|
z_ref.flatten().unsqueeze(0).float()).item()
|
|
mse = (z_bmm.float() - z_ref.float()).pow(2).mean().item()
|
|
|
|
status = "✅" if cos > 0.9999 else "❌"
|
|
print(f"BMM vs flat: cosine={cos:.8f} MSE={mse:.2e} {status}")
|
|
|
|
# Also verify shapes
|
|
assert o_grouped.shape == (n_local_groups, num_tokens, in_features), \
|
|
f"o_grouped shape: {o_grouped.shape}"
|
|
assert wo_a_w.shape == (n_local_groups, o_lora_rank, in_features), \
|
|
f"wo_a_w shape: {wo_a_w.shape}"
|
|
assert z_bmm.shape == (num_tokens, out_features), \
|
|
f"z_bmm shape: {z_bmm.shape}"
|
|
|
|
return cos
|
|
|
|
|
|
if __name__ == "__main__":
|
|
cos = test_bmm_vs_flat()
|
|
if cos > 0.9999:
|
|
print("\n✅ PASS")
|
|
else:
|
|
print("\n❌ FAIL")
|