fix: stop folding global scale into float8 block scales
The fold block_sf (float8) * global_sf (float32) -> float8 loses ~25% precision. Product of ~56-448 block_sf * ~4.65e-05 global_sf lands in float8 low-precision zone where step size is 25%. This makes model output garbage despite finite values. Fix: keep block scales as original float8, return global scales separately as float32 per-expert vectors. Apply global scale as per-expert GEMM alpha in cutlass_grouped_nvfp4_gemm (already iterates per-expert). For L1 with separate gate/up global scales, use gate_gs as alpha and apply up_correction ratio to the up half post-GEMM. weight_transform.py: no more _fold_global_scale, returns (w, sf, global_sf) nvfp4_mega_moe.py: per-expert alpha = activation_gs * weight_gs kernel.py: per_expert_alpha parameter in grouped GEMM deepseek_v4.py: updated type hints and comments
This commit is contained in:
@@ -61,6 +61,7 @@ def cutlass_grouped_nvfp4_gemm(
|
||||
slot_expert_ids, # (num_slots,) int32 — per-slot local expert IDs
|
||||
slot_token=None, # (num_slots,) int64 — per-slot token indices (default: arange)
|
||||
alpha=1.0, # fp32 scalar: D = alpha * A @ B (from stage_activation global scale)
|
||||
per_expert_alpha=None, # (E_per_rank,) float32 — per-expert alpha overrides scalar alpha
|
||||
):
|
||||
"""Per-expert grouped GEMM for MoE dispatch using CUTLASS NVFP4.
|
||||
|
||||
@@ -71,6 +72,11 @@ def cutlass_grouped_nvfp4_gemm(
|
||||
For L1: x_fp4 has num_tokens rows, slot_token maps slots→rows.
|
||||
For L2: x_fp4 has num_slots rows, slot_token is just arange(num_slots).
|
||||
|
||||
If per_expert_alpha is provided, each expert uses its own alpha value
|
||||
(activation_global_scale * weight_global_scale[expert]) instead of the
|
||||
scalar alpha. This preserves full float32 precision — no lossy float8
|
||||
folding of weight global scales.
|
||||
|
||||
Returns:
|
||||
slot_out: (num_slots, N) bfloat16 — per-slot GEMM results
|
||||
slot_token: (num_slots,) int64 — token index for each slot
|
||||
@@ -100,7 +106,7 @@ def cutlass_grouped_nvfp4_gemm(
|
||||
|
||||
if MEGA_MOE_DEBUG:
|
||||
print(f"[cutlass_grouped_gemm] slots={num_slots} K={K} N={N} "
|
||||
f"experts={num_experts}")
|
||||
f"experts={num_experts} per_expert_alpha={'yes' if per_expert_alpha is not None else 'no'}")
|
||||
|
||||
slot_out = torch.empty(num_slots, N, dtype=torch.bfloat16, device=x_fp4.device)
|
||||
|
||||
@@ -116,9 +122,12 @@ def cutlass_grouped_nvfp4_gemm(
|
||||
expert_w_sf = weight_sf[e]
|
||||
M_expert = e_idx.shape[0]
|
||||
|
||||
# Per-expert alpha: activation_gs * weight_gs (float32, no precision loss)
|
||||
expert_alpha = float(per_expert_alpha[e]) if per_expert_alpha is not None else alpha
|
||||
|
||||
if MEGA_MOE_DEBUG and e < 3 and M_expert > 0:
|
||||
print(f"[GEMM-IN] expert={e} M={M_expert} N={N} K={K} "
|
||||
f"w shape={expert_w.shape}")
|
||||
f"w shape={expert_w.shape} alpha={expert_alpha:.4e}")
|
||||
|
||||
# Shape/dtype contract asserts — SFB bugs hide in silent shape mismatches
|
||||
assert expert_x.shape == (M_expert, K // 2), f"expert_x shape {expert_x.shape} != ({M_expert}, {K // 2})"
|
||||
@@ -132,14 +141,14 @@ def cutlass_grouped_nvfp4_gemm(
|
||||
expert_x, expert_x_sf,
|
||||
expert_w, expert_w_sf,
|
||||
M_expert, N, K,
|
||||
alpha=alpha,
|
||||
alpha=expert_alpha,
|
||||
)
|
||||
|
||||
if MEGA_MOE_DEBUG:
|
||||
if torch.isnan(expert_out).any() or torch.isinf(expert_out).any():
|
||||
raise RuntimeError(
|
||||
f"expert {e} of {num_experts}: GEMM emitted NaN/Inf. "
|
||||
f"M={M_expert} N={N} K={K}")
|
||||
f"M={M_expert} N={N} K={K} alpha={expert_alpha:.4e}")
|
||||
|
||||
slot_out[e_idx] = expert_out
|
||||
|
||||
|
||||
@@ -56,8 +56,6 @@ def unpack_ue4m3_u32(x_u32):
|
||||
return out
|
||||
|
||||
# CUTLASS native NVFP4 block-scaled GEMM (SM100 Blackwell)
|
||||
# Primary path: uses CUTLASS MainloopSm100TmaUmmaWarpSpecializedBlockScaled
|
||||
# which invokes mxf8f6f4.block_scale tensor core instructions directly.
|
||||
MEGA_MOE_USE_CUTLASS = int(os.environ.get("MEGA_MOE_USE_CUTLASS", "1"))
|
||||
|
||||
try:
|
||||
@@ -97,13 +95,25 @@ def nvfp4_mega_moe_l1(
|
||||
l1_scales, # (E_per_rank, sf_k_groups, 2*INTER) float8_e4m3fn, column-major
|
||||
slot_expert_ids, # (num_slots,) int32 — per-slot local expert IDs
|
||||
slot_token, # (num_slots,) int64 — token index per slot
|
||||
l1_global_sf, # (E_per_rank, 2) or (E_per_rank,) float32 — weight global scales
|
||||
alpha=1.0, # fp32 scalar from stage_activation global scale
|
||||
):
|
||||
"""L1 GEMM: gate_up_proj — slot-based, no routing weights.
|
||||
|
||||
Takes pre-built slot mapping (slot_expert_ids, slot_token) from the outer
|
||||
routing logic. Returns (slot_out, slot_token) where each slot is one
|
||||
(token, topk) pair.
|
||||
Global scale is NOT folded into block scales. Instead, it's applied as a
|
||||
per-expert multiplier to the GEMM alpha: alpha_expert = alpha * global_sf[expert].
|
||||
For L1 with gate+up: gate and up share one GEMM but may have different global scales.
|
||||
Since the GEMM produces gate|up in one shot, we use a single alpha per expert.
|
||||
Post-GEMM, we apply the gate/up ratio correction if they differ.
|
||||
|
||||
Actually, for simplicity and correctness: we use the gate global scale as alpha
|
||||
and correct the up portion after GEMM. But since gate and up global scales
|
||||
are typically identical in practice, we just use the geometric mean.
|
||||
|
||||
CLEANER APPROACH: use per-expert alpha directly in the grouped GEMM.
|
||||
The grouped GEMM iterates per expert, so each expert can have its own alpha.
|
||||
For L1 with separate gate/up global scales, we use the geometric mean
|
||||
and then apply a correction factor to the up portion.
|
||||
"""
|
||||
K_half = x_fp4.shape[1]
|
||||
K = K_half * 2
|
||||
@@ -116,13 +126,36 @@ def nvfp4_mega_moe_l1(
|
||||
w_sf_fp8 = unpack_ue4m3_u32(l1_scales) if l1_scales.dtype == torch.uint32 else l1_scales
|
||||
assert w_sf_fp8.dtype == torch.float8_e4m3fn, f"l1_scales after unpack dtype={w_sf_fp8.dtype}"
|
||||
|
||||
# Compute per-expert alpha: activation_gs * weight_gs
|
||||
# For L1 with (E, 2) gate/up global scales, use geometric mean per expert
|
||||
if l1_global_sf.dim() == 2 and l1_global_sf.shape[1] == 2:
|
||||
# gate_gs and up_gs per expert — use gate_gs for the GEMM alpha,
|
||||
# then correct the up half post-GEMM
|
||||
l1_gate_gs = l1_global_sf[:, 0] # (E,) float32
|
||||
l1_up_gs = l1_global_sf[:, 1] # (E,) float32
|
||||
per_expert_alpha = alpha * l1_gate_gs # (E,) float32
|
||||
up_correction = l1_up_gs / l1_gate_gs # (E,) float32 — ratio to apply to up half
|
||||
else:
|
||||
per_expert_alpha = alpha * l1_global_sf # (E,) float32
|
||||
up_correction = None
|
||||
|
||||
slot_out, slot_token = cutlass_grouped_nvfp4_gemm(
|
||||
x_fp4, x_sf_fp8,
|
||||
l1_weights, w_sf_fp8,
|
||||
slot_expert_ids, # 1D per-slot expert IDs
|
||||
slot_token, # 1D per-slot token indices
|
||||
alpha=alpha,
|
||||
slot_expert_ids,
|
||||
slot_token,
|
||||
per_expert_alpha=per_expert_alpha,
|
||||
)
|
||||
|
||||
# Apply up correction if gate/up global scales differ
|
||||
if up_correction is not None:
|
||||
gate_N = N // 2
|
||||
# For each slot, apply the correction to the up half
|
||||
# slot_out is (num_slots, N) — up half is [:, gate_N:]
|
||||
# Correction factor is per-expert: up_correction[slot_expert_ids]
|
||||
correction = up_correction[slot_expert_ids].unsqueeze(1) # (num_slots, 1)
|
||||
slot_out[:, gate_N:] = slot_out[:, gate_N:] * correction.to(slot_out.dtype)
|
||||
|
||||
print(f"[L1-GEMM-OUT] slots={slot_out.shape[0]} N={N} amax={slot_out.abs().max().item():.4e} mean={slot_out.float().mean().item():.4e}")
|
||||
return slot_out, slot_token
|
||||
|
||||
@@ -132,13 +165,15 @@ def nvfp4_mega_moe_l2(
|
||||
x_sf, # (num_slots, sf_k_groups) float8_e4m3fn
|
||||
l2_weights, # (E_per_rank, INTER//2, HIDDEN) int8, column-major for CUTLASS
|
||||
l2_scales, # (E_per_rank, sf_k_groups, HIDDEN) float8_e4m3fn, column-major
|
||||
slot_expert_ids, # (num_slots,) int32 — per-slot local expert IDs (from L1 routing)
|
||||
slot_token, # (num_slots,) int64 — token index per slot (from L1)
|
||||
slot_expert_ids, # (num_slots,) int32 — per-slot local expert IDs
|
||||
slot_token, # (num_slots,) int64 — token index per slot
|
||||
l2_global_sf, # (E_per_rank,) float32 — weight global scales
|
||||
alpha=1.0, # fp32 scalar from stage_activation global scale
|
||||
):
|
||||
"""L2 GEMM: down_proj — slot-based, no routing weights.
|
||||
|
||||
Reuses the same slot mapping from L1 (same slot_token and slot_expert_ids).
|
||||
Per-expert alpha = activation_global_scale * weight_global_scale[expert].
|
||||
This preserves full float32 precision — no lossy float8 folding.
|
||||
"""
|
||||
K_half = x_fp4.shape[1]
|
||||
K = K_half * 2
|
||||
@@ -151,11 +186,14 @@ def nvfp4_mega_moe_l2(
|
||||
w_sf_fp8 = unpack_ue4m3_u32(l2_scales) if l2_scales.dtype == torch.uint32 else l2_scales
|
||||
assert w_sf_fp8.dtype == torch.float8_e4m3fn, f"l2_scales after unpack dtype={w_sf_fp8.dtype}"
|
||||
|
||||
# Per-expert alpha: activation_gs * weight_gs
|
||||
per_expert_alpha = alpha * l2_global_sf # (E,) float32
|
||||
|
||||
slot_out, _ = cutlass_grouped_nvfp4_gemm(
|
||||
x_fp4, x_sf_fp8,
|
||||
l2_weights, w_sf_fp8,
|
||||
slot_expert_ids, # 1D per-slot expert IDs — GEMM handles directly
|
||||
alpha=alpha,
|
||||
slot_expert_ids,
|
||||
per_expert_alpha=per_expert_alpha,
|
||||
)
|
||||
return slot_out # (num_slots, HIDDEN) bfloat16
|
||||
|
||||
@@ -236,8 +274,8 @@ def stage_activation(x_bf16):
|
||||
|
||||
def nvfp4_mega_moe_full(
|
||||
y, # output tensor (num_tokens, HIDDEN) bfloat16
|
||||
transformed_l1_weights, # (l1_w, l1_sf) tuple from finalize_weights
|
||||
transformed_l2_weights, # (l2_w, l2_sf) tuple from finalize_weights
|
||||
transformed_l1_weights, # (l1_w, l1_sf, l1_global_sf) from finalize_weights
|
||||
transformed_l2_weights, # (l2_w, l2_sf, l2_global_sf) from finalize_weights
|
||||
symm_buffer, # SymmBuffer from get_symm_buffer
|
||||
activation_clamp=None, # optional clamp value (unused in NVFP4)
|
||||
fast_math=False, # fast math flag (unused in NVFP4)
|
||||
@@ -246,10 +284,10 @@ def nvfp4_mega_moe_full(
|
||||
|
||||
Slot-based pipeline (routing weights applied ONCE at final scatter):
|
||||
1. Read staged activation from symm_buffer
|
||||
2. L1 GEMM → slot output (num_slots, 2*INTER) — NO routing weights
|
||||
2. L1 GEMM → slot output (num_slots, 2*INTER) — per-expert alpha
|
||||
3. SiLU + Mul PER SLOT (nonlinearity before combining expert paths)
|
||||
4. Quantize activated slots → FP4
|
||||
5. L2 GEMM → slot output (num_slots, HIDDEN) — NO routing weights
|
||||
5. L2 GEMM → slot output (num_slots, HIDDEN) — per-expert alpha
|
||||
6. Final scatter: y.index_add_(0, slot_token, slot_weight * l2_slots)
|
||||
Single routing weight application.
|
||||
"""
|
||||
@@ -264,9 +302,9 @@ def nvfp4_mega_moe_full(
|
||||
y.zero_()
|
||||
return
|
||||
|
||||
# Unpack transformed weights
|
||||
l1_w, l1_sf = transformed_l1_weights
|
||||
l2_w, l2_sf = transformed_l2_weights
|
||||
# Unpack transformed weights (now includes global_sf)
|
||||
l1_w, l1_sf, l1_global_sf = transformed_l1_weights
|
||||
l2_w, l2_sf, l2_global_sf = transformed_l2_weights
|
||||
|
||||
# Expert sanity check — are experts actually distinct?
|
||||
if not getattr(nvfp4_mega_moe_full, '_expert_sanity', False):
|
||||
@@ -276,6 +314,8 @@ def nvfp4_mega_moe_full(
|
||||
sf_sample = l1_sf[e].to(torch.float32)[:4, :4]
|
||||
print(f"[EXPERT-SANITY e={e}] w_bytes[:8,:8]={w_sample.flatten().tolist()[:16]}")
|
||||
print(f"[EXPERT-SANITY e={e}] sf[:4,:4]={sf_sample.flatten().tolist()[:8]}")
|
||||
print(f"[EXPERT-SANITY e={e}] l1_global_sf={l1_global_sf[e].tolist()}")
|
||||
print(f"[EXPERT-SANITY e={e}] l2_global_sf={l2_global_sf[e].tolist()}")
|
||||
|
||||
# Step 1: Read staged activation from symm_buffer
|
||||
x_fp4 = symm_buffer.x[:num_tokens]
|
||||
@@ -287,7 +327,8 @@ def nvfp4_mega_moe_full(
|
||||
_x_sf_f32 = x_sf.to(torch.float32)
|
||||
_igs = l1_global_scale if isinstance(l1_global_scale, float) else l1_global_scale.item() if hasattr(l1_global_scale, 'item') else float(l1_global_scale)
|
||||
if MEGA_MOE_DEBUG:
|
||||
print(f"[ALPHA L1] alpha={_igs:.4e} x_sf range [{_x_sf_f32.min().item():.4e}, {_x_sf_f32.max().item():.4e}]")
|
||||
print(f"[ALPHA L1] activation_gs={_igs:.4e} x_sf range [{_x_sf_f32.min().item():.4e}, {_x_sf_f32.max().item():.4e}]")
|
||||
print(f"[ALPHA L1] l1_global_sf range [{l1_global_sf.min().item():.4e}, {l1_global_sf.max().item():.4e}]")
|
||||
|
||||
# Convert global expert IDs to local expert IDs
|
||||
num_experts_per_rank = l1_w.shape[0]
|
||||
@@ -316,10 +357,10 @@ def nvfp4_mega_moe_full(
|
||||
y.zero_()
|
||||
return
|
||||
|
||||
# Ensure alpha is a plain Python float (C extension can't handle torch scalars)
|
||||
# Ensure alpha is a plain Python float for the base activation global scale
|
||||
l1_alpha = float(l1_global_scale) if not isinstance(l1_global_scale, float) else l1_global_scale
|
||||
|
||||
# Shape consistency asserts — catch mismatched slot mappings early
|
||||
# Shape consistency asserts
|
||||
assert slot_expert_local.ndim == 1
|
||||
assert slot_token.ndim == 1
|
||||
assert slot_weight.ndim == 1
|
||||
@@ -327,21 +368,11 @@ def nvfp4_mega_moe_full(
|
||||
assert slot_token.numel() == num_slots
|
||||
assert slot_weight.numel() == num_slots
|
||||
|
||||
# SFB weight scales are remapped per-expert inside CUTLASS on each call.
|
||||
# ─────────────────────────────────────────────────────────────────────
|
||||
# NO PREPACK CACHE — see README for rationale.
|
||||
# DO NOT add a prepack cache. Previous attempts caused:
|
||||
# - OOM: ~1.75 GiB per prepacked tensor × 61 layers = 214 GiB
|
||||
# - Peak memory 2× during torch.stack before eviction
|
||||
# - CUDA graph use-after-free on evicted entries
|
||||
# - M_for_layout=128 assumption (unverified M-independence)
|
||||
# The SFB remap is a small scatter kernel (~µs) — not the bottleneck.
|
||||
# ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
# Step 2: L1 GEMM — slot-based, no routing weights
|
||||
# Step 2: L1 GEMM — slot-based, per-expert alpha
|
||||
l1_slots, _ = nvfp4_mega_moe_l1(
|
||||
x_fp4, x_sf, l1_w, l1_sf,
|
||||
slot_expert_local, slot_token,
|
||||
l1_global_sf=l1_global_sf,
|
||||
alpha=l1_alpha,
|
||||
) # (num_slots, 2*INTER) bfloat16
|
||||
|
||||
@@ -374,12 +405,14 @@ def nvfp4_mega_moe_full(
|
||||
if MEGA_MOE_DEBUG:
|
||||
_l1sf_f32 = l1_sf_out.to(torch.float32)
|
||||
_l2gs = l2_global_scale if isinstance(l2_global_scale, float) else l2_global_scale.item()
|
||||
print(f"[ALPHA L2] alpha={_l2gs:.4e} l1_sf range [{_l1sf_f32.min().item():.4e}, {_l1sf_f32.max().item():.4e}]")
|
||||
print(f"[ALPHA L2] activation_gs={_l2gs:.4e} l1_sf range [{_l1sf_f32.min().item():.4e}, {_l1sf_f32.max().item():.4e}]")
|
||||
print(f"[ALPHA L2] l2_global_sf range [{l2_global_sf.min().item():.4e}, {l2_global_sf.max().item():.4e}]")
|
||||
|
||||
# Step 5: L2 GEMM — slot-based, no routing weights
|
||||
# Step 5: L2 GEMM — slot-based, per-expert alpha
|
||||
l2_slots = nvfp4_mega_moe_l2(
|
||||
l1_fp4, l1_sf_out, l2_w, l2_sf,
|
||||
slot_expert_local, slot_token,
|
||||
l2_global_sf=l2_global_sf,
|
||||
alpha=l2_alpha,
|
||||
) # (num_slots, HIDDEN) bfloat16
|
||||
|
||||
|
||||
@@ -5,12 +5,12 @@ Converts raw NVFP4 checkpoint weights (uint8 E2M1 + float8_e4m3fn UE4M3 + float3
|
||||
into the format expected by the CUTLASS block-scaled GEMM kernel:
|
||||
- Packed FP4 weights (int8, K-major)
|
||||
- UE4M3 block scales (float8_e4m3fn, row-major — CUTLASS SF remap handles interleaving)
|
||||
- float32 global scales (NOT folded into block scales — passed separately for per-expert alpha)
|
||||
|
||||
Weight scales are returned as float8_e4m3fn (NOT packed uint32). The CUTLASS GEMM
|
||||
consumes float8 scales directly; only activation scales from the staging kernel come
|
||||
as uint32 and need unpack_ue4m3_u32.
|
||||
|
||||
This replaces deep_gemma.mega.transform_nvfp4_weights_for_mega_moe.
|
||||
Previous versions folded weight_scale_2 into block scales via float8 round-trip, which caused
|
||||
25% relative error (product of ~56-448 block_sf × ~4.65e-05 global_sf lands in the low-precision
|
||||
zone of float8_e4m3fn where step size is 25%). The global scale is now applied as a per-expert
|
||||
multiplier to the GEMM alpha, preserving full float32 precision.
|
||||
|
||||
Call signature matches the nightly vLLM deepseek_v4.py finalize_weights:
|
||||
transform_nvfp4_weights_for_mega_moe(
|
||||
@@ -24,134 +24,80 @@ Call signature matches the nightly vLLM deepseek_v4.py finalize_weights:
|
||||
import torch
|
||||
|
||||
|
||||
def _fold_global_scale(
|
||||
weight_scale: torch.Tensor, # (E, N, K//16) float8_e4m3fn
|
||||
weight_scale_2: torch.Tensor, # (E,) or (E, 2) or scalar float32
|
||||
) -> torch.Tensor:
|
||||
"""Fold global scale into block scales: UE4M3 * FP32 → float32.
|
||||
|
||||
For fused projections (w13 = gate+up), weight_scale_2 is (E, 2):
|
||||
scale_2[e, 0] applies to gate_proj rows, scale_2[e, 1] applies to up_proj rows.
|
||||
N is split: gate = weight_scale[:, :N//2, :], up = weight_scale[:, N//2:, :]
|
||||
For single projections (w2), weight_scale_2 is (E,) or scalar.
|
||||
"""
|
||||
sf_f32 = weight_scale.to(torch.float32)
|
||||
gs = weight_scale_2.to(torch.float32)
|
||||
|
||||
if gs.numel() == 1:
|
||||
sf_f32 = sf_f32 * gs
|
||||
elif gs.dim() == 2 and gs.shape[1] == 2:
|
||||
# Fused projection: (E, 2) — gate and up have separate global scales
|
||||
# weight_scale is (E, N, K//16), N = gate_N + up_N
|
||||
gate_N = sf_f32.shape[1] // 2
|
||||
gs_gate = gs[:, 0].unsqueeze(-1) # (E, 1)
|
||||
gs_up = gs[:, 1].unsqueeze(-1) # (E, 1)
|
||||
sf_f32[:, :gate_N, :] = sf_f32[:, :gate_N, :] * gs_gate.unsqueeze(-1)
|
||||
sf_f32[:, gate_N:, :] = sf_f32[:, gate_N:, :] * gs_up.unsqueeze(-1)
|
||||
else:
|
||||
# Per-expert global scale — broadcast multiply
|
||||
while gs.dim() < sf_f32.dim():
|
||||
gs = gs.unsqueeze(-1)
|
||||
sf_f32 = sf_f32 * gs.expand_as(sf_f32)
|
||||
|
||||
return sf_f32
|
||||
|
||||
|
||||
def _pack_ue4m3_to_uint32(sf: torch.Tensor) -> torch.Tensor:
|
||||
"""Pack 4 UE4M3 (float8_e4m3fn) values into one uint32."""
|
||||
sf_u8 = sf.view(torch.uint8)
|
||||
assert sf_u8.shape[-1] % 4 == 0, f"Last dim {sf_u8.shape[-1]} not divisible by 4"
|
||||
packed = (sf_u8[..., 0::4].to(torch.int32) |
|
||||
(sf_u8[..., 1::4].to(torch.int32) << 8) |
|
||||
(sf_u8[..., 2::4].to(torch.int32) << 16) |
|
||||
(sf_u8[..., 3::4].to(torch.int32) << 24))
|
||||
return packed.contiguous()
|
||||
|
||||
|
||||
def transform_nvfp4_weights_for_mega_moe(
|
||||
l1_tuple: tuple[torch.Tensor, torch.Tensor], # (weight, weight_scale)
|
||||
l2_tuple: tuple[torch.Tensor, torch.Tensor], # (weight, weight_scale)
|
||||
l1_weight_scale_2: torch.Tensor = None, # float32 global scale for L1
|
||||
l2_weight_scale_2: torch.Tensor = None, # float32 global scale for L2
|
||||
) -> tuple[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
|
||||
) -> tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
||||
tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
|
||||
"""Transform NVFP4 weights for the CUTLASS block-scaled GEMM.
|
||||
|
||||
Matches the call signature from nightly vLLM deepseek_v4.py finalize_weights.
|
||||
|
||||
|
||||
NO LONGER FOLDS GLOBAL SCALES INTO BLOCK SCALES.
|
||||
Folding block_sf (float8) × global_sf (float32) → float8 loses ~25% precision
|
||||
because the product lands in the low-precision zone of float8_e4m3fn.
|
||||
Instead, global scales are returned separately and applied as per-expert GEMM alpha.
|
||||
|
||||
Args:
|
||||
l1_tuple: (w13_weight, w13_weight_scale) — gate_up proj
|
||||
l2_tuple: (w2_weight, w2_weight_scale) — down proj
|
||||
l1_weight_scale_2: global scale for L1 (float32)
|
||||
Shape (E, 2) for gate+up, or (E,) per-expert, or scalar
|
||||
l2_weight_scale_2: global scale for L2 (float32)
|
||||
|
||||
Shape (E,) per-expert, or scalar
|
||||
|
||||
Returns:
|
||||
((l1_weight, l1_sf_packed), (l2_weight, l2_sf_packed))
|
||||
((l1_weight, l1_sf, l1_global_sf), (l2_weight, l2_sf, l2_global_sf))
|
||||
where global_sf is (E,) float32 — the geometric mean of gate/up for L1,
|
||||
or the per-expert global scale for L2.
|
||||
The caller must apply global_sf as a per-expert multiplier to the GEMM alpha.
|
||||
"""
|
||||
l1_weight, l1_weight_scale = l1_tuple
|
||||
l2_weight, l2_weight_scale = l2_tuple
|
||||
|
||||
# DEBUG: check raw scales before folding
|
||||
l1_sf_f32_raw = l1_weight_scale.to(torch.float32)
|
||||
l1_gs_raw = l1_weight_scale_2.to(torch.float32) if l1_weight_scale_2 is not None else None
|
||||
if not getattr(transform_nvfp4_weights_for_mega_moe, '_sf_debug', False):
|
||||
transform_nvfp4_weights_for_mega_moe._sf_debug = True
|
||||
print(f"[SF-DEBUG] raw l1_sf dtype={l1_weight_scale.dtype} range=[{l1_sf_f32_raw.min().item():.4e}, {l1_sf_f32_raw.max().item():.4e}] "
|
||||
f"unique_raw={torch.unique(l1_weight_scale.view(torch.uint8)).numel()}")
|
||||
if l1_gs_raw is not None:
|
||||
print(f"[SF-DEBUG] l1_gs dtype={l1_weight_scale_2.dtype} shape={tuple(l1_weight_scale_2.shape)} "
|
||||
f"range=[{l1_gs_raw.min().item():.4e}, {l1_gs_raw.max().item():.4e}] "
|
||||
f"unique_gs={torch.unique(l1_gs_raw).numel()}")
|
||||
if l1_gs_raw.dim() == 2 and l1_gs_raw.shape[1] == 2:
|
||||
print(f"[SF-DEBUG] gate gs unique={torch.unique(l1_gs_raw[:, 0]).numel()} "
|
||||
f"up gs unique={torch.unique(l1_gs_raw[:, 1]).numel()}")
|
||||
|
||||
# DEBUG: check L2 scales
|
||||
l2_sf_f32_raw = l2_weight_scale.to(torch.float32)
|
||||
l2_gs_raw = l2_weight_scale_2.to(torch.float32) if l2_weight_scale_2 is not None else None
|
||||
if not getattr(transform_nvfp4_weights_for_mega_moe, '_sf_debug_l2', False):
|
||||
transform_nvfp4_weights_for_mega_moe._sf_debug_l2 = True
|
||||
print(f"[SF-DEBUG-L2] raw l2_sf dtype={l2_weight_scale.dtype} range=[{l2_sf_f32_raw.min().item():.4e}, {l2_sf_f32_raw.max().item():.4e}] "
|
||||
f"unique_raw={torch.unique(l2_weight_scale.view(torch.uint8)).numel()}")
|
||||
if l2_gs_raw is not None:
|
||||
print(f"[SF-DEBUG-L2] l2_gs dtype={l2_weight_scale_2.dtype} shape={tuple(l2_weight_scale_2.shape)} "
|
||||
f"range=[{l2_gs_raw.min().item():.4e}, {l2_gs_raw.max().item():.4e}] "
|
||||
f"unique_gs={torch.unique(l2_gs_raw).numel()}")
|
||||
|
||||
# Post-fold diagnostics — one-time
|
||||
if not getattr(transform_nvfp4_weights_for_mega_moe, '_sf_debug_fold', False):
|
||||
transform_nvfp4_weights_for_mega_moe._sf_debug_fold = True
|
||||
l1_sf_folded = _fold_global_scale(l1_weight_scale, l1_weight_scale_2) if l1_weight_scale_2 is not None else l1_weight_scale.to(torch.float32)
|
||||
l1_sf_out_check = l1_sf_folded.clamp(0.0, 448.0).to(torch.float8_e4m3fn)
|
||||
l2_sf_folded = _fold_global_scale(l2_weight_scale, l2_weight_scale_2) if l2_weight_scale_2 is not None else l2_weight_scale.to(torch.float32)
|
||||
l2_sf_out_check = l2_sf_folded.clamp(0.0, 448.0).to(torch.float8_e4m3fn)
|
||||
print(f"[SF-FOLD] l1 pre-fold unique_u8={torch.unique(l1_weight_scale.view(torch.uint8)).numel()} "
|
||||
f"post-fold unique_u8={torch.unique(l1_sf_out_check.view(torch.uint8)).numel()} "
|
||||
f"range=[{l1_sf_folded.min().item():.4e}, {l1_sf_folded.max().item():.4e}]")
|
||||
print(f"[SF-FOLD] l2 pre-fold unique_u8={torch.unique(l2_weight_scale.view(torch.uint8)).numel()} "
|
||||
f"post-fold unique_u8={torch.unique(l2_sf_out_check.view(torch.uint8)).numel()} "
|
||||
f"range=[{l2_sf_folded.min().item():.4e}, {l2_sf_folded.max().item():.4e}]")
|
||||
|
||||
# Fold global scales into block scales
|
||||
# The logical_widths branch was wrong: it treated gs as per-projection
|
||||
# scalars and only used experts 0 and 1's scales for ALL experts.
|
||||
# The else branch correctly broadcasts each expert's own global scale.
|
||||
# Extract global scales as per-expert float32 vectors
|
||||
# L1: gate/up have separate global scales — store both
|
||||
# The caller (nvfp4_mega_moe_full) will apply the right one per-expert
|
||||
if l1_weight_scale_2 is not None:
|
||||
l1_sf_folded = _fold_global_scale(l1_weight_scale, l1_weight_scale_2)
|
||||
l1_gs = l1_weight_scale_2.to(torch.float32)
|
||||
if l1_gs.dim() == 2 and l1_gs.shape[1] == 2:
|
||||
# (E, 2) — gate_gs and up_gs separate
|
||||
# For L1 alpha, use the geometric mean (close enough since gate and up
|
||||
# global scales are typically similar). Actually, we need BOTH because
|
||||
# the GEMM produces gate and up in one shot.
|
||||
# Better: just store (E, 2) and let the caller apply post-GEMM scaling.
|
||||
l1_global_sf = l1_gs # (E, 2) float32
|
||||
else:
|
||||
l1_global_sf = l1_gs # (E,) float32
|
||||
else:
|
||||
l1_sf_folded = l1_weight_scale.to(torch.float32)
|
||||
l1_global_sf = torch.ones(l1_weight.shape[0], dtype=torch.float32, device=l1_weight.device)
|
||||
|
||||
if l2_weight_scale_2 is not None:
|
||||
l2_sf_folded = _fold_global_scale(l2_weight_scale, l2_weight_scale_2)
|
||||
l2_gs = l2_weight_scale_2.to(torch.float32)
|
||||
l2_global_sf = l2_gs # (E,) or scalar → broadcast to (E,)
|
||||
if l2_global_sf.dim() == 0:
|
||||
l2_global_sf = l2_global_sf.expand(l2_weight.shape[0])
|
||||
else:
|
||||
l2_sf_folded = l2_weight_scale.to(torch.float32)
|
||||
l2_global_sf = torch.ones(l2_weight.shape[0], dtype=torch.float32, device=l2_weight.device)
|
||||
|
||||
# Clamp and convert back to UE4M3
|
||||
l1_sf_out = l1_sf_folded.clamp(0.0, 448.0).to(torch.float8_e4m3fn).contiguous()
|
||||
l2_sf_out = l2_sf_folded.clamp(0.0, 448.0).to(torch.float8_e4m3fn).contiguous()
|
||||
# Debug: one-time diagnostic
|
||||
if not getattr(transform_nvfp4_weights_for_mega_moe, '_diag', False):
|
||||
transform_nvfp4_weights_for_mega_moe._diag = True
|
||||
print(f"[WT-XFORM] L1 block_sf range=[{l1_weight_scale.float().min():.4e}, "
|
||||
f"{l1_weight_scale.float().max():.4e}] unique={torch.unique(l1_weight_scale.view(torch.uint8)).numel()}")
|
||||
print(f"[WT-XFORM] L1 global_sf: shape={tuple(l1_global_sf.shape)} "
|
||||
f"range=[{l1_global_sf.min():.4e}, {l1_global_sf.max():.4e}]")
|
||||
print(f"[WT-XFORM] L2 block_sf range=[{l2_weight_scale.float().min():.4e}, "
|
||||
f"{l2_weight_scale.float().max():.4e}] unique={torch.unique(l2_weight_scale.view(torch.uint8)).numel()}")
|
||||
print(f"[WT-XFORM] L2 global_sf: shape={tuple(l2_global_sf.shape)} "
|
||||
f"range=[{l2_global_sf.min():.4e}, {l2_global_sf.max():.4e}]")
|
||||
|
||||
# Block scales stay as original float8 — NO FOLDING
|
||||
l1_sf_out = l1_weight_scale.contiguous()
|
||||
l2_sf_out = l2_weight_scale.contiguous()
|
||||
|
||||
# CUTLASS B is declared ColumnMajor — it expects (K, N) in memory.
|
||||
# Checkpoint weights are (N, K_half) row-major, so we transpose to (K_half, N)
|
||||
# which is column-major (N, K_half). This is a one-time cost at load time.
|
||||
l1_weight_out = l1_weight.transpose(-2, -1).contiguous()
|
||||
l2_weight_out = l2_weight.transpose(-2, -1).contiguous()
|
||||
|
||||
@@ -159,4 +105,4 @@ def transform_nvfp4_weights_for_mega_moe(
|
||||
l1_sf_out = l1_sf_out.transpose(-2, -1).contiguous()
|
||||
l2_sf_out = l2_sf_out.transpose(-2, -1).contiguous()
|
||||
|
||||
return (l1_weight_out, l1_sf_out), (l2_weight_out, l2_sf_out)
|
||||
return (l1_weight_out, l1_sf_out, l1_global_sf), (l2_weight_out, l2_sf_out, l2_global_sf)
|
||||
|
||||
Reference in New Issue
Block a user