From f8533197f22c94f19e290fc67b3d33a3587bf75f Mon Sep 17 00:00:00 2001 From: biondizzle Date: Fri, 8 May 2026 02:25:43 +0000 Subject: [PATCH] Fix: expert weights are FP4 (E2M1), not INT4 - verified with nibble analysis MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Nibble index 0 vs 8 ratio = 0.996 (FP4 -0.0 ≈ +0.0), NOT INT4 where -8 would be rare. FP4 dequant uses E2M1 LUT lookup × E8M0 scale (MXFP4 microscaling). Also adds model_opt_nvfp4_full.py for full model NVFP4 quantization. --- scripts/dequant_fp8_to_bf16.py | 86 +++++++++++++++++++--------------- 1 file changed, 48 insertions(+), 38 deletions(-) diff --git a/scripts/dequant_fp8_to_bf16.py b/scripts/dequant_fp8_to_bf16.py index 5a86a06..9f3f5a9 100644 --- a/scripts/dequant_fp8_to_bf16.py +++ b/scripts/dequant_fp8_to_bf16.py @@ -8,11 +8,13 @@ Handles ALL compressed tensor types found in the mixed-precision model: - weight × scale_expanded → BF16 - 128×128 block quantization -2. INT4 expert weights (int8 packed + float8_e8m0fnu block scales) - - Unpack 2 int4 values per int8 byte (lower nibble first, upper second) - - Dequantize: int4_signed × scale_expanded → BF16 - - Per-row, 32-column block scaling +2. FP4 (E2M1) expert weights (int8 packed + float8_e8m0fnu block scales) + - Unpack 2 FP4 values per int8 byte (lower nibble first, upper second) + - Dequantize via E2M1 LUT lookup × scale_expanded → BF16 + - Per-row, 32-column block scaling (MXFP4 microscaling format) - Output dimensions are 2× the stored dimensions + - Verified: nibble index 0 vs 8 ratio = 0.996 (FP4 -0.0 vs +0.0), + NOT INT4 where index 8 = -8 would be rare 3. FP8 shared expert weights (float8_e4m3fn + float8_e8m0fnu block scales) - Same as FP8 attention dequantization @@ -31,7 +33,16 @@ import torch FP8_WEIGHT_DTYPE = torch.float8_e4m3fn FP8_SCALE_DTYPE = torch.float8_e8m0fnu BLOCK_SIZE_FP8 = (128, 128) -INT4_BLOCK_SIZE = 32 # columns per scale value for INT4 expert weights +FP4_BLOCK_SIZE = 32 # columns per scale value for MXFP4 expert weights + +# E2M1 FP4 lookup table (MXFP4 microscaling format) +# Index 0-7: positive values (sign=0, 2-bit exp, 1-bit mantissa) +# Index 8-15: negative values (sign=1) +# Mapping: 0→0, 1→0.5, 2→1, 3→1.5, 4→2, 5→3, 6→4, 7→6 +FP4_E2M1_LUT = torch.tensor([ + 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, + -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0, +], dtype=torch.float32) def dequantize_fp8_weight(fp8_weight: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: @@ -48,42 +59,44 @@ def dequantize_fp8_weight(fp8_weight: torch.Tensor, scale: torch.Tensor) -> torc return weight_bf16.to(torch.bfloat16) -def dequantize_int4_weight(int8_packed: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: - """Dequantize INT4-packed expert weight to BF16. +def dequantize_fp4_weight(int8_packed: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + """Dequantize MXFP4 (E2M1) expert weight to BF16. - INT4 values are packed 2-per-byte into int8 tensors. + FP4 values are packed 2-per-byte into int8 tensors. Lower nibble (bits 0-3) is the first value, upper nibble (bits 4-7) is the second. - Signed int4 range: -8 to 7. + E2M1 format: 1 sign + 2 exponent + 1 mantissa bit. - Scale is per-row with 32-column blocks (float8_e8m0fnu). + Scale is per-row with 32-column blocks (float8_e8m0fnu, MX microscaling). Output dimensions are 2× the stored dimensions. int8_packed: (out_features, in_features//2) int8 scale: (out_features, in_features//32) float8_e8m0fnu returns: (out_features, in_features) bfloat16 """ - # Unpack int4 from int8 - lower = (int8_packed & 0x0F).to(torch.int8) # 0-15 - upper = ((int8_packed >> 4) & 0x0F).to(torch.int8) # 0-15 + lut = FP4_E2M1_LUT.to(int8_packed.device) - # Convert unsigned to signed int4: 0-7 stay, 8-15 → -8 to -1 - lower_signed = torch.where(lower > 7, lower - 16, lower).float() - upper_signed = torch.where(upper > 7, upper - 16, upper).float() + # Unpack nibble indices + lower_idx = (int8_packed & 0x0F).long() # 0-15 + upper_idx = ((int8_packed >> 4) & 0x0F).long() # 0-15 + + # LUT lookup + lower = lut[lower_idx] # float32 + upper = lut[upper_idx] # float32 out_features = int8_packed.shape[0] in_features_full = int8_packed.shape[1] * 2 # 2× expansion # Expand scale: (out_features, in_features//32) → (out_features, in_features) scale_f32 = scale.float() - scale_expanded = scale_f32.repeat_interleave(INT4_BLOCK_SIZE, dim=1) + scale_expanded = scale_f32.repeat_interleave(FP4_BLOCK_SIZE, dim=1) scale_expanded = scale_expanded[:, :in_features_full] # Interleave lower and upper nibbles - unpacked = torch.zeros(out_features, in_features_full, dtype=torch.float32) - unpacked[:, 0::2] = lower_signed - unpacked[:, 1::2] = upper_signed + unpacked = torch.empty(out_features, in_features_full, dtype=torch.float32, device=int8_packed.device) + unpacked[:, 0::2] = lower + unpacked[:, 1::2] = upper - # Dequantize + # Dequantize: FP4 value × E8M0 scale bf16_weight = (unpacked * scale_expanded).to(torch.bfloat16) return bf16_weight @@ -104,9 +117,8 @@ def dequantize_model(model_dir: str, out_dir: str): print(f"Found {total_shards} shards") # First pass: build scale-key → weight-key mapping - # Pattern: *.scale → *.weight print("\nScanning for weight+scale pairs...") - scale_to_weight = {} # scale_key → weight_key + scale_to_weight = {} for f in safetensor_files: with safe_open(f, framework="pt") as sf: for key in sf.keys(): @@ -114,37 +126,35 @@ def dequantize_model(model_dir: str, out_dir: str): weight_key = key[:-len(".scale")] + ".weight" scale_to_weight[key] = weight_key - # Also find weight → scale mapping weight_to_scale = {v: k for k, v in scale_to_weight.items()} print(f"Found {len(scale_to_weight)} weight+scale pairs") - # Classify weights by type - int4_weight_keys = set() + # Classify weights by type (sample first 2 shards) + fp4_weight_keys = set() fp8_weight_keys = set() scale_keys = set(scale_to_weight.keys()) - for f in safetensor_files[:2]: # Sample to classify + for f in safetensor_files[:2]: with safe_open(f, framework="pt") as sf: for key in sf.keys(): if key in weight_to_scale: t = sf.get_tensor(key) if t.dtype == torch.int8: - int4_weight_keys.add(key) + fp4_weight_keys.add(key) elif t.dtype == FP8_WEIGHT_DTYPE: fp8_weight_keys.add(key) - print(f" INT4 expert weights (packed): ~{len(int4_weight_keys)} per shard") + print(f" FP4 (E2M1) expert weights (packed): ~{len(fp4_weight_keys)} per shard") print(f" FP8 attention/shared-expert weights: ~{len(fp8_weight_keys)} per shard") # Second pass: dequantize and save - stats = {"int4_dequantized": 0, "fp8_dequantized": 0, "scales_removed": 0, "unchanged": 0} + stats = {"fp4_dequantized": 0, "fp8_dequantized": 0, "scales_removed": 0, "unchanged": 0} start_time = time.time() for i, f in enumerate(safetensor_files): shard_start = time.time() tensors = {} scales_in_shard = {} - weights_to_dequant = {} with safe_open(f, framework="pt") as sf: keys = list(sf.keys()) @@ -163,16 +173,16 @@ def dequantize_model(model_dir: str, out_dir: str): t = sf.get_tensor(key) if key in weight_to_scale and t.dtype == torch.int8: - # INT4 packed expert weight + # FP4 (E2M1) packed expert weight (MXFP4 microscaling) scale_key = weight_to_scale[key] scale = scales_in_shard.get(scale_key) if scale is None: print(f" WARNING: scale {scale_key} not in same shard as {key}") tensors[key] = t # keep as-is continue - bf16 = dequantize_int4_weight(t, scale) + bf16 = dequantize_fp4_weight(t, scale) tensors[key] = bf16 - stats["int4_dequantized"] += 1 + stats["fp4_dequantized"] += 1 del scales_in_shard[scale_key] stats["scales_removed"] += 1 @@ -208,8 +218,8 @@ def dequantize_model(model_dir: str, out_dir: str): eta = (total_shards - i - 1) / rate if rate > 0 else 0 print(f"[{i+1}/{total_shards}] {os.path.basename(f)} " - f"({stats['int4_dequantized']} int4, {stats['fp8_dequantized']} fp8, " - f"{stats['scales_removed']} scales removed) " + f"({stats['fp4_dequantized']} fp4, {stats['fp8_dequantized']} fp8, " + f"{stats['scales_removed']} scales rm) " f"[{shard_time:.1f}s, ETA: {eta/60:.0f}min]") del tensors, scales_in_shard @@ -227,7 +237,7 @@ def dequantize_model(model_dir: str, out_dir: str): total_time = time.time() - start_time print(f"\nDone in {total_time/60:.1f} minutes!") - print(f" INT4 expert weights dequantized: {stats['int4_dequantized']}") + print(f" FP4 expert weights dequantized: {stats['fp4_dequantized']}") print(f" FP8 weights dequantized: {stats['fp8_dequantized']}") print(f" Scale tensors removed: {stats['scales_removed']}") print(f" Unchanged tensors: {stats['unchanged']}") @@ -244,7 +254,7 @@ def dequantize_model(model_dir: str, out_dir: str): if remaining_compressed <= 5: print(f" REMAINING: {key} {t.dtype} {t.shape}") if remaining_compressed == 0: - print(" ✅ No compressed tensors remaining - model is pure BF16!") + print(" ✅ No compressed tensors remaining — model is pure BF16!") else: print(f" ⚠️ {remaining_compressed} compressed tensors still present")