B1: mixed FP8/BF16 decode FMHA integration

- New: fmha_mixed_fp8_decode.cuh (Blackwell FP8 tensor-core FMHA kernel)
- New: fmha_mixed_fp8_capi.cu (C ABI launcher)
- New: fmha_mixed_fp8_op.py (Python ctypes/nvcc bridge)
- New: fp8_attention_io.cu (Q quantize + mixed KV gather kernels)
- New: fmha_umma_desc.cuh additions (f8f6f4 UMMA + idesc helpers)
- Modified: production.py (dsv4_attention_mixed_fp8_decode API)
- Modified: single_shot_inference.py (B1 gather + FMHA path)
- Modified: __init__.py (export mixed FP8 API)
- New: docs/B1_MIXED_FP8_FMHA.md, FINAL_STRETCH.md

noPE KV stays FP8_E4M3 + per-row scale, RoPE stays BF16.
No global FP8->BF16 KV staging before FMHA.
Decode-only (T==1), specialized HD=512/NOPE=448/ROPE=64.
CUDA compile/runtime validation pending on B200.
This commit is contained in:
2026-06-02 22:53:14 +00:00
parent 2eb4f0886e
commit a9d5e09f4c
11 changed files with 1095 additions and 37 deletions

1
.gitignore vendored
View File

@@ -1,3 +1,4 @@
__pycache__/
*.pyc
*.egg-info/
nvfp4-megamoe-kernel-*.zip

View File

@@ -25,6 +25,9 @@ Goal: native NVFP4 where the math allows, FP8_E4M3 where it doesn't, BF16/FP32 o
- Prevents infinite spin after crash/kill during CUDA kernel compilation
## B1 — FP8_E4M3 FMHA (BIG win; perf + memory + native Blackwell)
> Implementation note from ChatGPT B1 pass: a decode-only mixed FP8/BF16 FMHA path has been added. See `docs/B1_MIXED_FP8_FMHA.md`. CUDA compile/runtime validation still needs to be run on a Blackwell box with `nvcc`.
Today: KV is *stored* mixed (FP8 nope + BF16 rope), then in "5. Gather KV" it's **dequantized to BF16** into `gbuf`, and the FMHA runs in **BF16**. That throws away the FP8 you stored and runs the heaviest kernel at half the tensor-core throughput Blackwell offers.
NVFP4 KV is correctly ruled out — your own `KVCache` docstring shows 4-bit KV values cost ~0.4%/round-trip that compounds fatally over 61 layers. **FP8_E4M3 is the right target**, and you already store the nope dims in it. Plan:

55
docs/B1_MIXED_FP8_FMHA.md Normal file
View File

@@ -0,0 +1,55 @@
# B1 Mixed FP8/BF16 FMHA first pass
Implemented a decode-only DeepSeek-V4 attention path that keeps the cache in the paper/native storage format:
- noPE KV: FP8_E4M3 bytes plus per-row FP32 scale
- RoPE KV: BF16
- Q noPE: quantized BF16 -> FP8_E4M3 immediately before FMHA
- Q RoPE: BF16
The live `forward_attention` path now gathers compressed rows and the SWA tail into mixed buffers and calls `dsv4_attention_mixed_fp8_decode`; it no longer dequantizes noPE KV into `gather_buf` before attention.
## New files
- `dsv4/kernels/cuda/fp8_attention_io.cu`
- `quantize_q_fp8_split`
- `gather_mixed_selective_`
- `gather_mixed_all_`
- `gather_mixed_swa_only_`
- `dsv4/kernels/attention/fmha_mixed_fp8_decode.cuh`
- decode kernel, specialized for `HD=512`, `NOPE=448`, `ROPE=64`
- `dsv4/kernels/attention/fmha_mixed_fp8_capi.cu`
- C ABI launcher
- `dsv4/kernels/attention/fmha_mixed_fp8_op.py`
- Python ctypes/nvcc bridge
## Modified files
- `dsv4/kernels/attention/fmha_umma_desc.cuh`
- added `.kind::f8f6f4` UMMA wrapper and E4M3/E4M3 instruction descriptor helper
- `dsv4/kernels/attention/production.py`
- added `dsv4_attention_mixed_fp8_decode`
- `dsv4/kernels/attention/__init__.py`
- exported mixed FP8 API
- `single_shot_inference.py`
- added mixed gather buffers/methods to `KVCache`
- changed step 5 gather to preserve FP8 noPE globally
- changed step 6 FMHA to call the mixed FP8 decode path
## Intentional first-pass limits
- Decode only (`T == 1`). The launcher hard-errors for prefill.
- Specialized to DeepSeek-V4 attention dimensions (`512/448/64`).
- noPE QK uses Blackwell FP8 tensor cores; RoPE QK and PV use BF16 tensor cores.
- noPE V is dequantized only inside shared memory immediately before the PV BF16 tensor-core multiply. There is no global BF16 KV staging.
## Validation status
The sandbox used to make this patch does not have `nvcc`, so CUDA compilation/runtime validation was not possible here. Python syntax was checked with:
```bash
python3 -m py_compile single_shot_inference.py \
dsv4/kernels/attention/production.py \
dsv4/kernels/attention/fmha_mixed_fp8_op.py
```

View File

@@ -4,3 +4,4 @@ The live inference path uses dsv4.kernels.attention.production directly.
See production.py for the dsv4_attention function used by single_shot_inference.py.
"""
from dsv4.kernels.attention.production import dsv4_attention
from dsv4.kernels.attention.production import dsv4_attention_mixed_fp8_decode

View File

@@ -0,0 +1,79 @@
#include <cuda.h>
#include <cuda_runtime.h>
#include <cstdint>
#include "fmha_common.cuh"
#include "fmha_umma_desc.cuh"
#include "fmha_mixed_fp8_decode.cuh"
using namespace dsv4::kernels::attention;
extern "C" {
int fmha_mixed_fp8_decode_launch(
const void* q_nope_fp8,
const float* q_nope_scale,
const void* q_rope_bf16,
const void* k_nope_fp8,
const float* k_nope_scale,
const void* k_rope_bf16,
void* o_ptr,
void* lse_ptr,
const float* sink_bias_ptr,
int B, int H, int T, int N, int HD, int NOPE, int ROPE,
int q_nope_head_stride, int q_nope_batch_stride,
int q_scale_head_stride, int q_scale_batch_stride,
int q_rope_head_stride, int q_rope_batch_stride,
int o_head_stride, int o_batch_stride,
int lse_head_stride, int lse_batch_stride,
float scale
) {
if (T != 1 || HD != 512 || NOPE != 448 || ROPE != 64) return -2;
FmhaMixedFp8DecodeParams p;
p.q_nope_fp8 = (const uint8_t*)q_nope_fp8;
p.q_nope_scale = q_nope_scale;
p.q_rope_bf16 = (const bf16_t*)q_rope_bf16;
p.k_nope_fp8 = (const uint8_t*)k_nope_fp8;
p.k_nope_scale = k_nope_scale;
p.k_rope_bf16 = (const bf16_t*)k_rope_bf16;
p.o = (bf16_t*)o_ptr;
p.lse = (float*)lse_ptr;
p.sink_bias = sink_bias_ptr;
p.B = B; p.H = H; p.N = N; p.HD = HD; p.NOPE = NOPE; p.ROPE = ROPE;
p.q_nope_head_stride = q_nope_head_stride;
p.q_nope_batch_stride = q_nope_batch_stride;
p.q_scale_head_stride = q_scale_head_stride;
p.q_scale_batch_stride = q_scale_batch_stride;
p.q_rope_head_stride = q_rope_head_stride;
p.q_rope_batch_stride = q_rope_batch_stride;
p.o_head_stride = o_head_stride;
p.o_batch_stride = o_batch_stride;
p.lse_head_stride = lse_head_stride;
p.lse_batch_stride = lse_batch_stride;
p.scale = scale;
// Static shared memory size for fmha_mixed_fp8_decode_kernel<512,448,64>.
// Keep this mirrored with the header layout and aligned up generously.
int smem = 0;
smem += 4; smem = (smem + 127) & ~127;
smem += 128 * 32; smem = (smem + 127) & ~127; // sQ8
smem += 128 * 32; smem = (smem + 127) & ~127; // sK8
smem += 128 * 16 * 2; smem = (smem + 127) & ~127; // sQ16
smem += 128 * 16 * 2; smem = (smem + 127) & ~127; // sK16
smem += 128 * 16 * 2; smem = (smem + 127) & ~127; // sPk
smem += 16 * 16 * 2; smem = (smem + 127) & ~127; // sV
smem += 128 * 4; // sLogits
smem += 128 * 4; // sP
smem += 512 * 4; // sOacc
smem += 512 * 2; // sOepi
smem = (smem + 127) & ~127;
cudaFuncSetAttribute(fmha_mixed_fp8_decode_kernel<512,448,64>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem);
dim3 grid(1, H, B);
dim3 block(192);
fmha_mixed_fp8_decode_kernel<512,448,64><<<grid, block, smem>>>(p);
cudaError_t err = cudaGetLastError();
return err == cudaSuccess ? 0 : (int)err;
}
} // extern C

View File

@@ -0,0 +1,374 @@
/**
* DSV4 B1 — mixed FP8/BF16 decode FMHA for DeepSeek-V4 attention KV.
*
* Inputs are the storage-native DSV4 layout:
* Q noPE: FP8_E4M3 + per-row FP32 scale, Q RoPE: BF16
* KV noPE: FP8_E4M3 + per-row FP32 scale, KV RoPE: BF16
*
* This first B1 kernel targets the decode hot path (T == 1) and HD=512,
* NOPE=448, ROPE=64. It removes the global FP8->BF16 KV dequant/gather and
* uses Blackwell tcgen05 tensor cores for:
* - noPE QK: f8f6f4 E4M3 x E4M3 -> FP32
* - RoPE QK: f16 BF16 x BF16 -> FP32
* - PV: f16 BF16 x BF16 -> FP32, with noPE V dequantized only into SMEM
*
* The noPE KV is never materialized as a global BF16 buffer.
*/
#pragma once
#include <cuda_runtime.h>
#include <cuda_fp8.h>
#include <cuda_fp8.hpp>
#include <cstdint>
#include <cmath>
#include "fmha_common.cuh"
#include "fmha_umma_desc.cuh"
namespace dsv4::kernels::attention {
struct FmhaMixedFp8DecodeParams {
const uint8_t* __restrict__ q_nope_fp8; // (B,H,1,NOPE)
const float* __restrict__ q_nope_scale; // (B,H,1)
const bf16_t* __restrict__ q_rope_bf16; // (B,H,1,ROPE)
const uint8_t* __restrict__ k_nope_fp8; // (N,NOPE), MQA shared
const float* __restrict__ k_nope_scale; // (N,)
const bf16_t* __restrict__ k_rope_bf16; // (N,ROPE)
bf16_t* __restrict__ o; // (B,H,1,HD)
float* __restrict__ lse; // (B,H,1), optional
const float* __restrict__ sink_bias; // (B,H), optional
int B, H, N, HD, NOPE, ROPE;
int q_nope_head_stride, q_nope_batch_stride;
int q_scale_head_stride, q_scale_batch_stride;
int q_rope_head_stride, q_rope_batch_stride;
int o_head_stride, o_batch_stride;
int lse_head_stride, lse_batch_stride;
float scale;
};
__device__ __forceinline__ float fp8_e4m3_to_f32(uint8_t byte) {
__nv_fp8_e4m3 v;
*reinterpret_cast<uint8_t*>(&v) = byte;
return static_cast<float>(v);
}
// FP8 canonical K-major layout for tcgen05.mma.kind::f8f6f4.
// Logical matrix shape is (128, 32): 8 row groups x 16 FP8 columns per 128B atom.
__device__ __forceinline__ int canon_idx_fp8_128x32(int r, int c) {
constexpr int CORES_MN = 16; // 128 / 8
int core_mn = r >> 3;
int core_k = c >> 4; // 16 FP8 values = 16B atom width
int local_r = r & 7;
int local_c = c & 15;
return core_k * CORES_MN * 128 + core_mn * 128 + local_r * 16 + local_c;
}
__device__ __forceinline__ int canon_idx_bf16_128x16(int r, int c) {
constexpr int CORES_MN = 16;
int core_mn = r >> 3;
int core_k = c >> 3;
int local_r = r & 7;
int local_c = c & 7;
return core_k * CORES_MN * 64 + core_mn * 64 + local_r * 8 + local_c;
}
__device__ __forceinline__ int canon_idx_bf16_16x16(int r, int c) {
constexpr int CORES_MN = 2; // 16 / 8
int core_mn = r >> 3;
int core_k = c >> 3;
int local_r = r & 7;
int local_c = c & 7;
return core_k * CORES_MN * 64 + core_mn * 64 + local_r * 8 + local_c;
}
__device__ __forceinline__ bf16_t f32_to_bf16_bits(float x) { return f32_to_bf16(x); }
// Read row 0 of a 128-wide TMEM result. Must be called by a full warp;
// lane 0 receives row 0, lanes 1..31 receive rows 1..31 and are ignored.
__device__ __forceinline__ void read_tmem_row0_128(uint32_t tb, float* out128, bool lane0) {
for (int n = 0; n < 16; n++) {
float tmp[8];
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];"
: "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),
"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7])
: "r"(tb + n * 8));
asm volatile("tcgen05.wait::ld.sync.aligned;" ::: "memory");
if (lane0) {
#pragma unroll
for (int c = 0; c < 8; c++) out128[n * 8 + c] = tmp[c];
}
}
}
template<int HD=512, int NOPE=448, int ROPE=64, int SK_TILE=128>
__global__ void __launch_bounds__(192)
fmha_mixed_fp8_decode_kernel(FmhaMixedFp8DecodeParams p) {
static_assert(HD == 512 && NOPE == 448 && ROPE == 64, "B1 first pass is specialized for DSV4 HD=512/NOPE=448/ROPE=64");
constexpr int MMA_K_F8 = 32;
constexpr int MMA_K_F16 = 16;
constexpr int NKT_NOPE = NOPE / MMA_K_F8;
constexpr int NKT_ROPE = ROPE / MMA_K_F16;
constexpr int NKT_PV = SK_TILE / MMA_K_F16;
constexpr int N_SUB = HD / 16;
constexpr int TILE_F8 = 128 * MMA_K_F8; // bytes
constexpr int TILE_F16 = 128 * MMA_K_F16; // bf16 elements
constexpr int V_SUB_SZ = 16 * MMA_K_F16; // bf16 elements
constexpr int TMEM_COLS = 512;
const int head_idx = blockIdx.y;
const int batch_idx = blockIdx.z;
const int tid = threadIdx.x;
const int wid = tid >> 5;
const int lane = tid & 31;
const bool is_mma_warp = (wid == 4);
const bool is_lane0 = (wid == 0 && lane == 0);
const int n_kv_tiles = (p.N + SK_TILE - 1) / SK_TILE;
const uint8_t* q8 = p.q_nope_fp8 + batch_idx * p.q_nope_batch_stride + head_idx * p.q_nope_head_stride;
const float q8_scale = p.q_nope_scale[batch_idx * p.q_scale_batch_stride + head_idx * p.q_scale_head_stride];
const bf16_t* qrope = p.q_rope_bf16 + batch_idx * p.q_rope_batch_stride + head_idx * p.q_rope_head_stride;
bf16_t* out = p.o + batch_idx * p.o_batch_stride + head_idx * p.o_head_stride;
float* lse = p.lse ? p.lse + batch_idx * p.lse_batch_stride + head_idx * p.lse_head_stride : nullptr;
extern __shared__ __align__(128) char sbuf[];
size_t off = 0;
uint32_t* sTmemBase = (uint32_t*)(sbuf + off); off += 4;
off = (off + 127) & ~(size_t)127;
uint8_t* sQ8 = (uint8_t*)(sbuf + off); off += TILE_F8;
off = (off + 127) & ~(size_t)127;
uint8_t* sK8 = (uint8_t*)(sbuf + off); off += TILE_F8;
off = (off + 127) & ~(size_t)127;
bf16_t* sQ16 = (bf16_t*)(sbuf + off); off += TILE_F16 * sizeof(bf16_t);
off = (off + 127) & ~(size_t)127;
bf16_t* sK16 = (bf16_t*)(sbuf + off); off += TILE_F16 * sizeof(bf16_t);
off = (off + 127) & ~(size_t)127;
bf16_t* sPk = (bf16_t*)(sbuf + off); off += TILE_F16 * sizeof(bf16_t);
off = (off + 127) & ~(size_t)127;
bf16_t* sV = (bf16_t*)(sbuf + off); off += V_SUB_SZ * sizeof(bf16_t);
off = (off + 127) & ~(size_t)127;
float* sLogits = (float*)(sbuf + off); off += SK_TILE * sizeof(float);
float* sP = (float*)(sbuf + off); off += SK_TILE * sizeof(float);
float* sOacc = (float*)(sbuf + off); off += HD * sizeof(float);
bf16_t* sOepi = (bf16_t*)(sbuf + off); off += HD * sizeof(bf16_t);
if (is_mma_warp) tmem_alloc((uint32_t)__cvta_generic_to_shared(sTmemBase), TMEM_COLS);
asm volatile("fence.proxy.async.shared::cta;" ::: "memory");
__syncthreads();
uint32_t tb = *sTmemBase;
if (tid < HD) sOacc[tid] = 0.0f;
if (tid < SK_TILE) { sLogits[tid] = -INFINITY; sP[tid] = 0.0f; }
__syncthreads();
float running_max = -INFINITY;
float running_sum = 0.0f;
const uint32_t idesc_f8_qk = make_idesc_f8_e4m3(128, 128);
const uint32_t idesc_f16_qk = make_idesc(128, 128);
const uint32_t idesc_pv = make_idesc(128, 16);
for (int kv_tile = 0; kv_tile < n_kv_tiles; kv_tile++) {
const int kv_start = kv_tile * SK_TILE;
const int kv_len = min(SK_TILE, p.N - kv_start);
// ------------------------------------------------------------
// QK noPE: FP8 tensor cores, raw logits in TMEM.
// ------------------------------------------------------------
for (int kt = 0; kt < NKT_NOPE; kt++) {
for (int i = tid; i < TILE_F8; i += blockDim.x) { sQ8[i] = 0; sK8[i] = 0; }
__syncthreads();
for (int c = tid; c < MMA_K_F8; c += blockDim.x) {
int d = kt * MMA_K_F8 + c;
sQ8[canon_idx_fp8_128x32(0, c)] = q8[d];
}
for (int i = tid; i < kv_len * MMA_K_F8; i += blockDim.x) {
int r = i / MMA_K_F8, c = i % MMA_K_F8;
int d = kt * MMA_K_F8 + c;
sK8[canon_idx_fp8_128x32(r, c)] = p.k_nope_fp8[(int64_t)(kv_start + r) * NOPE + d];
}
__syncthreads();
if (is_mma_warp && lane == 0) {
uint64_t dq = make_umma_desc_kmajor_none((uint32_t)__cvta_generic_to_shared(sQ8), 128);
uint64_t dk = make_umma_desc_kmajor_none((uint32_t)__cvta_generic_to_shared(sK8), 128);
umma_ss_f8f6f4(tb, dq, dk, idesc_f8_qk, kt > 0);
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
}
__syncthreads();
}
asm volatile("fence.sc.gpu;" ::: "memory");
__syncthreads();
if (wid == 0) read_tmem_row0_128(tb, sLogits, lane == 0);
__syncthreads();
if (is_lane0) {
#pragma unroll
for (int c = 0; c < SK_TILE; c++) {
if (c < kv_len) {
float ks = p.k_nope_scale[kv_start + c];
sLogits[c] = sLogits[c] * q8_scale * ks;
} else {
sLogits[c] = -INFINITY;
}
}
}
__syncthreads();
// ------------------------------------------------------------
// QK RoPE: BF16 tensor cores, then add to scaled noPE logits.
// ------------------------------------------------------------
for (int kt = 0; kt < NKT_ROPE; kt++) {
for (int i = tid; i < TILE_F16; i += blockDim.x) { sQ16[i] = 0; sK16[i] = 0; }
__syncthreads();
for (int c = tid; c < MMA_K_F16; c += blockDim.x) {
int d = kt * MMA_K_F16 + c;
sQ16[canon_idx_bf16_128x16(0, c)] = qrope[d];
}
for (int i = tid; i < kv_len * MMA_K_F16; i += blockDim.x) {
int r = i / MMA_K_F16, c = i % MMA_K_F16;
int d = kt * MMA_K_F16 + c;
sK16[canon_idx_bf16_128x16(r, c)] = p.k_rope_bf16[(int64_t)(kv_start + r) * ROPE + d];
}
__syncthreads();
if (is_mma_warp && lane == 0) {
uint64_t dq = make_umma_desc_kmajor_none((uint32_t)__cvta_generic_to_shared(sQ16), 128);
uint64_t dk = make_umma_desc_kmajor_none((uint32_t)__cvta_generic_to_shared(sK16), 128);
umma_ss_f16(tb, dq, dk, idesc_f16_qk, kt > 0);
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
}
__syncthreads();
}
asm volatile("fence.sc.gpu;" ::: "memory");
__syncthreads();
// Use sP as a temporary row buffer here; probabilities are formed later.
if (wid == 0) read_tmem_row0_128(tb, sP, lane == 0);
__syncthreads();
if (is_lane0) {
for (int c = 0; c < kv_len; c++) sLogits[c] += sP[c];
}
__syncthreads();
// ------------------------------------------------------------
// Softmax tile probabilities for row 0.
// ------------------------------------------------------------
float tile_max = -INFINITY;
if (is_lane0) {
for (int c = 0; c < kv_len; c++) tile_max = fmaxf(tile_max, sLogits[c] * p.scale);
float tile_sum = 0.0f;
for (int c = 0; c < kv_len; c++) {
float pv = expf(sLogits[c] * p.scale - tile_max);
sP[c] = pv;
tile_sum += pv;
}
for (int c = kv_len; c < SK_TILE; c++) sP[c] = 0.0f;
float new_max = fmaxf(running_max, tile_max);
float rescale_old = (running_max > -INFINITY) ? expf(running_max - new_max) : 0.0f;
for (int d = 0; d < HD; d++) sOacc[d] *= rescale_old;
running_sum = running_sum * rescale_old + tile_sum * expf(tile_max - new_max);
running_max = new_max;
}
__syncthreads();
// ------------------------------------------------------------
// PV: probabilities BF16 x V BF16. noPE V is dequantized into SMEM only.
// ------------------------------------------------------------
for (int n_sub = 0; n_sub < N_SUB; n_sub++) {
int d_base = n_sub * 16;
for (int pv_kt = 0; pv_kt < NKT_PV; pv_kt++) {
const int col_start = pv_kt * MMA_K_F16;
for (int i = tid; i < TILE_F16; i += blockDim.x) sPk[i] = 0;
for (int i = tid; i < V_SUB_SZ; i += blockDim.x) sV[i] = 0;
__syncthreads();
// P matrix: only row 0 non-zero.
for (int c = tid; c < MMA_K_F16; c += blockDim.x) {
int gc = col_start + c;
sPk[canon_idx_bf16_128x16(0, c)] = f32_to_bf16_bits(sP[gc]);
}
// V matrix B: logical (16 K rows, 16 N cols) in BF16 canonical layout.
for (int i = tid; i < 16 * MMA_K_F16; i += blockDim.x) {
int dd = i / MMA_K_F16;
int kk = i % MMA_K_F16;
int row = col_start + kk;
int g_row = kv_start + row;
int d = d_base + dd;
bf16_t vbits = 0;
if (row < kv_len) {
if (d < NOPE) {
uint8_t b = p.k_nope_fp8[(int64_t)g_row * NOPE + d];
float v = fp8_e4m3_to_f32(b) * p.k_nope_scale[g_row];
vbits = f32_to_bf16_bits(v);
} else {
vbits = p.k_rope_bf16[(int64_t)g_row * ROPE + (d - NOPE)];
}
}
// B is (K=16 rows, N=16 cols). Reuse BF16 canonical with rows=16
// by embedding into the first 16 rows of a 128-row tile; MMA_N=16.
sV[canon_idx_bf16_16x16(kk, dd)] = vbits;
}
__syncthreads();
if (is_mma_warp && lane == 0) {
uint64_t dp = make_umma_desc_kmajor_none((uint32_t)__cvta_generic_to_shared(sPk), 128);
uint64_t dv = make_umma_desc_kmajor_none((uint32_t)__cvta_generic_to_shared(sV), 16);
umma_ss_f16(tb + n_sub * 16, dp, dv, idesc_pv, pv_kt > 0);
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
}
__syncthreads();
}
}
asm volatile("fence.sc.gpu;" ::: "memory");
__syncthreads();
// Accumulate PV tile contribution after applying exp(tile_max-new_max).
if (wid == 0) {
float rescale_new = 0.0f;
if (lane == 0) {
// running_max is already the post-tile max. Recompute tile scale.
float tile_max2 = -INFINITY;
for (int c = 0; c < kv_len; c++) tile_max2 = fmaxf(tile_max2, sLogits[c] * p.scale);
rescale_new = expf(tile_max2 - running_max);
}
for (int n = 0; n < HD / 8; n++) {
float tmp[8];
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];"
: "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),
"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7])
: "r"(tb + n * 8));
asm volatile("tcgen05.wait::ld.sync.aligned;" ::: "memory");
if (lane == 0) {
#pragma unroll
for (int c = 0; c < 8; c++) sOacc[n * 8 + c] += tmp[c] * rescale_new;
}
}
}
__syncthreads();
}
// Attention sink: denominator-only logit.
if (is_lane0 && p.sink_bias != nullptr) {
float sb = p.sink_bias[batch_idx * p.H + head_idx];
float new_max = fmaxf(running_max, sb);
float rescale_old = (running_max > -INFINITY) ? expf(running_max - new_max) : 0.0f;
for (int d = 0; d < HD; d++) sOacc[d] *= rescale_old;
running_sum = running_sum * rescale_old + expf(sb - new_max);
running_max = new_max;
}
__syncthreads();
if (is_lane0) {
float inv_sum = 1.0f / running_sum;
for (int d = 0; d < HD; d++) sOepi[d] = f32_to_bf16_bits(sOacc[d] * inv_sum);
if (lse) lse[0] = logf(running_sum) + running_max;
}
__syncthreads();
for (int d = tid; d < HD; d += blockDim.x) out[d] = sOepi[d];
__syncthreads();
if (is_mma_warp) tmem_dealloc(tb, TMEM_COLS);
}
} // namespace dsv4::kernels::attention

View File

@@ -0,0 +1,148 @@
"""DSV4 B1 mixed FP8/BF16 decode FMHA loader.
This path is intentionally hard-error only: it does not fall back to PyTorch or to
BF16 FMHA if the mixed FP8 kernel is requested.
"""
import ctypes
import logging
import os
import subprocess
from typing import Optional
import torch
logger = logging.getLogger(__name__)
KERNEL_DIR = os.path.dirname(os.path.abspath(__file__))
REPO_ROOT = os.path.normpath(os.path.join(KERNEL_DIR, "..", ".."))
SOURCE = os.path.join(KERNEL_DIR, "fmha_mixed_fp8_capi.cu")
BUILD_DIR = os.path.join(REPO_ROOT, "build", "fmha_mixed_fp8")
SO_NAME = "libfmha_mixed_fp8.so"
_lib = None
_lib_lock = False
def _find_nvcc():
import shutil
for c in ["/usr/local/cuda-13.2/bin/nvcc", "/usr/local/cuda/bin/nvcc"]:
if os.path.isfile(c):
return c
nvcc = shutil.which("nvcc")
if nvcc:
return nvcc
raise RuntimeError("nvcc not found")
def _ensure_built():
global _lib, _lib_lock
if _lib is not None:
return _lib
if _lib_lock:
raise RuntimeError("Recursive mixed-FP8 FMHA build")
_lib_lock = True
try:
so_path = os.path.join(BUILD_DIR, SO_NAME)
deps = [
SOURCE,
os.path.join(KERNEL_DIR, "fmha_common.cuh"),
os.path.join(KERNEL_DIR, "fmha_umma_desc.cuh"),
os.path.join(KERNEL_DIR, "fmha_mixed_fp8_decode.cuh"),
]
src_mtime = max(os.path.getmtime(p) for p in deps if os.path.exists(p))
need_build = not os.path.isfile(so_path) or src_mtime > os.path.getmtime(so_path)
if not need_build:
_lib = ctypes.CDLL(so_path)
return _lib
os.makedirs(BUILD_DIR, exist_ok=True)
nvcc = _find_nvcc()
cmd = [
nvcc, "-std=c++20", "-shared", "-Xcompiler", "-fPIC",
"-gencode=arch=compute_100a,code=sm_100a",
"-gencode=arch=compute_100a,code=compute_100a",
f"-I{KERNEL_DIR}", f"-I{REPO_ROOT}",
"-O3", "--use_fast_math", "--expt-relaxed-constexpr",
SOURCE, "-o", so_path, "-lcudart", "-lcuda",
]
logger.info("Building libfmha_mixed_fp8.so (sm_100a)...")
res = subprocess.run(cmd, capture_output=True, text=True)
if res.returncode != 0:
raise RuntimeError(f"mixed FP8 FMHA nvcc failed:\nSTDOUT:\n{res.stdout}\nSTDERR:\n{res.stderr}")
_lib = ctypes.CDLL(so_path)
return _lib
finally:
_lib_lock = False
def _quantize_q_split(q: torch.Tensor, rope_dim: int):
from dsv4.kernels.cuda.loader import get_cuda_module
mod = get_cuda_module("fp8_attention_io", ["fp8_attention_io.cu"],
extra_cuda_cflags=[
"-gencode=arch=compute_100a,code=sm_100a",
"-O3", "--use_fast_math", "--expt-relaxed-constexpr",
])
return mod.quantize_q_fp8_split(q, rope_dim)
def fmha_mixed_fp8_decode_raw(
q: torch.Tensor, # (B,H,1,HD) BF16
k_nope_fp8: torch.Tensor, # (N,NOPE) uint8/float8_e4m3fn
k_nope_scale: torch.Tensor, # (N,) FP32
k_rope_bf16: torch.Tensor, # (N,ROPE) BF16
scale: float,
attn_sink: Optional[torch.Tensor] = None,
rope_dim: int = 64,
):
if q.dim() != 4:
raise RuntimeError("q must be (B,H,T,HD)")
B, H, T, HD = q.shape
if T != 1:
raise RuntimeError("mixed FP8 FMHA supports decode T==1 only")
NOPE = HD - rope_dim
if HD != 512 or NOPE != 448 or rope_dim != 64:
raise RuntimeError(f"mixed FP8 FMHA first pass supports HD=512/NOPE=448/ROPE=64, got {HD}/{NOPE}/{rope_dim}")
q = q.contiguous()
k_nope_fp8 = k_nope_fp8.contiguous()
k_nope_scale = k_nope_scale.contiguous()
k_rope_bf16 = k_rope_bf16.contiguous()
q_nope_fp8, q_nope_scale, q_rope = _quantize_q_split(q, rope_dim)
N = k_nope_fp8.shape[0]
o = torch.empty((B, H, T, HD), dtype=torch.bfloat16, device=q.device)
lse = torch.empty((B, H, T), dtype=torch.float32, device=q.device)
sink_ptr = ctypes.c_void_p(0)
sb = None
if attn_sink is not None:
sb = attn_sink.float().contiguous()
if sb.dim() == 1:
sb = sb.unsqueeze(0).expand(B, -1).contiguous()
if tuple(sb.shape) != (B, H):
raise RuntimeError(f"sink bias shape {tuple(sb.shape)} != {(B,H)}")
sink_ptr = ctypes.c_void_p(sb.data_ptr())
lib = _ensure_built()
ret = lib.fmha_mixed_fp8_decode_launch(
ctypes.c_void_p(q_nope_fp8.data_ptr()),
ctypes.c_void_p(q_nope_scale.data_ptr()),
ctypes.c_void_p(q_rope.data_ptr()),
ctypes.c_void_p(k_nope_fp8.data_ptr()),
ctypes.c_void_p(k_nope_scale.data_ptr()),
ctypes.c_void_p(k_rope_bf16.data_ptr()),
ctypes.c_void_p(o.data_ptr()),
ctypes.c_void_p(lse.data_ptr()),
sink_ptr,
ctypes.c_int(B), ctypes.c_int(H), ctypes.c_int(T), ctypes.c_int(N),
ctypes.c_int(HD), ctypes.c_int(NOPE), ctypes.c_int(rope_dim),
ctypes.c_int(q_nope_fp8.stride(1)), ctypes.c_int(q_nope_fp8.stride(0)),
ctypes.c_int(q_nope_scale.stride(1)), ctypes.c_int(q_nope_scale.stride(0)),
ctypes.c_int(q_rope.stride(1)), ctypes.c_int(q_rope.stride(0)),
ctypes.c_int(o.stride(1)), ctypes.c_int(o.stride(0)),
ctypes.c_int(lse.stride(1)), ctypes.c_int(lse.stride(0)),
ctypes.c_float(scale),
)
if ret != 0:
raise RuntimeError(f"mixed FP8 FMHA launch failed: return code {ret}")
return o, lse

View File

@@ -340,4 +340,31 @@ __device__ __forceinline__ uint32_t make_idesc(int block_m, int block_n) {
| ((uint32_t)(block_m >> 4) << 24); // MMA_M
}
/**
* tcgen05.mma SS for .kind::f8f6f4 with E4M3xE4M3 -> FP32.
* A and B element types are encoded in idesc. For B1 we use E4M3/E4M3.
*/
__device__ void umma_ss_f8f6f4(
uint32_t tmem_c, uint64_t desc_a, uint64_t desc_b,
uint32_t i_desc, bool accumulate = false
) {
uint32_t scaleC_bits = accumulate ? 0x3F800000u : 0u;
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::1.kind::f8f6f4 [%0], %1, %2, %3, p;\n\t"
"}"
:: "r"(tmem_c), "l"(desc_a), "l"(desc_b),
"r"(i_desc), "r"(scaleC_bits)
);
}
/** Instruction descriptor for .kind::f8f6f4 E4M3 x E4M3 -> FP32. */
__device__ __forceinline__ uint32_t make_idesc_f8_e4m3(int block_m, int block_n) {
return (1U << 4) // dtype = F32
| ((uint32_t)(block_n >> 3) << 17) // MMA_N
| ((uint32_t)(block_m >> 4) << 24); // MMA_M
}
} // namespace dsv4::kernels::attention

View File

@@ -195,3 +195,41 @@ def dsv4_attention_per_head(
output[q_idx] = o
return output
# ---------------------------------------------------------------------------
# B1: mixed FP8/BF16 DeepSeek-V4 decode attention
# ---------------------------------------------------------------------------
def dsv4_attention_mixed_fp8_decode(
q: torch.Tensor, # (n_q_heads,T,HD) or (B,n_q_heads,T,HD) BF16
k_nope_fp8: torch.Tensor, # (N,NOPE) uint8/float8_e4m3fn
k_nope_scale: torch.Tensor, # (N,) FP32
k_rope_bf16: torch.Tensor, # (N,ROPE) BF16
scale: Optional[float] = None,
sink_bias: Optional[torch.Tensor] = None,
rope_dim: int = 64,
) -> torch.Tensor:
"""B1 production path: storage-native FP8/BF16 KV decode FMHA.
This intentionally has no PyTorch/BF16 fallback. It is the decode-only path
for DeepSeek-V4 attention where noPE KV is already stored as FP8_E4M3 with
per-row FP32 scales and RoPE KV is BF16.
"""
from dsv4.kernels.attention.fmha_mixed_fp8_op import fmha_mixed_fp8_decode_raw
has_batch = q.dim() == 4
if q.dim() == 3:
q4 = q.unsqueeze(0).contiguous()
elif q.dim() == 4:
q4 = q.contiguous()
else:
raise RuntimeError("q must be (H,T,HD) or (B,H,T,HD)")
hd = q4.shape[-1]
scale = scale or (1.0 / math.sqrt(hd))
o4, _lse = fmha_mixed_fp8_decode_raw(
q4, k_nope_fp8, k_nope_scale, k_rope_bf16,
scale, attn_sink=sink_bias, rope_dim=rope_dim,
)
return o4 if has_batch else o4.squeeze(0)

View File

@@ -0,0 +1,254 @@
/**
* DSV4 B1 — FP8 attention input/output preparation kernels.
*
* These are deliberately tiny launch-count reducers for the mixed-precision
* FMHA path:
* - quantize Q noPE dims BF16 -> FP8_E4M3 with a per-(batch,head,row) scale
* - keep Q RoPE dims BF16
* - gather compressed KV noPE bytes/scales and RoPE BF16 without global dequant
* - quantize the SWA noPE tail BF16 -> FP8_E4M3 in the same gather kernel
*
* No PyTorch fallback and no FP8->BF16 global staging for noPE KV.
*/
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp8.h>
#include <cuda_fp8.hpp>
#include <ATen/ATen.h>
#include <c10/cuda/CUDAStream.h>
#include <torch/extension.h>
#include <cstdint>
#include <cfloat>
static constexpr float E4M3_MAX = 448.0f;
__device__ __forceinline__ float bf16_load(const __nv_bfloat16* p) {
return __bfloat162float(*p);
}
__device__ __forceinline__ uint8_t fp8_e4m3_from_f32(float x) {
x = fminf(fmaxf(x, -E4M3_MAX), E4M3_MAX);
__nv_fp8_e4m3 v(x);
return *reinterpret_cast<uint8_t*>(&v);
}
__global__ void quantize_q_fp8_split_kernel(
const __nv_bfloat16* __restrict__ q, // (B,H,T,HD)
uint8_t* __restrict__ q_nope_fp8, // (B,H,T,NOPE)
float* __restrict__ q_nope_scale, // (B,H,T)
__nv_bfloat16* __restrict__ q_rope, // (B,H,T,ROPE)
int rows, int hd, int nope, int rope
) {
int row = blockIdx.x;
if (row >= rows) return;
const __nv_bfloat16* q_row = q + (int64_t)row * hd;
uint8_t* out8 = q_nope_fp8 + (int64_t)row * nope;
__nv_bfloat16* outrope = q_rope + (int64_t)row * rope;
float local_max = 0.0f;
for (int c = threadIdx.x; c < nope; c += blockDim.x) {
local_max = fmaxf(local_max, fabsf(bf16_load(q_row + c)));
}
// block reduction over 256 threads
for (int off = 16; off > 0; off >>= 1)
local_max = fmaxf(local_max, __shfl_down_sync(0xffffffff, local_max, off));
__shared__ float warp_max[8];
if ((threadIdx.x & 31) == 0) warp_max[threadIdx.x >> 5] = local_max;
__syncthreads();
float amax = 0.0f;
if (threadIdx.x < 32) {
amax = (threadIdx.x < (blockDim.x + 31) / 32) ? warp_max[threadIdx.x] : 0.0f;
for (int off = 16; off > 0; off >>= 1)
amax = fmaxf(amax, __shfl_down_sync(0xffffffff, amax, off));
if (threadIdx.x == 0) {
float scale = amax / E4M3_MAX;
if (scale < 1e-8f) scale = 1e-8f;
q_nope_scale[row] = scale;
}
}
__syncthreads();
float scale = q_nope_scale[row];
float inv_scale = 1.0f / scale;
for (int c = threadIdx.x; c < nope; c += blockDim.x) {
out8[c] = fp8_e4m3_from_f32(bf16_load(q_row + c) * inv_scale);
}
for (int c = threadIdx.x; c < rope; c += blockDim.x) {
outrope[c] = q_row[nope + c];
}
}
__global__ void copy_comp_rows_kernel(
const uint8_t* __restrict__ comp_nope_fp8,
const float* __restrict__ comp_nope_scale,
const __nv_bfloat16* __restrict__ comp_rope,
const int32_t* __restrict__ indices, // optional; nullptr => row i
uint8_t* __restrict__ out_nope_fp8,
float* __restrict__ out_nope_scale,
__nv_bfloat16* __restrict__ out_rope,
int K, int nope, int rope
) {
int row = blockIdx.y;
int col = blockIdx.x * blockDim.x + threadIdx.x;
if (row >= K) return;
int src = indices ? indices[row] : row;
if (col < nope) out_nope_fp8[(int64_t)row * nope + col] = comp_nope_fp8[(int64_t)src * nope + col];
if (col < rope) out_rope[(int64_t)row * rope + col] = comp_rope[(int64_t)src * rope + col];
if (blockIdx.x == 0 && threadIdx.x == 0) out_nope_scale[row] = comp_nope_scale[src];
}
__global__ void quantize_swa_tail_kernel(
const __nv_bfloat16* __restrict__ swa, // (S, HD), BF16
uint8_t* __restrict__ out_nope_fp8, // (K+S, NOPE)
float* __restrict__ out_nope_scale, // (K+S)
__nv_bfloat16* __restrict__ out_rope, // (K+S, ROPE)
int K, int S, int hd, int nope, int rope
) {
int s = blockIdx.x;
if (s >= S) return;
int out_row = K + s;
const __nv_bfloat16* src = swa + (int64_t)s * hd;
uint8_t* out8 = out_nope_fp8 + (int64_t)out_row * nope;
__nv_bfloat16* outrope = out_rope + (int64_t)out_row * rope;
float local_max = 0.0f;
for (int c = threadIdx.x; c < nope; c += blockDim.x) {
local_max = fmaxf(local_max, fabsf(bf16_load(src + c)));
}
for (int off = 16; off > 0; off >>= 1)
local_max = fmaxf(local_max, __shfl_down_sync(0xffffffff, local_max, off));
__shared__ float warp_max[8];
if ((threadIdx.x & 31) == 0) warp_max[threadIdx.x >> 5] = local_max;
__syncthreads();
float amax = 0.0f;
if (threadIdx.x < 32) {
amax = (threadIdx.x < (blockDim.x + 31) / 32) ? warp_max[threadIdx.x] : 0.0f;
for (int off = 16; off > 0; off >>= 1)
amax = fmaxf(amax, __shfl_down_sync(0xffffffff, amax, off));
if (threadIdx.x == 0) {
float scale = amax / E4M3_MAX;
if (scale < 1e-8f) scale = 1e-8f;
out_nope_scale[out_row] = scale;
}
}
__syncthreads();
float inv_scale = 1.0f / out_nope_scale[out_row];
for (int c = threadIdx.x; c < nope; c += blockDim.x) {
out8[c] = fp8_e4m3_from_f32(bf16_load(src + c) * inv_scale);
}
for (int c = threadIdx.x; c < rope; c += blockDim.x) {
outrope[c] = src[nope + c];
}
}
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> quantize_q_fp8_split_cuda(
torch::Tensor q, int64_t rope_dim
) {
TORCH_CHECK(q.is_cuda(), "q must be CUDA");
TORCH_CHECK(q.scalar_type() == torch::kBFloat16, "q must be BF16");
TORCH_CHECK(q.dim() == 4, "q must be (B,H,T,HD)");
q = q.contiguous();
int B = q.size(0), H = q.size(1), T = q.size(2), HD = q.size(3);
int rope = (int)rope_dim;
int nope = HD - rope;
TORCH_CHECK(nope > 0 && rope > 0, "invalid rope_dim");
auto q8 = torch::empty({B, H, T, nope}, q.options().dtype(torch::kUInt8));
auto qs = torch::empty({B, H, T}, q.options().dtype(torch::kFloat32));
auto qr = torch::empty({B, H, T, rope}, q.options().dtype(torch::kBFloat16));
int rows = B * H * T;
quantize_q_fp8_split_kernel<<<rows, 256, 0, c10::cuda::getCurrentCUDAStream()>>>(
reinterpret_cast<const __nv_bfloat16*>(q.data_ptr<at::BFloat16>()),
q8.data_ptr<uint8_t>(), qs.data_ptr<float>(),
reinterpret_cast<__nv_bfloat16*>(qr.data_ptr<at::BFloat16>()),
rows, HD, nope, rope);
return {q8.view(torch::kFloat8_e4m3fn), qs, qr};
}
void gather_mixed_selective_cuda(
torch::Tensor comp_nope_fp8, torch::Tensor comp_nope_scale, torch::Tensor comp_rope,
torch::Tensor swa, torch::Tensor indices,
torch::Tensor out_nope_fp8, torch::Tensor out_nope_scale, torch::Tensor out_rope
) {
TORCH_CHECK(indices.scalar_type() == torch::kInt32, "indices must be int32");
int K = indices.size(0);
int S = swa.size(0);
int nope = comp_nope_fp8.size(1);
int rope = comp_rope.size(1);
int hd = nope + rope;
if (K > 0) {
dim3 grid(((nope > rope ? nope : rope) + 255) / 256, K);
copy_comp_rows_kernel<<<grid, 256, 0, c10::cuda::getCurrentCUDAStream()>>>(
comp_nope_fp8.data_ptr<uint8_t>(), comp_nope_scale.data_ptr<float>(),
reinterpret_cast<const __nv_bfloat16*>(comp_rope.data_ptr<at::BFloat16>()),
indices.data_ptr<int32_t>(),
out_nope_fp8.data_ptr<uint8_t>(), out_nope_scale.data_ptr<float>(),
reinterpret_cast<__nv_bfloat16*>(out_rope.data_ptr<at::BFloat16>()),
K, nope, rope);
}
if (S > 0) {
quantize_swa_tail_kernel<<<S, 256, 0, c10::cuda::getCurrentCUDAStream()>>>(
reinterpret_cast<const __nv_bfloat16*>(swa.data_ptr<at::BFloat16>()),
out_nope_fp8.data_ptr<uint8_t>(), out_nope_scale.data_ptr<float>(),
reinterpret_cast<__nv_bfloat16*>(out_rope.data_ptr<at::BFloat16>()),
K, S, hd, nope, rope);
}
}
void gather_mixed_all_cuda(
torch::Tensor comp_nope_fp8, torch::Tensor comp_nope_scale, torch::Tensor comp_rope,
torch::Tensor swa, torch::Tensor out_nope_fp8, torch::Tensor out_nope_scale, torch::Tensor out_rope
) {
int K = comp_nope_fp8.size(0);
int S = swa.size(0);
int nope = comp_nope_fp8.size(1);
int rope = comp_rope.size(1);
int hd = nope + rope;
if (K > 0) {
dim3 grid(((nope > rope ? nope : rope) + 255) / 256, K);
copy_comp_rows_kernel<<<grid, 256, 0, c10::cuda::getCurrentCUDAStream()>>>(
comp_nope_fp8.data_ptr<uint8_t>(), comp_nope_scale.data_ptr<float>(),
reinterpret_cast<const __nv_bfloat16*>(comp_rope.data_ptr<at::BFloat16>()),
nullptr,
out_nope_fp8.data_ptr<uint8_t>(), out_nope_scale.data_ptr<float>(),
reinterpret_cast<__nv_bfloat16*>(out_rope.data_ptr<at::BFloat16>()),
K, nope, rope);
}
if (S > 0) {
quantize_swa_tail_kernel<<<S, 256, 0, c10::cuda::getCurrentCUDAStream()>>>(
reinterpret_cast<const __nv_bfloat16*>(swa.data_ptr<at::BFloat16>()),
out_nope_fp8.data_ptr<uint8_t>(), out_nope_scale.data_ptr<float>(),
reinterpret_cast<__nv_bfloat16*>(out_rope.data_ptr<at::BFloat16>()),
K, S, hd, nope, rope);
}
}
void gather_mixed_swa_only_cuda(torch::Tensor swa, torch::Tensor out_nope_fp8,
torch::Tensor out_nope_scale, torch::Tensor out_rope,
int64_t rope_dim) {
int S = swa.size(0);
int hd = swa.size(1);
int rope = (int)rope_dim;
int nope = hd - rope;
if (S > 0) {
quantize_swa_tail_kernel<<<S, 256, 0, c10::cuda::getCurrentCUDAStream()>>>(
reinterpret_cast<const __nv_bfloat16*>(swa.data_ptr<at::BFloat16>()),
out_nope_fp8.data_ptr<uint8_t>(), out_nope_scale.data_ptr<float>(),
reinterpret_cast<__nv_bfloat16*>(out_rope.data_ptr<at::BFloat16>()),
0, S, hd, nope, rope);
}
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("quantize_q_fp8_split", &quantize_q_fp8_split_cuda,
"Split Q into FP8_E4M3 noPE + BF16 RoPE");
m.def("gather_mixed_selective_", &gather_mixed_selective_cuda,
"In-place mixed KV gather for selected compressed rows + SWA tail");
m.def("gather_mixed_all_", &gather_mixed_all_cuda,
"In-place mixed KV gather for all compressed rows + SWA tail");
m.def("gather_mixed_swa_only_", &gather_mixed_swa_only_cuda,
"In-place mixed KV gather for SWA-only attention");
}

View File

@@ -519,13 +519,29 @@ class KVCache:
self.comp_idx_fp8 = torch.zeros(max_comp, indexer_key_dim, dtype=torch.uint8, device=device)
self.comp_idx_scale = torch.zeros(max_comp, dtype=torch.float32, device=device)
# Pre-allocated gather buffer — top_k compressed + SWA window
# Pre-allocated mixed gather buffers.
# CSA needs top_k + SWA; HCA is dense over compressed blocks, so it needs
# max_comp + SWA. These buffers preserve the paper/native storage layout:
# noPE stays FP8_E4M3 + scale, RoPE stays BF16.
if compress_ratio > 4:
self.mixed_gather_cap = max_comp + window_size
elif compress_ratio == 4:
self.mixed_gather_cap = indexer_top_k + window_size
else:
self.mixed_gather_cap = window_size
self.gather_nope_fp8 = torch.zeros(self.mixed_gather_cap, self.nope_dim, dtype=torch.uint8, device=device)
self.gather_nope_scale = torch.zeros(self.mixed_gather_cap, dtype=torch.float32, device=device)
self.gather_rope_bf16 = torch.zeros(self.mixed_gather_cap, rope_dim, dtype=torch.bfloat16, device=device)
# Legacy BF16 gather buffer kept only for non-B1 experiments; the live
# B1 path below does not materialize noPE KV as global BF16.
self.gather_buf = torch.zeros(indexer_top_k + window_size, head_dim, dtype=torch.bfloat16, device=device)
self.n_comp = 0
self._has_idx = False
# Cache dequant modules (loaded once)
# Cache extension modules (loaded once)
self._kv_quant_mod = None
self._fp8_attn_io_mod = None
def _get_kv_quant_mod(self):
if self._kv_quant_mod is None:
@@ -533,6 +549,18 @@ class KVCache:
self._kv_quant_mod = get_cuda_module("kv_quantize", ["kv_quantize.cu"])
return self._kv_quant_mod
def _get_fp8_attn_io_mod(self):
if self._fp8_attn_io_mod is None:
from dsv4.kernels.cuda.loader import get_cuda_module
self._fp8_attn_io_mod = get_cuda_module(
"fp8_attention_io", ["fp8_attention_io.cu"],
extra_cuda_cflags=[
"-gencode=arch=compute_100a,code=sm_100a",
"-O3", "--use_fast_math", "--expt-relaxed-constexpr",
],
)
return self._fp8_attn_io_mod
def append_swa(self, kv, pos):
"""Vectorized SWA append — 2 kernel launches instead of 2T."""
T = kv.shape[0]
@@ -605,6 +633,53 @@ class KVCache:
self.comp_idx_fp8[:self.n_comp],
self.comp_idx_scale[:self.n_comp])
def gather_mixed_selective(self, indices):
"""Gather selected compressed KV + SWA into mixed FP8/BF16 buffers.
Returns (nope_fp8, nope_scale, rope_bf16), each sliced to total length.
noPE is not dequantized to global BF16.
"""
mod = self._get_fp8_attn_io_mod()
swa_kv, _ = self.get_swa()
idx = indices.int().contiguous()
total = idx.numel() + swa_kv.shape[0]
if total > self.mixed_gather_cap:
raise RuntimeError(f"mixed gather capacity {self.mixed_gather_cap} < requested {total}")
mod.gather_mixed_selective_(
self.comp_nope_fp8, self.comp_nope_scale, self.comp_rope_bf16,
swa_kv, idx, self.gather_nope_fp8, self.gather_nope_scale, self.gather_rope_bf16)
return (self.gather_nope_fp8[:total],
self.gather_nope_scale[:total],
self.gather_rope_bf16[:total])
def gather_mixed_all(self):
"""Gather all compressed KV + SWA in mixed FP8/BF16 storage for HCA."""
mod = self._get_fp8_attn_io_mod()
swa_kv, _ = self.get_swa()
n_comp = int(self.n_comp)
total = n_comp + swa_kv.shape[0]
if total > self.mixed_gather_cap:
raise RuntimeError(f"mixed gather capacity {self.mixed_gather_cap} < requested {total}")
mod.gather_mixed_all_(
self.comp_nope_fp8[:n_comp], self.comp_nope_scale[:n_comp], self.comp_rope_bf16[:n_comp],
swa_kv, self.gather_nope_fp8, self.gather_nope_scale, self.gather_rope_bf16)
return (self.gather_nope_fp8[:total],
self.gather_nope_scale[:total],
self.gather_rope_bf16[:total])
def gather_mixed_swa_only(self):
"""Quantize SWA noPE tail to FP8 and keep SWA RoPE as BF16."""
mod = self._get_fp8_attn_io_mod()
swa_kv, _ = self.get_swa()
total = swa_kv.shape[0]
if total > self.mixed_gather_cap:
raise RuntimeError(f"mixed gather capacity {self.mixed_gather_cap} < requested {total}")
mod.gather_mixed_swa_only_(
swa_kv, self.gather_nope_fp8, self.gather_nope_scale, self.gather_rope_bf16, self.rope_dim)
return (self.gather_nope_fp8[:total],
self.gather_nope_scale[:total],
self.gather_rope_bf16[:total])
def get_swa(self):
"""Return SWA KV and positions as views (no clone)."""
if self.swa_len == 0:
@@ -648,6 +723,28 @@ def _run_production_fmha(q_heads, all_kv, n_h, hd, T, seq_len, scale, dev, li, w
attn_out = dsv4_attention(q=q, k=k, v=v, scale=scale, n_comp=0, sink_bias=sink_bias)
return attn_out.permute(1, 0, 2) # (T, n_h, hd)
def _run_production_fmha_mixed(q_heads, kv_nope_fp8, kv_nope_scale, kv_rope_bf16,
n_h, hd, T, seq_len, scale, dev, li, w, pfx, rope_dim):
"""B1 storage-native mixed FP8/BF16 decode FMHA. No BF16 KV staging."""
if T != 1:
raise RuntimeError(f"B1 mixed FP8 FMHA is decode-only (T==1); got T={T}")
from dsv4.kernels.attention.production import dsv4_attention_mixed_fp8_decode
q = q_heads.permute(1, 0, 2).contiguous() # (n_h, 1, hd)
sinks = w.get(f"{pfx}.sinks"); sink_bias = None
if sinks is not None:
sink_bias = sinks.to(device=dev).float().reshape(n_h)
attn_out = dsv4_attention_mixed_fp8_decode(
q=q,
k_nope_fp8=kv_nope_fp8,
k_nope_scale=kv_nope_scale,
k_rope_bf16=kv_rope_bf16,
scale=scale,
sink_bias=sink_bias,
rope_dim=rope_dim,
)
return attn_out.permute(1, 0, 2) # (T, n_h, hd)
# =====================================================================
# Attention — ALL production kernels
# =====================================================================
@@ -739,57 +836,38 @@ def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin,
if indexer is not None and ratio == 4:
topk_idx = indexer.forward(q_a, x_normed, kv_cache.comp_idx_kv, positions, layer_idx=li)
# 5. Gather KV — mixed storage: FP8 nope dequant + BF16 rope concat
# 5. Gather KV — B1 storage-native mixed path.
# noPE remains FP8_E4M3 + per-row scale; RoPE remains BF16.
# There is no global FP8->BF16 noPE materialization before FMHA.
_pt('gather_start')
swa_kv, _swa_pos = kv_cache.get_swa()
swa_len = swa_kv.shape[0]
gbuf = kv_cache.gather_buf # (max_len, hd) pre-allocated BF16
if kv_cache.n_comp > 0:
if ratio == 4:
# CSA: dequant only top-k entries
# CSA: gather top-k compressed rows + SWA tail without dequantizing noPE.
assert topk_idx is not None, f"CSA layer {li}: indexer returned no top-k — indexer is broken"
tk = topk_idx[0].clamp(0, kv_cache.n_comp - 1).int()
n_tk = tk.shape[0]
# Dequant FP8 nope + gather BF16 rope for top-k
nope_bf16 = kv_cache.comp_nope_selective(tk) # FP8→BF16 (n_tk, 448)
rope_bf16 = kv_cache.comp_rope_selective(tk) # BF16 gather (n_tk, 64)
gbuf[:n_tk, :nope_dim] = nope_bf16
gbuf[:n_tk, nope_dim:] = rope_bf16
gbuf[n_tk:n_tk + swa_len] = swa_kv
all_kv = gbuf[:n_tk + swa_len]
kv_nope_fp8, kv_nope_scale, kv_rope_bf16 = kv_cache.gather_mixed_selective(tk)
elif ratio > 4:
# HCA: dequant all entries
n_comp = kv_cache.n_comp
nope_bf16 = kv_cache.comp_nope_all # FP8→BF16 (n_comp, 448)
rope_bf16 = kv_cache.comp_rope_all # BF16 (n_comp, 64)
gbuf[:n_comp, :nope_dim] = nope_bf16
gbuf[:n_comp, nope_dim:] = rope_bf16
gbuf[n_comp:n_comp + swa_len] = swa_kv
all_kv = gbuf[:n_comp + swa_len]
# HCA: dense over compressed rows, still mixed storage.
kv_nope_fp8, kv_nope_scale, kv_rope_bf16 = kv_cache.gather_mixed_all()
else:
gbuf[:swa_len] = swa_kv
all_kv = gbuf[:swa_len]
kv_nope_fp8, kv_nope_scale, kv_rope_bf16 = kv_cache.gather_mixed_swa_only()
else:
gbuf[:swa_len] = swa_kv
all_kv = gbuf[:swa_len]
seq_len = all_kv.shape[0]
kv_nope_fp8, kv_nope_scale, kv_rope_bf16 = kv_cache.gather_mixed_swa_only()
seq_len = kv_nope_scale.shape[0]
if seq_len == 0: return torch.zeros(T, cfg["hidden_size"], dtype=torch.bfloat16, device=dev), q_a
# 6. Production FMHA
# 6. Production FMHA — B1 mixed FP8/BF16 decode path.
_pt('fmha_start')
if VERBOSE >= 2 and li < 3:
print(f" L{li} FMHA input: T={T} seq_len={seq_len} hd={hd} n_h={n_h} n_comp={kv_cache.n_comp} swa_len={swa_len}", flush=True)
attn_out = _run_production_fmha(q_heads, all_kv, n_h, hd, T, seq_len, scale, dev, li, w, pfx)
print(f" L{li} FMHA mixed input: T={T} seq_len={seq_len} hd={hd} n_h={n_h} n_comp={kv_cache.n_comp} swa_len={swa_len}", flush=True)
attn_out = _run_production_fmha_mixed(
q_heads, kv_nope_fp8, kv_nope_scale, kv_rope_bf16,
n_h, hd, T, seq_len, scale, dev, li, w, pfx, rd)
_pt('fmha_end')
if VERBOSE >= 2 and li < 3:
# Compare with PyTorch reference
k_exp = all_kv.unsqueeze(0).expand(n_h, -1, -1).contiguous()
v_exp = k_exp.clone()
q_in = q_heads.permute(1, 0, 2)
ref_scores = torch.matmul(q_in, k_exp.transpose(-1, -2)) * scale
ref_attn = torch.matmul(torch.softmax(ref_scores.float(), -1).bfloat16(), v_exp).permute(1, 0, 2)
cos_sim = torch.nn.functional.cosine_similarity(attn_out.flatten().float(), ref_attn.flatten().float(), dim=0).item()
print(f" L{li} FMHA: |prod|={attn_out.abs().max().item():.6f} |ref|={ref_attn.abs().max().item():.6f} cos={cos_sim:.6f}", flush=True)
print(f" L{li} FMHA mixed: |prod|={attn_out.abs().max().item():.6f} (reference disabled: B1 forbids global BF16 KV staging)", flush=True)
# 7. Inverse RoPE
_pt('inv_rope_start')
attn_out = _apply_rope(attn_out, positions, rope_cos, rope_sin, rd, inverse=True)