feat: direct NVFP4 path — no BF16 round-trip on weights

finalize_weights() now view-casts checkpoint uint8 → float4_e2m1fn_x2
directly. Block scales (float8_e4m3fn) and global scales (float32)
pass through unchanged. Zero precision loss on the weights themselves.

L1 dual global scale handling: gate and up have different global scales.
Normalize to max(gate_gs, up_gs) and fold the ratio into block scales
via float32 (one multiply + float8 round-trip on the RATIO only —
much better than dequantizing the entire weight matrix).

layertest.py: updated to test direct path. Expect cosine improvement
from 0.989 → 0.995+ (matching the L1-only result).
This commit is contained in:
2026-05-16 03:41:23 +00:00
parent 8fd9579127
commit 389453fbf4
3 changed files with 150 additions and 52 deletions

View File

@@ -16,7 +16,6 @@ REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, REPO_ROOT)
from cutedsl.moe_pipeline import (
prepare_nvfp4_moe_weights,
run_nvfp4_moe,
)
@@ -121,6 +120,72 @@ def moe_forward_bf16(hidden_states, experts, expert_ids, expert_weights):
return output
def prepare_nvfp4_weights_direct(nvfp4_tensors, layer_idx, expert_indices, intermediate_size):
"""Prepare weights via direct view-cast (no BF16 round-trip).
Checkpoint uint8 → float4_e2m1fn_x2 (byte-preserving).
Block scales float8_e4m3fn → used directly.
Global scales float32 → used directly.
For L1 (gate+up fused): normalize dual global scales to max, fold ratio
into block scales via float32 (one multiply + float8 round-trip on ratio only).
"""
l1_fp4, l1_sf, l1_gs = [], [], []
l2_fp4, l2_sf, l2_gs = [], [], []
for e in expert_indices:
# L1: gate + up
gate_w = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight"].to(DEVICE)
up_w = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight"].to(DEVICE)
gate_sf = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight_scale"].to(DEVICE)
up_sf = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight_scale"].to(DEVICE)
gate_gs = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight_scale_2"].item()
up_gs = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight_scale_2"].item()
# Fuse gate+up along N, transpose to K-major
fused_w = torch.cat([gate_w, up_w], dim=0) # (2*intermediate, hidden//2) uint8
fused_w_fp4 = fused_w.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous()
# (hidden//2, 2*intermediate) — K=hidden packed, N=2*intermediate
fused_sf = torch.cat([gate_sf, up_sf], dim=0) # (2*intermediate, hidden//16)
# Normalize dual global scales
l1_max_gs = max(gate_gs, up_gs)
if gate_gs != up_gs:
fused_sf_f32 = fused_sf.float()
fused_sf_f32[:intermediate_size] *= (gate_gs / l1_max_gs)
fused_sf_f32[intermediate_size:] *= (up_gs / l1_max_gs)
fused_sf = fused_sf_f32.to(torch.float8_e4m3fn)
l1_fp4.append(fused_w_fp4)
l1_sf.append(fused_sf)
l1_gs.append(l1_max_gs)
# L2: down
down_key = f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight"
if down_key in nvfp4_tensors:
down_w = nvfp4_tensors[down_key].to(DEVICE)
down_sf = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight_scale"].to(DEVICE)
down_gs = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight_scale_2"].item()
down_w_fp4 = down_w.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous()
# (intermediate//2, hidden) — K=intermediate packed, N=hidden
l2_fp4.append(down_w_fp4)
l2_sf.append(down_sf)
l2_gs.append(down_gs)
else:
# Expert 211 has no down_proj
l2_fp4.append(torch.zeros(3072 // 2, 7168, dtype=torch.float4_e2m1fn_x2, device=DEVICE))
l2_sf.append(torch.ones(7168, 3072 // 16, dtype=torch.float8_e4m3fn, device=DEVICE))
l2_gs.append(1.0)
return {
'l1_fp4': l1_fp4, 'l1_sf': l1_sf, 'l1_gs': l1_gs,
'l2_fp4': l2_fp4, 'l2_sf': l2_sf, 'l2_gs': l2_gs,
}
def main():
torch.manual_seed(42)
expert_indices = [0, 1, 2]
@@ -135,9 +200,9 @@ def main():
nvfp4_tensors = load_layer_tensors(NVFP4_MODEL_DIR, LAYER_IDX)
print(f" {len(nvfp4_tensors)} tensors loaded")
# Prepare weights
print("\n Preparing NVFP4 weights...")
weights = prepare_nvfp4_moe_weights(nvfp4_tensors, LAYER_IDX, expert_indices)
# Prepare weights — DIRECT PATH (no BF16 round-trip)
print("\n Preparing NVFP4 weights (direct view-cast)...")
weights = prepare_nvfp4_weights_direct(nvfp4_tensors, LAYER_IDX, expert_indices, 3072)
print(f" L1: {len(weights['l1_fp4'])} experts, shape {weights['l1_fp4'][0].shape}")
print(f" L2: {len(weights['l2_fp4'])} experts, shape {weights['l2_fp4'][0].shape}")

View File

@@ -54,6 +54,20 @@ class CuTeDSLMoERunner:
self.l2_sf = None
self.l2_gs = None
def prepare_weights_direct(self, l1_fp4, l1_sf, l1_gs, l2_fp4, l2_sf, l2_gs):
"""Set weights directly from checkpoint (no dequant→requant).
Use this when you've view-cast checkpoint uint8 → float4_e2m1fn_x2
and passed block scales / global scales through directly.
Zero precision loss — the bytes are identical.
"""
self.l1_fp4 = l1_fp4
self.l1_sf = l1_sf
self.l1_gs = l1_gs
self.l2_fp4 = l2_fp4
self.l2_sf = l2_sf
self.l2_gs = l2_gs
def prepare_weights_from_dequantized(self, l1_weights_bf16, l2_weights_bf16):
"""Prepare NVFP4 weights from dequantized BF16 tensors.

View File

@@ -417,65 +417,84 @@ class DeepseekV4MegaMoEExperts(nn.Module):
self._check_runtime_supported()
# Dequantize checkpoint NVFP4 → BF16, then re-quantize to native
# float4_e2m1fn_x2 for the CuTeDSL kernel.
# Future optimization: load checkpoint bytes directly into
# float4_e2m1fn_x2 without the BF16 round-trip.
# ── Direct NVFP4 path (no BF16 round-trip) ──
# Checkpoint stores:
# weight: uint8 packed E2M1 (2 FP4 values/byte) → view as float4_e2m1fn_x2
# weight_scale: float8_e4m3fn block scales → use directly
# weight_scale_2: float32 global scale → use directly
# The only conversion is uint8 → float4_e2m1fn_x2 (byte-preserving view cast).
#
# L1 complication: gate and up have different global scales, but the
# kernel takes one global_scale_b per expert. Solution: normalize to
# max(gate_gs, up_gs) and fold the ratio into block scales via float32
# (one multiply + float8 round-trip on the *ratio only* — much better
# than dequantizing the entire weight matrix through BF16).
from vllm.nvfp4_cutedsl import CuTeDSLMoERunner
E2M1_LUT = torch.tensor([
0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0,
-0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0,
], dtype=torch.float32, device=self.w13_weight.device)
def dequant_nvfp4(packed, scale, global_scale):
raw = packed.view(torch.uint8)
lo = E2M1_LUT[(raw & 0x0F).long()]
hi = E2M1_LUT[((raw >> 4) & 0x0F).long()]
out_features = raw.shape[0]
in_features = raw.shape[1] * 2
unpacked = torch.empty(out_features, in_features, dtype=torch.float32, device=raw.device)
unpacked[:, 0::2] = lo
unpacked[:, 1::2] = hi
bs = scale.float().repeat_interleave(16, dim=1)[:, :in_features]
return (unpacked * bs * global_scale).to(torch.bfloat16)
l1_weights_bf16 = []
l2_weights_bf16 = []
l1_fp4, l1_sf, l1_gs = [], [], []
l2_fp4, l2_sf, l2_gs = [], [], []
for e in range(self.num_local_experts):
# L1: gate + up fused
gate_w = dequant_nvfp4(
self.w13_weight.data[e, :self.intermediate_size],
self.w13_weight_scale.data[e, :self.intermediate_size],
self.w13_weight_scale_2.data[e, 0],
)
up_w = dequant_nvfp4(
self.w13_weight.data[e, self.intermediate_size:],
self.w13_weight_scale.data[e, self.intermediate_size:],
self.w13_weight_scale_2.data[e, 1],
)
fused = torch.cat([gate_w, up_w], dim=0) # (6144, hidden)
l1_weights_bf16.append(fused.T) # (hidden, 6144) — K=hidden packed dim
# ── L1: gate + up (fused) ──
gate_w = self.w13_weight.data[e, :self.intermediate_size] # (intermediate, hidden//2) uint8
up_w = self.w13_weight.data[e, self.intermediate_size:] # (intermediate, hidden//2) uint8
gate_sf = self.w13_weight_scale.data[e, :self.intermediate_size] # (intermediate, hidden//16) float8
up_sf = self.w13_weight_scale.data[e, self.intermediate_size:]
gate_gs = self.w13_weight_scale_2.data[e, 0].item() # float32 scalar
up_gs = self.w13_weight_scale_2.data[e, 1].item()
# L2: down
l2_w = dequant_nvfp4(
self.w2_weight.data[e],
self.w2_weight_scale.data[e],
self.w2_weight_scale_2.data[e],
)
l2_weights_bf16.append(l2_w.T) # (intermediate//2, hidden)
# Fuse gate+up along N dim, then transpose to K-major (K_packed, N)
# Checkpoint is (N, K_packed) → permute to (K_packed, N) = (hidden//2, 2*intermediate)
fused_w = torch.cat([gate_w, up_w], dim=0) # (2*intermediate, hidden//2)
fused_w_fp4 = fused_w.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous()
# shape: (hidden//2, 2*intermediate) — K=hidden packed, N=2*intermediate
# Create CuTeDSL runner and prepare weights
# Fuse block scales: (2*intermediate, hidden//16) = (N, K_sf) ✓
fused_sf = torch.cat([gate_sf, up_sf], dim=0)
# Handle dual global scales: normalize to max, fold ratio into block scales
l1_max_gs = max(gate_gs, up_gs)
if gate_gs != up_gs:
fused_sf_f32 = fused_sf.float()
# Gate is the first intermediate rows, up is the second
fused_sf_f32[:self.intermediate_size] *= (gate_gs / l1_max_gs)
fused_sf_f32[self.intermediate_size:] *= (up_gs / l1_max_gs)
fused_sf = fused_sf_f32.to(torch.float8_e4m3fn)
l1_fp4.append(fused_w_fp4)
l1_sf.append(fused_sf)
l1_gs.append(l1_max_gs)
# ── L2: down (single projection, straightforward) ──
down_w = self.w2_weight.data[e] # (hidden, intermediate//2) uint8
down_sf = self.w2_weight_scale.data[e] # (hidden, intermediate//16) float8
down_gs = self.w2_weight_scale_2.data[e].item() # float32 scalar
# Checkpoint is (N, K_packed) → permute to (K_packed, N)
# K=intermediate (packed dim), N=hidden
down_w_fp4 = down_w.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous()
# shape: (intermediate//2, hidden) — K=intermediate packed, N=hidden
# Block scales: (hidden, intermediate//16) = (N, K_sf) ✓ — already correct
l2_fp4.append(down_w_fp4)
l2_sf.append(down_sf)
l2_gs.append(down_gs)
# Create CuTeDSL runner with directly-cast weights
self._cutedsl_runner = CuTeDSLMoERunner(
num_experts=self.num_local_experts,
hidden_size=self.hidden_size,
intermediate_size=self.intermediate_size,
device=self.w13_weight.device,
)
self._cutedsl_runner.prepare_weights_from_dequantized(
l1_weights_bf16, l2_weights_bf16,
device=l1_fp4[0].device,
)
self._cutedsl_runner.l1_fp4 = l1_fp4
self._cutedsl_runner.l1_sf = l1_sf
self._cutedsl_runner.l1_gs = l1_gs
self._cutedsl_runner.l2_fp4 = l2_fp4
self._cutedsl_runner.l2_sf = l2_sf
self._cutedsl_runner.l2_gs = l2_gs
# Drop the original loader-side parameters
self._w13_input_scale = self.w13_input_scale.data.clone()