diff --git a/dsv4/kernels/attention/fmha_6warp_tma.cuh b/dsv4/kernels/attention/fmha_6warp_tma.cuh index 910a9269..2a46d697 100644 --- a/dsv4/kernels/attention/fmha_6warp_tma.cuh +++ b/dsv4/kernels/attention/fmha_6warp_tma.cuh @@ -91,7 +91,7 @@ fmha_tma_kernel(FmhaTmaParams params) { // sPk and sV for PV GEMM off = (off + 127) & ~(size_t)127; bf16_t* sPk = (bf16_t*)(sbuf + off); off += TILE_SZ * sizeof(bf16_t); - bf16_t* sV = (bf16_t*)(sbuf + off); off += TILE_SZ * sizeof(bf16_t); + bf16_t* sV = (bf16_t*)(sbuf + off); off += 16 * MMA_K_BF16 * sizeof(bf16_t); // (16,16) canonical for PV // ================================================================== // Initialize @@ -218,8 +218,13 @@ fmha_tma_kernel(FmhaTmaParams params) { __syncthreads(); // ================================================================== - // PV GEMM + // PV GEMM: N=16 sub-tiles + // V canonical uses CORES_MN_V = 16/8 = 2 (NOT 16!) + // V SMEM size = 16 * 16 BF16 = 256 (not 128*16 = 2048) // ================================================================== + static constexpr int V_SUB_SZ = 16 * MMA_K_BF16; // (16, 16) canonical + static constexpr int CORES_MN_V = 16 / 8; // 2 + for (int n_sub = 0; n_sub < N_NSUB; n_sub++) { int d_base = n_sub * 16; for (int pv_kt = 0; pv_kt < NKT_PV; pv_kt++) { @@ -229,19 +234,14 @@ fmha_tma_kernel(FmhaTmaParams params) { for (int i = tid; i < TILE_SZ; i += 128) sPk[i] = 0; __syncthreads(); - // Write P values to canonical sPk - if (my_row_active) { - for (int c = 0; c < MMA_K_BF16; c++) { - int gc = col_start + c; - int ck = c/8, lc = c%8, cm = my_row/8, lr = my_row%8; - sPk[ck*CORES_MN*64 + cm*64 + lr*8 + lc] = f32_to_bf16(my_p_vals[gc]); - } + // Write P (only row 0 for T=1 decode, 16 elements) + for (int c = tid; c < MMA_K_BF16; c += 128) { + int ck = c / 8, lc = c % 8; + sPk[ck * CORES_MN * 64 + 0 * 64 + 0 * 8 + lc] = f32_to_bf16(my_p_vals[col_start + c]); } __syncthreads(); - // V sub-tile: TMA load + canonical - // V is (HD, s_k). TMA coord: {col_start, d_base} - // We load a (16, 128) tile at position (d_base, col_start) in V + // V sub-tile: TMA load if (wid == 0 && lane == 0) { tma_load_2d((uint32_t)__cvta_generic_to_shared(sTmaBuf), (uint64_t)tma_v, mbar_addr, col_start, d_base); tma_mbarrier_arrive_expect_tx(mbar_addr, TMA_TILE_BYTES); @@ -249,12 +249,22 @@ fmha_tma_kernel(FmhaTmaParams params) { tma_mbarrier_wait(mbar_addr, phase); phase ^= 1; __syncthreads(); - // Convert V from (16, 128) row-major to (128, 16) canonical - for (int i = tid; i < TILE_SZ; i += 128) sV[i] = 0; - for (int i = tid; i < 16 * 128; i += 128) { - int d = i / 128, r = i % 128; - int ck = d / 8, lc = d % 8, tmn = r / 8, lr = r % 8; - sV[ck * CORES_MN * 64 + tmn * 64 + lr * 8 + lc] = sTmaBuf[i]; + // Convert V: TMA loaded (16, 128) row-major → (16, 16) canonical with CORES_MN_V=2 + // V in GMEM is (HD, s_k). The TMA tile covers rows [d_base, d_base+16) and all cols. + // We need V_sub = V[d_base:d_base+16, col_start:col_start+16] + // From sTmaBuf (16, 128): element at (dd, r) = sTmaBuf[dd * 128 + r] + // where dd is the head-dim index (0..15) and r is the sequence index (0..127) + // We only need r in [col_start, col_start+16) + for (int i = tid; i < V_SUB_SZ; i += 128) sV[i] = 0; + for (int dd = tid / 32; dd < 16; dd += 4) { // 4 warps x 32 lanes + for (int lr = lane; lr < MMA_K_BF16; lr += 32) { + int r = col_start + lr; // sequence index + if (r < s_k) { + int g_mn = dd / 8, g_k = lr / 8; + int llr = dd % 8, lc = lr % 8; + sV[g_k * CORES_MN_V * 64 + g_mn * 64 + llr * 8 + lc] = sTmaBuf[dd * 128 + r]; + } + } } __syncthreads();