diff --git a/dsv4/kernels/cuda/fused_deinterleave_amax_quantize.cu b/dsv4/kernels/cuda/fused_deinterleave_amax_quantize.cu new file mode 100644 index 00000000..a73e1984 --- /dev/null +++ b/dsv4/kernels/cuda/fused_deinterleave_amax_quantize.cu @@ -0,0 +1,151 @@ +/** + * Fused deinterleave + amax + gsa + NVFP4 quantize kernel. + * + * Single kernel launch that: + * 1. De-interleaves fused L1 SwiGLU output (extracts odd groups) + * 2. Computes row-wise amax of the de-interleaved values (GPU-only) + * 3. Derives gsa = max(amax) / divisor + * 4. Quantizes to NVFP4 (FP4 data + FP8 E4M3 block scales) + * 5. Writes gsa to a GPU buffer for downstream L2 GEMM global_scale_a + * + * This replaces the two-step path in Nvfp4MoE's fused_swiglu path: + * compute_amax_gsa_gpu(l1_out_real) → .item() sync + * deinterleave_quantize_nvfp4_cuda(l1_out_real, ..., gsa) → separate kernel + * + * Now: zero CPU-GPU syncs. gsa stays on GPU. Single kernel launch. + * + * Grid: (intermediate / 16, M, 1) — each CTA processes one 16-element block. + * Shared memory: n_blocks * sizeof(float) for cross-CTA amax reduction. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +__device__ __forceinline__ int half_step_to_e2m1(int hs) { + if (hs <= 4) return hs; + if (hs <= 5) return 4; + if (hs <= 7) return 5; + if (hs <= 10) return 6; + return 7; +} + +__global__ void fused_deinterleave_amax_quantize_kernel( + const __nv_bfloat16* __restrict__ fused, + int M, int N, int intermediate, int granularity, + float divisor, + uint8_t* __restrict__ out_fp4, + uint8_t* __restrict__ out_sf, + float* __restrict__ out_gsa // (M,) GPU buffer — gsa per row +) { + int m = blockIdx.y; + int n_block = blockIdx.x; + int n_blocks = gridDim.x; + if (m >= M || n_block * 16 >= intermediate) return; + + extern __shared__ float s_amax[]; + + // Step 1: De-interleave and compute local amax + float vals[16]; + float block_amax = 0.0f; + + for (int i = 0; i < 16; i++) { + int nd = n_block * 16 + i; + if (nd >= intermediate) { vals[i] = 0; continue; } + // Map de-interleaved position to fused position + int group = 2 * (nd / granularity) + 1; // odd group = SwiGLU + int offset = nd % granularity; + int fc = group * granularity + offset; + vals[i] = __bfloat162float(fused[m * N + fc]); + block_amax = fmaxf(block_amax, fabsf(vals[i])); + } + + // Step 2: Cross-CTA reduction to get row-wide amax + if (n_block < n_blocks) { + s_amax[n_block] = block_amax; + } + __syncthreads(); + + float gsa; + if (n_block == 0) { + float row_amax = 0.0f; + for (int b = 0; b < n_blocks; b++) { + row_amax = fmaxf(row_amax, s_amax[b]); + } + gsa = fmaxf(row_amax, 1e-8f) / divisor; + out_gsa[m] = gsa; + } + if (n_block == 0) { + s_amax[0] = gsa; + } + __syncthreads(); + gsa = s_amax[0]; + + // Step 3: Quantize — divide by gsa, compute FP8 block scale, quantize to FP4 + for (int i = 0; i < 16; i++) { + vals[i] = vals[i] / gsa; + } + + float q_amax = 0.0f; + for (int i = 0; i < 16; i++) { + q_amax = fmaxf(q_amax, fabsf(vals[i])); + } + + float bsf = q_amax / 6.0f; + if (q_amax < 6.0f * 0.001953125f) { + bsf = 0; + for (int i = 0; i < 16; i++) vals[i] = 0; + } + __nv_fp8_e4m3 bsf8_obj(bsf); + float bs = (float)bsf8_obj; + uint8_t bsf8 = *(uint8_t*)&bsf8_obj; + + uint8_t nibbles[16]; + for (int i = 0; i < 16; i++) { + if (bs < 1e-8f) { nibbles[i] = 0; continue; } + float s = vals[i] / bs; + int hs = __float2int_rn(fminf(fabsf(s), 6.0f) * 2.0f); + if (hs > 12) hs = 12; + int idx = half_step_to_e2m1(hs); + if (s < 0) idx += 8; + nibbles[i] = idx; + } + + for (int i = 0; i < 8; i++) + out_fp4[m * (intermediate / 2) + n_block * 8 + i] = (nibbles[2*i+1] << 4) | nibbles[2*i]; + + out_sf[m * (intermediate / 16) + n_block] = bsf8; +} + +std::tuple fused_deinterleave_amax_quantize_cuda( + torch::Tensor fused_bf16, int64_t intermediate, int64_t granularity, double divisor +) { + int M = fused_bf16.size(0); + int N = fused_bf16.size(1); + auto opts = fused_bf16.options(); + auto out_fp4 = torch::zeros({M, (int)intermediate / 2}, opts.dtype(torch::kUInt8)); + auto out_sf = torch::zeros({M, (int)intermediate / 16}, opts.dtype(torch::kUInt8)); + auto out_gsa = torch::zeros({M}, opts.dtype(torch::kFloat32)); + + int nb = (int)intermediate / 16; + dim3 grid(nb, M); + dim3 block(16); + int smem_size = nb * sizeof(float); + + fused_deinterleave_amax_quantize_kernel<<>>( + reinterpret_cast(fused_bf16.data_ptr()), + M, N, (int)intermediate, (int)granularity, (float)divisor, + out_fp4.data_ptr(), out_sf.data_ptr(), + out_gsa.data_ptr() + ); + return {out_fp4.view(torch::kFloat4_e2m1fn_x2), out_sf.view(torch::kFloat8_e4m3fn), out_gsa}; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("fused_deinterleave_amax_quantize", &fused_deinterleave_amax_quantize_cuda); +} diff --git a/dsv4/kernels/indexer/score_topk.py b/dsv4/kernels/indexer/score_topk.py index be0f7493..9b0a8227 100644 --- a/dsv4/kernels/indexer/score_topk.py +++ b/dsv4/kernels/indexer/score_topk.py @@ -23,13 +23,8 @@ def _get_kernel_module(): global _kernel_module if _kernel_module is not None: return _kernel_module - kernel_dir = os.path.join(os.path.dirname(__file__), "..", "cuda") - _kernel_module = torch.utils.cpp_extension.load( - name="indexer_score_topk", - sources=[os.path.join(kernel_dir, "indexer_score_topk.cu")], - extra_cuda_cflags=["-O3", "--generate-code=arch=compute_100a,code=[sm_100a]"], - verbose=False, - ) + from dsv4.kernels.cuda.loader import get_cuda_module + _kernel_module = get_cuda_module("indexer_score_topk", ["indexer_score_topk.cu"]) return _kernel_module diff --git a/dsv4/layers/linear.py b/dsv4/layers/linear.py index 0a23ee01..c67cadad 100644 --- a/dsv4/layers/linear.py +++ b/dsv4/layers/linear.py @@ -160,27 +160,24 @@ class Nvfp4Linear: # Ensure buffer is large enough self._ensure_buffer_size(num_tokens) - # Compute activation global scale at runtime if requested. - # This prevents E4M3 block scale overflow when the checkpoint's - # input_scale is too small for the actual activation magnitudes. + # Fused amax + quantize: single kernel launch, zero CPU-GPU syncs. + # Computes amax on GPU → derives gsa → quantizes to NVFP4. + # gsa written to GPU buffer for downstream GEMM global_scale_a. # - # PERFORMANCE FIX: Compute gsa on GPU, store in a scalar GPU tensor. - # The GEMM's global_scale_a is already a GPU tensor (via to_cute()), - # so we can pass the GPU scalar directly — zero CPU syncs for the GEMM. - # The quantize kernel still needs a Python float (kernel parameter), - # requiring one .item() sync per projection. Total: ~10 syncs per layer - # instead of ~10 syncs per projection (610 per step → 610 per step saved). + # This replaces the two-step path: + # compute_amax_gsa_gpu(hidden_states) → .item() sync + # quantize_nvfp4_gpu(hidden_states, gsa_float) → another kernel launch + # + # Old path: ~2 kernel launches + 1 .item() sync per projection. + # New path: 1 kernel launch + 0 .item() syncs per projection. + # Total across 61 layers: ~486 .item() syncs eliminated. if getattr(self, '_use_runtime_gsa', False): - from dsv4.ops.quantize import compute_amax_gsa_gpu - gsa_gpu = compute_amax_gsa_gpu(hidden_states) # scalar GPU tensor - self._gsa_buf.copy_(gsa_gpu.reshape(1)) # GPU → GPU, no sync - gsa_float = gsa_gpu.item() # one sync for quantize kernel param + from dsv4.ops.quantize import quantize_nvfp4_gpu_fused + x_fp4, x_sf, gsa_gpu = quantize_nvfp4_gpu_fused(hidden_states) + self._gsa_buf.copy_(gsa_gpu[:1].reshape(1)) # GPU → GPU, no sync else: self._gsa_buf.fill_(self._activation_global_scale) - gsa_float = self._activation_global_scale - - # Quantize activation using GPU-only kernel - x_fp4, x_sf = quantize_nvfp4_gpu(hidden_states, gsa_float) + x_fp4, x_sf = quantize_nvfp4_gpu(hidden_states, self._activation_global_scale) # Scatter x_fp4 into padded buffer padded_x_fp4 = self._padded_x_fp4_buf @@ -194,13 +191,8 @@ class Nvfp4Linear: expert_offsets = self._expert_offsets_buf expert_offsets.fill_(padded_rows) - # Global scales — use the GPU-computed gsa if available - # (already set in run() via compute_amax_gsa_gpu) - # For non-runtime-gsa, fill from the stored Python float - if not getattr(self, '_use_runtime_gsa', False): - gsa = self._gsa_buf.fill_(self._activation_global_scale) - else: - gsa = self._gsa_buf # already filled by GPU compute + # Global scales — GPU-computed gsa already in _gsa_buf (no CPU sync) + gsa = self._gsa_buf # Run GEMM out = run_nvfp4_grouped_gemm( diff --git a/dsv4/layers/moe.py b/dsv4/layers/moe.py index 503ae56f..82b6ac58 100644 --- a/dsv4/layers/moe.py +++ b/dsv4/layers/moe.py @@ -589,19 +589,17 @@ class Nvfp4MoE: padded_dst = padded_expert_offsets[expert_assign] + local_row # === L1: gate + up === - # Compute runtime gsa from actual activation magnitude if requested. - # This prevents E4M3 block scale overflow when checkpoint input_scale is too small. + # Fused amax + quantize: single kernel, zero CPU-GPU syncs. + # Computes amax on GPU → derives gsa → quantizes to NVFP4. + # gsa written to GPU buffer for GEMM global_scale_a. if getattr(self, '_use_runtime_gsa', False): - from dsv4.ops.quantize import compute_amax_gsa_gpu - gsa_l1 = compute_amax_gsa_gpu(slot_hidden) - self._l1_activation_global_scale = gsa_l1.item() - self._l1_gsa_buf.copy_(gsa_l1.reshape(1)) - # Quantize slot_hidden using GPU-only kernel (no CPU-GPU sync). - # slot_hidden is the sorted tokens (not padded). The GPU kernel - # replaces quantize_activation_nvfp4 which uses .amax() (CPU sync). - slot_x_fp4, slot_x_sf = quantize_nvfp4_gpu( - slot_hidden, self._l1_activation_global_scale - ) + from dsv4.ops.quantize import quantize_nvfp4_gpu_fused + slot_x_fp4, slot_x_sf, gsa_l1_gpu = quantize_nvfp4_gpu_fused(slot_hidden) + self._l1_gsa_buf.copy_(gsa_l1_gpu[:1].reshape(1)) # GPU → GPU, no sync + else: + slot_x_fp4, slot_x_sf = quantize_nvfp4_gpu( + slot_hidden, self._l1_activation_global_scale + ) # Scatter x_fp4 into padded layout for the GEMM # Must scatter as uint8 (float4_e2m1fn_x2 doesn't support index_put) padded_x_fp4 = self._shared_bufs['hidden_fp4'] @@ -613,7 +611,7 @@ class Nvfp4MoE: padded_expert_offsets, self._padded_x_sf_buf_l1, self._per_expert_scale_bufs_l1 ) - l1_gsa = self._l1_gsa_buf.fill_(self._l1_activation_global_scale) + l1_gsa = self._l1_gsa_buf # already filled by GPU compute (no .fill_ needed) if self._fused_swiglu: # === Fused L1 GEMM + SwiGLU in kernel registers === @@ -625,19 +623,18 @@ class Nvfp4MoE: swiglu_limit=self._swiglu_limit if self._swiglu_limit is not None else 0.0, ) l1_out_real = l1_out[padded_dst] - # Compute runtime gsa for L2 from the activated output + # Fused deinterleave + amax + quantize: zero CPU syncs. + # Computes gsa from de-interleaved SwiGLU output on GPU, + # quantizes in the same kernel. Writes gsa to GPU buffer. if getattr(self, '_use_runtime_gsa', False): - from dsv4.ops.quantize import compute_amax_gsa_gpu - gsa_l2 = compute_amax_gsa_gpu(l1_out_real) - self._l2_activation_global_scale = gsa_l2.item() - self._l2_gsa_buf.copy_(gsa_l2.reshape(1)) - # De-interleave + quantize to FP4 in one GPU kernel. - # l1_out_real has interleaved [silu(gate)*8, swiglu*8, ...]. - # The CUDA kernel extracts odd 8-col groups (SwiGLU result) - # and quantizes to NVFP4. No CPU sync, no Python deinterleave. - slot_l2_x_fp4, slot_l2_x_sf = deinterleave_quantize_nvfp4_cuda( - l1_out_real, self.intermediate_size, self._l2_activation_global_scale - ) + from dsv4.ops.quantize import deinterleave_amax_quantize_nvfp4_fused + slot_l2_x_fp4, slot_l2_x_sf, gsa_l2_gpu = deinterleave_amax_quantize_nvfp4_fused( + l1_out_real, self.intermediate_size) + self._l2_gsa_buf.copy_(gsa_l2_gpu[:1].reshape(1)) # GPU → GPU, no sync + else: + slot_l2_x_fp4, slot_l2_x_sf = deinterleave_quantize_nvfp4_cuda( + l1_out_real, self.intermediate_size, self._l2_activation_global_scale + ) else: # === Non-fused L1 GEMM + PyTorch SiLU(gate)*up === l1_out = run_nvfp4_grouped_gemm( @@ -657,15 +654,12 @@ class Nvfp4MoE: activated = gate_silu * up # Compute runtime gsa for L2 from activated output (non-fused path) + # Fused amax + quantize: zero CPU syncs. if not self._fused_swiglu and getattr(self, '_use_runtime_gsa', False): - from dsv4.ops.quantize import compute_amax_gsa_gpu - gsa_l2 = compute_amax_gsa_gpu(activated) - self._l2_activation_global_scale = gsa_l2.item() - self._l2_gsa_buf.copy_(gsa_l2.reshape(1)) - # === L2: down === - # Quantize activated (per-token) using GPU-only kernel, scatter into padded FP4 buffer. - # For fused_swiglu path, slot_l2_x_fp4/sf already set by deinterleave_quantize_nvfp4_cuda. - if not self._fused_swiglu: + from dsv4.ops.quantize import quantize_nvfp4_gpu_fused + slot_l2_x_fp4, slot_l2_x_sf, gsa_l2_gpu = quantize_nvfp4_gpu_fused(activated) + self._l2_gsa_buf.copy_(gsa_l2_gpu[:1].reshape(1)) # GPU → GPU, no sync + elif not self._fused_swiglu: slot_l2_x_fp4, slot_l2_x_sf = quantize_nvfp4_gpu( activated, self._l2_activation_global_scale ) @@ -678,7 +672,7 @@ class Nvfp4MoE: padded_expert_offsets, self._padded_x_sf_buf_l2, self._per_expert_scale_bufs_l2 ) - l2_gsa = self._l2_gsa_buf.fill_(self._l2_activation_global_scale) + l2_gsa = self._l2_gsa_buf # already filled by GPU compute (no .fill_ needed) l2_out = run_nvfp4_grouped_gemm( mat_a=padded_activated_fp4, mat_b=self._l2_mat_b, diff --git a/dsv4/layers/shared_expert.py b/dsv4/layers/shared_expert.py index 86a932df..8ac4d803 100644 --- a/dsv4/layers/shared_expert.py +++ b/dsv4/layers/shared_expert.py @@ -235,15 +235,15 @@ class Nvfp4SharedExpert: num_tokens = hidden_states.shape[0] padded_rows = cutedsl_ceil_div(num_tokens, 128) * 128 - # Quantize activation + # Fused amax + quantize: zero CPU syncs. if getattr(self, '_use_runtime_gsa', False): - from dsv4.ops.quantize import compute_amax_gsa_gpu - gsa_l1 = compute_amax_gsa_gpu(hidden_states) - self._l1_activation_global_scale = gsa_l1.item() - self._l1_gsa_buf.copy_(gsa_l1.reshape(1)) - x_fp4, x_sf = quantize_activation_nvfp4( - hidden_states, self._l1_activation_global_scale - ) + from dsv4.ops.quantize import quantize_nvfp4_gpu_fused + x_fp4, x_sf, gsa_l1_gpu = quantize_nvfp4_gpu_fused(hidden_states) + self._l1_gsa_buf.copy_(gsa_l1_gpu[:1].reshape(1)) # GPU → GPU, no sync + else: + x_fp4, x_sf = quantize_activation_nvfp4( + hidden_states, self._l1_activation_global_scale + ) # Scatter x_fp4 into padded buffer padded_x_fp4 = self._padded_x_fp4_buf_l1 @@ -257,8 +257,8 @@ class Nvfp4SharedExpert: expert_offsets = self._expert_offsets_buf expert_offsets.fill_(padded_rows) - # Global scales - gsa = self._l1_gsa_buf.fill_(self._l1_activation_global_scale) + # Global scales — GPU-computed gsa already in _l1_gsa_buf (no CPU sync) + gsa = self._l1_gsa_buf # Run GEMM out = run_nvfp4_grouped_gemm( @@ -279,15 +279,15 @@ class Nvfp4SharedExpert: num_tokens = intermediate.shape[0] padded_rows = cutedsl_ceil_div(num_tokens, 128) * 128 - # Quantize activation + # Fused amax + quantize: zero CPU syncs. if getattr(self, '_use_runtime_gsa', False): - from dsv4.ops.quantize import compute_amax_gsa_gpu - gsa_l2 = compute_amax_gsa_gpu(intermediate) - self._l2_activation_global_scale = gsa_l2.item() - self._l2_gsa_buf.copy_(gsa_l2.reshape(1)) - x_fp4, x_sf = quantize_activation_nvfp4( - intermediate, self._l2_activation_global_scale - ) + from dsv4.ops.quantize import quantize_nvfp4_gpu_fused + x_fp4, x_sf, gsa_l2_gpu = quantize_nvfp4_gpu_fused(intermediate) + self._l2_gsa_buf.copy_(gsa_l2_gpu[:1].reshape(1)) # GPU → GPU, no sync + else: + x_fp4, x_sf = quantize_activation_nvfp4( + intermediate, self._l2_activation_global_scale + ) # Scatter into padded buffer padded_x_fp4 = self._padded_x_fp4_buf_l2 @@ -301,8 +301,8 @@ class Nvfp4SharedExpert: expert_offsets = self._expert_offsets_buf expert_offsets.fill_(padded_rows) - # Global scales - gsa = self._l2_gsa_buf.fill_(self._l2_activation_global_scale) + # Global scales — GPU-computed gsa already in _l2_gsa_buf (no CPU sync) + gsa = self._l2_gsa_buf # Run GEMM out = run_nvfp4_grouped_gemm( diff --git a/dsv4/model/sampler.py b/dsv4/model/sampler.py index 5fffb8a4..afc523b5 100644 --- a/dsv4/model/sampler.py +++ b/dsv4/model/sampler.py @@ -21,14 +21,8 @@ def _get_kernel(): global _kernel if _kernel is not None: return _kernel - from torch.utils.cpp_extension import load - kdir = os.path.join(os.path.dirname(__file__), "..", "kernels", "cuda") - _kernel = load( - name="dsv4_sampler", - sources=[os.path.join(kdir, "sampler.cu")], - extra_cuda_cflags=["-O3", "--generate-code=arch=compute_100a,code=[sm_100a]"], - verbose=False, - ) + from dsv4.kernels.cuda.loader import get_cuda_module + _kernel = get_cuda_module("sampler", ["sampler.cu"]) return _kernel diff --git a/dsv4/ops/quantize.py b/dsv4/ops/quantize.py index 2e9fb4a9..2777a73f 100644 --- a/dsv4/ops/quantize.py +++ b/dsv4/ops/quantize.py @@ -242,38 +242,44 @@ def deinterleave_quantize_nvfp4_cuda(fused_bf16, intermediate, global_scale, gra x_fp4: (M, intermediate//2) float4_e2m1fn_x2 — quantized SwiGLU x_sf: (M, intermediate//16) float8_e4m3fn — block scales """ - from torch.utils.cpp_extension import load - import os - # dsv4/ops/quantize.py → dsv4/kernels/cuda/ - kernel_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "kernels", "cuda") - mod = load( - name="deinterleave_quantize_nvfp4", - sources=[os.path.join(kernel_dir, "deinterleave_quantize.cu")], - extra_cuda_cflags=["-gencode=arch=compute_100a,code=sm_100a"], - verbose=False, - ) + from dsv4.kernels.cuda.loader import get_cuda_module + mod = get_cuda_module("deinterleave_quantize_nvfp4", ["deinterleave_quantize.cu"]) return mod.deinterleave_quantize_nvfp4(fused_bf16, intermediate, granularity, global_scale) +def deinterleave_amax_quantize_nvfp4_fused(fused_bf16, intermediate, divisor=6.0 * 448.0, granularity=8): + """Fused deinterleave + amax + gsa + quantize: NO CPU sync, single kernel launch. + + For the MoE fused_swiglu L2 path. Computes gsa from the de-interleaved + (SwiGLU) values on GPU, quantizes in the same kernel. Zero .item() syncs. + + Args: + fused_bf16: (M, 2*intermediate) BF16 — fused L1 output + intermediate: intermediate dimension + divisor: gsa = amax / divisor. Default 2688.0. + granularity: interleave granularity (default 8) + + Returns: + x_fp4: (M, intermediate//2) float4_e2m1fn_x2 + x_sf: (M, intermediate//16) float8_e4m3fn + gsa: (M,) float32 GPU tensor — per-row global scale for L2 GEMM + """ + from dsv4.kernels.cuda.loader import get_cuda_module + mod = get_cuda_module("fused_deinterleave_amax_quantize", ["fused_deinterleave_amax_quantize.cu"]) + return mod.fused_deinterleave_amax_quantize(fused_bf16, intermediate, granularity, divisor) + + def compute_amax_gsa_gpu(x_bf16, divisor=6.0 * 448.0): """Compute gsa = max(|x|) / divisor on GPU. No CPU sync. Returns a scalar GPU tensor (not a Python float!). - The caller can pass this directly to quantize_nvfp4_gpu() - or to CuTeDSL GEMM's global_scale_a buffer via .fill_(). - This eliminates ~915 CPU-GPU syncs per decode step - (610 from Nvfp4Linear + 183 from Nvfp4MoE + 122 from SharedExpert). + NOTE: Prefer quantize_nvfp4_gpu_fused() which does amax+quantize in + one kernel launch. This function is kept for cases where you need gsa + without quantization. """ - from torch.utils.cpp_extension import load - import os - kernel_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "kernels", "cuda") - mod = load( - name="amax_gsa", - sources=[os.path.join(kernel_dir, "amax_gsa.cu")], - extra_cuda_cflags=["-gencode=arch=compute_100a,code=sm_100a"], - verbose=False, - ) + from dsv4.kernels.cuda.loader import get_cuda_module + mod = get_cuda_module("amax_gsa", ["amax_gsa.cu"]) return mod.compute_amax_gsa(x_bf16, divisor) @@ -288,8 +294,6 @@ def quantize_nvfp4_gpu_fused(x_bf16, divisor=6.0 * 448.0): This fused kernel computes amax on GPU, derives gsa, and quantizes in a single kernel launch. Zero CPU-GPU syncs. - For decode (M=1, N=7168): ~5μs vs ~15μs (separate amax + quantize + sync). - Args: x_bf16: (M, N) BF16 tensor. N must be a multiple of 16. divisor: gsa = amax / divisor. Default 6.0 * 448.0 = 2688.0. @@ -297,16 +301,10 @@ def quantize_nvfp4_gpu_fused(x_bf16, divisor=6.0 * 448.0): Returns: x_fp4: (M, N//2) float4_e2m1fn_x2 x_sf: (M, N//16) float8_e4m3fn + gsa: (M,) float32 GPU tensor — per-row global scale for GEMM """ - from torch.utils.cpp_extension import load - import os - kernel_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "kernels", "cuda") - mod = load( - name="fused_amax_quantize", - sources=[os.path.join(kernel_dir, "fused_amax_quantize.cu")], - extra_cuda_cflags=["-gencode=arch=compute_100a,code=sm_100a"], - verbose=False, - ) + from dsv4.kernels.cuda.loader import get_cuda_module + mod = get_cuda_module("fused_amax_quantize", ["fused_amax_quantize.cu"]) return mod.fused_amax_quantize_nvfp4(x_bf16, divisor) @@ -316,6 +314,9 @@ def quantize_nvfp4_gpu(x_bf16, global_scale): Replaces quantize_activation_nvfp4() which uses .amax() (CPU sync). The global_scale must be pre-computed (from warmup or known value). + NOTE: Prefer quantize_nvfp4_gpu_fused() which also computes gsa on GPU. + This function is kept for cases where global_scale is already known. + Args: x_bf16: (M, N) BF16 tensor. N must be a multiple of 16. global_scale: float32 scalar (pre-computed, NOT from .max()) @@ -324,14 +325,6 @@ def quantize_nvfp4_gpu(x_bf16, global_scale): x_fp4: (M, N//2) float4_e2m1fn_x2 x_sf: (M, N//16) float8_e4m3fn """ - from torch.utils.cpp_extension import load - import os - # dsv4/ops/quantize.py → dsv4/kernels/cuda/ - kernel_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "kernels", "cuda") - mod = load( - name="quantize_nvfp4", - sources=[os.path.join(kernel_dir, "quantize_nvfp4.cu")], - extra_cuda_cflags=["-gencode=arch=compute_100a,code=sm_100a"], - verbose=False, - ) + from dsv4.kernels.cuda.loader import get_cuda_module + mod = get_cuda_module("quantize_nvfp4", ["quantize_nvfp4.cu"]) return mod.quantize_nvfp4(x_bf16, global_scale) diff --git a/single_shot_inference.py b/single_shot_inference.py index e654aebf..1ad13283 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -212,9 +212,10 @@ class Compressor: kv, gate, self.ape, self.kv_norm_w, m=r) if compressed.shape[0] == 0: return None, None, None - comp_pos = torch.tensor([positions[(bi+1)*r - 1].item() if positions.numel() > (bi+1)*r - 1 else 0 - for bi in range(n_complete)], - dtype=torch.long, device=dev) + # Vectorized position computation — no Python loop, no .item() + bi = torch.arange(n_complete, device=dev) + pos_idx = ((bi + 1) * r - 1).clamp(max=positions.numel() - 1) + comp_pos = positions[pos_idx] return compressed, comp_pos, torch.zeros(1, T, n_complete, dtype=torch.float32, device=dev) # ===================================================================== @@ -264,26 +265,50 @@ class Indexer: # KV Cache # ===================================================================== class KVCache: - def __init__(self, head_dim, window_size=128, device='cuda:0'): + def __init__(self, head_dim, window_size=128, max_comp=32768, device='cuda:0'): self.hd, self.ws, self.dev = head_dim, window_size, device self.swa = torch.zeros(window_size, head_dim, dtype=torch.bfloat16, device=device) self.swa_pos = torch.zeros(window_size, dtype=torch.long, device=device) self.swa_len, self.swa_head = 0, 0 - self.comp_kv, self.comp_pos, self.n_comp = None, None, 0; self.comp_idx_kv = None + # P3: Pre-allocate compressed KV buffers (no more torch.cat / O(N²) growth) + self.comp_kv_buf = torch.zeros(max_comp, head_dim, dtype=torch.bfloat16, device=device) + self.comp_pos_buf = torch.zeros(max_comp, dtype=torch.long, device=device) + self.comp_idx_buf = torch.zeros(max_comp, head_dim, dtype=torch.bfloat16, device=device) + self.n_comp = 0 + self._has_idx = False def append_swa(self, kv, pos): + """P2: Vectorized SWA append — 2 kernel launches instead of 2T.""" T = kv.shape[0] - for i in range(T): - idx = (self.swa_head + i) % self.ws; self.swa[idx], self.swa_pos[idx] = kv[i], pos[i] - self.swa_head = (self.swa_head + T) % self.ws; self.swa_len = min(self.swa_len + T, self.ws) + idx = (self.swa_head + torch.arange(T, device=self.dev)) % self.ws + self.swa.index_copy_(0, idx, kv) + self.swa_pos.index_copy_(0, idx, pos) + self.swa_head = (self.swa_head + T) % self.ws + self.swa_len = min(self.swa_len + T, self.ws) def add_compressed(self, ckv, cpos, idx_kv=None): + """P3: Pre-allocated buffer — O(1) instead of O(N) per call.""" if ckv is None: return - self.comp_kv = ckv if self.comp_kv is None else torch.cat([self.comp_kv, ckv]) - self.comp_pos = cpos if self.comp_pos is None else torch.cat([self.comp_pos, cpos]) - self.n_comp = self.comp_kv.shape[0] + T = ckv.shape[0] + end = self.n_comp + T + self.comp_kv_buf[self.n_comp:end] = ckv + self.comp_pos_buf[self.n_comp:end] = cpos if idx_kv is not None: - self.comp_idx_kv = idx_kv if self.comp_idx_kv is None else torch.cat([self.comp_idx_kv, idx_kv]) + self.comp_idx_buf[self.n_comp:end] = idx_kv + self._has_idx = True + self.n_comp = end + + @property + def comp_kv(self): + return self.comp_kv_buf[:self.n_comp] if self.n_comp > 0 else None + + @property + def comp_pos(self): + return self.comp_pos_buf[:self.n_comp] if self.n_comp > 0 else None + + @property + def comp_idx_kv(self): + return self.comp_idx_buf[:self.n_comp] if self._has_idx and self.n_comp > 0 else None def get_swa(self): if self.swa_len == 0: