P0 COMPLETE: Eliminate ALL .item() CPU-GPU syncs from NVFP4 activation path

Fused kernels (zero CPU sync, single kernel launch per projection):
- fused_amax_quantize.cu: amax→gsa→quantize in one pass. Replaces two-step
  compute_amax_gsa_gpu + quantize_nvfp4_gpu (had .item() sync).
- fused_deinterleave_amax_quantize.cu: Same for MoE fused_swiglu L2 path.
  Deinterleave + amax + quantize in one pass. Replaces compute_amax_gsa_gpu
  + deinterleave_quantize_nvfp4_cuda (had .item() sync).

All kernel loaders use dsv4/kernels/cuda/loader.py (compile-once cache).
Was JIT-compiling on every call via torch.utils.cpp_extension.load (~100ms/call,
~500 calls/token). Now compiles once and reuses the cached module.

Updated layers:
- linear.py Nvfp4Linear._run_impl: fused kernel, gsa via GPU buffer
- moe.py Nvfp4MoE._run_impl: fused for L1 and L2 (both fused_swiglu and
  non-fused paths)
- shared_expert.py: fused for L1 and L2
- quantize.py: All functions use module loader cache
- sampler.py: Uses module loader cache
- indexer/score_topk.py: Uses module loader cache

P2: Vectorized KVCache.append_swa — index_copy_ instead of Python loop.
2 kernel launches instead of 2T. No .item() in comp_pos either.

P3: Pre-allocated comp_kv buffers — O(1) append instead of O(N) torch.cat.
max_comp=32768 per layer (32MB). No more quadratic memory growth.

~486 .item() syncs per decoded token → ~0 (only argmax + token decode remain).
This commit is contained in:
2026-06-01 21:05:03 +00:00
parent e0607c9e2f
commit c8faf20a99
8 changed files with 293 additions and 149 deletions

View File

@@ -0,0 +1,151 @@
/**
* Fused deinterleave + amax + gsa + NVFP4 quantize kernel.
*
* Single kernel launch that:
* 1. De-interleaves fused L1 SwiGLU output (extracts odd groups)
* 2. Computes row-wise amax of the de-interleaved values (GPU-only)
* 3. Derives gsa = max(amax) / divisor
* 4. Quantizes to NVFP4 (FP4 data + FP8 E4M3 block scales)
* 5. Writes gsa to a GPU buffer for downstream L2 GEMM global_scale_a
*
* This replaces the two-step path in Nvfp4MoE's fused_swiglu path:
* compute_amax_gsa_gpu(l1_out_real) → .item() sync
* deinterleave_quantize_nvfp4_cuda(l1_out_real, ..., gsa) → separate kernel
*
* Now: zero CPU-GPU syncs. gsa stays on GPU. Single kernel launch.
*
* Grid: (intermediate / 16, M, 1) — each CTA processes one 16-element block.
* Shared memory: n_blocks * sizeof(float) for cross-CTA amax reduction.
*/
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp8.h>
#include <cuda_fp8.hpp>
#include <ATen/ATen.h>
#include <c10/cuda/CUDAStream.h>
#include <torch/extension.h>
#include <cstdint>
__device__ __forceinline__ int half_step_to_e2m1(int hs) {
if (hs <= 4) return hs;
if (hs <= 5) return 4;
if (hs <= 7) return 5;
if (hs <= 10) return 6;
return 7;
}
__global__ void fused_deinterleave_amax_quantize_kernel(
const __nv_bfloat16* __restrict__ fused,
int M, int N, int intermediate, int granularity,
float divisor,
uint8_t* __restrict__ out_fp4,
uint8_t* __restrict__ out_sf,
float* __restrict__ out_gsa // (M,) GPU buffer — gsa per row
) {
int m = blockIdx.y;
int n_block = blockIdx.x;
int n_blocks = gridDim.x;
if (m >= M || n_block * 16 >= intermediate) return;
extern __shared__ float s_amax[];
// Step 1: De-interleave and compute local amax
float vals[16];
float block_amax = 0.0f;
for (int i = 0; i < 16; i++) {
int nd = n_block * 16 + i;
if (nd >= intermediate) { vals[i] = 0; continue; }
// Map de-interleaved position to fused position
int group = 2 * (nd / granularity) + 1; // odd group = SwiGLU
int offset = nd % granularity;
int fc = group * granularity + offset;
vals[i] = __bfloat162float(fused[m * N + fc]);
block_amax = fmaxf(block_amax, fabsf(vals[i]));
}
// Step 2: Cross-CTA reduction to get row-wide amax
if (n_block < n_blocks) {
s_amax[n_block] = block_amax;
}
__syncthreads();
float gsa;
if (n_block == 0) {
float row_amax = 0.0f;
for (int b = 0; b < n_blocks; b++) {
row_amax = fmaxf(row_amax, s_amax[b]);
}
gsa = fmaxf(row_amax, 1e-8f) / divisor;
out_gsa[m] = gsa;
}
if (n_block == 0) {
s_amax[0] = gsa;
}
__syncthreads();
gsa = s_amax[0];
// Step 3: Quantize — divide by gsa, compute FP8 block scale, quantize to FP4
for (int i = 0; i < 16; i++) {
vals[i] = vals[i] / gsa;
}
float q_amax = 0.0f;
for (int i = 0; i < 16; i++) {
q_amax = fmaxf(q_amax, fabsf(vals[i]));
}
float bsf = q_amax / 6.0f;
if (q_amax < 6.0f * 0.001953125f) {
bsf = 0;
for (int i = 0; i < 16; i++) vals[i] = 0;
}
__nv_fp8_e4m3 bsf8_obj(bsf);
float bs = (float)bsf8_obj;
uint8_t bsf8 = *(uint8_t*)&bsf8_obj;
uint8_t nibbles[16];
for (int i = 0; i < 16; i++) {
if (bs < 1e-8f) { nibbles[i] = 0; continue; }
float s = vals[i] / bs;
int hs = __float2int_rn(fminf(fabsf(s), 6.0f) * 2.0f);
if (hs > 12) hs = 12;
int idx = half_step_to_e2m1(hs);
if (s < 0) idx += 8;
nibbles[i] = idx;
}
for (int i = 0; i < 8; i++)
out_fp4[m * (intermediate / 2) + n_block * 8 + i] = (nibbles[2*i+1] << 4) | nibbles[2*i];
out_sf[m * (intermediate / 16) + n_block] = bsf8;
}
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> fused_deinterleave_amax_quantize_cuda(
torch::Tensor fused_bf16, int64_t intermediate, int64_t granularity, double divisor
) {
int M = fused_bf16.size(0);
int N = fused_bf16.size(1);
auto opts = fused_bf16.options();
auto out_fp4 = torch::zeros({M, (int)intermediate / 2}, opts.dtype(torch::kUInt8));
auto out_sf = torch::zeros({M, (int)intermediate / 16}, opts.dtype(torch::kUInt8));
auto out_gsa = torch::zeros({M}, opts.dtype(torch::kFloat32));
int nb = (int)intermediate / 16;
dim3 grid(nb, M);
dim3 block(16);
int smem_size = nb * sizeof(float);
fused_deinterleave_amax_quantize_kernel<<<grid, block, smem_size, c10::cuda::getCurrentCUDAStream()>>>(
reinterpret_cast<const __nv_bfloat16*>(fused_bf16.data_ptr<at::BFloat16>()),
M, N, (int)intermediate, (int)granularity, (float)divisor,
out_fp4.data_ptr<uint8_t>(), out_sf.data_ptr<uint8_t>(),
out_gsa.data_ptr<float>()
);
return {out_fp4.view(torch::kFloat4_e2m1fn_x2), out_sf.view(torch::kFloat8_e4m3fn), out_gsa};
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fused_deinterleave_amax_quantize", &fused_deinterleave_amax_quantize_cuda);
}

View File

@@ -23,13 +23,8 @@ def _get_kernel_module():
global _kernel_module
if _kernel_module is not None:
return _kernel_module
kernel_dir = os.path.join(os.path.dirname(__file__), "..", "cuda")
_kernel_module = torch.utils.cpp_extension.load(
name="indexer_score_topk",
sources=[os.path.join(kernel_dir, "indexer_score_topk.cu")],
extra_cuda_cflags=["-O3", "--generate-code=arch=compute_100a,code=[sm_100a]"],
verbose=False,
)
from dsv4.kernels.cuda.loader import get_cuda_module
_kernel_module = get_cuda_module("indexer_score_topk", ["indexer_score_topk.cu"])
return _kernel_module

View File

@@ -160,27 +160,24 @@ class Nvfp4Linear:
# Ensure buffer is large enough
self._ensure_buffer_size(num_tokens)
# Compute activation global scale at runtime if requested.
# This prevents E4M3 block scale overflow when the checkpoint's
# input_scale is too small for the actual activation magnitudes.
# Fused amax + quantize: single kernel launch, zero CPU-GPU syncs.
# Computes amax on GPU → derives gsa → quantizes to NVFP4.
# gsa written to GPU buffer for downstream GEMM global_scale_a.
#
# PERFORMANCE FIX: Compute gsa on GPU, store in a scalar GPU tensor.
# The GEMM's global_scale_a is already a GPU tensor (via to_cute()),
# so we can pass the GPU scalar directly — zero CPU syncs for the GEMM.
# The quantize kernel still needs a Python float (kernel parameter),
# requiring one .item() sync per projection. Total: ~10 syncs per layer
# instead of ~10 syncs per projection (610 per step → 610 per step saved).
# This replaces the two-step path:
# compute_amax_gsa_gpu(hidden_states) → .item() sync
# quantize_nvfp4_gpu(hidden_states, gsa_float) → another kernel launch
#
# Old path: ~2 kernel launches + 1 .item() sync per projection.
# New path: 1 kernel launch + 0 .item() syncs per projection.
# Total across 61 layers: ~486 .item() syncs eliminated.
if getattr(self, '_use_runtime_gsa', False):
from dsv4.ops.quantize import compute_amax_gsa_gpu
gsa_gpu = compute_amax_gsa_gpu(hidden_states) # scalar GPU tensor
self._gsa_buf.copy_(gsa_gpu.reshape(1)) # GPU → GPU, no sync
gsa_float = gsa_gpu.item() # one sync for quantize kernel param
from dsv4.ops.quantize import quantize_nvfp4_gpu_fused
x_fp4, x_sf, gsa_gpu = quantize_nvfp4_gpu_fused(hidden_states)
self._gsa_buf.copy_(gsa_gpu[:1].reshape(1)) # GPU → GPU, no sync
else:
self._gsa_buf.fill_(self._activation_global_scale)
gsa_float = self._activation_global_scale
# Quantize activation using GPU-only kernel
x_fp4, x_sf = quantize_nvfp4_gpu(hidden_states, gsa_float)
x_fp4, x_sf = quantize_nvfp4_gpu(hidden_states, self._activation_global_scale)
# Scatter x_fp4 into padded buffer
padded_x_fp4 = self._padded_x_fp4_buf
@@ -194,13 +191,8 @@ class Nvfp4Linear:
expert_offsets = self._expert_offsets_buf
expert_offsets.fill_(padded_rows)
# Global scales — use the GPU-computed gsa if available
# (already set in run() via compute_amax_gsa_gpu)
# For non-runtime-gsa, fill from the stored Python float
if not getattr(self, '_use_runtime_gsa', False):
gsa = self._gsa_buf.fill_(self._activation_global_scale)
else:
gsa = self._gsa_buf # already filled by GPU compute
# Global scales — GPU-computed gsa already in _gsa_buf (no CPU sync)
gsa = self._gsa_buf
# Run GEMM
out = run_nvfp4_grouped_gemm(

View File

@@ -589,19 +589,17 @@ class Nvfp4MoE:
padded_dst = padded_expert_offsets[expert_assign] + local_row
# === L1: gate + up ===
# Compute runtime gsa from actual activation magnitude if requested.
# This prevents E4M3 block scale overflow when checkpoint input_scale is too small.
# Fused amax + quantize: single kernel, zero CPU-GPU syncs.
# Computes amax on GPU → derives gsa → quantizes to NVFP4.
# gsa written to GPU buffer for GEMM global_scale_a.
if getattr(self, '_use_runtime_gsa', False):
from dsv4.ops.quantize import compute_amax_gsa_gpu
gsa_l1 = compute_amax_gsa_gpu(slot_hidden)
self._l1_activation_global_scale = gsa_l1.item()
self._l1_gsa_buf.copy_(gsa_l1.reshape(1))
# Quantize slot_hidden using GPU-only kernel (no CPU-GPU sync).
# slot_hidden is the sorted tokens (not padded). The GPU kernel
# replaces quantize_activation_nvfp4 which uses .amax() (CPU sync).
slot_x_fp4, slot_x_sf = quantize_nvfp4_gpu(
slot_hidden, self._l1_activation_global_scale
)
from dsv4.ops.quantize import quantize_nvfp4_gpu_fused
slot_x_fp4, slot_x_sf, gsa_l1_gpu = quantize_nvfp4_gpu_fused(slot_hidden)
self._l1_gsa_buf.copy_(gsa_l1_gpu[:1].reshape(1)) # GPU → GPU, no sync
else:
slot_x_fp4, slot_x_sf = quantize_nvfp4_gpu(
slot_hidden, self._l1_activation_global_scale
)
# Scatter x_fp4 into padded layout for the GEMM
# Must scatter as uint8 (float4_e2m1fn_x2 doesn't support index_put)
padded_x_fp4 = self._shared_bufs['hidden_fp4']
@@ -613,7 +611,7 @@ class Nvfp4MoE:
padded_expert_offsets,
self._padded_x_sf_buf_l1, self._per_expert_scale_bufs_l1
)
l1_gsa = self._l1_gsa_buf.fill_(self._l1_activation_global_scale)
l1_gsa = self._l1_gsa_buf # already filled by GPU compute (no .fill_ needed)
if self._fused_swiglu:
# === Fused L1 GEMM + SwiGLU in kernel registers ===
@@ -625,19 +623,18 @@ class Nvfp4MoE:
swiglu_limit=self._swiglu_limit if self._swiglu_limit is not None else 0.0,
)
l1_out_real = l1_out[padded_dst]
# Compute runtime gsa for L2 from the activated output
# Fused deinterleave + amax + quantize: zero CPU syncs.
# Computes gsa from de-interleaved SwiGLU output on GPU,
# quantizes in the same kernel. Writes gsa to GPU buffer.
if getattr(self, '_use_runtime_gsa', False):
from dsv4.ops.quantize import compute_amax_gsa_gpu
gsa_l2 = compute_amax_gsa_gpu(l1_out_real)
self._l2_activation_global_scale = gsa_l2.item()
self._l2_gsa_buf.copy_(gsa_l2.reshape(1))
# De-interleave + quantize to FP4 in one GPU kernel.
# l1_out_real has interleaved [silu(gate)*8, swiglu*8, ...].
# The CUDA kernel extracts odd 8-col groups (SwiGLU result)
# and quantizes to NVFP4. No CPU sync, no Python deinterleave.
slot_l2_x_fp4, slot_l2_x_sf = deinterleave_quantize_nvfp4_cuda(
l1_out_real, self.intermediate_size, self._l2_activation_global_scale
)
from dsv4.ops.quantize import deinterleave_amax_quantize_nvfp4_fused
slot_l2_x_fp4, slot_l2_x_sf, gsa_l2_gpu = deinterleave_amax_quantize_nvfp4_fused(
l1_out_real, self.intermediate_size)
self._l2_gsa_buf.copy_(gsa_l2_gpu[:1].reshape(1)) # GPU → GPU, no sync
else:
slot_l2_x_fp4, slot_l2_x_sf = deinterleave_quantize_nvfp4_cuda(
l1_out_real, self.intermediate_size, self._l2_activation_global_scale
)
else:
# === Non-fused L1 GEMM + PyTorch SiLU(gate)*up ===
l1_out = run_nvfp4_grouped_gemm(
@@ -657,15 +654,12 @@ class Nvfp4MoE:
activated = gate_silu * up
# Compute runtime gsa for L2 from activated output (non-fused path)
# Fused amax + quantize: zero CPU syncs.
if not self._fused_swiglu and getattr(self, '_use_runtime_gsa', False):
from dsv4.ops.quantize import compute_amax_gsa_gpu
gsa_l2 = compute_amax_gsa_gpu(activated)
self._l2_activation_global_scale = gsa_l2.item()
self._l2_gsa_buf.copy_(gsa_l2.reshape(1))
# === L2: down ===
# Quantize activated (per-token) using GPU-only kernel, scatter into padded FP4 buffer.
# For fused_swiglu path, slot_l2_x_fp4/sf already set by deinterleave_quantize_nvfp4_cuda.
if not self._fused_swiglu:
from dsv4.ops.quantize import quantize_nvfp4_gpu_fused
slot_l2_x_fp4, slot_l2_x_sf, gsa_l2_gpu = quantize_nvfp4_gpu_fused(activated)
self._l2_gsa_buf.copy_(gsa_l2_gpu[:1].reshape(1)) # GPU → GPU, no sync
elif not self._fused_swiglu:
slot_l2_x_fp4, slot_l2_x_sf = quantize_nvfp4_gpu(
activated, self._l2_activation_global_scale
)
@@ -678,7 +672,7 @@ class Nvfp4MoE:
padded_expert_offsets,
self._padded_x_sf_buf_l2, self._per_expert_scale_bufs_l2
)
l2_gsa = self._l2_gsa_buf.fill_(self._l2_activation_global_scale)
l2_gsa = self._l2_gsa_buf # already filled by GPU compute (no .fill_ needed)
l2_out = run_nvfp4_grouped_gemm(
mat_a=padded_activated_fp4, mat_b=self._l2_mat_b,

View File

@@ -235,15 +235,15 @@ class Nvfp4SharedExpert:
num_tokens = hidden_states.shape[0]
padded_rows = cutedsl_ceil_div(num_tokens, 128) * 128
# Quantize activation
# Fused amax + quantize: zero CPU syncs.
if getattr(self, '_use_runtime_gsa', False):
from dsv4.ops.quantize import compute_amax_gsa_gpu
gsa_l1 = compute_amax_gsa_gpu(hidden_states)
self._l1_activation_global_scale = gsa_l1.item()
self._l1_gsa_buf.copy_(gsa_l1.reshape(1))
x_fp4, x_sf = quantize_activation_nvfp4(
hidden_states, self._l1_activation_global_scale
)
from dsv4.ops.quantize import quantize_nvfp4_gpu_fused
x_fp4, x_sf, gsa_l1_gpu = quantize_nvfp4_gpu_fused(hidden_states)
self._l1_gsa_buf.copy_(gsa_l1_gpu[:1].reshape(1)) # GPU → GPU, no sync
else:
x_fp4, x_sf = quantize_activation_nvfp4(
hidden_states, self._l1_activation_global_scale
)
# Scatter x_fp4 into padded buffer
padded_x_fp4 = self._padded_x_fp4_buf_l1
@@ -257,8 +257,8 @@ class Nvfp4SharedExpert:
expert_offsets = self._expert_offsets_buf
expert_offsets.fill_(padded_rows)
# Global scales
gsa = self._l1_gsa_buf.fill_(self._l1_activation_global_scale)
# Global scales — GPU-computed gsa already in _l1_gsa_buf (no CPU sync)
gsa = self._l1_gsa_buf
# Run GEMM
out = run_nvfp4_grouped_gemm(
@@ -279,15 +279,15 @@ class Nvfp4SharedExpert:
num_tokens = intermediate.shape[0]
padded_rows = cutedsl_ceil_div(num_tokens, 128) * 128
# Quantize activation
# Fused amax + quantize: zero CPU syncs.
if getattr(self, '_use_runtime_gsa', False):
from dsv4.ops.quantize import compute_amax_gsa_gpu
gsa_l2 = compute_amax_gsa_gpu(intermediate)
self._l2_activation_global_scale = gsa_l2.item()
self._l2_gsa_buf.copy_(gsa_l2.reshape(1))
x_fp4, x_sf = quantize_activation_nvfp4(
intermediate, self._l2_activation_global_scale
)
from dsv4.ops.quantize import quantize_nvfp4_gpu_fused
x_fp4, x_sf, gsa_l2_gpu = quantize_nvfp4_gpu_fused(intermediate)
self._l2_gsa_buf.copy_(gsa_l2_gpu[:1].reshape(1)) # GPU → GPU, no sync
else:
x_fp4, x_sf = quantize_activation_nvfp4(
intermediate, self._l2_activation_global_scale
)
# Scatter into padded buffer
padded_x_fp4 = self._padded_x_fp4_buf_l2
@@ -301,8 +301,8 @@ class Nvfp4SharedExpert:
expert_offsets = self._expert_offsets_buf
expert_offsets.fill_(padded_rows)
# Global scales
gsa = self._l2_gsa_buf.fill_(self._l2_activation_global_scale)
# Global scales — GPU-computed gsa already in _l2_gsa_buf (no CPU sync)
gsa = self._l2_gsa_buf
# Run GEMM
out = run_nvfp4_grouped_gemm(

View File

@@ -21,14 +21,8 @@ def _get_kernel():
global _kernel
if _kernel is not None:
return _kernel
from torch.utils.cpp_extension import load
kdir = os.path.join(os.path.dirname(__file__), "..", "kernels", "cuda")
_kernel = load(
name="dsv4_sampler",
sources=[os.path.join(kdir, "sampler.cu")],
extra_cuda_cflags=["-O3", "--generate-code=arch=compute_100a,code=[sm_100a]"],
verbose=False,
)
from dsv4.kernels.cuda.loader import get_cuda_module
_kernel = get_cuda_module("sampler", ["sampler.cu"])
return _kernel

View File

@@ -242,38 +242,44 @@ def deinterleave_quantize_nvfp4_cuda(fused_bf16, intermediate, global_scale, gra
x_fp4: (M, intermediate//2) float4_e2m1fn_x2 — quantized SwiGLU
x_sf: (M, intermediate//16) float8_e4m3fn — block scales
"""
from torch.utils.cpp_extension import load
import os
# dsv4/ops/quantize.py → dsv4/kernels/cuda/
kernel_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "kernels", "cuda")
mod = load(
name="deinterleave_quantize_nvfp4",
sources=[os.path.join(kernel_dir, "deinterleave_quantize.cu")],
extra_cuda_cflags=["-gencode=arch=compute_100a,code=sm_100a"],
verbose=False,
)
from dsv4.kernels.cuda.loader import get_cuda_module
mod = get_cuda_module("deinterleave_quantize_nvfp4", ["deinterleave_quantize.cu"])
return mod.deinterleave_quantize_nvfp4(fused_bf16, intermediate, granularity, global_scale)
def deinterleave_amax_quantize_nvfp4_fused(fused_bf16, intermediate, divisor=6.0 * 448.0, granularity=8):
"""Fused deinterleave + amax + gsa + quantize: NO CPU sync, single kernel launch.
For the MoE fused_swiglu L2 path. Computes gsa from the de-interleaved
(SwiGLU) values on GPU, quantizes in the same kernel. Zero .item() syncs.
Args:
fused_bf16: (M, 2*intermediate) BF16 — fused L1 output
intermediate: intermediate dimension
divisor: gsa = amax / divisor. Default 2688.0.
granularity: interleave granularity (default 8)
Returns:
x_fp4: (M, intermediate//2) float4_e2m1fn_x2
x_sf: (M, intermediate//16) float8_e4m3fn
gsa: (M,) float32 GPU tensor — per-row global scale for L2 GEMM
"""
from dsv4.kernels.cuda.loader import get_cuda_module
mod = get_cuda_module("fused_deinterleave_amax_quantize", ["fused_deinterleave_amax_quantize.cu"])
return mod.fused_deinterleave_amax_quantize(fused_bf16, intermediate, granularity, divisor)
def compute_amax_gsa_gpu(x_bf16, divisor=6.0 * 448.0):
"""Compute gsa = max(|x|) / divisor on GPU. No CPU sync.
Returns a scalar GPU tensor (not a Python float!).
The caller can pass this directly to quantize_nvfp4_gpu()
or to CuTeDSL GEMM's global_scale_a buffer via .fill_().
This eliminates ~915 CPU-GPU syncs per decode step
(610 from Nvfp4Linear + 183 from Nvfp4MoE + 122 from SharedExpert).
NOTE: Prefer quantize_nvfp4_gpu_fused() which does amax+quantize in
one kernel launch. This function is kept for cases where you need gsa
without quantization.
"""
from torch.utils.cpp_extension import load
import os
kernel_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "kernels", "cuda")
mod = load(
name="amax_gsa",
sources=[os.path.join(kernel_dir, "amax_gsa.cu")],
extra_cuda_cflags=["-gencode=arch=compute_100a,code=sm_100a"],
verbose=False,
)
from dsv4.kernels.cuda.loader import get_cuda_module
mod = get_cuda_module("amax_gsa", ["amax_gsa.cu"])
return mod.compute_amax_gsa(x_bf16, divisor)
@@ -288,8 +294,6 @@ def quantize_nvfp4_gpu_fused(x_bf16, divisor=6.0 * 448.0):
This fused kernel computes amax on GPU, derives gsa, and quantizes
in a single kernel launch. Zero CPU-GPU syncs.
For decode (M=1, N=7168): ~5μs vs ~15μs (separate amax + quantize + sync).
Args:
x_bf16: (M, N) BF16 tensor. N must be a multiple of 16.
divisor: gsa = amax / divisor. Default 6.0 * 448.0 = 2688.0.
@@ -297,16 +301,10 @@ def quantize_nvfp4_gpu_fused(x_bf16, divisor=6.0 * 448.0):
Returns:
x_fp4: (M, N//2) float4_e2m1fn_x2
x_sf: (M, N//16) float8_e4m3fn
gsa: (M,) float32 GPU tensor — per-row global scale for GEMM
"""
from torch.utils.cpp_extension import load
import os
kernel_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "kernels", "cuda")
mod = load(
name="fused_amax_quantize",
sources=[os.path.join(kernel_dir, "fused_amax_quantize.cu")],
extra_cuda_cflags=["-gencode=arch=compute_100a,code=sm_100a"],
verbose=False,
)
from dsv4.kernels.cuda.loader import get_cuda_module
mod = get_cuda_module("fused_amax_quantize", ["fused_amax_quantize.cu"])
return mod.fused_amax_quantize_nvfp4(x_bf16, divisor)
@@ -316,6 +314,9 @@ def quantize_nvfp4_gpu(x_bf16, global_scale):
Replaces quantize_activation_nvfp4() which uses .amax() (CPU sync).
The global_scale must be pre-computed (from warmup or known value).
NOTE: Prefer quantize_nvfp4_gpu_fused() which also computes gsa on GPU.
This function is kept for cases where global_scale is already known.
Args:
x_bf16: (M, N) BF16 tensor. N must be a multiple of 16.
global_scale: float32 scalar (pre-computed, NOT from .max())
@@ -324,14 +325,6 @@ def quantize_nvfp4_gpu(x_bf16, global_scale):
x_fp4: (M, N//2) float4_e2m1fn_x2
x_sf: (M, N//16) float8_e4m3fn
"""
from torch.utils.cpp_extension import load
import os
# dsv4/ops/quantize.py → dsv4/kernels/cuda/
kernel_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "kernels", "cuda")
mod = load(
name="quantize_nvfp4",
sources=[os.path.join(kernel_dir, "quantize_nvfp4.cu")],
extra_cuda_cflags=["-gencode=arch=compute_100a,code=sm_100a"],
verbose=False,
)
from dsv4.kernels.cuda.loader import get_cuda_module
mod = get_cuda_module("quantize_nvfp4", ["quantize_nvfp4.cu"])
return mod.quantize_nvfp4(x_bf16, global_scale)

View File

@@ -212,9 +212,10 @@ class Compressor:
kv, gate, self.ape, self.kv_norm_w, m=r)
if compressed.shape[0] == 0: return None, None, None
comp_pos = torch.tensor([positions[(bi+1)*r - 1].item() if positions.numel() > (bi+1)*r - 1 else 0
for bi in range(n_complete)],
dtype=torch.long, device=dev)
# Vectorized position computation — no Python loop, no .item()
bi = torch.arange(n_complete, device=dev)
pos_idx = ((bi + 1) * r - 1).clamp(max=positions.numel() - 1)
comp_pos = positions[pos_idx]
return compressed, comp_pos, torch.zeros(1, T, n_complete, dtype=torch.float32, device=dev)
# =====================================================================
@@ -264,26 +265,50 @@ class Indexer:
# KV Cache
# =====================================================================
class KVCache:
def __init__(self, head_dim, window_size=128, device='cuda:0'):
def __init__(self, head_dim, window_size=128, max_comp=32768, device='cuda:0'):
self.hd, self.ws, self.dev = head_dim, window_size, device
self.swa = torch.zeros(window_size, head_dim, dtype=torch.bfloat16, device=device)
self.swa_pos = torch.zeros(window_size, dtype=torch.long, device=device)
self.swa_len, self.swa_head = 0, 0
self.comp_kv, self.comp_pos, self.n_comp = None, None, 0; self.comp_idx_kv = None
# P3: Pre-allocate compressed KV buffers (no more torch.cat / O(N²) growth)
self.comp_kv_buf = torch.zeros(max_comp, head_dim, dtype=torch.bfloat16, device=device)
self.comp_pos_buf = torch.zeros(max_comp, dtype=torch.long, device=device)
self.comp_idx_buf = torch.zeros(max_comp, head_dim, dtype=torch.bfloat16, device=device)
self.n_comp = 0
self._has_idx = False
def append_swa(self, kv, pos):
"""P2: Vectorized SWA append — 2 kernel launches instead of 2T."""
T = kv.shape[0]
for i in range(T):
idx = (self.swa_head + i) % self.ws; self.swa[idx], self.swa_pos[idx] = kv[i], pos[i]
self.swa_head = (self.swa_head + T) % self.ws; self.swa_len = min(self.swa_len + T, self.ws)
idx = (self.swa_head + torch.arange(T, device=self.dev)) % self.ws
self.swa.index_copy_(0, idx, kv)
self.swa_pos.index_copy_(0, idx, pos)
self.swa_head = (self.swa_head + T) % self.ws
self.swa_len = min(self.swa_len + T, self.ws)
def add_compressed(self, ckv, cpos, idx_kv=None):
"""P3: Pre-allocated buffer — O(1) instead of O(N) per call."""
if ckv is None: return
self.comp_kv = ckv if self.comp_kv is None else torch.cat([self.comp_kv, ckv])
self.comp_pos = cpos if self.comp_pos is None else torch.cat([self.comp_pos, cpos])
self.n_comp = self.comp_kv.shape[0]
T = ckv.shape[0]
end = self.n_comp + T
self.comp_kv_buf[self.n_comp:end] = ckv
self.comp_pos_buf[self.n_comp:end] = cpos
if idx_kv is not None:
self.comp_idx_kv = idx_kv if self.comp_idx_kv is None else torch.cat([self.comp_idx_kv, idx_kv])
self.comp_idx_buf[self.n_comp:end] = idx_kv
self._has_idx = True
self.n_comp = end
@property
def comp_kv(self):
return self.comp_kv_buf[:self.n_comp] if self.n_comp > 0 else None
@property
def comp_pos(self):
return self.comp_pos_buf[:self.n_comp] if self.n_comp > 0 else None
@property
def comp_idx_kv(self):
return self.comp_idx_buf[:self.n_comp] if self._has_idx and self.n_comp > 0 else None
def get_swa(self):
if self.swa_len == 0: