diff --git a/dsv4/layers/grouped_linear.py b/dsv4/layers/grouped_linear.py index 1142f5b2..5c144621 100644 --- a/dsv4/layers/grouped_linear.py +++ b/dsv4/layers/grouped_linear.py @@ -238,6 +238,11 @@ class Nvfp4GroupedLinear: # Permute to groups-first: (G, T, D) o_grouped = o_grouped.permute(1, 0, 2) + # Compute activation global scale at runtime if requested. + if getattr(self, '_use_runtime_gsa', False): + amax = o.float().abs().max().clamp(min=1e-8).item() + self._activation_global_scale = amax / (6.0 * 448.0) + # Quantize each group's activation and scatter into padded buffer padded_x_fp4 = self._padded_x_fp4_buf padded_x_fp4.view(torch.uint8).zero_() diff --git a/dsv4/layers/linear.py b/dsv4/layers/linear.py index 7b54e606..3c2a7860 100644 --- a/dsv4/layers/linear.py +++ b/dsv4/layers/linear.py @@ -160,6 +160,13 @@ class Nvfp4Linear: # Ensure buffer is large enough self._ensure_buffer_size(num_tokens) + # Compute activation global scale at runtime if requested. + # This prevents E4M3 block scale overflow when the checkpoint's + # input_scale is too small for the actual activation magnitudes. + if getattr(self, '_use_runtime_gsa', False): + amax = hidden_states.float().abs().max().clamp(min=1e-8).item() + self._activation_global_scale = amax / (6.0 * 448.0) + # Quantize activation x_fp4, x_sf = quantize_activation_nvfp4( hidden_states, self._activation_global_scale diff --git a/dsv4/layers/moe.py b/dsv4/layers/moe.py index 96502aaa..4e937eb6 100644 --- a/dsv4/layers/moe.py +++ b/dsv4/layers/moe.py @@ -589,6 +589,11 @@ class Nvfp4MoE: padded_dst = padded_expert_offsets[expert_assign] + local_row # === L1: gate + up === + # Compute runtime gsa from actual activation magnitude if requested. + # This prevents E4M3 block scale overflow when checkpoint input_scale is too small. + if getattr(self, '_use_runtime_gsa', False): + amax = slot_hidden.float().abs().max().clamp(min=1e-8).item() + self._l1_activation_global_scale = amax / (6.0 * 448.0) # Quantize slot_hidden using GPU-only kernel (no CPU-GPU sync). # slot_hidden is the sorted tokens (not padded). The GPU kernel # replaces quantize_activation_nvfp4 which uses .amax() (CPU sync). @@ -618,6 +623,10 @@ class Nvfp4MoE: swiglu_limit=self._swiglu_limit if self._swiglu_limit is not None else 0.0, ) l1_out_real = l1_out[padded_dst] + # Compute runtime gsa for L2 from the activated output + if getattr(self, '_use_runtime_gsa', False): + amax_l2 = l1_out_real.float().abs().max().clamp(min=1e-8).item() + self._l2_activation_global_scale = amax_l2 / (6.0 * 448.0) # De-interleave + quantize to FP4 in one GPU kernel. # l1_out_real has interleaved [silu(gate)*8, swiglu*8, ...]. # The CUDA kernel extracts odd 8-col groups (SwiGLU result) diff --git a/dsv4/layers/router.py b/dsv4/layers/router.py index 7f4cc879..fbdf3db6 100644 --- a/dsv4/layers/router.py +++ b/dsv4/layers/router.py @@ -184,6 +184,7 @@ class Router: ws2_val = gate_ws2.float().item() if gate_ws2.numel() == 1 else gate_ws2.float().mean().item() gate_lin.ws2 = [torch.tensor([ws2_val], device=self.device, dtype=torch.float32)] gate_lin._activation_global_scale = gate_input_scale.float().item() if gate_input_scale.numel() == 1 else gate_input_scale.float().mean().item() + gate_lin._use_runtime_gsa = True # compute gsa from actual input to avoid E4M3 overflow gate_lin.finalize_weights() self.gate_lin = gate_lin diff --git a/dsv4/layers/shared_expert.py b/dsv4/layers/shared_expert.py index be986354..8f318c43 100644 --- a/dsv4/layers/shared_expert.py +++ b/dsv4/layers/shared_expert.py @@ -236,6 +236,9 @@ class Nvfp4SharedExpert: padded_rows = cutedsl_ceil_div(num_tokens, 128) * 128 # Quantize activation + if getattr(self, '_use_runtime_gsa', False): + amax = hidden_states.float().abs().max().clamp(min=1e-8).item() + self._l1_activation_global_scale = amax / (6.0 * 448.0) x_fp4, x_sf = quantize_activation_nvfp4( hidden_states, self._l1_activation_global_scale ) @@ -275,6 +278,9 @@ class Nvfp4SharedExpert: padded_rows = cutedsl_ceil_div(num_tokens, 128) * 128 # Quantize activation + if getattr(self, '_use_runtime_gsa', False): + amax = intermediate.float().abs().max().clamp(min=1e-8).item() + self._l2_activation_global_scale = amax / (6.0 * 448.0) x_fp4, x_sf = quantize_activation_nvfp4( intermediate, self._l2_activation_global_scale ) diff --git a/single_shot_inference.py b/single_shot_inference.py index 6ca91cba..88e3973c 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -133,26 +133,18 @@ def make_nvfp4_linear(in_features, out_features, device, all_w, pfx, proj_name): d = device weight, ws, ws2, isc = get_nvfp4_weight(all_w, pfx, proj_name) assert weight is not None, f"{pfx}.{proj_name}.weight not found" - # Checkpoint weight is (N_packed, K_packed) uint8 - # NVFP4 GEMM output dim = N_packed BF16 elements - # Activation buffer needs K_packed FP4 columns = in_features BF16 - # So: in_features = K_packed * 2, out_features = N_packed actual_out = weight.shape[0] # N_packed = GEMM output dimension actual_in = weight.shape[1] * 2 # K_packed * 2 = BF16 input dim (for buffer allocation) lin = Nvfp4Linear(actual_in, actual_out, max_num_tokens=8192, device=d) lin.fp4 = [weight.to(d)]; lin.sf = [ws.to(d)] - # Global scales for NVFP4 GEMM: - # gsb (weight global scale) = weight_scale_2 (NOT input_scale * weight_scale_2) - # gsa (activation global scale) = input_scale from checkpoint - # Dequant: w = lut[w_packed] * weight_scale * weight_scale_2 - # GEMM: y = (x * scale_a * gsa) @ (w * scale_b * gsb) - # Nvfp4Linear.finalize_weights does: gsb = gs * ws2_val - # So to get gsb = ws2_val, set gs = 1.0 and let ws2 do its job lin.gs = [1.0] # base gs — finalize_weights will multiply by ws2 lin.ws2 = [ws2.to(d) if ws2 is not None else None] - # Set activation global scale from checkpoint input_scale - isc_val = isc.float().item() if isc is not None else 1.0 / (6.0 * 448.0) - lin._activation_global_scale = isc_val # gsa = input_scale + # CRITICAL FIX: Compute gsa at RUNTIME from actual input magnitude. + # The checkpoint's input_scale is for training-time FP8 quantization. + # Using it as gsa causes E4M3 block scale overflow when x/gsa > 2688. + # We set a placeholder and override in the forward pass. + lin._activation_global_scale = 1.0 / (6.0 * 448.0) # placeholder + lin._use_runtime_gsa = True # flag to compute gsa at runtime lin.finalize_weights(); return lin # ===================================================================== @@ -697,6 +689,7 @@ def main(): if oa_bf is not None: wo_a.set_bf16_weight(oa_bf.bfloat16().to(dev)) pl['o_a'] = wo_a + wo_a._use_runtime_gsa = True # compute gsa from actual input to avoid E4M3 overflow pl['o_b'] = make_nvfp4_linear(16384, 7168, dev, all_w, pfx, 'o_b_proj') prod_lins[li] = pl if (li+1) % 10 == 0: print(f" {li+1}/{n_layers} layers") @@ -769,10 +762,11 @@ def main(): # EAGERLY process stacked weights → K-major + swizzle, free raw tensors moe._ensure_stacked() # Fix activation global scales — _ensure_stacked sets gsa from l1_gs (which is 1.0) - if hasattr(moe, '_saved_l1_gsa'): - moe._l1_activation_global_scale = moe._saved_l1_gsa - if hasattr(moe, '_saved_l2_gsa'): - moe._l2_activation_global_scale = moe._saved_l2_gsa + # FIX: Do NOT use checkpoint input_scale as gsa — causes E4M3 overflow. + # Instead, compute gsa at runtime from actual activation magnitude. + # The MoE runner's compute_activation_global_scales() does this correctly. + # We enable runtime gsa for both MoE and SharedExpert. + moe._use_runtime_gsa = True moe_runners[li] = moe se = Nvfp4SharedExpert(hidden_size=H, intermediate_size=cfg.get("moe_intermediate_size", 3072), @@ -781,11 +775,8 @@ def main(): # EAGERLY process shared expert weights se._ensure_initialized() # Fix activation global scales — _ensure_initialized sets gsa from l1_gs (which is 1.0) - # The correct gsa is the input_scale from the checkpoint, saved in _saved_l1_gsa - if hasattr(se, '_saved_l1_gsa'): - se._l1_activation_global_scale = se._saved_l1_gsa - if hasattr(se, '_saved_l2_gsa'): - se._l2_activation_global_scale = se._saved_l2_gsa + # FIX: Same runtime gsa for SharedExpert + se._use_runtime_gsa = True se_runners[li] = se if (li+1) % 10 == 0: print(f" Built {li+1}/{n_layers} MoE layers") torch.cuda.empty_cache()