From 373900fa0886fd53dc021dae56dd5565c59d0d4d Mon Sep 17 00:00:00 2001 From: biondizzle Date: Thu, 28 May 2026 05:20:31 +0000 Subject: [PATCH] FMHA SM100: Fix launch wrapper to match new kernel API --- dsv4/kernels/attention/fmha_sm100_launch.cu | 131 ++++++-------------- 1 file changed, 38 insertions(+), 93 deletions(-) diff --git a/dsv4/kernels/attention/fmha_sm100_launch.cu b/dsv4/kernels/attention/fmha_sm100_launch.cu index 1b199c23..aeb4390b 100644 --- a/dsv4/kernels/attention/fmha_sm100_launch.cu +++ b/dsv4/kernels/attention/fmha_sm100_launch.cu @@ -1,115 +1,60 @@ /** - * DSV4 FMHA Decode — C++ launch wrapper and PyTorch binding. + * DSV4 FMHA Decode — Launch wrapper and PyTorch binding. */ #include "fmha_sm100.cuh" #include #include -namespace dsv4 { -namespace kernels { -namespace attention { +namespace dsv4::kernels::attention { + +/** Compute SMEM size for a given head_dim. */ +static int compute_smem(int D) { + int kvs = (D > 128) ? 1 : 2; + int q = 128 * D; // Q (1 stage) + int k = 128 * D * kvs; // K + int v = 128 * D * kvs; // V + int c = 128 * D; // C (epilogue) + return (q + k + v + c) * sizeof(bf16_t); +} -/** - * Launch FMHA decode kernel. - * - * Args: - * q: (batch, num_heads, T, head_dim) BF16 - * k: (batch, s_k, head_dim) BF16 - * v: (batch, head_dim, s_k) BF16 (transposed for PV) - * scale_softmax: 1 / sqrt(head_dim) - * n_comp: number of compressed KV entries (D3 mask offset) - * swa_len: SWA window length (D3 mask) - * is_causal: apply causal mask on SWA (D4) - * attn_sink: (batch, num_heads, T) float, nullable (D5c) - * - * Returns: - * o: (batch, num_heads, T, head_dim) BF16 - * lse: (batch, num_heads, T) float - */ std::tuple fmha_decode_cuda( - torch::Tensor q, // (B, H, T, D) BF16 - torch::Tensor k, // (B, s_k, D) BF16 - torch::Tensor v, // (B, D, s_k) BF16 - double scale_softmax, - int64_t n_comp, - int64_t swa_len, - bool is_causal, - c10::optional attn_sink // nullable + torch::Tensor q, torch::Tensor k, torch::Tensor v, + double scale, int64_t n_comp, int64_t swa_len, + bool is_causal, c10::optional attn_sink ) { - auto opts = q.options(); - int B = q.size(0); - int H = q.size(1); - int T = q.size(2); - int D = q.size(3); - int s_k_val = k.size(1); + int B = q.size(0), H = q.size(1), D = q.size(3); + int sk = k.size(1); + auto o = torch::zeros({B, H, 1, D}, q.options()); + auto lse = torch::zeros({B, H, 1}, q.options().dtype(torch::kFloat32)); - auto o = torch::zeros({B, H, T, D}, opts); - auto lse = torch::zeros({B, H, T}, opts.dtype(torch::kFloat32)); + int smem = compute_smem(D); + smem = (smem + 127) & ~127; - // Strides - int batch_stride_q = q.stride(0) * q.element_size() / 2; // in BF16 elements - int batch_stride_kv = k.stride(0) * k.element_size() / 2; - int batch_stride_o = o.stride(0) * o.element_size() / 2; - - // SMEM size (dynamic) - int smem_bytes = TOTAL_SMEM(D); - // Add some extra for alignment - smem_bytes = (smem_bytes + 127) & ~127; - - // Grid: (1, H, B) dim3 grid(1, H, B); - dim3 block(THREADS_PER_CTA); + dim3 block(NTHREADS); - const float* sink_ptr = attn_sink.has_value() ? attn_sink->data_ptr() : nullptr; + const float* sp = attn_sink.has_value() ? attn_sink->data_ptr() : nullptr; - // Launch kernel (head_dim template specialization) - #define LAUNCH_FMHA(HD, CAUSAL, SINK) \ - fmha_decode_kernel<<>>( \ - reinterpret_cast(q.data_ptr()), \ - reinterpret_cast(k.data_ptr()), \ - reinterpret_cast(v.data_ptr()), \ - reinterpret_cast(o.data_ptr()), \ - batch_stride_q, batch_stride_kv, batch_stride_o, \ - s_k_val, n_comp, swa_len, (float)scale_softmax, \ - sink_ptr, lse.data_ptr() \ - ) - - // Dispatch based on head_dim, causal, sink - if (D == 64) { - if (is_causal && sink_ptr) LAUNCH_FMHA(64, true, true); - else if (is_causal) LAUNCH_FMHA(64, true, false); - else if (sink_ptr) LAUNCH_FMHA(64, false, true); - else LAUNCH_FMHA(64, false, false); - } else if (D == 128) { - if (is_causal && sink_ptr) LAUNCH_FMHA(128, true, true); - else if (is_causal) LAUNCH_FMHA(128, true, false); - else if (sink_ptr) LAUNCH_FMHA(128, false, true); - else LAUNCH_FMHA(128, false, false); - } else if (D == 256) { - if (is_causal && sink_ptr) LAUNCH_FMHA(256, true, true); - else if (is_causal) LAUNCH_FMHA(256, true, false); - else if (sink_ptr) LAUNCH_FMHA(256, false, true); - else LAUNCH_FMHA(256, false, false); - } else if (D == 512) { - if (is_causal && sink_ptr) LAUNCH_FMHA(512, true, true); - else if (is_causal) LAUNCH_FMHA(512, true, false); - else if (sink_ptr) LAUNCH_FMHA(512, false, true); - else LAUNCH_FMHA(512, false, false); - } else { - TORCH_CHECK(false, "Unsupported head_dim: ", D); - } - - #undef LAUNCH_FMHA + #define L(D, C, S) fmha_decode<<>>( \ + (bf16_t*)q.data_ptr(), \ + (bf16_t*)k.data_ptr(), \ + (bf16_t*)v.data_ptr(), \ + (bf16_t*)o.data_ptr(), \ + q.stride(0), k.stride(0), o.stride(0), \ + sk, n_comp, swa_len, (float)scale, sp, lse.data_ptr()) + if (D==64) { if(is_causal&&sp) L(64,1,1); else if(is_causal) L(64,1,0); else if(sp) L(64,0,1); else L(64,0,0); } + else if (D==128) { if(is_causal&&sp) L(128,1,1); else if(is_causal) L(128,1,0); else if(sp) L(128,0,1); else L(128,0,0); } + else if (D==256) { if(is_causal&&sp) L(256,1,1); else if(is_causal) L(256,1,0); else if(sp) L(256,0,1); else L(256,0,0); } + else if (D==512) { if(is_causal&&sp) L(512,1,1); else if(is_causal) L(512,1,0); else if(sp) L(512,0,1); else L(512,0,0); } + else { TORCH_CHECK(false, "Unsupported head_dim: ", D); } + #undef L return {o, lse}; } -} // namespace attention -} // namespace kernels -} // namespace dsv4 +} // namespace dsv4::kernels::attention PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("fmha_decode", &dsv4::kernels::attention::fmha_decode_cuda, - "DSV4 FMHA Decode (Blackwell SM100, raw CUDA)"); + m.def("fmha_decode", &dsv4::kernels::attention::fmha_decode_cuda); }