The grouped GEMM expects mat_a to be laid out contiguously per group: [all tokens for group0, all tokens for group1, ...] A simple reshape of (T, G, D) → (T*G, D) gives interleaved layout which is wrong. Fix: permute to (G, T, D) before flattening. Same fix for output: permute (G, T, R) → (T, G, R).