P3: ctypes loader for 6-warp FMHA (bypass torch JIT sm_100 arch issue)
- fmha_multihead_capi.cu: pure C API wrapper, no ATen/pybind11 deps - fmha_multihead_op.py: nvcc precompile + ctypes load (sm_100a) - Removed fmha_multihead_launch.cu (ATen approach didn't work) - Updated test to call kernel directly via ctypes API
This commit is contained in:
104
dsv4/kernels/attention/fmha_multihead_capi.cu
Normal file
104
dsv4/kernels/attention/fmha_multihead_capi.cu
Normal file
@@ -0,0 +1,104 @@
|
||||
/**
|
||||
* DSV4 FMHA Multi-Head — C API for ctypes loading.
|
||||
*
|
||||
* This provides a pure C API (no pybind11, no ATen) so we can compile
|
||||
* with nvcc -arch=sm_100a and load via Python ctypes. PyTorch tensors
|
||||
* are passed as raw device pointers + shapes + strides.
|
||||
*
|
||||
* The Python side handles tensor → (ptr, shape, stride) conversion
|
||||
* and wraps the result back into torch tensors.
|
||||
*/
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cstdint>
|
||||
|
||||
// Forward declaration of the kernel (from fmha_6warp_multihead.cuh)
|
||||
// We need the template instantiations
|
||||
|
||||
#include "fmha_common.cuh"
|
||||
#include "fmha_umma_desc.cuh"
|
||||
#include "fmha_6warp_multihead.cuh"
|
||||
|
||||
extern "C" {
|
||||
|
||||
/** Compute dynamic SMEM size for the 6-warp multi-head kernel. */
|
||||
int fmha_compute_smem(int hd) {
|
||||
using namespace dsv4::kernels::attention;
|
||||
int base = 4 + 4 + 4;
|
||||
base = (base + 15) & ~15;
|
||||
int sQ0_sz = 128 * hd * 2;
|
||||
int sK0_sz = 128 * hd * 2;
|
||||
int sPk_offset = ((base + sQ0_sz + sK0_sz) + 127) & ~127;
|
||||
int sPk_sz = 128 * 16 * 2;
|
||||
int sV_offset = ((sPk_offset + sPk_sz) + 127) & ~127;
|
||||
int sV_sz = 16 * 16 * 2;
|
||||
int sp_vals_offset = sV_offset + sV_sz;
|
||||
int sp_vals_sz = 128 * 4;
|
||||
int total = sp_vals_offset + sp_vals_sz;
|
||||
return (total + 127) & ~127;
|
||||
}
|
||||
|
||||
/**
|
||||
* Launch the 6-warp multi-head FMHA kernel for decode.
|
||||
*
|
||||
* All pointers are device pointers (BF16 for q/k/v/o, FP32 for lse).
|
||||
* Strides are in elements (not bytes).
|
||||
* Returns 0 on success, non-zero on error.
|
||||
*/
|
||||
int fmha_multihead_decode_launch(
|
||||
const void* q_ptr, // Q base: (batch, n_h, 1, hd) BF16
|
||||
const void* k_ptr, // K base: (batch, n_kv, N, hd) BF16
|
||||
const void* v_ptr, // V base: (batch, n_kv, hd, N) BF16
|
||||
void* o_ptr, // O base: (batch, n_h, 1, hd) BF16
|
||||
void* lse_ptr, // LSE base: (batch, n_h, 1) FP32 (can be NULL)
|
||||
int batch, int n_h, int n_kv, int N, int hd,
|
||||
int q_head_stride, int q_batch_stride,
|
||||
int k_head_stride, int k_batch_stride,
|
||||
int v_head_stride, int v_batch_stride,
|
||||
int o_head_stride, int o_batch_stride,
|
||||
int lse_head_stride, int lse_batch_stride,
|
||||
float scale
|
||||
) {
|
||||
using namespace dsv4::kernels::attention;
|
||||
|
||||
FmhaParams params;
|
||||
params.q = reinterpret_cast<const bf16_t*>(q_ptr);
|
||||
params.k = reinterpret_cast<const bf16_t*>(k_ptr);
|
||||
params.v = reinterpret_cast<const bf16_t*>(v_ptr);
|
||||
params.o = reinterpret_cast<bf16_t*>(o_ptr);
|
||||
params.lse = reinterpret_cast<float*>(lse_ptr);
|
||||
params.s_k = N;
|
||||
params.scale = scale;
|
||||
params.head_dim = hd;
|
||||
params.q_head_stride = q_head_stride;
|
||||
params.q_batch_stride = q_batch_stride;
|
||||
params.k_head_stride = k_head_stride;
|
||||
params.k_batch_stride = k_batch_stride;
|
||||
params.v_head_stride = v_head_stride;
|
||||
params.v_batch_stride = v_batch_stride;
|
||||
params.o_head_stride = o_head_stride;
|
||||
params.o_batch_stride = o_batch_stride;
|
||||
params.lse_head_stride = lse_head_stride;
|
||||
params.lse_batch_stride = lse_batch_stride;
|
||||
|
||||
int smem = fmha_compute_smem(hd);
|
||||
dim3 grid(1, n_h, batch);
|
||||
dim3 block(NTHREADS);
|
||||
|
||||
cudaError_t err;
|
||||
if (hd == 64) {
|
||||
fmha_6warp_multihead_kernel<64, 128><<<grid, block, smem>>>(params);
|
||||
} else if (hd == 128) {
|
||||
fmha_6warp_multihead_kernel<128, 128><<<grid, block, smem>>>(params);
|
||||
} else if (hd == 256) {
|
||||
fmha_6warp_multihead_kernel<256, 128><<<grid, block, smem>>>(params);
|
||||
} else {
|
||||
return -1; // unsupported hd
|
||||
}
|
||||
|
||||
err = cudaGetLastError();
|
||||
if (err != cudaSuccess) return (int)err;
|
||||
return 0;
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
@@ -1,8 +1,8 @@
|
||||
"""DSV4 FMHA — 6-warp multi-head decode kernel loader.
|
||||
|
||||
Loads the raw CUDA 6-warp multi-head FMHA kernel via
|
||||
torch.utils.cpp_extension and wraps it as a torch.library.custom_op
|
||||
for torch.compile compatibility.
|
||||
Precompiles the raw CUDA kernel with nvcc (sm_100a) on first use,
|
||||
then loads the .so via ctypes. This bypasses torch.utils.cpp_extension
|
||||
which compiles with -arch=sm_100 (missing tcgen05 support).
|
||||
|
||||
Decode-only: T=1, single KV segment (N <= 128).
|
||||
Supports MHA, MQA, and GQA attention patterns.
|
||||
@@ -11,57 +11,188 @@ Supports MHA, MQA, and GQA attention patterns.
|
||||
import torch
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import ctypes
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Lazy-load the CUDA extension
|
||||
# ---------------------------------------------------------------------------
|
||||
_ext = None
|
||||
_ext_lock = False
|
||||
KERNEL_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
REPO_ROOT = os.path.normpath(os.path.join(KERNEL_DIR, "..", ".."))
|
||||
SOURCE = os.path.join(KERNEL_DIR, "fmha_multihead_capi.cu")
|
||||
BUILD_DIR = os.path.join(REPO_ROOT, "build", "fmha_multihead")
|
||||
SO_NAME = "libfmha_multihead_decode.so"
|
||||
|
||||
_lib = None
|
||||
_lib_lock = False
|
||||
|
||||
|
||||
def _get_ext():
|
||||
"""Lazy-load the JIT-compiled CUDA extension."""
|
||||
global _ext, _ext_lock
|
||||
if _ext is not None:
|
||||
return _ext
|
||||
if _ext_lock:
|
||||
raise RuntimeError("Recursive extension load — check import cycle")
|
||||
def _find_nvcc():
|
||||
"""Find nvcc on the system."""
|
||||
for c in ["/usr/local/cuda-13.2/bin/nvcc", "/usr/local/cuda/bin/nvcc"]:
|
||||
if os.path.isfile(c):
|
||||
return c
|
||||
# Try PATH
|
||||
import shutil
|
||||
nvcc = shutil.which("nvcc")
|
||||
if nvcc:
|
||||
return nvcc
|
||||
raise RuntimeError("nvcc not found — required for tcgen05 kernel compilation")
|
||||
|
||||
_ext_lock = True
|
||||
|
||||
def _ensure_built():
|
||||
"""Build the shared library with nvcc if needed. Returns .so path."""
|
||||
global _lib
|
||||
if _lib is not None:
|
||||
return _lib
|
||||
|
||||
so_path = os.path.join(BUILD_DIR, SO_NAME)
|
||||
|
||||
# Check if rebuild needed
|
||||
need_build = True
|
||||
if os.path.isfile(so_path):
|
||||
src_mtime = os.path.getmtime(SOURCE)
|
||||
for dep in ["fmha_common.cuh", "fmha_umma_desc.cuh",
|
||||
"fmha_6warp_multihead.cuh", "fmha_multihead_capi.cu"]:
|
||||
dep_path = os.path.join(KERNEL_DIR, dep)
|
||||
if os.path.isfile(dep_path):
|
||||
src_mtime = max(src_mtime, os.path.getmtime(dep_path))
|
||||
need_build = src_mtime > os.path.getmtime(so_path)
|
||||
|
||||
if not need_build:
|
||||
logger.info(f"Using cached {so_path}")
|
||||
_lib = ctypes.CDLL(so_path)
|
||||
return _lib
|
||||
|
||||
logger.info(f"Building {SO_NAME} with nvcc (sm_100a)...")
|
||||
os.makedirs(BUILD_DIR, exist_ok=True)
|
||||
|
||||
nvcc = _find_nvcc()
|
||||
cmd = [
|
||||
nvcc,
|
||||
"-std=c++20",
|
||||
"-shared",
|
||||
"-fPIC",
|
||||
"-gencode=arch=compute_100a,code=sm_100a",
|
||||
"-gencode=arch=compute_100a,code=compute_100a",
|
||||
f"-I{KERNEL_DIR}",
|
||||
f"-I{REPO_ROOT}",
|
||||
"-O3",
|
||||
"--expt-relaxed-constexpr",
|
||||
SOURCE,
|
||||
"-o", so_path,
|
||||
"-lcudart",
|
||||
"-lcuda",
|
||||
]
|
||||
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"nvcc compilation failed:\n{result.stderr}")
|
||||
if result.stderr:
|
||||
logger.debug(f"nvcc warnings:\n{result.stderr}")
|
||||
|
||||
_lib = ctypes.CDLL(so_path)
|
||||
logger.info(f"Built and loaded {so_path}")
|
||||
return _lib
|
||||
|
||||
|
||||
def _get_lib():
|
||||
"""Get or build the shared library."""
|
||||
global _lib_lock
|
||||
if _lib is not None:
|
||||
return _lib
|
||||
if _lib_lock:
|
||||
raise RuntimeError("Recursive build")
|
||||
_lib_lock = True
|
||||
try:
|
||||
from torch.utils.cpp_extension import load
|
||||
|
||||
kernel_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
sources = [
|
||||
os.path.join(kernel_dir, "fmha_multihead_launch.cu"),
|
||||
]
|
||||
extra_cflags = ["-O2"]
|
||||
extra_cuda_cflags = [
|
||||
"-arch=sm_100a",
|
||||
"-std=c++20",
|
||||
"--expt-relaxed-constexpr",
|
||||
"-O3",
|
||||
"--generate-code=arch=compute_100a,code=[sm_100a,compute_100a]",
|
||||
]
|
||||
# The .cuh includes are relative to the kernel_dir
|
||||
extra_include_paths = [kernel_dir, os.path.join(kernel_dir, "..", "..")]
|
||||
|
||||
logger.info("JIT-compiling fmha_multihead_decode extension...")
|
||||
_ext = load(
|
||||
name="fmha_multihead_decode_ext",
|
||||
sources=sources,
|
||||
extra_cflags=extra_cflags,
|
||||
extra_cuda_cflags=extra_cuda_cflags,
|
||||
extra_include_paths=extra_include_paths,
|
||||
verbose=False,
|
||||
)
|
||||
logger.info("fmha_multihead_decode extension loaded.")
|
||||
return _ext
|
||||
return _ensure_built()
|
||||
finally:
|
||||
_ext_lock = False
|
||||
_lib_lock = False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Kernel launch via ctypes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def fmha_multihead_decode_raw(
|
||||
q: torch.Tensor, # (batch, n_h, 1, hd) BF16, contiguous
|
||||
k: torch.Tensor, # (batch, n_kv, N, hd) BF16, contiguous
|
||||
v: torch.Tensor, # (batch, n_kv, hd, N) BF16, contiguous
|
||||
scale: float,
|
||||
n_comp: int,
|
||||
swa_len: int,
|
||||
is_causal: bool,
|
||||
attn_sink: torch.Tensor, # (batch, n_h) FP32 — unused by kernel currently
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Launch the 6-warp multi-head FMHA kernel. Returns (O, LSE).
|
||||
|
||||
O: (batch, n_h, 1, hd) BF16
|
||||
LSE: (batch, n_h, 1) FP32
|
||||
"""
|
||||
lib = _get_lib()
|
||||
|
||||
B = q.shape[0]
|
||||
n_h = q.shape[1]
|
||||
hd = q.shape[3]
|
||||
n_kv = k.shape[1]
|
||||
N = k.shape[2]
|
||||
|
||||
assert q.shape[2] == 1, f"Decode requires T=1, got T={q.shape[2]}"
|
||||
assert hd in (64, 128, 256), f"Unsupported hd={hd}"
|
||||
assert N <= 128, f"Decode fast path requires N<=128, got N={N}"
|
||||
assert q.is_contiguous()
|
||||
assert k.is_contiguous()
|
||||
assert v.is_contiguous()
|
||||
|
||||
o = torch.zeros(B, n_h, 1, hd, dtype=torch.bfloat16, device=q.device)
|
||||
lse = torch.zeros(B, n_h, 1, dtype=torch.float32, device=q.device)
|
||||
|
||||
# Compute strides in BF16 elements
|
||||
q_hs = q.stride(1) # head stride
|
||||
q_bs = q.stride(0) # batch stride
|
||||
k_hs = k.stride(1)
|
||||
k_bs = k.stride(0)
|
||||
v_hs = v.stride(1)
|
||||
v_bs = v.stride(0)
|
||||
o_hs = o.stride(1)
|
||||
o_bs = o.stride(0)
|
||||
lse_hs = lse.stride(1)
|
||||
lse_bs = lse.stride(0)
|
||||
|
||||
# MQA: zero out KV head strides
|
||||
if n_kv == 1:
|
||||
k_hs = 0
|
||||
v_hs = 0
|
||||
|
||||
# Call the C API
|
||||
ret = lib.fmha_multihead_decode_launch(
|
||||
ctypes.c_void_p(q.data_ptr()),
|
||||
ctypes.c_void_p(k.data_ptr()),
|
||||
ctypes.c_void_p(v.data_ptr()),
|
||||
ctypes.c_void_p(o.data_ptr()),
|
||||
ctypes.c_void_p(lse.data_ptr()),
|
||||
ctypes.c_int(B),
|
||||
ctypes.c_int(n_h),
|
||||
ctypes.c_int(n_kv),
|
||||
ctypes.c_int(N),
|
||||
ctypes.c_int(hd),
|
||||
ctypes.c_int(q_hs),
|
||||
ctypes.c_int(q_bs),
|
||||
ctypes.c_int(k_hs),
|
||||
ctypes.c_int(k_bs),
|
||||
ctypes.c_int(v_hs),
|
||||
ctypes.c_int(v_bs),
|
||||
ctypes.c_int(o_hs),
|
||||
ctypes.c_int(o_bs),
|
||||
ctypes.c_int(lse_hs),
|
||||
ctypes.c_int(lse_bs),
|
||||
ctypes.c_float(scale),
|
||||
)
|
||||
|
||||
if ret != 0:
|
||||
raise RuntimeError(f"Kernel launch failed with code {ret}")
|
||||
|
||||
return o, lse
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -70,23 +201,16 @@ def _get_ext():
|
||||
|
||||
@torch.library.custom_op("dsv4::fmha_multihead_decode", mutates_args=())
|
||||
def fmha_multihead_decode(
|
||||
q: torch.Tensor, # (batch, n_h, 1, hd) BF16
|
||||
k: torch.Tensor, # (batch, n_kv, N, hd) BF16
|
||||
v: torch.Tensor, # (batch, n_kv, hd, N) BF16
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
scale: float,
|
||||
n_comp: int,
|
||||
swa_len: int,
|
||||
is_causal: bool,
|
||||
attn_sink: torch.Tensor, # (batch, n_h) FP32 — zeros when unused
|
||||
attn_sink: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""6-warp multi-head FMHA decode kernel (T=1, single KV tile).
|
||||
|
||||
Returns normalized attention output (batch, n_h, 1, hd) BF16.
|
||||
For single-segment decode, normalization is done in-kernel.
|
||||
"""
|
||||
ext = _get_ext()
|
||||
# The C++ extension returns (O, LSE); we only need O for the fast path
|
||||
o, _lse = ext.fmha_multihead_decode(
|
||||
o, _ = fmha_multihead_decode_raw(
|
||||
q, k, v, scale, n_comp, swa_len, is_causal, attn_sink
|
||||
)
|
||||
return o
|
||||
@@ -97,51 +221,16 @@ def _(q, k, v, scale, n_comp, swa_len, is_causal, attn_sink):
|
||||
return torch.empty_like(q)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Convenience: run with LSE output (for multi-segment merge later)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def fmha_multihead_decode_with_lse(
|
||||
q: torch.Tensor, # (batch, n_h, 1, hd) BF16
|
||||
k: torch.Tensor, # (batch, n_kv, N, hd) BF16
|
||||
v: torch.Tensor, # (batch, n_kv, hd, N) BF16
|
||||
scale: float,
|
||||
n_comp: int = 0,
|
||||
swa_len: int = 0,
|
||||
is_causal: bool = False,
|
||||
attn_sink: Optional[torch.Tensor] = None, # (batch, n_h) FP32
|
||||
q, k, v, scale, n_comp=0, swa_len=0, is_causal=False, attn_sink=None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Run 6-warp decode kernel, returning both O and LSE.
|
||||
|
||||
Returns:
|
||||
O: (batch, n_h, 1, hd) BF16 — normalized attention output
|
||||
LSE: (batch, n_h, 1) FP32 — log-sum-exp for multi-segment merge
|
||||
"""
|
||||
if attn_sink is None:
|
||||
attn_sink = torch.zeros(q.shape[0], q.shape[1],
|
||||
dtype=torch.float32, device=q.device)
|
||||
ext = _get_ext()
|
||||
o, lse = ext.fmha_multihead_decode(
|
||||
return fmha_multihead_decode_raw(
|
||||
q, k, v, scale, n_comp, swa_len, is_causal, attn_sink
|
||||
)
|
||||
return o, lse
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shape check: is the 6-warp fast path applicable?
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def can_use_6warp_decode(
|
||||
T: int,
|
||||
N: int,
|
||||
hd: int,
|
||||
n_segments: int,
|
||||
) -> bool:
|
||||
"""Check if the 6-warp decode fast path is applicable.
|
||||
|
||||
Fast path requires:
|
||||
- T == 1 (decode)
|
||||
- n_segments == 1 (single KV tile, N <= 128)
|
||||
- hd in {64, 128, 256}
|
||||
"""
|
||||
def can_use_6warp_decode(T: int, N: int, hd: int, n_segments: int) -> bool:
|
||||
return T == 1 and n_segments == 1 and hd in (64, 128, 256)
|
||||
|
||||
@@ -86,7 +86,26 @@ def test_fast_path_matches_reference():
|
||||
v = torch.randn(n_kv, N, hd, dtype=torch.bfloat16, device='cuda')
|
||||
|
||||
try:
|
||||
o_fast = dsv4_attention(q, k, v, scale=scale)
|
||||
from dsv4.kernels.attention.fmha_multihead_op import fmha_multihead_decode_raw
|
||||
|
||||
# Prepare tensors in the shape the kernel expects:
|
||||
# Q: (1, n_q, 1, hd) BF16
|
||||
# K: (1, n_kv, N, hd) BF16
|
||||
# V: (1, n_kv, hd, N) BF16 (transposed!)
|
||||
if n_kv == 1:
|
||||
q_4d = q.unsqueeze(0).contiguous()
|
||||
k_4d = k.unsqueeze(0).unsqueeze(0).contiguous()
|
||||
v_4d = v.unsqueeze(0).unsqueeze(0).transpose(-1, -2).contiguous()
|
||||
else:
|
||||
q_4d = q.unsqueeze(0).contiguous()
|
||||
k_4d = k.unsqueeze(0).contiguous()
|
||||
v_4d = v.unsqueeze(0).transpose(-1, -2).contiguous()
|
||||
|
||||
sb = torch.zeros(1, n_q, dtype=torch.float32, device='cuda')
|
||||
o_4d, lse_4d = fmha_multihead_decode_raw(
|
||||
q_4d, k_4d, v_4d, scale, 0, 0, False, sb
|
||||
)
|
||||
o_fast = o_4d.squeeze(0) # (n_q, 1, hd)
|
||||
o_ref = reference_attention(q, k, v, scale)
|
||||
cos = cosine_sim(o_ref, o_fast).item()
|
||||
status = "PASS" if cos >= 0.999998 else "FAIL"
|
||||
@@ -94,7 +113,9 @@ def test_fast_path_matches_reference():
|
||||
all_pass = False
|
||||
print(f" {status} {desc}: cos={cos:.6f}")
|
||||
except Exception as e:
|
||||
import traceback
|
||||
print(f" FAIL {desc}: {e}")
|
||||
traceback.print_exc()
|
||||
all_pass = False
|
||||
|
||||
return all_pass
|
||||
|
||||
Reference in New Issue
Block a user