FMHA SM100: Fix launch wrapper to match new kernel API
This commit is contained in:
@@ -1,115 +1,60 @@
|
||||
/**
|
||||
* DSV4 FMHA Decode — C++ launch wrapper and PyTorch binding.
|
||||
* DSV4 FMHA Decode — Launch wrapper and PyTorch binding.
|
||||
*/
|
||||
|
||||
#include "fmha_sm100.cuh"
|
||||
#include <ATen/ATen.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
namespace dsv4 {
|
||||
namespace kernels {
|
||||
namespace attention {
|
||||
namespace dsv4::kernels::attention {
|
||||
|
||||
/** Compute SMEM size for a given head_dim. */
|
||||
static int compute_smem(int D) {
|
||||
int kvs = (D > 128) ? 1 : 2;
|
||||
int q = 128 * D; // Q (1 stage)
|
||||
int k = 128 * D * kvs; // K
|
||||
int v = 128 * D * kvs; // V
|
||||
int c = 128 * D; // C (epilogue)
|
||||
return (q + k + v + c) * sizeof(bf16_t);
|
||||
}
|
||||
|
||||
/**
|
||||
* Launch FMHA decode kernel.
|
||||
*
|
||||
* Args:
|
||||
* q: (batch, num_heads, T, head_dim) BF16
|
||||
* k: (batch, s_k, head_dim) BF16
|
||||
* v: (batch, head_dim, s_k) BF16 (transposed for PV)
|
||||
* scale_softmax: 1 / sqrt(head_dim)
|
||||
* n_comp: number of compressed KV entries (D3 mask offset)
|
||||
* swa_len: SWA window length (D3 mask)
|
||||
* is_causal: apply causal mask on SWA (D4)
|
||||
* attn_sink: (batch, num_heads, T) float, nullable (D5c)
|
||||
*
|
||||
* Returns:
|
||||
* o: (batch, num_heads, T, head_dim) BF16
|
||||
* lse: (batch, num_heads, T) float
|
||||
*/
|
||||
std::tuple<torch::Tensor, torch::Tensor> fmha_decode_cuda(
|
||||
torch::Tensor q, // (B, H, T, D) BF16
|
||||
torch::Tensor k, // (B, s_k, D) BF16
|
||||
torch::Tensor v, // (B, D, s_k) BF16
|
||||
double scale_softmax,
|
||||
int64_t n_comp,
|
||||
int64_t swa_len,
|
||||
bool is_causal,
|
||||
c10::optional<torch::Tensor> attn_sink // nullable
|
||||
torch::Tensor q, torch::Tensor k, torch::Tensor v,
|
||||
double scale, int64_t n_comp, int64_t swa_len,
|
||||
bool is_causal, c10::optional<torch::Tensor> attn_sink
|
||||
) {
|
||||
auto opts = q.options();
|
||||
int B = q.size(0);
|
||||
int H = q.size(1);
|
||||
int T = q.size(2);
|
||||
int D = q.size(3);
|
||||
int s_k_val = k.size(1);
|
||||
int B = q.size(0), H = q.size(1), D = q.size(3);
|
||||
int sk = k.size(1);
|
||||
auto o = torch::zeros({B, H, 1, D}, q.options());
|
||||
auto lse = torch::zeros({B, H, 1}, q.options().dtype(torch::kFloat32));
|
||||
|
||||
auto o = torch::zeros({B, H, T, D}, opts);
|
||||
auto lse = torch::zeros({B, H, T}, opts.dtype(torch::kFloat32));
|
||||
int smem = compute_smem(D);
|
||||
smem = (smem + 127) & ~127;
|
||||
|
||||
// Strides
|
||||
int batch_stride_q = q.stride(0) * q.element_size() / 2; // in BF16 elements
|
||||
int batch_stride_kv = k.stride(0) * k.element_size() / 2;
|
||||
int batch_stride_o = o.stride(0) * o.element_size() / 2;
|
||||
|
||||
// SMEM size (dynamic)
|
||||
int smem_bytes = TOTAL_SMEM(D);
|
||||
// Add some extra for alignment
|
||||
smem_bytes = (smem_bytes + 127) & ~127;
|
||||
|
||||
// Grid: (1, H, B)
|
||||
dim3 grid(1, H, B);
|
||||
dim3 block(THREADS_PER_CTA);
|
||||
dim3 block(NTHREADS);
|
||||
|
||||
const float* sink_ptr = attn_sink.has_value() ? attn_sink->data_ptr<float>() : nullptr;
|
||||
const float* sp = attn_sink.has_value() ? attn_sink->data_ptr<float>() : nullptr;
|
||||
|
||||
// Launch kernel (head_dim template specialization)
|
||||
#define LAUNCH_FMHA(HD, CAUSAL, SINK) \
|
||||
fmha_decode_kernel<HD, 1, CAUSAL, SINK><<<grid, block, smem_bytes>>>( \
|
||||
reinterpret_cast<const bf16_t*>(q.data_ptr<at::BFloat16>()), \
|
||||
reinterpret_cast<const bf16_t*>(k.data_ptr<at::BFloat16>()), \
|
||||
reinterpret_cast<const bf16_t*>(v.data_ptr<at::BFloat16>()), \
|
||||
reinterpret_cast<bf16_t*>(o.data_ptr<at::BFloat16>()), \
|
||||
batch_stride_q, batch_stride_kv, batch_stride_o, \
|
||||
s_k_val, n_comp, swa_len, (float)scale_softmax, \
|
||||
sink_ptr, lse.data_ptr<float>() \
|
||||
)
|
||||
|
||||
// Dispatch based on head_dim, causal, sink
|
||||
if (D == 64) {
|
||||
if (is_causal && sink_ptr) LAUNCH_FMHA(64, true, true);
|
||||
else if (is_causal) LAUNCH_FMHA(64, true, false);
|
||||
else if (sink_ptr) LAUNCH_FMHA(64, false, true);
|
||||
else LAUNCH_FMHA(64, false, false);
|
||||
} else if (D == 128) {
|
||||
if (is_causal && sink_ptr) LAUNCH_FMHA(128, true, true);
|
||||
else if (is_causal) LAUNCH_FMHA(128, true, false);
|
||||
else if (sink_ptr) LAUNCH_FMHA(128, false, true);
|
||||
else LAUNCH_FMHA(128, false, false);
|
||||
} else if (D == 256) {
|
||||
if (is_causal && sink_ptr) LAUNCH_FMHA(256, true, true);
|
||||
else if (is_causal) LAUNCH_FMHA(256, true, false);
|
||||
else if (sink_ptr) LAUNCH_FMHA(256, false, true);
|
||||
else LAUNCH_FMHA(256, false, false);
|
||||
} else if (D == 512) {
|
||||
if (is_causal && sink_ptr) LAUNCH_FMHA(512, true, true);
|
||||
else if (is_causal) LAUNCH_FMHA(512, true, false);
|
||||
else if (sink_ptr) LAUNCH_FMHA(512, false, true);
|
||||
else LAUNCH_FMHA(512, false, false);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported head_dim: ", D);
|
||||
}
|
||||
|
||||
#undef LAUNCH_FMHA
|
||||
#define L(D, C, S) fmha_decode<D,C,S><<<grid,block,smem>>>( \
|
||||
(bf16_t*)q.data_ptr<at::BFloat16>(), \
|
||||
(bf16_t*)k.data_ptr<at::BFloat16>(), \
|
||||
(bf16_t*)v.data_ptr<at::BFloat16>(), \
|
||||
(bf16_t*)o.data_ptr<at::BFloat16>(), \
|
||||
q.stride(0), k.stride(0), o.stride(0), \
|
||||
sk, n_comp, swa_len, (float)scale, sp, lse.data_ptr<float>())
|
||||
|
||||
if (D==64) { if(is_causal&&sp) L(64,1,1); else if(is_causal) L(64,1,0); else if(sp) L(64,0,1); else L(64,0,0); }
|
||||
else if (D==128) { if(is_causal&&sp) L(128,1,1); else if(is_causal) L(128,1,0); else if(sp) L(128,0,1); else L(128,0,0); }
|
||||
else if (D==256) { if(is_causal&&sp) L(256,1,1); else if(is_causal) L(256,1,0); else if(sp) L(256,0,1); else L(256,0,0); }
|
||||
else if (D==512) { if(is_causal&&sp) L(512,1,1); else if(is_causal) L(512,1,0); else if(sp) L(512,0,1); else L(512,0,0); }
|
||||
else { TORCH_CHECK(false, "Unsupported head_dim: ", D); }
|
||||
#undef L
|
||||
return {o, lse};
|
||||
}
|
||||
|
||||
} // namespace attention
|
||||
} // namespace kernels
|
||||
} // namespace dsv4
|
||||
} // namespace dsv4::kernels::attention
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("fmha_decode", &dsv4::kernels::attention::fmha_decode_cuda,
|
||||
"DSV4 FMHA Decode (Blackwell SM100, raw CUDA)");
|
||||
m.def("fmha_decode", &dsv4::kernels::attention::fmha_decode_cuda);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user