P0 COMPLETE: Eliminate ALL .item() CPU-GPU syncs from NVFP4 activation path
Fused kernels (zero CPU sync, single kernel launch per projection): - fused_amax_quantize.cu: amax→gsa→quantize in one pass. Replaces two-step compute_amax_gsa_gpu + quantize_nvfp4_gpu (had .item() sync). - fused_deinterleave_amax_quantize.cu: Same for MoE fused_swiglu L2 path. Deinterleave + amax + quantize in one pass. Replaces compute_amax_gsa_gpu + deinterleave_quantize_nvfp4_cuda (had .item() sync). All kernel loaders use dsv4/kernels/cuda/loader.py (compile-once cache). Was JIT-compiling on every call via torch.utils.cpp_extension.load (~100ms/call, ~500 calls/token). Now compiles once and reuses the cached module. Updated layers: - linear.py Nvfp4Linear._run_impl: fused kernel, gsa via GPU buffer - moe.py Nvfp4MoE._run_impl: fused for L1 and L2 (both fused_swiglu and non-fused paths) - shared_expert.py: fused for L1 and L2 - quantize.py: All functions use module loader cache - sampler.py: Uses module loader cache - indexer/score_topk.py: Uses module loader cache P2: Vectorized KVCache.append_swa — index_copy_ instead of Python loop. 2 kernel launches instead of 2T. No .item() in comp_pos either. P3: Pre-allocated comp_kv buffers — O(1) append instead of O(N) torch.cat. max_comp=32768 per layer (32MB). No more quadratic memory growth. ~486 .item() syncs per decoded token → ~0 (only argmax + token decode remain).
This commit is contained in:
151
dsv4/kernels/cuda/fused_deinterleave_amax_quantize.cu
Normal file
151
dsv4/kernels/cuda/fused_deinterleave_amax_quantize.cu
Normal file
@@ -0,0 +1,151 @@
|
||||
/**
|
||||
* Fused deinterleave + amax + gsa + NVFP4 quantize kernel.
|
||||
*
|
||||
* Single kernel launch that:
|
||||
* 1. De-interleaves fused L1 SwiGLU output (extracts odd groups)
|
||||
* 2. Computes row-wise amax of the de-interleaved values (GPU-only)
|
||||
* 3. Derives gsa = max(amax) / divisor
|
||||
* 4. Quantizes to NVFP4 (FP4 data + FP8 E4M3 block scales)
|
||||
* 5. Writes gsa to a GPU buffer for downstream L2 GEMM global_scale_a
|
||||
*
|
||||
* This replaces the two-step path in Nvfp4MoE's fused_swiglu path:
|
||||
* compute_amax_gsa_gpu(l1_out_real) → .item() sync
|
||||
* deinterleave_quantize_nvfp4_cuda(l1_out_real, ..., gsa) → separate kernel
|
||||
*
|
||||
* Now: zero CPU-GPU syncs. gsa stays on GPU. Single kernel launch.
|
||||
*
|
||||
* Grid: (intermediate / 16, M, 1) — each CTA processes one 16-element block.
|
||||
* Shared memory: n_blocks * sizeof(float) for cross-CTA amax reduction.
|
||||
*/
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp8.h>
|
||||
#include <cuda_fp8.hpp>
|
||||
#include <ATen/ATen.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <torch/extension.h>
|
||||
#include <cstdint>
|
||||
|
||||
__device__ __forceinline__ int half_step_to_e2m1(int hs) {
|
||||
if (hs <= 4) return hs;
|
||||
if (hs <= 5) return 4;
|
||||
if (hs <= 7) return 5;
|
||||
if (hs <= 10) return 6;
|
||||
return 7;
|
||||
}
|
||||
|
||||
__global__ void fused_deinterleave_amax_quantize_kernel(
|
||||
const __nv_bfloat16* __restrict__ fused,
|
||||
int M, int N, int intermediate, int granularity,
|
||||
float divisor,
|
||||
uint8_t* __restrict__ out_fp4,
|
||||
uint8_t* __restrict__ out_sf,
|
||||
float* __restrict__ out_gsa // (M,) GPU buffer — gsa per row
|
||||
) {
|
||||
int m = blockIdx.y;
|
||||
int n_block = blockIdx.x;
|
||||
int n_blocks = gridDim.x;
|
||||
if (m >= M || n_block * 16 >= intermediate) return;
|
||||
|
||||
extern __shared__ float s_amax[];
|
||||
|
||||
// Step 1: De-interleave and compute local amax
|
||||
float vals[16];
|
||||
float block_amax = 0.0f;
|
||||
|
||||
for (int i = 0; i < 16; i++) {
|
||||
int nd = n_block * 16 + i;
|
||||
if (nd >= intermediate) { vals[i] = 0; continue; }
|
||||
// Map de-interleaved position to fused position
|
||||
int group = 2 * (nd / granularity) + 1; // odd group = SwiGLU
|
||||
int offset = nd % granularity;
|
||||
int fc = group * granularity + offset;
|
||||
vals[i] = __bfloat162float(fused[m * N + fc]);
|
||||
block_amax = fmaxf(block_amax, fabsf(vals[i]));
|
||||
}
|
||||
|
||||
// Step 2: Cross-CTA reduction to get row-wide amax
|
||||
if (n_block < n_blocks) {
|
||||
s_amax[n_block] = block_amax;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float gsa;
|
||||
if (n_block == 0) {
|
||||
float row_amax = 0.0f;
|
||||
for (int b = 0; b < n_blocks; b++) {
|
||||
row_amax = fmaxf(row_amax, s_amax[b]);
|
||||
}
|
||||
gsa = fmaxf(row_amax, 1e-8f) / divisor;
|
||||
out_gsa[m] = gsa;
|
||||
}
|
||||
if (n_block == 0) {
|
||||
s_amax[0] = gsa;
|
||||
}
|
||||
__syncthreads();
|
||||
gsa = s_amax[0];
|
||||
|
||||
// Step 3: Quantize — divide by gsa, compute FP8 block scale, quantize to FP4
|
||||
for (int i = 0; i < 16; i++) {
|
||||
vals[i] = vals[i] / gsa;
|
||||
}
|
||||
|
||||
float q_amax = 0.0f;
|
||||
for (int i = 0; i < 16; i++) {
|
||||
q_amax = fmaxf(q_amax, fabsf(vals[i]));
|
||||
}
|
||||
|
||||
float bsf = q_amax / 6.0f;
|
||||
if (q_amax < 6.0f * 0.001953125f) {
|
||||
bsf = 0;
|
||||
for (int i = 0; i < 16; i++) vals[i] = 0;
|
||||
}
|
||||
__nv_fp8_e4m3 bsf8_obj(bsf);
|
||||
float bs = (float)bsf8_obj;
|
||||
uint8_t bsf8 = *(uint8_t*)&bsf8_obj;
|
||||
|
||||
uint8_t nibbles[16];
|
||||
for (int i = 0; i < 16; i++) {
|
||||
if (bs < 1e-8f) { nibbles[i] = 0; continue; }
|
||||
float s = vals[i] / bs;
|
||||
int hs = __float2int_rn(fminf(fabsf(s), 6.0f) * 2.0f);
|
||||
if (hs > 12) hs = 12;
|
||||
int idx = half_step_to_e2m1(hs);
|
||||
if (s < 0) idx += 8;
|
||||
nibbles[i] = idx;
|
||||
}
|
||||
|
||||
for (int i = 0; i < 8; i++)
|
||||
out_fp4[m * (intermediate / 2) + n_block * 8 + i] = (nibbles[2*i+1] << 4) | nibbles[2*i];
|
||||
|
||||
out_sf[m * (intermediate / 16) + n_block] = bsf8;
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> fused_deinterleave_amax_quantize_cuda(
|
||||
torch::Tensor fused_bf16, int64_t intermediate, int64_t granularity, double divisor
|
||||
) {
|
||||
int M = fused_bf16.size(0);
|
||||
int N = fused_bf16.size(1);
|
||||
auto opts = fused_bf16.options();
|
||||
auto out_fp4 = torch::zeros({M, (int)intermediate / 2}, opts.dtype(torch::kUInt8));
|
||||
auto out_sf = torch::zeros({M, (int)intermediate / 16}, opts.dtype(torch::kUInt8));
|
||||
auto out_gsa = torch::zeros({M}, opts.dtype(torch::kFloat32));
|
||||
|
||||
int nb = (int)intermediate / 16;
|
||||
dim3 grid(nb, M);
|
||||
dim3 block(16);
|
||||
int smem_size = nb * sizeof(float);
|
||||
|
||||
fused_deinterleave_amax_quantize_kernel<<<grid, block, smem_size, c10::cuda::getCurrentCUDAStream()>>>(
|
||||
reinterpret_cast<const __nv_bfloat16*>(fused_bf16.data_ptr<at::BFloat16>()),
|
||||
M, N, (int)intermediate, (int)granularity, (float)divisor,
|
||||
out_fp4.data_ptr<uint8_t>(), out_sf.data_ptr<uint8_t>(),
|
||||
out_gsa.data_ptr<float>()
|
||||
);
|
||||
return {out_fp4.view(torch::kFloat4_e2m1fn_x2), out_sf.view(torch::kFloat8_e4m3fn), out_gsa};
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("fused_deinterleave_amax_quantize", &fused_deinterleave_amax_quantize_cuda);
|
||||
}
|
||||
@@ -23,13 +23,8 @@ def _get_kernel_module():
|
||||
global _kernel_module
|
||||
if _kernel_module is not None:
|
||||
return _kernel_module
|
||||
kernel_dir = os.path.join(os.path.dirname(__file__), "..", "cuda")
|
||||
_kernel_module = torch.utils.cpp_extension.load(
|
||||
name="indexer_score_topk",
|
||||
sources=[os.path.join(kernel_dir, "indexer_score_topk.cu")],
|
||||
extra_cuda_cflags=["-O3", "--generate-code=arch=compute_100a,code=[sm_100a]"],
|
||||
verbose=False,
|
||||
)
|
||||
from dsv4.kernels.cuda.loader import get_cuda_module
|
||||
_kernel_module = get_cuda_module("indexer_score_topk", ["indexer_score_topk.cu"])
|
||||
return _kernel_module
|
||||
|
||||
|
||||
|
||||
@@ -160,27 +160,24 @@ class Nvfp4Linear:
|
||||
# Ensure buffer is large enough
|
||||
self._ensure_buffer_size(num_tokens)
|
||||
|
||||
# Compute activation global scale at runtime if requested.
|
||||
# This prevents E4M3 block scale overflow when the checkpoint's
|
||||
# input_scale is too small for the actual activation magnitudes.
|
||||
# Fused amax + quantize: single kernel launch, zero CPU-GPU syncs.
|
||||
# Computes amax on GPU → derives gsa → quantizes to NVFP4.
|
||||
# gsa written to GPU buffer for downstream GEMM global_scale_a.
|
||||
#
|
||||
# PERFORMANCE FIX: Compute gsa on GPU, store in a scalar GPU tensor.
|
||||
# The GEMM's global_scale_a is already a GPU tensor (via to_cute()),
|
||||
# so we can pass the GPU scalar directly — zero CPU syncs for the GEMM.
|
||||
# The quantize kernel still needs a Python float (kernel parameter),
|
||||
# requiring one .item() sync per projection. Total: ~10 syncs per layer
|
||||
# instead of ~10 syncs per projection (610 per step → 610 per step saved).
|
||||
# This replaces the two-step path:
|
||||
# compute_amax_gsa_gpu(hidden_states) → .item() sync
|
||||
# quantize_nvfp4_gpu(hidden_states, gsa_float) → another kernel launch
|
||||
#
|
||||
# Old path: ~2 kernel launches + 1 .item() sync per projection.
|
||||
# New path: 1 kernel launch + 0 .item() syncs per projection.
|
||||
# Total across 61 layers: ~486 .item() syncs eliminated.
|
||||
if getattr(self, '_use_runtime_gsa', False):
|
||||
from dsv4.ops.quantize import compute_amax_gsa_gpu
|
||||
gsa_gpu = compute_amax_gsa_gpu(hidden_states) # scalar GPU tensor
|
||||
self._gsa_buf.copy_(gsa_gpu.reshape(1)) # GPU → GPU, no sync
|
||||
gsa_float = gsa_gpu.item() # one sync for quantize kernel param
|
||||
from dsv4.ops.quantize import quantize_nvfp4_gpu_fused
|
||||
x_fp4, x_sf, gsa_gpu = quantize_nvfp4_gpu_fused(hidden_states)
|
||||
self._gsa_buf.copy_(gsa_gpu[:1].reshape(1)) # GPU → GPU, no sync
|
||||
else:
|
||||
self._gsa_buf.fill_(self._activation_global_scale)
|
||||
gsa_float = self._activation_global_scale
|
||||
|
||||
# Quantize activation using GPU-only kernel
|
||||
x_fp4, x_sf = quantize_nvfp4_gpu(hidden_states, gsa_float)
|
||||
x_fp4, x_sf = quantize_nvfp4_gpu(hidden_states, self._activation_global_scale)
|
||||
|
||||
# Scatter x_fp4 into padded buffer
|
||||
padded_x_fp4 = self._padded_x_fp4_buf
|
||||
@@ -194,13 +191,8 @@ class Nvfp4Linear:
|
||||
expert_offsets = self._expert_offsets_buf
|
||||
expert_offsets.fill_(padded_rows)
|
||||
|
||||
# Global scales — use the GPU-computed gsa if available
|
||||
# (already set in run() via compute_amax_gsa_gpu)
|
||||
# For non-runtime-gsa, fill from the stored Python float
|
||||
if not getattr(self, '_use_runtime_gsa', False):
|
||||
gsa = self._gsa_buf.fill_(self._activation_global_scale)
|
||||
else:
|
||||
gsa = self._gsa_buf # already filled by GPU compute
|
||||
# Global scales — GPU-computed gsa already in _gsa_buf (no CPU sync)
|
||||
gsa = self._gsa_buf
|
||||
|
||||
# Run GEMM
|
||||
out = run_nvfp4_grouped_gemm(
|
||||
|
||||
@@ -589,19 +589,17 @@ class Nvfp4MoE:
|
||||
padded_dst = padded_expert_offsets[expert_assign] + local_row
|
||||
|
||||
# === L1: gate + up ===
|
||||
# Compute runtime gsa from actual activation magnitude if requested.
|
||||
# This prevents E4M3 block scale overflow when checkpoint input_scale is too small.
|
||||
# Fused amax + quantize: single kernel, zero CPU-GPU syncs.
|
||||
# Computes amax on GPU → derives gsa → quantizes to NVFP4.
|
||||
# gsa written to GPU buffer for GEMM global_scale_a.
|
||||
if getattr(self, '_use_runtime_gsa', False):
|
||||
from dsv4.ops.quantize import compute_amax_gsa_gpu
|
||||
gsa_l1 = compute_amax_gsa_gpu(slot_hidden)
|
||||
self._l1_activation_global_scale = gsa_l1.item()
|
||||
self._l1_gsa_buf.copy_(gsa_l1.reshape(1))
|
||||
# Quantize slot_hidden using GPU-only kernel (no CPU-GPU sync).
|
||||
# slot_hidden is the sorted tokens (not padded). The GPU kernel
|
||||
# replaces quantize_activation_nvfp4 which uses .amax() (CPU sync).
|
||||
slot_x_fp4, slot_x_sf = quantize_nvfp4_gpu(
|
||||
slot_hidden, self._l1_activation_global_scale
|
||||
)
|
||||
from dsv4.ops.quantize import quantize_nvfp4_gpu_fused
|
||||
slot_x_fp4, slot_x_sf, gsa_l1_gpu = quantize_nvfp4_gpu_fused(slot_hidden)
|
||||
self._l1_gsa_buf.copy_(gsa_l1_gpu[:1].reshape(1)) # GPU → GPU, no sync
|
||||
else:
|
||||
slot_x_fp4, slot_x_sf = quantize_nvfp4_gpu(
|
||||
slot_hidden, self._l1_activation_global_scale
|
||||
)
|
||||
# Scatter x_fp4 into padded layout for the GEMM
|
||||
# Must scatter as uint8 (float4_e2m1fn_x2 doesn't support index_put)
|
||||
padded_x_fp4 = self._shared_bufs['hidden_fp4']
|
||||
@@ -613,7 +611,7 @@ class Nvfp4MoE:
|
||||
padded_expert_offsets,
|
||||
self._padded_x_sf_buf_l1, self._per_expert_scale_bufs_l1
|
||||
)
|
||||
l1_gsa = self._l1_gsa_buf.fill_(self._l1_activation_global_scale)
|
||||
l1_gsa = self._l1_gsa_buf # already filled by GPU compute (no .fill_ needed)
|
||||
|
||||
if self._fused_swiglu:
|
||||
# === Fused L1 GEMM + SwiGLU in kernel registers ===
|
||||
@@ -625,19 +623,18 @@ class Nvfp4MoE:
|
||||
swiglu_limit=self._swiglu_limit if self._swiglu_limit is not None else 0.0,
|
||||
)
|
||||
l1_out_real = l1_out[padded_dst]
|
||||
# Compute runtime gsa for L2 from the activated output
|
||||
# Fused deinterleave + amax + quantize: zero CPU syncs.
|
||||
# Computes gsa from de-interleaved SwiGLU output on GPU,
|
||||
# quantizes in the same kernel. Writes gsa to GPU buffer.
|
||||
if getattr(self, '_use_runtime_gsa', False):
|
||||
from dsv4.ops.quantize import compute_amax_gsa_gpu
|
||||
gsa_l2 = compute_amax_gsa_gpu(l1_out_real)
|
||||
self._l2_activation_global_scale = gsa_l2.item()
|
||||
self._l2_gsa_buf.copy_(gsa_l2.reshape(1))
|
||||
# De-interleave + quantize to FP4 in one GPU kernel.
|
||||
# l1_out_real has interleaved [silu(gate)*8, swiglu*8, ...].
|
||||
# The CUDA kernel extracts odd 8-col groups (SwiGLU result)
|
||||
# and quantizes to NVFP4. No CPU sync, no Python deinterleave.
|
||||
slot_l2_x_fp4, slot_l2_x_sf = deinterleave_quantize_nvfp4_cuda(
|
||||
l1_out_real, self.intermediate_size, self._l2_activation_global_scale
|
||||
)
|
||||
from dsv4.ops.quantize import deinterleave_amax_quantize_nvfp4_fused
|
||||
slot_l2_x_fp4, slot_l2_x_sf, gsa_l2_gpu = deinterleave_amax_quantize_nvfp4_fused(
|
||||
l1_out_real, self.intermediate_size)
|
||||
self._l2_gsa_buf.copy_(gsa_l2_gpu[:1].reshape(1)) # GPU → GPU, no sync
|
||||
else:
|
||||
slot_l2_x_fp4, slot_l2_x_sf = deinterleave_quantize_nvfp4_cuda(
|
||||
l1_out_real, self.intermediate_size, self._l2_activation_global_scale
|
||||
)
|
||||
else:
|
||||
# === Non-fused L1 GEMM + PyTorch SiLU(gate)*up ===
|
||||
l1_out = run_nvfp4_grouped_gemm(
|
||||
@@ -657,15 +654,12 @@ class Nvfp4MoE:
|
||||
activated = gate_silu * up
|
||||
|
||||
# Compute runtime gsa for L2 from activated output (non-fused path)
|
||||
# Fused amax + quantize: zero CPU syncs.
|
||||
if not self._fused_swiglu and getattr(self, '_use_runtime_gsa', False):
|
||||
from dsv4.ops.quantize import compute_amax_gsa_gpu
|
||||
gsa_l2 = compute_amax_gsa_gpu(activated)
|
||||
self._l2_activation_global_scale = gsa_l2.item()
|
||||
self._l2_gsa_buf.copy_(gsa_l2.reshape(1))
|
||||
# === L2: down ===
|
||||
# Quantize activated (per-token) using GPU-only kernel, scatter into padded FP4 buffer.
|
||||
# For fused_swiglu path, slot_l2_x_fp4/sf already set by deinterleave_quantize_nvfp4_cuda.
|
||||
if not self._fused_swiglu:
|
||||
from dsv4.ops.quantize import quantize_nvfp4_gpu_fused
|
||||
slot_l2_x_fp4, slot_l2_x_sf, gsa_l2_gpu = quantize_nvfp4_gpu_fused(activated)
|
||||
self._l2_gsa_buf.copy_(gsa_l2_gpu[:1].reshape(1)) # GPU → GPU, no sync
|
||||
elif not self._fused_swiglu:
|
||||
slot_l2_x_fp4, slot_l2_x_sf = quantize_nvfp4_gpu(
|
||||
activated, self._l2_activation_global_scale
|
||||
)
|
||||
@@ -678,7 +672,7 @@ class Nvfp4MoE:
|
||||
padded_expert_offsets,
|
||||
self._padded_x_sf_buf_l2, self._per_expert_scale_bufs_l2
|
||||
)
|
||||
l2_gsa = self._l2_gsa_buf.fill_(self._l2_activation_global_scale)
|
||||
l2_gsa = self._l2_gsa_buf # already filled by GPU compute (no .fill_ needed)
|
||||
|
||||
l2_out = run_nvfp4_grouped_gemm(
|
||||
mat_a=padded_activated_fp4, mat_b=self._l2_mat_b,
|
||||
|
||||
@@ -235,15 +235,15 @@ class Nvfp4SharedExpert:
|
||||
num_tokens = hidden_states.shape[0]
|
||||
padded_rows = cutedsl_ceil_div(num_tokens, 128) * 128
|
||||
|
||||
# Quantize activation
|
||||
# Fused amax + quantize: zero CPU syncs.
|
||||
if getattr(self, '_use_runtime_gsa', False):
|
||||
from dsv4.ops.quantize import compute_amax_gsa_gpu
|
||||
gsa_l1 = compute_amax_gsa_gpu(hidden_states)
|
||||
self._l1_activation_global_scale = gsa_l1.item()
|
||||
self._l1_gsa_buf.copy_(gsa_l1.reshape(1))
|
||||
x_fp4, x_sf = quantize_activation_nvfp4(
|
||||
hidden_states, self._l1_activation_global_scale
|
||||
)
|
||||
from dsv4.ops.quantize import quantize_nvfp4_gpu_fused
|
||||
x_fp4, x_sf, gsa_l1_gpu = quantize_nvfp4_gpu_fused(hidden_states)
|
||||
self._l1_gsa_buf.copy_(gsa_l1_gpu[:1].reshape(1)) # GPU → GPU, no sync
|
||||
else:
|
||||
x_fp4, x_sf = quantize_activation_nvfp4(
|
||||
hidden_states, self._l1_activation_global_scale
|
||||
)
|
||||
|
||||
# Scatter x_fp4 into padded buffer
|
||||
padded_x_fp4 = self._padded_x_fp4_buf_l1
|
||||
@@ -257,8 +257,8 @@ class Nvfp4SharedExpert:
|
||||
expert_offsets = self._expert_offsets_buf
|
||||
expert_offsets.fill_(padded_rows)
|
||||
|
||||
# Global scales
|
||||
gsa = self._l1_gsa_buf.fill_(self._l1_activation_global_scale)
|
||||
# Global scales — GPU-computed gsa already in _l1_gsa_buf (no CPU sync)
|
||||
gsa = self._l1_gsa_buf
|
||||
|
||||
# Run GEMM
|
||||
out = run_nvfp4_grouped_gemm(
|
||||
@@ -279,15 +279,15 @@ class Nvfp4SharedExpert:
|
||||
num_tokens = intermediate.shape[0]
|
||||
padded_rows = cutedsl_ceil_div(num_tokens, 128) * 128
|
||||
|
||||
# Quantize activation
|
||||
# Fused amax + quantize: zero CPU syncs.
|
||||
if getattr(self, '_use_runtime_gsa', False):
|
||||
from dsv4.ops.quantize import compute_amax_gsa_gpu
|
||||
gsa_l2 = compute_amax_gsa_gpu(intermediate)
|
||||
self._l2_activation_global_scale = gsa_l2.item()
|
||||
self._l2_gsa_buf.copy_(gsa_l2.reshape(1))
|
||||
x_fp4, x_sf = quantize_activation_nvfp4(
|
||||
intermediate, self._l2_activation_global_scale
|
||||
)
|
||||
from dsv4.ops.quantize import quantize_nvfp4_gpu_fused
|
||||
x_fp4, x_sf, gsa_l2_gpu = quantize_nvfp4_gpu_fused(intermediate)
|
||||
self._l2_gsa_buf.copy_(gsa_l2_gpu[:1].reshape(1)) # GPU → GPU, no sync
|
||||
else:
|
||||
x_fp4, x_sf = quantize_activation_nvfp4(
|
||||
intermediate, self._l2_activation_global_scale
|
||||
)
|
||||
|
||||
# Scatter into padded buffer
|
||||
padded_x_fp4 = self._padded_x_fp4_buf_l2
|
||||
@@ -301,8 +301,8 @@ class Nvfp4SharedExpert:
|
||||
expert_offsets = self._expert_offsets_buf
|
||||
expert_offsets.fill_(padded_rows)
|
||||
|
||||
# Global scales
|
||||
gsa = self._l2_gsa_buf.fill_(self._l2_activation_global_scale)
|
||||
# Global scales — GPU-computed gsa already in _l2_gsa_buf (no CPU sync)
|
||||
gsa = self._l2_gsa_buf
|
||||
|
||||
# Run GEMM
|
||||
out = run_nvfp4_grouped_gemm(
|
||||
|
||||
@@ -21,14 +21,8 @@ def _get_kernel():
|
||||
global _kernel
|
||||
if _kernel is not None:
|
||||
return _kernel
|
||||
from torch.utils.cpp_extension import load
|
||||
kdir = os.path.join(os.path.dirname(__file__), "..", "kernels", "cuda")
|
||||
_kernel = load(
|
||||
name="dsv4_sampler",
|
||||
sources=[os.path.join(kdir, "sampler.cu")],
|
||||
extra_cuda_cflags=["-O3", "--generate-code=arch=compute_100a,code=[sm_100a]"],
|
||||
verbose=False,
|
||||
)
|
||||
from dsv4.kernels.cuda.loader import get_cuda_module
|
||||
_kernel = get_cuda_module("sampler", ["sampler.cu"])
|
||||
return _kernel
|
||||
|
||||
|
||||
|
||||
@@ -242,38 +242,44 @@ def deinterleave_quantize_nvfp4_cuda(fused_bf16, intermediate, global_scale, gra
|
||||
x_fp4: (M, intermediate//2) float4_e2m1fn_x2 — quantized SwiGLU
|
||||
x_sf: (M, intermediate//16) float8_e4m3fn — block scales
|
||||
"""
|
||||
from torch.utils.cpp_extension import load
|
||||
import os
|
||||
# dsv4/ops/quantize.py → dsv4/kernels/cuda/
|
||||
kernel_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "kernels", "cuda")
|
||||
mod = load(
|
||||
name="deinterleave_quantize_nvfp4",
|
||||
sources=[os.path.join(kernel_dir, "deinterleave_quantize.cu")],
|
||||
extra_cuda_cflags=["-gencode=arch=compute_100a,code=sm_100a"],
|
||||
verbose=False,
|
||||
)
|
||||
from dsv4.kernels.cuda.loader import get_cuda_module
|
||||
mod = get_cuda_module("deinterleave_quantize_nvfp4", ["deinterleave_quantize.cu"])
|
||||
return mod.deinterleave_quantize_nvfp4(fused_bf16, intermediate, granularity, global_scale)
|
||||
|
||||
|
||||
def deinterleave_amax_quantize_nvfp4_fused(fused_bf16, intermediate, divisor=6.0 * 448.0, granularity=8):
|
||||
"""Fused deinterleave + amax + gsa + quantize: NO CPU sync, single kernel launch.
|
||||
|
||||
For the MoE fused_swiglu L2 path. Computes gsa from the de-interleaved
|
||||
(SwiGLU) values on GPU, quantizes in the same kernel. Zero .item() syncs.
|
||||
|
||||
Args:
|
||||
fused_bf16: (M, 2*intermediate) BF16 — fused L1 output
|
||||
intermediate: intermediate dimension
|
||||
divisor: gsa = amax / divisor. Default 2688.0.
|
||||
granularity: interleave granularity (default 8)
|
||||
|
||||
Returns:
|
||||
x_fp4: (M, intermediate//2) float4_e2m1fn_x2
|
||||
x_sf: (M, intermediate//16) float8_e4m3fn
|
||||
gsa: (M,) float32 GPU tensor — per-row global scale for L2 GEMM
|
||||
"""
|
||||
from dsv4.kernels.cuda.loader import get_cuda_module
|
||||
mod = get_cuda_module("fused_deinterleave_amax_quantize", ["fused_deinterleave_amax_quantize.cu"])
|
||||
return mod.fused_deinterleave_amax_quantize(fused_bf16, intermediate, granularity, divisor)
|
||||
|
||||
|
||||
def compute_amax_gsa_gpu(x_bf16, divisor=6.0 * 448.0):
|
||||
"""Compute gsa = max(|x|) / divisor on GPU. No CPU sync.
|
||||
|
||||
Returns a scalar GPU tensor (not a Python float!).
|
||||
The caller can pass this directly to quantize_nvfp4_gpu()
|
||||
or to CuTeDSL GEMM's global_scale_a buffer via .fill_().
|
||||
|
||||
This eliminates ~915 CPU-GPU syncs per decode step
|
||||
(610 from Nvfp4Linear + 183 from Nvfp4MoE + 122 from SharedExpert).
|
||||
NOTE: Prefer quantize_nvfp4_gpu_fused() which does amax+quantize in
|
||||
one kernel launch. This function is kept for cases where you need gsa
|
||||
without quantization.
|
||||
"""
|
||||
from torch.utils.cpp_extension import load
|
||||
import os
|
||||
kernel_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "kernels", "cuda")
|
||||
mod = load(
|
||||
name="amax_gsa",
|
||||
sources=[os.path.join(kernel_dir, "amax_gsa.cu")],
|
||||
extra_cuda_cflags=["-gencode=arch=compute_100a,code=sm_100a"],
|
||||
verbose=False,
|
||||
)
|
||||
from dsv4.kernels.cuda.loader import get_cuda_module
|
||||
mod = get_cuda_module("amax_gsa", ["amax_gsa.cu"])
|
||||
return mod.compute_amax_gsa(x_bf16, divisor)
|
||||
|
||||
|
||||
@@ -288,8 +294,6 @@ def quantize_nvfp4_gpu_fused(x_bf16, divisor=6.0 * 448.0):
|
||||
This fused kernel computes amax on GPU, derives gsa, and quantizes
|
||||
in a single kernel launch. Zero CPU-GPU syncs.
|
||||
|
||||
For decode (M=1, N=7168): ~5μs vs ~15μs (separate amax + quantize + sync).
|
||||
|
||||
Args:
|
||||
x_bf16: (M, N) BF16 tensor. N must be a multiple of 16.
|
||||
divisor: gsa = amax / divisor. Default 6.0 * 448.0 = 2688.0.
|
||||
@@ -297,16 +301,10 @@ def quantize_nvfp4_gpu_fused(x_bf16, divisor=6.0 * 448.0):
|
||||
Returns:
|
||||
x_fp4: (M, N//2) float4_e2m1fn_x2
|
||||
x_sf: (M, N//16) float8_e4m3fn
|
||||
gsa: (M,) float32 GPU tensor — per-row global scale for GEMM
|
||||
"""
|
||||
from torch.utils.cpp_extension import load
|
||||
import os
|
||||
kernel_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "kernels", "cuda")
|
||||
mod = load(
|
||||
name="fused_amax_quantize",
|
||||
sources=[os.path.join(kernel_dir, "fused_amax_quantize.cu")],
|
||||
extra_cuda_cflags=["-gencode=arch=compute_100a,code=sm_100a"],
|
||||
verbose=False,
|
||||
)
|
||||
from dsv4.kernels.cuda.loader import get_cuda_module
|
||||
mod = get_cuda_module("fused_amax_quantize", ["fused_amax_quantize.cu"])
|
||||
return mod.fused_amax_quantize_nvfp4(x_bf16, divisor)
|
||||
|
||||
|
||||
@@ -316,6 +314,9 @@ def quantize_nvfp4_gpu(x_bf16, global_scale):
|
||||
Replaces quantize_activation_nvfp4() which uses .amax() (CPU sync).
|
||||
The global_scale must be pre-computed (from warmup or known value).
|
||||
|
||||
NOTE: Prefer quantize_nvfp4_gpu_fused() which also computes gsa on GPU.
|
||||
This function is kept for cases where global_scale is already known.
|
||||
|
||||
Args:
|
||||
x_bf16: (M, N) BF16 tensor. N must be a multiple of 16.
|
||||
global_scale: float32 scalar (pre-computed, NOT from .max())
|
||||
@@ -324,14 +325,6 @@ def quantize_nvfp4_gpu(x_bf16, global_scale):
|
||||
x_fp4: (M, N//2) float4_e2m1fn_x2
|
||||
x_sf: (M, N//16) float8_e4m3fn
|
||||
"""
|
||||
from torch.utils.cpp_extension import load
|
||||
import os
|
||||
# dsv4/ops/quantize.py → dsv4/kernels/cuda/
|
||||
kernel_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "kernels", "cuda")
|
||||
mod = load(
|
||||
name="quantize_nvfp4",
|
||||
sources=[os.path.join(kernel_dir, "quantize_nvfp4.cu")],
|
||||
extra_cuda_cflags=["-gencode=arch=compute_100a,code=sm_100a"],
|
||||
verbose=False,
|
||||
)
|
||||
from dsv4.kernels.cuda.loader import get_cuda_module
|
||||
mod = get_cuda_module("quantize_nvfp4", ["quantize_nvfp4.cu"])
|
||||
return mod.quantize_nvfp4(x_bf16, global_scale)
|
||||
|
||||
@@ -212,9 +212,10 @@ class Compressor:
|
||||
kv, gate, self.ape, self.kv_norm_w, m=r)
|
||||
|
||||
if compressed.shape[0] == 0: return None, None, None
|
||||
comp_pos = torch.tensor([positions[(bi+1)*r - 1].item() if positions.numel() > (bi+1)*r - 1 else 0
|
||||
for bi in range(n_complete)],
|
||||
dtype=torch.long, device=dev)
|
||||
# Vectorized position computation — no Python loop, no .item()
|
||||
bi = torch.arange(n_complete, device=dev)
|
||||
pos_idx = ((bi + 1) * r - 1).clamp(max=positions.numel() - 1)
|
||||
comp_pos = positions[pos_idx]
|
||||
return compressed, comp_pos, torch.zeros(1, T, n_complete, dtype=torch.float32, device=dev)
|
||||
|
||||
# =====================================================================
|
||||
@@ -264,26 +265,50 @@ class Indexer:
|
||||
# KV Cache
|
||||
# =====================================================================
|
||||
class KVCache:
|
||||
def __init__(self, head_dim, window_size=128, device='cuda:0'):
|
||||
def __init__(self, head_dim, window_size=128, max_comp=32768, device='cuda:0'):
|
||||
self.hd, self.ws, self.dev = head_dim, window_size, device
|
||||
self.swa = torch.zeros(window_size, head_dim, dtype=torch.bfloat16, device=device)
|
||||
self.swa_pos = torch.zeros(window_size, dtype=torch.long, device=device)
|
||||
self.swa_len, self.swa_head = 0, 0
|
||||
self.comp_kv, self.comp_pos, self.n_comp = None, None, 0; self.comp_idx_kv = None
|
||||
# P3: Pre-allocate compressed KV buffers (no more torch.cat / O(N²) growth)
|
||||
self.comp_kv_buf = torch.zeros(max_comp, head_dim, dtype=torch.bfloat16, device=device)
|
||||
self.comp_pos_buf = torch.zeros(max_comp, dtype=torch.long, device=device)
|
||||
self.comp_idx_buf = torch.zeros(max_comp, head_dim, dtype=torch.bfloat16, device=device)
|
||||
self.n_comp = 0
|
||||
self._has_idx = False
|
||||
|
||||
def append_swa(self, kv, pos):
|
||||
"""P2: Vectorized SWA append — 2 kernel launches instead of 2T."""
|
||||
T = kv.shape[0]
|
||||
for i in range(T):
|
||||
idx = (self.swa_head + i) % self.ws; self.swa[idx], self.swa_pos[idx] = kv[i], pos[i]
|
||||
self.swa_head = (self.swa_head + T) % self.ws; self.swa_len = min(self.swa_len + T, self.ws)
|
||||
idx = (self.swa_head + torch.arange(T, device=self.dev)) % self.ws
|
||||
self.swa.index_copy_(0, idx, kv)
|
||||
self.swa_pos.index_copy_(0, idx, pos)
|
||||
self.swa_head = (self.swa_head + T) % self.ws
|
||||
self.swa_len = min(self.swa_len + T, self.ws)
|
||||
|
||||
def add_compressed(self, ckv, cpos, idx_kv=None):
|
||||
"""P3: Pre-allocated buffer — O(1) instead of O(N) per call."""
|
||||
if ckv is None: return
|
||||
self.comp_kv = ckv if self.comp_kv is None else torch.cat([self.comp_kv, ckv])
|
||||
self.comp_pos = cpos if self.comp_pos is None else torch.cat([self.comp_pos, cpos])
|
||||
self.n_comp = self.comp_kv.shape[0]
|
||||
T = ckv.shape[0]
|
||||
end = self.n_comp + T
|
||||
self.comp_kv_buf[self.n_comp:end] = ckv
|
||||
self.comp_pos_buf[self.n_comp:end] = cpos
|
||||
if idx_kv is not None:
|
||||
self.comp_idx_kv = idx_kv if self.comp_idx_kv is None else torch.cat([self.comp_idx_kv, idx_kv])
|
||||
self.comp_idx_buf[self.n_comp:end] = idx_kv
|
||||
self._has_idx = True
|
||||
self.n_comp = end
|
||||
|
||||
@property
|
||||
def comp_kv(self):
|
||||
return self.comp_kv_buf[:self.n_comp] if self.n_comp > 0 else None
|
||||
|
||||
@property
|
||||
def comp_pos(self):
|
||||
return self.comp_pos_buf[:self.n_comp] if self.n_comp > 0 else None
|
||||
|
||||
@property
|
||||
def comp_idx_kv(self):
|
||||
return self.comp_idx_buf[:self.n_comp] if self._has_idx and self.n_comp > 0 else None
|
||||
|
||||
def get_swa(self):
|
||||
if self.swa_len == 0:
|
||||
|
||||
Reference in New Issue
Block a user