sync B200 deployment files: Dockerfile, docker-compose, patches

This commit is contained in:
2026-05-14 14:12:52 +00:00
parent 7e2f219259
commit f2656dcf6d
5 changed files with 1852 additions and 248 deletions

54
Dockerfile Normal file
View File

@@ -0,0 +1,54 @@
# DeepSeek V4 NVFP4 vLLM + CUTLASS NVFP4 Mega MoE Kernel
FROM vllm/vllm-openai:nightly-x86_64
# Remove broken nixl_ep (built against CUDA 12, image is CUDA 13)
RUN pip uninstall -y nixl-ep; rm -rf /usr/local/lib/python3.12/dist-packages/nixl_ep
RUN apt-get update && apt-get install -y git screen cmake libcusolver-dev-13-0 libcusparse-dev-13-0 libcublas-dev-13-0 libcurand-dev-13-0 libcufft-dev-13-0 libnvjitlink-dev-13-0 && rm -rf /var/lib/apt/lists/*
# Remove the broken symlink if it exists
RUN rm -f /usr/local/cuda/lib64/libcudart.so.12
ENV CUDA_HOME=/usr/local/cuda
ENV TORCH_CUDA_ARCH_LIST="10.0"
# Clone latest CUTLASS (has NVFP4 block-scaled MMA support)
ARG CUTLASS_CACHE_BUSTER=1
RUN git clone --depth 1 https://github.com/NVIDIA/cutlass.git /root/cutlass
# Clone our NVFP4 mega_moe kernel
ARG KERNEL_CACHE_BUSTER=24
RUN git clone https://sweetapi.com/biondizzle/nvfp4-megamoe-kernel.git /root/nvfp4-megamoe-kernel && \
cd /root/nvfp4-megamoe-kernel && \
pip install -e .
# Build the CUTLASS NVFP4 block-scaled GEMM extension
RUN cd /root/nvfp4-megamoe-kernel/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm && \
mkdir -p cutlass_nvfp4_gemm && \
CUTLASS_INCLUDE_DIR=/root/cutlass/include \
TORCH_CUDA_ARCH_LIST=10.0 \
python3 setup.py build_ext --inplace
# Install TileLang (for potential future use)
RUN pip install tilelang
ENV PYTHONPATH="/root/nvfp4-megamoe-kernel/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm:/root/nvfp4-megamoe-kernel:${PYTHONPATH}"
# Copy patches
ARG PATCH_CACHE_BUSTER=82
COPY patches/deepseek_v4.py /tmp/patches/deepseek_v4.py
COPY patches/staging_kernel.py /tmp/patches/staging_kernel.py
COPY patches/deepseek_v4_attention.py /tmp/patches/deepseek_v4_attention.py
# Apply patches
RUN VLLM_MODELS_DIR=$(python3 -c "import vllm.model_executor.models; import os; print(os.path.dirname(vllm.model_executor.models.__file__))") && \
VLLM_LAYERS_DIR=$(python3 -c "import vllm.model_executor.layers; import os; print(os.path.dirname(vllm.model_executor.layers.__file__))") && \
cp /tmp/patches/deepseek_v4.py "$VLLM_MODELS_DIR/deepseek_v4.py" && \
cp /tmp/patches/staging_kernel.py "$VLLM_MODELS_DIR/staging_kernel.py" && \
cp /tmp/patches/deepseek_v4_attention.py "$VLLM_LAYERS_DIR/deepseek_v4_attention.py" && \
rm -rf /tmp/patches
# Verify
RUN python3 -c "import torch; import cutlass_nvfp4_gemm._C; print('CUTLASS NVFP4 OK')" && \
python3 -c "import vllm; print('vLLM OK')" && \
python3 -c "import nvfp4_megamoe_kernel; print('NVFP4 kernel OK')"

View File

@@ -1,30 +1,22 @@
services:
vllm:
image: atl.vultrcr.com/vllm/vllm-with-lmcache:dream-build
pull_policy: always
entrypoint:
- bash
- -c
- |
cp /patches/deepseek_v4.py /usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/deepseek_v4.py
exec vllm serve "$$@"
- --
build:
context: .
ports:
- "8000:8000"
environment:
- HF_TOKEN=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
- OMP_NUM_THREADS=128
- MEGA_MOE_DEBUG=1
- MEGA_MOE_STATIC=0
- MEGA_MOE_USE_CUTLASS=1
- DG_JIT_DEBUG=1
command:
- /model
- --trust-remote-code
- --kv-cache-dtype=fp8
- --block-size=256
- --enable-expert-parallel
- --tensor-parallel-size=8
- --compilation-config={"cudagraph_mode":"FULL_AND_PIECEWISE","custom_ops":["all"]}
- --attention_config.use_fp4_indexer_cache=True
- --enforce-eager
- --tokenizer-mode=deepseek_v4
- --tool-call-parser=deepseek_v4
- --enable-auto-tool-choice
- --reasoning-parser=deepseek_v4
- --speculative_config={"method":"mtp","num_speculative_tokens":2}
- --host=0.0.0.0
- --port=8000
deploy:
@@ -34,13 +26,5 @@ services:
- driver: nvidia
count: all
capabilities: [gpu]
ipc: host
security_opt:
- seccomp:unconfined
tty: true
stdin_open: true
volumes:
- /root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4:/model:ro
- /root/nvidia-meeting/deepseek-v4-quant/patches/deepseek_v4.py:/patches/deepseek_v4.py:ro
- /root/nvidia-meeting/deepseek-v4-quant/patches:/patches:ro
network_mode: host

View File

@@ -1,40 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ==============================================================================
# DeepSeek V4 NVFP4 Patch — Version Banner (printed at import time)
# ==============================================================================
import datetime as _dt
import os as _os
_git_commit = _os.popen("git -C /root/nvidia-meeting/deepseek-v4-quant rev-parse --short HEAD 2>/dev/null || echo 'unknown'").read().strip()
print(f"""
{'='*70}
DeepSeek V4 NVFP4 Patch
{'='*70}
Commit: {_git_commit}
Loaded: {_dt.datetime.now().strftime('%Y-%m-%d %H:%M:%S UTC')}
Node: {_os.uname().nodename}
Architecture:
wo_a → FP8 + DeepGEMM block scale (BMM einsum)
wq_b/wo_b → BF16 (UnquantizedLinearMethod)
fused_wqa → BF16 (stacked q_a + kv, dequantized from NVFP4)
compressor → BF16 (reconstructed from separate kv_proj+gate_proj)
shared_exp → FP8 (Fp8LinearMethod, DeepGEMM)
MoE experts → NVFP4 (FusedMoE, FLASHINFER_TRTLLM) — NOT converted
Bugs fixed:
#1 DeepGEMM sf.dim() — block scale format (deepgemm_post_process)
#2 fused_skip_regex — q_b/o_a/o_b scales no longer skipped
#3 input_scale — removed from weight dequant (activations only)
#4 compressor indexer — sub_path for .indexer keys
#5 block scale dtype — must be float32, not float8_e4m3fn
#6 block scale values — torch.full(fp8_scale) not torch.ones
#7 UE8M0 block scale — .to(float32) misinterprets E8M0 as E4M3
{'='*70}
""")
# ==============================================================================
import typing
from collections.abc import Callable, Iterable
from itertools import islice
@@ -185,8 +151,9 @@ class DeepseekV4FP8Config(Fp8Config):
try:
hf_config = get_current_vllm_config().model_config.hf_config
except Exception:
# vllm_config not yet set; defer the decision until a
# later call lands inside set_current_vllm_config.
# vllm_config not yet set; return safe default but do NOT
# cache — a later call inside set_current_vllm_config may
# resolve differently.
return "fp4"
expert_dtype = getattr(hf_config, "expert_dtype", "fp4")
if expert_dtype not in _DEEPSEEK_V4_EXPERT_DTYPES:
@@ -195,11 +162,6 @@ class DeepseekV4FP8Config(Fp8Config):
f"expected one of {_DEEPSEEK_V4_EXPERT_DTYPES}."
)
self._resolved_expert_dtype = expert_dtype
from vllm.logger import init_logger
init_logger(__name__).info_once(
"DeepSeek V4 expert_dtype resolved to %r", expert_dtype
)
return self._resolved_expert_dtype
@property
@@ -244,11 +206,23 @@ class DeepseekV4FP8Config(Fp8Config):
return isinstance(layer, FusedMoE) and self.expert_dtype == "fp4"
import triton
import triton.language as tl
import torch
"""
NVFP4 staging kernel — full FP4 (E2M1) activations + UE4M3 block16 scales.
The mxf4nvf4 PTX instruction requires BOTH A and B to be FP4 (E2M1 packed).
This kernel quantizes BF16 activations → E2M1 packed uint8 with UE4M3 scales.
"""
@triton.jit
def _deepseek_v4_stage_mega_moe_inputs_kernel(
hidden_states,
x_fp8,
x_sf,
x_fp4, # uint8, shape (M, K//2) — E2M1 packed, 2 values per byte
x_sf, # int32, shape (M, K//64) — UE4M3 packed, 4 scales per int32
topk_ids,
topk_weights,
topk_idx_out,
@@ -269,8 +243,8 @@ def _deepseek_v4_stage_mega_moe_inputs_kernel(
topk_weights_out_stride_k: tl.constexpr,
hidden_size: tl.constexpr,
top_k: tl.constexpr,
BLOCK_K: tl.constexpr,
GROUP_K: tl.constexpr,
BLOCK_K: tl.constexpr, # 128 elements (loaded from hidden)
GROUP_K: tl.constexpr, # 16 (NVFP4 group_size)
BLOCK_TOPK: tl.constexpr,
) -> None:
token_id = tl.program_id(0)
@@ -284,35 +258,94 @@ def _deepseek_v4_stage_mega_moe_inputs_kernel(
other=0.0,
).to(tl.float32)
num_groups: tl.constexpr = BLOCK_K // GROUP_K
hidden_groups = tl.reshape(tl.abs(hidden), [num_groups, GROUP_K])
amax = tl.max(hidden_groups, axis=1)
num_groups: tl.constexpr = BLOCK_K // GROUP_K # 8
hidden_groups = tl.reshape(hidden, [num_groups, GROUP_K])
abs_groups = tl.reshape(tl.abs(hidden), [num_groups, GROUP_K])
amax = tl.max(abs_groups, axis=1)
amax = tl.maximum(amax, 1.0e-4)
scale = amax / 448.0
# ---- UE4M3 scale computation ----
# scale = amax / 6.0 (E2M1 max value = 6)
# Then quantize scale to UE4M3 format
scale = amax / 6.0
scale_bits = scale.to(tl.uint32, bitcast=True)
scale_exp = ((scale_bits >> 23) & 0xFF) + ((scale_bits & 0x7FFFFF) != 0).to(
tl.uint32
)
scale_exp = tl.minimum(tl.maximum(scale_exp, 1), 254)
rounded_scale = (scale_exp << 23).to(tl.float32, bitcast=True)
scale_exp = (scale_bits >> 23) & 0xFF
scale_mant = scale_bits & 0x7FFFFF
hidden_groups = tl.reshape(hidden, [num_groups, GROUP_K])
scaled = hidden_groups * (1.0 / rounded_scale)[:, None]
scaled = tl.reshape(scaled, [BLOCK_K])
fp8 = scaled.to(tl.float8e4nv)
# Convert FP32 → E4M3 manually
e4m3_exp = scale_exp - 120 # FP32 bias=127, E4M3 bias=7
e4m3_exp = tl.maximum(e4m3_exp, 0)
e4m3_exp = tl.minimum(e4m3_exp, 15)
e4m3_mant = scale_mant >> 20
round_bit = (scale_mant >> 19) & 1
e4m3_mant = e4m3_mant + round_bit
overflow = e4m3_mant >= 8
e4m3_mant = tl.where(overflow, 0, e4m3_mant)
e4m3_exp = tl.where(overflow, e4m3_exp + 1, e4m3_exp)
e4m3_exp = tl.minimum(e4m3_exp, 15)
scale_e4m3_bits = (e4m3_exp << 3) | e4m3_mant
# Reconstruct dequantized scale for E2M1 quantization
e4m3_exp_for_recon = tl.maximum(e4m3_exp.to(tl.int32) - 7, -126)
two_pow_exp_bits = (e4m3_exp_for_recon + 127).to(tl.uint32) << 23
two_pow_exp = two_pow_exp_bits.to(tl.float32, bitcast=True)
normal_value = (1.0 + e4m3_mant.to(tl.float32) / 8.0) * two_pow_exp
subnormal_value = (e4m3_mant.to(tl.float32) / 8.0) * 0.015625
e4m3_value = tl.where(e4m3_exp == 0, subnormal_value, normal_value)
# ---- E2M1 FP4 quantization ----
# E2M1 LUT (unsigned): [0, 0.5, 1, 1.5, 2, 3, 4, 6]
# Nearest-neighbor using thresholds (midpoints between consecutive values)
scaled = hidden_groups * (1.0 / tl.maximum(e4m3_value, 1e-6))[:, None]
# Clamp to E2M1 range [-6, 6]
scaled = tl.maximum(scaled, -6.0)
scaled = tl.minimum(scaled, 6.0)
abs_s = tl.abs(scaled)
# E2M1 quantization using arithmetic instead of nested tl.where (Triton compile error)
# LUT: [0, 0.5, 1, 1.5, 2, 3, 4, 6] → thresholds at midpoints
# idx = sum(abs_s >= threshold_i) for thresholds [0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0]
e2m1_idx = ((abs_s >= 0.25).to(tl.int32) + (abs_s >= 0.75).to(tl.int32) +
(abs_s >= 1.25).to(tl.int32) + (abs_s >= 1.75).to(tl.int32) +
(abs_s >= 2.5).to(tl.int32) + (abs_s >= 3.5).to(tl.int32) +
(abs_s >= 5.0).to(tl.int32))
sign_bit = (scaled < 0).to(tl.int32)
e2m1_4bit = (sign_bit << 3) | e2m1_idx # 4-bit: (sign << 3) | index
# Pack 2 E2M1 values per byte: even→low nibble, odd→high nibble
PACKED_K: tl.constexpr = BLOCK_K // 2 # 64
e2m1_pairs = tl.reshape(e2m1_4bit, [PACKED_K, 2])
even, odd = tl.split(e2m1_pairs) # splits last axis (size 2) into two [PACKED_K] tensors
packed_byte = (odd.to(tl.uint8) << 4) | even.to(tl.uint8)
packed_k_offsets = k_block_id * PACKED_K + tl.arange(0, PACKED_K)
packed_k_mask = packed_k_offsets < (hidden_size // 2)
tl.store(
x_fp8 + token_id * x_stride_m + k_offsets * x_stride_k,
fp8,
mask=k_mask,
x_fp4 + token_id * x_stride_m + packed_k_offsets * x_stride_k,
packed_byte,
mask=packed_k_mask,
)
scale_offsets = tl.arange(0, num_groups)
packed_scale = tl.sum(scale_exp << (scale_offsets * 8), axis=0).to(tl.int32)
tl.store(
x_sf + token_id * x_sf_stride_m + k_block_id * x_sf_stride_k,
packed_scale,
)
# Pack UE4M3 bytes into int32 (NVFP4: group_size=16, 4 groups per 64 elements)
# 8 groups per k_block of 128 → 2 int32s per k_block
# int32 can only pack 4 bytes (shifts >= 32 are UB on GPU), so split into two packs
scale_offsets = tl.arange(0, num_groups) # [0..7]
first_half = scale_offsets < 4 # groups 0-3 → int32[0]
second_half = scale_offsets >= 4 # groups 4-7 → int32[1]
packed_lo = tl.sum(
tl.where(first_half, scale_e4m3_bits.to(tl.int32) << (scale_offsets * 8), 0),
axis=0,
).to(tl.int32)
packed_hi = tl.sum(
tl.where(second_half, scale_e4m3_bits.to(tl.int32) << ((scale_offsets - 4) * 8), 0),
axis=0,
).to(tl.int32)
# Write 2 int32s per k_block: x_sf shape is (M, K//64) = (M, num_k_blocks * 2)
sf_base = token_id * x_sf_stride_m + k_block_id * 2 * x_sf_stride_k
tl.store(x_sf + sf_base, packed_lo)
tl.store(x_sf + sf_base + x_sf_stride_k, packed_hi)
if k_block_id == 0:
topk_offsets = tl.arange(0, BLOCK_TOPK)
@@ -351,8 +384,8 @@ def _stage_deepseek_v4_mega_moe_inputs(
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
x_fp8: torch.Tensor,
x_sf: torch.Tensor,
x_fp4: torch.Tensor, # uint8, shape (M, K//2)
x_sf: torch.Tensor, # int32, shape (M, K//64)
topk_idx_out: torch.Tensor,
topk_weights_out: torch.Tensor,
) -> None:
@@ -376,7 +409,7 @@ def _stage_deepseek_v4_mega_moe_inputs(
block_topk = triton.next_power_of_2(top_k)
_deepseek_v4_stage_mega_moe_inputs_kernel[grid](
hidden_states,
x_fp8,
x_fp4,
x_sf,
topk_ids,
topk_weights,
@@ -384,8 +417,8 @@ def _stage_deepseek_v4_mega_moe_inputs(
topk_weights_out,
hidden_states.stride(0),
hidden_states.stride(1),
x_fp8.stride(0),
x_fp8.stride(1),
x_fp4.stride(0),
x_fp4.stride(1),
x_sf.stride(0),
x_sf.stride(1),
topk_ids.stride(0),
@@ -399,7 +432,7 @@ def _stage_deepseek_v4_mega_moe_inputs(
hidden_size,
top_k,
BLOCK_K=block_k,
GROUP_K=32,
GROUP_K=16, # NVFP4: group_size=16 (scale_vec::4X)
BLOCK_TOPK=block_topk,
num_warps=4,
)
@@ -425,8 +458,21 @@ def make_deepseek_v4_expert_params_mapping(
class DeepseekV4MegaMoEExperts(nn.Module):
"""MegaMoE experts for DeepSeek V4 with NVFP4 quantization.
Loads NVFP4 expert weights (E2M1 packed uint8 + float8_e4m3fn block scales
+ float32 global scales) and feeds them natively to the DeepGEMM
fp8_nvfp4_mega_moe kernel (kind::mxf4nvf4.scale_vec::4X).
No conversion to MXFP4. Experts stay NVFP4. The global scale (weight_scale_2)
is folded into the block scales before kernel consumption.
"""
_symm_buffer_cache: dict[tuple[int, int, int, int, int, int, int], object] = {}
# NVFP4 E2M1 lookup table (positive values, sign from bit 3)
E2M1_LUT = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0]
# MXFP4 E2M1 is the same format
def __init__(
self,
vllm_config: VllmConfig,
@@ -451,52 +497,83 @@ class DeepseekV4MegaMoEExperts(nn.Module):
self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
weight_attrs = {"weight_loader": self.weight_loader}
# NVFP4 weights: E2M1 packed as uint8, 2 values per byte
self.w13_weight = nn.Parameter(
torch.zeros(
num_local_experts,
2 * intermediate_size,
hidden_size // 2,
dtype=torch.uint8,
dtype=torch.int8,
),
requires_grad=False,
)
set_weight_attrs(self.w13_weight, weight_attrs)
# NVFP4 block scales: float8_e4m3fn, group_size=16
# Shape: [num_local_experts, 2*intermediate_size, hidden_size // 16]
self.w13_weight_scale = nn.Parameter(
torch.zeros(
num_local_experts,
2 * intermediate_size,
hidden_size // 32,
dtype=torch.uint8,
hidden_size // 16,
dtype=torch.float8_e4m3fn,
),
requires_grad=False,
)
set_weight_attrs(self.w13_weight_scale, weight_attrs)
self.w13_weight_scale.quant_method = "block"
# NVFP4 global scales: float32, per-expert
self.w13_weight_scale_2 = nn.Parameter(
torch.zeros(num_local_experts, dtype=torch.float32),
requires_grad=False,
)
set_weight_attrs(self.w13_weight_scale_2, weight_attrs)
# NVFP4 activation scales: float32, per-expert
self.w13_input_scale = nn.Parameter(
torch.zeros(num_local_experts, dtype=torch.float32),
requires_grad=False,
)
set_weight_attrs(self.w13_input_scale, weight_attrs)
self.w2_weight = nn.Parameter(
torch.zeros(
num_local_experts,
hidden_size,
intermediate_size // 2,
dtype=torch.uint8,
dtype=torch.int8,
),
requires_grad=False,
)
set_weight_attrs(self.w2_weight, weight_attrs)
# NVFP4 block scales for w2
self.w2_weight_scale = nn.Parameter(
torch.zeros(
num_local_experts,
hidden_size,
intermediate_size // 32,
dtype=torch.uint8,
intermediate_size // 16,
dtype=torch.float8_e4m3fn,
),
requires_grad=False,
)
set_weight_attrs(self.w2_weight_scale, weight_attrs)
self.w2_weight_scale.quant_method = "block"
self.w2_weight_scale_2 = nn.Parameter(
torch.zeros(num_local_experts, dtype=torch.float32),
requires_grad=False,
)
set_weight_attrs(self.w2_weight_scale_2, weight_attrs)
self.w2_input_scale = nn.Parameter(
torch.zeros(num_local_experts, dtype=torch.float32),
requires_grad=False,
)
set_weight_attrs(self.w2_input_scale, weight_attrs)
self._transformed_l1_weights: tuple[torch.Tensor, torch.Tensor] | None = None
self._transformed_l2_weights: tuple[torch.Tensor, torch.Tensor] | None = None
@@ -519,21 +596,25 @@ class DeepseekV4MegaMoEExperts(nn.Module):
weight_name: str,
shard_id: str,
expert_id: int,
return_success: bool = False,
) -> bool | None:
) -> bool:
local_expert_id = self._map_global_expert_id(expert_id)
if local_expert_id == -1:
return False if return_success else None
return False
# Scalar params (weight_scale_2, input_scale): 1D per-expert
if "weight_scale_2" in weight_name or "input_scale" in weight_name:
param.data[local_expert_id].copy_(loaded_weight)
return True
expert_data = param.data[local_expert_id]
if shard_id in ("w1", "w3"):
if "w13_" not in weight_name:
return False if return_success else None
return False
shard_offset = 0 if shard_id == "w1" else self.intermediate_size
expert_data = expert_data.narrow(0, shard_offset, self.intermediate_size)
elif shard_id == "w2":
if "w2_" not in weight_name:
return False if return_success else None
return False
else:
raise ValueError(f"Unsupported expert shard id: {shard_id}")
@@ -544,11 +625,7 @@ class DeepseekV4MegaMoEExperts(nn.Module):
f"vs checkpoint {tuple(loaded_weight.shape)}"
)
expert_data.copy_(loaded_weight)
return True if return_success else None
@staticmethod
def _ue8m0_uint8_to_float(sf: torch.Tensor) -> torch.Tensor:
return (sf.to(torch.int32) << 23).view(torch.float32)
return True
def _check_runtime_supported(self) -> None:
if not torch.cuda.is_available():
@@ -558,7 +635,7 @@ class DeepseekV4MegaMoEExperts(nn.Module):
raise NotImplementedError(
"DeepSeek V4 MegaMoE expert weights must be loaded on CUDA."
)
if torch.cuda.get_device_capability(device)[0] != 10:
if torch.cuda.get_device_capability(device)[0] < 10:
raise NotImplementedError("DeepGEMM MegaMoE requires SM100 GPUs.")
if self.hidden_size % 128 != 0 or self.intermediate_size % 128 != 0:
raise ValueError(
@@ -571,41 +648,51 @@ class DeepseekV4MegaMoEExperts(nn.Module):
return
self._check_runtime_supported()
import vllm.third_party.deep_gemm as deep_gemm
from nvfp4_megamoe_kernel import transform_nvfp4_weights_for_mega_moe
w13_scale = deep_gemm.transform_sf_into_required_layout(
self._ue8m0_uint8_to_float(self.w13_weight_scale.data).contiguous(),
2 * self.intermediate_size,
self.hidden_size,
(1, 32),
self.num_local_experts,
)
w2_scale = deep_gemm.transform_sf_into_required_layout(
self._ue8m0_uint8_to_float(self.w2_weight_scale.data).contiguous(),
self.hidden_size,
self.intermediate_size,
(1, 32),
self.num_local_experts,
)
# === Native NVFP4 path ===
# The DeepGEMM nvfp4 mega_moe kernel consumes NVFP4 directly:
# - E2M1 packed uint8 (same as checkpoint)
# - UE4M3 block scales (float8_e4m3fn), group_size=16
# - float32 global scale folded into block scales
# No conversion to MXFP4. Experts stay NVFP4.
# Fold global scales into block scales and transform for the kernel
self._transformed_l1_weights, self._transformed_l2_weights = (
deep_gemm.transform_weights_for_mega_moe(
(self.w13_weight.data.view(torch.int8).contiguous(), w13_scale),
(self.w2_weight.data.view(torch.int8).contiguous(), w2_scale),
transform_nvfp4_weights_for_mega_moe(
(self.w13_weight.data.contiguous(),
self.w13_weight_scale.data.contiguous()),
(self.w2_weight.data.contiguous(),
self.w2_weight_scale.data.contiguous()),
l1_weight_scale_2=self.w13_weight_scale_2.data.contiguous(),
l2_weight_scale_2=self.w2_weight_scale_2.data.contiguous(),
)
)
# Drop the original loader-side parameters: the MegaMoE kernels only
# consume the transformed views above. transform_weights_for_mega_moe
# allocates a fresh tensor for the L1 weight (see _interleave_l1_weights)
# and fresh SF tensors for L1/L2; the L2 weight is the only tensor that
# aliases the original storage, and _transformed_l2_weights still holds
# it, so the storage stays live after we drop the Parameter.
# Drop the original loader-side parameters
self.w13_weight = None
self.w13_weight_scale = None
self.w13_weight_scale_2 = None
self.w13_input_scale = None
self.w2_weight = None
self.w2_weight_scale = None
self.w2_weight_scale_2 = None
self.w2_input_scale = None
@staticmethod
def _ue8m0_to_float32(sf: torch.Tensor) -> torch.Tensor:
"""Convert NVFP4 block scales (float8_e4m3fn / UE4M3) to float32.
Checkpoint stores float8_e4m3fn (standard NVFP4 spec, NOT UE8M0).
Simple .to(float32) is correct — shift-by-23 was wrong (Bug #7 fix).
"""
return sf.to(torch.float32)
def get_symm_buffer(self):
import vllm.third_party.deep_gemm as deep_gemm
import nvfp4_megamoe_kernel as deep_gemm
from nvfp4_megamoe_kernel import SymmBuffer, get_symm_buffer_for_nvfp4_mega_moe
group = get_ep_group().device_group
device = torch.accelerator.current_device_index()
@@ -620,7 +707,8 @@ class DeepseekV4MegaMoEExperts(nn.Module):
)
symm_buffer = self._symm_buffer_cache.get(key)
if symm_buffer is None:
symm_buffer = deep_gemm.get_symm_buffer_for_mega_moe(
# NVFP4 SymmBuffer: 2x SF size due to group_size=16
symm_buffer = get_symm_buffer_for_nvfp4_mega_moe(
group,
self.num_experts,
self.max_num_tokens,
@@ -666,7 +754,7 @@ class DeepseekV4MegaMoEExperts(nn.Module):
activation_clamp: float | None,
fast_math: bool,
) -> None:
import vllm.third_party.deep_gemm as deep_gemm
import nvfp4_megamoe_kernel as deep_gemm
symm_buffer = self.get_symm_buffer()
num_tokens = hidden_states.shape[0]
@@ -680,13 +768,57 @@ class DeepseekV4MegaMoEExperts(nn.Module):
symm_buffer.topk_weights[:num_tokens],
)
# Debug: check staging output
import os
if int(os.environ.get('MEGA_MOE_DEBUG', '0')):
print(f"[MEGA_MOE_DEBUG] After staging: x dtype={symm_buffer.x.dtype} shape={symm_buffer.x.shape}")
print(f"[MEGA_MOE_DEBUG] x_sf dtype={symm_buffer.x_sf.dtype} shape={symm_buffer.x_sf.shape}")
print(f"[MEGA_MOE_DEBUG] topk_idx dtype={symm_buffer.topk_idx.dtype} shape={symm_buffer.topk_idx.shape}")
print(f"[MEGA_MOE_DEBUG] topk_weights dtype={symm_buffer.topk_weights.dtype} shape={symm_buffer.topk_weights.shape}")
# Check for NaN/Inf in the staging output
x_sample = symm_buffer.x[:num_tokens]
sf_sample = symm_buffer.x_sf[:num_tokens]
print(f"[MEGA_MOE_DEBUG] x range: min={x_sample.min().item()} max={x_sample.max().item()}")
if sf_sample.numel() > 0:
print(f"[MEGA_MOE_DEBUG] x_sf range: min={sf_sample.to(torch.float32).min().item()} max={sf_sample.to(torch.float32).max().item}")
topk_sample = symm_buffer.topk_idx[:num_tokens]
print(f"[MEGA_MOE_DEBUG] topk_idx range: min={topk_sample.min().item()} max={topk_sample.max().item()}")
torch.cuda.synchronize()
print("[MEGA_MOE_DEBUG] Staging CUDA sync OK")
# This method must have been already called during the weight loading phase.
# We call it again here to cover the dummy weight loading case.
self.finalize_weights()
assert self._transformed_l1_weights is not None
assert self._transformed_l2_weights is not None
deep_gemm.fp8_fp4_mega_moe(
from nvfp4_megamoe_kernel import nvfp4_mega_moe_full as fp8_nvfp4_mega_moe
# Debug: dump shapes before mega_moe
import os
if int(os.environ.get('MEGA_MOE_DEBUG', '0')):
l1_w, l1_sf = self._transformed_l1_weights
l2_w, l2_sf = self._transformed_l2_weights
print(f"[MEGA_MOE_DEBUG] num_tokens={num_tokens}, hidden={hidden_states.shape[1]}")
print(f"[MEGA_MOE_DEBUG] l1_w: dtype={l1_w.dtype} shape={l1_w.shape} stride={l1_w.stride()}")
print(f"[MEGA_MOE_DEBUG] l1_sf: dtype={l1_sf.dtype} shape={l1_sf.shape} stride={l1_sf.stride()}")
print(f"[MEGA_MOE_DEBUG] l2_w: dtype={l2_w.dtype} shape={l2_w.shape} stride={l2_w.stride()}")
print(f"[MEGA_MOE_DEBUG] l2_sf: dtype={l2_sf.dtype} shape={l2_sf.shape} stride={l2_sf.stride()}")
print(f"[MEGA_MOE_DEBUG] symm_buffer nbytes={symm_buffer.buffer.nbytes} rank={symm_buffer.group.rank()}")
print(f"[MEGA_MOE_DEBUG] num_experts={self.num_experts} topk={topk_ids.shape[1]} max_tokens={self.max_num_tokens}")
print(f"[MEGA_MOE_DEBUG] y: dtype={y.dtype} shape={y.shape}")
# Force CUDA sync to catch any prior async errors
torch.cuda.synchronize()
print("[MEGA_MOE_DEBUG] CUDA sync OK — prior ops clean")
# MEGA_MOE_STATIC: skip the kernel entirely, return zeros
# Tests whether the crash is in the kernel launch or in prior data prep
if int(os.environ.get('MEGA_MOE_STATIC', '0')):
print(f"[MEGA_MOE_STATIC] Skipping fp8_nvfp4_mega_moe, returning zeros")
y.zero_()
return
fp8_nvfp4_mega_moe(
y,
self._transformed_l1_weights,
self._transformed_l2_weights,
@@ -694,6 +826,8 @@ class DeepseekV4MegaMoEExperts(nn.Module):
activation_clamp=activation_clamp,
fast_math=fast_math,
)
if os.environ.get('NVFP4_DEBUG_SYNC', '') == '1':
torch.cuda.synchronize()
DeepseekV4MegaMoEExperts.weight_loader.supports_moe_loading = True # type: ignore[attr-defined]
@@ -751,9 +885,7 @@ class DeepseekV4MoE(nn.Module):
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.prefix = prefix
self.use_mega_moe = (
vllm_config.kernel_config.moe_backend == "deep_gemm_mega_moe"
)
self.use_mega_moe = True # Force mega_moe for NVFP4 pipeline
if self.use_mega_moe and not vllm_config.parallel_config.enable_expert_parallel:
raise NotImplementedError(
"DeepSeek V4 MegaMoE currently requires expert parallel. "
@@ -774,12 +906,7 @@ class DeepseekV4MoE(nn.Module):
raise NotImplementedError(
"DeepSeek V4 MegaMoE currently supports sqrtsoftplus routing only."
)
if self.use_mega_moe and getattr(config, "expert_dtype", "fp4") != "fp4":
raise NotImplementedError(
"DeepSeek V4 MegaMoE only supports fp4 experts; got expert_dtype="
f"{config.expert_dtype!r}. Drop --kernel-config moe_backend="
"deep_gemm_mega_moe for this checkpoint."
)
# NVFP4 experts work with mega_moe via NVFP4→MXFP4 conversion in finalize_weights
self.gate = GateLinear(
config.hidden_size,
@@ -1045,7 +1172,7 @@ class DeepseekV4Attention(nn.Module):
self.rope_parameters = config.rope_scaling
# Initialize rotary embedding BEFORE DeepseekV4MLAModules (which needs it)
rope_parameters = config.rope_parameters
rope_parameters = dict(config.rope_parameters)
rope_parameters["rope_theta"] = (
config.compress_rope_theta if self.compress_ratio > 1 else config.rope_theta
)
@@ -1236,6 +1363,15 @@ class DeepseekV4DecoderLayer(nn.Module):
positions: torch.Tensor,
input_ids: torch.Tensor | None,
) -> torch.Tensor:
# DEBUG: skip attention entirely, just run FFN on raw input
if int(os.environ.get('SKIP_ATTENTION', '0')):
# Flatten to 2D for ffn, then restore
org_shape = x.shape
x_2d = x.view(-1, x.shape[-1])
x_2d = self.ffn_norm(x_2d)
x_2d = self.ffn(x_2d, input_ids)
return x_2d.view(org_shape)
residual = x
x, post, comb = self.hc_pre(
x, self.hc_attn_fn, self.hc_attn_scale, self.hc_attn_base
@@ -1262,9 +1398,7 @@ class DeepseekV4Model(nn.Module):
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.use_mega_moe = (
vllm_config.kernel_config.moe_backend == "deep_gemm_mega_moe"
)
self.use_mega_moe = True # Force mega_moe for NVFP4 pipeline
if self.use_mega_moe and not vllm_config.parallel_config.enable_expert_parallel:
raise NotImplementedError(
"DeepSeek V4 MegaMoE currently requires expert parallel. "
@@ -1461,14 +1595,19 @@ class DeepseekV4Model(nn.Module):
else:
if ".experts." in name:
# E8M0 scales are stored as float8_e8m0fnu in
# checkpoints but the MoE param is uint8. copy_()
# would do a numeric conversion (e.g. 2^-7 → 0),
# destroying the raw exponent bytes.
# MXFP4 checkpoints but NVFP4 uses float8_e4m3fn.
# The uint8 view+copy path is only valid for MXFP4;
# for NVFP4 it would paste raw E8M0 bytes into an
# E4M3 buffer, producing garbage.
if (
"weight_scale" in name
and loaded_weight.dtype == torch.float8_e8m0fnu
):
loaded_weight = loaded_weight.view(torch.uint8)
assert False, (
f"E8M0 weight_scale encountered for NVFP4 experts "
f"({name}) — this is only valid for MXFP4. "
f"Check checkpoint dtype."
)
for mapping in expert_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
@@ -1489,7 +1628,6 @@ class DeepseekV4Model(nn.Module):
name_mapped,
shard_id=shard_id,
expert_id=expert_id,
return_success=True,
)
if success:
name = name_mapped
@@ -1537,16 +1675,10 @@ class DeepseekV4Model(nn.Module):
weight_scale_2_val = global_amax / (6.0 * 448.0)
weight_scale_2 = weight_scale_2_val.to(torch.float32)
# Per-block scale (weight_scale): UE8M0 format
# scale_fmt=ue8m0: block_scale = 2^(exp-127), stored as
# uint8 exponent byte viewed as float8_e4m3fn
# Per-block scale (weight_scale): UE4M3 format (standard NVFP4)
# block_scale = amax / (6.0 * weight_scale_2)
block_scale = amax / (6.0 * weight_scale_2_val)
# Convert to UE8M0: floor to nearest power of 2
# UE8M0 exponent = floor(log2(block_scale)) + 127
block_scale_clamped = block_scale.clamp(min=2**-127)
block_scale_exp = torch.floor(torch.log2(block_scale_clamped)).to(torch.int32) + 127
block_scale_exp = block_scale_exp.clamp(0, 254).to(torch.uint8)
weight_scale = block_scale_exp.view(torch.float8_e4m3fn)
weight_scale = block_scale.clamp(0.0, 448.0).to(torch.float8_e4m3fn)
# Quantize to FP4 (E2M1)
# E2M1 LUT: 0, 0.5, 1, 1.5, 2, 3, 4, 6 (positive)
@@ -1554,10 +1686,8 @@ class DeepseekV4Model(nn.Module):
[0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0],
dtype=torch.float32, device=w_bf16.device,
)
# For each block, dequantize the block scale from UE8M0
block_scale_f32 = (block_scale_exp.to(torch.int32) << 23).view(torch.float32)
# Scale the weight values: normalized = w / (block_scale * weight_scale_2)
# We need to find the nearest FP4 value
block_scale_f32 = block_scale.clamp(0.0, 448.0)
scaled = w_blocks / (block_scale_f32.unsqueeze(-1) * weight_scale_2_val)
# Find nearest FP4 index (0-7 for magnitude)
# Use absolute value for matching, then apply sign
@@ -1575,7 +1705,7 @@ class DeepseekV4Model(nn.Module):
even = fp4_flat[:, 0::2] # lower nibble
odd = fp4_flat[:, 1::2] # upper nibble
packed = (odd << 4) | even
weight_packed = packed.to(torch.uint8)
weight_packed = packed.to(torch.uint8).view(torch.int8)
# Reshape weight_scale to [out, n_blocks]
weight_scale_2d = weight_scale.reshape(out_dim, n_blocks)
@@ -1647,8 +1777,8 @@ class DeepseekV4Model(nn.Module):
.forward() which goes through quant_method; FP8 would dtype-mismatch)
- compressor.fused_wkv_wgate: Dequant NVFP4->bf16 (used via direct
torch.mm in attention parallel stream)
- shared_experts (gate_up_proj, down_proj): Dequant NVFP4->bf16
- MoE experts: Stay in native NVFP4 (ModelOptNvFp4FusedMoE)
- shared_experts (gate_up_proj, down_proj): Stay native NVFP4 via DeepGEMM fp8_fp4_gemm
- MoE experts: Handled by DeepseekV4MegaMoEExperts (NVFP4→MXFP4)
"""
E2M1_LUT = torch.tensor(
[0, 0.5, 1, 1.5, 2, 3, 4, 6], dtype=torch.bfloat16
@@ -1659,14 +1789,19 @@ class DeepseekV4Model(nn.Module):
# for fp8_einsum. Only layer that needs FP8 conversion.
fp8_proj_names = {"wo_a"}
# Attention layers called via .forward() — need bf16
bf16_proj_names = {"fused_wqa_wkv", "wq_b", "wo_b"}
# Shared expert layers called via .forward() — need bf16
bf16_shared_names = {"gate_up_proj", "down_proj"}
# cuBLAS BF16 is broken on Blackwell — nothing gets dequantized to BF16.
# Everything stays native NVFP4/FP8 via FlashInfer CUTLASS.
bf16_proj_names = set()
bf16_shared_names = set()
fp8_converted = 0
fp8_from_bf16 = 0
bf16_converted = 0
compressor_converted = 0
# Build shard index once for compressor reconstruction (avoids N×M full-shard loads)
_shard_index = self._build_shard_index("/model") if os.path.isdir("/model") else None
for layer_idx, layer in enumerate(self.layers):
attn = layer.attn
@@ -1677,13 +1812,11 @@ class DeepseekV4Model(nn.Module):
mod = getattr(attn, proj_name)
if not hasattr(mod, "weight"):
continue
if mod.weight.dtype == torch.uint8:
if mod.weight.dtype in (torch.uint8, torch.int8):
# NVFP4 -> dequant to bf16 -> requant to FP8
self._convert_nvfp4_to_fp8(mod, E2M1_LUT, FP8_MAX)
fp8_converted += 1
elif mod.weight.dtype == torch.bfloat16:
# modelopt did NOT quantize o_a_proj — it's bf16 already.
# Convert bf16 -> FP8 directly for fp8_einsum path.
self._convert_bf16_to_fp8(mod, FP8_MAX)
fp8_from_bf16 += 1
@@ -1692,7 +1825,7 @@ class DeepseekV4Model(nn.Module):
if not hasattr(attn, proj_name):
continue
mod = getattr(attn, proj_name)
if not hasattr(mod, "weight") or mod.weight.dtype != torch.uint8:
if not hasattr(mod, "weight") or mod.weight.dtype not in (torch.uint8, torch.int8):
continue
self._dequant_nvfp4_to_bf16(mod, E2M1_LUT)
bf16_converted += 1
@@ -1710,23 +1843,23 @@ class DeepseekV4Model(nn.Module):
compressor = getattr(mla_attn, "compressor", None)
if compressor is not None and hasattr(compressor, "fused_wkv_wgate"):
compressor_converted += self._reconstruct_compressor_weight(
compressor.fused_wkv_wgate, attn, layer_idx, E2M1_LUT)
compressor.fused_wkv_wgate, attn, layer_idx, E2M1_LUT, _shard_index=_shard_index)
# Indexer compressor (C4A layers only)
indexer = getattr(mla_attn, "indexer", None)
if indexer is not None:
idx_compressor = getattr(indexer, "compressor", None)
if idx_compressor is not None and hasattr(idx_compressor, "fused_wkv_wgate"):
compressor_converted += self._reconstruct_compressor_weight(
idx_compressor.fused_wkv_wgate, indexer, layer_idx, E2M1_LUT, sub_path=".indexer")
idx_compressor.fused_wkv_wgate, indexer, layer_idx, E2M1_LUT, sub_path=".indexer", _shard_index=_shard_index)
# Shared experts
# Shared experts: dequantize NVFP4 → BF16
ffn = layer.ffn
if hasattr(ffn, "shared_experts") and ffn.shared_experts is not None:
for proj_name in bf16_shared_names:
if not hasattr(ffn.shared_experts, proj_name):
continue
mod = getattr(ffn.shared_experts, proj_name)
if not hasattr(mod, "weight") or mod.weight.dtype != torch.uint8:
if not hasattr(mod, "weight") or mod.weight.dtype not in (torch.uint8, torch.int8):
continue
self._dequant_nvfp4_to_bf16(mod, E2M1_LUT)
bf16_converted += 1
@@ -1737,8 +1870,7 @@ class DeepseekV4Model(nn.Module):
print(f"NVFP4 post-load: {fp8_converted} NVFP4->FP8, "
f"{fp8_from_bf16} BF16->FP8, "
f"{bf16_converted} attn/shared->BF16, "
f"{compressor_converted} compressor->BF16, "
f"MoE experts stay NVFP4")
f"{compressor_converted} compressor->BF16")
def _dequant_nvfp4_to_bf16(self, mod, e2m1_lut):
@@ -1749,9 +1881,8 @@ class DeepseekV4Model(nn.Module):
# Dequantize with scales
if hasattr(mod, "weight_scale") and hasattr(mod, "weight_scale_2"):
# scale_fmt=ue8m0: weight_scale bytes are E8M0 format (power-of-2 only).
# A simple .to(float32) misinterprets them as E4M3. Must reinterpret
# the raw uint8 bits as IEEE 754 exponent field.
# NVFP4 block scales are float8_e4m3fn (UE4M3) — standard spec.
# .to(float32) is correct (Bug #7: shift-by-23 was wrong, reverted)
block_scale = self._ue8m0_to_float32(mod.weight_scale.data)
if block_scale.dim() == 2 and w_bf16.dim() == 2:
block_size = w_bf16.shape[1] // block_scale.shape[1]
@@ -1773,8 +1904,10 @@ class DeepseekV4Model(nn.Module):
else:
w_dequant = w_bf16
# Replace weight with bf16 version
# Free source tensors eagerly to avoid holding uint8+bf16+fp32 simultaneously
del w_uint8, w_bf16
mod.weight = torch.nn.Parameter(w_dequant, requires_grad=False)
del w_dequant
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
mod.quant_method = UnquantizedLinearMethod()
for attr in ("weight_scale", "weight_scale_2", "input_scale",
@@ -1794,7 +1927,7 @@ class DeepseekV4Model(nn.Module):
# Dequantize with scales
if hasattr(mod, "weight_scale") and hasattr(mod, "weight_scale_2"):
# scale_fmt=ue8m0: reinterpret E8M0 bytes as float32
# NVFP4 block scales: float8_e4m3fn → .to(float32) (Bug #7 reverted)
block_scale = self._ue8m0_to_float32(mod.weight_scale.data)
if block_scale.dim() == 2 and w_bf16.dim() == 2:
block_size = w_bf16.shape[1] // block_scale.shape[1]
@@ -1857,21 +1990,38 @@ class DeepseekV4Model(nn.Module):
bmm_batch_size=bmm_batch_size,
)
# Free source tensors eagerly
del w_uint8, w_bf16, w_dequant
mod.weight = torch.nn.Parameter(w_fp8, requires_grad=False)
del w_fp8
# weight_scale_inv is what the attention runtime reads as b_scale
# for deepseek_v4_fp8_einsum -> DeepGEMM fp8_einsum.
# It must be the DeepGEMM-formatted block scale (dg_ws), NOT the
# per-tensor scalar. See: deepseek_v4_attention.py line 319.
mod.weight_scale_inv = torch.nn.Parameter(ws, requires_grad=False)
# weight_scale is not used at runtime for BMM layers; remove it
# to avoid confusing other code paths.
del ws
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
mod.quant_method = UnquantizedLinearMethod()
for attr in ("weight_scale", "weight_scale_2", "input_scale"):
if hasattr(mod, attr):
delattr(mod, attr)
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
mod.quant_method = UnquantizedLinearMethod()
def _reconstruct_compressor_weight(self, fused_mod, parent_mod, layer_idx, e2m1_lut, sub_path=""):
@staticmethod
def _build_shard_index(ckpt_dir: str) -> dict[str, str]:
"""Build key→shard_path index from safetensors metadata (no tensor I/O)."""
import glob
from safetensors import safe_open
index = {}
for shard_file in sorted(glob.glob(os.path.join(ckpt_dir, "model-*.safetensors"))):
try:
with safe_open(shard_file, framework="pt") as f:
for key in f.keys():
index[key] = shard_file
except Exception:
continue
return index
def _reconstruct_compressor_weight(self, fused_mod, parent_mod, layer_idx, e2m1_lut, sub_path="", _shard_index=None):
"""Reconstruct compressor fused_wkv_wgate from checkpoint.
Compressor weights are SKIPPED during loading because NVFP4 uint8 data
@@ -1879,8 +2029,7 @@ class DeepseekV4Model(nn.Module):
We read the original uint8 data from the safetensors checkpoint, unpack
E2M1, dequantize, and stack into the fused weight param.
"""
import glob
from safetensors.torch import load_file
from safetensors import safe_open
# Find the checkpoint directory
# The model weights are mounted at /model in Docker
@@ -1895,49 +2044,45 @@ class DeepseekV4Model(nn.Module):
# We read from checkpoint (before mapper), so use original names
layer_prefix = f"model.layers.{layer_idx}.self_attn.compressor{sub_path}"
# Find which shard contains this layer's compressor weights
wkv_key = f"{layer_prefix}.kv_proj.weight"
wgate_key = f"{layer_prefix}.gate_proj.weight"
wkv_scale_key = f"{layer_prefix}.kv_proj.weight_scale"
wgate_scale_key = f"{layer_prefix}.gate_proj.weight_scale"
wkv_scale2_key = f"{layer_prefix}.kv_proj.weight_scale_2"
wgate_scale2_key = f"{layer_prefix}.gate_proj.weight_scale_2"
wkv_iscale_key = f"{layer_prefix}.kv_proj.input_scale"
wgate_iscale_key = f"{layer_prefix}.gate_proj.input_scale"
# All keys we need from the checkpoint
keys = {
'wkv_uint8': f"{layer_prefix}.kv_proj.weight",
'wgate_uint8': f"{layer_prefix}.gate_proj.weight",
'wkv_block_scale': f"{layer_prefix}.kv_proj.weight_scale",
'wgate_block_scale': f"{layer_prefix}.gate_proj.weight_scale",
'wkv_global_scale': f"{layer_prefix}.kv_proj.weight_scale_2",
'wgate_global_scale': f"{layer_prefix}.gate_proj.weight_scale_2",
'wkv_input_scale': f"{layer_prefix}.kv_proj.input_scale",
'wgate_input_scale': f"{layer_prefix}.gate_proj.input_scale",
}
# Load from safetensors
wkv_uint8 = None
wgate_uint8 = None
wkv_block_scale = None
wgate_block_scale = None
wkv_global_scale = None
wgate_global_scale = None
wkv_input_scale = None
wgate_input_scale = None
shard_files = sorted(glob.glob(os.path.join(ckpt_dir, "model-*.safetensors")))
for shard_file in shard_files:
# Read tensors using shard index for targeted access (no full-shard loads)
tensors = {}
for name, key in keys.items():
shard_path = (_shard_index or {}).get(key)
if shard_path is None:
continue
try:
shard_data = load_file(shard_file)
with safe_open(shard_path, framework="pt") as f:
if key in f.keys():
tensors[name] = f.get_tensor(key)
except Exception:
continue
if wkv_key in shard_data:
wkv_uint8 = shard_data[wkv_key]
wkv_block_scale = shard_data.get(wkv_scale_key)
wkv_global_scale = shard_data.get(wkv_scale2_key)
wkv_input_scale = shard_data.get(wkv_iscale_key)
if wgate_key in shard_data:
wgate_uint8 = shard_data[wgate_key]
wgate_block_scale = shard_data.get(wgate_scale_key)
wgate_global_scale = shard_data.get(wgate_scale2_key)
wgate_input_scale = shard_data.get(wgate_iscale_key)
if wkv_uint8 is not None and wgate_uint8 is not None:
break
wkv_uint8 = tensors.get('wkv_uint8')
wgate_uint8 = tensors.get('wgate_uint8')
if wkv_uint8 is None or wgate_uint8 is None:
# Layer might not have a compressor (compress_ratio=1 layers)
return 0
wkv_block_scale = tensors.get('wkv_block_scale')
wgate_block_scale = tensors.get('wgate_block_scale')
wkv_global_scale = tensors.get('wkv_global_scale')
wgate_global_scale = tensors.get('wgate_global_scale')
wkv_input_scale = tensors.get('wkv_input_scale')
wgate_input_scale = tensors.get('wgate_input_scale')
device = fused_mod.weight.device
wkv_uint8 = wkv_uint8.to(device)
wgate_uint8 = wgate_uint8.to(device)
@@ -1949,7 +2094,7 @@ class DeepseekV4Model(nn.Module):
# Dequantize with scales
def _dequant(w_bf16, block_scale, global_scale, input_scale):
if block_scale is not None and global_scale is not None:
# scale_fmt=ue8m0: reinterpret E8M0 bytes as float32
# NVFP4 block scales: float8_e4m3fn → .to(float32) (Bug #7 reverted)
block_scale = self._ue8m0_to_float32(block_scale.to(device))
if block_scale.dim() == 2 and w_bf16.dim() == 2:
block_size = w_bf16.shape[1] // block_scale.shape[1]
@@ -1972,8 +2117,6 @@ class DeepseekV4Model(nn.Module):
# fused_wkv_wgate.weight = cat([wkv, wgate], dim=0) → (2*head_dim, hidden_size)
w_fused = torch.cat([wkv_dequant, wgate_dequant], dim=0)
# DEBUG: log shapes to diagnose compressor weight mismatch
print(f"NVFP4 compressor layer {layer_idx}: wkv={wkv_dequant.shape}, wgate={wgate_dequant.shape}, fused={w_fused.shape}, existing_param={fused_mod.weight.shape}")
# Replace the weight
fused_mod.weight = torch.nn.Parameter(w_fused, requires_grad=False)
@@ -2041,17 +2184,12 @@ class DeepseekV4Model(nn.Module):
@staticmethod
def _ue8m0_to_float32(sf: torch.Tensor) -> torch.Tensor:
"""Convert UE8M0 (E8M0 power-of-2) scale bytes to float32.
"""Convert NVFP4 block scales (float8_e4m3fn / UE4M3) to float32.
NVFP4 checkpoints with scale_fmt=ue8m0 store per-block weight scales as
E8M0 format (8-bit exponent, no mantissa). The value = 2^(raw_byte - 127).
The bytes are loaded as float8_e4m3fn by safetensors, but a simple
.to(float32) misinterprets them as E4M3 (which has mantissa bits).
Correct conversion: place the raw uint8 bits into the exponent field
of an IEEE 754 float32 (bits 23-30), yielding 2^(raw-127) * implicit_1.
Checkpoint stores float8_e4m3fn (standard NVFP4 spec, NOT UE8M0).
Simple .to(float32) is correct — shift-by-23 was wrong (Bug #7 fix).
"""
raw_uint8 = sf.view(torch.uint8)
return (raw_uint8.to(torch.int32) << 23).view(torch.float32)
return sf.to(torch.float32)
def _unpack_nvfp4_to_bf16(self, w_uint8, e2m1_lut, device):
"""Unpack NVFP4 uint8 packed weights to bf16 using E2M1 format."""
@@ -2269,6 +2407,9 @@ class DeepseekV4ForCausalLM(nn.Module):
loaded_params = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
self.model.finalize_mega_moe_weights()
self.model._convert_nvfp4_post_load()
if os.environ.get('NVFP4_DEBUG_SYNC', '') == '1':
torch.cuda.synchronize()
print("[NVFP4] post-load conversion done, CUDA OK")
return loaded_params
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:

File diff suppressed because it is too large Load Diff

270
patches/staging_kernel.py Normal file
View File

@@ -0,0 +1,270 @@
"""
NVFP4 staging kernel — full FP4 (E2M1) activations + UE4M3 block16 scales.
The mxf4nvf4 PTX instruction requires BOTH A and B to be FP4 (E2M1 packed).
This kernel quantizes BF16 activations → E2M1 packed uint8 with UE4M3 scales.
"""
import triton
import triton.language as tl
import torch
@triton.jit
def _deepseek_v4_stage_mega_moe_inputs_kernel(
hidden_states,
x_fp4, # uint8, shape (M, K//2) — E2M1 packed, 2 values per byte
x_sf, # int32, shape (M, K//64) — UE4M3 packed, 4 scales per int32
topk_ids,
topk_weights,
topk_idx_out,
topk_weights_out,
hidden_stride_m: tl.constexpr,
hidden_stride_k: tl.constexpr,
x_stride_m: tl.constexpr,
x_stride_k: tl.constexpr,
x_sf_stride_m: tl.constexpr,
x_sf_stride_k: tl.constexpr,
topk_ids_stride_m: tl.constexpr,
topk_ids_stride_k: tl.constexpr,
topk_weights_stride_m: tl.constexpr,
topk_weights_stride_k: tl.constexpr,
topk_idx_stride_m: tl.constexpr,
topk_idx_stride_k: tl.constexpr,
topk_weights_out_stride_m: tl.constexpr,
topk_weights_out_stride_k: tl.constexpr,
hidden_size: tl.constexpr,
top_k: tl.constexpr,
BLOCK_K: tl.constexpr, # 128 elements (loaded from hidden)
GROUP_K: tl.constexpr, # 16 (NVFP4 group_size)
BLOCK_TOPK: tl.constexpr,
) -> None:
token_id = tl.program_id(0)
k_block_id = tl.program_id(1)
k_offsets = k_block_id * BLOCK_K + tl.arange(0, BLOCK_K)
k_mask = k_offsets < hidden_size
hidden = tl.load(
hidden_states + token_id * hidden_stride_m + k_offsets * hidden_stride_k,
mask=k_mask,
other=0.0,
).to(tl.float32)
num_groups: tl.constexpr = BLOCK_K // GROUP_K # 8
hidden_groups = tl.reshape(hidden, [num_groups, GROUP_K])
abs_groups = tl.reshape(tl.abs(hidden), [num_groups, GROUP_K])
amax = tl.max(abs_groups, axis=1)
amax = tl.maximum(amax, 1.0e-4)
# ---- UE4M3 scale computation ----
# scale = amax / 6.0 (E2M1 max value = 6)
# Then quantize scale to UE4M3 format
scale = amax / 6.0
scale_bits = scale.to(tl.uint32, bitcast=True)
scale_exp = (scale_bits >> 23) & 0xFF
scale_mant = scale_bits & 0x7FFFFF
# Convert FP32 → E4M3 manually (with subnormal support)
# FP32 bias=127, E4M3 bias=7 → raw exp = scale_exp - 120
e4m3_exp_raw = scale_exp - 120 # can be negative → subnormal
# Normal path: exp >= 1, just truncate mantissa top 3 bits
# RNE rounding: need guard (bit 19), sticky (OR of bits 18:0), and LSB of result
normal_mant = scale_mant >> 20
guard_bit = (scale_mant >> 19) & 1
sticky_bit = tl.where((scale_mant & 0x7FFFF) != 0, 1, 0) # OR of bits [18:0]
result_lsb = normal_mant & 1
# RNE: round up if (guard=1 and sticky=1) or (guard=1 and sticky=0 and lsb=1)
round_up = guard_bit & (sticky_bit | result_lsb)
normal_mant = normal_mant + round_up
normal_exp = e4m3_exp_raw
# Subnormal path: exp_raw <= 0
# Insert implicit leading 1 and right-shift by (1 - exp_raw)
# E4M3 subnormal: value = (mant/8) * 2^(1-7) = (mant/8) * 2^-6
# So we need: (1 + mant_fp32/2^23) * 2^(exp_raw - 7) = (shifted_mant/8) * 2^-6
# shifted_mant = (implicit_1 | mant_fp32) >> (1 - exp_raw - 1) then take top 3 bits
shift = 1 - e4m3_exp_raw # positive when subnormal
mant_with_leading = (0x800000 | scale_mant) # insert implicit 1
# Right-shift to get into the 3-bit E4M3 mantissa window
# We want bits [shift+19 : shift+23) of mant_with_leading for 3 mantissa bits + 1 round bit
subnormal_mant = (mant_with_leading >> (shift.to(tl.int32) + 20)) & 0x7
sub_guard_bit = (mant_with_leading >> (shift.to(tl.int32) + 19)) & 1
# Sticky: OR of all bits below the guard bit in the shifted result
# shift ≤ 8 in practice (amax floor = 1e-4 → scale ≈ 2^-15 → exp_raw ≈ -7), so mask ≤ 2^27
sub_sticky_mask = (1 << (shift.to(tl.int32) + 19)) - 1
sub_sticky_bit = tl.where((mant_with_leading & sub_sticky_mask) != 0, 1, 0)
sub_result_lsb = subnormal_mant & 1
sub_round_up = sub_guard_bit & (sub_sticky_bit | sub_result_lsb)
subnormal_mant = subnormal_mant + sub_round_up
is_normal = e4m3_exp_raw >= 1
e4m3_mant = tl.where(is_normal, normal_mant, subnormal_mant)
e4m3_exp = tl.where(is_normal, normal_exp, 0) # exp=0 for subnormals
# Handle mantissa overflow after rounding
overflow = e4m3_mant >= 8
e4m3_mant = tl.where(overflow, 0, e4m3_mant)
e4m3_exp = tl.where(overflow, e4m3_exp + 1, e4m3_exp)
e4m3_exp = tl.maximum(e4m3_exp, 0)
e4m3_exp = tl.minimum(e4m3_exp, 15)
scale_e4m3_bits = (e4m3_exp << 3) | e4m3_mant
# Reconstruct dequantized scale by decoding the STORED E4M3 bits.
# This guarantees the E2M1 quantization divides by exactly the value
# the CUDA kernel will multiply back — same bits, single decode, no
# possibility of encode/decode disagreement.
stored_exp = (scale_e4m3_bits >> 3) & 0xF
stored_mant = scale_e4m3_bits & 0x7
e4m3_exp_for_recon = tl.maximum(stored_exp.to(tl.int32) - 7, -126)
two_pow_exp_bits = (e4m3_exp_for_recon + 127).to(tl.uint32) << 23
two_pow_exp = two_pow_exp_bits.to(tl.float32, bitcast=True)
normal_value = (1.0 + stored_mant.to(tl.float32) / 8.0) * two_pow_exp
subnormal_value = (stored_mant.to(tl.float32) / 8.0) * 0.015625
e4m3_value = tl.where(stored_exp == 0, subnormal_value, normal_value)
# ---- E2M1 FP4 quantization (unpacked, 1 byte/element) ----
# E2M1 LUT (unsigned): [0, 0.5, 1, 1.5, 2, 3, 4, 6]
# Nearest-neighbor using thresholds (midpoints between consecutive values)
scaled = hidden_groups * (1.0 / tl.maximum(e4m3_value, 1e-6))[:, None]
# Clamp to E2M1 range [-6, 6]
scaled = tl.maximum(scaled, -6.0)
scaled = tl.minimum(scaled, 6.0)
abs_s = tl.abs(scaled)
# Thresholds: midpoints between [0, 0.5, 1, 1.5, 2, 3, 4, 6]
# [0, 0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0, INF]
e2m1_idx = tl.where(abs_s < 0.25, 0,
tl.where(abs_s < 0.75, 1,
tl.where(abs_s < 1.25, 2,
tl.where(abs_s < 1.75, 3,
tl.where(abs_s < 2.5, 4,
tl.where(abs_s < 3.5, 5,
tl.where(abs_s < 5.0, 6, 7)))))))
sign_bit = (scaled < 0).to(tl.int32)
e2m1_4bit = (sign_bit << 3) | e2m1_idx # 4-bit: (sign << 3) | index
# Pack E2M1 pairs into single bytes (2 per byte, low nibble first)
# mxf4nvf4 reads FP4 packed from SMEM — must match kernel's TMA layout
e2m1_flat = tl.reshape(e2m1_4bit, [BLOCK_K])
e2m1_lo = e2m1_flat[0::2] # even indices → low nibble
e2m1_hi = e2m1_flat[1::2] # odd indices → high nibble
e2m1_packed = (e2m1_hi << 4 | e2m1_lo).to(tl.uint8) # [BLOCK_K // 2]
k_offsets_out = k_block_id * (BLOCK_K // 2) + tl.arange(0, BLOCK_K // 2)
k_mask_out = k_offsets_out < (hidden_size // 2)
tl.store(
x_fp4 + token_id * x_stride_m + k_offsets_out * x_stride_k,
e2m1_packed,
mask=k_mask_out,
)
# Pack UE4M3 bytes into int32 (NVFP4: group_size=16, 4 groups per 64 elements)
# 8 groups per k_block of 128 → 2 int32s per k_block
# int32 can only pack 4 bytes (shifts >= 32 are UB), so split into two packs
scale_offsets = tl.arange(0, num_groups) # [0..7]
first_half = scale_offsets < 4 # groups 0-3 → int32[0]
second_half = scale_offsets >= 4 # groups 4-7 → int32[1]
packed_lo = tl.sum(
tl.where(first_half, scale_e4m3_bits.to(tl.int32) << (scale_offsets * 8), 0),
axis=0,
).to(tl.int32)
packed_hi = tl.sum(
tl.where(second_half, scale_e4m3_bits.to(tl.int32) << ((scale_offsets - 4) * 8), 0),
axis=0,
).to(tl.int32)
# Write 2 int32s per k_block: x_sf shape is (M, K//64) = (M, num_k_blocks * 2)
sf_base = token_id * x_sf_stride_m + k_block_id * 2 * x_sf_stride_k
tl.store(x_sf + sf_base, packed_lo)
tl.store(x_sf + sf_base + x_sf_stride_k, packed_hi)
if k_block_id == 0:
topk_offsets = tl.arange(0, BLOCK_TOPK)
topk_mask = topk_offsets < top_k
ids = tl.load(
topk_ids + token_id * topk_ids_stride_m + topk_offsets * topk_ids_stride_k,
mask=topk_mask,
other=0,
).to(tl.int64)
tl.store(
topk_idx_out
+ token_id * topk_idx_stride_m
+ topk_offsets * topk_idx_stride_k,
ids,
mask=topk_mask,
)
weights = tl.load(
topk_weights
+ token_id * topk_weights_stride_m
+ topk_offsets * topk_weights_stride_k,
mask=topk_mask,
other=0.0,
)
tl.store(
topk_weights_out
+ token_id * topk_weights_out_stride_m
+ topk_offsets * topk_weights_out_stride_k,
weights,
mask=topk_mask,
)
def _stage_deepseek_v4_mega_moe_inputs(
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
x_fp4: torch.Tensor, # uint8, shape (M, K//2)
x_sf: torch.Tensor, # int32, shape (M, K//64)
topk_idx_out: torch.Tensor,
topk_weights_out: torch.Tensor,
) -> None:
num_tokens, hidden_size = hidden_states.shape
if num_tokens == 0:
return
if hidden_size % 128 != 0:
raise ValueError(
"DeepSeek V4 MegaMoE input staging requires hidden_size to be "
"a multiple of 128."
)
top_k = topk_ids.shape[1]
if topk_weights.shape != topk_ids.shape:
raise ValueError(
"DeepSeek V4 MegaMoE input staging requires topk_weights and "
"topk_ids to have the same shape."
)
block_k = 128
grid = (num_tokens, triton.cdiv(hidden_size, block_k))
block_topk = triton.next_power_of_2(top_k)
_deepseek_v4_stage_mega_moe_inputs_kernel[grid](
hidden_states,
x_fp4,
x_sf,
topk_ids,
topk_weights,
topk_idx_out,
topk_weights_out,
hidden_states.stride(0),
hidden_states.stride(1),
x_fp4.stride(0),
x_fp4.stride(1),
x_sf.stride(0),
x_sf.stride(1),
topk_ids.stride(0),
topk_ids.stride(1),
topk_weights.stride(0),
topk_weights.stride(1),
topk_idx_out.stride(0),
topk_idx_out.stride(1),
topk_weights_out.stride(0),
topk_weights_out.stride(1),
hidden_size,
top_k,
BLOCK_K=block_k,
GROUP_K=16, # NVFP4: group_size=16 (scale_vec::4X)
BLOCK_TOPK=block_topk,
num_warps=4,
)