Files
nvfp4-megamoe-kernel/dsv4/kernels/attention/fmha_sm100.cpp

116 lines
4.1 KiB
C++

/**
* DSV4 FMHA Decode — C++ launch wrapper and PyTorch binding.
*/
#include "fmha_sm100.cuh"
#include <ATen/ATen.h>
#include <torch/extension.h>
namespace dsv4 {
namespace kernels {
namespace attention {
/**
* Launch FMHA decode kernel.
*
* Args:
* q: (batch, num_heads, T, head_dim) BF16
* k: (batch, s_k, head_dim) BF16
* v: (batch, head_dim, s_k) BF16 (transposed for PV)
* scale_softmax: 1 / sqrt(head_dim)
* n_comp: number of compressed KV entries (D3 mask offset)
* swa_len: SWA window length (D3 mask)
* is_causal: apply causal mask on SWA (D4)
* attn_sink: (batch, num_heads, T) float, nullable (D5c)
*
* Returns:
* o: (batch, num_heads, T, head_dim) BF16
* lse: (batch, num_heads, T) float
*/
std::tuple<torch::Tensor, torch::Tensor> fmha_decode_cuda(
torch::Tensor q, // (B, H, T, D) BF16
torch::Tensor k, // (B, s_k, D) BF16
torch::Tensor v, // (B, D, s_k) BF16
double scale_softmax,
int64_t n_comp,
int64_t swa_len,
bool is_causal,
c10::optional<torch::Tensor> attn_sink // nullable
) {
auto opts = q.options();
int B = q.size(0);
int H = q.size(1);
int T = q.size(2);
int D = q.size(3);
int s_k_val = k.size(1);
auto o = torch::zeros({B, H, T, D}, opts);
auto lse = torch::zeros({B, H, T}, opts.dtype(torch::kFloat32));
// Strides
int batch_stride_q = q.stride(0) * q.element_size() / 2; // in BF16 elements
int batch_stride_kv = k.stride(0) * k.element_size() / 2;
int batch_stride_o = o.stride(0) * o.element_size() / 2;
// SMEM size (dynamic)
int smem_bytes = TOTAL_SMEM(D);
// Add some extra for alignment
smem_bytes = (smem_bytes + 127) & ~127;
// Grid: (1, H, B)
dim3 grid(1, H, B);
dim3 block(THREADS_PER_CTA);
const float* sink_ptr = attn_sink.has_value() ? attn_sink->data_ptr<float>() : nullptr;
// Launch kernel (head_dim template specialization)
#define LAUNCH_FMHA(HD, CAUSAL, SINK) \
fmha_decode_kernel<HD, 1, CAUSAL, SINK><<<grid, block, smem_bytes>>>( \
reinterpret_cast<const __nv_bfloat16*>(q.data_ptr<at::BFloat16>()), \
reinterpret_cast<const __nv_bfloat16*>(k.data_ptr<at::BFloat16>()), \
reinterpret_cast<const __nv_bfloat16*>(v.data_ptr<at::BFloat16>()), \
reinterpret_cast<__nv_bfloat16*>(o.data_ptr<at::BFloat16>()), \
batch_stride_q, batch_stride_kv, batch_stride_o, \
s_k_val, n_comp, swa_len, (float)scale_softmax, \
sink_ptr, lse.data_ptr<float>() \
)
// Dispatch based on head_dim, causal, sink
if (D == 64) {
if (is_causal && sink_ptr) LAUNCH_FMHA(64, true, true);
else if (is_causal) LAUNCH_FMHA(64, true, false);
else if (sink_ptr) LAUNCH_FMHA(64, false, true);
else LAUNCH_FMHA(64, false, false);
} else if (D == 128) {
if (is_causal && sink_ptr) LAUNCH_FMHA(128, true, true);
else if (is_causal) LAUNCH_FMHA(128, true, false);
else if (sink_ptr) LAUNCH_FMHA(128, false, true);
else LAUNCH_FMHA(128, false, false);
} else if (D == 256) {
if (is_causal && sink_ptr) LAUNCH_FMHA(256, true, true);
else if (is_causal) LAUNCH_FMHA(256, true, false);
else if (sink_ptr) LAUNCH_FMHA(256, false, true);
else LAUNCH_FMHA(256, false, false);
} else if (D == 512) {
if (is_causal && sink_ptr) LAUNCH_FMHA(512, true, true);
else if (is_causal) LAUNCH_FMHA(512, true, false);
else if (sink_ptr) LAUNCH_FMHA(512, false, true);
else LAUNCH_FMHA(512, false, false);
} else {
TORCH_CHECK(false, "Unsupported head_dim: ", D);
}
#undef LAUNCH_FMHA
return {o, lse};
}
} // namespace attention
} // namespace kernels
} // namespace dsv4
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fmha_decode", &dsv4::kernels::attention::fmha_decode_cuda,
"DSV4 FMHA Decode (Blackwell SM100, raw CUDA)");
}