CRITICAL FIX: runtime activation global scale to prevent E4M3 overflow

The checkpoint's input_scale was designed for training-time FP8 quantization,
not NVFP4 activation quantization. Using it as gsa causes x/gsa to exceed
the E4M3 block scale maximum (448), leading to systematic magnitude loss
in every projection. This accumulates over 61 layers, compressing the
logit range and producing garbage tokens.

Fix: compute gsa at runtime from actual activation magnitude:
  gsa = max(|x|) / (6.0 * 448.0)
This ensures x/gsa ≤ 2688 (the maximum representable in E4M3 block scales).

Applied to: Nvfp4Linear, Nvfp4GroupedLinear, Nvfp4MoE, Nvfp4SharedExpert, Router gate
This commit is contained in:
2026-06-01 14:21:16 +00:00
parent 3b2714410f
commit 2b1fca6dae
6 changed files with 42 additions and 23 deletions

View File

@@ -238,6 +238,11 @@ class Nvfp4GroupedLinear:
# Permute to groups-first: (G, T, D) # Permute to groups-first: (G, T, D)
o_grouped = o_grouped.permute(1, 0, 2) o_grouped = o_grouped.permute(1, 0, 2)
# Compute activation global scale at runtime if requested.
if getattr(self, '_use_runtime_gsa', False):
amax = o.float().abs().max().clamp(min=1e-8).item()
self._activation_global_scale = amax / (6.0 * 448.0)
# Quantize each group's activation and scatter into padded buffer # Quantize each group's activation and scatter into padded buffer
padded_x_fp4 = self._padded_x_fp4_buf padded_x_fp4 = self._padded_x_fp4_buf
padded_x_fp4.view(torch.uint8).zero_() padded_x_fp4.view(torch.uint8).zero_()

View File

@@ -160,6 +160,13 @@ class Nvfp4Linear:
# Ensure buffer is large enough # Ensure buffer is large enough
self._ensure_buffer_size(num_tokens) self._ensure_buffer_size(num_tokens)
# Compute activation global scale at runtime if requested.
# This prevents E4M3 block scale overflow when the checkpoint's
# input_scale is too small for the actual activation magnitudes.
if getattr(self, '_use_runtime_gsa', False):
amax = hidden_states.float().abs().max().clamp(min=1e-8).item()
self._activation_global_scale = amax / (6.0 * 448.0)
# Quantize activation # Quantize activation
x_fp4, x_sf = quantize_activation_nvfp4( x_fp4, x_sf = quantize_activation_nvfp4(
hidden_states, self._activation_global_scale hidden_states, self._activation_global_scale

View File

@@ -589,6 +589,11 @@ class Nvfp4MoE:
padded_dst = padded_expert_offsets[expert_assign] + local_row padded_dst = padded_expert_offsets[expert_assign] + local_row
# === L1: gate + up === # === L1: gate + up ===
# Compute runtime gsa from actual activation magnitude if requested.
# This prevents E4M3 block scale overflow when checkpoint input_scale is too small.
if getattr(self, '_use_runtime_gsa', False):
amax = slot_hidden.float().abs().max().clamp(min=1e-8).item()
self._l1_activation_global_scale = amax / (6.0 * 448.0)
# Quantize slot_hidden using GPU-only kernel (no CPU-GPU sync). # Quantize slot_hidden using GPU-only kernel (no CPU-GPU sync).
# slot_hidden is the sorted tokens (not padded). The GPU kernel # slot_hidden is the sorted tokens (not padded). The GPU kernel
# replaces quantize_activation_nvfp4 which uses .amax() (CPU sync). # replaces quantize_activation_nvfp4 which uses .amax() (CPU sync).
@@ -618,6 +623,10 @@ class Nvfp4MoE:
swiglu_limit=self._swiglu_limit if self._swiglu_limit is not None else 0.0, swiglu_limit=self._swiglu_limit if self._swiglu_limit is not None else 0.0,
) )
l1_out_real = l1_out[padded_dst] l1_out_real = l1_out[padded_dst]
# Compute runtime gsa for L2 from the activated output
if getattr(self, '_use_runtime_gsa', False):
amax_l2 = l1_out_real.float().abs().max().clamp(min=1e-8).item()
self._l2_activation_global_scale = amax_l2 / (6.0 * 448.0)
# De-interleave + quantize to FP4 in one GPU kernel. # De-interleave + quantize to FP4 in one GPU kernel.
# l1_out_real has interleaved [silu(gate)*8, swiglu*8, ...]. # l1_out_real has interleaved [silu(gate)*8, swiglu*8, ...].
# The CUDA kernel extracts odd 8-col groups (SwiGLU result) # The CUDA kernel extracts odd 8-col groups (SwiGLU result)

View File

@@ -184,6 +184,7 @@ class Router:
ws2_val = gate_ws2.float().item() if gate_ws2.numel() == 1 else gate_ws2.float().mean().item() ws2_val = gate_ws2.float().item() if gate_ws2.numel() == 1 else gate_ws2.float().mean().item()
gate_lin.ws2 = [torch.tensor([ws2_val], device=self.device, dtype=torch.float32)] gate_lin.ws2 = [torch.tensor([ws2_val], device=self.device, dtype=torch.float32)]
gate_lin._activation_global_scale = gate_input_scale.float().item() if gate_input_scale.numel() == 1 else gate_input_scale.float().mean().item() gate_lin._activation_global_scale = gate_input_scale.float().item() if gate_input_scale.numel() == 1 else gate_input_scale.float().mean().item()
gate_lin._use_runtime_gsa = True # compute gsa from actual input to avoid E4M3 overflow
gate_lin.finalize_weights() gate_lin.finalize_weights()
self.gate_lin = gate_lin self.gate_lin = gate_lin

View File

@@ -236,6 +236,9 @@ class Nvfp4SharedExpert:
padded_rows = cutedsl_ceil_div(num_tokens, 128) * 128 padded_rows = cutedsl_ceil_div(num_tokens, 128) * 128
# Quantize activation # Quantize activation
if getattr(self, '_use_runtime_gsa', False):
amax = hidden_states.float().abs().max().clamp(min=1e-8).item()
self._l1_activation_global_scale = amax / (6.0 * 448.0)
x_fp4, x_sf = quantize_activation_nvfp4( x_fp4, x_sf = quantize_activation_nvfp4(
hidden_states, self._l1_activation_global_scale hidden_states, self._l1_activation_global_scale
) )
@@ -275,6 +278,9 @@ class Nvfp4SharedExpert:
padded_rows = cutedsl_ceil_div(num_tokens, 128) * 128 padded_rows = cutedsl_ceil_div(num_tokens, 128) * 128
# Quantize activation # Quantize activation
if getattr(self, '_use_runtime_gsa', False):
amax = intermediate.float().abs().max().clamp(min=1e-8).item()
self._l2_activation_global_scale = amax / (6.0 * 448.0)
x_fp4, x_sf = quantize_activation_nvfp4( x_fp4, x_sf = quantize_activation_nvfp4(
intermediate, self._l2_activation_global_scale intermediate, self._l2_activation_global_scale
) )

View File

@@ -133,26 +133,18 @@ def make_nvfp4_linear(in_features, out_features, device, all_w, pfx, proj_name):
d = device d = device
weight, ws, ws2, isc = get_nvfp4_weight(all_w, pfx, proj_name) weight, ws, ws2, isc = get_nvfp4_weight(all_w, pfx, proj_name)
assert weight is not None, f"{pfx}.{proj_name}.weight not found" assert weight is not None, f"{pfx}.{proj_name}.weight not found"
# Checkpoint weight is (N_packed, K_packed) uint8
# NVFP4 GEMM output dim = N_packed BF16 elements
# Activation buffer needs K_packed FP4 columns = in_features BF16
# So: in_features = K_packed * 2, out_features = N_packed
actual_out = weight.shape[0] # N_packed = GEMM output dimension actual_out = weight.shape[0] # N_packed = GEMM output dimension
actual_in = weight.shape[1] * 2 # K_packed * 2 = BF16 input dim (for buffer allocation) actual_in = weight.shape[1] * 2 # K_packed * 2 = BF16 input dim (for buffer allocation)
lin = Nvfp4Linear(actual_in, actual_out, max_num_tokens=8192, device=d) lin = Nvfp4Linear(actual_in, actual_out, max_num_tokens=8192, device=d)
lin.fp4 = [weight.to(d)]; lin.sf = [ws.to(d)] lin.fp4 = [weight.to(d)]; lin.sf = [ws.to(d)]
# Global scales for NVFP4 GEMM:
# gsb (weight global scale) = weight_scale_2 (NOT input_scale * weight_scale_2)
# gsa (activation global scale) = input_scale from checkpoint
# Dequant: w = lut[w_packed] * weight_scale * weight_scale_2
# GEMM: y = (x * scale_a * gsa) @ (w * scale_b * gsb)
# Nvfp4Linear.finalize_weights does: gsb = gs * ws2_val
# So to get gsb = ws2_val, set gs = 1.0 and let ws2 do its job
lin.gs = [1.0] # base gs — finalize_weights will multiply by ws2 lin.gs = [1.0] # base gs — finalize_weights will multiply by ws2
lin.ws2 = [ws2.to(d) if ws2 is not None else None] lin.ws2 = [ws2.to(d) if ws2 is not None else None]
# Set activation global scale from checkpoint input_scale # CRITICAL FIX: Compute gsa at RUNTIME from actual input magnitude.
isc_val = isc.float().item() if isc is not None else 1.0 / (6.0 * 448.0) # The checkpoint's input_scale is for training-time FP8 quantization.
lin._activation_global_scale = isc_val # gsa = input_scale # Using it as gsa causes E4M3 block scale overflow when x/gsa > 2688.
# We set a placeholder and override in the forward pass.
lin._activation_global_scale = 1.0 / (6.0 * 448.0) # placeholder
lin._use_runtime_gsa = True # flag to compute gsa at runtime
lin.finalize_weights(); return lin lin.finalize_weights(); return lin
# ===================================================================== # =====================================================================
@@ -697,6 +689,7 @@ def main():
if oa_bf is not None: if oa_bf is not None:
wo_a.set_bf16_weight(oa_bf.bfloat16().to(dev)) wo_a.set_bf16_weight(oa_bf.bfloat16().to(dev))
pl['o_a'] = wo_a pl['o_a'] = wo_a
wo_a._use_runtime_gsa = True # compute gsa from actual input to avoid E4M3 overflow
pl['o_b'] = make_nvfp4_linear(16384, 7168, dev, all_w, pfx, 'o_b_proj') pl['o_b'] = make_nvfp4_linear(16384, 7168, dev, all_w, pfx, 'o_b_proj')
prod_lins[li] = pl prod_lins[li] = pl
if (li+1) % 10 == 0: print(f" {li+1}/{n_layers} layers") if (li+1) % 10 == 0: print(f" {li+1}/{n_layers} layers")
@@ -769,10 +762,11 @@ def main():
# EAGERLY process stacked weights → K-major + swizzle, free raw tensors # EAGERLY process stacked weights → K-major + swizzle, free raw tensors
moe._ensure_stacked() moe._ensure_stacked()
# Fix activation global scales — _ensure_stacked sets gsa from l1_gs (which is 1.0) # Fix activation global scales — _ensure_stacked sets gsa from l1_gs (which is 1.0)
if hasattr(moe, '_saved_l1_gsa'): # FIX: Do NOT use checkpoint input_scale as gsa — causes E4M3 overflow.
moe._l1_activation_global_scale = moe._saved_l1_gsa # Instead, compute gsa at runtime from actual activation magnitude.
if hasattr(moe, '_saved_l2_gsa'): # The MoE runner's compute_activation_global_scales() does this correctly.
moe._l2_activation_global_scale = moe._saved_l2_gsa # We enable runtime gsa for both MoE and SharedExpert.
moe._use_runtime_gsa = True
moe_runners[li] = moe moe_runners[li] = moe
se = Nvfp4SharedExpert(hidden_size=H, intermediate_size=cfg.get("moe_intermediate_size", 3072), se = Nvfp4SharedExpert(hidden_size=H, intermediate_size=cfg.get("moe_intermediate_size", 3072),
@@ -781,11 +775,8 @@ def main():
# EAGERLY process shared expert weights # EAGERLY process shared expert weights
se._ensure_initialized() se._ensure_initialized()
# Fix activation global scales — _ensure_initialized sets gsa from l1_gs (which is 1.0) # Fix activation global scales — _ensure_initialized sets gsa from l1_gs (which is 1.0)
# The correct gsa is the input_scale from the checkpoint, saved in _saved_l1_gsa # FIX: Same runtime gsa for SharedExpert
if hasattr(se, '_saved_l1_gsa'): se._use_runtime_gsa = True
se._l1_activation_global_scale = se._saved_l1_gsa
if hasattr(se, '_saved_l2_gsa'):
se._l2_activation_global_scale = se._saved_l2_gsa
se_runners[li] = se se_runners[li] = se
if (li+1) % 10 == 0: print(f" Built {li+1}/{n_layers} MoE layers") if (li+1) % 10 == 0: print(f" Built {li+1}/{n_layers} MoE layers")
torch.cuda.empty_cache() torch.cuda.empty_cache()