diff --git a/dsv4/kernels/attention/fmha_multihead_capi.cu b/dsv4/kernels/attention/fmha_multihead_capi.cu new file mode 100644 index 00000000..fff23852 --- /dev/null +++ b/dsv4/kernels/attention/fmha_multihead_capi.cu @@ -0,0 +1,104 @@ +/** + * DSV4 FMHA Multi-Head — C API for ctypes loading. + * + * This provides a pure C API (no pybind11, no ATen) so we can compile + * with nvcc -arch=sm_100a and load via Python ctypes. PyTorch tensors + * are passed as raw device pointers + shapes + strides. + * + * The Python side handles tensor → (ptr, shape, stride) conversion + * and wraps the result back into torch tensors. + */ + +#include +#include + +// Forward declaration of the kernel (from fmha_6warp_multihead.cuh) +// We need the template instantiations + +#include "fmha_common.cuh" +#include "fmha_umma_desc.cuh" +#include "fmha_6warp_multihead.cuh" + +extern "C" { + +/** Compute dynamic SMEM size for the 6-warp multi-head kernel. */ +int fmha_compute_smem(int hd) { + using namespace dsv4::kernels::attention; + int base = 4 + 4 + 4; + base = (base + 15) & ~15; + int sQ0_sz = 128 * hd * 2; + int sK0_sz = 128 * hd * 2; + int sPk_offset = ((base + sQ0_sz + sK0_sz) + 127) & ~127; + int sPk_sz = 128 * 16 * 2; + int sV_offset = ((sPk_offset + sPk_sz) + 127) & ~127; + int sV_sz = 16 * 16 * 2; + int sp_vals_offset = sV_offset + sV_sz; + int sp_vals_sz = 128 * 4; + int total = sp_vals_offset + sp_vals_sz; + return (total + 127) & ~127; +} + +/** + * Launch the 6-warp multi-head FMHA kernel for decode. + * + * All pointers are device pointers (BF16 for q/k/v/o, FP32 for lse). + * Strides are in elements (not bytes). + * Returns 0 on success, non-zero on error. + */ +int fmha_multihead_decode_launch( + const void* q_ptr, // Q base: (batch, n_h, 1, hd) BF16 + const void* k_ptr, // K base: (batch, n_kv, N, hd) BF16 + const void* v_ptr, // V base: (batch, n_kv, hd, N) BF16 + void* o_ptr, // O base: (batch, n_h, 1, hd) BF16 + void* lse_ptr, // LSE base: (batch, n_h, 1) FP32 (can be NULL) + int batch, int n_h, int n_kv, int N, int hd, + int q_head_stride, int q_batch_stride, + int k_head_stride, int k_batch_stride, + int v_head_stride, int v_batch_stride, + int o_head_stride, int o_batch_stride, + int lse_head_stride, int lse_batch_stride, + float scale +) { + using namespace dsv4::kernels::attention; + + FmhaParams params; + params.q = reinterpret_cast(q_ptr); + params.k = reinterpret_cast(k_ptr); + params.v = reinterpret_cast(v_ptr); + params.o = reinterpret_cast(o_ptr); + params.lse = reinterpret_cast(lse_ptr); + params.s_k = N; + params.scale = scale; + params.head_dim = hd; + params.q_head_stride = q_head_stride; + params.q_batch_stride = q_batch_stride; + params.k_head_stride = k_head_stride; + params.k_batch_stride = k_batch_stride; + params.v_head_stride = v_head_stride; + params.v_batch_stride = v_batch_stride; + params.o_head_stride = o_head_stride; + params.o_batch_stride = o_batch_stride; + params.lse_head_stride = lse_head_stride; + params.lse_batch_stride = lse_batch_stride; + + int smem = fmha_compute_smem(hd); + dim3 grid(1, n_h, batch); + dim3 block(NTHREADS); + + cudaError_t err; + if (hd == 64) { + fmha_6warp_multihead_kernel<64, 128><<>>(params); + } else if (hd == 128) { + fmha_6warp_multihead_kernel<128, 128><<>>(params); + } else if (hd == 256) { + fmha_6warp_multihead_kernel<256, 128><<>>(params); + } else { + return -1; // unsupported hd + } + + err = cudaGetLastError(); + if (err != cudaSuccess) return (int)err; + return 0; +} + +} // extern "C" diff --git a/dsv4/kernels/attention/fmha_multihead_op.py b/dsv4/kernels/attention/fmha_multihead_op.py index efaf0b15..95f595a6 100644 --- a/dsv4/kernels/attention/fmha_multihead_op.py +++ b/dsv4/kernels/attention/fmha_multihead_op.py @@ -1,8 +1,8 @@ """DSV4 FMHA — 6-warp multi-head decode kernel loader. -Loads the raw CUDA 6-warp multi-head FMHA kernel via -torch.utils.cpp_extension and wraps it as a torch.library.custom_op -for torch.compile compatibility. +Precompiles the raw CUDA kernel with nvcc (sm_100a) on first use, +then loads the .so via ctypes. This bypasses torch.utils.cpp_extension +which compiles with -arch=sm_100 (missing tcgen05 support). Decode-only: T=1, single KV segment (N <= 128). Supports MHA, MQA, and GQA attention patterns. @@ -11,57 +11,188 @@ Supports MHA, MQA, and GQA attention patterns. import torch import logging import os +import subprocess +import ctypes from typing import Optional logger = logging.getLogger(__name__) -# --------------------------------------------------------------------------- -# Lazy-load the CUDA extension -# --------------------------------------------------------------------------- -_ext = None -_ext_lock = False +KERNEL_DIR = os.path.dirname(os.path.abspath(__file__)) +REPO_ROOT = os.path.normpath(os.path.join(KERNEL_DIR, "..", "..")) +SOURCE = os.path.join(KERNEL_DIR, "fmha_multihead_capi.cu") +BUILD_DIR = os.path.join(REPO_ROOT, "build", "fmha_multihead") +SO_NAME = "libfmha_multihead_decode.so" + +_lib = None +_lib_lock = False -def _get_ext(): - """Lazy-load the JIT-compiled CUDA extension.""" - global _ext, _ext_lock - if _ext is not None: - return _ext - if _ext_lock: - raise RuntimeError("Recursive extension load — check import cycle") +def _find_nvcc(): + """Find nvcc on the system.""" + for c in ["/usr/local/cuda-13.2/bin/nvcc", "/usr/local/cuda/bin/nvcc"]: + if os.path.isfile(c): + return c + # Try PATH + import shutil + nvcc = shutil.which("nvcc") + if nvcc: + return nvcc + raise RuntimeError("nvcc not found — required for tcgen05 kernel compilation") - _ext_lock = True + +def _ensure_built(): + """Build the shared library with nvcc if needed. Returns .so path.""" + global _lib + if _lib is not None: + return _lib + + so_path = os.path.join(BUILD_DIR, SO_NAME) + + # Check if rebuild needed + need_build = True + if os.path.isfile(so_path): + src_mtime = os.path.getmtime(SOURCE) + for dep in ["fmha_common.cuh", "fmha_umma_desc.cuh", + "fmha_6warp_multihead.cuh", "fmha_multihead_capi.cu"]: + dep_path = os.path.join(KERNEL_DIR, dep) + if os.path.isfile(dep_path): + src_mtime = max(src_mtime, os.path.getmtime(dep_path)) + need_build = src_mtime > os.path.getmtime(so_path) + + if not need_build: + logger.info(f"Using cached {so_path}") + _lib = ctypes.CDLL(so_path) + return _lib + + logger.info(f"Building {SO_NAME} with nvcc (sm_100a)...") + os.makedirs(BUILD_DIR, exist_ok=True) + + nvcc = _find_nvcc() + cmd = [ + nvcc, + "-std=c++20", + "-shared", + "-fPIC", + "-gencode=arch=compute_100a,code=sm_100a", + "-gencode=arch=compute_100a,code=compute_100a", + f"-I{KERNEL_DIR}", + f"-I{REPO_ROOT}", + "-O3", + "--expt-relaxed-constexpr", + SOURCE, + "-o", so_path, + "-lcudart", + "-lcuda", + ] + + result = subprocess.run(cmd, capture_output=True, text=True) + if result.returncode != 0: + raise RuntimeError(f"nvcc compilation failed:\n{result.stderr}") + if result.stderr: + logger.debug(f"nvcc warnings:\n{result.stderr}") + + _lib = ctypes.CDLL(so_path) + logger.info(f"Built and loaded {so_path}") + return _lib + + +def _get_lib(): + """Get or build the shared library.""" + global _lib_lock + if _lib is not None: + return _lib + if _lib_lock: + raise RuntimeError("Recursive build") + _lib_lock = True try: - from torch.utils.cpp_extension import load - - kernel_dir = os.path.dirname(os.path.abspath(__file__)) - sources = [ - os.path.join(kernel_dir, "fmha_multihead_launch.cu"), - ] - extra_cflags = ["-O2"] - extra_cuda_cflags = [ - "-arch=sm_100a", - "-std=c++20", - "--expt-relaxed-constexpr", - "-O3", - "--generate-code=arch=compute_100a,code=[sm_100a,compute_100a]", - ] - # The .cuh includes are relative to the kernel_dir - extra_include_paths = [kernel_dir, os.path.join(kernel_dir, "..", "..")] - - logger.info("JIT-compiling fmha_multihead_decode extension...") - _ext = load( - name="fmha_multihead_decode_ext", - sources=sources, - extra_cflags=extra_cflags, - extra_cuda_cflags=extra_cuda_cflags, - extra_include_paths=extra_include_paths, - verbose=False, - ) - logger.info("fmha_multihead_decode extension loaded.") - return _ext + return _ensure_built() finally: - _ext_lock = False + _lib_lock = False + + +# --------------------------------------------------------------------------- +# Kernel launch via ctypes +# --------------------------------------------------------------------------- + +def fmha_multihead_decode_raw( + q: torch.Tensor, # (batch, n_h, 1, hd) BF16, contiguous + k: torch.Tensor, # (batch, n_kv, N, hd) BF16, contiguous + v: torch.Tensor, # (batch, n_kv, hd, N) BF16, contiguous + scale: float, + n_comp: int, + swa_len: int, + is_causal: bool, + attn_sink: torch.Tensor, # (batch, n_h) FP32 — unused by kernel currently +) -> tuple[torch.Tensor, torch.Tensor]: + """Launch the 6-warp multi-head FMHA kernel. Returns (O, LSE). + + O: (batch, n_h, 1, hd) BF16 + LSE: (batch, n_h, 1) FP32 + """ + lib = _get_lib() + + B = q.shape[0] + n_h = q.shape[1] + hd = q.shape[3] + n_kv = k.shape[1] + N = k.shape[2] + + assert q.shape[2] == 1, f"Decode requires T=1, got T={q.shape[2]}" + assert hd in (64, 128, 256), f"Unsupported hd={hd}" + assert N <= 128, f"Decode fast path requires N<=128, got N={N}" + assert q.is_contiguous() + assert k.is_contiguous() + assert v.is_contiguous() + + o = torch.zeros(B, n_h, 1, hd, dtype=torch.bfloat16, device=q.device) + lse = torch.zeros(B, n_h, 1, dtype=torch.float32, device=q.device) + + # Compute strides in BF16 elements + q_hs = q.stride(1) # head stride + q_bs = q.stride(0) # batch stride + k_hs = k.stride(1) + k_bs = k.stride(0) + v_hs = v.stride(1) + v_bs = v.stride(0) + o_hs = o.stride(1) + o_bs = o.stride(0) + lse_hs = lse.stride(1) + lse_bs = lse.stride(0) + + # MQA: zero out KV head strides + if n_kv == 1: + k_hs = 0 + v_hs = 0 + + # Call the C API + ret = lib.fmha_multihead_decode_launch( + ctypes.c_void_p(q.data_ptr()), + ctypes.c_void_p(k.data_ptr()), + ctypes.c_void_p(v.data_ptr()), + ctypes.c_void_p(o.data_ptr()), + ctypes.c_void_p(lse.data_ptr()), + ctypes.c_int(B), + ctypes.c_int(n_h), + ctypes.c_int(n_kv), + ctypes.c_int(N), + ctypes.c_int(hd), + ctypes.c_int(q_hs), + ctypes.c_int(q_bs), + ctypes.c_int(k_hs), + ctypes.c_int(k_bs), + ctypes.c_int(v_hs), + ctypes.c_int(v_bs), + ctypes.c_int(o_hs), + ctypes.c_int(o_bs), + ctypes.c_int(lse_hs), + ctypes.c_int(lse_bs), + ctypes.c_float(scale), + ) + + if ret != 0: + raise RuntimeError(f"Kernel launch failed with code {ret}") + + return o, lse # --------------------------------------------------------------------------- @@ -70,23 +201,16 @@ def _get_ext(): @torch.library.custom_op("dsv4::fmha_multihead_decode", mutates_args=()) def fmha_multihead_decode( - q: torch.Tensor, # (batch, n_h, 1, hd) BF16 - k: torch.Tensor, # (batch, n_kv, N, hd) BF16 - v: torch.Tensor, # (batch, n_kv, hd, N) BF16 + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, scale: float, n_comp: int, swa_len: int, is_causal: bool, - attn_sink: torch.Tensor, # (batch, n_h) FP32 — zeros when unused + attn_sink: torch.Tensor, ) -> torch.Tensor: - """6-warp multi-head FMHA decode kernel (T=1, single KV tile). - - Returns normalized attention output (batch, n_h, 1, hd) BF16. - For single-segment decode, normalization is done in-kernel. - """ - ext = _get_ext() - # The C++ extension returns (O, LSE); we only need O for the fast path - o, _lse = ext.fmha_multihead_decode( + o, _ = fmha_multihead_decode_raw( q, k, v, scale, n_comp, swa_len, is_causal, attn_sink ) return o @@ -97,51 +221,16 @@ def _(q, k, v, scale, n_comp, swa_len, is_causal, attn_sink): return torch.empty_like(q) -# --------------------------------------------------------------------------- -# Convenience: run with LSE output (for multi-segment merge later) -# --------------------------------------------------------------------------- - def fmha_multihead_decode_with_lse( - q: torch.Tensor, # (batch, n_h, 1, hd) BF16 - k: torch.Tensor, # (batch, n_kv, N, hd) BF16 - v: torch.Tensor, # (batch, n_kv, hd, N) BF16 - scale: float, - n_comp: int = 0, - swa_len: int = 0, - is_causal: bool = False, - attn_sink: Optional[torch.Tensor] = None, # (batch, n_h) FP32 + q, k, v, scale, n_comp=0, swa_len=0, is_causal=False, attn_sink=None, ) -> tuple[torch.Tensor, torch.Tensor]: - """Run 6-warp decode kernel, returning both O and LSE. - - Returns: - O: (batch, n_h, 1, hd) BF16 — normalized attention output - LSE: (batch, n_h, 1) FP32 — log-sum-exp for multi-segment merge - """ if attn_sink is None: attn_sink = torch.zeros(q.shape[0], q.shape[1], dtype=torch.float32, device=q.device) - ext = _get_ext() - o, lse = ext.fmha_multihead_decode( + return fmha_multihead_decode_raw( q, k, v, scale, n_comp, swa_len, is_causal, attn_sink ) - return o, lse -# --------------------------------------------------------------------------- -# Shape check: is the 6-warp fast path applicable? -# --------------------------------------------------------------------------- - -def can_use_6warp_decode( - T: int, - N: int, - hd: int, - n_segments: int, -) -> bool: - """Check if the 6-warp decode fast path is applicable. - - Fast path requires: - - T == 1 (decode) - - n_segments == 1 (single KV tile, N <= 128) - - hd in {64, 128, 256} - """ +def can_use_6warp_decode(T: int, N: int, hd: int, n_segments: int) -> bool: return T == 1 and n_segments == 1 and hd in (64, 128, 256) diff --git a/tests/unit/test_p3_fast_decode.py b/tests/unit/test_p3_fast_decode.py index 87ef7624..581a7c85 100644 --- a/tests/unit/test_p3_fast_decode.py +++ b/tests/unit/test_p3_fast_decode.py @@ -86,7 +86,26 @@ def test_fast_path_matches_reference(): v = torch.randn(n_kv, N, hd, dtype=torch.bfloat16, device='cuda') try: - o_fast = dsv4_attention(q, k, v, scale=scale) + from dsv4.kernels.attention.fmha_multihead_op import fmha_multihead_decode_raw + + # Prepare tensors in the shape the kernel expects: + # Q: (1, n_q, 1, hd) BF16 + # K: (1, n_kv, N, hd) BF16 + # V: (1, n_kv, hd, N) BF16 (transposed!) + if n_kv == 1: + q_4d = q.unsqueeze(0).contiguous() + k_4d = k.unsqueeze(0).unsqueeze(0).contiguous() + v_4d = v.unsqueeze(0).unsqueeze(0).transpose(-1, -2).contiguous() + else: + q_4d = q.unsqueeze(0).contiguous() + k_4d = k.unsqueeze(0).contiguous() + v_4d = v.unsqueeze(0).transpose(-1, -2).contiguous() + + sb = torch.zeros(1, n_q, dtype=torch.float32, device='cuda') + o_4d, lse_4d = fmha_multihead_decode_raw( + q_4d, k_4d, v_4d, scale, 0, 0, False, sb + ) + o_fast = o_4d.squeeze(0) # (n_q, 1, hd) o_ref = reference_attention(q, k, v, scale) cos = cosine_sim(o_ref, o_fast).item() status = "PASS" if cos >= 0.999998 else "FAIL" @@ -94,7 +113,9 @@ def test_fast_path_matches_reference(): all_pass = False print(f" {status} {desc}: cos={cos:.6f}") except Exception as e: + import traceback print(f" FAIL {desc}: {e}") + traceback.print_exc() all_pass = False return all_pass