fix: PV GEMM — V canonical uses CORES_MN_V=2 (block_mn=16), not 16
V is the B operand with block_mn=16 in the PV MMA. Its canonical layout
uses CORES_MN=16/8=2, not 128/8=16. The previous code used CORES_MN=16
which produced wrong canonical indexing → garbage PV output.
Also:
- V SMEM size is (16,16) canonical = 256 BF16, not (128,16) = 2048
- P written as 16 elements at row 0 (T=1 decode)
- V loaded from TMA (16,128) and sub-sampled to (16,16) canonical
- V TMA coord: {col_start, d_base} for (HD,s_k) tensor
This commit is contained in:
@@ -91,7 +91,7 @@ fmha_tma_kernel(FmhaTmaParams params) {
|
||||
// sPk and sV for PV GEMM
|
||||
off = (off + 127) & ~(size_t)127;
|
||||
bf16_t* sPk = (bf16_t*)(sbuf + off); off += TILE_SZ * sizeof(bf16_t);
|
||||
bf16_t* sV = (bf16_t*)(sbuf + off); off += TILE_SZ * sizeof(bf16_t);
|
||||
bf16_t* sV = (bf16_t*)(sbuf + off); off += 16 * MMA_K_BF16 * sizeof(bf16_t); // (16,16) canonical for PV
|
||||
|
||||
// ==================================================================
|
||||
// Initialize
|
||||
@@ -218,8 +218,13 @@ fmha_tma_kernel(FmhaTmaParams params) {
|
||||
__syncthreads();
|
||||
|
||||
// ==================================================================
|
||||
// PV GEMM
|
||||
// PV GEMM: N=16 sub-tiles
|
||||
// V canonical uses CORES_MN_V = 16/8 = 2 (NOT 16!)
|
||||
// V SMEM size = 16 * 16 BF16 = 256 (not 128*16 = 2048)
|
||||
// ==================================================================
|
||||
static constexpr int V_SUB_SZ = 16 * MMA_K_BF16; // (16, 16) canonical
|
||||
static constexpr int CORES_MN_V = 16 / 8; // 2
|
||||
|
||||
for (int n_sub = 0; n_sub < N_NSUB; n_sub++) {
|
||||
int d_base = n_sub * 16;
|
||||
for (int pv_kt = 0; pv_kt < NKT_PV; pv_kt++) {
|
||||
@@ -229,19 +234,14 @@ fmha_tma_kernel(FmhaTmaParams params) {
|
||||
for (int i = tid; i < TILE_SZ; i += 128) sPk[i] = 0;
|
||||
__syncthreads();
|
||||
|
||||
// Write P values to canonical sPk
|
||||
if (my_row_active) {
|
||||
for (int c = 0; c < MMA_K_BF16; c++) {
|
||||
int gc = col_start + c;
|
||||
int ck = c/8, lc = c%8, cm = my_row/8, lr = my_row%8;
|
||||
sPk[ck*CORES_MN*64 + cm*64 + lr*8 + lc] = f32_to_bf16(my_p_vals[gc]);
|
||||
}
|
||||
// Write P (only row 0 for T=1 decode, 16 elements)
|
||||
for (int c = tid; c < MMA_K_BF16; c += 128) {
|
||||
int ck = c / 8, lc = c % 8;
|
||||
sPk[ck * CORES_MN * 64 + 0 * 64 + 0 * 8 + lc] = f32_to_bf16(my_p_vals[col_start + c]);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// V sub-tile: TMA load + canonical
|
||||
// V is (HD, s_k). TMA coord: {col_start, d_base}
|
||||
// We load a (16, 128) tile at position (d_base, col_start) in V
|
||||
// V sub-tile: TMA load
|
||||
if (wid == 0 && lane == 0) {
|
||||
tma_load_2d((uint32_t)__cvta_generic_to_shared(sTmaBuf), (uint64_t)tma_v, mbar_addr, col_start, d_base);
|
||||
tma_mbarrier_arrive_expect_tx(mbar_addr, TMA_TILE_BYTES);
|
||||
@@ -249,12 +249,22 @@ fmha_tma_kernel(FmhaTmaParams params) {
|
||||
tma_mbarrier_wait(mbar_addr, phase); phase ^= 1;
|
||||
__syncthreads();
|
||||
|
||||
// Convert V from (16, 128) row-major to (128, 16) canonical
|
||||
for (int i = tid; i < TILE_SZ; i += 128) sV[i] = 0;
|
||||
for (int i = tid; i < 16 * 128; i += 128) {
|
||||
int d = i / 128, r = i % 128;
|
||||
int ck = d / 8, lc = d % 8, tmn = r / 8, lr = r % 8;
|
||||
sV[ck * CORES_MN * 64 + tmn * 64 + lr * 8 + lc] = sTmaBuf[i];
|
||||
// Convert V: TMA loaded (16, 128) row-major → (16, 16) canonical with CORES_MN_V=2
|
||||
// V in GMEM is (HD, s_k). The TMA tile covers rows [d_base, d_base+16) and all cols.
|
||||
// We need V_sub = V[d_base:d_base+16, col_start:col_start+16]
|
||||
// From sTmaBuf (16, 128): element at (dd, r) = sTmaBuf[dd * 128 + r]
|
||||
// where dd is the head-dim index (0..15) and r is the sequence index (0..127)
|
||||
// We only need r in [col_start, col_start+16)
|
||||
for (int i = tid; i < V_SUB_SZ; i += 128) sV[i] = 0;
|
||||
for (int dd = tid / 32; dd < 16; dd += 4) { // 4 warps x 32 lanes
|
||||
for (int lr = lane; lr < MMA_K_BF16; lr += 32) {
|
||||
int r = col_start + lr; // sequence index
|
||||
if (r < s_k) {
|
||||
int g_mn = dd / 8, g_k = lr / 8;
|
||||
int llr = dd % 8, lc = lr % 8;
|
||||
sV[g_k * CORES_MN_V * 64 + g_mn * 64 + llr * 8 + lc] = sTmaBuf[dd * 128 + r];
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user