B1: mixed FP8/BF16 decode FMHA integration
- New: fmha_mixed_fp8_decode.cuh (Blackwell FP8 tensor-core FMHA kernel) - New: fmha_mixed_fp8_capi.cu (C ABI launcher) - New: fmha_mixed_fp8_op.py (Python ctypes/nvcc bridge) - New: fp8_attention_io.cu (Q quantize + mixed KV gather kernels) - New: fmha_umma_desc.cuh additions (f8f6f4 UMMA + idesc helpers) - Modified: production.py (dsv4_attention_mixed_fp8_decode API) - Modified: single_shot_inference.py (B1 gather + FMHA path) - Modified: __init__.py (export mixed FP8 API) - New: docs/B1_MIXED_FP8_FMHA.md, FINAL_STRETCH.md noPE KV stays FP8_E4M3 + per-row scale, RoPE stays BF16. No global FP8->BF16 KV staging before FMHA. Decode-only (T==1), specialized HD=512/NOPE=448/ROPE=64. CUDA compile/runtime validation pending on B200.
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -1,3 +1,4 @@
|
||||
__pycache__/
|
||||
*.pyc
|
||||
*.egg-info/
|
||||
nvfp4-megamoe-kernel-*.zip
|
||||
|
||||
@@ -25,6 +25,9 @@ Goal: native NVFP4 where the math allows, FP8_E4M3 where it doesn't, BF16/FP32 o
|
||||
- Prevents infinite spin after crash/kill during CUDA kernel compilation
|
||||
|
||||
## B1 — FP8_E4M3 FMHA (BIG win; perf + memory + native Blackwell)
|
||||
|
||||
> Implementation note from ChatGPT B1 pass: a decode-only mixed FP8/BF16 FMHA path has been added. See `docs/B1_MIXED_FP8_FMHA.md`. CUDA compile/runtime validation still needs to be run on a Blackwell box with `nvcc`.
|
||||
|
||||
Today: KV is *stored* mixed (FP8 nope + BF16 rope), then in "5. Gather KV" it's **dequantized to BF16** into `gbuf`, and the FMHA runs in **BF16**. That throws away the FP8 you stored and runs the heaviest kernel at half the tensor-core throughput Blackwell offers.
|
||||
|
||||
NVFP4 KV is correctly ruled out — your own `KVCache` docstring shows 4-bit KV values cost ~0.4%/round-trip that compounds fatally over 61 layers. **FP8_E4M3 is the right target**, and you already store the nope dims in it. Plan:
|
||||
|
||||
55
docs/B1_MIXED_FP8_FMHA.md
Normal file
55
docs/B1_MIXED_FP8_FMHA.md
Normal file
@@ -0,0 +1,55 @@
|
||||
# B1 Mixed FP8/BF16 FMHA first pass
|
||||
|
||||
Implemented a decode-only DeepSeek-V4 attention path that keeps the cache in the paper/native storage format:
|
||||
|
||||
- noPE KV: FP8_E4M3 bytes plus per-row FP32 scale
|
||||
- RoPE KV: BF16
|
||||
- Q noPE: quantized BF16 -> FP8_E4M3 immediately before FMHA
|
||||
- Q RoPE: BF16
|
||||
|
||||
The live `forward_attention` path now gathers compressed rows and the SWA tail into mixed buffers and calls `dsv4_attention_mixed_fp8_decode`; it no longer dequantizes noPE KV into `gather_buf` before attention.
|
||||
|
||||
## New files
|
||||
|
||||
- `dsv4/kernels/cuda/fp8_attention_io.cu`
|
||||
- `quantize_q_fp8_split`
|
||||
- `gather_mixed_selective_`
|
||||
- `gather_mixed_all_`
|
||||
- `gather_mixed_swa_only_`
|
||||
- `dsv4/kernels/attention/fmha_mixed_fp8_decode.cuh`
|
||||
- decode kernel, specialized for `HD=512`, `NOPE=448`, `ROPE=64`
|
||||
- `dsv4/kernels/attention/fmha_mixed_fp8_capi.cu`
|
||||
- C ABI launcher
|
||||
- `dsv4/kernels/attention/fmha_mixed_fp8_op.py`
|
||||
- Python ctypes/nvcc bridge
|
||||
|
||||
## Modified files
|
||||
|
||||
- `dsv4/kernels/attention/fmha_umma_desc.cuh`
|
||||
- added `.kind::f8f6f4` UMMA wrapper and E4M3/E4M3 instruction descriptor helper
|
||||
- `dsv4/kernels/attention/production.py`
|
||||
- added `dsv4_attention_mixed_fp8_decode`
|
||||
- `dsv4/kernels/attention/__init__.py`
|
||||
- exported mixed FP8 API
|
||||
- `single_shot_inference.py`
|
||||
- added mixed gather buffers/methods to `KVCache`
|
||||
- changed step 5 gather to preserve FP8 noPE globally
|
||||
- changed step 6 FMHA to call the mixed FP8 decode path
|
||||
|
||||
## Intentional first-pass limits
|
||||
|
||||
- Decode only (`T == 1`). The launcher hard-errors for prefill.
|
||||
- Specialized to DeepSeek-V4 attention dimensions (`512/448/64`).
|
||||
- noPE QK uses Blackwell FP8 tensor cores; RoPE QK and PV use BF16 tensor cores.
|
||||
- noPE V is dequantized only inside shared memory immediately before the PV BF16 tensor-core multiply. There is no global BF16 KV staging.
|
||||
|
||||
## Validation status
|
||||
|
||||
The sandbox used to make this patch does not have `nvcc`, so CUDA compilation/runtime validation was not possible here. Python syntax was checked with:
|
||||
|
||||
```bash
|
||||
python3 -m py_compile single_shot_inference.py \
|
||||
dsv4/kernels/attention/production.py \
|
||||
dsv4/kernels/attention/fmha_mixed_fp8_op.py
|
||||
```
|
||||
|
||||
@@ -4,3 +4,4 @@ The live inference path uses dsv4.kernels.attention.production directly.
|
||||
See production.py for the dsv4_attention function used by single_shot_inference.py.
|
||||
"""
|
||||
from dsv4.kernels.attention.production import dsv4_attention
|
||||
from dsv4.kernels.attention.production import dsv4_attention_mixed_fp8_decode
|
||||
|
||||
79
dsv4/kernels/attention/fmha_mixed_fp8_capi.cu
Normal file
79
dsv4/kernels/attention/fmha_mixed_fp8_capi.cu
Normal file
@@ -0,0 +1,79 @@
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cstdint>
|
||||
#include "fmha_common.cuh"
|
||||
#include "fmha_umma_desc.cuh"
|
||||
#include "fmha_mixed_fp8_decode.cuh"
|
||||
|
||||
using namespace dsv4::kernels::attention;
|
||||
|
||||
extern "C" {
|
||||
|
||||
int fmha_mixed_fp8_decode_launch(
|
||||
const void* q_nope_fp8,
|
||||
const float* q_nope_scale,
|
||||
const void* q_rope_bf16,
|
||||
const void* k_nope_fp8,
|
||||
const float* k_nope_scale,
|
||||
const void* k_rope_bf16,
|
||||
void* o_ptr,
|
||||
void* lse_ptr,
|
||||
const float* sink_bias_ptr,
|
||||
int B, int H, int T, int N, int HD, int NOPE, int ROPE,
|
||||
int q_nope_head_stride, int q_nope_batch_stride,
|
||||
int q_scale_head_stride, int q_scale_batch_stride,
|
||||
int q_rope_head_stride, int q_rope_batch_stride,
|
||||
int o_head_stride, int o_batch_stride,
|
||||
int lse_head_stride, int lse_batch_stride,
|
||||
float scale
|
||||
) {
|
||||
if (T != 1 || HD != 512 || NOPE != 448 || ROPE != 64) return -2;
|
||||
|
||||
FmhaMixedFp8DecodeParams p;
|
||||
p.q_nope_fp8 = (const uint8_t*)q_nope_fp8;
|
||||
p.q_nope_scale = q_nope_scale;
|
||||
p.q_rope_bf16 = (const bf16_t*)q_rope_bf16;
|
||||
p.k_nope_fp8 = (const uint8_t*)k_nope_fp8;
|
||||
p.k_nope_scale = k_nope_scale;
|
||||
p.k_rope_bf16 = (const bf16_t*)k_rope_bf16;
|
||||
p.o = (bf16_t*)o_ptr;
|
||||
p.lse = (float*)lse_ptr;
|
||||
p.sink_bias = sink_bias_ptr;
|
||||
p.B = B; p.H = H; p.N = N; p.HD = HD; p.NOPE = NOPE; p.ROPE = ROPE;
|
||||
p.q_nope_head_stride = q_nope_head_stride;
|
||||
p.q_nope_batch_stride = q_nope_batch_stride;
|
||||
p.q_scale_head_stride = q_scale_head_stride;
|
||||
p.q_scale_batch_stride = q_scale_batch_stride;
|
||||
p.q_rope_head_stride = q_rope_head_stride;
|
||||
p.q_rope_batch_stride = q_rope_batch_stride;
|
||||
p.o_head_stride = o_head_stride;
|
||||
p.o_batch_stride = o_batch_stride;
|
||||
p.lse_head_stride = lse_head_stride;
|
||||
p.lse_batch_stride = lse_batch_stride;
|
||||
p.scale = scale;
|
||||
|
||||
// Static shared memory size for fmha_mixed_fp8_decode_kernel<512,448,64>.
|
||||
// Keep this mirrored with the header layout and aligned up generously.
|
||||
int smem = 0;
|
||||
smem += 4; smem = (smem + 127) & ~127;
|
||||
smem += 128 * 32; smem = (smem + 127) & ~127; // sQ8
|
||||
smem += 128 * 32; smem = (smem + 127) & ~127; // sK8
|
||||
smem += 128 * 16 * 2; smem = (smem + 127) & ~127; // sQ16
|
||||
smem += 128 * 16 * 2; smem = (smem + 127) & ~127; // sK16
|
||||
smem += 128 * 16 * 2; smem = (smem + 127) & ~127; // sPk
|
||||
smem += 16 * 16 * 2; smem = (smem + 127) & ~127; // sV
|
||||
smem += 128 * 4; // sLogits
|
||||
smem += 128 * 4; // sP
|
||||
smem += 512 * 4; // sOacc
|
||||
smem += 512 * 2; // sOepi
|
||||
smem = (smem + 127) & ~127;
|
||||
|
||||
cudaFuncSetAttribute(fmha_mixed_fp8_decode_kernel<512,448,64>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem);
|
||||
dim3 grid(1, H, B);
|
||||
dim3 block(192);
|
||||
fmha_mixed_fp8_decode_kernel<512,448,64><<<grid, block, smem>>>(p);
|
||||
cudaError_t err = cudaGetLastError();
|
||||
return err == cudaSuccess ? 0 : (int)err;
|
||||
}
|
||||
|
||||
} // extern C
|
||||
374
dsv4/kernels/attention/fmha_mixed_fp8_decode.cuh
Normal file
374
dsv4/kernels/attention/fmha_mixed_fp8_decode.cuh
Normal file
@@ -0,0 +1,374 @@
|
||||
/**
|
||||
* DSV4 B1 — mixed FP8/BF16 decode FMHA for DeepSeek-V4 attention KV.
|
||||
*
|
||||
* Inputs are the storage-native DSV4 layout:
|
||||
* Q noPE: FP8_E4M3 + per-row FP32 scale, Q RoPE: BF16
|
||||
* KV noPE: FP8_E4M3 + per-row FP32 scale, KV RoPE: BF16
|
||||
*
|
||||
* This first B1 kernel targets the decode hot path (T == 1) and HD=512,
|
||||
* NOPE=448, ROPE=64. It removes the global FP8->BF16 KV dequant/gather and
|
||||
* uses Blackwell tcgen05 tensor cores for:
|
||||
* - noPE QK: f8f6f4 E4M3 x E4M3 -> FP32
|
||||
* - RoPE QK: f16 BF16 x BF16 -> FP32
|
||||
* - PV: f16 BF16 x BF16 -> FP32, with noPE V dequantized only into SMEM
|
||||
*
|
||||
* The noPE KV is never materialized as a global BF16 buffer.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp8.h>
|
||||
#include <cuda_fp8.hpp>
|
||||
#include <cstdint>
|
||||
#include <cmath>
|
||||
#include "fmha_common.cuh"
|
||||
#include "fmha_umma_desc.cuh"
|
||||
|
||||
namespace dsv4::kernels::attention {
|
||||
|
||||
struct FmhaMixedFp8DecodeParams {
|
||||
const uint8_t* __restrict__ q_nope_fp8; // (B,H,1,NOPE)
|
||||
const float* __restrict__ q_nope_scale; // (B,H,1)
|
||||
const bf16_t* __restrict__ q_rope_bf16; // (B,H,1,ROPE)
|
||||
|
||||
const uint8_t* __restrict__ k_nope_fp8; // (N,NOPE), MQA shared
|
||||
const float* __restrict__ k_nope_scale; // (N,)
|
||||
const bf16_t* __restrict__ k_rope_bf16; // (N,ROPE)
|
||||
|
||||
bf16_t* __restrict__ o; // (B,H,1,HD)
|
||||
float* __restrict__ lse; // (B,H,1), optional
|
||||
const float* __restrict__ sink_bias; // (B,H), optional
|
||||
|
||||
int B, H, N, HD, NOPE, ROPE;
|
||||
int q_nope_head_stride, q_nope_batch_stride;
|
||||
int q_scale_head_stride, q_scale_batch_stride;
|
||||
int q_rope_head_stride, q_rope_batch_stride;
|
||||
int o_head_stride, o_batch_stride;
|
||||
int lse_head_stride, lse_batch_stride;
|
||||
float scale;
|
||||
};
|
||||
|
||||
__device__ __forceinline__ float fp8_e4m3_to_f32(uint8_t byte) {
|
||||
__nv_fp8_e4m3 v;
|
||||
*reinterpret_cast<uint8_t*>(&v) = byte;
|
||||
return static_cast<float>(v);
|
||||
}
|
||||
|
||||
// FP8 canonical K-major layout for tcgen05.mma.kind::f8f6f4.
|
||||
// Logical matrix shape is (128, 32): 8 row groups x 16 FP8 columns per 128B atom.
|
||||
__device__ __forceinline__ int canon_idx_fp8_128x32(int r, int c) {
|
||||
constexpr int CORES_MN = 16; // 128 / 8
|
||||
int core_mn = r >> 3;
|
||||
int core_k = c >> 4; // 16 FP8 values = 16B atom width
|
||||
int local_r = r & 7;
|
||||
int local_c = c & 15;
|
||||
return core_k * CORES_MN * 128 + core_mn * 128 + local_r * 16 + local_c;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ int canon_idx_bf16_128x16(int r, int c) {
|
||||
constexpr int CORES_MN = 16;
|
||||
int core_mn = r >> 3;
|
||||
int core_k = c >> 3;
|
||||
int local_r = r & 7;
|
||||
int local_c = c & 7;
|
||||
return core_k * CORES_MN * 64 + core_mn * 64 + local_r * 8 + local_c;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ int canon_idx_bf16_16x16(int r, int c) {
|
||||
constexpr int CORES_MN = 2; // 16 / 8
|
||||
int core_mn = r >> 3;
|
||||
int core_k = c >> 3;
|
||||
int local_r = r & 7;
|
||||
int local_c = c & 7;
|
||||
return core_k * CORES_MN * 64 + core_mn * 64 + local_r * 8 + local_c;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ bf16_t f32_to_bf16_bits(float x) { return f32_to_bf16(x); }
|
||||
|
||||
// Read row 0 of a 128-wide TMEM result. Must be called by a full warp;
|
||||
// lane 0 receives row 0, lanes 1..31 receive rows 1..31 and are ignored.
|
||||
__device__ __forceinline__ void read_tmem_row0_128(uint32_t tb, float* out128, bool lane0) {
|
||||
for (int n = 0; n < 16; n++) {
|
||||
float tmp[8];
|
||||
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];"
|
||||
: "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),
|
||||
"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7])
|
||||
: "r"(tb + n * 8));
|
||||
asm volatile("tcgen05.wait::ld.sync.aligned;" ::: "memory");
|
||||
if (lane0) {
|
||||
#pragma unroll
|
||||
for (int c = 0; c < 8; c++) out128[n * 8 + c] = tmp[c];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<int HD=512, int NOPE=448, int ROPE=64, int SK_TILE=128>
|
||||
__global__ void __launch_bounds__(192)
|
||||
fmha_mixed_fp8_decode_kernel(FmhaMixedFp8DecodeParams p) {
|
||||
static_assert(HD == 512 && NOPE == 448 && ROPE == 64, "B1 first pass is specialized for DSV4 HD=512/NOPE=448/ROPE=64");
|
||||
constexpr int MMA_K_F8 = 32;
|
||||
constexpr int MMA_K_F16 = 16;
|
||||
constexpr int NKT_NOPE = NOPE / MMA_K_F8;
|
||||
constexpr int NKT_ROPE = ROPE / MMA_K_F16;
|
||||
constexpr int NKT_PV = SK_TILE / MMA_K_F16;
|
||||
constexpr int N_SUB = HD / 16;
|
||||
constexpr int TILE_F8 = 128 * MMA_K_F8; // bytes
|
||||
constexpr int TILE_F16 = 128 * MMA_K_F16; // bf16 elements
|
||||
constexpr int V_SUB_SZ = 16 * MMA_K_F16; // bf16 elements
|
||||
constexpr int TMEM_COLS = 512;
|
||||
|
||||
const int head_idx = blockIdx.y;
|
||||
const int batch_idx = blockIdx.z;
|
||||
const int tid = threadIdx.x;
|
||||
const int wid = tid >> 5;
|
||||
const int lane = tid & 31;
|
||||
const bool is_mma_warp = (wid == 4);
|
||||
const bool is_lane0 = (wid == 0 && lane == 0);
|
||||
const int n_kv_tiles = (p.N + SK_TILE - 1) / SK_TILE;
|
||||
|
||||
const uint8_t* q8 = p.q_nope_fp8 + batch_idx * p.q_nope_batch_stride + head_idx * p.q_nope_head_stride;
|
||||
const float q8_scale = p.q_nope_scale[batch_idx * p.q_scale_batch_stride + head_idx * p.q_scale_head_stride];
|
||||
const bf16_t* qrope = p.q_rope_bf16 + batch_idx * p.q_rope_batch_stride + head_idx * p.q_rope_head_stride;
|
||||
bf16_t* out = p.o + batch_idx * p.o_batch_stride + head_idx * p.o_head_stride;
|
||||
float* lse = p.lse ? p.lse + batch_idx * p.lse_batch_stride + head_idx * p.lse_head_stride : nullptr;
|
||||
|
||||
extern __shared__ __align__(128) char sbuf[];
|
||||
size_t off = 0;
|
||||
uint32_t* sTmemBase = (uint32_t*)(sbuf + off); off += 4;
|
||||
off = (off + 127) & ~(size_t)127;
|
||||
uint8_t* sQ8 = (uint8_t*)(sbuf + off); off += TILE_F8;
|
||||
off = (off + 127) & ~(size_t)127;
|
||||
uint8_t* sK8 = (uint8_t*)(sbuf + off); off += TILE_F8;
|
||||
off = (off + 127) & ~(size_t)127;
|
||||
bf16_t* sQ16 = (bf16_t*)(sbuf + off); off += TILE_F16 * sizeof(bf16_t);
|
||||
off = (off + 127) & ~(size_t)127;
|
||||
bf16_t* sK16 = (bf16_t*)(sbuf + off); off += TILE_F16 * sizeof(bf16_t);
|
||||
off = (off + 127) & ~(size_t)127;
|
||||
bf16_t* sPk = (bf16_t*)(sbuf + off); off += TILE_F16 * sizeof(bf16_t);
|
||||
off = (off + 127) & ~(size_t)127;
|
||||
bf16_t* sV = (bf16_t*)(sbuf + off); off += V_SUB_SZ * sizeof(bf16_t);
|
||||
off = (off + 127) & ~(size_t)127;
|
||||
float* sLogits = (float*)(sbuf + off); off += SK_TILE * sizeof(float);
|
||||
float* sP = (float*)(sbuf + off); off += SK_TILE * sizeof(float);
|
||||
float* sOacc = (float*)(sbuf + off); off += HD * sizeof(float);
|
||||
bf16_t* sOepi = (bf16_t*)(sbuf + off); off += HD * sizeof(bf16_t);
|
||||
|
||||
if (is_mma_warp) tmem_alloc((uint32_t)__cvta_generic_to_shared(sTmemBase), TMEM_COLS);
|
||||
asm volatile("fence.proxy.async.shared::cta;" ::: "memory");
|
||||
__syncthreads();
|
||||
uint32_t tb = *sTmemBase;
|
||||
|
||||
if (tid < HD) sOacc[tid] = 0.0f;
|
||||
if (tid < SK_TILE) { sLogits[tid] = -INFINITY; sP[tid] = 0.0f; }
|
||||
__syncthreads();
|
||||
|
||||
float running_max = -INFINITY;
|
||||
float running_sum = 0.0f;
|
||||
const uint32_t idesc_f8_qk = make_idesc_f8_e4m3(128, 128);
|
||||
const uint32_t idesc_f16_qk = make_idesc(128, 128);
|
||||
const uint32_t idesc_pv = make_idesc(128, 16);
|
||||
|
||||
for (int kv_tile = 0; kv_tile < n_kv_tiles; kv_tile++) {
|
||||
const int kv_start = kv_tile * SK_TILE;
|
||||
const int kv_len = min(SK_TILE, p.N - kv_start);
|
||||
|
||||
// ------------------------------------------------------------
|
||||
// QK noPE: FP8 tensor cores, raw logits in TMEM.
|
||||
// ------------------------------------------------------------
|
||||
for (int kt = 0; kt < NKT_NOPE; kt++) {
|
||||
for (int i = tid; i < TILE_F8; i += blockDim.x) { sQ8[i] = 0; sK8[i] = 0; }
|
||||
__syncthreads();
|
||||
for (int c = tid; c < MMA_K_F8; c += blockDim.x) {
|
||||
int d = kt * MMA_K_F8 + c;
|
||||
sQ8[canon_idx_fp8_128x32(0, c)] = q8[d];
|
||||
}
|
||||
for (int i = tid; i < kv_len * MMA_K_F8; i += blockDim.x) {
|
||||
int r = i / MMA_K_F8, c = i % MMA_K_F8;
|
||||
int d = kt * MMA_K_F8 + c;
|
||||
sK8[canon_idx_fp8_128x32(r, c)] = p.k_nope_fp8[(int64_t)(kv_start + r) * NOPE + d];
|
||||
}
|
||||
__syncthreads();
|
||||
if (is_mma_warp && lane == 0) {
|
||||
uint64_t dq = make_umma_desc_kmajor_none((uint32_t)__cvta_generic_to_shared(sQ8), 128);
|
||||
uint64_t dk = make_umma_desc_kmajor_none((uint32_t)__cvta_generic_to_shared(sK8), 128);
|
||||
umma_ss_f8f6f4(tb, dq, dk, idesc_f8_qk, kt > 0);
|
||||
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
asm volatile("fence.sc.gpu;" ::: "memory");
|
||||
__syncthreads();
|
||||
|
||||
if (wid == 0) read_tmem_row0_128(tb, sLogits, lane == 0);
|
||||
__syncthreads();
|
||||
if (is_lane0) {
|
||||
#pragma unroll
|
||||
for (int c = 0; c < SK_TILE; c++) {
|
||||
if (c < kv_len) {
|
||||
float ks = p.k_nope_scale[kv_start + c];
|
||||
sLogits[c] = sLogits[c] * q8_scale * ks;
|
||||
} else {
|
||||
sLogits[c] = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// ------------------------------------------------------------
|
||||
// QK RoPE: BF16 tensor cores, then add to scaled noPE logits.
|
||||
// ------------------------------------------------------------
|
||||
for (int kt = 0; kt < NKT_ROPE; kt++) {
|
||||
for (int i = tid; i < TILE_F16; i += blockDim.x) { sQ16[i] = 0; sK16[i] = 0; }
|
||||
__syncthreads();
|
||||
for (int c = tid; c < MMA_K_F16; c += blockDim.x) {
|
||||
int d = kt * MMA_K_F16 + c;
|
||||
sQ16[canon_idx_bf16_128x16(0, c)] = qrope[d];
|
||||
}
|
||||
for (int i = tid; i < kv_len * MMA_K_F16; i += blockDim.x) {
|
||||
int r = i / MMA_K_F16, c = i % MMA_K_F16;
|
||||
int d = kt * MMA_K_F16 + c;
|
||||
sK16[canon_idx_bf16_128x16(r, c)] = p.k_rope_bf16[(int64_t)(kv_start + r) * ROPE + d];
|
||||
}
|
||||
__syncthreads();
|
||||
if (is_mma_warp && lane == 0) {
|
||||
uint64_t dq = make_umma_desc_kmajor_none((uint32_t)__cvta_generic_to_shared(sQ16), 128);
|
||||
uint64_t dk = make_umma_desc_kmajor_none((uint32_t)__cvta_generic_to_shared(sK16), 128);
|
||||
umma_ss_f16(tb, dq, dk, idesc_f16_qk, kt > 0);
|
||||
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
asm volatile("fence.sc.gpu;" ::: "memory");
|
||||
__syncthreads();
|
||||
|
||||
// Use sP as a temporary row buffer here; probabilities are formed later.
|
||||
if (wid == 0) read_tmem_row0_128(tb, sP, lane == 0);
|
||||
__syncthreads();
|
||||
if (is_lane0) {
|
||||
for (int c = 0; c < kv_len; c++) sLogits[c] += sP[c];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// ------------------------------------------------------------
|
||||
// Softmax tile probabilities for row 0.
|
||||
// ------------------------------------------------------------
|
||||
float tile_max = -INFINITY;
|
||||
if (is_lane0) {
|
||||
for (int c = 0; c < kv_len; c++) tile_max = fmaxf(tile_max, sLogits[c] * p.scale);
|
||||
float tile_sum = 0.0f;
|
||||
for (int c = 0; c < kv_len; c++) {
|
||||
float pv = expf(sLogits[c] * p.scale - tile_max);
|
||||
sP[c] = pv;
|
||||
tile_sum += pv;
|
||||
}
|
||||
for (int c = kv_len; c < SK_TILE; c++) sP[c] = 0.0f;
|
||||
|
||||
float new_max = fmaxf(running_max, tile_max);
|
||||
float rescale_old = (running_max > -INFINITY) ? expf(running_max - new_max) : 0.0f;
|
||||
for (int d = 0; d < HD; d++) sOacc[d] *= rescale_old;
|
||||
running_sum = running_sum * rescale_old + tile_sum * expf(tile_max - new_max);
|
||||
running_max = new_max;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// ------------------------------------------------------------
|
||||
// PV: probabilities BF16 x V BF16. noPE V is dequantized into SMEM only.
|
||||
// ------------------------------------------------------------
|
||||
for (int n_sub = 0; n_sub < N_SUB; n_sub++) {
|
||||
int d_base = n_sub * 16;
|
||||
for (int pv_kt = 0; pv_kt < NKT_PV; pv_kt++) {
|
||||
const int col_start = pv_kt * MMA_K_F16;
|
||||
for (int i = tid; i < TILE_F16; i += blockDim.x) sPk[i] = 0;
|
||||
for (int i = tid; i < V_SUB_SZ; i += blockDim.x) sV[i] = 0;
|
||||
__syncthreads();
|
||||
|
||||
// P matrix: only row 0 non-zero.
|
||||
for (int c = tid; c < MMA_K_F16; c += blockDim.x) {
|
||||
int gc = col_start + c;
|
||||
sPk[canon_idx_bf16_128x16(0, c)] = f32_to_bf16_bits(sP[gc]);
|
||||
}
|
||||
|
||||
// V matrix B: logical (16 K rows, 16 N cols) in BF16 canonical layout.
|
||||
for (int i = tid; i < 16 * MMA_K_F16; i += blockDim.x) {
|
||||
int dd = i / MMA_K_F16;
|
||||
int kk = i % MMA_K_F16;
|
||||
int row = col_start + kk;
|
||||
int g_row = kv_start + row;
|
||||
int d = d_base + dd;
|
||||
bf16_t vbits = 0;
|
||||
if (row < kv_len) {
|
||||
if (d < NOPE) {
|
||||
uint8_t b = p.k_nope_fp8[(int64_t)g_row * NOPE + d];
|
||||
float v = fp8_e4m3_to_f32(b) * p.k_nope_scale[g_row];
|
||||
vbits = f32_to_bf16_bits(v);
|
||||
} else {
|
||||
vbits = p.k_rope_bf16[(int64_t)g_row * ROPE + (d - NOPE)];
|
||||
}
|
||||
}
|
||||
// B is (K=16 rows, N=16 cols). Reuse BF16 canonical with rows=16
|
||||
// by embedding into the first 16 rows of a 128-row tile; MMA_N=16.
|
||||
sV[canon_idx_bf16_16x16(kk, dd)] = vbits;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (is_mma_warp && lane == 0) {
|
||||
uint64_t dp = make_umma_desc_kmajor_none((uint32_t)__cvta_generic_to_shared(sPk), 128);
|
||||
uint64_t dv = make_umma_desc_kmajor_none((uint32_t)__cvta_generic_to_shared(sV), 16);
|
||||
umma_ss_f16(tb + n_sub * 16, dp, dv, idesc_pv, pv_kt > 0);
|
||||
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
asm volatile("fence.sc.gpu;" ::: "memory");
|
||||
__syncthreads();
|
||||
|
||||
// Accumulate PV tile contribution after applying exp(tile_max-new_max).
|
||||
if (wid == 0) {
|
||||
float rescale_new = 0.0f;
|
||||
if (lane == 0) {
|
||||
// running_max is already the post-tile max. Recompute tile scale.
|
||||
float tile_max2 = -INFINITY;
|
||||
for (int c = 0; c < kv_len; c++) tile_max2 = fmaxf(tile_max2, sLogits[c] * p.scale);
|
||||
rescale_new = expf(tile_max2 - running_max);
|
||||
}
|
||||
for (int n = 0; n < HD / 8; n++) {
|
||||
float tmp[8];
|
||||
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];"
|
||||
: "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),
|
||||
"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7])
|
||||
: "r"(tb + n * 8));
|
||||
asm volatile("tcgen05.wait::ld.sync.aligned;" ::: "memory");
|
||||
if (lane == 0) {
|
||||
#pragma unroll
|
||||
for (int c = 0; c < 8; c++) sOacc[n * 8 + c] += tmp[c] * rescale_new;
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Attention sink: denominator-only logit.
|
||||
if (is_lane0 && p.sink_bias != nullptr) {
|
||||
float sb = p.sink_bias[batch_idx * p.H + head_idx];
|
||||
float new_max = fmaxf(running_max, sb);
|
||||
float rescale_old = (running_max > -INFINITY) ? expf(running_max - new_max) : 0.0f;
|
||||
for (int d = 0; d < HD; d++) sOacc[d] *= rescale_old;
|
||||
running_sum = running_sum * rescale_old + expf(sb - new_max);
|
||||
running_max = new_max;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (is_lane0) {
|
||||
float inv_sum = 1.0f / running_sum;
|
||||
for (int d = 0; d < HD; d++) sOepi[d] = f32_to_bf16_bits(sOacc[d] * inv_sum);
|
||||
if (lse) lse[0] = logf(running_sum) + running_max;
|
||||
}
|
||||
__syncthreads();
|
||||
for (int d = tid; d < HD; d += blockDim.x) out[d] = sOepi[d];
|
||||
__syncthreads();
|
||||
|
||||
if (is_mma_warp) tmem_dealloc(tb, TMEM_COLS);
|
||||
}
|
||||
|
||||
} // namespace dsv4::kernels::attention
|
||||
148
dsv4/kernels/attention/fmha_mixed_fp8_op.py
Normal file
148
dsv4/kernels/attention/fmha_mixed_fp8_op.py
Normal file
@@ -0,0 +1,148 @@
|
||||
"""DSV4 B1 mixed FP8/BF16 decode FMHA loader.
|
||||
|
||||
This path is intentionally hard-error only: it does not fall back to PyTorch or to
|
||||
BF16 FMHA if the mixed FP8 kernel is requested.
|
||||
"""
|
||||
import ctypes
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
KERNEL_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
REPO_ROOT = os.path.normpath(os.path.join(KERNEL_DIR, "..", ".."))
|
||||
SOURCE = os.path.join(KERNEL_DIR, "fmha_mixed_fp8_capi.cu")
|
||||
BUILD_DIR = os.path.join(REPO_ROOT, "build", "fmha_mixed_fp8")
|
||||
SO_NAME = "libfmha_mixed_fp8.so"
|
||||
|
||||
_lib = None
|
||||
_lib_lock = False
|
||||
|
||||
|
||||
def _find_nvcc():
|
||||
import shutil
|
||||
for c in ["/usr/local/cuda-13.2/bin/nvcc", "/usr/local/cuda/bin/nvcc"]:
|
||||
if os.path.isfile(c):
|
||||
return c
|
||||
nvcc = shutil.which("nvcc")
|
||||
if nvcc:
|
||||
return nvcc
|
||||
raise RuntimeError("nvcc not found")
|
||||
|
||||
|
||||
def _ensure_built():
|
||||
global _lib, _lib_lock
|
||||
if _lib is not None:
|
||||
return _lib
|
||||
if _lib_lock:
|
||||
raise RuntimeError("Recursive mixed-FP8 FMHA build")
|
||||
_lib_lock = True
|
||||
try:
|
||||
so_path = os.path.join(BUILD_DIR, SO_NAME)
|
||||
deps = [
|
||||
SOURCE,
|
||||
os.path.join(KERNEL_DIR, "fmha_common.cuh"),
|
||||
os.path.join(KERNEL_DIR, "fmha_umma_desc.cuh"),
|
||||
os.path.join(KERNEL_DIR, "fmha_mixed_fp8_decode.cuh"),
|
||||
]
|
||||
src_mtime = max(os.path.getmtime(p) for p in deps if os.path.exists(p))
|
||||
need_build = not os.path.isfile(so_path) or src_mtime > os.path.getmtime(so_path)
|
||||
if not need_build:
|
||||
_lib = ctypes.CDLL(so_path)
|
||||
return _lib
|
||||
|
||||
os.makedirs(BUILD_DIR, exist_ok=True)
|
||||
nvcc = _find_nvcc()
|
||||
cmd = [
|
||||
nvcc, "-std=c++20", "-shared", "-Xcompiler", "-fPIC",
|
||||
"-gencode=arch=compute_100a,code=sm_100a",
|
||||
"-gencode=arch=compute_100a,code=compute_100a",
|
||||
f"-I{KERNEL_DIR}", f"-I{REPO_ROOT}",
|
||||
"-O3", "--use_fast_math", "--expt-relaxed-constexpr",
|
||||
SOURCE, "-o", so_path, "-lcudart", "-lcuda",
|
||||
]
|
||||
logger.info("Building libfmha_mixed_fp8.so (sm_100a)...")
|
||||
res = subprocess.run(cmd, capture_output=True, text=True)
|
||||
if res.returncode != 0:
|
||||
raise RuntimeError(f"mixed FP8 FMHA nvcc failed:\nSTDOUT:\n{res.stdout}\nSTDERR:\n{res.stderr}")
|
||||
_lib = ctypes.CDLL(so_path)
|
||||
return _lib
|
||||
finally:
|
||||
_lib_lock = False
|
||||
|
||||
|
||||
def _quantize_q_split(q: torch.Tensor, rope_dim: int):
|
||||
from dsv4.kernels.cuda.loader import get_cuda_module
|
||||
mod = get_cuda_module("fp8_attention_io", ["fp8_attention_io.cu"],
|
||||
extra_cuda_cflags=[
|
||||
"-gencode=arch=compute_100a,code=sm_100a",
|
||||
"-O3", "--use_fast_math", "--expt-relaxed-constexpr",
|
||||
])
|
||||
return mod.quantize_q_fp8_split(q, rope_dim)
|
||||
|
||||
|
||||
def fmha_mixed_fp8_decode_raw(
|
||||
q: torch.Tensor, # (B,H,1,HD) BF16
|
||||
k_nope_fp8: torch.Tensor, # (N,NOPE) uint8/float8_e4m3fn
|
||||
k_nope_scale: torch.Tensor, # (N,) FP32
|
||||
k_rope_bf16: torch.Tensor, # (N,ROPE) BF16
|
||||
scale: float,
|
||||
attn_sink: Optional[torch.Tensor] = None,
|
||||
rope_dim: int = 64,
|
||||
):
|
||||
if q.dim() != 4:
|
||||
raise RuntimeError("q must be (B,H,T,HD)")
|
||||
B, H, T, HD = q.shape
|
||||
if T != 1:
|
||||
raise RuntimeError("mixed FP8 FMHA supports decode T==1 only")
|
||||
NOPE = HD - rope_dim
|
||||
if HD != 512 or NOPE != 448 or rope_dim != 64:
|
||||
raise RuntimeError(f"mixed FP8 FMHA first pass supports HD=512/NOPE=448/ROPE=64, got {HD}/{NOPE}/{rope_dim}")
|
||||
|
||||
q = q.contiguous()
|
||||
k_nope_fp8 = k_nope_fp8.contiguous()
|
||||
k_nope_scale = k_nope_scale.contiguous()
|
||||
k_rope_bf16 = k_rope_bf16.contiguous()
|
||||
q_nope_fp8, q_nope_scale, q_rope = _quantize_q_split(q, rope_dim)
|
||||
|
||||
N = k_nope_fp8.shape[0]
|
||||
o = torch.empty((B, H, T, HD), dtype=torch.bfloat16, device=q.device)
|
||||
lse = torch.empty((B, H, T), dtype=torch.float32, device=q.device)
|
||||
|
||||
sink_ptr = ctypes.c_void_p(0)
|
||||
sb = None
|
||||
if attn_sink is not None:
|
||||
sb = attn_sink.float().contiguous()
|
||||
if sb.dim() == 1:
|
||||
sb = sb.unsqueeze(0).expand(B, -1).contiguous()
|
||||
if tuple(sb.shape) != (B, H):
|
||||
raise RuntimeError(f"sink bias shape {tuple(sb.shape)} != {(B,H)}")
|
||||
sink_ptr = ctypes.c_void_p(sb.data_ptr())
|
||||
|
||||
lib = _ensure_built()
|
||||
ret = lib.fmha_mixed_fp8_decode_launch(
|
||||
ctypes.c_void_p(q_nope_fp8.data_ptr()),
|
||||
ctypes.c_void_p(q_nope_scale.data_ptr()),
|
||||
ctypes.c_void_p(q_rope.data_ptr()),
|
||||
ctypes.c_void_p(k_nope_fp8.data_ptr()),
|
||||
ctypes.c_void_p(k_nope_scale.data_ptr()),
|
||||
ctypes.c_void_p(k_rope_bf16.data_ptr()),
|
||||
ctypes.c_void_p(o.data_ptr()),
|
||||
ctypes.c_void_p(lse.data_ptr()),
|
||||
sink_ptr,
|
||||
ctypes.c_int(B), ctypes.c_int(H), ctypes.c_int(T), ctypes.c_int(N),
|
||||
ctypes.c_int(HD), ctypes.c_int(NOPE), ctypes.c_int(rope_dim),
|
||||
ctypes.c_int(q_nope_fp8.stride(1)), ctypes.c_int(q_nope_fp8.stride(0)),
|
||||
ctypes.c_int(q_nope_scale.stride(1)), ctypes.c_int(q_nope_scale.stride(0)),
|
||||
ctypes.c_int(q_rope.stride(1)), ctypes.c_int(q_rope.stride(0)),
|
||||
ctypes.c_int(o.stride(1)), ctypes.c_int(o.stride(0)),
|
||||
ctypes.c_int(lse.stride(1)), ctypes.c_int(lse.stride(0)),
|
||||
ctypes.c_float(scale),
|
||||
)
|
||||
if ret != 0:
|
||||
raise RuntimeError(f"mixed FP8 FMHA launch failed: return code {ret}")
|
||||
return o, lse
|
||||
@@ -340,4 +340,31 @@ __device__ __forceinline__ uint32_t make_idesc(int block_m, int block_n) {
|
||||
| ((uint32_t)(block_m >> 4) << 24); // MMA_M
|
||||
}
|
||||
|
||||
/**
|
||||
* tcgen05.mma SS for .kind::f8f6f4 with E4M3xE4M3 -> FP32.
|
||||
* A and B element types are encoded in idesc. For B1 we use E4M3/E4M3.
|
||||
*/
|
||||
__device__ void umma_ss_f8f6f4(
|
||||
uint32_t tmem_c, uint64_t desc_a, uint64_t desc_b,
|
||||
uint32_t i_desc, bool accumulate = false
|
||||
) {
|
||||
uint32_t scaleC_bits = accumulate ? 0x3F800000u : 0u;
|
||||
asm volatile(
|
||||
"{\n\t"
|
||||
".reg .pred p;\n\t"
|
||||
"setp.ne.b32 p, %4, 0;\n\t"
|
||||
"tcgen05.mma.cta_group::1.kind::f8f6f4 [%0], %1, %2, %3, p;\n\t"
|
||||
"}"
|
||||
:: "r"(tmem_c), "l"(desc_a), "l"(desc_b),
|
||||
"r"(i_desc), "r"(scaleC_bits)
|
||||
);
|
||||
}
|
||||
|
||||
/** Instruction descriptor for .kind::f8f6f4 E4M3 x E4M3 -> FP32. */
|
||||
__device__ __forceinline__ uint32_t make_idesc_f8_e4m3(int block_m, int block_n) {
|
||||
return (1U << 4) // dtype = F32
|
||||
| ((uint32_t)(block_n >> 3) << 17) // MMA_N
|
||||
| ((uint32_t)(block_m >> 4) << 24); // MMA_M
|
||||
}
|
||||
|
||||
} // namespace dsv4::kernels::attention
|
||||
|
||||
@@ -195,3 +195,41 @@ def dsv4_attention_per_head(
|
||||
output[q_idx] = o
|
||||
|
||||
return output
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# B1: mixed FP8/BF16 DeepSeek-V4 decode attention
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def dsv4_attention_mixed_fp8_decode(
|
||||
q: torch.Tensor, # (n_q_heads,T,HD) or (B,n_q_heads,T,HD) BF16
|
||||
k_nope_fp8: torch.Tensor, # (N,NOPE) uint8/float8_e4m3fn
|
||||
k_nope_scale: torch.Tensor, # (N,) FP32
|
||||
k_rope_bf16: torch.Tensor, # (N,ROPE) BF16
|
||||
scale: Optional[float] = None,
|
||||
sink_bias: Optional[torch.Tensor] = None,
|
||||
rope_dim: int = 64,
|
||||
) -> torch.Tensor:
|
||||
"""B1 production path: storage-native FP8/BF16 KV decode FMHA.
|
||||
|
||||
This intentionally has no PyTorch/BF16 fallback. It is the decode-only path
|
||||
for DeepSeek-V4 attention where noPE KV is already stored as FP8_E4M3 with
|
||||
per-row FP32 scales and RoPE KV is BF16.
|
||||
"""
|
||||
from dsv4.kernels.attention.fmha_mixed_fp8_op import fmha_mixed_fp8_decode_raw
|
||||
|
||||
has_batch = q.dim() == 4
|
||||
if q.dim() == 3:
|
||||
q4 = q.unsqueeze(0).contiguous()
|
||||
elif q.dim() == 4:
|
||||
q4 = q.contiguous()
|
||||
else:
|
||||
raise RuntimeError("q must be (H,T,HD) or (B,H,T,HD)")
|
||||
|
||||
hd = q4.shape[-1]
|
||||
scale = scale or (1.0 / math.sqrt(hd))
|
||||
o4, _lse = fmha_mixed_fp8_decode_raw(
|
||||
q4, k_nope_fp8, k_nope_scale, k_rope_bf16,
|
||||
scale, attn_sink=sink_bias, rope_dim=rope_dim,
|
||||
)
|
||||
return o4 if has_batch else o4.squeeze(0)
|
||||
|
||||
254
dsv4/kernels/cuda/fp8_attention_io.cu
Normal file
254
dsv4/kernels/cuda/fp8_attention_io.cu
Normal file
@@ -0,0 +1,254 @@
|
||||
/**
|
||||
* DSV4 B1 — FP8 attention input/output preparation kernels.
|
||||
*
|
||||
* These are deliberately tiny launch-count reducers for the mixed-precision
|
||||
* FMHA path:
|
||||
* - quantize Q noPE dims BF16 -> FP8_E4M3 with a per-(batch,head,row) scale
|
||||
* - keep Q RoPE dims BF16
|
||||
* - gather compressed KV noPE bytes/scales and RoPE BF16 without global dequant
|
||||
* - quantize the SWA noPE tail BF16 -> FP8_E4M3 in the same gather kernel
|
||||
*
|
||||
* No PyTorch fallback and no FP8->BF16 global staging for noPE KV.
|
||||
*/
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp8.h>
|
||||
#include <cuda_fp8.hpp>
|
||||
#include <ATen/ATen.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <torch/extension.h>
|
||||
#include <cstdint>
|
||||
#include <cfloat>
|
||||
|
||||
static constexpr float E4M3_MAX = 448.0f;
|
||||
|
||||
__device__ __forceinline__ float bf16_load(const __nv_bfloat16* p) {
|
||||
return __bfloat162float(*p);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint8_t fp8_e4m3_from_f32(float x) {
|
||||
x = fminf(fmaxf(x, -E4M3_MAX), E4M3_MAX);
|
||||
__nv_fp8_e4m3 v(x);
|
||||
return *reinterpret_cast<uint8_t*>(&v);
|
||||
}
|
||||
|
||||
__global__ void quantize_q_fp8_split_kernel(
|
||||
const __nv_bfloat16* __restrict__ q, // (B,H,T,HD)
|
||||
uint8_t* __restrict__ q_nope_fp8, // (B,H,T,NOPE)
|
||||
float* __restrict__ q_nope_scale, // (B,H,T)
|
||||
__nv_bfloat16* __restrict__ q_rope, // (B,H,T,ROPE)
|
||||
int rows, int hd, int nope, int rope
|
||||
) {
|
||||
int row = blockIdx.x;
|
||||
if (row >= rows) return;
|
||||
|
||||
const __nv_bfloat16* q_row = q + (int64_t)row * hd;
|
||||
uint8_t* out8 = q_nope_fp8 + (int64_t)row * nope;
|
||||
__nv_bfloat16* outrope = q_rope + (int64_t)row * rope;
|
||||
|
||||
float local_max = 0.0f;
|
||||
for (int c = threadIdx.x; c < nope; c += blockDim.x) {
|
||||
local_max = fmaxf(local_max, fabsf(bf16_load(q_row + c)));
|
||||
}
|
||||
|
||||
// block reduction over 256 threads
|
||||
for (int off = 16; off > 0; off >>= 1)
|
||||
local_max = fmaxf(local_max, __shfl_down_sync(0xffffffff, local_max, off));
|
||||
__shared__ float warp_max[8];
|
||||
if ((threadIdx.x & 31) == 0) warp_max[threadIdx.x >> 5] = local_max;
|
||||
__syncthreads();
|
||||
float amax = 0.0f;
|
||||
if (threadIdx.x < 32) {
|
||||
amax = (threadIdx.x < (blockDim.x + 31) / 32) ? warp_max[threadIdx.x] : 0.0f;
|
||||
for (int off = 16; off > 0; off >>= 1)
|
||||
amax = fmaxf(amax, __shfl_down_sync(0xffffffff, amax, off));
|
||||
if (threadIdx.x == 0) {
|
||||
float scale = amax / E4M3_MAX;
|
||||
if (scale < 1e-8f) scale = 1e-8f;
|
||||
q_nope_scale[row] = scale;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float scale = q_nope_scale[row];
|
||||
float inv_scale = 1.0f / scale;
|
||||
for (int c = threadIdx.x; c < nope; c += blockDim.x) {
|
||||
out8[c] = fp8_e4m3_from_f32(bf16_load(q_row + c) * inv_scale);
|
||||
}
|
||||
for (int c = threadIdx.x; c < rope; c += blockDim.x) {
|
||||
outrope[c] = q_row[nope + c];
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void copy_comp_rows_kernel(
|
||||
const uint8_t* __restrict__ comp_nope_fp8,
|
||||
const float* __restrict__ comp_nope_scale,
|
||||
const __nv_bfloat16* __restrict__ comp_rope,
|
||||
const int32_t* __restrict__ indices, // optional; nullptr => row i
|
||||
uint8_t* __restrict__ out_nope_fp8,
|
||||
float* __restrict__ out_nope_scale,
|
||||
__nv_bfloat16* __restrict__ out_rope,
|
||||
int K, int nope, int rope
|
||||
) {
|
||||
int row = blockIdx.y;
|
||||
int col = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (row >= K) return;
|
||||
int src = indices ? indices[row] : row;
|
||||
if (col < nope) out_nope_fp8[(int64_t)row * nope + col] = comp_nope_fp8[(int64_t)src * nope + col];
|
||||
if (col < rope) out_rope[(int64_t)row * rope + col] = comp_rope[(int64_t)src * rope + col];
|
||||
if (blockIdx.x == 0 && threadIdx.x == 0) out_nope_scale[row] = comp_nope_scale[src];
|
||||
}
|
||||
|
||||
__global__ void quantize_swa_tail_kernel(
|
||||
const __nv_bfloat16* __restrict__ swa, // (S, HD), BF16
|
||||
uint8_t* __restrict__ out_nope_fp8, // (K+S, NOPE)
|
||||
float* __restrict__ out_nope_scale, // (K+S)
|
||||
__nv_bfloat16* __restrict__ out_rope, // (K+S, ROPE)
|
||||
int K, int S, int hd, int nope, int rope
|
||||
) {
|
||||
int s = blockIdx.x;
|
||||
if (s >= S) return;
|
||||
int out_row = K + s;
|
||||
const __nv_bfloat16* src = swa + (int64_t)s * hd;
|
||||
uint8_t* out8 = out_nope_fp8 + (int64_t)out_row * nope;
|
||||
__nv_bfloat16* outrope = out_rope + (int64_t)out_row * rope;
|
||||
|
||||
float local_max = 0.0f;
|
||||
for (int c = threadIdx.x; c < nope; c += blockDim.x) {
|
||||
local_max = fmaxf(local_max, fabsf(bf16_load(src + c)));
|
||||
}
|
||||
for (int off = 16; off > 0; off >>= 1)
|
||||
local_max = fmaxf(local_max, __shfl_down_sync(0xffffffff, local_max, off));
|
||||
__shared__ float warp_max[8];
|
||||
if ((threadIdx.x & 31) == 0) warp_max[threadIdx.x >> 5] = local_max;
|
||||
__syncthreads();
|
||||
float amax = 0.0f;
|
||||
if (threadIdx.x < 32) {
|
||||
amax = (threadIdx.x < (blockDim.x + 31) / 32) ? warp_max[threadIdx.x] : 0.0f;
|
||||
for (int off = 16; off > 0; off >>= 1)
|
||||
amax = fmaxf(amax, __shfl_down_sync(0xffffffff, amax, off));
|
||||
if (threadIdx.x == 0) {
|
||||
float scale = amax / E4M3_MAX;
|
||||
if (scale < 1e-8f) scale = 1e-8f;
|
||||
out_nope_scale[out_row] = scale;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float inv_scale = 1.0f / out_nope_scale[out_row];
|
||||
for (int c = threadIdx.x; c < nope; c += blockDim.x) {
|
||||
out8[c] = fp8_e4m3_from_f32(bf16_load(src + c) * inv_scale);
|
||||
}
|
||||
for (int c = threadIdx.x; c < rope; c += blockDim.x) {
|
||||
outrope[c] = src[nope + c];
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> quantize_q_fp8_split_cuda(
|
||||
torch::Tensor q, int64_t rope_dim
|
||||
) {
|
||||
TORCH_CHECK(q.is_cuda(), "q must be CUDA");
|
||||
TORCH_CHECK(q.scalar_type() == torch::kBFloat16, "q must be BF16");
|
||||
TORCH_CHECK(q.dim() == 4, "q must be (B,H,T,HD)");
|
||||
q = q.contiguous();
|
||||
int B = q.size(0), H = q.size(1), T = q.size(2), HD = q.size(3);
|
||||
int rope = (int)rope_dim;
|
||||
int nope = HD - rope;
|
||||
TORCH_CHECK(nope > 0 && rope > 0, "invalid rope_dim");
|
||||
auto q8 = torch::empty({B, H, T, nope}, q.options().dtype(torch::kUInt8));
|
||||
auto qs = torch::empty({B, H, T}, q.options().dtype(torch::kFloat32));
|
||||
auto qr = torch::empty({B, H, T, rope}, q.options().dtype(torch::kBFloat16));
|
||||
int rows = B * H * T;
|
||||
quantize_q_fp8_split_kernel<<<rows, 256, 0, c10::cuda::getCurrentCUDAStream()>>>(
|
||||
reinterpret_cast<const __nv_bfloat16*>(q.data_ptr<at::BFloat16>()),
|
||||
q8.data_ptr<uint8_t>(), qs.data_ptr<float>(),
|
||||
reinterpret_cast<__nv_bfloat16*>(qr.data_ptr<at::BFloat16>()),
|
||||
rows, HD, nope, rope);
|
||||
return {q8.view(torch::kFloat8_e4m3fn), qs, qr};
|
||||
}
|
||||
|
||||
void gather_mixed_selective_cuda(
|
||||
torch::Tensor comp_nope_fp8, torch::Tensor comp_nope_scale, torch::Tensor comp_rope,
|
||||
torch::Tensor swa, torch::Tensor indices,
|
||||
torch::Tensor out_nope_fp8, torch::Tensor out_nope_scale, torch::Tensor out_rope
|
||||
) {
|
||||
TORCH_CHECK(indices.scalar_type() == torch::kInt32, "indices must be int32");
|
||||
int K = indices.size(0);
|
||||
int S = swa.size(0);
|
||||
int nope = comp_nope_fp8.size(1);
|
||||
int rope = comp_rope.size(1);
|
||||
int hd = nope + rope;
|
||||
if (K > 0) {
|
||||
dim3 grid(((nope > rope ? nope : rope) + 255) / 256, K);
|
||||
copy_comp_rows_kernel<<<grid, 256, 0, c10::cuda::getCurrentCUDAStream()>>>(
|
||||
comp_nope_fp8.data_ptr<uint8_t>(), comp_nope_scale.data_ptr<float>(),
|
||||
reinterpret_cast<const __nv_bfloat16*>(comp_rope.data_ptr<at::BFloat16>()),
|
||||
indices.data_ptr<int32_t>(),
|
||||
out_nope_fp8.data_ptr<uint8_t>(), out_nope_scale.data_ptr<float>(),
|
||||
reinterpret_cast<__nv_bfloat16*>(out_rope.data_ptr<at::BFloat16>()),
|
||||
K, nope, rope);
|
||||
}
|
||||
if (S > 0) {
|
||||
quantize_swa_tail_kernel<<<S, 256, 0, c10::cuda::getCurrentCUDAStream()>>>(
|
||||
reinterpret_cast<const __nv_bfloat16*>(swa.data_ptr<at::BFloat16>()),
|
||||
out_nope_fp8.data_ptr<uint8_t>(), out_nope_scale.data_ptr<float>(),
|
||||
reinterpret_cast<__nv_bfloat16*>(out_rope.data_ptr<at::BFloat16>()),
|
||||
K, S, hd, nope, rope);
|
||||
}
|
||||
}
|
||||
|
||||
void gather_mixed_all_cuda(
|
||||
torch::Tensor comp_nope_fp8, torch::Tensor comp_nope_scale, torch::Tensor comp_rope,
|
||||
torch::Tensor swa, torch::Tensor out_nope_fp8, torch::Tensor out_nope_scale, torch::Tensor out_rope
|
||||
) {
|
||||
int K = comp_nope_fp8.size(0);
|
||||
int S = swa.size(0);
|
||||
int nope = comp_nope_fp8.size(1);
|
||||
int rope = comp_rope.size(1);
|
||||
int hd = nope + rope;
|
||||
if (K > 0) {
|
||||
dim3 grid(((nope > rope ? nope : rope) + 255) / 256, K);
|
||||
copy_comp_rows_kernel<<<grid, 256, 0, c10::cuda::getCurrentCUDAStream()>>>(
|
||||
comp_nope_fp8.data_ptr<uint8_t>(), comp_nope_scale.data_ptr<float>(),
|
||||
reinterpret_cast<const __nv_bfloat16*>(comp_rope.data_ptr<at::BFloat16>()),
|
||||
nullptr,
|
||||
out_nope_fp8.data_ptr<uint8_t>(), out_nope_scale.data_ptr<float>(),
|
||||
reinterpret_cast<__nv_bfloat16*>(out_rope.data_ptr<at::BFloat16>()),
|
||||
K, nope, rope);
|
||||
}
|
||||
if (S > 0) {
|
||||
quantize_swa_tail_kernel<<<S, 256, 0, c10::cuda::getCurrentCUDAStream()>>>(
|
||||
reinterpret_cast<const __nv_bfloat16*>(swa.data_ptr<at::BFloat16>()),
|
||||
out_nope_fp8.data_ptr<uint8_t>(), out_nope_scale.data_ptr<float>(),
|
||||
reinterpret_cast<__nv_bfloat16*>(out_rope.data_ptr<at::BFloat16>()),
|
||||
K, S, hd, nope, rope);
|
||||
}
|
||||
}
|
||||
|
||||
void gather_mixed_swa_only_cuda(torch::Tensor swa, torch::Tensor out_nope_fp8,
|
||||
torch::Tensor out_nope_scale, torch::Tensor out_rope,
|
||||
int64_t rope_dim) {
|
||||
int S = swa.size(0);
|
||||
int hd = swa.size(1);
|
||||
int rope = (int)rope_dim;
|
||||
int nope = hd - rope;
|
||||
if (S > 0) {
|
||||
quantize_swa_tail_kernel<<<S, 256, 0, c10::cuda::getCurrentCUDAStream()>>>(
|
||||
reinterpret_cast<const __nv_bfloat16*>(swa.data_ptr<at::BFloat16>()),
|
||||
out_nope_fp8.data_ptr<uint8_t>(), out_nope_scale.data_ptr<float>(),
|
||||
reinterpret_cast<__nv_bfloat16*>(out_rope.data_ptr<at::BFloat16>()),
|
||||
0, S, hd, nope, rope);
|
||||
}
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("quantize_q_fp8_split", &quantize_q_fp8_split_cuda,
|
||||
"Split Q into FP8_E4M3 noPE + BF16 RoPE");
|
||||
m.def("gather_mixed_selective_", &gather_mixed_selective_cuda,
|
||||
"In-place mixed KV gather for selected compressed rows + SWA tail");
|
||||
m.def("gather_mixed_all_", &gather_mixed_all_cuda,
|
||||
"In-place mixed KV gather for all compressed rows + SWA tail");
|
||||
m.def("gather_mixed_swa_only_", &gather_mixed_swa_only_cuda,
|
||||
"In-place mixed KV gather for SWA-only attention");
|
||||
}
|
||||
@@ -519,13 +519,29 @@ class KVCache:
|
||||
self.comp_idx_fp8 = torch.zeros(max_comp, indexer_key_dim, dtype=torch.uint8, device=device)
|
||||
self.comp_idx_scale = torch.zeros(max_comp, dtype=torch.float32, device=device)
|
||||
|
||||
# Pre-allocated gather buffer — top_k compressed + SWA window
|
||||
# Pre-allocated mixed gather buffers.
|
||||
# CSA needs top_k + SWA; HCA is dense over compressed blocks, so it needs
|
||||
# max_comp + SWA. These buffers preserve the paper/native storage layout:
|
||||
# noPE stays FP8_E4M3 + scale, RoPE stays BF16.
|
||||
if compress_ratio > 4:
|
||||
self.mixed_gather_cap = max_comp + window_size
|
||||
elif compress_ratio == 4:
|
||||
self.mixed_gather_cap = indexer_top_k + window_size
|
||||
else:
|
||||
self.mixed_gather_cap = window_size
|
||||
self.gather_nope_fp8 = torch.zeros(self.mixed_gather_cap, self.nope_dim, dtype=torch.uint8, device=device)
|
||||
self.gather_nope_scale = torch.zeros(self.mixed_gather_cap, dtype=torch.float32, device=device)
|
||||
self.gather_rope_bf16 = torch.zeros(self.mixed_gather_cap, rope_dim, dtype=torch.bfloat16, device=device)
|
||||
|
||||
# Legacy BF16 gather buffer kept only for non-B1 experiments; the live
|
||||
# B1 path below does not materialize noPE KV as global BF16.
|
||||
self.gather_buf = torch.zeros(indexer_top_k + window_size, head_dim, dtype=torch.bfloat16, device=device)
|
||||
self.n_comp = 0
|
||||
self._has_idx = False
|
||||
|
||||
# Cache dequant modules (loaded once)
|
||||
# Cache extension modules (loaded once)
|
||||
self._kv_quant_mod = None
|
||||
self._fp8_attn_io_mod = None
|
||||
|
||||
def _get_kv_quant_mod(self):
|
||||
if self._kv_quant_mod is None:
|
||||
@@ -533,6 +549,18 @@ class KVCache:
|
||||
self._kv_quant_mod = get_cuda_module("kv_quantize", ["kv_quantize.cu"])
|
||||
return self._kv_quant_mod
|
||||
|
||||
def _get_fp8_attn_io_mod(self):
|
||||
if self._fp8_attn_io_mod is None:
|
||||
from dsv4.kernels.cuda.loader import get_cuda_module
|
||||
self._fp8_attn_io_mod = get_cuda_module(
|
||||
"fp8_attention_io", ["fp8_attention_io.cu"],
|
||||
extra_cuda_cflags=[
|
||||
"-gencode=arch=compute_100a,code=sm_100a",
|
||||
"-O3", "--use_fast_math", "--expt-relaxed-constexpr",
|
||||
],
|
||||
)
|
||||
return self._fp8_attn_io_mod
|
||||
|
||||
def append_swa(self, kv, pos):
|
||||
"""Vectorized SWA append — 2 kernel launches instead of 2T."""
|
||||
T = kv.shape[0]
|
||||
@@ -605,6 +633,53 @@ class KVCache:
|
||||
self.comp_idx_fp8[:self.n_comp],
|
||||
self.comp_idx_scale[:self.n_comp])
|
||||
|
||||
def gather_mixed_selective(self, indices):
|
||||
"""Gather selected compressed KV + SWA into mixed FP8/BF16 buffers.
|
||||
|
||||
Returns (nope_fp8, nope_scale, rope_bf16), each sliced to total length.
|
||||
noPE is not dequantized to global BF16.
|
||||
"""
|
||||
mod = self._get_fp8_attn_io_mod()
|
||||
swa_kv, _ = self.get_swa()
|
||||
idx = indices.int().contiguous()
|
||||
total = idx.numel() + swa_kv.shape[0]
|
||||
if total > self.mixed_gather_cap:
|
||||
raise RuntimeError(f"mixed gather capacity {self.mixed_gather_cap} < requested {total}")
|
||||
mod.gather_mixed_selective_(
|
||||
self.comp_nope_fp8, self.comp_nope_scale, self.comp_rope_bf16,
|
||||
swa_kv, idx, self.gather_nope_fp8, self.gather_nope_scale, self.gather_rope_bf16)
|
||||
return (self.gather_nope_fp8[:total],
|
||||
self.gather_nope_scale[:total],
|
||||
self.gather_rope_bf16[:total])
|
||||
|
||||
def gather_mixed_all(self):
|
||||
"""Gather all compressed KV + SWA in mixed FP8/BF16 storage for HCA."""
|
||||
mod = self._get_fp8_attn_io_mod()
|
||||
swa_kv, _ = self.get_swa()
|
||||
n_comp = int(self.n_comp)
|
||||
total = n_comp + swa_kv.shape[0]
|
||||
if total > self.mixed_gather_cap:
|
||||
raise RuntimeError(f"mixed gather capacity {self.mixed_gather_cap} < requested {total}")
|
||||
mod.gather_mixed_all_(
|
||||
self.comp_nope_fp8[:n_comp], self.comp_nope_scale[:n_comp], self.comp_rope_bf16[:n_comp],
|
||||
swa_kv, self.gather_nope_fp8, self.gather_nope_scale, self.gather_rope_bf16)
|
||||
return (self.gather_nope_fp8[:total],
|
||||
self.gather_nope_scale[:total],
|
||||
self.gather_rope_bf16[:total])
|
||||
|
||||
def gather_mixed_swa_only(self):
|
||||
"""Quantize SWA noPE tail to FP8 and keep SWA RoPE as BF16."""
|
||||
mod = self._get_fp8_attn_io_mod()
|
||||
swa_kv, _ = self.get_swa()
|
||||
total = swa_kv.shape[0]
|
||||
if total > self.mixed_gather_cap:
|
||||
raise RuntimeError(f"mixed gather capacity {self.mixed_gather_cap} < requested {total}")
|
||||
mod.gather_mixed_swa_only_(
|
||||
swa_kv, self.gather_nope_fp8, self.gather_nope_scale, self.gather_rope_bf16, self.rope_dim)
|
||||
return (self.gather_nope_fp8[:total],
|
||||
self.gather_nope_scale[:total],
|
||||
self.gather_rope_bf16[:total])
|
||||
|
||||
def get_swa(self):
|
||||
"""Return SWA KV and positions as views (no clone)."""
|
||||
if self.swa_len == 0:
|
||||
@@ -648,6 +723,28 @@ def _run_production_fmha(q_heads, all_kv, n_h, hd, T, seq_len, scale, dev, li, w
|
||||
attn_out = dsv4_attention(q=q, k=k, v=v, scale=scale, n_comp=0, sink_bias=sink_bias)
|
||||
return attn_out.permute(1, 0, 2) # (T, n_h, hd)
|
||||
|
||||
|
||||
def _run_production_fmha_mixed(q_heads, kv_nope_fp8, kv_nope_scale, kv_rope_bf16,
|
||||
n_h, hd, T, seq_len, scale, dev, li, w, pfx, rope_dim):
|
||||
"""B1 storage-native mixed FP8/BF16 decode FMHA. No BF16 KV staging."""
|
||||
if T != 1:
|
||||
raise RuntimeError(f"B1 mixed FP8 FMHA is decode-only (T==1); got T={T}")
|
||||
from dsv4.kernels.attention.production import dsv4_attention_mixed_fp8_decode
|
||||
q = q_heads.permute(1, 0, 2).contiguous() # (n_h, 1, hd)
|
||||
sinks = w.get(f"{pfx}.sinks"); sink_bias = None
|
||||
if sinks is not None:
|
||||
sink_bias = sinks.to(device=dev).float().reshape(n_h)
|
||||
attn_out = dsv4_attention_mixed_fp8_decode(
|
||||
q=q,
|
||||
k_nope_fp8=kv_nope_fp8,
|
||||
k_nope_scale=kv_nope_scale,
|
||||
k_rope_bf16=kv_rope_bf16,
|
||||
scale=scale,
|
||||
sink_bias=sink_bias,
|
||||
rope_dim=rope_dim,
|
||||
)
|
||||
return attn_out.permute(1, 0, 2) # (T, n_h, hd)
|
||||
|
||||
# =====================================================================
|
||||
# Attention — ALL production kernels
|
||||
# =====================================================================
|
||||
@@ -739,57 +836,38 @@ def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin,
|
||||
if indexer is not None and ratio == 4:
|
||||
topk_idx = indexer.forward(q_a, x_normed, kv_cache.comp_idx_kv, positions, layer_idx=li)
|
||||
|
||||
# 5. Gather KV — mixed storage: FP8 nope dequant + BF16 rope concat
|
||||
# 5. Gather KV — B1 storage-native mixed path.
|
||||
# noPE remains FP8_E4M3 + per-row scale; RoPE remains BF16.
|
||||
# There is no global FP8->BF16 noPE materialization before FMHA.
|
||||
_pt('gather_start')
|
||||
swa_kv, _swa_pos = kv_cache.get_swa()
|
||||
swa_len = swa_kv.shape[0]
|
||||
gbuf = kv_cache.gather_buf # (max_len, hd) pre-allocated BF16
|
||||
if kv_cache.n_comp > 0:
|
||||
if ratio == 4:
|
||||
# CSA: dequant only top-k entries
|
||||
# CSA: gather top-k compressed rows + SWA tail without dequantizing noPE.
|
||||
assert topk_idx is not None, f"CSA layer {li}: indexer returned no top-k — indexer is broken"
|
||||
tk = topk_idx[0].clamp(0, kv_cache.n_comp - 1).int()
|
||||
n_tk = tk.shape[0]
|
||||
# Dequant FP8 nope + gather BF16 rope for top-k
|
||||
nope_bf16 = kv_cache.comp_nope_selective(tk) # FP8→BF16 (n_tk, 448)
|
||||
rope_bf16 = kv_cache.comp_rope_selective(tk) # BF16 gather (n_tk, 64)
|
||||
gbuf[:n_tk, :nope_dim] = nope_bf16
|
||||
gbuf[:n_tk, nope_dim:] = rope_bf16
|
||||
gbuf[n_tk:n_tk + swa_len] = swa_kv
|
||||
all_kv = gbuf[:n_tk + swa_len]
|
||||
kv_nope_fp8, kv_nope_scale, kv_rope_bf16 = kv_cache.gather_mixed_selective(tk)
|
||||
elif ratio > 4:
|
||||
# HCA: dequant all entries
|
||||
n_comp = kv_cache.n_comp
|
||||
nope_bf16 = kv_cache.comp_nope_all # FP8→BF16 (n_comp, 448)
|
||||
rope_bf16 = kv_cache.comp_rope_all # BF16 (n_comp, 64)
|
||||
gbuf[:n_comp, :nope_dim] = nope_bf16
|
||||
gbuf[:n_comp, nope_dim:] = rope_bf16
|
||||
gbuf[n_comp:n_comp + swa_len] = swa_kv
|
||||
all_kv = gbuf[:n_comp + swa_len]
|
||||
# HCA: dense over compressed rows, still mixed storage.
|
||||
kv_nope_fp8, kv_nope_scale, kv_rope_bf16 = kv_cache.gather_mixed_all()
|
||||
else:
|
||||
gbuf[:swa_len] = swa_kv
|
||||
all_kv = gbuf[:swa_len]
|
||||
kv_nope_fp8, kv_nope_scale, kv_rope_bf16 = kv_cache.gather_mixed_swa_only()
|
||||
else:
|
||||
gbuf[:swa_len] = swa_kv
|
||||
all_kv = gbuf[:swa_len]
|
||||
seq_len = all_kv.shape[0]
|
||||
kv_nope_fp8, kv_nope_scale, kv_rope_bf16 = kv_cache.gather_mixed_swa_only()
|
||||
seq_len = kv_nope_scale.shape[0]
|
||||
if seq_len == 0: return torch.zeros(T, cfg["hidden_size"], dtype=torch.bfloat16, device=dev), q_a
|
||||
|
||||
# 6. Production FMHA
|
||||
# 6. Production FMHA — B1 mixed FP8/BF16 decode path.
|
||||
_pt('fmha_start')
|
||||
if VERBOSE >= 2 and li < 3:
|
||||
print(f" L{li} FMHA input: T={T} seq_len={seq_len} hd={hd} n_h={n_h} n_comp={kv_cache.n_comp} swa_len={swa_len}", flush=True)
|
||||
attn_out = _run_production_fmha(q_heads, all_kv, n_h, hd, T, seq_len, scale, dev, li, w, pfx)
|
||||
print(f" L{li} FMHA mixed input: T={T} seq_len={seq_len} hd={hd} n_h={n_h} n_comp={kv_cache.n_comp} swa_len={swa_len}", flush=True)
|
||||
attn_out = _run_production_fmha_mixed(
|
||||
q_heads, kv_nope_fp8, kv_nope_scale, kv_rope_bf16,
|
||||
n_h, hd, T, seq_len, scale, dev, li, w, pfx, rd)
|
||||
_pt('fmha_end')
|
||||
if VERBOSE >= 2 and li < 3:
|
||||
# Compare with PyTorch reference
|
||||
k_exp = all_kv.unsqueeze(0).expand(n_h, -1, -1).contiguous()
|
||||
v_exp = k_exp.clone()
|
||||
q_in = q_heads.permute(1, 0, 2)
|
||||
ref_scores = torch.matmul(q_in, k_exp.transpose(-1, -2)) * scale
|
||||
ref_attn = torch.matmul(torch.softmax(ref_scores.float(), -1).bfloat16(), v_exp).permute(1, 0, 2)
|
||||
cos_sim = torch.nn.functional.cosine_similarity(attn_out.flatten().float(), ref_attn.flatten().float(), dim=0).item()
|
||||
print(f" L{li} FMHA: |prod|={attn_out.abs().max().item():.6f} |ref|={ref_attn.abs().max().item():.6f} cos={cos_sim:.6f}", flush=True)
|
||||
print(f" L{li} FMHA mixed: |prod|={attn_out.abs().max().item():.6f} (reference disabled: B1 forbids global BF16 KV staging)", flush=True)
|
||||
# 7. Inverse RoPE
|
||||
_pt('inv_rope_start')
|
||||
attn_out = _apply_rope(attn_out, positions, rope_cos, rope_sin, rd, inverse=True)
|
||||
|
||||
Reference in New Issue
Block a user