From 2b1fca6dae94ea9846d0eefaf1a55b1afcc59599 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Mon, 1 Jun 2026 14:21:16 +0000 Subject: [PATCH] CRITICAL FIX: runtime activation global scale to prevent E4M3 overflow MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The checkpoint's input_scale was designed for training-time FP8 quantization, not NVFP4 activation quantization. Using it as gsa causes x/gsa to exceed the E4M3 block scale maximum (448), leading to systematic magnitude loss in every projection. This accumulates over 61 layers, compressing the logit range and producing garbage tokens. Fix: compute gsa at runtime from actual activation magnitude: gsa = max(|x|) / (6.0 * 448.0) This ensures x/gsa ≤ 2688 (the maximum representable in E4M3 block scales). Applied to: Nvfp4Linear, Nvfp4GroupedLinear, Nvfp4MoE, Nvfp4SharedExpert, Router gate --- dsv4/layers/grouped_linear.py | 5 +++++ dsv4/layers/linear.py | 7 +++++++ dsv4/layers/moe.py | 9 +++++++++ dsv4/layers/router.py | 1 + dsv4/layers/shared_expert.py | 6 ++++++ single_shot_inference.py | 37 +++++++++++++---------------------- 6 files changed, 42 insertions(+), 23 deletions(-) diff --git a/dsv4/layers/grouped_linear.py b/dsv4/layers/grouped_linear.py index 1142f5b2..5c144621 100644 --- a/dsv4/layers/grouped_linear.py +++ b/dsv4/layers/grouped_linear.py @@ -238,6 +238,11 @@ class Nvfp4GroupedLinear: # Permute to groups-first: (G, T, D) o_grouped = o_grouped.permute(1, 0, 2) + # Compute activation global scale at runtime if requested. + if getattr(self, '_use_runtime_gsa', False): + amax = o.float().abs().max().clamp(min=1e-8).item() + self._activation_global_scale = amax / (6.0 * 448.0) + # Quantize each group's activation and scatter into padded buffer padded_x_fp4 = self._padded_x_fp4_buf padded_x_fp4.view(torch.uint8).zero_() diff --git a/dsv4/layers/linear.py b/dsv4/layers/linear.py index 7b54e606..3c2a7860 100644 --- a/dsv4/layers/linear.py +++ b/dsv4/layers/linear.py @@ -160,6 +160,13 @@ class Nvfp4Linear: # Ensure buffer is large enough self._ensure_buffer_size(num_tokens) + # Compute activation global scale at runtime if requested. + # This prevents E4M3 block scale overflow when the checkpoint's + # input_scale is too small for the actual activation magnitudes. + if getattr(self, '_use_runtime_gsa', False): + amax = hidden_states.float().abs().max().clamp(min=1e-8).item() + self._activation_global_scale = amax / (6.0 * 448.0) + # Quantize activation x_fp4, x_sf = quantize_activation_nvfp4( hidden_states, self._activation_global_scale diff --git a/dsv4/layers/moe.py b/dsv4/layers/moe.py index 96502aaa..4e937eb6 100644 --- a/dsv4/layers/moe.py +++ b/dsv4/layers/moe.py @@ -589,6 +589,11 @@ class Nvfp4MoE: padded_dst = padded_expert_offsets[expert_assign] + local_row # === L1: gate + up === + # Compute runtime gsa from actual activation magnitude if requested. + # This prevents E4M3 block scale overflow when checkpoint input_scale is too small. + if getattr(self, '_use_runtime_gsa', False): + amax = slot_hidden.float().abs().max().clamp(min=1e-8).item() + self._l1_activation_global_scale = amax / (6.0 * 448.0) # Quantize slot_hidden using GPU-only kernel (no CPU-GPU sync). # slot_hidden is the sorted tokens (not padded). The GPU kernel # replaces quantize_activation_nvfp4 which uses .amax() (CPU sync). @@ -618,6 +623,10 @@ class Nvfp4MoE: swiglu_limit=self._swiglu_limit if self._swiglu_limit is not None else 0.0, ) l1_out_real = l1_out[padded_dst] + # Compute runtime gsa for L2 from the activated output + if getattr(self, '_use_runtime_gsa', False): + amax_l2 = l1_out_real.float().abs().max().clamp(min=1e-8).item() + self._l2_activation_global_scale = amax_l2 / (6.0 * 448.0) # De-interleave + quantize to FP4 in one GPU kernel. # l1_out_real has interleaved [silu(gate)*8, swiglu*8, ...]. # The CUDA kernel extracts odd 8-col groups (SwiGLU result) diff --git a/dsv4/layers/router.py b/dsv4/layers/router.py index 7f4cc879..fbdf3db6 100644 --- a/dsv4/layers/router.py +++ b/dsv4/layers/router.py @@ -184,6 +184,7 @@ class Router: ws2_val = gate_ws2.float().item() if gate_ws2.numel() == 1 else gate_ws2.float().mean().item() gate_lin.ws2 = [torch.tensor([ws2_val], device=self.device, dtype=torch.float32)] gate_lin._activation_global_scale = gate_input_scale.float().item() if gate_input_scale.numel() == 1 else gate_input_scale.float().mean().item() + gate_lin._use_runtime_gsa = True # compute gsa from actual input to avoid E4M3 overflow gate_lin.finalize_weights() self.gate_lin = gate_lin diff --git a/dsv4/layers/shared_expert.py b/dsv4/layers/shared_expert.py index be986354..8f318c43 100644 --- a/dsv4/layers/shared_expert.py +++ b/dsv4/layers/shared_expert.py @@ -236,6 +236,9 @@ class Nvfp4SharedExpert: padded_rows = cutedsl_ceil_div(num_tokens, 128) * 128 # Quantize activation + if getattr(self, '_use_runtime_gsa', False): + amax = hidden_states.float().abs().max().clamp(min=1e-8).item() + self._l1_activation_global_scale = amax / (6.0 * 448.0) x_fp4, x_sf = quantize_activation_nvfp4( hidden_states, self._l1_activation_global_scale ) @@ -275,6 +278,9 @@ class Nvfp4SharedExpert: padded_rows = cutedsl_ceil_div(num_tokens, 128) * 128 # Quantize activation + if getattr(self, '_use_runtime_gsa', False): + amax = intermediate.float().abs().max().clamp(min=1e-8).item() + self._l2_activation_global_scale = amax / (6.0 * 448.0) x_fp4, x_sf = quantize_activation_nvfp4( intermediate, self._l2_activation_global_scale ) diff --git a/single_shot_inference.py b/single_shot_inference.py index 6ca91cba..88e3973c 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -133,26 +133,18 @@ def make_nvfp4_linear(in_features, out_features, device, all_w, pfx, proj_name): d = device weight, ws, ws2, isc = get_nvfp4_weight(all_w, pfx, proj_name) assert weight is not None, f"{pfx}.{proj_name}.weight not found" - # Checkpoint weight is (N_packed, K_packed) uint8 - # NVFP4 GEMM output dim = N_packed BF16 elements - # Activation buffer needs K_packed FP4 columns = in_features BF16 - # So: in_features = K_packed * 2, out_features = N_packed actual_out = weight.shape[0] # N_packed = GEMM output dimension actual_in = weight.shape[1] * 2 # K_packed * 2 = BF16 input dim (for buffer allocation) lin = Nvfp4Linear(actual_in, actual_out, max_num_tokens=8192, device=d) lin.fp4 = [weight.to(d)]; lin.sf = [ws.to(d)] - # Global scales for NVFP4 GEMM: - # gsb (weight global scale) = weight_scale_2 (NOT input_scale * weight_scale_2) - # gsa (activation global scale) = input_scale from checkpoint - # Dequant: w = lut[w_packed] * weight_scale * weight_scale_2 - # GEMM: y = (x * scale_a * gsa) @ (w * scale_b * gsb) - # Nvfp4Linear.finalize_weights does: gsb = gs * ws2_val - # So to get gsb = ws2_val, set gs = 1.0 and let ws2 do its job lin.gs = [1.0] # base gs — finalize_weights will multiply by ws2 lin.ws2 = [ws2.to(d) if ws2 is not None else None] - # Set activation global scale from checkpoint input_scale - isc_val = isc.float().item() if isc is not None else 1.0 / (6.0 * 448.0) - lin._activation_global_scale = isc_val # gsa = input_scale + # CRITICAL FIX: Compute gsa at RUNTIME from actual input magnitude. + # The checkpoint's input_scale is for training-time FP8 quantization. + # Using it as gsa causes E4M3 block scale overflow when x/gsa > 2688. + # We set a placeholder and override in the forward pass. + lin._activation_global_scale = 1.0 / (6.0 * 448.0) # placeholder + lin._use_runtime_gsa = True # flag to compute gsa at runtime lin.finalize_weights(); return lin # ===================================================================== @@ -697,6 +689,7 @@ def main(): if oa_bf is not None: wo_a.set_bf16_weight(oa_bf.bfloat16().to(dev)) pl['o_a'] = wo_a + wo_a._use_runtime_gsa = True # compute gsa from actual input to avoid E4M3 overflow pl['o_b'] = make_nvfp4_linear(16384, 7168, dev, all_w, pfx, 'o_b_proj') prod_lins[li] = pl if (li+1) % 10 == 0: print(f" {li+1}/{n_layers} layers") @@ -769,10 +762,11 @@ def main(): # EAGERLY process stacked weights → K-major + swizzle, free raw tensors moe._ensure_stacked() # Fix activation global scales — _ensure_stacked sets gsa from l1_gs (which is 1.0) - if hasattr(moe, '_saved_l1_gsa'): - moe._l1_activation_global_scale = moe._saved_l1_gsa - if hasattr(moe, '_saved_l2_gsa'): - moe._l2_activation_global_scale = moe._saved_l2_gsa + # FIX: Do NOT use checkpoint input_scale as gsa — causes E4M3 overflow. + # Instead, compute gsa at runtime from actual activation magnitude. + # The MoE runner's compute_activation_global_scales() does this correctly. + # We enable runtime gsa for both MoE and SharedExpert. + moe._use_runtime_gsa = True moe_runners[li] = moe se = Nvfp4SharedExpert(hidden_size=H, intermediate_size=cfg.get("moe_intermediate_size", 3072), @@ -781,11 +775,8 @@ def main(): # EAGERLY process shared expert weights se._ensure_initialized() # Fix activation global scales — _ensure_initialized sets gsa from l1_gs (which is 1.0) - # The correct gsa is the input_scale from the checkpoint, saved in _saved_l1_gsa - if hasattr(se, '_saved_l1_gsa'): - se._l1_activation_global_scale = se._saved_l1_gsa - if hasattr(se, '_saved_l2_gsa'): - se._l2_activation_global_scale = se._saved_l2_gsa + # FIX: Same runtime gsa for SharedExpert + se._use_runtime_gsa = True se_runners[li] = se if (li+1) % 10 == 0: print(f" Built {li+1}/{n_layers} MoE layers") torch.cuda.empty_cache()