P3: ctypes loader for 6-warp FMHA (bypass torch JIT sm_100 arch issue)

- fmha_multihead_capi.cu: pure C API wrapper, no ATen/pybind11 deps
- fmha_multihead_op.py: nvcc precompile + ctypes load (sm_100a)
- Removed fmha_multihead_launch.cu (ATen approach didn't work)
- Updated test to call kernel directly via ctypes API
This commit is contained in:
2026-05-30 08:15:31 +00:00
parent 1e6adf5e01
commit adcf3e04ab
3 changed files with 310 additions and 96 deletions

View File

@@ -0,0 +1,104 @@
/**
* DSV4 FMHA Multi-Head — C API for ctypes loading.
*
* This provides a pure C API (no pybind11, no ATen) so we can compile
* with nvcc -arch=sm_100a and load via Python ctypes. PyTorch tensors
* are passed as raw device pointers + shapes + strides.
*
* The Python side handles tensor → (ptr, shape, stride) conversion
* and wraps the result back into torch tensors.
*/
#include <cuda_runtime.h>
#include <cstdint>
// Forward declaration of the kernel (from fmha_6warp_multihead.cuh)
// We need the template instantiations
#include "fmha_common.cuh"
#include "fmha_umma_desc.cuh"
#include "fmha_6warp_multihead.cuh"
extern "C" {
/** Compute dynamic SMEM size for the 6-warp multi-head kernel. */
int fmha_compute_smem(int hd) {
using namespace dsv4::kernels::attention;
int base = 4 + 4 + 4;
base = (base + 15) & ~15;
int sQ0_sz = 128 * hd * 2;
int sK0_sz = 128 * hd * 2;
int sPk_offset = ((base + sQ0_sz + sK0_sz) + 127) & ~127;
int sPk_sz = 128 * 16 * 2;
int sV_offset = ((sPk_offset + sPk_sz) + 127) & ~127;
int sV_sz = 16 * 16 * 2;
int sp_vals_offset = sV_offset + sV_sz;
int sp_vals_sz = 128 * 4;
int total = sp_vals_offset + sp_vals_sz;
return (total + 127) & ~127;
}
/**
* Launch the 6-warp multi-head FMHA kernel for decode.
*
* All pointers are device pointers (BF16 for q/k/v/o, FP32 for lse).
* Strides are in elements (not bytes).
* Returns 0 on success, non-zero on error.
*/
int fmha_multihead_decode_launch(
const void* q_ptr, // Q base: (batch, n_h, 1, hd) BF16
const void* k_ptr, // K base: (batch, n_kv, N, hd) BF16
const void* v_ptr, // V base: (batch, n_kv, hd, N) BF16
void* o_ptr, // O base: (batch, n_h, 1, hd) BF16
void* lse_ptr, // LSE base: (batch, n_h, 1) FP32 (can be NULL)
int batch, int n_h, int n_kv, int N, int hd,
int q_head_stride, int q_batch_stride,
int k_head_stride, int k_batch_stride,
int v_head_stride, int v_batch_stride,
int o_head_stride, int o_batch_stride,
int lse_head_stride, int lse_batch_stride,
float scale
) {
using namespace dsv4::kernels::attention;
FmhaParams params;
params.q = reinterpret_cast<const bf16_t*>(q_ptr);
params.k = reinterpret_cast<const bf16_t*>(k_ptr);
params.v = reinterpret_cast<const bf16_t*>(v_ptr);
params.o = reinterpret_cast<bf16_t*>(o_ptr);
params.lse = reinterpret_cast<float*>(lse_ptr);
params.s_k = N;
params.scale = scale;
params.head_dim = hd;
params.q_head_stride = q_head_stride;
params.q_batch_stride = q_batch_stride;
params.k_head_stride = k_head_stride;
params.k_batch_stride = k_batch_stride;
params.v_head_stride = v_head_stride;
params.v_batch_stride = v_batch_stride;
params.o_head_stride = o_head_stride;
params.o_batch_stride = o_batch_stride;
params.lse_head_stride = lse_head_stride;
params.lse_batch_stride = lse_batch_stride;
int smem = fmha_compute_smem(hd);
dim3 grid(1, n_h, batch);
dim3 block(NTHREADS);
cudaError_t err;
if (hd == 64) {
fmha_6warp_multihead_kernel<64, 128><<<grid, block, smem>>>(params);
} else if (hd == 128) {
fmha_6warp_multihead_kernel<128, 128><<<grid, block, smem>>>(params);
} else if (hd == 256) {
fmha_6warp_multihead_kernel<256, 128><<<grid, block, smem>>>(params);
} else {
return -1; // unsupported hd
}
err = cudaGetLastError();
if (err != cudaSuccess) return (int)err;
return 0;
}
} // extern "C"

View File

@@ -1,8 +1,8 @@
"""DSV4 FMHA — 6-warp multi-head decode kernel loader.
Loads the raw CUDA 6-warp multi-head FMHA kernel via
torch.utils.cpp_extension and wraps it as a torch.library.custom_op
for torch.compile compatibility.
Precompiles the raw CUDA kernel with nvcc (sm_100a) on first use,
then loads the .so via ctypes. This bypasses torch.utils.cpp_extension
which compiles with -arch=sm_100 (missing tcgen05 support).
Decode-only: T=1, single KV segment (N <= 128).
Supports MHA, MQA, and GQA attention patterns.
@@ -11,57 +11,188 @@ Supports MHA, MQA, and GQA attention patterns.
import torch
import logging
import os
import subprocess
import ctypes
from typing import Optional
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Lazy-load the CUDA extension
# ---------------------------------------------------------------------------
_ext = None
_ext_lock = False
KERNEL_DIR = os.path.dirname(os.path.abspath(__file__))
REPO_ROOT = os.path.normpath(os.path.join(KERNEL_DIR, "..", ".."))
SOURCE = os.path.join(KERNEL_DIR, "fmha_multihead_capi.cu")
BUILD_DIR = os.path.join(REPO_ROOT, "build", "fmha_multihead")
SO_NAME = "libfmha_multihead_decode.so"
_lib = None
_lib_lock = False
def _get_ext():
"""Lazy-load the JIT-compiled CUDA extension."""
global _ext, _ext_lock
if _ext is not None:
return _ext
if _ext_lock:
raise RuntimeError("Recursive extension load — check import cycle")
def _find_nvcc():
"""Find nvcc on the system."""
for c in ["/usr/local/cuda-13.2/bin/nvcc", "/usr/local/cuda/bin/nvcc"]:
if os.path.isfile(c):
return c
# Try PATH
import shutil
nvcc = shutil.which("nvcc")
if nvcc:
return nvcc
raise RuntimeError("nvcc not found — required for tcgen05 kernel compilation")
_ext_lock = True
def _ensure_built():
"""Build the shared library with nvcc if needed. Returns .so path."""
global _lib
if _lib is not None:
return _lib
so_path = os.path.join(BUILD_DIR, SO_NAME)
# Check if rebuild needed
need_build = True
if os.path.isfile(so_path):
src_mtime = os.path.getmtime(SOURCE)
for dep in ["fmha_common.cuh", "fmha_umma_desc.cuh",
"fmha_6warp_multihead.cuh", "fmha_multihead_capi.cu"]:
dep_path = os.path.join(KERNEL_DIR, dep)
if os.path.isfile(dep_path):
src_mtime = max(src_mtime, os.path.getmtime(dep_path))
need_build = src_mtime > os.path.getmtime(so_path)
if not need_build:
logger.info(f"Using cached {so_path}")
_lib = ctypes.CDLL(so_path)
return _lib
logger.info(f"Building {SO_NAME} with nvcc (sm_100a)...")
os.makedirs(BUILD_DIR, exist_ok=True)
nvcc = _find_nvcc()
cmd = [
nvcc,
"-std=c++20",
"-shared",
"-fPIC",
"-gencode=arch=compute_100a,code=sm_100a",
"-gencode=arch=compute_100a,code=compute_100a",
f"-I{KERNEL_DIR}",
f"-I{REPO_ROOT}",
"-O3",
"--expt-relaxed-constexpr",
SOURCE,
"-o", so_path,
"-lcudart",
"-lcuda",
]
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
raise RuntimeError(f"nvcc compilation failed:\n{result.stderr}")
if result.stderr:
logger.debug(f"nvcc warnings:\n{result.stderr}")
_lib = ctypes.CDLL(so_path)
logger.info(f"Built and loaded {so_path}")
return _lib
def _get_lib():
"""Get or build the shared library."""
global _lib_lock
if _lib is not None:
return _lib
if _lib_lock:
raise RuntimeError("Recursive build")
_lib_lock = True
try:
from torch.utils.cpp_extension import load
kernel_dir = os.path.dirname(os.path.abspath(__file__))
sources = [
os.path.join(kernel_dir, "fmha_multihead_launch.cu"),
]
extra_cflags = ["-O2"]
extra_cuda_cflags = [
"-arch=sm_100a",
"-std=c++20",
"--expt-relaxed-constexpr",
"-O3",
"--generate-code=arch=compute_100a,code=[sm_100a,compute_100a]",
]
# The .cuh includes are relative to the kernel_dir
extra_include_paths = [kernel_dir, os.path.join(kernel_dir, "..", "..")]
logger.info("JIT-compiling fmha_multihead_decode extension...")
_ext = load(
name="fmha_multihead_decode_ext",
sources=sources,
extra_cflags=extra_cflags,
extra_cuda_cflags=extra_cuda_cflags,
extra_include_paths=extra_include_paths,
verbose=False,
)
logger.info("fmha_multihead_decode extension loaded.")
return _ext
return _ensure_built()
finally:
_ext_lock = False
_lib_lock = False
# ---------------------------------------------------------------------------
# Kernel launch via ctypes
# ---------------------------------------------------------------------------
def fmha_multihead_decode_raw(
q: torch.Tensor, # (batch, n_h, 1, hd) BF16, contiguous
k: torch.Tensor, # (batch, n_kv, N, hd) BF16, contiguous
v: torch.Tensor, # (batch, n_kv, hd, N) BF16, contiguous
scale: float,
n_comp: int,
swa_len: int,
is_causal: bool,
attn_sink: torch.Tensor, # (batch, n_h) FP32 — unused by kernel currently
) -> tuple[torch.Tensor, torch.Tensor]:
"""Launch the 6-warp multi-head FMHA kernel. Returns (O, LSE).
O: (batch, n_h, 1, hd) BF16
LSE: (batch, n_h, 1) FP32
"""
lib = _get_lib()
B = q.shape[0]
n_h = q.shape[1]
hd = q.shape[3]
n_kv = k.shape[1]
N = k.shape[2]
assert q.shape[2] == 1, f"Decode requires T=1, got T={q.shape[2]}"
assert hd in (64, 128, 256), f"Unsupported hd={hd}"
assert N <= 128, f"Decode fast path requires N<=128, got N={N}"
assert q.is_contiguous()
assert k.is_contiguous()
assert v.is_contiguous()
o = torch.zeros(B, n_h, 1, hd, dtype=torch.bfloat16, device=q.device)
lse = torch.zeros(B, n_h, 1, dtype=torch.float32, device=q.device)
# Compute strides in BF16 elements
q_hs = q.stride(1) # head stride
q_bs = q.stride(0) # batch stride
k_hs = k.stride(1)
k_bs = k.stride(0)
v_hs = v.stride(1)
v_bs = v.stride(0)
o_hs = o.stride(1)
o_bs = o.stride(0)
lse_hs = lse.stride(1)
lse_bs = lse.stride(0)
# MQA: zero out KV head strides
if n_kv == 1:
k_hs = 0
v_hs = 0
# Call the C API
ret = lib.fmha_multihead_decode_launch(
ctypes.c_void_p(q.data_ptr()),
ctypes.c_void_p(k.data_ptr()),
ctypes.c_void_p(v.data_ptr()),
ctypes.c_void_p(o.data_ptr()),
ctypes.c_void_p(lse.data_ptr()),
ctypes.c_int(B),
ctypes.c_int(n_h),
ctypes.c_int(n_kv),
ctypes.c_int(N),
ctypes.c_int(hd),
ctypes.c_int(q_hs),
ctypes.c_int(q_bs),
ctypes.c_int(k_hs),
ctypes.c_int(k_bs),
ctypes.c_int(v_hs),
ctypes.c_int(v_bs),
ctypes.c_int(o_hs),
ctypes.c_int(o_bs),
ctypes.c_int(lse_hs),
ctypes.c_int(lse_bs),
ctypes.c_float(scale),
)
if ret != 0:
raise RuntimeError(f"Kernel launch failed with code {ret}")
return o, lse
# ---------------------------------------------------------------------------
@@ -70,23 +201,16 @@ def _get_ext():
@torch.library.custom_op("dsv4::fmha_multihead_decode", mutates_args=())
def fmha_multihead_decode(
q: torch.Tensor, # (batch, n_h, 1, hd) BF16
k: torch.Tensor, # (batch, n_kv, N, hd) BF16
v: torch.Tensor, # (batch, n_kv, hd, N) BF16
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
scale: float,
n_comp: int,
swa_len: int,
is_causal: bool,
attn_sink: torch.Tensor, # (batch, n_h) FP32 — zeros when unused
attn_sink: torch.Tensor,
) -> torch.Tensor:
"""6-warp multi-head FMHA decode kernel (T=1, single KV tile).
Returns normalized attention output (batch, n_h, 1, hd) BF16.
For single-segment decode, normalization is done in-kernel.
"""
ext = _get_ext()
# The C++ extension returns (O, LSE); we only need O for the fast path
o, _lse = ext.fmha_multihead_decode(
o, _ = fmha_multihead_decode_raw(
q, k, v, scale, n_comp, swa_len, is_causal, attn_sink
)
return o
@@ -97,51 +221,16 @@ def _(q, k, v, scale, n_comp, swa_len, is_causal, attn_sink):
return torch.empty_like(q)
# ---------------------------------------------------------------------------
# Convenience: run with LSE output (for multi-segment merge later)
# ---------------------------------------------------------------------------
def fmha_multihead_decode_with_lse(
q: torch.Tensor, # (batch, n_h, 1, hd) BF16
k: torch.Tensor, # (batch, n_kv, N, hd) BF16
v: torch.Tensor, # (batch, n_kv, hd, N) BF16
scale: float,
n_comp: int = 0,
swa_len: int = 0,
is_causal: bool = False,
attn_sink: Optional[torch.Tensor] = None, # (batch, n_h) FP32
q, k, v, scale, n_comp=0, swa_len=0, is_causal=False, attn_sink=None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Run 6-warp decode kernel, returning both O and LSE.
Returns:
O: (batch, n_h, 1, hd) BF16 — normalized attention output
LSE: (batch, n_h, 1) FP32 — log-sum-exp for multi-segment merge
"""
if attn_sink is None:
attn_sink = torch.zeros(q.shape[0], q.shape[1],
dtype=torch.float32, device=q.device)
ext = _get_ext()
o, lse = ext.fmha_multihead_decode(
return fmha_multihead_decode_raw(
q, k, v, scale, n_comp, swa_len, is_causal, attn_sink
)
return o, lse
# ---------------------------------------------------------------------------
# Shape check: is the 6-warp fast path applicable?
# ---------------------------------------------------------------------------
def can_use_6warp_decode(
T: int,
N: int,
hd: int,
n_segments: int,
) -> bool:
"""Check if the 6-warp decode fast path is applicable.
Fast path requires:
- T == 1 (decode)
- n_segments == 1 (single KV tile, N <= 128)
- hd in {64, 128, 256}
"""
def can_use_6warp_decode(T: int, N: int, hd: int, n_segments: int) -> bool:
return T == 1 and n_segments == 1 and hd in (64, 128, 256)

View File

@@ -86,7 +86,26 @@ def test_fast_path_matches_reference():
v = torch.randn(n_kv, N, hd, dtype=torch.bfloat16, device='cuda')
try:
o_fast = dsv4_attention(q, k, v, scale=scale)
from dsv4.kernels.attention.fmha_multihead_op import fmha_multihead_decode_raw
# Prepare tensors in the shape the kernel expects:
# Q: (1, n_q, 1, hd) BF16
# K: (1, n_kv, N, hd) BF16
# V: (1, n_kv, hd, N) BF16 (transposed!)
if n_kv == 1:
q_4d = q.unsqueeze(0).contiguous()
k_4d = k.unsqueeze(0).unsqueeze(0).contiguous()
v_4d = v.unsqueeze(0).unsqueeze(0).transpose(-1, -2).contiguous()
else:
q_4d = q.unsqueeze(0).contiguous()
k_4d = k.unsqueeze(0).contiguous()
v_4d = v.unsqueeze(0).transpose(-1, -2).contiguous()
sb = torch.zeros(1, n_q, dtype=torch.float32, device='cuda')
o_4d, lse_4d = fmha_multihead_decode_raw(
q_4d, k_4d, v_4d, scale, 0, 0, False, sb
)
o_fast = o_4d.squeeze(0) # (n_q, 1, hd)
o_ref = reference_attention(q, k, v, scale)
cos = cosine_sim(o_ref, o_fast).item()
status = "PASS" if cos >= 0.999998 else "FAIL"
@@ -94,7 +113,9 @@ def test_fast_path_matches_reference():
all_pass = False
print(f" {status} {desc}: cos={cos:.6f}")
except Exception as e:
import traceback
print(f" FAIL {desc}: {e}")
traceback.print_exc()
all_pass = False
return all_pass