Server running on B200 port 8000 with full NVFP4→vLLM bridge. All critical bugs fixed: DeepGEMM scale format, compressor shapes, block scale values.
163 lines
9.1 KiB
Python
163 lines
9.1 KiB
Python
#!/usr/bin/env python3
|
|
"""Fix the bf16→uint8 handler to properly quantize to NVFP4 instead of switching to UnquantizedLinearMethod"""
|
|
|
|
filepath = "/root/nvidia-meeting/deepseek-v4-quant/patches/deepseek_v4.py"
|
|
|
|
with open(filepath, 'r') as f:
|
|
c = f.read()
|
|
|
|
old_handler = ''' # Handle bf16 → uint8 mismatch for o_a_proj:
|
|
# modelopt didn't quantize o_a_proj (bf16, no scales),
|
|
# but ModelOptNvFp4Config creates wo_a with NVFP4 quant
|
|
# (uint8 weight + scales). When loading bf16 into uint8,
|
|
# we replace the quant method with UnquantizedLinearMethod
|
|
# so the layer runs in bf16 at inference.
|
|
if (name.endswith(".weight")
|
|
and loaded_weight.dtype != torch.uint8
|
|
and param.data.dtype == torch.uint8):
|
|
# Replace this layer's quant method with unquantized
|
|
from vllm.model_executor.layers.linear import (
|
|
UnquantizedLinearMethod,
|
|
)
|
|
parts = name.rsplit(".", 1)
|
|
module_path = parts[0] # e.g., layers.0.attn.wo_a
|
|
# Find the module and override its quant method
|
|
mod = self
|
|
for attr in module_path.split("."):
|
|
if attr.isdigit():
|
|
mod = mod[int(attr)]
|
|
else:
|
|
mod = getattr(mod, attr)
|
|
if hasattr(mod, 'quant_method'):
|
|
mod.quant_method = UnquantizedLinearMethod()
|
|
# Replace the uint8 weight param with bf16
|
|
new_param = torch.nn.Parameter(
|
|
loaded_weight.clone(), requires_grad=False
|
|
)
|
|
mod.weight = new_param
|
|
# Set weight_scale_inv = 1.0 (required by
|
|
# DeepseekV4MLAModules forward pass which
|
|
# reads wo_a.weight_scale_inv directly)
|
|
mod.weight_scale_inv = torch.nn.Parameter(
|
|
torch.tensor(1.0, dtype=torch.float32),
|
|
requires_grad=False,
|
|
)
|
|
# Also set input_scale to prevent missing attr errors
|
|
if hasattr(mod, 'input_scale'):
|
|
mod.input_scale = torch.nn.Parameter(
|
|
torch.tensor(1.0, dtype=torch.float32),
|
|
requires_grad=False,
|
|
)
|
|
loaded_params.add(name)
|
|
loaded_params.add(name.replace('.weight', '.weight_scale_inv'))
|
|
continue'''
|
|
|
|
new_handler = ''' # Handle bf16 → uint8 mismatch for o_a_proj:
|
|
# modelopt didn't quantize o_a_proj (bf16, no scales),
|
|
# but ModelOptNvFp4Config creates wo_a with NVFP4 quant
|
|
# (uint8 weight + scales). We quantize the bf16 weight
|
|
# to NVFP4 at load time so the layer runs in NVFP4 path.
|
|
if (name.endswith(".weight")
|
|
and loaded_weight.dtype != torch.uint8
|
|
and param.data.dtype == torch.uint8):
|
|
# Quantize bf16 → NVFP4 (E2M1 packed uint8 + scales)
|
|
w_bf16 = loaded_weight
|
|
out_dim, in_dim = w_bf16.shape
|
|
block_size = 16
|
|
assert in_dim % block_size == 0
|
|
n_blocks = in_dim // block_size
|
|
|
|
# Reshape into blocks
|
|
w_blocks = w_bf16.reshape(out_dim, n_blocks, block_size)
|
|
|
|
# Compute per-block amax
|
|
amax = w_blocks.abs().amax(dim=-1) # [out, n_blocks]
|
|
|
|
# Global scale (weight_scale_2): max amax / (6.0 * 448.0)
|
|
global_amax = amax.max()
|
|
# Use 448.0 as the max e4m3 value for scale computation
|
|
weight_scale_2_val = global_amax / (6.0 * 448.0)
|
|
weight_scale_2 = weight_scale_2_val.to(torch.float32)
|
|
|
|
# Per-block scale (weight_scale): fp8 e4m3
|
|
# block_scale = amax / (6.0 * weight_scale_2)
|
|
block_scale = amax / (6.0 * weight_scale_2_val)
|
|
# Clamp to fp8 e4m3 range and cast
|
|
block_scale = block_scale.clamp(min=0, max=448.0)
|
|
weight_scale = block_scale.to(torch.float8_e4m3fn)
|
|
|
|
# Quantize to FP4 (E2M1)
|
|
# E2M1 LUT: 0, 0.5, 1, 1.5, 2, 3, 4, 6 (positive)
|
|
FP4_POS = torch.tensor(
|
|
[0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0],
|
|
dtype=torch.float32, device=w_bf16.device,
|
|
)
|
|
# For each block, dequantize the block scale from fp8
|
|
block_scale_f32 = weight_scale.to(torch.float32)
|
|
# Scale the weight values: normalized = w / (block_scale * weight_scale_2)
|
|
# We need to find the nearest FP4 value
|
|
scaled = w_blocks / (block_scale_f32.unsqueeze(-1) * weight_scale_2_val)
|
|
# Find nearest FP4 index (0-7 for magnitude)
|
|
# Use absolute value for matching, then apply sign
|
|
scaled_abs = scaled.abs()
|
|
# Find closest FP4 value
|
|
diff = (scaled_abs.unsqueeze(-1) - FP4_POS).abs()
|
|
fp4_idx = diff.argmin(dim=-1) # [out, n_blocks, block_size]
|
|
# Apply sign: negative values get bit 3 set
|
|
sign = (scaled < 0).int()
|
|
fp4_val = (sign << 3) | fp4_idx.int()
|
|
# Pack: 2 FP4 values per uint8 byte
|
|
# Even positions → lower nibble, Odd → upper nibble
|
|
fp4_flat = fp4_val.reshape(out_dim, -1) # [out, in_dim]
|
|
assert fp4_flat.shape[1] % 2 == 0
|
|
even = fp4_flat[:, 0::2] # lower nibble
|
|
odd = fp4_flat[:, 1::2] # upper nibble
|
|
packed = (odd << 4) | even
|
|
weight_packed = packed.to(torch.uint8)
|
|
|
|
# Reshape weight_scale to [out, n_blocks]
|
|
weight_scale_2d = weight_scale.reshape(out_dim, n_blocks)
|
|
|
|
# Load the quantized weight into the uint8 param
|
|
weight_loader = param.weight_loader
|
|
weight_loader(param, weight_packed)
|
|
loaded_params.add(name)
|
|
|
|
# Load scales into sibling params
|
|
base = name.rsplit(".", 1)[0]
|
|
# weight_scale
|
|
ws_name = f"{base}.weight_scale"
|
|
if ws_name in params_dict:
|
|
ws_param = params_dict[ws_name]
|
|
ws_loader = getattr(ws_param, "weight_loader", default_weight_loader)
|
|
ws_loader(ws_param, weight_scale_2d)
|
|
loaded_params.add(ws_name)
|
|
# weight_scale_2
|
|
ws2_name = f"{base}.weight_scale_2"
|
|
if ws2_name in params_dict:
|
|
ws2_param = params_dict[ws2_name]
|
|
ws2_loader = getattr(ws2_param, "weight_loader", default_weight_loader)
|
|
ws2_loader(ws2_param, weight_scale_2.reshape(1))
|
|
loaded_params.add(ws2_name)
|
|
# input_scale: use 1.0 default (dynamic quant)
|
|
is_name = f"{base}.input_scale"
|
|
if is_name in params_dict:
|
|
is_param = params_dict[is_name]
|
|
is_loader = getattr(is_param, "weight_loader", default_weight_loader)
|
|
is_loader(is_param, torch.tensor(1.0, dtype=torch.float32))
|
|
loaded_params.add(is_name)
|
|
continue'''
|
|
|
|
if old_handler in c:
|
|
c = c.replace(old_handler, new_handler)
|
|
print('FIX 5 applied: Replaced UnquantizedLinearMethod with proper NVFP4 quantization')
|
|
else:
|
|
print('FIX 5: Could not find exact handler block, trying flexible match...')
|
|
if 'UnquantizedLinearMethod' in c:
|
|
print(' Found UnquantizedLinearMethod in code - manual fix needed')
|
|
else:
|
|
print(' UnquantizedLinearMethod not found - already replaced?')
|
|
|
|
with open(filepath, 'w') as f:
|
|
f.write(c)
|