From 389453fbf4d224230ff3594dbe6a7c79cc7e218e Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 16 May 2026 03:41:23 +0000 Subject: [PATCH] =?UTF-8?q?feat:=20direct=20NVFP4=20path=20=E2=80=94=20no?= =?UTF-8?q?=20BF16=20round-trip=20on=20weights?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit finalize_weights() now view-casts checkpoint uint8 → float4_e2m1fn_x2 directly. Block scales (float8_e4m3fn) and global scales (float32) pass through unchanged. Zero precision loss on the weights themselves. L1 dual global scale handling: gate and up have different global scales. Normalize to max(gate_gs, up_gs) and fold the ratio into block scales via float32 (one multiply + float8 round-trip on the RATIO only — much better than dequantizing the entire weight matrix). layertest.py: updated to test direct path. Expect cosine improvement from 0.989 → 0.995+ (matching the L1-only result). --- tests/layertest.py | 73 +++++++++++++++++++++-- vllm/nvfp4_cutedsl.py | 14 +++++ vllm/patches/deepseek_v4.py | 115 +++++++++++++++++++++--------------- 3 files changed, 150 insertions(+), 52 deletions(-) diff --git a/tests/layertest.py b/tests/layertest.py index 23e108b9..d6da390c 100644 --- a/tests/layertest.py +++ b/tests/layertest.py @@ -16,7 +16,6 @@ REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.insert(0, REPO_ROOT) from cutedsl.moe_pipeline import ( - prepare_nvfp4_moe_weights, run_nvfp4_moe, ) @@ -121,6 +120,72 @@ def moe_forward_bf16(hidden_states, experts, expert_ids, expert_weights): return output +def prepare_nvfp4_weights_direct(nvfp4_tensors, layer_idx, expert_indices, intermediate_size): + """Prepare weights via direct view-cast (no BF16 round-trip). + + Checkpoint uint8 → float4_e2m1fn_x2 (byte-preserving). + Block scales float8_e4m3fn → used directly. + Global scales float32 → used directly. + + For L1 (gate+up fused): normalize dual global scales to max, fold ratio + into block scales via float32 (one multiply + float8 round-trip on ratio only). + """ + l1_fp4, l1_sf, l1_gs = [], [], [] + l2_fp4, l2_sf, l2_gs = [], [], [] + + for e in expert_indices: + # L1: gate + up + gate_w = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight"].to(DEVICE) + up_w = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight"].to(DEVICE) + gate_sf = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight_scale"].to(DEVICE) + up_sf = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight_scale"].to(DEVICE) + gate_gs = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight_scale_2"].item() + up_gs = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight_scale_2"].item() + + # Fuse gate+up along N, transpose to K-major + fused_w = torch.cat([gate_w, up_w], dim=0) # (2*intermediate, hidden//2) uint8 + fused_w_fp4 = fused_w.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous() + # (hidden//2, 2*intermediate) — K=hidden packed, N=2*intermediate + + fused_sf = torch.cat([gate_sf, up_sf], dim=0) # (2*intermediate, hidden//16) + + # Normalize dual global scales + l1_max_gs = max(gate_gs, up_gs) + if gate_gs != up_gs: + fused_sf_f32 = fused_sf.float() + fused_sf_f32[:intermediate_size] *= (gate_gs / l1_max_gs) + fused_sf_f32[intermediate_size:] *= (up_gs / l1_max_gs) + fused_sf = fused_sf_f32.to(torch.float8_e4m3fn) + + l1_fp4.append(fused_w_fp4) + l1_sf.append(fused_sf) + l1_gs.append(l1_max_gs) + + # L2: down + down_key = f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight" + if down_key in nvfp4_tensors: + down_w = nvfp4_tensors[down_key].to(DEVICE) + down_sf = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight_scale"].to(DEVICE) + down_gs = nvfp4_tensors[f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight_scale_2"].item() + + down_w_fp4 = down_w.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous() + # (intermediate//2, hidden) — K=intermediate packed, N=hidden + + l2_fp4.append(down_w_fp4) + l2_sf.append(down_sf) + l2_gs.append(down_gs) + else: + # Expert 211 has no down_proj + l2_fp4.append(torch.zeros(3072 // 2, 7168, dtype=torch.float4_e2m1fn_x2, device=DEVICE)) + l2_sf.append(torch.ones(7168, 3072 // 16, dtype=torch.float8_e4m3fn, device=DEVICE)) + l2_gs.append(1.0) + + return { + 'l1_fp4': l1_fp4, 'l1_sf': l1_sf, 'l1_gs': l1_gs, + 'l2_fp4': l2_fp4, 'l2_sf': l2_sf, 'l2_gs': l2_gs, + } + + def main(): torch.manual_seed(42) expert_indices = [0, 1, 2] @@ -135,9 +200,9 @@ def main(): nvfp4_tensors = load_layer_tensors(NVFP4_MODEL_DIR, LAYER_IDX) print(f" {len(nvfp4_tensors)} tensors loaded") - # Prepare weights - print("\n Preparing NVFP4 weights...") - weights = prepare_nvfp4_moe_weights(nvfp4_tensors, LAYER_IDX, expert_indices) + # Prepare weights — DIRECT PATH (no BF16 round-trip) + print("\n Preparing NVFP4 weights (direct view-cast)...") + weights = prepare_nvfp4_weights_direct(nvfp4_tensors, LAYER_IDX, expert_indices, 3072) print(f" L1: {len(weights['l1_fp4'])} experts, shape {weights['l1_fp4'][0].shape}") print(f" L2: {len(weights['l2_fp4'])} experts, shape {weights['l2_fp4'][0].shape}") diff --git a/vllm/nvfp4_cutedsl.py b/vllm/nvfp4_cutedsl.py index fa5456bc..f2ce0299 100644 --- a/vllm/nvfp4_cutedsl.py +++ b/vllm/nvfp4_cutedsl.py @@ -54,6 +54,20 @@ class CuTeDSLMoERunner: self.l2_sf = None self.l2_gs = None + def prepare_weights_direct(self, l1_fp4, l1_sf, l1_gs, l2_fp4, l2_sf, l2_gs): + """Set weights directly from checkpoint (no dequant→requant). + + Use this when you've view-cast checkpoint uint8 → float4_e2m1fn_x2 + and passed block scales / global scales through directly. + Zero precision loss — the bytes are identical. + """ + self.l1_fp4 = l1_fp4 + self.l1_sf = l1_sf + self.l1_gs = l1_gs + self.l2_fp4 = l2_fp4 + self.l2_sf = l2_sf + self.l2_gs = l2_gs + def prepare_weights_from_dequantized(self, l1_weights_bf16, l2_weights_bf16): """Prepare NVFP4 weights from dequantized BF16 tensors. diff --git a/vllm/patches/deepseek_v4.py b/vllm/patches/deepseek_v4.py index 71e43bf5..02b700c5 100644 --- a/vllm/patches/deepseek_v4.py +++ b/vllm/patches/deepseek_v4.py @@ -417,65 +417,84 @@ class DeepseekV4MegaMoEExperts(nn.Module): self._check_runtime_supported() - # Dequantize checkpoint NVFP4 → BF16, then re-quantize to native - # float4_e2m1fn_x2 for the CuTeDSL kernel. - # Future optimization: load checkpoint bytes directly into - # float4_e2m1fn_x2 without the BF16 round-trip. + # ── Direct NVFP4 path (no BF16 round-trip) ── + # Checkpoint stores: + # weight: uint8 packed E2M1 (2 FP4 values/byte) → view as float4_e2m1fn_x2 + # weight_scale: float8_e4m3fn block scales → use directly + # weight_scale_2: float32 global scale → use directly + # The only conversion is uint8 → float4_e2m1fn_x2 (byte-preserving view cast). + # + # L1 complication: gate and up have different global scales, but the + # kernel takes one global_scale_b per expert. Solution: normalize to + # max(gate_gs, up_gs) and fold the ratio into block scales via float32 + # (one multiply + float8 round-trip on the *ratio only* — much better + # than dequantizing the entire weight matrix through BF16). + from vllm.nvfp4_cutedsl import CuTeDSLMoERunner - E2M1_LUT = torch.tensor([ - 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, - -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0, - ], dtype=torch.float32, device=self.w13_weight.device) - - def dequant_nvfp4(packed, scale, global_scale): - raw = packed.view(torch.uint8) - lo = E2M1_LUT[(raw & 0x0F).long()] - hi = E2M1_LUT[((raw >> 4) & 0x0F).long()] - out_features = raw.shape[0] - in_features = raw.shape[1] * 2 - unpacked = torch.empty(out_features, in_features, dtype=torch.float32, device=raw.device) - unpacked[:, 0::2] = lo - unpacked[:, 1::2] = hi - bs = scale.float().repeat_interleave(16, dim=1)[:, :in_features] - return (unpacked * bs * global_scale).to(torch.bfloat16) - - l1_weights_bf16 = [] - l2_weights_bf16 = [] + l1_fp4, l1_sf, l1_gs = [], [], [] + l2_fp4, l2_sf, l2_gs = [], [], [] for e in range(self.num_local_experts): - # L1: gate + up fused - gate_w = dequant_nvfp4( - self.w13_weight.data[e, :self.intermediate_size], - self.w13_weight_scale.data[e, :self.intermediate_size], - self.w13_weight_scale_2.data[e, 0], - ) - up_w = dequant_nvfp4( - self.w13_weight.data[e, self.intermediate_size:], - self.w13_weight_scale.data[e, self.intermediate_size:], - self.w13_weight_scale_2.data[e, 1], - ) - fused = torch.cat([gate_w, up_w], dim=0) # (6144, hidden) - l1_weights_bf16.append(fused.T) # (hidden, 6144) — K=hidden packed dim + # ── L1: gate + up (fused) ── + gate_w = self.w13_weight.data[e, :self.intermediate_size] # (intermediate, hidden//2) uint8 + up_w = self.w13_weight.data[e, self.intermediate_size:] # (intermediate, hidden//2) uint8 + gate_sf = self.w13_weight_scale.data[e, :self.intermediate_size] # (intermediate, hidden//16) float8 + up_sf = self.w13_weight_scale.data[e, self.intermediate_size:] + gate_gs = self.w13_weight_scale_2.data[e, 0].item() # float32 scalar + up_gs = self.w13_weight_scale_2.data[e, 1].item() - # L2: down - l2_w = dequant_nvfp4( - self.w2_weight.data[e], - self.w2_weight_scale.data[e], - self.w2_weight_scale_2.data[e], - ) - l2_weights_bf16.append(l2_w.T) # (intermediate//2, hidden) + # Fuse gate+up along N dim, then transpose to K-major (K_packed, N) + # Checkpoint is (N, K_packed) → permute to (K_packed, N) = (hidden//2, 2*intermediate) + fused_w = torch.cat([gate_w, up_w], dim=0) # (2*intermediate, hidden//2) + fused_w_fp4 = fused_w.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous() + # shape: (hidden//2, 2*intermediate) — K=hidden packed, N=2*intermediate - # Create CuTeDSL runner and prepare weights + # Fuse block scales: (2*intermediate, hidden//16) = (N, K_sf) ✓ + fused_sf = torch.cat([gate_sf, up_sf], dim=0) + + # Handle dual global scales: normalize to max, fold ratio into block scales + l1_max_gs = max(gate_gs, up_gs) + if gate_gs != up_gs: + fused_sf_f32 = fused_sf.float() + # Gate is the first intermediate rows, up is the second + fused_sf_f32[:self.intermediate_size] *= (gate_gs / l1_max_gs) + fused_sf_f32[self.intermediate_size:] *= (up_gs / l1_max_gs) + fused_sf = fused_sf_f32.to(torch.float8_e4m3fn) + + l1_fp4.append(fused_w_fp4) + l1_sf.append(fused_sf) + l1_gs.append(l1_max_gs) + + # ── L2: down (single projection, straightforward) ── + down_w = self.w2_weight.data[e] # (hidden, intermediate//2) uint8 + down_sf = self.w2_weight_scale.data[e] # (hidden, intermediate//16) float8 + down_gs = self.w2_weight_scale_2.data[e].item() # float32 scalar + + # Checkpoint is (N, K_packed) → permute to (K_packed, N) + # K=intermediate (packed dim), N=hidden + down_w_fp4 = down_w.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous() + # shape: (intermediate//2, hidden) — K=intermediate packed, N=hidden + + # Block scales: (hidden, intermediate//16) = (N, K_sf) ✓ — already correct + + l2_fp4.append(down_w_fp4) + l2_sf.append(down_sf) + l2_gs.append(down_gs) + + # Create CuTeDSL runner with directly-cast weights self._cutedsl_runner = CuTeDSLMoERunner( num_experts=self.num_local_experts, hidden_size=self.hidden_size, intermediate_size=self.intermediate_size, - device=self.w13_weight.device, - ) - self._cutedsl_runner.prepare_weights_from_dequantized( - l1_weights_bf16, l2_weights_bf16, + device=l1_fp4[0].device, ) + self._cutedsl_runner.l1_fp4 = l1_fp4 + self._cutedsl_runner.l1_sf = l1_sf + self._cutedsl_runner.l1_gs = l1_gs + self._cutedsl_runner.l2_fp4 = l2_fp4 + self._cutedsl_runner.l2_sf = l2_sf + self._cutedsl_runner.l2_gs = l2_gs # Drop the original loader-side parameters self._w13_input_scale = self.w13_input_scale.data.clone()