The BF16 wo_a path was calling self.wo_a(o_inv.reshape(num_tokens, -1))
which flattens across groups: (num_tokens, n_local_heads*head_dim)=(tokens, 8192).
But wo_a is a BMM with in_features=n_heads*head_dim/n_groups=4096.
The FP8 path handles this via einsum 'bhr,hdr->bhd' with per-group shapes.
The BF16 path now does the same: reshape o_inv to per-group format,
do torch.bmm, then reshape output and handle TP all-gather manually.