Files
nvfp4-megamoe-kernel/tests/test_wo_a_bmm.py
biondizzle c289c44920 Fix BF16 wo_a: per-group BMM instead of flat linear
The BF16 wo_a path was calling self.wo_a(o_inv.reshape(num_tokens, -1))
which flattens across groups: (num_tokens, n_local_heads*head_dim)=(tokens, 8192).
But wo_a is a BMM with in_features=n_heads*head_dim/n_groups=4096.

The FP8 path handles this via einsum 'bhr,hdr->bhd' with per-group shapes.
The BF16 path now does the same: reshape o_inv to per-group format,
do torch.bmm, then reshape output and handle TP all-gather manually.
2026-05-19 04:10:02 +00:00

80 lines
3.2 KiB
Python

"""Unit test: wo_a BF16 BMM reshape logic (CPU only).
Verifies that the per-group BMM reshape in the BF16 wo_a path
produces the same result as the flat linear (when no TP sharding).
Usage: python3 tests/test_wo_a_bmm.py
"""
import torch
import torch.nn.functional as F
def test_bmm_vs_flat():
"""Compare per-group BMM vs flat linear for wo_a."""
# Simulate: n_local_groups=2, heads_per_group=8, head_dim=512, o_lora_rank=1024
n_local_groups = 2
heads_per_group = 8
head_dim = 512
o_lora_rank = 1024
num_tokens = 4
in_features = heads_per_group * head_dim # 4096
out_features = n_local_groups * o_lora_rank # 2048
torch.manual_seed(42)
# Random attention output after inverse RoPE
# Shape: (num_tokens, n_local_heads, head_dim) where n_local_heads = n_local_groups * heads_per_group
o_inv = torch.randn(num_tokens, n_local_groups * heads_per_group, head_dim, dtype=torch.bfloat16)
# Random wo_a weight (ColumnParallelLinear, no TP sharding for this test)
# Weight shape: (out_features, in_features) = (2048, 4096)
wo_a_weight = torch.randn(out_features, in_features, dtype=torch.bfloat16) * 0.02
# Flat linear (the OLD broken way - would give wrong result if in_features != n_local_heads * head_dim)
# This test just verifies the BMM matches the flat case when dimensions align
# BMM approach (NEW way):
# Reshape o_inv: (num_tokens, n_local_groups, heads_per_group * head_dim)
# -> permute: (n_local_groups, num_tokens, in_features)
o_grouped = o_inv.view(num_tokens, n_local_groups, heads_per_group * head_dim).permute(1, 0, 2)
# Reshape weight: (out_features, in_features) -> (n_local_groups, o_lora_rank, in_features)
wo_a_w = wo_a_weight.view(n_local_groups, o_lora_rank, in_features)
# BMM: (n_local_groups, num_tokens, in) @ (n_local_groups, in, o_lora_rank)
z_bmm = torch.bmm(o_grouped, wo_a_w.transpose(1, 2))
# -> permute: (num_tokens, n_local_groups, o_lora_rank)
z_bmm = z_bmm.permute(1, 0, 2).reshape(num_tokens, n_local_groups * o_lora_rank)
# Reference: per-group matmul (the ground truth)
z_ref = torch.zeros(num_tokens, n_local_groups, o_lora_rank, dtype=torch.bfloat16)
for g in range(n_local_groups):
# (num_tokens, in_features) @ (in_features, o_lora_rank)
z_ref[:, g, :] = o_grouped[g] @ wo_a_w[g].T
z_ref = z_ref.reshape(num_tokens, n_local_groups * o_lora_rank)
cos = F.cosine_similarity(z_bmm.flatten().unsqueeze(0).float(),
z_ref.flatten().unsqueeze(0).float()).item()
mse = (z_bmm.float() - z_ref.float()).pow(2).mean().item()
status = "" if cos > 0.9999 else ""
print(f"BMM vs flat: cosine={cos:.8f} MSE={mse:.2e} {status}")
# Also verify shapes
assert o_grouped.shape == (n_local_groups, num_tokens, in_features), \
f"o_grouped shape: {o_grouped.shape}"
assert wo_a_w.shape == (n_local_groups, o_lora_rank, in_features), \
f"wo_a_w shape: {wo_a_w.shape}"
assert z_bmm.shape == (num_tokens, out_features), \
f"z_bmm shape: {z_bmm.shape}"
return cos
if __name__ == "__main__":
cos = test_bmm_vs_flat()
if cos > 0.9999:
print("\n✅ PASS")
else:
print("\n❌ FAIL")