Files
deepseek-v4-quant/tmp/fix5_nvfp4.py
biondizzle 02b8ea536f Update MEMORY.md and memory files with vLLM NVFP4 serving progress
Server running on B200 port 8000 with full NVFP4→vLLM bridge.
All critical bugs fixed: DeepGEMM scale format, compressor shapes, block scale values.
2026-05-11 02:02:49 +00:00

163 lines
9.1 KiB
Python

#!/usr/bin/env python3
"""Fix the bf16→uint8 handler to properly quantize to NVFP4 instead of switching to UnquantizedLinearMethod"""
filepath = "/root/nvidia-meeting/deepseek-v4-quant/patches/deepseek_v4.py"
with open(filepath, 'r') as f:
c = f.read()
old_handler = ''' # Handle bf16 → uint8 mismatch for o_a_proj:
# modelopt didn't quantize o_a_proj (bf16, no scales),
# but ModelOptNvFp4Config creates wo_a with NVFP4 quant
# (uint8 weight + scales). When loading bf16 into uint8,
# we replace the quant method with UnquantizedLinearMethod
# so the layer runs in bf16 at inference.
if (name.endswith(".weight")
and loaded_weight.dtype != torch.uint8
and param.data.dtype == torch.uint8):
# Replace this layer's quant method with unquantized
from vllm.model_executor.layers.linear import (
UnquantizedLinearMethod,
)
parts = name.rsplit(".", 1)
module_path = parts[0] # e.g., layers.0.attn.wo_a
# Find the module and override its quant method
mod = self
for attr in module_path.split("."):
if attr.isdigit():
mod = mod[int(attr)]
else:
mod = getattr(mod, attr)
if hasattr(mod, 'quant_method'):
mod.quant_method = UnquantizedLinearMethod()
# Replace the uint8 weight param with bf16
new_param = torch.nn.Parameter(
loaded_weight.clone(), requires_grad=False
)
mod.weight = new_param
# Set weight_scale_inv = 1.0 (required by
# DeepseekV4MLAModules forward pass which
# reads wo_a.weight_scale_inv directly)
mod.weight_scale_inv = torch.nn.Parameter(
torch.tensor(1.0, dtype=torch.float32),
requires_grad=False,
)
# Also set input_scale to prevent missing attr errors
if hasattr(mod, 'input_scale'):
mod.input_scale = torch.nn.Parameter(
torch.tensor(1.0, dtype=torch.float32),
requires_grad=False,
)
loaded_params.add(name)
loaded_params.add(name.replace('.weight', '.weight_scale_inv'))
continue'''
new_handler = ''' # Handle bf16 → uint8 mismatch for o_a_proj:
# modelopt didn't quantize o_a_proj (bf16, no scales),
# but ModelOptNvFp4Config creates wo_a with NVFP4 quant
# (uint8 weight + scales). We quantize the bf16 weight
# to NVFP4 at load time so the layer runs in NVFP4 path.
if (name.endswith(".weight")
and loaded_weight.dtype != torch.uint8
and param.data.dtype == torch.uint8):
# Quantize bf16 → NVFP4 (E2M1 packed uint8 + scales)
w_bf16 = loaded_weight
out_dim, in_dim = w_bf16.shape
block_size = 16
assert in_dim % block_size == 0
n_blocks = in_dim // block_size
# Reshape into blocks
w_blocks = w_bf16.reshape(out_dim, n_blocks, block_size)
# Compute per-block amax
amax = w_blocks.abs().amax(dim=-1) # [out, n_blocks]
# Global scale (weight_scale_2): max amax / (6.0 * 448.0)
global_amax = amax.max()
# Use 448.0 as the max e4m3 value for scale computation
weight_scale_2_val = global_amax / (6.0 * 448.0)
weight_scale_2 = weight_scale_2_val.to(torch.float32)
# Per-block scale (weight_scale): fp8 e4m3
# block_scale = amax / (6.0 * weight_scale_2)
block_scale = amax / (6.0 * weight_scale_2_val)
# Clamp to fp8 e4m3 range and cast
block_scale = block_scale.clamp(min=0, max=448.0)
weight_scale = block_scale.to(torch.float8_e4m3fn)
# Quantize to FP4 (E2M1)
# E2M1 LUT: 0, 0.5, 1, 1.5, 2, 3, 4, 6 (positive)
FP4_POS = torch.tensor(
[0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0],
dtype=torch.float32, device=w_bf16.device,
)
# For each block, dequantize the block scale from fp8
block_scale_f32 = weight_scale.to(torch.float32)
# Scale the weight values: normalized = w / (block_scale * weight_scale_2)
# We need to find the nearest FP4 value
scaled = w_blocks / (block_scale_f32.unsqueeze(-1) * weight_scale_2_val)
# Find nearest FP4 index (0-7 for magnitude)
# Use absolute value for matching, then apply sign
scaled_abs = scaled.abs()
# Find closest FP4 value
diff = (scaled_abs.unsqueeze(-1) - FP4_POS).abs()
fp4_idx = diff.argmin(dim=-1) # [out, n_blocks, block_size]
# Apply sign: negative values get bit 3 set
sign = (scaled < 0).int()
fp4_val = (sign << 3) | fp4_idx.int()
# Pack: 2 FP4 values per uint8 byte
# Even positions → lower nibble, Odd → upper nibble
fp4_flat = fp4_val.reshape(out_dim, -1) # [out, in_dim]
assert fp4_flat.shape[1] % 2 == 0
even = fp4_flat[:, 0::2] # lower nibble
odd = fp4_flat[:, 1::2] # upper nibble
packed = (odd << 4) | even
weight_packed = packed.to(torch.uint8)
# Reshape weight_scale to [out, n_blocks]
weight_scale_2d = weight_scale.reshape(out_dim, n_blocks)
# Load the quantized weight into the uint8 param
weight_loader = param.weight_loader
weight_loader(param, weight_packed)
loaded_params.add(name)
# Load scales into sibling params
base = name.rsplit(".", 1)[0]
# weight_scale
ws_name = f"{base}.weight_scale"
if ws_name in params_dict:
ws_param = params_dict[ws_name]
ws_loader = getattr(ws_param, "weight_loader", default_weight_loader)
ws_loader(ws_param, weight_scale_2d)
loaded_params.add(ws_name)
# weight_scale_2
ws2_name = f"{base}.weight_scale_2"
if ws2_name in params_dict:
ws2_param = params_dict[ws2_name]
ws2_loader = getattr(ws2_param, "weight_loader", default_weight_loader)
ws2_loader(ws2_param, weight_scale_2.reshape(1))
loaded_params.add(ws2_name)
# input_scale: use 1.0 default (dynamic quant)
is_name = f"{base}.input_scale"
if is_name in params_dict:
is_param = params_dict[is_name]
is_loader = getattr(is_param, "weight_loader", default_weight_loader)
is_loader(is_param, torch.tensor(1.0, dtype=torch.float32))
loaded_params.add(is_name)
continue'''
if old_handler in c:
c = c.replace(old_handler, new_handler)
print('FIX 5 applied: Replaced UnquantizedLinearMethod with proper NVFP4 quantization')
else:
print('FIX 5: Could not find exact handler block, trying flexible match...')
if 'UnquantizedLinearMethod' in c:
print(' Found UnquantizedLinearMethod in code - manual fix needed')
else:
print(' UnquantizedLinearMethod not found - already replaced?')
with open(filepath, 'w') as f:
f.write(c)