"""Unit test: wo_a BF16 BMM reshape logic (CPU only). Verifies that the per-group BMM reshape in the BF16 wo_a path produces the same result as the flat linear (when no TP sharding). Usage: python3 tests/test_wo_a_bmm.py """ import torch import torch.nn.functional as F def test_bmm_vs_flat(): """Compare per-group BMM vs flat linear for wo_a.""" # Simulate: n_local_groups=2, heads_per_group=8, head_dim=512, o_lora_rank=1024 n_local_groups = 2 heads_per_group = 8 head_dim = 512 o_lora_rank = 1024 num_tokens = 4 in_features = heads_per_group * head_dim # 4096 out_features = n_local_groups * o_lora_rank # 2048 torch.manual_seed(42) # Random attention output after inverse RoPE # Shape: (num_tokens, n_local_heads, head_dim) where n_local_heads = n_local_groups * heads_per_group o_inv = torch.randn(num_tokens, n_local_groups * heads_per_group, head_dim, dtype=torch.bfloat16) # Random wo_a weight (ColumnParallelLinear, no TP sharding for this test) # Weight shape: (out_features, in_features) = (2048, 4096) wo_a_weight = torch.randn(out_features, in_features, dtype=torch.bfloat16) * 0.02 # Flat linear (the OLD broken way - would give wrong result if in_features != n_local_heads * head_dim) # This test just verifies the BMM matches the flat case when dimensions align # BMM approach (NEW way): # Reshape o_inv: (num_tokens, n_local_groups, heads_per_group * head_dim) # -> permute: (n_local_groups, num_tokens, in_features) o_grouped = o_inv.view(num_tokens, n_local_groups, heads_per_group * head_dim).permute(1, 0, 2) # Reshape weight: (out_features, in_features) -> (n_local_groups, o_lora_rank, in_features) wo_a_w = wo_a_weight.view(n_local_groups, o_lora_rank, in_features) # BMM: (n_local_groups, num_tokens, in) @ (n_local_groups, in, o_lora_rank) z_bmm = torch.bmm(o_grouped, wo_a_w.transpose(1, 2)) # -> permute: (num_tokens, n_local_groups, o_lora_rank) z_bmm = z_bmm.permute(1, 0, 2).reshape(num_tokens, n_local_groups * o_lora_rank) # Reference: per-group matmul (the ground truth) z_ref = torch.zeros(num_tokens, n_local_groups, o_lora_rank, dtype=torch.bfloat16) for g in range(n_local_groups): # (num_tokens, in_features) @ (in_features, o_lora_rank) z_ref[:, g, :] = o_grouped[g] @ wo_a_w[g].T z_ref = z_ref.reshape(num_tokens, n_local_groups * o_lora_rank) cos = F.cosine_similarity(z_bmm.flatten().unsqueeze(0).float(), z_ref.flatten().unsqueeze(0).float()).item() mse = (z_bmm.float() - z_ref.float()).pow(2).mean().item() status = "āœ…" if cos > 0.9999 else "āŒ" print(f"BMM vs flat: cosine={cos:.8f} MSE={mse:.2e} {status}") # Also verify shapes assert o_grouped.shape == (n_local_groups, num_tokens, in_features), \ f"o_grouped shape: {o_grouped.shape}" assert wo_a_w.shape == (n_local_groups, o_lora_rank, in_features), \ f"wo_a_w shape: {wo_a_w.shape}" assert z_bmm.shape == (num_tokens, out_features), \ f"z_bmm shape: {z_bmm.shape}" return cos if __name__ == "__main__": cos = test_bmm_vs_flat() if cos > 0.9999: print("\nāœ… PASS") else: print("\nāŒ FAIL")