fix: stop folding global scale into float8 block scales

The fold block_sf (float8) * global_sf (float32) -> float8 loses ~25% precision.
Product of ~56-448 block_sf * ~4.65e-05 global_sf lands in float8 low-precision
zone where step size is 25%. This makes model output garbage despite finite values.

Fix: keep block scales as original float8, return global scales separately as
float32 per-expert vectors. Apply global scale as per-expert GEMM alpha in
cutlass_grouped_nvfp4_gemm (already iterates per-expert). For L1 with separate
gate/up global scales, use gate_gs as alpha and apply up_correction ratio to
the up half post-GEMM.

weight_transform.py: no more _fold_global_scale, returns (w, sf, global_sf)
nvfp4_mega_moe.py: per-expert alpha = activation_gs * weight_gs
kernel.py: per_expert_alpha parameter in grouped GEMM
deepseek_v4.py: updated type hints and comments
This commit is contained in:
2026-05-15 12:42:53 +00:00
parent 56e62e916d
commit fd59222fc0
9 changed files with 955 additions and 157 deletions

View File

@@ -61,6 +61,7 @@ def cutlass_grouped_nvfp4_gemm(
slot_expert_ids, # (num_slots,) int32 — per-slot local expert IDs
slot_token=None, # (num_slots,) int64 — per-slot token indices (default: arange)
alpha=1.0, # fp32 scalar: D = alpha * A @ B (from stage_activation global scale)
per_expert_alpha=None, # (E_per_rank,) float32 — per-expert alpha overrides scalar alpha
):
"""Per-expert grouped GEMM for MoE dispatch using CUTLASS NVFP4.
@@ -71,6 +72,11 @@ def cutlass_grouped_nvfp4_gemm(
For L1: x_fp4 has num_tokens rows, slot_token maps slots→rows.
For L2: x_fp4 has num_slots rows, slot_token is just arange(num_slots).
If per_expert_alpha is provided, each expert uses its own alpha value
(activation_global_scale * weight_global_scale[expert]) instead of the
scalar alpha. This preserves full float32 precision — no lossy float8
folding of weight global scales.
Returns:
slot_out: (num_slots, N) bfloat16 — per-slot GEMM results
slot_token: (num_slots,) int64 — token index for each slot
@@ -100,7 +106,7 @@ def cutlass_grouped_nvfp4_gemm(
if MEGA_MOE_DEBUG:
print(f"[cutlass_grouped_gemm] slots={num_slots} K={K} N={N} "
f"experts={num_experts}")
f"experts={num_experts} per_expert_alpha={'yes' if per_expert_alpha is not None else 'no'}")
slot_out = torch.empty(num_slots, N, dtype=torch.bfloat16, device=x_fp4.device)
@@ -116,9 +122,12 @@ def cutlass_grouped_nvfp4_gemm(
expert_w_sf = weight_sf[e]
M_expert = e_idx.shape[0]
# Per-expert alpha: activation_gs * weight_gs (float32, no precision loss)
expert_alpha = float(per_expert_alpha[e]) if per_expert_alpha is not None else alpha
if MEGA_MOE_DEBUG and e < 3 and M_expert > 0:
print(f"[GEMM-IN] expert={e} M={M_expert} N={N} K={K} "
f"w shape={expert_w.shape}")
f"w shape={expert_w.shape} alpha={expert_alpha:.4e}")
# Shape/dtype contract asserts — SFB bugs hide in silent shape mismatches
assert expert_x.shape == (M_expert, K // 2), f"expert_x shape {expert_x.shape} != ({M_expert}, {K // 2})"
@@ -132,14 +141,14 @@ def cutlass_grouped_nvfp4_gemm(
expert_x, expert_x_sf,
expert_w, expert_w_sf,
M_expert, N, K,
alpha=alpha,
alpha=expert_alpha,
)
if MEGA_MOE_DEBUG:
if torch.isnan(expert_out).any() or torch.isinf(expert_out).any():
raise RuntimeError(
f"expert {e} of {num_experts}: GEMM emitted NaN/Inf. "
f"M={M_expert} N={N} K={K}")
f"M={M_expert} N={N} K={K} alpha={expert_alpha:.4e}")
slot_out[e_idx] = expert_out

View File

@@ -56,8 +56,6 @@ def unpack_ue4m3_u32(x_u32):
return out
# CUTLASS native NVFP4 block-scaled GEMM (SM100 Blackwell)
# Primary path: uses CUTLASS MainloopSm100TmaUmmaWarpSpecializedBlockScaled
# which invokes mxf8f6f4.block_scale tensor core instructions directly.
MEGA_MOE_USE_CUTLASS = int(os.environ.get("MEGA_MOE_USE_CUTLASS", "1"))
try:
@@ -97,13 +95,25 @@ def nvfp4_mega_moe_l1(
l1_scales, # (E_per_rank, sf_k_groups, 2*INTER) float8_e4m3fn, column-major
slot_expert_ids, # (num_slots,) int32 — per-slot local expert IDs
slot_token, # (num_slots,) int64 — token index per slot
l1_global_sf, # (E_per_rank, 2) or (E_per_rank,) float32 — weight global scales
alpha=1.0, # fp32 scalar from stage_activation global scale
):
"""L1 GEMM: gate_up_proj — slot-based, no routing weights.
Takes pre-built slot mapping (slot_expert_ids, slot_token) from the outer
routing logic. Returns (slot_out, slot_token) where each slot is one
(token, topk) pair.
Global scale is NOT folded into block scales. Instead, it's applied as a
per-expert multiplier to the GEMM alpha: alpha_expert = alpha * global_sf[expert].
For L1 with gate+up: gate and up share one GEMM but may have different global scales.
Since the GEMM produces gate|up in one shot, we use a single alpha per expert.
Post-GEMM, we apply the gate/up ratio correction if they differ.
Actually, for simplicity and correctness: we use the gate global scale as alpha
and correct the up portion after GEMM. But since gate and up global scales
are typically identical in practice, we just use the geometric mean.
CLEANER APPROACH: use per-expert alpha directly in the grouped GEMM.
The grouped GEMM iterates per expert, so each expert can have its own alpha.
For L1 with separate gate/up global scales, we use the geometric mean
and then apply a correction factor to the up portion.
"""
K_half = x_fp4.shape[1]
K = K_half * 2
@@ -116,13 +126,36 @@ def nvfp4_mega_moe_l1(
w_sf_fp8 = unpack_ue4m3_u32(l1_scales) if l1_scales.dtype == torch.uint32 else l1_scales
assert w_sf_fp8.dtype == torch.float8_e4m3fn, f"l1_scales after unpack dtype={w_sf_fp8.dtype}"
# Compute per-expert alpha: activation_gs * weight_gs
# For L1 with (E, 2) gate/up global scales, use geometric mean per expert
if l1_global_sf.dim() == 2 and l1_global_sf.shape[1] == 2:
# gate_gs and up_gs per expert — use gate_gs for the GEMM alpha,
# then correct the up half post-GEMM
l1_gate_gs = l1_global_sf[:, 0] # (E,) float32
l1_up_gs = l1_global_sf[:, 1] # (E,) float32
per_expert_alpha = alpha * l1_gate_gs # (E,) float32
up_correction = l1_up_gs / l1_gate_gs # (E,) float32 — ratio to apply to up half
else:
per_expert_alpha = alpha * l1_global_sf # (E,) float32
up_correction = None
slot_out, slot_token = cutlass_grouped_nvfp4_gemm(
x_fp4, x_sf_fp8,
l1_weights, w_sf_fp8,
slot_expert_ids, # 1D per-slot expert IDs
slot_token, # 1D per-slot token indices
alpha=alpha,
slot_expert_ids,
slot_token,
per_expert_alpha=per_expert_alpha,
)
# Apply up correction if gate/up global scales differ
if up_correction is not None:
gate_N = N // 2
# For each slot, apply the correction to the up half
# slot_out is (num_slots, N) — up half is [:, gate_N:]
# Correction factor is per-expert: up_correction[slot_expert_ids]
correction = up_correction[slot_expert_ids].unsqueeze(1) # (num_slots, 1)
slot_out[:, gate_N:] = slot_out[:, gate_N:] * correction.to(slot_out.dtype)
print(f"[L1-GEMM-OUT] slots={slot_out.shape[0]} N={N} amax={slot_out.abs().max().item():.4e} mean={slot_out.float().mean().item():.4e}")
return slot_out, slot_token
@@ -132,13 +165,15 @@ def nvfp4_mega_moe_l2(
x_sf, # (num_slots, sf_k_groups) float8_e4m3fn
l2_weights, # (E_per_rank, INTER//2, HIDDEN) int8, column-major for CUTLASS
l2_scales, # (E_per_rank, sf_k_groups, HIDDEN) float8_e4m3fn, column-major
slot_expert_ids, # (num_slots,) int32 — per-slot local expert IDs (from L1 routing)
slot_token, # (num_slots,) int64 — token index per slot (from L1)
slot_expert_ids, # (num_slots,) int32 — per-slot local expert IDs
slot_token, # (num_slots,) int64 — token index per slot
l2_global_sf, # (E_per_rank,) float32 — weight global scales
alpha=1.0, # fp32 scalar from stage_activation global scale
):
"""L2 GEMM: down_proj — slot-based, no routing weights.
Reuses the same slot mapping from L1 (same slot_token and slot_expert_ids).
Per-expert alpha = activation_global_scale * weight_global_scale[expert].
This preserves full float32 precision — no lossy float8 folding.
"""
K_half = x_fp4.shape[1]
K = K_half * 2
@@ -151,11 +186,14 @@ def nvfp4_mega_moe_l2(
w_sf_fp8 = unpack_ue4m3_u32(l2_scales) if l2_scales.dtype == torch.uint32 else l2_scales
assert w_sf_fp8.dtype == torch.float8_e4m3fn, f"l2_scales after unpack dtype={w_sf_fp8.dtype}"
# Per-expert alpha: activation_gs * weight_gs
per_expert_alpha = alpha * l2_global_sf # (E,) float32
slot_out, _ = cutlass_grouped_nvfp4_gemm(
x_fp4, x_sf_fp8,
l2_weights, w_sf_fp8,
slot_expert_ids, # 1D per-slot expert IDs — GEMM handles directly
alpha=alpha,
slot_expert_ids,
per_expert_alpha=per_expert_alpha,
)
return slot_out # (num_slots, HIDDEN) bfloat16
@@ -236,8 +274,8 @@ def stage_activation(x_bf16):
def nvfp4_mega_moe_full(
y, # output tensor (num_tokens, HIDDEN) bfloat16
transformed_l1_weights, # (l1_w, l1_sf) tuple from finalize_weights
transformed_l2_weights, # (l2_w, l2_sf) tuple from finalize_weights
transformed_l1_weights, # (l1_w, l1_sf, l1_global_sf) from finalize_weights
transformed_l2_weights, # (l2_w, l2_sf, l2_global_sf) from finalize_weights
symm_buffer, # SymmBuffer from get_symm_buffer
activation_clamp=None, # optional clamp value (unused in NVFP4)
fast_math=False, # fast math flag (unused in NVFP4)
@@ -246,10 +284,10 @@ def nvfp4_mega_moe_full(
Slot-based pipeline (routing weights applied ONCE at final scatter):
1. Read staged activation from symm_buffer
2. L1 GEMM → slot output (num_slots, 2*INTER) — NO routing weights
2. L1 GEMM → slot output (num_slots, 2*INTER) — per-expert alpha
3. SiLU + Mul PER SLOT (nonlinearity before combining expert paths)
4. Quantize activated slots → FP4
5. L2 GEMM → slot output (num_slots, HIDDEN) — NO routing weights
5. L2 GEMM → slot output (num_slots, HIDDEN) — per-expert alpha
6. Final scatter: y.index_add_(0, slot_token, slot_weight * l2_slots)
Single routing weight application.
"""
@@ -264,9 +302,9 @@ def nvfp4_mega_moe_full(
y.zero_()
return
# Unpack transformed weights
l1_w, l1_sf = transformed_l1_weights
l2_w, l2_sf = transformed_l2_weights
# Unpack transformed weights (now includes global_sf)
l1_w, l1_sf, l1_global_sf = transformed_l1_weights
l2_w, l2_sf, l2_global_sf = transformed_l2_weights
# Expert sanity check — are experts actually distinct?
if not getattr(nvfp4_mega_moe_full, '_expert_sanity', False):
@@ -276,6 +314,8 @@ def nvfp4_mega_moe_full(
sf_sample = l1_sf[e].to(torch.float32)[:4, :4]
print(f"[EXPERT-SANITY e={e}] w_bytes[:8,:8]={w_sample.flatten().tolist()[:16]}")
print(f"[EXPERT-SANITY e={e}] sf[:4,:4]={sf_sample.flatten().tolist()[:8]}")
print(f"[EXPERT-SANITY e={e}] l1_global_sf={l1_global_sf[e].tolist()}")
print(f"[EXPERT-SANITY e={e}] l2_global_sf={l2_global_sf[e].tolist()}")
# Step 1: Read staged activation from symm_buffer
x_fp4 = symm_buffer.x[:num_tokens]
@@ -287,7 +327,8 @@ def nvfp4_mega_moe_full(
_x_sf_f32 = x_sf.to(torch.float32)
_igs = l1_global_scale if isinstance(l1_global_scale, float) else l1_global_scale.item() if hasattr(l1_global_scale, 'item') else float(l1_global_scale)
if MEGA_MOE_DEBUG:
print(f"[ALPHA L1] alpha={_igs:.4e} x_sf range [{_x_sf_f32.min().item():.4e}, {_x_sf_f32.max().item():.4e}]")
print(f"[ALPHA L1] activation_gs={_igs:.4e} x_sf range [{_x_sf_f32.min().item():.4e}, {_x_sf_f32.max().item():.4e}]")
print(f"[ALPHA L1] l1_global_sf range [{l1_global_sf.min().item():.4e}, {l1_global_sf.max().item():.4e}]")
# Convert global expert IDs to local expert IDs
num_experts_per_rank = l1_w.shape[0]
@@ -316,10 +357,10 @@ def nvfp4_mega_moe_full(
y.zero_()
return
# Ensure alpha is a plain Python float (C extension can't handle torch scalars)
# Ensure alpha is a plain Python float for the base activation global scale
l1_alpha = float(l1_global_scale) if not isinstance(l1_global_scale, float) else l1_global_scale
# Shape consistency asserts — catch mismatched slot mappings early
# Shape consistency asserts
assert slot_expert_local.ndim == 1
assert slot_token.ndim == 1
assert slot_weight.ndim == 1
@@ -327,21 +368,11 @@ def nvfp4_mega_moe_full(
assert slot_token.numel() == num_slots
assert slot_weight.numel() == num_slots
# SFB weight scales are remapped per-expert inside CUTLASS on each call.
# ─────────────────────────────────────────────────────────────────────
# NO PREPACK CACHE — see README for rationale.
# DO NOT add a prepack cache. Previous attempts caused:
# - OOM: ~1.75 GiB per prepacked tensor × 61 layers = 214 GiB
# - Peak memory 2× during torch.stack before eviction
# - CUDA graph use-after-free on evicted entries
# - M_for_layout=128 assumption (unverified M-independence)
# The SFB remap is a small scatter kernel (~µs) — not the bottleneck.
# ─────────────────────────────────────────────────────────────────────
# Step 2: L1 GEMM — slot-based, no routing weights
# Step 2: L1 GEMM — slot-based, per-expert alpha
l1_slots, _ = nvfp4_mega_moe_l1(
x_fp4, x_sf, l1_w, l1_sf,
slot_expert_local, slot_token,
l1_global_sf=l1_global_sf,
alpha=l1_alpha,
) # (num_slots, 2*INTER) bfloat16
@@ -374,12 +405,14 @@ def nvfp4_mega_moe_full(
if MEGA_MOE_DEBUG:
_l1sf_f32 = l1_sf_out.to(torch.float32)
_l2gs = l2_global_scale if isinstance(l2_global_scale, float) else l2_global_scale.item()
print(f"[ALPHA L2] alpha={_l2gs:.4e} l1_sf range [{_l1sf_f32.min().item():.4e}, {_l1sf_f32.max().item():.4e}]")
print(f"[ALPHA L2] activation_gs={_l2gs:.4e} l1_sf range [{_l1sf_f32.min().item():.4e}, {_l1sf_f32.max().item():.4e}]")
print(f"[ALPHA L2] l2_global_sf range [{l2_global_sf.min().item():.4e}, {l2_global_sf.max().item():.4e}]")
# Step 5: L2 GEMM — slot-based, no routing weights
# Step 5: L2 GEMM — slot-based, per-expert alpha
l2_slots = nvfp4_mega_moe_l2(
l1_fp4, l1_sf_out, l2_w, l2_sf,
slot_expert_local, slot_token,
l2_global_sf=l2_global_sf,
alpha=l2_alpha,
) # (num_slots, HIDDEN) bfloat16

View File

@@ -5,12 +5,12 @@ Converts raw NVFP4 checkpoint weights (uint8 E2M1 + float8_e4m3fn UE4M3 + float3
into the format expected by the CUTLASS block-scaled GEMM kernel:
- Packed FP4 weights (int8, K-major)
- UE4M3 block scales (float8_e4m3fn, row-major — CUTLASS SF remap handles interleaving)
- float32 global scales (NOT folded into block scales — passed separately for per-expert alpha)
Weight scales are returned as float8_e4m3fn (NOT packed uint32). The CUTLASS GEMM
consumes float8 scales directly; only activation scales from the staging kernel come
as uint32 and need unpack_ue4m3_u32.
This replaces deep_gemma.mega.transform_nvfp4_weights_for_mega_moe.
Previous versions folded weight_scale_2 into block scales via float8 round-trip, which caused
25% relative error (product of ~56-448 block_sf × ~4.65e-05 global_sf lands in the low-precision
zone of float8_e4m3fn where step size is 25%). The global scale is now applied as a per-expert
multiplier to the GEMM alpha, preserving full float32 precision.
Call signature matches the nightly vLLM deepseek_v4.py finalize_weights:
transform_nvfp4_weights_for_mega_moe(
@@ -24,134 +24,80 @@ Call signature matches the nightly vLLM deepseek_v4.py finalize_weights:
import torch
def _fold_global_scale(
weight_scale: torch.Tensor, # (E, N, K//16) float8_e4m3fn
weight_scale_2: torch.Tensor, # (E,) or (E, 2) or scalar float32
) -> torch.Tensor:
"""Fold global scale into block scales: UE4M3 * FP32 → float32.
For fused projections (w13 = gate+up), weight_scale_2 is (E, 2):
scale_2[e, 0] applies to gate_proj rows, scale_2[e, 1] applies to up_proj rows.
N is split: gate = weight_scale[:, :N//2, :], up = weight_scale[:, N//2:, :]
For single projections (w2), weight_scale_2 is (E,) or scalar.
"""
sf_f32 = weight_scale.to(torch.float32)
gs = weight_scale_2.to(torch.float32)
if gs.numel() == 1:
sf_f32 = sf_f32 * gs
elif gs.dim() == 2 and gs.shape[1] == 2:
# Fused projection: (E, 2) — gate and up have separate global scales
# weight_scale is (E, N, K//16), N = gate_N + up_N
gate_N = sf_f32.shape[1] // 2
gs_gate = gs[:, 0].unsqueeze(-1) # (E, 1)
gs_up = gs[:, 1].unsqueeze(-1) # (E, 1)
sf_f32[:, :gate_N, :] = sf_f32[:, :gate_N, :] * gs_gate.unsqueeze(-1)
sf_f32[:, gate_N:, :] = sf_f32[:, gate_N:, :] * gs_up.unsqueeze(-1)
else:
# Per-expert global scale — broadcast multiply
while gs.dim() < sf_f32.dim():
gs = gs.unsqueeze(-1)
sf_f32 = sf_f32 * gs.expand_as(sf_f32)
return sf_f32
def _pack_ue4m3_to_uint32(sf: torch.Tensor) -> torch.Tensor:
"""Pack 4 UE4M3 (float8_e4m3fn) values into one uint32."""
sf_u8 = sf.view(torch.uint8)
assert sf_u8.shape[-1] % 4 == 0, f"Last dim {sf_u8.shape[-1]} not divisible by 4"
packed = (sf_u8[..., 0::4].to(torch.int32) |
(sf_u8[..., 1::4].to(torch.int32) << 8) |
(sf_u8[..., 2::4].to(torch.int32) << 16) |
(sf_u8[..., 3::4].to(torch.int32) << 24))
return packed.contiguous()
def transform_nvfp4_weights_for_mega_moe(
l1_tuple: tuple[torch.Tensor, torch.Tensor], # (weight, weight_scale)
l2_tuple: tuple[torch.Tensor, torch.Tensor], # (weight, weight_scale)
l1_weight_scale_2: torch.Tensor = None, # float32 global scale for L1
l2_weight_scale_2: torch.Tensor = None, # float32 global scale for L2
) -> tuple[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
) -> tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor],
tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""Transform NVFP4 weights for the CUTLASS block-scaled GEMM.
Matches the call signature from nightly vLLM deepseek_v4.py finalize_weights.
NO LONGER FOLDS GLOBAL SCALES INTO BLOCK SCALES.
Folding block_sf (float8) × global_sf (float32) → float8 loses ~25% precision
because the product lands in the low-precision zone of float8_e4m3fn.
Instead, global scales are returned separately and applied as per-expert GEMM alpha.
Args:
l1_tuple: (w13_weight, w13_weight_scale) — gate_up proj
l2_tuple: (w2_weight, w2_weight_scale) — down proj
l1_weight_scale_2: global scale for L1 (float32)
Shape (E, 2) for gate+up, or (E,) per-expert, or scalar
l2_weight_scale_2: global scale for L2 (float32)
Shape (E,) per-expert, or scalar
Returns:
((l1_weight, l1_sf_packed), (l2_weight, l2_sf_packed))
((l1_weight, l1_sf, l1_global_sf), (l2_weight, l2_sf, l2_global_sf))
where global_sf is (E,) float32 — the geometric mean of gate/up for L1,
or the per-expert global scale for L2.
The caller must apply global_sf as a per-expert multiplier to the GEMM alpha.
"""
l1_weight, l1_weight_scale = l1_tuple
l2_weight, l2_weight_scale = l2_tuple
# DEBUG: check raw scales before folding
l1_sf_f32_raw = l1_weight_scale.to(torch.float32)
l1_gs_raw = l1_weight_scale_2.to(torch.float32) if l1_weight_scale_2 is not None else None
if not getattr(transform_nvfp4_weights_for_mega_moe, '_sf_debug', False):
transform_nvfp4_weights_for_mega_moe._sf_debug = True
print(f"[SF-DEBUG] raw l1_sf dtype={l1_weight_scale.dtype} range=[{l1_sf_f32_raw.min().item():.4e}, {l1_sf_f32_raw.max().item():.4e}] "
f"unique_raw={torch.unique(l1_weight_scale.view(torch.uint8)).numel()}")
if l1_gs_raw is not None:
print(f"[SF-DEBUG] l1_gs dtype={l1_weight_scale_2.dtype} shape={tuple(l1_weight_scale_2.shape)} "
f"range=[{l1_gs_raw.min().item():.4e}, {l1_gs_raw.max().item():.4e}] "
f"unique_gs={torch.unique(l1_gs_raw).numel()}")
if l1_gs_raw.dim() == 2 and l1_gs_raw.shape[1] == 2:
print(f"[SF-DEBUG] gate gs unique={torch.unique(l1_gs_raw[:, 0]).numel()} "
f"up gs unique={torch.unique(l1_gs_raw[:, 1]).numel()}")
# DEBUG: check L2 scales
l2_sf_f32_raw = l2_weight_scale.to(torch.float32)
l2_gs_raw = l2_weight_scale_2.to(torch.float32) if l2_weight_scale_2 is not None else None
if not getattr(transform_nvfp4_weights_for_mega_moe, '_sf_debug_l2', False):
transform_nvfp4_weights_for_mega_moe._sf_debug_l2 = True
print(f"[SF-DEBUG-L2] raw l2_sf dtype={l2_weight_scale.dtype} range=[{l2_sf_f32_raw.min().item():.4e}, {l2_sf_f32_raw.max().item():.4e}] "
f"unique_raw={torch.unique(l2_weight_scale.view(torch.uint8)).numel()}")
if l2_gs_raw is not None:
print(f"[SF-DEBUG-L2] l2_gs dtype={l2_weight_scale_2.dtype} shape={tuple(l2_weight_scale_2.shape)} "
f"range=[{l2_gs_raw.min().item():.4e}, {l2_gs_raw.max().item():.4e}] "
f"unique_gs={torch.unique(l2_gs_raw).numel()}")
# Post-fold diagnostics — one-time
if not getattr(transform_nvfp4_weights_for_mega_moe, '_sf_debug_fold', False):
transform_nvfp4_weights_for_mega_moe._sf_debug_fold = True
l1_sf_folded = _fold_global_scale(l1_weight_scale, l1_weight_scale_2) if l1_weight_scale_2 is not None else l1_weight_scale.to(torch.float32)
l1_sf_out_check = l1_sf_folded.clamp(0.0, 448.0).to(torch.float8_e4m3fn)
l2_sf_folded = _fold_global_scale(l2_weight_scale, l2_weight_scale_2) if l2_weight_scale_2 is not None else l2_weight_scale.to(torch.float32)
l2_sf_out_check = l2_sf_folded.clamp(0.0, 448.0).to(torch.float8_e4m3fn)
print(f"[SF-FOLD] l1 pre-fold unique_u8={torch.unique(l1_weight_scale.view(torch.uint8)).numel()} "
f"post-fold unique_u8={torch.unique(l1_sf_out_check.view(torch.uint8)).numel()} "
f"range=[{l1_sf_folded.min().item():.4e}, {l1_sf_folded.max().item():.4e}]")
print(f"[SF-FOLD] l2 pre-fold unique_u8={torch.unique(l2_weight_scale.view(torch.uint8)).numel()} "
f"post-fold unique_u8={torch.unique(l2_sf_out_check.view(torch.uint8)).numel()} "
f"range=[{l2_sf_folded.min().item():.4e}, {l2_sf_folded.max().item():.4e}]")
# Fold global scales into block scales
# The logical_widths branch was wrong: it treated gs as per-projection
# scalars and only used experts 0 and 1's scales for ALL experts.
# The else branch correctly broadcasts each expert's own global scale.
# Extract global scales as per-expert float32 vectors
# L1: gate/up have separate global scales — store both
# The caller (nvfp4_mega_moe_full) will apply the right one per-expert
if l1_weight_scale_2 is not None:
l1_sf_folded = _fold_global_scale(l1_weight_scale, l1_weight_scale_2)
l1_gs = l1_weight_scale_2.to(torch.float32)
if l1_gs.dim() == 2 and l1_gs.shape[1] == 2:
# (E, 2) — gate_gs and up_gs separate
# For L1 alpha, use the geometric mean (close enough since gate and up
# global scales are typically similar). Actually, we need BOTH because
# the GEMM produces gate and up in one shot.
# Better: just store (E, 2) and let the caller apply post-GEMM scaling.
l1_global_sf = l1_gs # (E, 2) float32
else:
l1_global_sf = l1_gs # (E,) float32
else:
l1_sf_folded = l1_weight_scale.to(torch.float32)
l1_global_sf = torch.ones(l1_weight.shape[0], dtype=torch.float32, device=l1_weight.device)
if l2_weight_scale_2 is not None:
l2_sf_folded = _fold_global_scale(l2_weight_scale, l2_weight_scale_2)
l2_gs = l2_weight_scale_2.to(torch.float32)
l2_global_sf = l2_gs # (E,) or scalar → broadcast to (E,)
if l2_global_sf.dim() == 0:
l2_global_sf = l2_global_sf.expand(l2_weight.shape[0])
else:
l2_sf_folded = l2_weight_scale.to(torch.float32)
l2_global_sf = torch.ones(l2_weight.shape[0], dtype=torch.float32, device=l2_weight.device)
# Clamp and convert back to UE4M3
l1_sf_out = l1_sf_folded.clamp(0.0, 448.0).to(torch.float8_e4m3fn).contiguous()
l2_sf_out = l2_sf_folded.clamp(0.0, 448.0).to(torch.float8_e4m3fn).contiguous()
# Debug: one-time diagnostic
if not getattr(transform_nvfp4_weights_for_mega_moe, '_diag', False):
transform_nvfp4_weights_for_mega_moe._diag = True
print(f"[WT-XFORM] L1 block_sf range=[{l1_weight_scale.float().min():.4e}, "
f"{l1_weight_scale.float().max():.4e}] unique={torch.unique(l1_weight_scale.view(torch.uint8)).numel()}")
print(f"[WT-XFORM] L1 global_sf: shape={tuple(l1_global_sf.shape)} "
f"range=[{l1_global_sf.min():.4e}, {l1_global_sf.max():.4e}]")
print(f"[WT-XFORM] L2 block_sf range=[{l2_weight_scale.float().min():.4e}, "
f"{l2_weight_scale.float().max():.4e}] unique={torch.unique(l2_weight_scale.view(torch.uint8)).numel()}")
print(f"[WT-XFORM] L2 global_sf: shape={tuple(l2_global_sf.shape)} "
f"range=[{l2_global_sf.min():.4e}, {l2_global_sf.max():.4e}]")
# Block scales stay as original float8 — NO FOLDING
l1_sf_out = l1_weight_scale.contiguous()
l2_sf_out = l2_weight_scale.contiguous()
# CUTLASS B is declared ColumnMajor — it expects (K, N) in memory.
# Checkpoint weights are (N, K_half) row-major, so we transpose to (K_half, N)
# which is column-major (N, K_half). This is a one-time cost at load time.
l1_weight_out = l1_weight.transpose(-2, -1).contiguous()
l2_weight_out = l2_weight.transpose(-2, -1).contiguous()
@@ -159,4 +105,4 @@ def transform_nvfp4_weights_for_mega_moe(
l1_sf_out = l1_sf_out.transpose(-2, -1).contiguous()
l2_sf_out = l2_sf_out.transpose(-2, -1).contiguous()
return (l1_weight_out, l1_sf_out), (l2_weight_out, l2_sf_out)
return (l1_weight_out, l1_sf_out, l1_global_sf), (l2_weight_out, l2_sf_out, l2_global_sf)