diff --git a/dsv4/kernels/attention/fmha_6warp_multirow.cuh b/dsv4/kernels/attention/fmha_6warp_multirow.cuh index 185faad3..f0ff5e32 100644 --- a/dsv4/kernels/attention/fmha_6warp_multirow.cuh +++ b/dsv4/kernels/attention/fmha_6warp_multirow.cuh @@ -51,6 +51,7 @@ struct FmhaMultiRowParams { int v_head_stride, v_batch_stride; int o_head_stride, o_batch_stride; int lse_head_stride, lse_batch_stride; + int normalize; // 1 = normalize in kernel, 0 = emit un-normalized O + LSE for multi-tile merge }; template @@ -264,16 +265,21 @@ fmha_6warp_multirow_kernel(FmhaMultiRowParams params) { __syncthreads(); // ================================================================ - // EPILOGUE: TMEM → regs → normalize → BF16 → GMEM + // EPILOGUE: TMEM → regs → normalize (optional) → BF16 → GMEM // // CRITICAL: TMEM loads (32x32b.x8) are WARP-COLLECTIVE. // ALL 32 lanes must execute them. The load MUST be outside // the my_row_active guard. Only the GMEM store is conditional. + // + // When normalize=1 (single KV tile): O_norm = O_unnorm / row_sum + // When normalize=0 (multi-tile merge): emit O_unnorm + LSE for + // Python merge: O = Σ exp(lse_i)·O_i / Σ exp(lse_i) // ================================================================ + const bool do_normalize = (params.normalize != 0); if (my_warp_active) { float rm = my_row_active ? sRowMax[my_row] : 0.0f; float rs = my_row_active ? sRowSum[my_row] : 0.0f; - float inv_rs = my_row_active ? (1.0f / rs) : 0.0f; + float inv_rs = (my_row_active && do_normalize) ? (1.0f / rs) : 1.0f; // Read O from TMEM: N_NSUB*2 groups of 8 columns // ALL lanes in the warp must execute the TMEM load (warp-collective) diff --git a/tests/unit/test_fmha_6warp_multirow.cu b/tests/unit/test_fmha_6warp_multirow.cu index 31d9164f..e0af30d1 100644 --- a/tests/unit/test_fmha_6warp_multirow.cu +++ b/tests/unit/test_fmha_6warp_multirow.cu @@ -1,6 +1,12 @@ /** * Test multi-row FMHA kernel (6-warp, T>1 prefill). * Compile with -DHD_VAL=64 etc. + * + * Tests: + * 1. Single KV tile, T=1..128 (normalized output) + * 2. Single KV tile, T=1..128 (un-normalized output + LSE) + * 3. Multi-tile KV via Python merge (s_k=256, 2 segments) + * 4. Multi-head and batched launches */ #include @@ -29,17 +35,16 @@ constexpr int MAX_T = 128; static int compute_smem() { size_t off = 0; - off += 8; - off += 128 * sizeof(float); // sRowMax - off += 128 * sizeof(float); // sRowSum + off += 4; // sTmemBase + off += 128 * sizeof(float); // sRowMax + off += 128 * sizeof(float); // sRowSum off = (off + 127) & ~(size_t)127; off += 128 * MMA_K_BF16 * sizeof(bf16_t); // sQ0 off += 128 * MMA_K_BF16 * sizeof(bf16_t); // sK0 off = (off + 127) & ~(size_t)127; off += 128 * MMA_K_BF16 * sizeof(bf16_t); // sPk off = (off + 127) & ~(size_t)127; - off += 16 * MMA_K_BF16 * sizeof(bf16_t); // sV - off += 256; + off += 16 * MMA_K_BF16 * sizeof(bf16_t); // sV return (int)off; } @@ -70,8 +75,36 @@ static void reference_attention_multirow( } } -static int test_single_T(int T, int n_h = 1, int batch = 1) { - printf("\n=== T=%d, n_h=%d, batch=%d, HD=%d ===\n", T, n_h, batch, HD); +// Reference that computes un-normalized O + LSE +static void reference_attention_multirow_unnorm( + const bf16_t* q, const bf16_t* k, const bf16_t* v, + float* o_unnorm, float* lse_ref, + int hd, int T, int s_k, float scale +) { + for (int t = 0; t < T; t++) { + float s[512]; + for (int j = 0; j < s_k; j++) { + float dot = 0.0f; + for (int d = 0; d < hd; d++) + dot += bf16_to_f32_host(q[t * hd + d]) * bf16_to_f32_host(k[j * hd + d]); + s[j] = dot * scale; + } + float mx = -INFINITY; + for (int j = 0; j < s_k; j++) mx = fmaxf(mx, s[j]); + float sm = 0.0f; + for (int j = 0; j < s_k; j++) { s[j] = expf(s[j] - mx); sm += s[j]; } + // Un-normalized: don't divide by sm + for (int d = 0; d < hd; d++) { + float ov = 0.0f; + for (int j = 0; j < s_k; j++) ov += s[j] * bf16_to_f32_host(v[d * s_k + j]); + o_unnorm[t * hd + d] = ov; // un-normalized! + } + if (lse_ref) lse_ref[t] = logf(sm) + mx; + } +} + +static int test_normalized(int T, int n_h = 1, int batch = 1) { + printf("\n=== NORMALIZED T=%d, n_h=%d, batch=%d, HD=%d ===\n", T, n_h, batch, HD); const float SCALE = 1.0f / sqrtf((float)HD); int total_heads = batch * n_h; @@ -104,6 +137,7 @@ static int test_single_T(int T, int n_h = 1, int batch = 1) { params.v_head_stride = HD * SK; params.v_batch_stride = n_h * HD * SK; params.o_head_stride = T * HD; params.o_batch_stride = n_h * T * HD; params.lse_head_stride = T; params.lse_batch_stride = n_h * T; + params.normalize = 1; int smem = compute_smem(); if (smem > 48 * 1024) @@ -150,27 +184,234 @@ static int test_single_T(int T, int n_h = 1, int batch = 1) { return failed == 0; } +static int test_unnormalized(int T) { + printf("\n=== UN-NORMALIZED T=%d, HD=%d ===\n", T, HD); + const float SCALE = 1.0f / sqrtf((float)HD); + + bf16_t* h_q = (bf16_t*)malloc(T * HD * sizeof(bf16_t)); + bf16_t* h_k = (bf16_t*)malloc(SK * HD * sizeof(bf16_t)); + bf16_t* h_v = (bf16_t*)malloc(HD * SK * sizeof(bf16_t)); + bf16_t* h_o = (bf16_t*)calloc(T * HD, sizeof(bf16_t)); + float* h_lse = (float*)calloc(T, sizeof(float)); + + srand(42 + T + 1000); + for (int i = 0; i < T * HD; i++) h_q[i] = f32_to_bf16_host((float)(rand()%100)/100.0f - 0.5f); + for (int i = 0; i < SK * HD; i++) h_k[i] = f32_to_bf16_host((float)(rand()%100)/100.0f - 0.5f); + for (int i = 0; i < HD * SK; i++) h_v[i] = f32_to_bf16_host((float)(rand()%100)/100.0f - 0.5f); + + bf16_t *d_q, *d_k, *d_v, *d_o; float *d_lse; + cudaMalloc(&d_q, T * HD * sizeof(bf16_t)); + cudaMalloc(&d_k, SK * HD * sizeof(bf16_t)); + cudaMalloc(&d_v, HD * SK * sizeof(bf16_t)); + cudaMalloc(&d_o, T * HD * sizeof(bf16_t)); + cudaMalloc(&d_lse, T * sizeof(float)); + cudaMemcpy(d_q, h_q, T * HD * sizeof(bf16_t), cudaMemcpyHostToDevice); + cudaMemcpy(d_k, h_k, SK * HD * sizeof(bf16_t), cudaMemcpyHostToDevice); + cudaMemcpy(d_v, h_v, HD * SK * sizeof(bf16_t), cudaMemcpyHostToDevice); + + FmhaMultiRowParams params; + params.q = d_q; params.k = d_k; params.v = d_v; params.o = d_o; params.lse = d_lse; + params.s_k = SK; params.T = T; params.scale = SCALE; params.head_dim = HD; + params.q_head_stride = T * HD; params.q_batch_stride = T * HD; + params.k_head_stride = SK * HD; params.k_batch_stride = SK * HD; + params.v_head_stride = HD * SK; params.v_batch_stride = HD * SK; + params.o_head_stride = T * HD; params.o_batch_stride = T * HD; + params.lse_head_stride = T; params.lse_batch_stride = T; + params.normalize = 0; // UN-NORMALIZED + + int smem = compute_smem(); + if (smem > 48 * 1024) + cudaFuncSetAttribute(fmha_6warp_multirow_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem); + + dim3 grid(1, 1, 1); + fmha_6warp_multirow_kernel<<>>(params); + + cudaError_t err = cudaDeviceSynchronize(); + if (err != cudaSuccess) { + printf(" CUDA ERROR: %s\n", cudaGetErrorString(err)); + cudaFree(d_q); cudaFree(d_k); cudaFree(d_v); cudaFree(d_o); cudaFree(d_lse); + free(h_q); free(h_k); free(h_v); free(h_o); free(h_lse); + return 0; + } + + cudaMemcpy(h_o, d_o, T * HD * sizeof(bf16_t), cudaMemcpyDeviceToHost); + cudaMemcpy(h_lse, d_lse, T * sizeof(float), cudaMemcpyDeviceToHost); + + // Verify: O_normalized = O_unnorm / row_sum, where exp(LSE) = row_sum * exp(max) + float o_unnorm_ref[MAX_T * 512]; float lse_ref[MAX_T]; + reference_attention_multirow_unnorm(h_q, h_k, h_v, o_unnorm_ref, lse_ref, HD, T, SK, SCALE); + + int failed = 0; float min_cos = 1.0f; + for (int t = 0; t < T; t++) { + // Check un-normalized O matches reference + float cs=0,na=0,nb=0; + for (int d=0;d1e-4f){cs+=a*b2;na+=a*a;nb+=b2*b2;} + } + cs /= (sqrtf(na)*sqrtf(nb)+1e-10f); + if(cs 0.01f) { printf(" FAIL lse t=%d kernel=%.6f ref=%.6f err=%.6f\n",t,h_lse[t],lse_ref[t],lse_err); failed++; } + } + printf(" min_cos=%.8f %s\n", min_cos, failed==0?"PASSED":"FAILED"); + + cudaFree(d_q); cudaFree(d_k); cudaFree(d_v); cudaFree(d_o); cudaFree(d_lse); + free(h_q); free(h_k); free(h_v); free(h_o); free(h_lse); + return failed == 0; +} + +static int test_multitile_merge(int T) { + printf("\n=== MULTI-TILE MERGE T=%d, s_k=256, HD=%d ===\n", T, HD); + constexpr int SK_TOTAL = 256; // 2 KV tiles + constexpr int N_TILES = SK_TOTAL / SK; // 2 + const float SCALE = 1.0f / sqrtf((float)HD); + + bf16_t* h_q = (bf16_t*)malloc(T * HD * sizeof(bf16_t)); + bf16_t* h_k = (bf16_t*)malloc(SK_TOTAL * HD * sizeof(bf16_t)); + bf16_t* h_v = (bf16_t*)malloc(HD * SK_TOTAL * sizeof(bf16_t)); + + srand(42 + T + 2000); + for (int i = 0; i < T * HD; i++) h_q[i] = f32_to_bf16_host((float)(rand()%100)/100.0f - 0.5f); + for (int i = 0; i < SK_TOTAL * HD; i++) h_k[i] = f32_to_bf16_host((float)(rand()%100)/100.0f - 0.5f); + for (int i = 0; i < HD * SK_TOTAL; i++) h_v[i] = f32_to_bf16_host((float)(rand()%100)/100.0f - 0.5f); + + // Run kernel per KV tile, get un-normalized O + LSE, merge in Python + float* h_o_merged = (float*)calloc(T * HD, sizeof(float)); + + bf16_t *d_q, *d_k, *d_v, *d_o; float *d_lse; + cudaMalloc(&d_q, T * HD * sizeof(bf16_t)); + cudaMalloc(&d_k, SK * HD * sizeof(bf16_t)); // single tile + cudaMalloc(&d_v, HD * SK * sizeof(bf16_t)); // single tile + cudaMalloc(&d_o, T * HD * sizeof(bf16_t)); + cudaMalloc(&d_lse, T * sizeof(float)); + cudaMemcpy(d_q, h_q, T * HD * sizeof(bf16_t), cudaMemcpyHostToDevice); + + int smem = compute_smem(); + if (smem > 48 * 1024) + cudaFuncSetAttribute(fmha_6warp_multirow_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem); + + float* lse_per_tile = (float*)malloc(N_TILES * T * sizeof(float)); + float* o_per_tile = (float*)malloc(N_TILES * T * HD * sizeof(float)); + + for (int tile = 0; tile < N_TILES; tile++) { + // K/V for this tile + cudaMemcpy(d_k, h_k + tile * SK * HD, SK * HD * sizeof(bf16_t), cudaMemcpyHostToDevice); + cudaMemcpy(d_v, h_v + tile * HD * SK, HD * SK * sizeof(bf16_t), cudaMemcpyHostToDevice); + + FmhaMultiRowParams params; + params.q = d_q; params.k = d_k; params.v = d_v; params.o = d_o; params.lse = d_lse; + params.s_k = SK; params.T = T; params.scale = SCALE; params.head_dim = HD; + params.q_head_stride = T * HD; params.q_batch_stride = T * HD; + params.k_head_stride = SK * HD; params.k_batch_stride = SK * HD; + params.v_head_stride = HD * SK; params.v_batch_stride = HD * SK; + params.o_head_stride = T * HD; params.o_batch_stride = T * HD; + params.lse_head_stride = T; params.lse_batch_stride = T; + params.normalize = 0; // UN-NORMALIZED for merge + + dim3 grid(1, 1, 1); + fmha_6warp_multirow_kernel<<>>(params); + cudaError_t err = cudaDeviceSynchronize(); + if (err != cudaSuccess) { + printf(" CUDA ERROR tile %d: %s\n", tile, cudaGetErrorString(err)); + cudaFree(d_q); cudaFree(d_k); cudaFree(d_v); cudaFree(d_o); cudaFree(d_lse); + free(h_q); free(h_k); free(h_v); free(h_o_merged); + free(lse_per_tile); free(o_per_tile); + return 0; + } + + bf16_t* h_o_tile = (bf16_t*)malloc(T * HD * sizeof(bf16_t)); + cudaMemcpy(h_o_tile, d_o, T * HD * sizeof(bf16_t), cudaMemcpyDeviceToHost); + cudaMemcpy(lse_per_tile + tile * T, d_lse, T * sizeof(float), cudaMemcpyDeviceToHost); + + for (int i = 0; i < T * HD; i++) + o_per_tile[tile * T * HD + i] = bf16_to_f32_host(h_o_tile[i]); + free(h_o_tile); + } + + // Python KV merge: O = Σ exp(lse_i)·O_i / Σ exp(lse_i) + for (int t = 0; t < T; t++) { + float lse_max = -INFINITY; + for (int tile = 0; tile < N_TILES; tile++) + lse_max = fmaxf(lse_max, lse_per_tile[tile * T + t]); + float sum_w = 0.0f; + for (int tile = 0; tile < N_TILES; tile++) + sum_w += expf(lse_per_tile[tile * T + t] - lse_max); + for (int d = 0; d < HD; d++) { + float ov = 0.0f; + for (int tile = 0; tile < N_TILES; tile++) + ov += expf(lse_per_tile[tile * T + t] - lse_max) * o_per_tile[tile * T * HD + t * HD + d]; + h_o_merged[t * HD + d] = ov / sum_w; + } + } + + // Compare with full reference + float o_ref[MAX_T * 512]; + reference_attention_multirow(h_q, h_k, h_v, o_ref, nullptr, HD, T, SK_TOTAL, SCALE); + + int failed = 0; float min_cos = 1.0f; + for (int t = 0; t < T; t++) { + float cs=0,na=0,nb=0; + for (int d=0;d1e-4f){cs+=a*b2;na+=a*a;nb+=b2*b2;} + } + cs /= (sqrtf(na)*sqrtf(nb)+1e-10f); + if(cs