diff --git a/dsv4/kernels/attention/fmha_6warp_multirow.cuh b/dsv4/kernels/attention/fmha_6warp_multirow.cuh index f0ff5e32..5625f0f0 100644 --- a/dsv4/kernels/attention/fmha_6warp_multirow.cuh +++ b/dsv4/kernels/attention/fmha_6warp_multirow.cuh @@ -42,7 +42,7 @@ struct FmhaMultiRowParams { const bf16_t* __restrict__ k; const bf16_t* __restrict__ v; bf16_t* __restrict__ o; - float* __restrict__ lse; + float* __restrict__ lse; // [batch, n_h, T] — per-row LSE for multi-tile KV merge int s_k, T; float scale; int head_dim; @@ -51,7 +51,6 @@ struct FmhaMultiRowParams { int v_head_stride, v_batch_stride; int o_head_stride, o_batch_stride; int lse_head_stride, lse_batch_stride; - int normalize; // 1 = normalize in kernel, 0 = emit un-normalized O + LSE for multi-tile merge }; template @@ -265,21 +264,21 @@ fmha_6warp_multirow_kernel(FmhaMultiRowParams params) { __syncthreads(); // ================================================================ - // EPILOGUE: TMEM → regs → normalize (optional) → BF16 → GMEM + // EPILOGUE: TMEM → regs → normalize → BF16 → GMEM + LSE output // // CRITICAL: TMEM loads (32x32b.x8) are WARP-COLLECTIVE. // ALL 32 lanes must execute them. The load MUST be outside // the my_row_active guard. Only the GMEM store is conditional. // - // When normalize=1 (single KV tile): O_norm = O_unnorm / row_sum - // When normalize=0 (multi-tile merge): emit O_unnorm + LSE for - // Python merge: O = Σ exp(lse_i)·O_i / Σ exp(lse_i) + // Output: normalized O (O_unnorm / row_sum) + per-row LSE. + // LSE = ln(row_sum) + row_max, for multi-tile KV merge: + // O = Σ exp(lse_i - L) * O_i / Σ exp(lse_i - L) + // where L = max(lse_i) for numerical stability. // ================================================================ - const bool do_normalize = (params.normalize != 0); if (my_warp_active) { float rm = my_row_active ? sRowMax[my_row] : 0.0f; float rs = my_row_active ? sRowSum[my_row] : 0.0f; - float inv_rs = (my_row_active && do_normalize) ? (1.0f / rs) : 1.0f; + float inv_rs = my_row_active ? (1.0f / rs) : 0.0f; // Read O from TMEM: N_NSUB*2 groups of 8 columns // ALL lanes in the warp must execute the TMEM load (warp-collective) diff --git a/tests/unit/test_fmha_6warp_multirow.cu b/tests/unit/test_fmha_6warp_multirow.cu index e0af30d1..5539b9dc 100644 --- a/tests/unit/test_fmha_6warp_multirow.cu +++ b/tests/unit/test_fmha_6warp_multirow.cu @@ -3,10 +3,9 @@ * Compile with -DHD_VAL=64 etc. * * Tests: - * 1. Single KV tile, T=1..128 (normalized output) - * 2. Single KV tile, T=1..128 (un-normalized output + LSE) - * 3. Multi-tile KV via Python merge (s_k=256, 2 segments) - * 4. Multi-head and batched launches + * 1. Single KV tile, T=1..128 (normalized output + LSE) + * 2. Multi-tile KV via Python merge (s_k=256, 2 segments) + * 3. Multi-head and batched launches */ #include @@ -75,36 +74,8 @@ static void reference_attention_multirow( } } -// Reference that computes un-normalized O + LSE -static void reference_attention_multirow_unnorm( - const bf16_t* q, const bf16_t* k, const bf16_t* v, - float* o_unnorm, float* lse_ref, - int hd, int T, int s_k, float scale -) { - for (int t = 0; t < T; t++) { - float s[512]; - for (int j = 0; j < s_k; j++) { - float dot = 0.0f; - for (int d = 0; d < hd; d++) - dot += bf16_to_f32_host(q[t * hd + d]) * bf16_to_f32_host(k[j * hd + d]); - s[j] = dot * scale; - } - float mx = -INFINITY; - for (int j = 0; j < s_k; j++) mx = fmaxf(mx, s[j]); - float sm = 0.0f; - for (int j = 0; j < s_k; j++) { s[j] = expf(s[j] - mx); sm += s[j]; } - // Un-normalized: don't divide by sm - for (int d = 0; d < hd; d++) { - float ov = 0.0f; - for (int j = 0; j < s_k; j++) ov += s[j] * bf16_to_f32_host(v[d * s_k + j]); - o_unnorm[t * hd + d] = ov; // un-normalized! - } - if (lse_ref) lse_ref[t] = logf(sm) + mx; - } -} - -static int test_normalized(int T, int n_h = 1, int batch = 1) { - printf("\n=== NORMALIZED T=%d, n_h=%d, batch=%d, HD=%d ===\n", T, n_h, batch, HD); +static int test_single(int T, int n_h = 1, int batch = 1) { + printf("\n=== T=%d, n_h=%d, batch=%d, HD=%d ===\n", T, n_h, batch, HD); const float SCALE = 1.0f / sqrtf((float)HD); int total_heads = batch * n_h; @@ -137,7 +108,6 @@ static int test_normalized(int T, int n_h = 1, int batch = 1) { params.v_head_stride = HD * SK; params.v_batch_stride = n_h * HD * SK; params.o_head_stride = T * HD; params.o_batch_stride = n_h * T * HD; params.lse_head_stride = T; params.lse_batch_stride = n_h * T; - params.normalize = 1; int smem = compute_smem(); if (smem > 48 * 1024) @@ -184,86 +154,6 @@ static int test_normalized(int T, int n_h = 1, int batch = 1) { return failed == 0; } -static int test_unnormalized(int T) { - printf("\n=== UN-NORMALIZED T=%d, HD=%d ===\n", T, HD); - const float SCALE = 1.0f / sqrtf((float)HD); - - bf16_t* h_q = (bf16_t*)malloc(T * HD * sizeof(bf16_t)); - bf16_t* h_k = (bf16_t*)malloc(SK * HD * sizeof(bf16_t)); - bf16_t* h_v = (bf16_t*)malloc(HD * SK * sizeof(bf16_t)); - bf16_t* h_o = (bf16_t*)calloc(T * HD, sizeof(bf16_t)); - float* h_lse = (float*)calloc(T, sizeof(float)); - - srand(42 + T + 1000); - for (int i = 0; i < T * HD; i++) h_q[i] = f32_to_bf16_host((float)(rand()%100)/100.0f - 0.5f); - for (int i = 0; i < SK * HD; i++) h_k[i] = f32_to_bf16_host((float)(rand()%100)/100.0f - 0.5f); - for (int i = 0; i < HD * SK; i++) h_v[i] = f32_to_bf16_host((float)(rand()%100)/100.0f - 0.5f); - - bf16_t *d_q, *d_k, *d_v, *d_o; float *d_lse; - cudaMalloc(&d_q, T * HD * sizeof(bf16_t)); - cudaMalloc(&d_k, SK * HD * sizeof(bf16_t)); - cudaMalloc(&d_v, HD * SK * sizeof(bf16_t)); - cudaMalloc(&d_o, T * HD * sizeof(bf16_t)); - cudaMalloc(&d_lse, T * sizeof(float)); - cudaMemcpy(d_q, h_q, T * HD * sizeof(bf16_t), cudaMemcpyHostToDevice); - cudaMemcpy(d_k, h_k, SK * HD * sizeof(bf16_t), cudaMemcpyHostToDevice); - cudaMemcpy(d_v, h_v, HD * SK * sizeof(bf16_t), cudaMemcpyHostToDevice); - - FmhaMultiRowParams params; - params.q = d_q; params.k = d_k; params.v = d_v; params.o = d_o; params.lse = d_lse; - params.s_k = SK; params.T = T; params.scale = SCALE; params.head_dim = HD; - params.q_head_stride = T * HD; params.q_batch_stride = T * HD; - params.k_head_stride = SK * HD; params.k_batch_stride = SK * HD; - params.v_head_stride = HD * SK; params.v_batch_stride = HD * SK; - params.o_head_stride = T * HD; params.o_batch_stride = T * HD; - params.lse_head_stride = T; params.lse_batch_stride = T; - params.normalize = 0; // UN-NORMALIZED - - int smem = compute_smem(); - if (smem > 48 * 1024) - cudaFuncSetAttribute(fmha_6warp_multirow_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem); - - dim3 grid(1, 1, 1); - fmha_6warp_multirow_kernel<<>>(params); - - cudaError_t err = cudaDeviceSynchronize(); - if (err != cudaSuccess) { - printf(" CUDA ERROR: %s\n", cudaGetErrorString(err)); - cudaFree(d_q); cudaFree(d_k); cudaFree(d_v); cudaFree(d_o); cudaFree(d_lse); - free(h_q); free(h_k); free(h_v); free(h_o); free(h_lse); - return 0; - } - - cudaMemcpy(h_o, d_o, T * HD * sizeof(bf16_t), cudaMemcpyDeviceToHost); - cudaMemcpy(h_lse, d_lse, T * sizeof(float), cudaMemcpyDeviceToHost); - - // Verify: O_normalized = O_unnorm / row_sum, where exp(LSE) = row_sum * exp(max) - float o_unnorm_ref[MAX_T * 512]; float lse_ref[MAX_T]; - reference_attention_multirow_unnorm(h_q, h_k, h_v, o_unnorm_ref, lse_ref, HD, T, SK, SCALE); - - int failed = 0; float min_cos = 1.0f; - for (int t = 0; t < T; t++) { - // Check un-normalized O matches reference - float cs=0,na=0,nb=0; - for (int d=0;d1e-4f){cs+=a*b2;na+=a*a;nb+=b2*b2;} - } - cs /= (sqrtf(na)*sqrtf(nb)+1e-10f); - if(cs 0.01f) { printf(" FAIL lse t=%d kernel=%.6f ref=%.6f err=%.6f\n",t,h_lse[t],lse_ref[t],lse_err); failed++; } - } - printf(" min_cos=%.8f %s\n", min_cos, failed==0?"PASSED":"FAILED"); - - cudaFree(d_q); cudaFree(d_k); cudaFree(d_v); cudaFree(d_o); cudaFree(d_lse); - free(h_q); free(h_k); free(h_v); free(h_o); free(h_lse); - return failed == 0; -} - static int test_multitile_merge(int T) { printf("\n=== MULTI-TILE MERGE T=%d, s_k=256, HD=%d ===\n", T, HD); constexpr int SK_TOTAL = 256; // 2 KV tiles @@ -279,9 +169,6 @@ static int test_multitile_merge(int T) { for (int i = 0; i < SK_TOTAL * HD; i++) h_k[i] = f32_to_bf16_host((float)(rand()%100)/100.0f - 0.5f); for (int i = 0; i < HD * SK_TOTAL; i++) h_v[i] = f32_to_bf16_host((float)(rand()%100)/100.0f - 0.5f); - // Run kernel per KV tile, get un-normalized O + LSE, merge in Python - float* h_o_merged = (float*)calloc(T * HD, sizeof(float)); - bf16_t *d_q, *d_k, *d_v, *d_o; float *d_lse; cudaMalloc(&d_q, T * HD * sizeof(bf16_t)); cudaMalloc(&d_k, SK * HD * sizeof(bf16_t)); // single tile @@ -310,7 +197,6 @@ static int test_multitile_merge(int T) { params.v_head_stride = HD * SK; params.v_batch_stride = HD * SK; params.o_head_stride = T * HD; params.o_batch_stride = T * HD; params.lse_head_stride = T; params.lse_batch_stride = T; - params.normalize = 0; // UN-NORMALIZED for merge dim3 grid(1, 1, 1); fmha_6warp_multirow_kernel<<>>(params); @@ -318,7 +204,7 @@ static int test_multitile_merge(int T) { if (err != cudaSuccess) { printf(" CUDA ERROR tile %d: %s\n", tile, cudaGetErrorString(err)); cudaFree(d_q); cudaFree(d_k); cudaFree(d_v); cudaFree(d_o); cudaFree(d_lse); - free(h_q); free(h_k); free(h_v); free(h_o_merged); + free(h_q); free(h_k); free(h_v); free(lse_per_tile); free(o_per_tile); return 0; } @@ -332,7 +218,10 @@ static int test_multitile_merge(int T) { free(h_o_tile); } - // Python KV merge: O = Σ exp(lse_i)·O_i / Σ exp(lse_i) + // Python KV merge with normalized O + LSE: + // O = Σ exp(lse_i - L) * O_i_norm / Σ exp(lse_i - L) + // where L = max(lse_i) for numerical stability + float* h_o_merged = (float*)calloc(T * HD, sizeof(float)); for (int t = 0; t < T; t++) { float lse_max = -INFINITY; for (int tile = 0; tile < N_TILES; tile++) @@ -376,27 +265,18 @@ int main() { int ok = 1; - // 1. Normalized output (single KV tile) - printf("\n--- Normalized output tests ---\n"); - ok &= test_normalized(1); - ok &= test_normalized(2); - ok &= test_normalized(4); - ok &= test_normalized(8); - ok &= test_normalized(16); - ok &= test_normalized(32); - ok &= test_normalized(64); - ok &= test_normalized(128); + // 1. Single KV tile, normalized output + printf("\n--- Single KV tile tests ---\n"); + ok &= test_single(1); + ok &= test_single(2); + ok &= test_single(4); + ok &= test_single(8); + ok &= test_single(16); + ok &= test_single(32); + ok &= test_single(64); + ok &= test_single(128); - // 2. Un-normalized output + LSE - printf("\n--- Un-normalized output + LSE tests ---\n"); - ok &= test_unnormalized(1); - ok &= test_unnormalized(4); - ok &= test_unnormalized(16); - ok &= test_unnormalized(32); - ok &= test_unnormalized(64); - ok &= test_unnormalized(128); - - // 3. Multi-tile KV merge + // 2. Multi-tile KV merge (s_k=256, 2 segments) printf("\n--- Multi-tile KV merge tests ---\n"); ok &= test_multitile_merge(1); ok &= test_multitile_merge(4); @@ -405,13 +285,13 @@ int main() { ok &= test_multitile_merge(64); ok &= test_multitile_merge(128); - // 4. Multi-head and batched + // 3. Multi-head and batched printf("\n--- Multi-head and batched tests ---\n"); - ok &= test_normalized(4, 4, 1); // 4 heads, T=4 - ok &= test_normalized(16, 4, 1); // 4 heads, T=16 - ok &= test_normalized(64, 4, 1); // 4 heads, T=64 - ok &= test_normalized(1, 2, 2); // 2 heads, 2 batch, T=1 - ok &= test_normalized(16, 2, 2); // 2 heads, 2 batch, T=16 + ok &= test_single(4, 4, 1); // 4 heads, T=4 + ok &= test_single(16, 4, 1); // 4 heads, T=16 + ok &= test_single(64, 4, 1); // 4 heads, T=64 + ok &= test_single(1, 2, 2); // 2 heads, 2 batch, T=1 + ok &= test_single(16, 2, 2); // 2 heads, 2 batch, T=16 printf("\n%s\n", ok ? "ALL PASSED" : "SOME FAILED"); return ok ? 0 : 1;