tweax n shit

This commit is contained in:
2026-05-12 23:16:33 +00:00
parent 2bdda36bb7
commit f08bcd456b
3 changed files with 89 additions and 206 deletions

View File

@@ -8,18 +8,18 @@ RUN apt-get update && apt-get install -y git screen cmake && rm -rf /var/lib/apt
# Clone and build DeepGEMM with NVFP4 mega_moe kernel
# CACHE_BUSTER: increment to force fresh clone
RUN git clone -b nvfp4-mega-moe https://sweetapi.com/biondizzle/DeepGEMM.git /root/DeepGEMM && PATCH_CACHE_BUSTER=69
RUN git clone -b nvfp4-mega-moe https://sweetapi.com/biondizzle/DeepGEMM.git /root/DeepGEMM && PATCH_CACHE_BUSTER=70
# Build DeepGEMM with proper CUDA/NVRTC paths
ENV CPATH="/usr/local/lib/python3.12/dist-packages/flashinfer/data/cutlass/include:/usr/local/lib/python3.12/dist-packages/nvidia/cu13/include:/usr/local/cuda-13.0/include:${CPATH}"
ENV PYTHONPATH="/root/DeepGEMM:${PYTHONPATH}"
# NVRTC lives in the pip nvidia/cu13 package, but the linker expects it in cuda/lib64
# Create a symlink so -lnvrtc resolves
RUN ln -sf /usr/local/lib/python3.12/dist-packages/nvidia/cu13/lib/libnvrtc.so.13 /usr/local/cuda/lib64/libnvrtc.so && PATCH_CACHE_BUSTER=69
RUN ln -sf /usr/local/lib/python3.12/dist-packages/nvidia/cu13/lib/libnvrtc.so.13 /usr/local/cuda/lib64/libnvrtc.so && PATCH_CACHE_BUSTER=70
RUN cd /root/DeepGEMM && python3 setup.py build_ext --inplace && PATCH_CACHE_BUSTER=69
# Bust cache for patch changes — ARG before COPY ensures layer invalidation
ARG PATCH_CACHE_BUSTER=69
ARG PATCH_CACHE_BUSTER=70
# Copy our DeepSeek V4 patch over vLLM's model file
COPY patches/deepseek_v4.py /usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/deepseek_v4.py
# Copy the NVFP4 staging kernel (BF16→E2M1+UE4M3 quantization for activations)

View File

@@ -12,7 +12,7 @@ services:
command:
- /model
- --trust-remote-code
- --kv-cache-dtype=fp8
#- --kv-cache-dtype=fp8 # maybe we just let it figure its own shit out
#- --block-size=256
- --enable-expert-parallel
- --tensor-parallel-size=8

View File

@@ -1,20 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ==============================================================================
# DeepSeek V4 NVFP4 Patch — Version Banner (printed at import time)
# ==============================================================================
import datetime as _dt
import os as _os
_git_commit = _os.popen("git -C /root/nvidia-meeting/deepseek-v4-quant rev-parse --short HEAD 2>/dev/null || echo 'unknown'").read().strip()
print(f"""
{'='*70}
DeepSeek V4 NVFP4 Patch (commit {_git_commit})
Loaded: {_dt.datetime.now().strftime('%Y-%m-%d %H:%M:%S UTC')}
{'='*70}
""")
# ==============================================================================
import typing
from collections.abc import Callable, Iterable
from itertools import islice
@@ -165,8 +151,9 @@ class DeepseekV4FP8Config(Fp8Config):
try:
hf_config = get_current_vllm_config().model_config.hf_config
except Exception:
# vllm_config not yet set; defer the decision until a
# later call lands inside set_current_vllm_config.
# vllm_config not yet set; return safe default but do NOT
# cache — a later call inside set_current_vllm_config may
# resolve differently.
return "fp4"
expert_dtype = getattr(hf_config, "expert_dtype", "fp4")
if expert_dtype not in _DEEPSEEK_V4_EXPERT_DTYPES:
@@ -175,11 +162,6 @@ class DeepseekV4FP8Config(Fp8Config):
f"expected one of {_DEEPSEEK_V4_EXPERT_DTYPES}."
)
self._resolved_expert_dtype = expert_dtype
from vllm.logger import init_logger
init_logger(__name__).info_once(
"DeepSeek V4 expert_dtype resolved to %r", expert_dtype
)
return self._resolved_expert_dtype
@property
@@ -614,26 +596,25 @@ class DeepseekV4MegaMoEExperts(nn.Module):
weight_name: str,
shard_id: str,
expert_id: int,
return_success: bool = False,
) -> bool | None:
) -> bool:
local_expert_id = self._map_global_expert_id(expert_id)
if local_expert_id == -1:
return False if return_success else None
return False
# Scalar params (weight_scale_2, input_scale): 1D per-expert
if "weight_scale_2" in weight_name or "input_scale" in weight_name:
param.data[local_expert_id].copy_(loaded_weight)
return True if return_success else None
return True
expert_data = param.data[local_expert_id]
if shard_id in ("w1", "w3"):
if "w13_" not in weight_name:
return False if return_success else None
return False
shard_offset = 0 if shard_id == "w1" else self.intermediate_size
expert_data = expert_data.narrow(0, shard_offset, self.intermediate_size)
elif shard_id == "w2":
if "w2_" not in weight_name:
return False if return_success else None
return False
else:
raise ValueError(f"Unsupported expert shard id: {shard_id}")
@@ -644,7 +625,7 @@ class DeepseekV4MegaMoEExperts(nn.Module):
f"vs checkpoint {tuple(loaded_weight.shape)}"
)
expert_data.copy_(loaded_weight)
return True if return_success else None
return True
def _check_runtime_supported(self) -> None:
if not torch.cuda.is_available():
@@ -654,7 +635,7 @@ class DeepseekV4MegaMoEExperts(nn.Module):
raise NotImplementedError(
"DeepSeek V4 MegaMoE expert weights must be loaded on CUDA."
)
if torch.cuda.get_device_capability(device)[0] != 10:
if torch.cuda.get_device_capability(device)[0] < 10:
raise NotImplementedError("DeepGEMM MegaMoE requires SM100 GPUs.")
if self.hidden_size % 128 != 0 or self.intermediate_size % 128 != 0:
raise ValueError(
@@ -707,98 +688,7 @@ class DeepseekV4MegaMoEExperts(nn.Module):
"""
return sf.to(torch.float32)
def _nvfp4_to_bf16(
self,
w_uint8: torch.Tensor, # [E, M, K//2]
w_scale_f8: torch.Tensor, # [E, M, K//16] (float8_e4m3fn, UE4M3 standard NVFP4)
w_scale_2: torch.Tensor, # [E] float32 global scale
w_input_scale: torch.Tensor, # [E] float32 activation scale
) -> torch.Tensor:
"""Dequantize NVFP4 expert weights to BF16.
Formula: weight_bf16 = e2m1_value * block_scale_ue4m3 * global_scale
(input_scale is for activations, not weights)
"""
device = w_uint8.device
E, M, K2 = w_uint8.shape
K = K2 * 2 # unpacked dim
# Unpack E2M1 FP4 → BF16
e2m1_lut = torch.tensor(
[0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0],
dtype=torch.bfloat16, device=device,
)
even_raw = (w_uint8 & 0x0F).int()
odd_raw = ((w_uint8 >> 4) & 0x0F).int()
even_sign = torch.where(even_raw >= 8, -1.0, 1.0).to(torch.bfloat16)
odd_sign = torch.where(odd_raw >= 8, -1.0, 1.0).to(torch.bfloat16)
even_vals = even_sign * e2m1_lut[even_raw & 0x07]
odd_vals = odd_sign * e2m1_lut[odd_raw & 0x07]
w_bf16 = torch.stack([even_vals, odd_vals], dim=-1).reshape(E, M, K)
# Dequantize: e2m1 * block_scale * global_scale
block_scale = self._ue8m0_to_float32(w_scale_f8) # [E, M, K//16]
GROUP_SIZE = 16
# Expand block scale to match weight elements
block_scale_expanded = block_scale.unsqueeze(-1).expand(
-1, -1, -1, GROUP_SIZE
).reshape(E, M, K)
# Global scale: [E] → [E, 1, 1]
global_scale = w_scale_2.view(E, 1, 1)
w_dequant = w_bf16.float() * block_scale_expanded * global_scale
return w_dequant.to(torch.bfloat16)
def _bf16_to_mxfp4(
self,
w_bf16: torch.Tensor, # [E, M, K]
group_size: int = 32,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Re-quantize BF16 → MXFP4 (E2M1 packed uint8 + UE8M0 uint8 scales).
Returns (w_packed_uint8, w_scale_uint8) where:
w_packed_uint8: [E, M, K//2]
w_scale_uint8: [E, M, K//group_size] as uint8 UE8M0 bytes
"""
device = w_bf16.device
E, M, K = w_bf16.shape
# Block quantization
n_groups = K // group_size
w_groups = w_bf16.reshape(E, M, n_groups, group_size)
# Compute block amax
amax = w_groups.abs().amax(dim=-1) # [E, M, n_groups]
# UE8M0 scale: floor to nearest power of 2
# value = 2^(exp-127), so exp = floor(log2(amax)) + 127
amax_clamped = amax.clamp(min=2**-126)
scale_exp = torch.floor(torch.log2(amax_clamped)).to(torch.int32) + 127
scale_exp = scale_exp.clamp(1, 254).to(torch.uint8) # avoid 0 and 255
# Recover float32 scale for quantization
scale_f32 = (scale_exp.to(torch.int32) << 23).view(torch.float32) # [E, M, n_groups]
# Quantize each element: value / scale → nearest E2M1
E2M1_POS = torch.tensor(
[0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0],
dtype=torch.float32, device=device,
)
scaled = w_groups.float() / scale_f32.unsqueeze(-1) # [E, M, n_groups, group_size]
scaled_abs = scaled.abs()
diff = (scaled_abs.unsqueeze(-1) - E2M1_POS).abs()
fp4_idx = diff.argmin(dim=-1) # [E, M, n_groups, group_size]
sign = (scaled < 0).int()
fp4_val = (sign << 3) | fp4_idx.int()
# Pack 2 FP4 values per byte
fp4_flat = fp4_val.reshape(E, M, K)
even = fp4_flat[:, :, 0::2]
odd = fp4_flat[:, :, 1::2]
w_packed = ((odd << 4) | even).to(torch.uint8).view(torch.int8)
return w_packed, scale_exp
def get_symm_buffer(self):
import deep_gemm
@@ -893,8 +783,8 @@ class DeepseekV4MegaMoEExperts(nn.Module):
activation_clamp=activation_clamp,
fast_math=fast_math,
)
# Sync to catch mega_moe CUDA errors immediately
torch.cuda.synchronize()
if os.environ.get('NVFP4_DEBUG_SYNC', '') == '1':
torch.cuda.synchronize()
DeepseekV4MegaMoEExperts.weight_loader.supports_moe_loading = True # type: ignore[attr-defined]
@@ -1239,7 +1129,7 @@ class DeepseekV4Attention(nn.Module):
self.rope_parameters = config.rope_scaling
# Initialize rotary embedding BEFORE DeepseekV4MLAModules (which needs it)
rope_parameters = config.rope_parameters
rope_parameters = dict(config.rope_parameters)
rope_parameters["rope_theta"] = (
config.compress_rope_theta if self.compress_ratio > 1 else config.rope_theta
)
@@ -1653,14 +1543,19 @@ class DeepseekV4Model(nn.Module):
else:
if ".experts." in name:
# E8M0 scales are stored as float8_e8m0fnu in
# checkpoints but the MoE param is uint8. copy_()
# would do a numeric conversion (e.g. 2^-7 → 0),
# destroying the raw exponent bytes.
# MXFP4 checkpoints but NVFP4 uses float8_e4m3fn.
# The uint8 view+copy path is only valid for MXFP4;
# for NVFP4 it would paste raw E8M0 bytes into an
# E4M3 buffer, producing garbage.
if (
"weight_scale" in name
and loaded_weight.dtype == torch.float8_e8m0fnu
):
loaded_weight = loaded_weight.view(torch.uint8)
assert False, (
f"E8M0 weight_scale encountered for NVFP4 experts "
f"({name}) — this is only valid for MXFP4. "
f"Check checkpoint dtype."
)
for mapping in expert_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
@@ -1681,7 +1576,6 @@ class DeepseekV4Model(nn.Module):
name_mapped,
shard_id=shard_id,
expert_id=expert_id,
return_success=True,
)
if success:
name = name_mapped
@@ -1851,7 +1745,10 @@ class DeepseekV4Model(nn.Module):
fp8_from_bf16 = 0
bf16_converted = 0
compressor_converted = 0
diag_printed = False
# Build shard index once for compressor reconstruction (avoids N×M full-shard loads)
_shard_index = self._build_shard_index("/model") if os.path.isdir("/model") else None
for layer_idx, layer in enumerate(self.layers):
attn = layer.attn
@@ -1864,20 +1761,11 @@ class DeepseekV4Model(nn.Module):
continue
if mod.weight.dtype in (torch.uint8, torch.int8):
# NVFP4 -> dequant to bf16 -> requant to FP8
if not diag_printed and layer_idx == 0:
ws = getattr(mod, 'weight_scale', None)
ws2 = getattr(mod, 'weight_scale_2', None)
print(f"[DIAG-wo_a:0] dtype={mod.weight.dtype} shape={mod.weight.shape} "
f"ws_dtype={ws.dtype if ws is not None else None} ws_shape={ws.shape if ws is not None else None} "
f"ws2_val={ws2.data.item() if ws2 is not None else None}")
self._convert_nvfp4_to_fp8(mod, E2M1_LUT, FP8_MAX)
fp8_converted += 1
elif mod.weight.dtype == torch.bfloat16:
if not diag_printed and layer_idx == 0:
print(f"[DIAG-wo_a:0] dtype=bf16 shape={mod.weight.shape} (direct bf16→fp8)")
self._convert_bf16_to_fp8(mod, FP8_MAX)
fp8_from_bf16 += 1
diag_printed = True
# BF16 conversion: attention layers via .forward()
for proj_name in bf16_proj_names:
@@ -1886,10 +1774,6 @@ class DeepseekV4Model(nn.Module):
mod = getattr(attn, proj_name)
if not hasattr(mod, "weight") or mod.weight.dtype not in (torch.uint8, torch.int8):
continue
if not diag_printed and layer_idx == 0:
ws = getattr(mod, 'weight_scale', None)
print(f"[DIAG-bf16:0/{proj_name}] dtype={mod.weight.dtype} shape={mod.weight.shape} "
f"ws_dtype={ws.dtype if ws is not None else None} ws_shape={ws.shape if ws is not None else None}")
self._dequant_nvfp4_to_bf16(mod, E2M1_LUT)
bf16_converted += 1
@@ -1906,14 +1790,14 @@ class DeepseekV4Model(nn.Module):
compressor = getattr(mla_attn, "compressor", None)
if compressor is not None and hasattr(compressor, "fused_wkv_wgate"):
compressor_converted += self._reconstruct_compressor_weight(
compressor.fused_wkv_wgate, attn, layer_idx, E2M1_LUT)
compressor.fused_wkv_wgate, attn, layer_idx, E2M1_LUT, _shard_index=_shard_index)
# Indexer compressor (C4A layers only)
indexer = getattr(mla_attn, "indexer", None)
if indexer is not None:
idx_compressor = getattr(indexer, "compressor", None)
if idx_compressor is not None and hasattr(idx_compressor, "fused_wkv_wgate"):
compressor_converted += self._reconstruct_compressor_weight(
idx_compressor.fused_wkv_wgate, indexer, layer_idx, E2M1_LUT, sub_path=".indexer")
idx_compressor.fused_wkv_wgate, indexer, layer_idx, E2M1_LUT, sub_path=".indexer", _shard_index=_shard_index)
# Shared experts
ffn = layer.ffn
@@ -1968,17 +1852,10 @@ class DeepseekV4Model(nn.Module):
else:
w_dequant = w_bf16
# Replace weight with bf16 version
# Diagnostics: check for NaN/Inf from bad dequant
if w_dequant.isnan().any() or w_dequant.isinf().any():
nan_count = w_dequant.isnan().sum().item()
inf_count = w_dequant.isinf().sum().item()
print(f"[NVFP4-DEQUANT-WARN] {getattr(mod, 'prefix', 'unknown')}: "
f"shape={w_dequant.shape}, dtype={w_dequant.dtype}, "
f"NaN={nan_count}, Inf={inf_count}, "
f"block_scale range=[{block_scale.min().item():.6f}, {block_scale.max().item():.6f}], "
f"global_scale={global_scale:.6f}")
# Free source tensors eagerly to avoid holding uint8+bf16+fp32 simultaneously
del w_uint8, w_bf16
mod.weight = torch.nn.Parameter(w_dequant, requires_grad=False)
del w_dequant
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
mod.quant_method = UnquantizedLinearMethod()
for attr in ("weight_scale", "weight_scale_2", "input_scale",
@@ -2061,25 +1938,38 @@ class DeepseekV4Model(nn.Module):
bmm_batch_size=bmm_batch_size,
)
# Free source tensors eagerly
del w_uint8, w_bf16, w_dequant
mod.weight = torch.nn.Parameter(w_fp8, requires_grad=False)
del w_fp8
# weight_scale_inv is what the attention runtime reads as b_scale
# for deepseek_v4_fp8_einsum -> DeepGEMM fp8_einsum.
# It must be the DeepGEMM-formatted block scale (dg_ws), NOT the
# per-tensor scalar. See: deepseek_v4_attention.py line 319.
mod.weight_scale_inv = torch.nn.Parameter(ws, requires_grad=False)
# weight_scale is not used at runtime for BMM layers; remove it
# to avoid confusing other code paths.
for attr in ("weight_scale", "weight_scale_2", "input_scale"):
if hasattr(mod, attr):
delattr(mod, attr)
del ws
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
mod.quant_method = UnquantizedLinearMethod()
for attr in ("weight_scale", "weight_scale_2", "input_scale"):
if hasattr(mod, attr):
delattr(mod, attr)
def _reconstruct_compressor_weight(self, fused_mod, parent_mod, layer_idx, e2m1_lut, sub_path=""):
@staticmethod
def _build_shard_index(ckpt_dir: str) -> dict[str, str]:
"""Build key→shard_path index from safetensors metadata (no tensor I/O)."""
import glob
from safetensors import safe_open
index = {}
for shard_file in sorted(glob.glob(os.path.join(ckpt_dir, "model-*.safetensors"))):
try:
with safe_open(shard_file, framework="pt") as f:
for key in f.keys():
index[key] = shard_file
except Exception:
continue
return index
def _reconstruct_compressor_weight(self, fused_mod, parent_mod, layer_idx, e2m1_lut, sub_path="", _shard_index=None):
"""Reconstruct compressor fused_wkv_wgate from checkpoint.
Compressor weights are SKIPPED during loading because NVFP4 uint8 data
@@ -2087,8 +1977,7 @@ class DeepseekV4Model(nn.Module):
We read the original uint8 data from the safetensors checkpoint, unpack
E2M1, dequantize, and stack into the fused weight param.
"""
import glob
from safetensors.torch import load_file
from safetensors import safe_open
# Find the checkpoint directory
# The model weights are mounted at /model in Docker
@@ -2103,49 +1992,45 @@ class DeepseekV4Model(nn.Module):
# We read from checkpoint (before mapper), so use original names
layer_prefix = f"model.layers.{layer_idx}.self_attn.compressor{sub_path}"
# Find which shard contains this layer's compressor weights
wkv_key = f"{layer_prefix}.kv_proj.weight"
wgate_key = f"{layer_prefix}.gate_proj.weight"
wkv_scale_key = f"{layer_prefix}.kv_proj.weight_scale"
wgate_scale_key = f"{layer_prefix}.gate_proj.weight_scale"
wkv_scale2_key = f"{layer_prefix}.kv_proj.weight_scale_2"
wgate_scale2_key = f"{layer_prefix}.gate_proj.weight_scale_2"
wkv_iscale_key = f"{layer_prefix}.kv_proj.input_scale"
wgate_iscale_key = f"{layer_prefix}.gate_proj.input_scale"
# All keys we need from the checkpoint
keys = {
'wkv_uint8': f"{layer_prefix}.kv_proj.weight",
'wgate_uint8': f"{layer_prefix}.gate_proj.weight",
'wkv_block_scale': f"{layer_prefix}.kv_proj.weight_scale",
'wgate_block_scale': f"{layer_prefix}.gate_proj.weight_scale",
'wkv_global_scale': f"{layer_prefix}.kv_proj.weight_scale_2",
'wgate_global_scale': f"{layer_prefix}.gate_proj.weight_scale_2",
'wkv_input_scale': f"{layer_prefix}.kv_proj.input_scale",
'wgate_input_scale': f"{layer_prefix}.gate_proj.input_scale",
}
# Load from safetensors
wkv_uint8 = None
wgate_uint8 = None
wkv_block_scale = None
wgate_block_scale = None
wkv_global_scale = None
wgate_global_scale = None
wkv_input_scale = None
wgate_input_scale = None
shard_files = sorted(glob.glob(os.path.join(ckpt_dir, "model-*.safetensors")))
for shard_file in shard_files:
# Read tensors using shard index for targeted access (no full-shard loads)
tensors = {}
for name, key in keys.items():
shard_path = (_shard_index or {}).get(key)
if shard_path is None:
continue
try:
shard_data = load_file(shard_file)
with safe_open(shard_path, framework="pt") as f:
if key in f.keys():
tensors[name] = f.get_tensor(key)
except Exception:
continue
if wkv_key in shard_data:
wkv_uint8 = shard_data[wkv_key]
wkv_block_scale = shard_data.get(wkv_scale_key)
wkv_global_scale = shard_data.get(wkv_scale2_key)
wkv_input_scale = shard_data.get(wkv_iscale_key)
if wgate_key in shard_data:
wgate_uint8 = shard_data[wgate_key]
wgate_block_scale = shard_data.get(wgate_scale_key)
wgate_global_scale = shard_data.get(wgate_scale2_key)
wgate_input_scale = shard_data.get(wgate_iscale_key)
if wkv_uint8 is not None and wgate_uint8 is not None:
break
wkv_uint8 = tensors.get('wkv_uint8')
wgate_uint8 = tensors.get('wgate_uint8')
if wkv_uint8 is None or wgate_uint8 is None:
# Layer might not have a compressor (compress_ratio=1 layers)
return 0
wkv_block_scale = tensors.get('wkv_block_scale')
wgate_block_scale = tensors.get('wgate_block_scale')
wkv_global_scale = tensors.get('wkv_global_scale')
wgate_global_scale = tensors.get('wgate_global_scale')
wkv_input_scale = tensors.get('wkv_input_scale')
wgate_input_scale = tensors.get('wgate_input_scale')
device = fused_mod.weight.device
wkv_uint8 = wkv_uint8.to(device)
wgate_uint8 = wgate_uint8.to(device)
@@ -2469,12 +2354,10 @@ class DeepseekV4ForCausalLM(nn.Module):
loader = AutoWeightsLoader(self, skip_substrs=["mtp."])
loaded_params = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
self.model.finalize_mega_moe_weights()
# Sync after mega_moe weight transform to catch CUDA errors early
torch.cuda.synchronize()
print("[NVFP4] mega_moe finalize_weights done, CUDA OK")
self.model._convert_nvfp4_post_load()
torch.cuda.synchronize()
print("[NVFP4] post-load conversion done, CUDA OK")
if os.environ.get('NVFP4_DEBUG_SYNC', '') == '1':
torch.cuda.synchronize()
print("[NVFP4] post-load conversion done, CUDA OK")
return loaded_params
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: