CRITICAL FIX: runtime activation global scale to prevent E4M3 overflow
The checkpoint's input_scale was designed for training-time FP8 quantization, not NVFP4 activation quantization. Using it as gsa causes x/gsa to exceed the E4M3 block scale maximum (448), leading to systematic magnitude loss in every projection. This accumulates over 61 layers, compressing the logit range and producing garbage tokens. Fix: compute gsa at runtime from actual activation magnitude: gsa = max(|x|) / (6.0 * 448.0) This ensures x/gsa ≤ 2688 (the maximum representable in E4M3 block scales). Applied to: Nvfp4Linear, Nvfp4GroupedLinear, Nvfp4MoE, Nvfp4SharedExpert, Router gate
This commit is contained in:
@@ -238,6 +238,11 @@ class Nvfp4GroupedLinear:
|
||||
# Permute to groups-first: (G, T, D)
|
||||
o_grouped = o_grouped.permute(1, 0, 2)
|
||||
|
||||
# Compute activation global scale at runtime if requested.
|
||||
if getattr(self, '_use_runtime_gsa', False):
|
||||
amax = o.float().abs().max().clamp(min=1e-8).item()
|
||||
self._activation_global_scale = amax / (6.0 * 448.0)
|
||||
|
||||
# Quantize each group's activation and scatter into padded buffer
|
||||
padded_x_fp4 = self._padded_x_fp4_buf
|
||||
padded_x_fp4.view(torch.uint8).zero_()
|
||||
|
||||
@@ -160,6 +160,13 @@ class Nvfp4Linear:
|
||||
# Ensure buffer is large enough
|
||||
self._ensure_buffer_size(num_tokens)
|
||||
|
||||
# Compute activation global scale at runtime if requested.
|
||||
# This prevents E4M3 block scale overflow when the checkpoint's
|
||||
# input_scale is too small for the actual activation magnitudes.
|
||||
if getattr(self, '_use_runtime_gsa', False):
|
||||
amax = hidden_states.float().abs().max().clamp(min=1e-8).item()
|
||||
self._activation_global_scale = amax / (6.0 * 448.0)
|
||||
|
||||
# Quantize activation
|
||||
x_fp4, x_sf = quantize_activation_nvfp4(
|
||||
hidden_states, self._activation_global_scale
|
||||
|
||||
@@ -589,6 +589,11 @@ class Nvfp4MoE:
|
||||
padded_dst = padded_expert_offsets[expert_assign] + local_row
|
||||
|
||||
# === L1: gate + up ===
|
||||
# Compute runtime gsa from actual activation magnitude if requested.
|
||||
# This prevents E4M3 block scale overflow when checkpoint input_scale is too small.
|
||||
if getattr(self, '_use_runtime_gsa', False):
|
||||
amax = slot_hidden.float().abs().max().clamp(min=1e-8).item()
|
||||
self._l1_activation_global_scale = amax / (6.0 * 448.0)
|
||||
# Quantize slot_hidden using GPU-only kernel (no CPU-GPU sync).
|
||||
# slot_hidden is the sorted tokens (not padded). The GPU kernel
|
||||
# replaces quantize_activation_nvfp4 which uses .amax() (CPU sync).
|
||||
@@ -618,6 +623,10 @@ class Nvfp4MoE:
|
||||
swiglu_limit=self._swiglu_limit if self._swiglu_limit is not None else 0.0,
|
||||
)
|
||||
l1_out_real = l1_out[padded_dst]
|
||||
# Compute runtime gsa for L2 from the activated output
|
||||
if getattr(self, '_use_runtime_gsa', False):
|
||||
amax_l2 = l1_out_real.float().abs().max().clamp(min=1e-8).item()
|
||||
self._l2_activation_global_scale = amax_l2 / (6.0 * 448.0)
|
||||
# De-interleave + quantize to FP4 in one GPU kernel.
|
||||
# l1_out_real has interleaved [silu(gate)*8, swiglu*8, ...].
|
||||
# The CUDA kernel extracts odd 8-col groups (SwiGLU result)
|
||||
|
||||
@@ -184,6 +184,7 @@ class Router:
|
||||
ws2_val = gate_ws2.float().item() if gate_ws2.numel() == 1 else gate_ws2.float().mean().item()
|
||||
gate_lin.ws2 = [torch.tensor([ws2_val], device=self.device, dtype=torch.float32)]
|
||||
gate_lin._activation_global_scale = gate_input_scale.float().item() if gate_input_scale.numel() == 1 else gate_input_scale.float().mean().item()
|
||||
gate_lin._use_runtime_gsa = True # compute gsa from actual input to avoid E4M3 overflow
|
||||
gate_lin.finalize_weights()
|
||||
self.gate_lin = gate_lin
|
||||
|
||||
|
||||
@@ -236,6 +236,9 @@ class Nvfp4SharedExpert:
|
||||
padded_rows = cutedsl_ceil_div(num_tokens, 128) * 128
|
||||
|
||||
# Quantize activation
|
||||
if getattr(self, '_use_runtime_gsa', False):
|
||||
amax = hidden_states.float().abs().max().clamp(min=1e-8).item()
|
||||
self._l1_activation_global_scale = amax / (6.0 * 448.0)
|
||||
x_fp4, x_sf = quantize_activation_nvfp4(
|
||||
hidden_states, self._l1_activation_global_scale
|
||||
)
|
||||
@@ -275,6 +278,9 @@ class Nvfp4SharedExpert:
|
||||
padded_rows = cutedsl_ceil_div(num_tokens, 128) * 128
|
||||
|
||||
# Quantize activation
|
||||
if getattr(self, '_use_runtime_gsa', False):
|
||||
amax = intermediate.float().abs().max().clamp(min=1e-8).item()
|
||||
self._l2_activation_global_scale = amax / (6.0 * 448.0)
|
||||
x_fp4, x_sf = quantize_activation_nvfp4(
|
||||
intermediate, self._l2_activation_global_scale
|
||||
)
|
||||
|
||||
@@ -133,26 +133,18 @@ def make_nvfp4_linear(in_features, out_features, device, all_w, pfx, proj_name):
|
||||
d = device
|
||||
weight, ws, ws2, isc = get_nvfp4_weight(all_w, pfx, proj_name)
|
||||
assert weight is not None, f"{pfx}.{proj_name}.weight not found"
|
||||
# Checkpoint weight is (N_packed, K_packed) uint8
|
||||
# NVFP4 GEMM output dim = N_packed BF16 elements
|
||||
# Activation buffer needs K_packed FP4 columns = in_features BF16
|
||||
# So: in_features = K_packed * 2, out_features = N_packed
|
||||
actual_out = weight.shape[0] # N_packed = GEMM output dimension
|
||||
actual_in = weight.shape[1] * 2 # K_packed * 2 = BF16 input dim (for buffer allocation)
|
||||
lin = Nvfp4Linear(actual_in, actual_out, max_num_tokens=8192, device=d)
|
||||
lin.fp4 = [weight.to(d)]; lin.sf = [ws.to(d)]
|
||||
# Global scales for NVFP4 GEMM:
|
||||
# gsb (weight global scale) = weight_scale_2 (NOT input_scale * weight_scale_2)
|
||||
# gsa (activation global scale) = input_scale from checkpoint
|
||||
# Dequant: w = lut[w_packed] * weight_scale * weight_scale_2
|
||||
# GEMM: y = (x * scale_a * gsa) @ (w * scale_b * gsb)
|
||||
# Nvfp4Linear.finalize_weights does: gsb = gs * ws2_val
|
||||
# So to get gsb = ws2_val, set gs = 1.0 and let ws2 do its job
|
||||
lin.gs = [1.0] # base gs — finalize_weights will multiply by ws2
|
||||
lin.ws2 = [ws2.to(d) if ws2 is not None else None]
|
||||
# Set activation global scale from checkpoint input_scale
|
||||
isc_val = isc.float().item() if isc is not None else 1.0 / (6.0 * 448.0)
|
||||
lin._activation_global_scale = isc_val # gsa = input_scale
|
||||
# CRITICAL FIX: Compute gsa at RUNTIME from actual input magnitude.
|
||||
# The checkpoint's input_scale is for training-time FP8 quantization.
|
||||
# Using it as gsa causes E4M3 block scale overflow when x/gsa > 2688.
|
||||
# We set a placeholder and override in the forward pass.
|
||||
lin._activation_global_scale = 1.0 / (6.0 * 448.0) # placeholder
|
||||
lin._use_runtime_gsa = True # flag to compute gsa at runtime
|
||||
lin.finalize_weights(); return lin
|
||||
|
||||
# =====================================================================
|
||||
@@ -697,6 +689,7 @@ def main():
|
||||
if oa_bf is not None:
|
||||
wo_a.set_bf16_weight(oa_bf.bfloat16().to(dev))
|
||||
pl['o_a'] = wo_a
|
||||
wo_a._use_runtime_gsa = True # compute gsa from actual input to avoid E4M3 overflow
|
||||
pl['o_b'] = make_nvfp4_linear(16384, 7168, dev, all_w, pfx, 'o_b_proj')
|
||||
prod_lins[li] = pl
|
||||
if (li+1) % 10 == 0: print(f" {li+1}/{n_layers} layers")
|
||||
@@ -769,10 +762,11 @@ def main():
|
||||
# EAGERLY process stacked weights → K-major + swizzle, free raw tensors
|
||||
moe._ensure_stacked()
|
||||
# Fix activation global scales — _ensure_stacked sets gsa from l1_gs (which is 1.0)
|
||||
if hasattr(moe, '_saved_l1_gsa'):
|
||||
moe._l1_activation_global_scale = moe._saved_l1_gsa
|
||||
if hasattr(moe, '_saved_l2_gsa'):
|
||||
moe._l2_activation_global_scale = moe._saved_l2_gsa
|
||||
# FIX: Do NOT use checkpoint input_scale as gsa — causes E4M3 overflow.
|
||||
# Instead, compute gsa at runtime from actual activation magnitude.
|
||||
# The MoE runner's compute_activation_global_scales() does this correctly.
|
||||
# We enable runtime gsa for both MoE and SharedExpert.
|
||||
moe._use_runtime_gsa = True
|
||||
moe_runners[li] = moe
|
||||
|
||||
se = Nvfp4SharedExpert(hidden_size=H, intermediate_size=cfg.get("moe_intermediate_size", 3072),
|
||||
@@ -781,11 +775,8 @@ def main():
|
||||
# EAGERLY process shared expert weights
|
||||
se._ensure_initialized()
|
||||
# Fix activation global scales — _ensure_initialized sets gsa from l1_gs (which is 1.0)
|
||||
# The correct gsa is the input_scale from the checkpoint, saved in _saved_l1_gsa
|
||||
if hasattr(se, '_saved_l1_gsa'):
|
||||
se._l1_activation_global_scale = se._saved_l1_gsa
|
||||
if hasattr(se, '_saved_l2_gsa'):
|
||||
se._l2_activation_global_scale = se._saved_l2_gsa
|
||||
# FIX: Same runtime gsa for SharedExpert
|
||||
se._use_runtime_gsa = True
|
||||
se_runners[li] = se
|
||||
if (li+1) % 10 == 0: print(f" Built {li+1}/{n_layers} MoE layers")
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
Reference in New Issue
Block a user