From a9d5e09f4cdcab9d6a64a75f365b4637b79fd557 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Tue, 2 Jun 2026 22:53:14 +0000 Subject: [PATCH] B1: mixed FP8/BF16 decode FMHA integration - New: fmha_mixed_fp8_decode.cuh (Blackwell FP8 tensor-core FMHA kernel) - New: fmha_mixed_fp8_capi.cu (C ABI launcher) - New: fmha_mixed_fp8_op.py (Python ctypes/nvcc bridge) - New: fp8_attention_io.cu (Q quantize + mixed KV gather kernels) - New: fmha_umma_desc.cuh additions (f8f6f4 UMMA + idesc helpers) - Modified: production.py (dsv4_attention_mixed_fp8_decode API) - Modified: single_shot_inference.py (B1 gather + FMHA path) - Modified: __init__.py (export mixed FP8 API) - New: docs/B1_MIXED_FP8_FMHA.md, FINAL_STRETCH.md noPE KV stays FP8_E4M3 + per-row scale, RoPE stays BF16. No global FP8->BF16 KV staging before FMHA. Decode-only (T==1), specialized HD=512/NOPE=448/ROPE=64. CUDA compile/runtime validation pending on B200. --- .gitignore | 1 + FINAL_STRETCH.md | 3 + docs/B1_MIXED_FP8_FMHA.md | 55 +++ dsv4/kernels/attention/__init__.py | 1 + dsv4/kernels/attention/fmha_mixed_fp8_capi.cu | 79 ++++ .../attention/fmha_mixed_fp8_decode.cuh | 374 ++++++++++++++++++ dsv4/kernels/attention/fmha_mixed_fp8_op.py | 148 +++++++ dsv4/kernels/attention/fmha_umma_desc.cuh | 27 ++ dsv4/kernels/attention/production.py | 38 ++ dsv4/kernels/cuda/fp8_attention_io.cu | 254 ++++++++++++ single_shot_inference.py | 152 +++++-- 11 files changed, 1095 insertions(+), 37 deletions(-) create mode 100644 docs/B1_MIXED_FP8_FMHA.md create mode 100644 dsv4/kernels/attention/fmha_mixed_fp8_capi.cu create mode 100644 dsv4/kernels/attention/fmha_mixed_fp8_decode.cuh create mode 100644 dsv4/kernels/attention/fmha_mixed_fp8_op.py create mode 100644 dsv4/kernels/cuda/fp8_attention_io.cu diff --git a/.gitignore b/.gitignore index 9f7983e8..1e794850 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ __pycache__/ *.pyc *.egg-info/ +nvfp4-megamoe-kernel-*.zip diff --git a/FINAL_STRETCH.md b/FINAL_STRETCH.md index 6662f750..11341cfc 100644 --- a/FINAL_STRETCH.md +++ b/FINAL_STRETCH.md @@ -25,6 +25,9 @@ Goal: native NVFP4 where the math allows, FP8_E4M3 where it doesn't, BF16/FP32 o - Prevents infinite spin after crash/kill during CUDA kernel compilation ## B1 — FP8_E4M3 FMHA (BIG win; perf + memory + native Blackwell) + +> Implementation note from ChatGPT B1 pass: a decode-only mixed FP8/BF16 FMHA path has been added. See `docs/B1_MIXED_FP8_FMHA.md`. CUDA compile/runtime validation still needs to be run on a Blackwell box with `nvcc`. + Today: KV is *stored* mixed (FP8 nope + BF16 rope), then in "5. Gather KV" it's **dequantized to BF16** into `gbuf`, and the FMHA runs in **BF16**. That throws away the FP8 you stored and runs the heaviest kernel at half the tensor-core throughput Blackwell offers. NVFP4 KV is correctly ruled out — your own `KVCache` docstring shows 4-bit KV values cost ~0.4%/round-trip that compounds fatally over 61 layers. **FP8_E4M3 is the right target**, and you already store the nope dims in it. Plan: diff --git a/docs/B1_MIXED_FP8_FMHA.md b/docs/B1_MIXED_FP8_FMHA.md new file mode 100644 index 00000000..23e6705a --- /dev/null +++ b/docs/B1_MIXED_FP8_FMHA.md @@ -0,0 +1,55 @@ +# B1 Mixed FP8/BF16 FMHA first pass + +Implemented a decode-only DeepSeek-V4 attention path that keeps the cache in the paper/native storage format: + +- noPE KV: FP8_E4M3 bytes plus per-row FP32 scale +- RoPE KV: BF16 +- Q noPE: quantized BF16 -> FP8_E4M3 immediately before FMHA +- Q RoPE: BF16 + +The live `forward_attention` path now gathers compressed rows and the SWA tail into mixed buffers and calls `dsv4_attention_mixed_fp8_decode`; it no longer dequantizes noPE KV into `gather_buf` before attention. + +## New files + +- `dsv4/kernels/cuda/fp8_attention_io.cu` + - `quantize_q_fp8_split` + - `gather_mixed_selective_` + - `gather_mixed_all_` + - `gather_mixed_swa_only_` +- `dsv4/kernels/attention/fmha_mixed_fp8_decode.cuh` + - decode kernel, specialized for `HD=512`, `NOPE=448`, `ROPE=64` +- `dsv4/kernels/attention/fmha_mixed_fp8_capi.cu` + - C ABI launcher +- `dsv4/kernels/attention/fmha_mixed_fp8_op.py` + - Python ctypes/nvcc bridge + +## Modified files + +- `dsv4/kernels/attention/fmha_umma_desc.cuh` + - added `.kind::f8f6f4` UMMA wrapper and E4M3/E4M3 instruction descriptor helper +- `dsv4/kernels/attention/production.py` + - added `dsv4_attention_mixed_fp8_decode` +- `dsv4/kernels/attention/__init__.py` + - exported mixed FP8 API +- `single_shot_inference.py` + - added mixed gather buffers/methods to `KVCache` + - changed step 5 gather to preserve FP8 noPE globally + - changed step 6 FMHA to call the mixed FP8 decode path + +## Intentional first-pass limits + +- Decode only (`T == 1`). The launcher hard-errors for prefill. +- Specialized to DeepSeek-V4 attention dimensions (`512/448/64`). +- noPE QK uses Blackwell FP8 tensor cores; RoPE QK and PV use BF16 tensor cores. +- noPE V is dequantized only inside shared memory immediately before the PV BF16 tensor-core multiply. There is no global BF16 KV staging. + +## Validation status + +The sandbox used to make this patch does not have `nvcc`, so CUDA compilation/runtime validation was not possible here. Python syntax was checked with: + +```bash +python3 -m py_compile single_shot_inference.py \ + dsv4/kernels/attention/production.py \ + dsv4/kernels/attention/fmha_mixed_fp8_op.py +``` + diff --git a/dsv4/kernels/attention/__init__.py b/dsv4/kernels/attention/__init__.py index e13a048a..1401e883 100644 --- a/dsv4/kernels/attention/__init__.py +++ b/dsv4/kernels/attention/__init__.py @@ -4,3 +4,4 @@ The live inference path uses dsv4.kernels.attention.production directly. See production.py for the dsv4_attention function used by single_shot_inference.py. """ from dsv4.kernels.attention.production import dsv4_attention +from dsv4.kernels.attention.production import dsv4_attention_mixed_fp8_decode diff --git a/dsv4/kernels/attention/fmha_mixed_fp8_capi.cu b/dsv4/kernels/attention/fmha_mixed_fp8_capi.cu new file mode 100644 index 00000000..f5d507e5 --- /dev/null +++ b/dsv4/kernels/attention/fmha_mixed_fp8_capi.cu @@ -0,0 +1,79 @@ +#include +#include +#include +#include "fmha_common.cuh" +#include "fmha_umma_desc.cuh" +#include "fmha_mixed_fp8_decode.cuh" + +using namespace dsv4::kernels::attention; + +extern "C" { + +int fmha_mixed_fp8_decode_launch( + const void* q_nope_fp8, + const float* q_nope_scale, + const void* q_rope_bf16, + const void* k_nope_fp8, + const float* k_nope_scale, + const void* k_rope_bf16, + void* o_ptr, + void* lse_ptr, + const float* sink_bias_ptr, + int B, int H, int T, int N, int HD, int NOPE, int ROPE, + int q_nope_head_stride, int q_nope_batch_stride, + int q_scale_head_stride, int q_scale_batch_stride, + int q_rope_head_stride, int q_rope_batch_stride, + int o_head_stride, int o_batch_stride, + int lse_head_stride, int lse_batch_stride, + float scale +) { + if (T != 1 || HD != 512 || NOPE != 448 || ROPE != 64) return -2; + + FmhaMixedFp8DecodeParams p; + p.q_nope_fp8 = (const uint8_t*)q_nope_fp8; + p.q_nope_scale = q_nope_scale; + p.q_rope_bf16 = (const bf16_t*)q_rope_bf16; + p.k_nope_fp8 = (const uint8_t*)k_nope_fp8; + p.k_nope_scale = k_nope_scale; + p.k_rope_bf16 = (const bf16_t*)k_rope_bf16; + p.o = (bf16_t*)o_ptr; + p.lse = (float*)lse_ptr; + p.sink_bias = sink_bias_ptr; + p.B = B; p.H = H; p.N = N; p.HD = HD; p.NOPE = NOPE; p.ROPE = ROPE; + p.q_nope_head_stride = q_nope_head_stride; + p.q_nope_batch_stride = q_nope_batch_stride; + p.q_scale_head_stride = q_scale_head_stride; + p.q_scale_batch_stride = q_scale_batch_stride; + p.q_rope_head_stride = q_rope_head_stride; + p.q_rope_batch_stride = q_rope_batch_stride; + p.o_head_stride = o_head_stride; + p.o_batch_stride = o_batch_stride; + p.lse_head_stride = lse_head_stride; + p.lse_batch_stride = lse_batch_stride; + p.scale = scale; + + // Static shared memory size for fmha_mixed_fp8_decode_kernel<512,448,64>. + // Keep this mirrored with the header layout and aligned up generously. + int smem = 0; + smem += 4; smem = (smem + 127) & ~127; + smem += 128 * 32; smem = (smem + 127) & ~127; // sQ8 + smem += 128 * 32; smem = (smem + 127) & ~127; // sK8 + smem += 128 * 16 * 2; smem = (smem + 127) & ~127; // sQ16 + smem += 128 * 16 * 2; smem = (smem + 127) & ~127; // sK16 + smem += 128 * 16 * 2; smem = (smem + 127) & ~127; // sPk + smem += 16 * 16 * 2; smem = (smem + 127) & ~127; // sV + smem += 128 * 4; // sLogits + smem += 128 * 4; // sP + smem += 512 * 4; // sOacc + smem += 512 * 2; // sOepi + smem = (smem + 127) & ~127; + + cudaFuncSetAttribute(fmha_mixed_fp8_decode_kernel<512,448,64>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem); + dim3 grid(1, H, B); + dim3 block(192); + fmha_mixed_fp8_decode_kernel<512,448,64><<>>(p); + cudaError_t err = cudaGetLastError(); + return err == cudaSuccess ? 0 : (int)err; +} + +} // extern C diff --git a/dsv4/kernels/attention/fmha_mixed_fp8_decode.cuh b/dsv4/kernels/attention/fmha_mixed_fp8_decode.cuh new file mode 100644 index 00000000..65cd4b04 --- /dev/null +++ b/dsv4/kernels/attention/fmha_mixed_fp8_decode.cuh @@ -0,0 +1,374 @@ +/** + * DSV4 B1 — mixed FP8/BF16 decode FMHA for DeepSeek-V4 attention KV. + * + * Inputs are the storage-native DSV4 layout: + * Q noPE: FP8_E4M3 + per-row FP32 scale, Q RoPE: BF16 + * KV noPE: FP8_E4M3 + per-row FP32 scale, KV RoPE: BF16 + * + * This first B1 kernel targets the decode hot path (T == 1) and HD=512, + * NOPE=448, ROPE=64. It removes the global FP8->BF16 KV dequant/gather and + * uses Blackwell tcgen05 tensor cores for: + * - noPE QK: f8f6f4 E4M3 x E4M3 -> FP32 + * - RoPE QK: f16 BF16 x BF16 -> FP32 + * - PV: f16 BF16 x BF16 -> FP32, with noPE V dequantized only into SMEM + * + * The noPE KV is never materialized as a global BF16 buffer. + */ +#pragma once + +#include +#include +#include +#include +#include +#include "fmha_common.cuh" +#include "fmha_umma_desc.cuh" + +namespace dsv4::kernels::attention { + +struct FmhaMixedFp8DecodeParams { + const uint8_t* __restrict__ q_nope_fp8; // (B,H,1,NOPE) + const float* __restrict__ q_nope_scale; // (B,H,1) + const bf16_t* __restrict__ q_rope_bf16; // (B,H,1,ROPE) + + const uint8_t* __restrict__ k_nope_fp8; // (N,NOPE), MQA shared + const float* __restrict__ k_nope_scale; // (N,) + const bf16_t* __restrict__ k_rope_bf16; // (N,ROPE) + + bf16_t* __restrict__ o; // (B,H,1,HD) + float* __restrict__ lse; // (B,H,1), optional + const float* __restrict__ sink_bias; // (B,H), optional + + int B, H, N, HD, NOPE, ROPE; + int q_nope_head_stride, q_nope_batch_stride; + int q_scale_head_stride, q_scale_batch_stride; + int q_rope_head_stride, q_rope_batch_stride; + int o_head_stride, o_batch_stride; + int lse_head_stride, lse_batch_stride; + float scale; +}; + +__device__ __forceinline__ float fp8_e4m3_to_f32(uint8_t byte) { + __nv_fp8_e4m3 v; + *reinterpret_cast(&v) = byte; + return static_cast(v); +} + +// FP8 canonical K-major layout for tcgen05.mma.kind::f8f6f4. +// Logical matrix shape is (128, 32): 8 row groups x 16 FP8 columns per 128B atom. +__device__ __forceinline__ int canon_idx_fp8_128x32(int r, int c) { + constexpr int CORES_MN = 16; // 128 / 8 + int core_mn = r >> 3; + int core_k = c >> 4; // 16 FP8 values = 16B atom width + int local_r = r & 7; + int local_c = c & 15; + return core_k * CORES_MN * 128 + core_mn * 128 + local_r * 16 + local_c; +} + +__device__ __forceinline__ int canon_idx_bf16_128x16(int r, int c) { + constexpr int CORES_MN = 16; + int core_mn = r >> 3; + int core_k = c >> 3; + int local_r = r & 7; + int local_c = c & 7; + return core_k * CORES_MN * 64 + core_mn * 64 + local_r * 8 + local_c; +} + +__device__ __forceinline__ int canon_idx_bf16_16x16(int r, int c) { + constexpr int CORES_MN = 2; // 16 / 8 + int core_mn = r >> 3; + int core_k = c >> 3; + int local_r = r & 7; + int local_c = c & 7; + return core_k * CORES_MN * 64 + core_mn * 64 + local_r * 8 + local_c; +} + +__device__ __forceinline__ bf16_t f32_to_bf16_bits(float x) { return f32_to_bf16(x); } + +// Read row 0 of a 128-wide TMEM result. Must be called by a full warp; +// lane 0 receives row 0, lanes 1..31 receive rows 1..31 and are ignored. +__device__ __forceinline__ void read_tmem_row0_128(uint32_t tb, float* out128, bool lane0) { + for (int n = 0; n < 16; n++) { + float tmp[8]; + asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];" + : "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]), + "=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7]) + : "r"(tb + n * 8)); + asm volatile("tcgen05.wait::ld.sync.aligned;" ::: "memory"); + if (lane0) { + #pragma unroll + for (int c = 0; c < 8; c++) out128[n * 8 + c] = tmp[c]; + } + } +} + +template +__global__ void __launch_bounds__(192) +fmha_mixed_fp8_decode_kernel(FmhaMixedFp8DecodeParams p) { + static_assert(HD == 512 && NOPE == 448 && ROPE == 64, "B1 first pass is specialized for DSV4 HD=512/NOPE=448/ROPE=64"); + constexpr int MMA_K_F8 = 32; + constexpr int MMA_K_F16 = 16; + constexpr int NKT_NOPE = NOPE / MMA_K_F8; + constexpr int NKT_ROPE = ROPE / MMA_K_F16; + constexpr int NKT_PV = SK_TILE / MMA_K_F16; + constexpr int N_SUB = HD / 16; + constexpr int TILE_F8 = 128 * MMA_K_F8; // bytes + constexpr int TILE_F16 = 128 * MMA_K_F16; // bf16 elements + constexpr int V_SUB_SZ = 16 * MMA_K_F16; // bf16 elements + constexpr int TMEM_COLS = 512; + + const int head_idx = blockIdx.y; + const int batch_idx = blockIdx.z; + const int tid = threadIdx.x; + const int wid = tid >> 5; + const int lane = tid & 31; + const bool is_mma_warp = (wid == 4); + const bool is_lane0 = (wid == 0 && lane == 0); + const int n_kv_tiles = (p.N + SK_TILE - 1) / SK_TILE; + + const uint8_t* q8 = p.q_nope_fp8 + batch_idx * p.q_nope_batch_stride + head_idx * p.q_nope_head_stride; + const float q8_scale = p.q_nope_scale[batch_idx * p.q_scale_batch_stride + head_idx * p.q_scale_head_stride]; + const bf16_t* qrope = p.q_rope_bf16 + batch_idx * p.q_rope_batch_stride + head_idx * p.q_rope_head_stride; + bf16_t* out = p.o + batch_idx * p.o_batch_stride + head_idx * p.o_head_stride; + float* lse = p.lse ? p.lse + batch_idx * p.lse_batch_stride + head_idx * p.lse_head_stride : nullptr; + + extern __shared__ __align__(128) char sbuf[]; + size_t off = 0; + uint32_t* sTmemBase = (uint32_t*)(sbuf + off); off += 4; + off = (off + 127) & ~(size_t)127; + uint8_t* sQ8 = (uint8_t*)(sbuf + off); off += TILE_F8; + off = (off + 127) & ~(size_t)127; + uint8_t* sK8 = (uint8_t*)(sbuf + off); off += TILE_F8; + off = (off + 127) & ~(size_t)127; + bf16_t* sQ16 = (bf16_t*)(sbuf + off); off += TILE_F16 * sizeof(bf16_t); + off = (off + 127) & ~(size_t)127; + bf16_t* sK16 = (bf16_t*)(sbuf + off); off += TILE_F16 * sizeof(bf16_t); + off = (off + 127) & ~(size_t)127; + bf16_t* sPk = (bf16_t*)(sbuf + off); off += TILE_F16 * sizeof(bf16_t); + off = (off + 127) & ~(size_t)127; + bf16_t* sV = (bf16_t*)(sbuf + off); off += V_SUB_SZ * sizeof(bf16_t); + off = (off + 127) & ~(size_t)127; + float* sLogits = (float*)(sbuf + off); off += SK_TILE * sizeof(float); + float* sP = (float*)(sbuf + off); off += SK_TILE * sizeof(float); + float* sOacc = (float*)(sbuf + off); off += HD * sizeof(float); + bf16_t* sOepi = (bf16_t*)(sbuf + off); off += HD * sizeof(bf16_t); + + if (is_mma_warp) tmem_alloc((uint32_t)__cvta_generic_to_shared(sTmemBase), TMEM_COLS); + asm volatile("fence.proxy.async.shared::cta;" ::: "memory"); + __syncthreads(); + uint32_t tb = *sTmemBase; + + if (tid < HD) sOacc[tid] = 0.0f; + if (tid < SK_TILE) { sLogits[tid] = -INFINITY; sP[tid] = 0.0f; } + __syncthreads(); + + float running_max = -INFINITY; + float running_sum = 0.0f; + const uint32_t idesc_f8_qk = make_idesc_f8_e4m3(128, 128); + const uint32_t idesc_f16_qk = make_idesc(128, 128); + const uint32_t idesc_pv = make_idesc(128, 16); + + for (int kv_tile = 0; kv_tile < n_kv_tiles; kv_tile++) { + const int kv_start = kv_tile * SK_TILE; + const int kv_len = min(SK_TILE, p.N - kv_start); + + // ------------------------------------------------------------ + // QK noPE: FP8 tensor cores, raw logits in TMEM. + // ------------------------------------------------------------ + for (int kt = 0; kt < NKT_NOPE; kt++) { + for (int i = tid; i < TILE_F8; i += blockDim.x) { sQ8[i] = 0; sK8[i] = 0; } + __syncthreads(); + for (int c = tid; c < MMA_K_F8; c += blockDim.x) { + int d = kt * MMA_K_F8 + c; + sQ8[canon_idx_fp8_128x32(0, c)] = q8[d]; + } + for (int i = tid; i < kv_len * MMA_K_F8; i += blockDim.x) { + int r = i / MMA_K_F8, c = i % MMA_K_F8; + int d = kt * MMA_K_F8 + c; + sK8[canon_idx_fp8_128x32(r, c)] = p.k_nope_fp8[(int64_t)(kv_start + r) * NOPE + d]; + } + __syncthreads(); + if (is_mma_warp && lane == 0) { + uint64_t dq = make_umma_desc_kmajor_none((uint32_t)__cvta_generic_to_shared(sQ8), 128); + uint64_t dk = make_umma_desc_kmajor_none((uint32_t)__cvta_generic_to_shared(sK8), 128); + umma_ss_f8f6f4(tb, dq, dk, idesc_f8_qk, kt > 0); + asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory"); + } + __syncthreads(); + } + asm volatile("fence.sc.gpu;" ::: "memory"); + __syncthreads(); + + if (wid == 0) read_tmem_row0_128(tb, sLogits, lane == 0); + __syncthreads(); + if (is_lane0) { + #pragma unroll + for (int c = 0; c < SK_TILE; c++) { + if (c < kv_len) { + float ks = p.k_nope_scale[kv_start + c]; + sLogits[c] = sLogits[c] * q8_scale * ks; + } else { + sLogits[c] = -INFINITY; + } + } + } + __syncthreads(); + + // ------------------------------------------------------------ + // QK RoPE: BF16 tensor cores, then add to scaled noPE logits. + // ------------------------------------------------------------ + for (int kt = 0; kt < NKT_ROPE; kt++) { + for (int i = tid; i < TILE_F16; i += blockDim.x) { sQ16[i] = 0; sK16[i] = 0; } + __syncthreads(); + for (int c = tid; c < MMA_K_F16; c += blockDim.x) { + int d = kt * MMA_K_F16 + c; + sQ16[canon_idx_bf16_128x16(0, c)] = qrope[d]; + } + for (int i = tid; i < kv_len * MMA_K_F16; i += blockDim.x) { + int r = i / MMA_K_F16, c = i % MMA_K_F16; + int d = kt * MMA_K_F16 + c; + sK16[canon_idx_bf16_128x16(r, c)] = p.k_rope_bf16[(int64_t)(kv_start + r) * ROPE + d]; + } + __syncthreads(); + if (is_mma_warp && lane == 0) { + uint64_t dq = make_umma_desc_kmajor_none((uint32_t)__cvta_generic_to_shared(sQ16), 128); + uint64_t dk = make_umma_desc_kmajor_none((uint32_t)__cvta_generic_to_shared(sK16), 128); + umma_ss_f16(tb, dq, dk, idesc_f16_qk, kt > 0); + asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory"); + } + __syncthreads(); + } + asm volatile("fence.sc.gpu;" ::: "memory"); + __syncthreads(); + + // Use sP as a temporary row buffer here; probabilities are formed later. + if (wid == 0) read_tmem_row0_128(tb, sP, lane == 0); + __syncthreads(); + if (is_lane0) { + for (int c = 0; c < kv_len; c++) sLogits[c] += sP[c]; + } + __syncthreads(); + + // ------------------------------------------------------------ + // Softmax tile probabilities for row 0. + // ------------------------------------------------------------ + float tile_max = -INFINITY; + if (is_lane0) { + for (int c = 0; c < kv_len; c++) tile_max = fmaxf(tile_max, sLogits[c] * p.scale); + float tile_sum = 0.0f; + for (int c = 0; c < kv_len; c++) { + float pv = expf(sLogits[c] * p.scale - tile_max); + sP[c] = pv; + tile_sum += pv; + } + for (int c = kv_len; c < SK_TILE; c++) sP[c] = 0.0f; + + float new_max = fmaxf(running_max, tile_max); + float rescale_old = (running_max > -INFINITY) ? expf(running_max - new_max) : 0.0f; + for (int d = 0; d < HD; d++) sOacc[d] *= rescale_old; + running_sum = running_sum * rescale_old + tile_sum * expf(tile_max - new_max); + running_max = new_max; + } + __syncthreads(); + + // ------------------------------------------------------------ + // PV: probabilities BF16 x V BF16. noPE V is dequantized into SMEM only. + // ------------------------------------------------------------ + for (int n_sub = 0; n_sub < N_SUB; n_sub++) { + int d_base = n_sub * 16; + for (int pv_kt = 0; pv_kt < NKT_PV; pv_kt++) { + const int col_start = pv_kt * MMA_K_F16; + for (int i = tid; i < TILE_F16; i += blockDim.x) sPk[i] = 0; + for (int i = tid; i < V_SUB_SZ; i += blockDim.x) sV[i] = 0; + __syncthreads(); + + // P matrix: only row 0 non-zero. + for (int c = tid; c < MMA_K_F16; c += blockDim.x) { + int gc = col_start + c; + sPk[canon_idx_bf16_128x16(0, c)] = f32_to_bf16_bits(sP[gc]); + } + + // V matrix B: logical (16 K rows, 16 N cols) in BF16 canonical layout. + for (int i = tid; i < 16 * MMA_K_F16; i += blockDim.x) { + int dd = i / MMA_K_F16; + int kk = i % MMA_K_F16; + int row = col_start + kk; + int g_row = kv_start + row; + int d = d_base + dd; + bf16_t vbits = 0; + if (row < kv_len) { + if (d < NOPE) { + uint8_t b = p.k_nope_fp8[(int64_t)g_row * NOPE + d]; + float v = fp8_e4m3_to_f32(b) * p.k_nope_scale[g_row]; + vbits = f32_to_bf16_bits(v); + } else { + vbits = p.k_rope_bf16[(int64_t)g_row * ROPE + (d - NOPE)]; + } + } + // B is (K=16 rows, N=16 cols). Reuse BF16 canonical with rows=16 + // by embedding into the first 16 rows of a 128-row tile; MMA_N=16. + sV[canon_idx_bf16_16x16(kk, dd)] = vbits; + } + __syncthreads(); + + if (is_mma_warp && lane == 0) { + uint64_t dp = make_umma_desc_kmajor_none((uint32_t)__cvta_generic_to_shared(sPk), 128); + uint64_t dv = make_umma_desc_kmajor_none((uint32_t)__cvta_generic_to_shared(sV), 16); + umma_ss_f16(tb + n_sub * 16, dp, dv, idesc_pv, pv_kt > 0); + asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory"); + } + __syncthreads(); + } + } + asm volatile("fence.sc.gpu;" ::: "memory"); + __syncthreads(); + + // Accumulate PV tile contribution after applying exp(tile_max-new_max). + if (wid == 0) { + float rescale_new = 0.0f; + if (lane == 0) { + // running_max is already the post-tile max. Recompute tile scale. + float tile_max2 = -INFINITY; + for (int c = 0; c < kv_len; c++) tile_max2 = fmaxf(tile_max2, sLogits[c] * p.scale); + rescale_new = expf(tile_max2 - running_max); + } + for (int n = 0; n < HD / 8; n++) { + float tmp[8]; + asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];" + : "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]), + "=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7]) + : "r"(tb + n * 8)); + asm volatile("tcgen05.wait::ld.sync.aligned;" ::: "memory"); + if (lane == 0) { + #pragma unroll + for (int c = 0; c < 8; c++) sOacc[n * 8 + c] += tmp[c] * rescale_new; + } + } + } + __syncthreads(); + } + + // Attention sink: denominator-only logit. + if (is_lane0 && p.sink_bias != nullptr) { + float sb = p.sink_bias[batch_idx * p.H + head_idx]; + float new_max = fmaxf(running_max, sb); + float rescale_old = (running_max > -INFINITY) ? expf(running_max - new_max) : 0.0f; + for (int d = 0; d < HD; d++) sOacc[d] *= rescale_old; + running_sum = running_sum * rescale_old + expf(sb - new_max); + running_max = new_max; + } + __syncthreads(); + + if (is_lane0) { + float inv_sum = 1.0f / running_sum; + for (int d = 0; d < HD; d++) sOepi[d] = f32_to_bf16_bits(sOacc[d] * inv_sum); + if (lse) lse[0] = logf(running_sum) + running_max; + } + __syncthreads(); + for (int d = tid; d < HD; d += blockDim.x) out[d] = sOepi[d]; + __syncthreads(); + + if (is_mma_warp) tmem_dealloc(tb, TMEM_COLS); +} + +} // namespace dsv4::kernels::attention diff --git a/dsv4/kernels/attention/fmha_mixed_fp8_op.py b/dsv4/kernels/attention/fmha_mixed_fp8_op.py new file mode 100644 index 00000000..ba2bd1c1 --- /dev/null +++ b/dsv4/kernels/attention/fmha_mixed_fp8_op.py @@ -0,0 +1,148 @@ +"""DSV4 B1 mixed FP8/BF16 decode FMHA loader. + +This path is intentionally hard-error only: it does not fall back to PyTorch or to +BF16 FMHA if the mixed FP8 kernel is requested. +""" +import ctypes +import logging +import os +import subprocess +from typing import Optional + +import torch + +logger = logging.getLogger(__name__) + +KERNEL_DIR = os.path.dirname(os.path.abspath(__file__)) +REPO_ROOT = os.path.normpath(os.path.join(KERNEL_DIR, "..", "..")) +SOURCE = os.path.join(KERNEL_DIR, "fmha_mixed_fp8_capi.cu") +BUILD_DIR = os.path.join(REPO_ROOT, "build", "fmha_mixed_fp8") +SO_NAME = "libfmha_mixed_fp8.so" + +_lib = None +_lib_lock = False + + +def _find_nvcc(): + import shutil + for c in ["/usr/local/cuda-13.2/bin/nvcc", "/usr/local/cuda/bin/nvcc"]: + if os.path.isfile(c): + return c + nvcc = shutil.which("nvcc") + if nvcc: + return nvcc + raise RuntimeError("nvcc not found") + + +def _ensure_built(): + global _lib, _lib_lock + if _lib is not None: + return _lib + if _lib_lock: + raise RuntimeError("Recursive mixed-FP8 FMHA build") + _lib_lock = True + try: + so_path = os.path.join(BUILD_DIR, SO_NAME) + deps = [ + SOURCE, + os.path.join(KERNEL_DIR, "fmha_common.cuh"), + os.path.join(KERNEL_DIR, "fmha_umma_desc.cuh"), + os.path.join(KERNEL_DIR, "fmha_mixed_fp8_decode.cuh"), + ] + src_mtime = max(os.path.getmtime(p) for p in deps if os.path.exists(p)) + need_build = not os.path.isfile(so_path) or src_mtime > os.path.getmtime(so_path) + if not need_build: + _lib = ctypes.CDLL(so_path) + return _lib + + os.makedirs(BUILD_DIR, exist_ok=True) + nvcc = _find_nvcc() + cmd = [ + nvcc, "-std=c++20", "-shared", "-Xcompiler", "-fPIC", + "-gencode=arch=compute_100a,code=sm_100a", + "-gencode=arch=compute_100a,code=compute_100a", + f"-I{KERNEL_DIR}", f"-I{REPO_ROOT}", + "-O3", "--use_fast_math", "--expt-relaxed-constexpr", + SOURCE, "-o", so_path, "-lcudart", "-lcuda", + ] + logger.info("Building libfmha_mixed_fp8.so (sm_100a)...") + res = subprocess.run(cmd, capture_output=True, text=True) + if res.returncode != 0: + raise RuntimeError(f"mixed FP8 FMHA nvcc failed:\nSTDOUT:\n{res.stdout}\nSTDERR:\n{res.stderr}") + _lib = ctypes.CDLL(so_path) + return _lib + finally: + _lib_lock = False + + +def _quantize_q_split(q: torch.Tensor, rope_dim: int): + from dsv4.kernels.cuda.loader import get_cuda_module + mod = get_cuda_module("fp8_attention_io", ["fp8_attention_io.cu"], + extra_cuda_cflags=[ + "-gencode=arch=compute_100a,code=sm_100a", + "-O3", "--use_fast_math", "--expt-relaxed-constexpr", + ]) + return mod.quantize_q_fp8_split(q, rope_dim) + + +def fmha_mixed_fp8_decode_raw( + q: torch.Tensor, # (B,H,1,HD) BF16 + k_nope_fp8: torch.Tensor, # (N,NOPE) uint8/float8_e4m3fn + k_nope_scale: torch.Tensor, # (N,) FP32 + k_rope_bf16: torch.Tensor, # (N,ROPE) BF16 + scale: float, + attn_sink: Optional[torch.Tensor] = None, + rope_dim: int = 64, +): + if q.dim() != 4: + raise RuntimeError("q must be (B,H,T,HD)") + B, H, T, HD = q.shape + if T != 1: + raise RuntimeError("mixed FP8 FMHA supports decode T==1 only") + NOPE = HD - rope_dim + if HD != 512 or NOPE != 448 or rope_dim != 64: + raise RuntimeError(f"mixed FP8 FMHA first pass supports HD=512/NOPE=448/ROPE=64, got {HD}/{NOPE}/{rope_dim}") + + q = q.contiguous() + k_nope_fp8 = k_nope_fp8.contiguous() + k_nope_scale = k_nope_scale.contiguous() + k_rope_bf16 = k_rope_bf16.contiguous() + q_nope_fp8, q_nope_scale, q_rope = _quantize_q_split(q, rope_dim) + + N = k_nope_fp8.shape[0] + o = torch.empty((B, H, T, HD), dtype=torch.bfloat16, device=q.device) + lse = torch.empty((B, H, T), dtype=torch.float32, device=q.device) + + sink_ptr = ctypes.c_void_p(0) + sb = None + if attn_sink is not None: + sb = attn_sink.float().contiguous() + if sb.dim() == 1: + sb = sb.unsqueeze(0).expand(B, -1).contiguous() + if tuple(sb.shape) != (B, H): + raise RuntimeError(f"sink bias shape {tuple(sb.shape)} != {(B,H)}") + sink_ptr = ctypes.c_void_p(sb.data_ptr()) + + lib = _ensure_built() + ret = lib.fmha_mixed_fp8_decode_launch( + ctypes.c_void_p(q_nope_fp8.data_ptr()), + ctypes.c_void_p(q_nope_scale.data_ptr()), + ctypes.c_void_p(q_rope.data_ptr()), + ctypes.c_void_p(k_nope_fp8.data_ptr()), + ctypes.c_void_p(k_nope_scale.data_ptr()), + ctypes.c_void_p(k_rope_bf16.data_ptr()), + ctypes.c_void_p(o.data_ptr()), + ctypes.c_void_p(lse.data_ptr()), + sink_ptr, + ctypes.c_int(B), ctypes.c_int(H), ctypes.c_int(T), ctypes.c_int(N), + ctypes.c_int(HD), ctypes.c_int(NOPE), ctypes.c_int(rope_dim), + ctypes.c_int(q_nope_fp8.stride(1)), ctypes.c_int(q_nope_fp8.stride(0)), + ctypes.c_int(q_nope_scale.stride(1)), ctypes.c_int(q_nope_scale.stride(0)), + ctypes.c_int(q_rope.stride(1)), ctypes.c_int(q_rope.stride(0)), + ctypes.c_int(o.stride(1)), ctypes.c_int(o.stride(0)), + ctypes.c_int(lse.stride(1)), ctypes.c_int(lse.stride(0)), + ctypes.c_float(scale), + ) + if ret != 0: + raise RuntimeError(f"mixed FP8 FMHA launch failed: return code {ret}") + return o, lse diff --git a/dsv4/kernels/attention/fmha_umma_desc.cuh b/dsv4/kernels/attention/fmha_umma_desc.cuh index 0ecb2a0c..7ee6f200 100644 --- a/dsv4/kernels/attention/fmha_umma_desc.cuh +++ b/dsv4/kernels/attention/fmha_umma_desc.cuh @@ -340,4 +340,31 @@ __device__ __forceinline__ uint32_t make_idesc(int block_m, int block_n) { | ((uint32_t)(block_m >> 4) << 24); // MMA_M } +/** + * tcgen05.mma SS for .kind::f8f6f4 with E4M3xE4M3 -> FP32. + * A and B element types are encoded in idesc. For B1 we use E4M3/E4M3. + */ +__device__ void umma_ss_f8f6f4( + uint32_t tmem_c, uint64_t desc_a, uint64_t desc_b, + uint32_t i_desc, bool accumulate = false +) { + uint32_t scaleC_bits = accumulate ? 0x3F800000u : 0u; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::f8f6f4 [%0], %1, %2, %3, p;\n\t" + "}" + :: "r"(tmem_c), "l"(desc_a), "l"(desc_b), + "r"(i_desc), "r"(scaleC_bits) + ); +} + +/** Instruction descriptor for .kind::f8f6f4 E4M3 x E4M3 -> FP32. */ +__device__ __forceinline__ uint32_t make_idesc_f8_e4m3(int block_m, int block_n) { + return (1U << 4) // dtype = F32 + | ((uint32_t)(block_n >> 3) << 17) // MMA_N + | ((uint32_t)(block_m >> 4) << 24); // MMA_M +} + } // namespace dsv4::kernels::attention diff --git a/dsv4/kernels/attention/production.py b/dsv4/kernels/attention/production.py index ba27e1b0..001535fb 100644 --- a/dsv4/kernels/attention/production.py +++ b/dsv4/kernels/attention/production.py @@ -195,3 +195,41 @@ def dsv4_attention_per_head( output[q_idx] = o return output + + +# --------------------------------------------------------------------------- +# B1: mixed FP8/BF16 DeepSeek-V4 decode attention +# --------------------------------------------------------------------------- + +def dsv4_attention_mixed_fp8_decode( + q: torch.Tensor, # (n_q_heads,T,HD) or (B,n_q_heads,T,HD) BF16 + k_nope_fp8: torch.Tensor, # (N,NOPE) uint8/float8_e4m3fn + k_nope_scale: torch.Tensor, # (N,) FP32 + k_rope_bf16: torch.Tensor, # (N,ROPE) BF16 + scale: Optional[float] = None, + sink_bias: Optional[torch.Tensor] = None, + rope_dim: int = 64, +) -> torch.Tensor: + """B1 production path: storage-native FP8/BF16 KV decode FMHA. + + This intentionally has no PyTorch/BF16 fallback. It is the decode-only path + for DeepSeek-V4 attention where noPE KV is already stored as FP8_E4M3 with + per-row FP32 scales and RoPE KV is BF16. + """ + from dsv4.kernels.attention.fmha_mixed_fp8_op import fmha_mixed_fp8_decode_raw + + has_batch = q.dim() == 4 + if q.dim() == 3: + q4 = q.unsqueeze(0).contiguous() + elif q.dim() == 4: + q4 = q.contiguous() + else: + raise RuntimeError("q must be (H,T,HD) or (B,H,T,HD)") + + hd = q4.shape[-1] + scale = scale or (1.0 / math.sqrt(hd)) + o4, _lse = fmha_mixed_fp8_decode_raw( + q4, k_nope_fp8, k_nope_scale, k_rope_bf16, + scale, attn_sink=sink_bias, rope_dim=rope_dim, + ) + return o4 if has_batch else o4.squeeze(0) diff --git a/dsv4/kernels/cuda/fp8_attention_io.cu b/dsv4/kernels/cuda/fp8_attention_io.cu new file mode 100644 index 00000000..853cae76 --- /dev/null +++ b/dsv4/kernels/cuda/fp8_attention_io.cu @@ -0,0 +1,254 @@ +/** + * DSV4 B1 — FP8 attention input/output preparation kernels. + * + * These are deliberately tiny launch-count reducers for the mixed-precision + * FMHA path: + * - quantize Q noPE dims BF16 -> FP8_E4M3 with a per-(batch,head,row) scale + * - keep Q RoPE dims BF16 + * - gather compressed KV noPE bytes/scales and RoPE BF16 without global dequant + * - quantize the SWA noPE tail BF16 -> FP8_E4M3 in the same gather kernel + * + * No PyTorch fallback and no FP8->BF16 global staging for noPE KV. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static constexpr float E4M3_MAX = 448.0f; + +__device__ __forceinline__ float bf16_load(const __nv_bfloat16* p) { + return __bfloat162float(*p); +} + +__device__ __forceinline__ uint8_t fp8_e4m3_from_f32(float x) { + x = fminf(fmaxf(x, -E4M3_MAX), E4M3_MAX); + __nv_fp8_e4m3 v(x); + return *reinterpret_cast(&v); +} + +__global__ void quantize_q_fp8_split_kernel( + const __nv_bfloat16* __restrict__ q, // (B,H,T,HD) + uint8_t* __restrict__ q_nope_fp8, // (B,H,T,NOPE) + float* __restrict__ q_nope_scale, // (B,H,T) + __nv_bfloat16* __restrict__ q_rope, // (B,H,T,ROPE) + int rows, int hd, int nope, int rope +) { + int row = blockIdx.x; + if (row >= rows) return; + + const __nv_bfloat16* q_row = q + (int64_t)row * hd; + uint8_t* out8 = q_nope_fp8 + (int64_t)row * nope; + __nv_bfloat16* outrope = q_rope + (int64_t)row * rope; + + float local_max = 0.0f; + for (int c = threadIdx.x; c < nope; c += blockDim.x) { + local_max = fmaxf(local_max, fabsf(bf16_load(q_row + c))); + } + + // block reduction over 256 threads + for (int off = 16; off > 0; off >>= 1) + local_max = fmaxf(local_max, __shfl_down_sync(0xffffffff, local_max, off)); + __shared__ float warp_max[8]; + if ((threadIdx.x & 31) == 0) warp_max[threadIdx.x >> 5] = local_max; + __syncthreads(); + float amax = 0.0f; + if (threadIdx.x < 32) { + amax = (threadIdx.x < (blockDim.x + 31) / 32) ? warp_max[threadIdx.x] : 0.0f; + for (int off = 16; off > 0; off >>= 1) + amax = fmaxf(amax, __shfl_down_sync(0xffffffff, amax, off)); + if (threadIdx.x == 0) { + float scale = amax / E4M3_MAX; + if (scale < 1e-8f) scale = 1e-8f; + q_nope_scale[row] = scale; + } + } + __syncthreads(); + + float scale = q_nope_scale[row]; + float inv_scale = 1.0f / scale; + for (int c = threadIdx.x; c < nope; c += blockDim.x) { + out8[c] = fp8_e4m3_from_f32(bf16_load(q_row + c) * inv_scale); + } + for (int c = threadIdx.x; c < rope; c += blockDim.x) { + outrope[c] = q_row[nope + c]; + } +} + +__global__ void copy_comp_rows_kernel( + const uint8_t* __restrict__ comp_nope_fp8, + const float* __restrict__ comp_nope_scale, + const __nv_bfloat16* __restrict__ comp_rope, + const int32_t* __restrict__ indices, // optional; nullptr => row i + uint8_t* __restrict__ out_nope_fp8, + float* __restrict__ out_nope_scale, + __nv_bfloat16* __restrict__ out_rope, + int K, int nope, int rope +) { + int row = blockIdx.y; + int col = blockIdx.x * blockDim.x + threadIdx.x; + if (row >= K) return; + int src = indices ? indices[row] : row; + if (col < nope) out_nope_fp8[(int64_t)row * nope + col] = comp_nope_fp8[(int64_t)src * nope + col]; + if (col < rope) out_rope[(int64_t)row * rope + col] = comp_rope[(int64_t)src * rope + col]; + if (blockIdx.x == 0 && threadIdx.x == 0) out_nope_scale[row] = comp_nope_scale[src]; +} + +__global__ void quantize_swa_tail_kernel( + const __nv_bfloat16* __restrict__ swa, // (S, HD), BF16 + uint8_t* __restrict__ out_nope_fp8, // (K+S, NOPE) + float* __restrict__ out_nope_scale, // (K+S) + __nv_bfloat16* __restrict__ out_rope, // (K+S, ROPE) + int K, int S, int hd, int nope, int rope +) { + int s = blockIdx.x; + if (s >= S) return; + int out_row = K + s; + const __nv_bfloat16* src = swa + (int64_t)s * hd; + uint8_t* out8 = out_nope_fp8 + (int64_t)out_row * nope; + __nv_bfloat16* outrope = out_rope + (int64_t)out_row * rope; + + float local_max = 0.0f; + for (int c = threadIdx.x; c < nope; c += blockDim.x) { + local_max = fmaxf(local_max, fabsf(bf16_load(src + c))); + } + for (int off = 16; off > 0; off >>= 1) + local_max = fmaxf(local_max, __shfl_down_sync(0xffffffff, local_max, off)); + __shared__ float warp_max[8]; + if ((threadIdx.x & 31) == 0) warp_max[threadIdx.x >> 5] = local_max; + __syncthreads(); + float amax = 0.0f; + if (threadIdx.x < 32) { + amax = (threadIdx.x < (blockDim.x + 31) / 32) ? warp_max[threadIdx.x] : 0.0f; + for (int off = 16; off > 0; off >>= 1) + amax = fmaxf(amax, __shfl_down_sync(0xffffffff, amax, off)); + if (threadIdx.x == 0) { + float scale = amax / E4M3_MAX; + if (scale < 1e-8f) scale = 1e-8f; + out_nope_scale[out_row] = scale; + } + } + __syncthreads(); + + float inv_scale = 1.0f / out_nope_scale[out_row]; + for (int c = threadIdx.x; c < nope; c += blockDim.x) { + out8[c] = fp8_e4m3_from_f32(bf16_load(src + c) * inv_scale); + } + for (int c = threadIdx.x; c < rope; c += blockDim.x) { + outrope[c] = src[nope + c]; + } +} + +std::tuple quantize_q_fp8_split_cuda( + torch::Tensor q, int64_t rope_dim +) { + TORCH_CHECK(q.is_cuda(), "q must be CUDA"); + TORCH_CHECK(q.scalar_type() == torch::kBFloat16, "q must be BF16"); + TORCH_CHECK(q.dim() == 4, "q must be (B,H,T,HD)"); + q = q.contiguous(); + int B = q.size(0), H = q.size(1), T = q.size(2), HD = q.size(3); + int rope = (int)rope_dim; + int nope = HD - rope; + TORCH_CHECK(nope > 0 && rope > 0, "invalid rope_dim"); + auto q8 = torch::empty({B, H, T, nope}, q.options().dtype(torch::kUInt8)); + auto qs = torch::empty({B, H, T}, q.options().dtype(torch::kFloat32)); + auto qr = torch::empty({B, H, T, rope}, q.options().dtype(torch::kBFloat16)); + int rows = B * H * T; + quantize_q_fp8_split_kernel<<>>( + reinterpret_cast(q.data_ptr()), + q8.data_ptr(), qs.data_ptr(), + reinterpret_cast<__nv_bfloat16*>(qr.data_ptr()), + rows, HD, nope, rope); + return {q8.view(torch::kFloat8_e4m3fn), qs, qr}; +} + +void gather_mixed_selective_cuda( + torch::Tensor comp_nope_fp8, torch::Tensor comp_nope_scale, torch::Tensor comp_rope, + torch::Tensor swa, torch::Tensor indices, + torch::Tensor out_nope_fp8, torch::Tensor out_nope_scale, torch::Tensor out_rope +) { + TORCH_CHECK(indices.scalar_type() == torch::kInt32, "indices must be int32"); + int K = indices.size(0); + int S = swa.size(0); + int nope = comp_nope_fp8.size(1); + int rope = comp_rope.size(1); + int hd = nope + rope; + if (K > 0) { + dim3 grid(((nope > rope ? nope : rope) + 255) / 256, K); + copy_comp_rows_kernel<<>>( + comp_nope_fp8.data_ptr(), comp_nope_scale.data_ptr(), + reinterpret_cast(comp_rope.data_ptr()), + indices.data_ptr(), + out_nope_fp8.data_ptr(), out_nope_scale.data_ptr(), + reinterpret_cast<__nv_bfloat16*>(out_rope.data_ptr()), + K, nope, rope); + } + if (S > 0) { + quantize_swa_tail_kernel<<>>( + reinterpret_cast(swa.data_ptr()), + out_nope_fp8.data_ptr(), out_nope_scale.data_ptr(), + reinterpret_cast<__nv_bfloat16*>(out_rope.data_ptr()), + K, S, hd, nope, rope); + } +} + +void gather_mixed_all_cuda( + torch::Tensor comp_nope_fp8, torch::Tensor comp_nope_scale, torch::Tensor comp_rope, + torch::Tensor swa, torch::Tensor out_nope_fp8, torch::Tensor out_nope_scale, torch::Tensor out_rope +) { + int K = comp_nope_fp8.size(0); + int S = swa.size(0); + int nope = comp_nope_fp8.size(1); + int rope = comp_rope.size(1); + int hd = nope + rope; + if (K > 0) { + dim3 grid(((nope > rope ? nope : rope) + 255) / 256, K); + copy_comp_rows_kernel<<>>( + comp_nope_fp8.data_ptr(), comp_nope_scale.data_ptr(), + reinterpret_cast(comp_rope.data_ptr()), + nullptr, + out_nope_fp8.data_ptr(), out_nope_scale.data_ptr(), + reinterpret_cast<__nv_bfloat16*>(out_rope.data_ptr()), + K, nope, rope); + } + if (S > 0) { + quantize_swa_tail_kernel<<>>( + reinterpret_cast(swa.data_ptr()), + out_nope_fp8.data_ptr(), out_nope_scale.data_ptr(), + reinterpret_cast<__nv_bfloat16*>(out_rope.data_ptr()), + K, S, hd, nope, rope); + } +} + +void gather_mixed_swa_only_cuda(torch::Tensor swa, torch::Tensor out_nope_fp8, + torch::Tensor out_nope_scale, torch::Tensor out_rope, + int64_t rope_dim) { + int S = swa.size(0); + int hd = swa.size(1); + int rope = (int)rope_dim; + int nope = hd - rope; + if (S > 0) { + quantize_swa_tail_kernel<<>>( + reinterpret_cast(swa.data_ptr()), + out_nope_fp8.data_ptr(), out_nope_scale.data_ptr(), + reinterpret_cast<__nv_bfloat16*>(out_rope.data_ptr()), + 0, S, hd, nope, rope); + } +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("quantize_q_fp8_split", &quantize_q_fp8_split_cuda, + "Split Q into FP8_E4M3 noPE + BF16 RoPE"); + m.def("gather_mixed_selective_", &gather_mixed_selective_cuda, + "In-place mixed KV gather for selected compressed rows + SWA tail"); + m.def("gather_mixed_all_", &gather_mixed_all_cuda, + "In-place mixed KV gather for all compressed rows + SWA tail"); + m.def("gather_mixed_swa_only_", &gather_mixed_swa_only_cuda, + "In-place mixed KV gather for SWA-only attention"); +} diff --git a/single_shot_inference.py b/single_shot_inference.py index b9742290..b09dd409 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -519,13 +519,29 @@ class KVCache: self.comp_idx_fp8 = torch.zeros(max_comp, indexer_key_dim, dtype=torch.uint8, device=device) self.comp_idx_scale = torch.zeros(max_comp, dtype=torch.float32, device=device) - # Pre-allocated gather buffer — top_k compressed + SWA window + # Pre-allocated mixed gather buffers. + # CSA needs top_k + SWA; HCA is dense over compressed blocks, so it needs + # max_comp + SWA. These buffers preserve the paper/native storage layout: + # noPE stays FP8_E4M3 + scale, RoPE stays BF16. + if compress_ratio > 4: + self.mixed_gather_cap = max_comp + window_size + elif compress_ratio == 4: + self.mixed_gather_cap = indexer_top_k + window_size + else: + self.mixed_gather_cap = window_size + self.gather_nope_fp8 = torch.zeros(self.mixed_gather_cap, self.nope_dim, dtype=torch.uint8, device=device) + self.gather_nope_scale = torch.zeros(self.mixed_gather_cap, dtype=torch.float32, device=device) + self.gather_rope_bf16 = torch.zeros(self.mixed_gather_cap, rope_dim, dtype=torch.bfloat16, device=device) + + # Legacy BF16 gather buffer kept only for non-B1 experiments; the live + # B1 path below does not materialize noPE KV as global BF16. self.gather_buf = torch.zeros(indexer_top_k + window_size, head_dim, dtype=torch.bfloat16, device=device) self.n_comp = 0 self._has_idx = False - # Cache dequant modules (loaded once) + # Cache extension modules (loaded once) self._kv_quant_mod = None + self._fp8_attn_io_mod = None def _get_kv_quant_mod(self): if self._kv_quant_mod is None: @@ -533,6 +549,18 @@ class KVCache: self._kv_quant_mod = get_cuda_module("kv_quantize", ["kv_quantize.cu"]) return self._kv_quant_mod + def _get_fp8_attn_io_mod(self): + if self._fp8_attn_io_mod is None: + from dsv4.kernels.cuda.loader import get_cuda_module + self._fp8_attn_io_mod = get_cuda_module( + "fp8_attention_io", ["fp8_attention_io.cu"], + extra_cuda_cflags=[ + "-gencode=arch=compute_100a,code=sm_100a", + "-O3", "--use_fast_math", "--expt-relaxed-constexpr", + ], + ) + return self._fp8_attn_io_mod + def append_swa(self, kv, pos): """Vectorized SWA append — 2 kernel launches instead of 2T.""" T = kv.shape[0] @@ -605,6 +633,53 @@ class KVCache: self.comp_idx_fp8[:self.n_comp], self.comp_idx_scale[:self.n_comp]) + def gather_mixed_selective(self, indices): + """Gather selected compressed KV + SWA into mixed FP8/BF16 buffers. + + Returns (nope_fp8, nope_scale, rope_bf16), each sliced to total length. + noPE is not dequantized to global BF16. + """ + mod = self._get_fp8_attn_io_mod() + swa_kv, _ = self.get_swa() + idx = indices.int().contiguous() + total = idx.numel() + swa_kv.shape[0] + if total > self.mixed_gather_cap: + raise RuntimeError(f"mixed gather capacity {self.mixed_gather_cap} < requested {total}") + mod.gather_mixed_selective_( + self.comp_nope_fp8, self.comp_nope_scale, self.comp_rope_bf16, + swa_kv, idx, self.gather_nope_fp8, self.gather_nope_scale, self.gather_rope_bf16) + return (self.gather_nope_fp8[:total], + self.gather_nope_scale[:total], + self.gather_rope_bf16[:total]) + + def gather_mixed_all(self): + """Gather all compressed KV + SWA in mixed FP8/BF16 storage for HCA.""" + mod = self._get_fp8_attn_io_mod() + swa_kv, _ = self.get_swa() + n_comp = int(self.n_comp) + total = n_comp + swa_kv.shape[0] + if total > self.mixed_gather_cap: + raise RuntimeError(f"mixed gather capacity {self.mixed_gather_cap} < requested {total}") + mod.gather_mixed_all_( + self.comp_nope_fp8[:n_comp], self.comp_nope_scale[:n_comp], self.comp_rope_bf16[:n_comp], + swa_kv, self.gather_nope_fp8, self.gather_nope_scale, self.gather_rope_bf16) + return (self.gather_nope_fp8[:total], + self.gather_nope_scale[:total], + self.gather_rope_bf16[:total]) + + def gather_mixed_swa_only(self): + """Quantize SWA noPE tail to FP8 and keep SWA RoPE as BF16.""" + mod = self._get_fp8_attn_io_mod() + swa_kv, _ = self.get_swa() + total = swa_kv.shape[0] + if total > self.mixed_gather_cap: + raise RuntimeError(f"mixed gather capacity {self.mixed_gather_cap} < requested {total}") + mod.gather_mixed_swa_only_( + swa_kv, self.gather_nope_fp8, self.gather_nope_scale, self.gather_rope_bf16, self.rope_dim) + return (self.gather_nope_fp8[:total], + self.gather_nope_scale[:total], + self.gather_rope_bf16[:total]) + def get_swa(self): """Return SWA KV and positions as views (no clone).""" if self.swa_len == 0: @@ -648,6 +723,28 @@ def _run_production_fmha(q_heads, all_kv, n_h, hd, T, seq_len, scale, dev, li, w attn_out = dsv4_attention(q=q, k=k, v=v, scale=scale, n_comp=0, sink_bias=sink_bias) return attn_out.permute(1, 0, 2) # (T, n_h, hd) + +def _run_production_fmha_mixed(q_heads, kv_nope_fp8, kv_nope_scale, kv_rope_bf16, + n_h, hd, T, seq_len, scale, dev, li, w, pfx, rope_dim): + """B1 storage-native mixed FP8/BF16 decode FMHA. No BF16 KV staging.""" + if T != 1: + raise RuntimeError(f"B1 mixed FP8 FMHA is decode-only (T==1); got T={T}") + from dsv4.kernels.attention.production import dsv4_attention_mixed_fp8_decode + q = q_heads.permute(1, 0, 2).contiguous() # (n_h, 1, hd) + sinks = w.get(f"{pfx}.sinks"); sink_bias = None + if sinks is not None: + sink_bias = sinks.to(device=dev).float().reshape(n_h) + attn_out = dsv4_attention_mixed_fp8_decode( + q=q, + k_nope_fp8=kv_nope_fp8, + k_nope_scale=kv_nope_scale, + k_rope_bf16=kv_rope_bf16, + scale=scale, + sink_bias=sink_bias, + rope_dim=rope_dim, + ) + return attn_out.permute(1, 0, 2) # (T, n_h, hd) + # ===================================================================== # Attention — ALL production kernels # ===================================================================== @@ -739,57 +836,38 @@ def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin, if indexer is not None and ratio == 4: topk_idx = indexer.forward(q_a, x_normed, kv_cache.comp_idx_kv, positions, layer_idx=li) - # 5. Gather KV — mixed storage: FP8 nope dequant + BF16 rope concat + # 5. Gather KV — B1 storage-native mixed path. + # noPE remains FP8_E4M3 + per-row scale; RoPE remains BF16. + # There is no global FP8->BF16 noPE materialization before FMHA. _pt('gather_start') swa_kv, _swa_pos = kv_cache.get_swa() swa_len = swa_kv.shape[0] - gbuf = kv_cache.gather_buf # (max_len, hd) pre-allocated BF16 if kv_cache.n_comp > 0: if ratio == 4: - # CSA: dequant only top-k entries + # CSA: gather top-k compressed rows + SWA tail without dequantizing noPE. assert topk_idx is not None, f"CSA layer {li}: indexer returned no top-k — indexer is broken" tk = topk_idx[0].clamp(0, kv_cache.n_comp - 1).int() - n_tk = tk.shape[0] - # Dequant FP8 nope + gather BF16 rope for top-k - nope_bf16 = kv_cache.comp_nope_selective(tk) # FP8→BF16 (n_tk, 448) - rope_bf16 = kv_cache.comp_rope_selective(tk) # BF16 gather (n_tk, 64) - gbuf[:n_tk, :nope_dim] = nope_bf16 - gbuf[:n_tk, nope_dim:] = rope_bf16 - gbuf[n_tk:n_tk + swa_len] = swa_kv - all_kv = gbuf[:n_tk + swa_len] + kv_nope_fp8, kv_nope_scale, kv_rope_bf16 = kv_cache.gather_mixed_selective(tk) elif ratio > 4: - # HCA: dequant all entries - n_comp = kv_cache.n_comp - nope_bf16 = kv_cache.comp_nope_all # FP8→BF16 (n_comp, 448) - rope_bf16 = kv_cache.comp_rope_all # BF16 (n_comp, 64) - gbuf[:n_comp, :nope_dim] = nope_bf16 - gbuf[:n_comp, nope_dim:] = rope_bf16 - gbuf[n_comp:n_comp + swa_len] = swa_kv - all_kv = gbuf[:n_comp + swa_len] + # HCA: dense over compressed rows, still mixed storage. + kv_nope_fp8, kv_nope_scale, kv_rope_bf16 = kv_cache.gather_mixed_all() else: - gbuf[:swa_len] = swa_kv - all_kv = gbuf[:swa_len] + kv_nope_fp8, kv_nope_scale, kv_rope_bf16 = kv_cache.gather_mixed_swa_only() else: - gbuf[:swa_len] = swa_kv - all_kv = gbuf[:swa_len] - seq_len = all_kv.shape[0] + kv_nope_fp8, kv_nope_scale, kv_rope_bf16 = kv_cache.gather_mixed_swa_only() + seq_len = kv_nope_scale.shape[0] if seq_len == 0: return torch.zeros(T, cfg["hidden_size"], dtype=torch.bfloat16, device=dev), q_a - # 6. Production FMHA + # 6. Production FMHA — B1 mixed FP8/BF16 decode path. _pt('fmha_start') if VERBOSE >= 2 and li < 3: - print(f" L{li} FMHA input: T={T} seq_len={seq_len} hd={hd} n_h={n_h} n_comp={kv_cache.n_comp} swa_len={swa_len}", flush=True) - attn_out = _run_production_fmha(q_heads, all_kv, n_h, hd, T, seq_len, scale, dev, li, w, pfx) + print(f" L{li} FMHA mixed input: T={T} seq_len={seq_len} hd={hd} n_h={n_h} n_comp={kv_cache.n_comp} swa_len={swa_len}", flush=True) + attn_out = _run_production_fmha_mixed( + q_heads, kv_nope_fp8, kv_nope_scale, kv_rope_bf16, + n_h, hd, T, seq_len, scale, dev, li, w, pfx, rd) _pt('fmha_end') if VERBOSE >= 2 and li < 3: - # Compare with PyTorch reference - k_exp = all_kv.unsqueeze(0).expand(n_h, -1, -1).contiguous() - v_exp = k_exp.clone() - q_in = q_heads.permute(1, 0, 2) - ref_scores = torch.matmul(q_in, k_exp.transpose(-1, -2)) * scale - ref_attn = torch.matmul(torch.softmax(ref_scores.float(), -1).bfloat16(), v_exp).permute(1, 0, 2) - cos_sim = torch.nn.functional.cosine_similarity(attn_out.flatten().float(), ref_attn.flatten().float(), dim=0).item() - print(f" L{li} FMHA: |prod|={attn_out.abs().max().item():.6f} |ref|={ref_attn.abs().max().item():.6f} cos={cos_sim:.6f}", flush=True) + print(f" L{li} FMHA mixed: |prod|={attn_out.abs().max().item():.6f} (reference disabled: B1 forbids global BF16 KV staging)", flush=True) # 7. Inverse RoPE _pt('inv_rope_start') attn_out = _apply_rope(attn_out, positions, rope_cos, rope_sin, rd, inverse=True)