116 lines
4.1 KiB
C++
116 lines
4.1 KiB
C++
/**
|
|
* DSV4 FMHA Decode — C++ launch wrapper and PyTorch binding.
|
|
*/
|
|
|
|
#include "fmha_sm100.cuh"
|
|
#include <ATen/ATen.h>
|
|
#include <torch/extension.h>
|
|
|
|
namespace dsv4 {
|
|
namespace kernels {
|
|
namespace attention {
|
|
|
|
/**
|
|
* Launch FMHA decode kernel.
|
|
*
|
|
* Args:
|
|
* q: (batch, num_heads, T, head_dim) BF16
|
|
* k: (batch, s_k, head_dim) BF16
|
|
* v: (batch, head_dim, s_k) BF16 (transposed for PV)
|
|
* scale_softmax: 1 / sqrt(head_dim)
|
|
* n_comp: number of compressed KV entries (D3 mask offset)
|
|
* swa_len: SWA window length (D3 mask)
|
|
* is_causal: apply causal mask on SWA (D4)
|
|
* attn_sink: (batch, num_heads, T) float, nullable (D5c)
|
|
*
|
|
* Returns:
|
|
* o: (batch, num_heads, T, head_dim) BF16
|
|
* lse: (batch, num_heads, T) float
|
|
*/
|
|
std::tuple<torch::Tensor, torch::Tensor> fmha_decode_cuda(
|
|
torch::Tensor q, // (B, H, T, D) BF16
|
|
torch::Tensor k, // (B, s_k, D) BF16
|
|
torch::Tensor v, // (B, D, s_k) BF16
|
|
double scale_softmax,
|
|
int64_t n_comp,
|
|
int64_t swa_len,
|
|
bool is_causal,
|
|
c10::optional<torch::Tensor> attn_sink // nullable
|
|
) {
|
|
auto opts = q.options();
|
|
int B = q.size(0);
|
|
int H = q.size(1);
|
|
int T = q.size(2);
|
|
int D = q.size(3);
|
|
int s_k_val = k.size(1);
|
|
|
|
auto o = torch::zeros({B, H, T, D}, opts);
|
|
auto lse = torch::zeros({B, H, T}, opts.dtype(torch::kFloat32));
|
|
|
|
// Strides
|
|
int batch_stride_q = q.stride(0) * q.element_size() / 2; // in BF16 elements
|
|
int batch_stride_kv = k.stride(0) * k.element_size() / 2;
|
|
int batch_stride_o = o.stride(0) * o.element_size() / 2;
|
|
|
|
// SMEM size (dynamic)
|
|
int smem_bytes = TOTAL_SMEM(D);
|
|
// Add some extra for alignment
|
|
smem_bytes = (smem_bytes + 127) & ~127;
|
|
|
|
// Grid: (1, H, B)
|
|
dim3 grid(1, H, B);
|
|
dim3 block(THREADS_PER_CTA);
|
|
|
|
const float* sink_ptr = attn_sink.has_value() ? attn_sink->data_ptr<float>() : nullptr;
|
|
|
|
// Launch kernel (head_dim template specialization)
|
|
#define LAUNCH_FMHA(HD, CAUSAL, SINK) \
|
|
fmha_decode_kernel<HD, 1, CAUSAL, SINK><<<grid, block, smem_bytes>>>( \
|
|
reinterpret_cast<const __nv_bfloat16*>(q.data_ptr<at::BFloat16>()), \
|
|
reinterpret_cast<const __nv_bfloat16*>(k.data_ptr<at::BFloat16>()), \
|
|
reinterpret_cast<const __nv_bfloat16*>(v.data_ptr<at::BFloat16>()), \
|
|
reinterpret_cast<__nv_bfloat16*>(o.data_ptr<at::BFloat16>()), \
|
|
batch_stride_q, batch_stride_kv, batch_stride_o, \
|
|
s_k_val, n_comp, swa_len, (float)scale_softmax, \
|
|
sink_ptr, lse.data_ptr<float>() \
|
|
)
|
|
|
|
// Dispatch based on head_dim, causal, sink
|
|
if (D == 64) {
|
|
if (is_causal && sink_ptr) LAUNCH_FMHA(64, true, true);
|
|
else if (is_causal) LAUNCH_FMHA(64, true, false);
|
|
else if (sink_ptr) LAUNCH_FMHA(64, false, true);
|
|
else LAUNCH_FMHA(64, false, false);
|
|
} else if (D == 128) {
|
|
if (is_causal && sink_ptr) LAUNCH_FMHA(128, true, true);
|
|
else if (is_causal) LAUNCH_FMHA(128, true, false);
|
|
else if (sink_ptr) LAUNCH_FMHA(128, false, true);
|
|
else LAUNCH_FMHA(128, false, false);
|
|
} else if (D == 256) {
|
|
if (is_causal && sink_ptr) LAUNCH_FMHA(256, true, true);
|
|
else if (is_causal) LAUNCH_FMHA(256, true, false);
|
|
else if (sink_ptr) LAUNCH_FMHA(256, false, true);
|
|
else LAUNCH_FMHA(256, false, false);
|
|
} else if (D == 512) {
|
|
if (is_causal && sink_ptr) LAUNCH_FMHA(512, true, true);
|
|
else if (is_causal) LAUNCH_FMHA(512, true, false);
|
|
else if (sink_ptr) LAUNCH_FMHA(512, false, true);
|
|
else LAUNCH_FMHA(512, false, false);
|
|
} else {
|
|
TORCH_CHECK(false, "Unsupported head_dim: ", D);
|
|
}
|
|
|
|
#undef LAUNCH_FMHA
|
|
|
|
return {o, lse};
|
|
}
|
|
|
|
} // namespace attention
|
|
} // namespace kernels
|
|
} // namespace dsv4
|
|
|
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
m.def("fmha_decode", &dsv4::kernels::attention::fmha_decode_cuda,
|
|
"DSV4 FMHA Decode (Blackwell SM100, raw CUDA)");
|
|
}
|