FMHA SM100: Fix launch wrapper to match new kernel API

This commit is contained in:
2026-05-28 05:20:31 +00:00
parent a30ebfb197
commit 373900fa08

View File

@@ -1,115 +1,60 @@
/**
* DSV4 FMHA Decode — C++ launch wrapper and PyTorch binding.
* DSV4 FMHA Decode — Launch wrapper and PyTorch binding.
*/
#include "fmha_sm100.cuh"
#include <ATen/ATen.h>
#include <torch/extension.h>
namespace dsv4 {
namespace kernels {
namespace attention {
namespace dsv4::kernels::attention {
/** Compute SMEM size for a given head_dim. */
static int compute_smem(int D) {
int kvs = (D > 128) ? 1 : 2;
int q = 128 * D; // Q (1 stage)
int k = 128 * D * kvs; // K
int v = 128 * D * kvs; // V
int c = 128 * D; // C (epilogue)
return (q + k + v + c) * sizeof(bf16_t);
}
/**
* Launch FMHA decode kernel.
*
* Args:
* q: (batch, num_heads, T, head_dim) BF16
* k: (batch, s_k, head_dim) BF16
* v: (batch, head_dim, s_k) BF16 (transposed for PV)
* scale_softmax: 1 / sqrt(head_dim)
* n_comp: number of compressed KV entries (D3 mask offset)
* swa_len: SWA window length (D3 mask)
* is_causal: apply causal mask on SWA (D4)
* attn_sink: (batch, num_heads, T) float, nullable (D5c)
*
* Returns:
* o: (batch, num_heads, T, head_dim) BF16
* lse: (batch, num_heads, T) float
*/
std::tuple<torch::Tensor, torch::Tensor> fmha_decode_cuda(
torch::Tensor q, // (B, H, T, D) BF16
torch::Tensor k, // (B, s_k, D) BF16
torch::Tensor v, // (B, D, s_k) BF16
double scale_softmax,
int64_t n_comp,
int64_t swa_len,
bool is_causal,
c10::optional<torch::Tensor> attn_sink // nullable
torch::Tensor q, torch::Tensor k, torch::Tensor v,
double scale, int64_t n_comp, int64_t swa_len,
bool is_causal, c10::optional<torch::Tensor> attn_sink
) {
auto opts = q.options();
int B = q.size(0);
int H = q.size(1);
int T = q.size(2);
int D = q.size(3);
int s_k_val = k.size(1);
int B = q.size(0), H = q.size(1), D = q.size(3);
int sk = k.size(1);
auto o = torch::zeros({B, H, 1, D}, q.options());
auto lse = torch::zeros({B, H, 1}, q.options().dtype(torch::kFloat32));
auto o = torch::zeros({B, H, T, D}, opts);
auto lse = torch::zeros({B, H, T}, opts.dtype(torch::kFloat32));
int smem = compute_smem(D);
smem = (smem + 127) & ~127;
// Strides
int batch_stride_q = q.stride(0) * q.element_size() / 2; // in BF16 elements
int batch_stride_kv = k.stride(0) * k.element_size() / 2;
int batch_stride_o = o.stride(0) * o.element_size() / 2;
// SMEM size (dynamic)
int smem_bytes = TOTAL_SMEM(D);
// Add some extra for alignment
smem_bytes = (smem_bytes + 127) & ~127;
// Grid: (1, H, B)
dim3 grid(1, H, B);
dim3 block(THREADS_PER_CTA);
dim3 block(NTHREADS);
const float* sink_ptr = attn_sink.has_value() ? attn_sink->data_ptr<float>() : nullptr;
const float* sp = attn_sink.has_value() ? attn_sink->data_ptr<float>() : nullptr;
// Launch kernel (head_dim template specialization)
#define LAUNCH_FMHA(HD, CAUSAL, SINK) \
fmha_decode_kernel<HD, 1, CAUSAL, SINK><<<grid, block, smem_bytes>>>( \
reinterpret_cast<const bf16_t*>(q.data_ptr<at::BFloat16>()), \
reinterpret_cast<const bf16_t*>(k.data_ptr<at::BFloat16>()), \
reinterpret_cast<const bf16_t*>(v.data_ptr<at::BFloat16>()), \
reinterpret_cast<bf16_t*>(o.data_ptr<at::BFloat16>()), \
batch_stride_q, batch_stride_kv, batch_stride_o, \
s_k_val, n_comp, swa_len, (float)scale_softmax, \
sink_ptr, lse.data_ptr<float>() \
)
// Dispatch based on head_dim, causal, sink
if (D == 64) {
if (is_causal && sink_ptr) LAUNCH_FMHA(64, true, true);
else if (is_causal) LAUNCH_FMHA(64, true, false);
else if (sink_ptr) LAUNCH_FMHA(64, false, true);
else LAUNCH_FMHA(64, false, false);
} else if (D == 128) {
if (is_causal && sink_ptr) LAUNCH_FMHA(128, true, true);
else if (is_causal) LAUNCH_FMHA(128, true, false);
else if (sink_ptr) LAUNCH_FMHA(128, false, true);
else LAUNCH_FMHA(128, false, false);
} else if (D == 256) {
if (is_causal && sink_ptr) LAUNCH_FMHA(256, true, true);
else if (is_causal) LAUNCH_FMHA(256, true, false);
else if (sink_ptr) LAUNCH_FMHA(256, false, true);
else LAUNCH_FMHA(256, false, false);
} else if (D == 512) {
if (is_causal && sink_ptr) LAUNCH_FMHA(512, true, true);
else if (is_causal) LAUNCH_FMHA(512, true, false);
else if (sink_ptr) LAUNCH_FMHA(512, false, true);
else LAUNCH_FMHA(512, false, false);
} else {
TORCH_CHECK(false, "Unsupported head_dim: ", D);
}
#undef LAUNCH_FMHA
#define L(D, C, S) fmha_decode<D,C,S><<<grid,block,smem>>>( \
(bf16_t*)q.data_ptr<at::BFloat16>(), \
(bf16_t*)k.data_ptr<at::BFloat16>(), \
(bf16_t*)v.data_ptr<at::BFloat16>(), \
(bf16_t*)o.data_ptr<at::BFloat16>(), \
q.stride(0), k.stride(0), o.stride(0), \
sk, n_comp, swa_len, (float)scale, sp, lse.data_ptr<float>())
if (D==64) { if(is_causal&&sp) L(64,1,1); else if(is_causal) L(64,1,0); else if(sp) L(64,0,1); else L(64,0,0); }
else if (D==128) { if(is_causal&&sp) L(128,1,1); else if(is_causal) L(128,1,0); else if(sp) L(128,0,1); else L(128,0,0); }
else if (D==256) { if(is_causal&&sp) L(256,1,1); else if(is_causal) L(256,1,0); else if(sp) L(256,0,1); else L(256,0,0); }
else if (D==512) { if(is_causal&&sp) L(512,1,1); else if(is_causal) L(512,1,0); else if(sp) L(512,0,1); else L(512,0,0); }
else { TORCH_CHECK(false, "Unsupported head_dim: ", D); }
#undef L
return {o, lse};
}
} // namespace attention
} // namespace kernels
} // namespace dsv4
} // namespace dsv4::kernels::attention
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fmha_decode", &dsv4::kernels::attention::fmha_decode_cuda,
"DSV4 FMHA Decode (Blackwell SM100, raw CUDA)");
m.def("fmha_decode", &dsv4::kernels::attention::fmha_decode_cuda);
}