Compare commits
2 Commits
mega-moe-n
...
modelopt-n
| Author | SHA1 | Date | |
|---|---|---|---|
| 0c77a88757 | |||
| f2656dcf6d |
54
Dockerfile
Normal file
54
Dockerfile
Normal file
@@ -0,0 +1,54 @@
|
||||
# DeepSeek V4 NVFP4 vLLM + CUTLASS NVFP4 Mega MoE Kernel
|
||||
FROM vllm/vllm-openai:nightly-x86_64
|
||||
|
||||
# Remove broken nixl_ep (built against CUDA 12, image is CUDA 13)
|
||||
RUN pip uninstall -y nixl-ep; rm -rf /usr/local/lib/python3.12/dist-packages/nixl_ep
|
||||
|
||||
RUN apt-get update && apt-get install -y git screen cmake libcusolver-dev-13-0 libcusparse-dev-13-0 libcublas-dev-13-0 libcurand-dev-13-0 libcufft-dev-13-0 libnvjitlink-dev-13-0 && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Remove the broken symlink if it exists
|
||||
RUN rm -f /usr/local/cuda/lib64/libcudart.so.12
|
||||
|
||||
ENV CUDA_HOME=/usr/local/cuda
|
||||
ENV TORCH_CUDA_ARCH_LIST="10.0"
|
||||
|
||||
# Clone latest CUTLASS (has NVFP4 block-scaled MMA support)
|
||||
ARG CUTLASS_CACHE_BUSTER=1
|
||||
RUN git clone --depth 1 https://github.com/NVIDIA/cutlass.git /root/cutlass
|
||||
|
||||
# Clone our NVFP4 mega_moe kernel
|
||||
ARG KERNEL_CACHE_BUSTER=40
|
||||
RUN git clone https://sweetapi.com/biondizzle/nvfp4-megamoe-kernel.git /root/nvfp4-megamoe-kernel && \
|
||||
cd /root/nvfp4-megamoe-kernel && \
|
||||
pip install -e .
|
||||
|
||||
# Build the CUTLASS NVFP4 block-scaled GEMM extension
|
||||
RUN cd /root/nvfp4-megamoe-kernel/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm && \
|
||||
mkdir -p cutlass_nvfp4_gemm && \
|
||||
CUTLASS_INCLUDE_DIR=/root/cutlass/include \
|
||||
TORCH_CUDA_ARCH_LIST=10.0 \
|
||||
python3 setup.py build_ext --inplace
|
||||
|
||||
# Install TileLang (for potential future use)
|
||||
RUN pip install tilelang
|
||||
|
||||
ENV PYTHONPATH="/root/nvfp4-megamoe-kernel/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm:/root/nvfp4-megamoe-kernel:${PYTHONPATH}"
|
||||
|
||||
# Copy patches
|
||||
ARG PATCH_CACHE_BUSTER=82
|
||||
COPY patches/deepseek_v4.py /tmp/patches/deepseek_v4.py
|
||||
COPY patches/staging_kernel.py /tmp/patches/staging_kernel.py
|
||||
COPY patches/deepseek_v4_attention.py /tmp/patches/deepseek_v4_attention.py
|
||||
|
||||
# Apply patches
|
||||
RUN VLLM_MODELS_DIR=$(python3 -c "import vllm.model_executor.models; import os; print(os.path.dirname(vllm.model_executor.models.__file__))") && \
|
||||
VLLM_LAYERS_DIR=$(python3 -c "import vllm.model_executor.layers; import os; print(os.path.dirname(vllm.model_executor.layers.__file__))") && \
|
||||
cp /tmp/patches/deepseek_v4.py "$VLLM_MODELS_DIR/deepseek_v4.py" && \
|
||||
cp /tmp/patches/staging_kernel.py "$VLLM_MODELS_DIR/staging_kernel.py" && \
|
||||
cp /tmp/patches/deepseek_v4_attention.py "$VLLM_LAYERS_DIR/deepseek_v4_attention.py" && \
|
||||
rm -rf /tmp/patches
|
||||
|
||||
# Verify
|
||||
RUN python3 -c "import torch; import cutlass_nvfp4_gemm._C; print('CUTLASS NVFP4 OK')" && \
|
||||
python3 -c "import vllm; print('vLLM OK')" && \
|
||||
python3 -c "import nvfp4_megamoe_kernel; print('NVFP4 kernel OK')"
|
||||
@@ -1,30 +1,22 @@
|
||||
services:
|
||||
vllm:
|
||||
image: atl.vultrcr.com/vllm/vllm-with-lmcache:dream-build
|
||||
pull_policy: always
|
||||
entrypoint:
|
||||
- bash
|
||||
- -c
|
||||
- |
|
||||
cp /patches/deepseek_v4.py /usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/deepseek_v4.py
|
||||
exec vllm serve "$$@"
|
||||
- --
|
||||
build:
|
||||
context: .
|
||||
ports:
|
||||
- "8000:8000"
|
||||
environment:
|
||||
- HF_TOKEN=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
|
||||
- OMP_NUM_THREADS=128
|
||||
- MEGA_MOE_DEBUG=1
|
||||
- MEGA_MOE_STATIC=0
|
||||
- MEGA_MOE_USE_CUTLASS=1
|
||||
- DG_JIT_DEBUG=1
|
||||
command:
|
||||
- /model
|
||||
- --trust-remote-code
|
||||
- --kv-cache-dtype=fp8
|
||||
- --block-size=256
|
||||
- --enable-expert-parallel
|
||||
- --tensor-parallel-size=8
|
||||
- --compilation-config={"cudagraph_mode":"FULL_AND_PIECEWISE","custom_ops":["all"]}
|
||||
- --attention_config.use_fp4_indexer_cache=True
|
||||
- --enforce-eager
|
||||
- --tokenizer-mode=deepseek_v4
|
||||
- --tool-call-parser=deepseek_v4
|
||||
- --enable-auto-tool-choice
|
||||
- --reasoning-parser=deepseek_v4
|
||||
- --speculative_config={"method":"mtp","num_speculative_tokens":2}
|
||||
- --host=0.0.0.0
|
||||
- --port=8000
|
||||
deploy:
|
||||
@@ -34,13 +26,5 @@ services:
|
||||
- driver: nvidia
|
||||
count: all
|
||||
capabilities: [gpu]
|
||||
ipc: host
|
||||
security_opt:
|
||||
- seccomp:unconfined
|
||||
tty: true
|
||||
stdin_open: true
|
||||
volumes:
|
||||
- /root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4:/model:ro
|
||||
- /root/nvidia-meeting/deepseek-v4-quant/patches/deepseek_v4.py:/patches/deepseek_v4.py:ro
|
||||
- /root/nvidia-meeting/deepseek-v4-quant/patches:/patches:ro
|
||||
network_mode: host
|
||||
|
||||
@@ -1,40 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# ==============================================================================
|
||||
# DeepSeek V4 NVFP4 Patch — Version Banner (printed at import time)
|
||||
# ==============================================================================
|
||||
import datetime as _dt
|
||||
import os as _os
|
||||
_git_commit = _os.popen("git -C /root/nvidia-meeting/deepseek-v4-quant rev-parse --short HEAD 2>/dev/null || echo 'unknown'").read().strip()
|
||||
print(f"""
|
||||
{'='*70}
|
||||
DeepSeek V4 NVFP4 Patch
|
||||
{'='*70}
|
||||
Commit: {_git_commit}
|
||||
Loaded: {_dt.datetime.now().strftime('%Y-%m-%d %H:%M:%S UTC')}
|
||||
Node: {_os.uname().nodename}
|
||||
|
||||
Architecture:
|
||||
wo_a → FP8 + DeepGEMM block scale (BMM einsum)
|
||||
wq_b/wo_b → BF16 (UnquantizedLinearMethod)
|
||||
fused_wqa → BF16 (stacked q_a + kv, dequantized from NVFP4)
|
||||
compressor → BF16 (reconstructed from separate kv_proj+gate_proj)
|
||||
shared_exp → FP8 (Fp8LinearMethod, DeepGEMM)
|
||||
MoE experts → NVFP4 (FusedMoE, FLASHINFER_TRTLLM) — NOT converted
|
||||
|
||||
Bugs fixed:
|
||||
#1 DeepGEMM sf.dim() — block scale format (deepgemm_post_process)
|
||||
#2 fused_skip_regex — q_b/o_a/o_b scales no longer skipped
|
||||
#3 input_scale — removed from weight dequant (activations only)
|
||||
#4 compressor indexer — sub_path for .indexer keys
|
||||
#5 block scale dtype — must be float32, not float8_e4m3fn
|
||||
#6 block scale values — torch.full(fp8_scale) not torch.ones
|
||||
#7 UE8M0 block scale — .to(float32) misinterprets E8M0 as E4M3
|
||||
{'='*70}
|
||||
""")
|
||||
# ==============================================================================
|
||||
|
||||
import typing
|
||||
from collections.abc import Callable, Iterable
|
||||
from itertools import islice
|
||||
@@ -185,8 +151,9 @@ class DeepseekV4FP8Config(Fp8Config):
|
||||
try:
|
||||
hf_config = get_current_vllm_config().model_config.hf_config
|
||||
except Exception:
|
||||
# vllm_config not yet set; defer the decision until a
|
||||
# later call lands inside set_current_vllm_config.
|
||||
# vllm_config not yet set; return safe default but do NOT
|
||||
# cache — a later call inside set_current_vllm_config may
|
||||
# resolve differently.
|
||||
return "fp4"
|
||||
expert_dtype = getattr(hf_config, "expert_dtype", "fp4")
|
||||
if expert_dtype not in _DEEPSEEK_V4_EXPERT_DTYPES:
|
||||
@@ -195,11 +162,6 @@ class DeepseekV4FP8Config(Fp8Config):
|
||||
f"expected one of {_DEEPSEEK_V4_EXPERT_DTYPES}."
|
||||
)
|
||||
self._resolved_expert_dtype = expert_dtype
|
||||
from vllm.logger import init_logger
|
||||
|
||||
init_logger(__name__).info_once(
|
||||
"DeepSeek V4 expert_dtype resolved to %r", expert_dtype
|
||||
)
|
||||
return self._resolved_expert_dtype
|
||||
|
||||
@property
|
||||
@@ -244,11 +206,23 @@ class DeepseekV4FP8Config(Fp8Config):
|
||||
return isinstance(layer, FusedMoE) and self.expert_dtype == "fp4"
|
||||
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
import torch
|
||||
|
||||
"""
|
||||
NVFP4 staging kernel — full FP4 (E2M1) activations + UE4M3 block16 scales.
|
||||
|
||||
The mxf4nvf4 PTX instruction requires BOTH A and B to be FP4 (E2M1 packed).
|
||||
This kernel quantizes BF16 activations → E2M1 packed uint8 with UE4M3 scales.
|
||||
"""
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _deepseek_v4_stage_mega_moe_inputs_kernel(
|
||||
hidden_states,
|
||||
x_fp8,
|
||||
x_sf,
|
||||
x_fp4, # uint8, shape (M, K//2) — E2M1 packed, 2 values per byte
|
||||
x_sf, # int32, shape (M, K//64) — UE4M3 packed, 4 scales per int32
|
||||
topk_ids,
|
||||
topk_weights,
|
||||
topk_idx_out,
|
||||
@@ -269,8 +243,8 @@ def _deepseek_v4_stage_mega_moe_inputs_kernel(
|
||||
topk_weights_out_stride_k: tl.constexpr,
|
||||
hidden_size: tl.constexpr,
|
||||
top_k: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
GROUP_K: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr, # 128 elements (loaded from hidden)
|
||||
GROUP_K: tl.constexpr, # 16 (NVFP4 group_size)
|
||||
BLOCK_TOPK: tl.constexpr,
|
||||
) -> None:
|
||||
token_id = tl.program_id(0)
|
||||
@@ -284,35 +258,94 @@ def _deepseek_v4_stage_mega_moe_inputs_kernel(
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
|
||||
num_groups: tl.constexpr = BLOCK_K // GROUP_K
|
||||
hidden_groups = tl.reshape(tl.abs(hidden), [num_groups, GROUP_K])
|
||||
amax = tl.max(hidden_groups, axis=1)
|
||||
num_groups: tl.constexpr = BLOCK_K // GROUP_K # 8
|
||||
hidden_groups = tl.reshape(hidden, [num_groups, GROUP_K])
|
||||
abs_groups = tl.reshape(tl.abs(hidden), [num_groups, GROUP_K])
|
||||
amax = tl.max(abs_groups, axis=1)
|
||||
amax = tl.maximum(amax, 1.0e-4)
|
||||
|
||||
scale = amax / 448.0
|
||||
# ---- UE4M3 scale computation ----
|
||||
# scale = amax / 6.0 (E2M1 max value = 6)
|
||||
# Then quantize scale to UE4M3 format
|
||||
scale = amax / 6.0
|
||||
scale_bits = scale.to(tl.uint32, bitcast=True)
|
||||
scale_exp = ((scale_bits >> 23) & 0xFF) + ((scale_bits & 0x7FFFFF) != 0).to(
|
||||
tl.uint32
|
||||
)
|
||||
scale_exp = tl.minimum(tl.maximum(scale_exp, 1), 254)
|
||||
rounded_scale = (scale_exp << 23).to(tl.float32, bitcast=True)
|
||||
scale_exp = (scale_bits >> 23) & 0xFF
|
||||
scale_mant = scale_bits & 0x7FFFFF
|
||||
|
||||
hidden_groups = tl.reshape(hidden, [num_groups, GROUP_K])
|
||||
scaled = hidden_groups * (1.0 / rounded_scale)[:, None]
|
||||
scaled = tl.reshape(scaled, [BLOCK_K])
|
||||
fp8 = scaled.to(tl.float8e4nv)
|
||||
# Convert FP32 → E4M3 manually
|
||||
e4m3_exp = scale_exp - 120 # FP32 bias=127, E4M3 bias=7
|
||||
e4m3_exp = tl.maximum(e4m3_exp, 0)
|
||||
e4m3_exp = tl.minimum(e4m3_exp, 15)
|
||||
e4m3_mant = scale_mant >> 20
|
||||
round_bit = (scale_mant >> 19) & 1
|
||||
e4m3_mant = e4m3_mant + round_bit
|
||||
overflow = e4m3_mant >= 8
|
||||
e4m3_mant = tl.where(overflow, 0, e4m3_mant)
|
||||
e4m3_exp = tl.where(overflow, e4m3_exp + 1, e4m3_exp)
|
||||
e4m3_exp = tl.minimum(e4m3_exp, 15)
|
||||
scale_e4m3_bits = (e4m3_exp << 3) | e4m3_mant
|
||||
|
||||
# Reconstruct dequantized scale for E2M1 quantization
|
||||
e4m3_exp_for_recon = tl.maximum(e4m3_exp.to(tl.int32) - 7, -126)
|
||||
two_pow_exp_bits = (e4m3_exp_for_recon + 127).to(tl.uint32) << 23
|
||||
two_pow_exp = two_pow_exp_bits.to(tl.float32, bitcast=True)
|
||||
normal_value = (1.0 + e4m3_mant.to(tl.float32) / 8.0) * two_pow_exp
|
||||
subnormal_value = (e4m3_mant.to(tl.float32) / 8.0) * 0.015625
|
||||
e4m3_value = tl.where(e4m3_exp == 0, subnormal_value, normal_value)
|
||||
|
||||
# ---- E2M1 FP4 quantization ----
|
||||
# E2M1 LUT (unsigned): [0, 0.5, 1, 1.5, 2, 3, 4, 6]
|
||||
# Nearest-neighbor using thresholds (midpoints between consecutive values)
|
||||
scaled = hidden_groups * (1.0 / tl.maximum(e4m3_value, 1e-6))[:, None]
|
||||
# Clamp to E2M1 range [-6, 6]
|
||||
scaled = tl.maximum(scaled, -6.0)
|
||||
scaled = tl.minimum(scaled, 6.0)
|
||||
|
||||
abs_s = tl.abs(scaled)
|
||||
# E2M1 quantization using arithmetic instead of nested tl.where (Triton compile error)
|
||||
# LUT: [0, 0.5, 1, 1.5, 2, 3, 4, 6] → thresholds at midpoints
|
||||
# idx = sum(abs_s >= threshold_i) for thresholds [0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0]
|
||||
e2m1_idx = ((abs_s >= 0.25).to(tl.int32) + (abs_s >= 0.75).to(tl.int32) +
|
||||
(abs_s >= 1.25).to(tl.int32) + (abs_s >= 1.75).to(tl.int32) +
|
||||
(abs_s >= 2.5).to(tl.int32) + (abs_s >= 3.5).to(tl.int32) +
|
||||
(abs_s >= 5.0).to(tl.int32))
|
||||
sign_bit = (scaled < 0).to(tl.int32)
|
||||
e2m1_4bit = (sign_bit << 3) | e2m1_idx # 4-bit: (sign << 3) | index
|
||||
|
||||
# Pack 2 E2M1 values per byte: even→low nibble, odd→high nibble
|
||||
PACKED_K: tl.constexpr = BLOCK_K // 2 # 64
|
||||
e2m1_pairs = tl.reshape(e2m1_4bit, [PACKED_K, 2])
|
||||
even, odd = tl.split(e2m1_pairs) # splits last axis (size 2) into two [PACKED_K] tensors
|
||||
packed_byte = (odd.to(tl.uint8) << 4) | even.to(tl.uint8)
|
||||
|
||||
packed_k_offsets = k_block_id * PACKED_K + tl.arange(0, PACKED_K)
|
||||
packed_k_mask = packed_k_offsets < (hidden_size // 2)
|
||||
tl.store(
|
||||
x_fp8 + token_id * x_stride_m + k_offsets * x_stride_k,
|
||||
fp8,
|
||||
mask=k_mask,
|
||||
x_fp4 + token_id * x_stride_m + packed_k_offsets * x_stride_k,
|
||||
packed_byte,
|
||||
mask=packed_k_mask,
|
||||
)
|
||||
|
||||
scale_offsets = tl.arange(0, num_groups)
|
||||
packed_scale = tl.sum(scale_exp << (scale_offsets * 8), axis=0).to(tl.int32)
|
||||
tl.store(
|
||||
x_sf + token_id * x_sf_stride_m + k_block_id * x_sf_stride_k,
|
||||
packed_scale,
|
||||
)
|
||||
# Pack UE4M3 bytes into int32 (NVFP4: group_size=16, 4 groups per 64 elements)
|
||||
# 8 groups per k_block of 128 → 2 int32s per k_block
|
||||
# int32 can only pack 4 bytes (shifts >= 32 are UB on GPU), so split into two packs
|
||||
scale_offsets = tl.arange(0, num_groups) # [0..7]
|
||||
first_half = scale_offsets < 4 # groups 0-3 → int32[0]
|
||||
second_half = scale_offsets >= 4 # groups 4-7 → int32[1]
|
||||
|
||||
packed_lo = tl.sum(
|
||||
tl.where(first_half, scale_e4m3_bits.to(tl.int32) << (scale_offsets * 8), 0),
|
||||
axis=0,
|
||||
).to(tl.int32)
|
||||
packed_hi = tl.sum(
|
||||
tl.where(second_half, scale_e4m3_bits.to(tl.int32) << ((scale_offsets - 4) * 8), 0),
|
||||
axis=0,
|
||||
).to(tl.int32)
|
||||
|
||||
# Write 2 int32s per k_block: x_sf shape is (M, K//64) = (M, num_k_blocks * 2)
|
||||
sf_base = token_id * x_sf_stride_m + k_block_id * 2 * x_sf_stride_k
|
||||
tl.store(x_sf + sf_base, packed_lo)
|
||||
tl.store(x_sf + sf_base + x_sf_stride_k, packed_hi)
|
||||
|
||||
if k_block_id == 0:
|
||||
topk_offsets = tl.arange(0, BLOCK_TOPK)
|
||||
@@ -351,8 +384,8 @@ def _stage_deepseek_v4_mega_moe_inputs(
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
x_fp8: torch.Tensor,
|
||||
x_sf: torch.Tensor,
|
||||
x_fp4: torch.Tensor, # uint8, shape (M, K//2)
|
||||
x_sf: torch.Tensor, # int32, shape (M, K//64)
|
||||
topk_idx_out: torch.Tensor,
|
||||
topk_weights_out: torch.Tensor,
|
||||
) -> None:
|
||||
@@ -376,7 +409,7 @@ def _stage_deepseek_v4_mega_moe_inputs(
|
||||
block_topk = triton.next_power_of_2(top_k)
|
||||
_deepseek_v4_stage_mega_moe_inputs_kernel[grid](
|
||||
hidden_states,
|
||||
x_fp8,
|
||||
x_fp4,
|
||||
x_sf,
|
||||
topk_ids,
|
||||
topk_weights,
|
||||
@@ -384,8 +417,8 @@ def _stage_deepseek_v4_mega_moe_inputs(
|
||||
topk_weights_out,
|
||||
hidden_states.stride(0),
|
||||
hidden_states.stride(1),
|
||||
x_fp8.stride(0),
|
||||
x_fp8.stride(1),
|
||||
x_fp4.stride(0),
|
||||
x_fp4.stride(1),
|
||||
x_sf.stride(0),
|
||||
x_sf.stride(1),
|
||||
topk_ids.stride(0),
|
||||
@@ -399,7 +432,7 @@ def _stage_deepseek_v4_mega_moe_inputs(
|
||||
hidden_size,
|
||||
top_k,
|
||||
BLOCK_K=block_k,
|
||||
GROUP_K=32,
|
||||
GROUP_K=16, # NVFP4: group_size=16 (scale_vec::4X)
|
||||
BLOCK_TOPK=block_topk,
|
||||
num_warps=4,
|
||||
)
|
||||
@@ -425,8 +458,21 @@ def make_deepseek_v4_expert_params_mapping(
|
||||
|
||||
|
||||
class DeepseekV4MegaMoEExperts(nn.Module):
|
||||
"""MegaMoE experts for DeepSeek V4 with NVFP4 quantization.
|
||||
|
||||
Loads NVFP4 expert weights (E2M1 packed uint8 + float8_e4m3fn block scales
|
||||
+ float32 global scales) and feeds them natively to the DeepGEMM
|
||||
fp8_nvfp4_mega_moe kernel (kind::mxf4nvf4.scale_vec::4X).
|
||||
|
||||
No conversion to MXFP4. Experts stay NVFP4. The global scale (weight_scale_2)
|
||||
is folded into the block scales before kernel consumption.
|
||||
"""
|
||||
_symm_buffer_cache: dict[tuple[int, int, int, int, int, int, int], object] = {}
|
||||
|
||||
# NVFP4 E2M1 lookup table (positive values, sign from bit 3)
|
||||
E2M1_LUT = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0]
|
||||
# MXFP4 E2M1 is the same format
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
@@ -451,52 +497,83 @@ class DeepseekV4MegaMoEExperts(nn.Module):
|
||||
self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
|
||||
|
||||
weight_attrs = {"weight_loader": self.weight_loader}
|
||||
|
||||
# NVFP4 weights: E2M1 packed as uint8, 2 values per byte
|
||||
self.w13_weight = nn.Parameter(
|
||||
torch.zeros(
|
||||
num_local_experts,
|
||||
2 * intermediate_size,
|
||||
hidden_size // 2,
|
||||
dtype=torch.uint8,
|
||||
dtype=torch.int8,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(self.w13_weight, weight_attrs)
|
||||
|
||||
# NVFP4 block scales: float8_e4m3fn, group_size=16
|
||||
# Shape: [num_local_experts, 2*intermediate_size, hidden_size // 16]
|
||||
self.w13_weight_scale = nn.Parameter(
|
||||
torch.zeros(
|
||||
num_local_experts,
|
||||
2 * intermediate_size,
|
||||
hidden_size // 32,
|
||||
dtype=torch.uint8,
|
||||
hidden_size // 16,
|
||||
dtype=torch.float8_e4m3fn,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(self.w13_weight_scale, weight_attrs)
|
||||
self.w13_weight_scale.quant_method = "block"
|
||||
|
||||
# NVFP4 global scales: float32, per-expert
|
||||
self.w13_weight_scale_2 = nn.Parameter(
|
||||
torch.zeros(num_local_experts, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(self.w13_weight_scale_2, weight_attrs)
|
||||
|
||||
# NVFP4 activation scales: float32, per-expert
|
||||
self.w13_input_scale = nn.Parameter(
|
||||
torch.zeros(num_local_experts, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(self.w13_input_scale, weight_attrs)
|
||||
|
||||
self.w2_weight = nn.Parameter(
|
||||
torch.zeros(
|
||||
num_local_experts,
|
||||
hidden_size,
|
||||
intermediate_size // 2,
|
||||
dtype=torch.uint8,
|
||||
dtype=torch.int8,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(self.w2_weight, weight_attrs)
|
||||
|
||||
# NVFP4 block scales for w2
|
||||
self.w2_weight_scale = nn.Parameter(
|
||||
torch.zeros(
|
||||
num_local_experts,
|
||||
hidden_size,
|
||||
intermediate_size // 32,
|
||||
dtype=torch.uint8,
|
||||
intermediate_size // 16,
|
||||
dtype=torch.float8_e4m3fn,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(self.w2_weight_scale, weight_attrs)
|
||||
self.w2_weight_scale.quant_method = "block"
|
||||
|
||||
self.w2_weight_scale_2 = nn.Parameter(
|
||||
torch.zeros(num_local_experts, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(self.w2_weight_scale_2, weight_attrs)
|
||||
|
||||
self.w2_input_scale = nn.Parameter(
|
||||
torch.zeros(num_local_experts, dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(self.w2_input_scale, weight_attrs)
|
||||
|
||||
self._transformed_l1_weights: tuple[torch.Tensor, torch.Tensor] | None = None
|
||||
self._transformed_l2_weights: tuple[torch.Tensor, torch.Tensor] | None = None
|
||||
|
||||
@@ -519,21 +596,25 @@ class DeepseekV4MegaMoEExperts(nn.Module):
|
||||
weight_name: str,
|
||||
shard_id: str,
|
||||
expert_id: int,
|
||||
return_success: bool = False,
|
||||
) -> bool | None:
|
||||
) -> bool:
|
||||
local_expert_id = self._map_global_expert_id(expert_id)
|
||||
if local_expert_id == -1:
|
||||
return False if return_success else None
|
||||
return False
|
||||
|
||||
# Scalar params (weight_scale_2, input_scale): 1D per-expert
|
||||
if "weight_scale_2" in weight_name or "input_scale" in weight_name:
|
||||
param.data[local_expert_id].copy_(loaded_weight)
|
||||
return True
|
||||
|
||||
expert_data = param.data[local_expert_id]
|
||||
if shard_id in ("w1", "w3"):
|
||||
if "w13_" not in weight_name:
|
||||
return False if return_success else None
|
||||
return False
|
||||
shard_offset = 0 if shard_id == "w1" else self.intermediate_size
|
||||
expert_data = expert_data.narrow(0, shard_offset, self.intermediate_size)
|
||||
elif shard_id == "w2":
|
||||
if "w2_" not in weight_name:
|
||||
return False if return_success else None
|
||||
return False
|
||||
else:
|
||||
raise ValueError(f"Unsupported expert shard id: {shard_id}")
|
||||
|
||||
@@ -544,11 +625,7 @@ class DeepseekV4MegaMoEExperts(nn.Module):
|
||||
f"vs checkpoint {tuple(loaded_weight.shape)}"
|
||||
)
|
||||
expert_data.copy_(loaded_weight)
|
||||
return True if return_success else None
|
||||
|
||||
@staticmethod
|
||||
def _ue8m0_uint8_to_float(sf: torch.Tensor) -> torch.Tensor:
|
||||
return (sf.to(torch.int32) << 23).view(torch.float32)
|
||||
return True
|
||||
|
||||
def _check_runtime_supported(self) -> None:
|
||||
if not torch.cuda.is_available():
|
||||
@@ -558,7 +635,7 @@ class DeepseekV4MegaMoEExperts(nn.Module):
|
||||
raise NotImplementedError(
|
||||
"DeepSeek V4 MegaMoE expert weights must be loaded on CUDA."
|
||||
)
|
||||
if torch.cuda.get_device_capability(device)[0] != 10:
|
||||
if torch.cuda.get_device_capability(device)[0] < 10:
|
||||
raise NotImplementedError("DeepGEMM MegaMoE requires SM100 GPUs.")
|
||||
if self.hidden_size % 128 != 0 or self.intermediate_size % 128 != 0:
|
||||
raise ValueError(
|
||||
@@ -571,41 +648,51 @@ class DeepseekV4MegaMoEExperts(nn.Module):
|
||||
return
|
||||
|
||||
self._check_runtime_supported()
|
||||
import vllm.third_party.deep_gemm as deep_gemm
|
||||
from nvfp4_megamoe_kernel import transform_nvfp4_weights_for_mega_moe
|
||||
|
||||
w13_scale = deep_gemm.transform_sf_into_required_layout(
|
||||
self._ue8m0_uint8_to_float(self.w13_weight_scale.data).contiguous(),
|
||||
2 * self.intermediate_size,
|
||||
self.hidden_size,
|
||||
(1, 32),
|
||||
self.num_local_experts,
|
||||
)
|
||||
w2_scale = deep_gemm.transform_sf_into_required_layout(
|
||||
self._ue8m0_uint8_to_float(self.w2_weight_scale.data).contiguous(),
|
||||
self.hidden_size,
|
||||
self.intermediate_size,
|
||||
(1, 32),
|
||||
self.num_local_experts,
|
||||
)
|
||||
# === Native NVFP4 path ===
|
||||
# The DeepGEMM nvfp4 mega_moe kernel consumes NVFP4 directly:
|
||||
# - E2M1 packed uint8 (same as checkpoint)
|
||||
# - UE4M3 block scales (float8_e4m3fn), group_size=16
|
||||
# - float32 global scale folded into block scales
|
||||
# No conversion to MXFP4. Experts stay NVFP4.
|
||||
|
||||
# Fold global scales into block scales and transform for the kernel
|
||||
self._transformed_l1_weights, self._transformed_l2_weights = (
|
||||
deep_gemm.transform_weights_for_mega_moe(
|
||||
(self.w13_weight.data.view(torch.int8).contiguous(), w13_scale),
|
||||
(self.w2_weight.data.view(torch.int8).contiguous(), w2_scale),
|
||||
transform_nvfp4_weights_for_mega_moe(
|
||||
(self.w13_weight.data.contiguous(),
|
||||
self.w13_weight_scale.data.contiguous()),
|
||||
(self.w2_weight.data.contiguous(),
|
||||
self.w2_weight_scale.data.contiguous()),
|
||||
l1_weight_scale_2=self.w13_weight_scale_2.data.contiguous(),
|
||||
l2_weight_scale_2=self.w2_weight_scale_2.data.contiguous(),
|
||||
)
|
||||
)
|
||||
# Drop the original loader-side parameters: the MegaMoE kernels only
|
||||
# consume the transformed views above. transform_weights_for_mega_moe
|
||||
# allocates a fresh tensor for the L1 weight (see _interleave_l1_weights)
|
||||
# and fresh SF tensors for L1/L2; the L2 weight is the only tensor that
|
||||
# aliases the original storage, and _transformed_l2_weights still holds
|
||||
# it, so the storage stays live after we drop the Parameter.
|
||||
|
||||
# Drop the original loader-side parameters
|
||||
self.w13_weight = None
|
||||
self.w13_weight_scale = None
|
||||
self.w13_weight_scale_2 = None
|
||||
self.w13_input_scale = None
|
||||
self.w2_weight = None
|
||||
self.w2_weight_scale = None
|
||||
self.w2_weight_scale_2 = None
|
||||
self.w2_input_scale = None
|
||||
|
||||
@staticmethod
|
||||
def _ue8m0_to_float32(sf: torch.Tensor) -> torch.Tensor:
|
||||
"""Convert NVFP4 block scales (float8_e4m3fn / UE4M3) to float32.
|
||||
|
||||
Checkpoint stores float8_e4m3fn (standard NVFP4 spec, NOT UE8M0).
|
||||
Simple .to(float32) is correct — shift-by-23 was wrong (Bug #7 fix).
|
||||
"""
|
||||
return sf.to(torch.float32)
|
||||
|
||||
|
||||
|
||||
def get_symm_buffer(self):
|
||||
import vllm.third_party.deep_gemm as deep_gemm
|
||||
import nvfp4_megamoe_kernel as deep_gemm
|
||||
from nvfp4_megamoe_kernel import SymmBuffer, get_symm_buffer_for_nvfp4_mega_moe
|
||||
|
||||
group = get_ep_group().device_group
|
||||
device = torch.accelerator.current_device_index()
|
||||
@@ -620,7 +707,8 @@ class DeepseekV4MegaMoEExperts(nn.Module):
|
||||
)
|
||||
symm_buffer = self._symm_buffer_cache.get(key)
|
||||
if symm_buffer is None:
|
||||
symm_buffer = deep_gemm.get_symm_buffer_for_mega_moe(
|
||||
# NVFP4 SymmBuffer: 2x SF size due to group_size=16
|
||||
symm_buffer = get_symm_buffer_for_nvfp4_mega_moe(
|
||||
group,
|
||||
self.num_experts,
|
||||
self.max_num_tokens,
|
||||
@@ -666,7 +754,7 @@ class DeepseekV4MegaMoEExperts(nn.Module):
|
||||
activation_clamp: float | None,
|
||||
fast_math: bool,
|
||||
) -> None:
|
||||
import vllm.third_party.deep_gemm as deep_gemm
|
||||
import nvfp4_megamoe_kernel as deep_gemm
|
||||
|
||||
symm_buffer = self.get_symm_buffer()
|
||||
num_tokens = hidden_states.shape[0]
|
||||
@@ -680,13 +768,57 @@ class DeepseekV4MegaMoEExperts(nn.Module):
|
||||
symm_buffer.topk_weights[:num_tokens],
|
||||
)
|
||||
|
||||
# Debug: check staging output
|
||||
import os
|
||||
if int(os.environ.get('MEGA_MOE_DEBUG', '0')):
|
||||
print(f"[MEGA_MOE_DEBUG] After staging: x dtype={symm_buffer.x.dtype} shape={symm_buffer.x.shape}")
|
||||
print(f"[MEGA_MOE_DEBUG] x_sf dtype={symm_buffer.x_sf.dtype} shape={symm_buffer.x_sf.shape}")
|
||||
print(f"[MEGA_MOE_DEBUG] topk_idx dtype={symm_buffer.topk_idx.dtype} shape={symm_buffer.topk_idx.shape}")
|
||||
print(f"[MEGA_MOE_DEBUG] topk_weights dtype={symm_buffer.topk_weights.dtype} shape={symm_buffer.topk_weights.shape}")
|
||||
# Check for NaN/Inf in the staging output
|
||||
x_sample = symm_buffer.x[:num_tokens]
|
||||
sf_sample = symm_buffer.x_sf[:num_tokens]
|
||||
print(f"[MEGA_MOE_DEBUG] x range: min={x_sample.min().item()} max={x_sample.max().item()}")
|
||||
if sf_sample.numel() > 0:
|
||||
print(f"[MEGA_MOE_DEBUG] x_sf range: min={sf_sample.to(torch.float32).min().item()} max={sf_sample.to(torch.float32).max().item}")
|
||||
topk_sample = symm_buffer.topk_idx[:num_tokens]
|
||||
print(f"[MEGA_MOE_DEBUG] topk_idx range: min={topk_sample.min().item()} max={topk_sample.max().item()}")
|
||||
torch.cuda.synchronize()
|
||||
print("[MEGA_MOE_DEBUG] Staging CUDA sync OK")
|
||||
|
||||
# This method must have been already called during the weight loading phase.
|
||||
# We call it again here to cover the dummy weight loading case.
|
||||
self.finalize_weights()
|
||||
|
||||
assert self._transformed_l1_weights is not None
|
||||
assert self._transformed_l2_weights is not None
|
||||
deep_gemm.fp8_fp4_mega_moe(
|
||||
from nvfp4_megamoe_kernel import nvfp4_mega_moe_full as fp8_nvfp4_mega_moe
|
||||
|
||||
# Debug: dump shapes before mega_moe
|
||||
import os
|
||||
if int(os.environ.get('MEGA_MOE_DEBUG', '0')):
|
||||
l1_w, l1_sf = self._transformed_l1_weights
|
||||
l2_w, l2_sf = self._transformed_l2_weights
|
||||
print(f"[MEGA_MOE_DEBUG] num_tokens={num_tokens}, hidden={hidden_states.shape[1]}")
|
||||
print(f"[MEGA_MOE_DEBUG] l1_w: dtype={l1_w.dtype} shape={l1_w.shape} stride={l1_w.stride()}")
|
||||
print(f"[MEGA_MOE_DEBUG] l1_sf: dtype={l1_sf.dtype} shape={l1_sf.shape} stride={l1_sf.stride()}")
|
||||
print(f"[MEGA_MOE_DEBUG] l2_w: dtype={l2_w.dtype} shape={l2_w.shape} stride={l2_w.stride()}")
|
||||
print(f"[MEGA_MOE_DEBUG] l2_sf: dtype={l2_sf.dtype} shape={l2_sf.shape} stride={l2_sf.stride()}")
|
||||
print(f"[MEGA_MOE_DEBUG] symm_buffer nbytes={symm_buffer.buffer.nbytes} rank={symm_buffer.group.rank()}")
|
||||
print(f"[MEGA_MOE_DEBUG] num_experts={self.num_experts} topk={topk_ids.shape[1]} max_tokens={self.max_num_tokens}")
|
||||
print(f"[MEGA_MOE_DEBUG] y: dtype={y.dtype} shape={y.shape}")
|
||||
# Force CUDA sync to catch any prior async errors
|
||||
torch.cuda.synchronize()
|
||||
print("[MEGA_MOE_DEBUG] CUDA sync OK — prior ops clean")
|
||||
|
||||
# MEGA_MOE_STATIC: skip the kernel entirely, return zeros
|
||||
# Tests whether the crash is in the kernel launch or in prior data prep
|
||||
if int(os.environ.get('MEGA_MOE_STATIC', '0')):
|
||||
print(f"[MEGA_MOE_STATIC] Skipping fp8_nvfp4_mega_moe, returning zeros")
|
||||
y.zero_()
|
||||
return
|
||||
|
||||
fp8_nvfp4_mega_moe(
|
||||
y,
|
||||
self._transformed_l1_weights,
|
||||
self._transformed_l2_weights,
|
||||
@@ -694,6 +826,8 @@ class DeepseekV4MegaMoEExperts(nn.Module):
|
||||
activation_clamp=activation_clamp,
|
||||
fast_math=fast_math,
|
||||
)
|
||||
if os.environ.get('NVFP4_DEBUG_SYNC', '') == '1':
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
DeepseekV4MegaMoEExperts.weight_loader.supports_moe_loading = True # type: ignore[attr-defined]
|
||||
@@ -751,9 +885,7 @@ class DeepseekV4MoE(nn.Module):
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.prefix = prefix
|
||||
self.use_mega_moe = (
|
||||
vllm_config.kernel_config.moe_backend == "deep_gemm_mega_moe"
|
||||
)
|
||||
self.use_mega_moe = True # Force mega_moe for NVFP4 pipeline
|
||||
if self.use_mega_moe and not vllm_config.parallel_config.enable_expert_parallel:
|
||||
raise NotImplementedError(
|
||||
"DeepSeek V4 MegaMoE currently requires expert parallel. "
|
||||
@@ -774,12 +906,7 @@ class DeepseekV4MoE(nn.Module):
|
||||
raise NotImplementedError(
|
||||
"DeepSeek V4 MegaMoE currently supports sqrtsoftplus routing only."
|
||||
)
|
||||
if self.use_mega_moe and getattr(config, "expert_dtype", "fp4") != "fp4":
|
||||
raise NotImplementedError(
|
||||
"DeepSeek V4 MegaMoE only supports fp4 experts; got expert_dtype="
|
||||
f"{config.expert_dtype!r}. Drop --kernel-config moe_backend="
|
||||
"deep_gemm_mega_moe for this checkpoint."
|
||||
)
|
||||
# NVFP4 experts work with mega_moe via NVFP4→MXFP4 conversion in finalize_weights
|
||||
|
||||
self.gate = GateLinear(
|
||||
config.hidden_size,
|
||||
@@ -1045,7 +1172,7 @@ class DeepseekV4Attention(nn.Module):
|
||||
self.rope_parameters = config.rope_scaling
|
||||
|
||||
# Initialize rotary embedding BEFORE DeepseekV4MLAModules (which needs it)
|
||||
rope_parameters = config.rope_parameters
|
||||
rope_parameters = dict(config.rope_parameters)
|
||||
rope_parameters["rope_theta"] = (
|
||||
config.compress_rope_theta if self.compress_ratio > 1 else config.rope_theta
|
||||
)
|
||||
@@ -1236,6 +1363,15 @@ class DeepseekV4DecoderLayer(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
input_ids: torch.Tensor | None,
|
||||
) -> torch.Tensor:
|
||||
# DEBUG: skip attention entirely, just run FFN on raw input
|
||||
if int(os.environ.get('SKIP_ATTENTION', '0')):
|
||||
# Flatten to 2D for ffn, then restore
|
||||
org_shape = x.shape
|
||||
x_2d = x.view(-1, x.shape[-1])
|
||||
x_2d = self.ffn_norm(x_2d)
|
||||
x_2d = self.ffn(x_2d, input_ids)
|
||||
return x_2d.view(org_shape)
|
||||
|
||||
residual = x
|
||||
x, post, comb = self.hc_pre(
|
||||
x, self.hc_attn_fn, self.hc_attn_scale, self.hc_attn_base
|
||||
@@ -1262,9 +1398,7 @@ class DeepseekV4Model(nn.Module):
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.config = config
|
||||
self.use_mega_moe = (
|
||||
vllm_config.kernel_config.moe_backend == "deep_gemm_mega_moe"
|
||||
)
|
||||
self.use_mega_moe = True # Force mega_moe for NVFP4 pipeline
|
||||
if self.use_mega_moe and not vllm_config.parallel_config.enable_expert_parallel:
|
||||
raise NotImplementedError(
|
||||
"DeepSeek V4 MegaMoE currently requires expert parallel. "
|
||||
@@ -1461,14 +1595,19 @@ class DeepseekV4Model(nn.Module):
|
||||
else:
|
||||
if ".experts." in name:
|
||||
# E8M0 scales are stored as float8_e8m0fnu in
|
||||
# checkpoints but the MoE param is uint8. copy_()
|
||||
# would do a numeric conversion (e.g. 2^-7 → 0),
|
||||
# destroying the raw exponent bytes.
|
||||
# MXFP4 checkpoints but NVFP4 uses float8_e4m3fn.
|
||||
# The uint8 view+copy path is only valid for MXFP4;
|
||||
# for NVFP4 it would paste raw E8M0 bytes into an
|
||||
# E4M3 buffer, producing garbage.
|
||||
if (
|
||||
"weight_scale" in name
|
||||
and loaded_weight.dtype == torch.float8_e8m0fnu
|
||||
):
|
||||
loaded_weight = loaded_weight.view(torch.uint8)
|
||||
assert False, (
|
||||
f"E8M0 weight_scale encountered for NVFP4 experts "
|
||||
f"({name}) — this is only valid for MXFP4. "
|
||||
f"Check checkpoint dtype."
|
||||
)
|
||||
for mapping in expert_mapping:
|
||||
param_name, weight_name, expert_id, shard_id = mapping
|
||||
if weight_name not in name:
|
||||
@@ -1489,7 +1628,6 @@ class DeepseekV4Model(nn.Module):
|
||||
name_mapped,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id,
|
||||
return_success=True,
|
||||
)
|
||||
if success:
|
||||
name = name_mapped
|
||||
@@ -1537,16 +1675,10 @@ class DeepseekV4Model(nn.Module):
|
||||
weight_scale_2_val = global_amax / (6.0 * 448.0)
|
||||
weight_scale_2 = weight_scale_2_val.to(torch.float32)
|
||||
|
||||
# Per-block scale (weight_scale): UE8M0 format
|
||||
# scale_fmt=ue8m0: block_scale = 2^(exp-127), stored as
|
||||
# uint8 exponent byte viewed as float8_e4m3fn
|
||||
# Per-block scale (weight_scale): UE4M3 format (standard NVFP4)
|
||||
# block_scale = amax / (6.0 * weight_scale_2)
|
||||
block_scale = amax / (6.0 * weight_scale_2_val)
|
||||
# Convert to UE8M0: floor to nearest power of 2
|
||||
# UE8M0 exponent = floor(log2(block_scale)) + 127
|
||||
block_scale_clamped = block_scale.clamp(min=2**-127)
|
||||
block_scale_exp = torch.floor(torch.log2(block_scale_clamped)).to(torch.int32) + 127
|
||||
block_scale_exp = block_scale_exp.clamp(0, 254).to(torch.uint8)
|
||||
weight_scale = block_scale_exp.view(torch.float8_e4m3fn)
|
||||
weight_scale = block_scale.clamp(0.0, 448.0).to(torch.float8_e4m3fn)
|
||||
|
||||
# Quantize to FP4 (E2M1)
|
||||
# E2M1 LUT: 0, 0.5, 1, 1.5, 2, 3, 4, 6 (positive)
|
||||
@@ -1554,10 +1686,8 @@ class DeepseekV4Model(nn.Module):
|
||||
[0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0],
|
||||
dtype=torch.float32, device=w_bf16.device,
|
||||
)
|
||||
# For each block, dequantize the block scale from UE8M0
|
||||
block_scale_f32 = (block_scale_exp.to(torch.int32) << 23).view(torch.float32)
|
||||
# Scale the weight values: normalized = w / (block_scale * weight_scale_2)
|
||||
# We need to find the nearest FP4 value
|
||||
block_scale_f32 = block_scale.clamp(0.0, 448.0)
|
||||
scaled = w_blocks / (block_scale_f32.unsqueeze(-1) * weight_scale_2_val)
|
||||
# Find nearest FP4 index (0-7 for magnitude)
|
||||
# Use absolute value for matching, then apply sign
|
||||
@@ -1575,7 +1705,7 @@ class DeepseekV4Model(nn.Module):
|
||||
even = fp4_flat[:, 0::2] # lower nibble
|
||||
odd = fp4_flat[:, 1::2] # upper nibble
|
||||
packed = (odd << 4) | even
|
||||
weight_packed = packed.to(torch.uint8)
|
||||
weight_packed = packed.to(torch.uint8).view(torch.int8)
|
||||
|
||||
# Reshape weight_scale to [out, n_blocks]
|
||||
weight_scale_2d = weight_scale.reshape(out_dim, n_blocks)
|
||||
@@ -1647,8 +1777,8 @@ class DeepseekV4Model(nn.Module):
|
||||
.forward() which goes through quant_method; FP8 would dtype-mismatch)
|
||||
- compressor.fused_wkv_wgate: Dequant NVFP4->bf16 (used via direct
|
||||
torch.mm in attention parallel stream)
|
||||
- shared_experts (gate_up_proj, down_proj): Dequant NVFP4->bf16
|
||||
- MoE experts: Stay in native NVFP4 (ModelOptNvFp4FusedMoE)
|
||||
- shared_experts (gate_up_proj, down_proj): Stay native NVFP4 via DeepGEMM fp8_fp4_gemm
|
||||
- MoE experts: Handled by DeepseekV4MegaMoEExperts (NVFP4→MXFP4)
|
||||
"""
|
||||
E2M1_LUT = torch.tensor(
|
||||
[0, 0.5, 1, 1.5, 2, 3, 4, 6], dtype=torch.bfloat16
|
||||
@@ -1659,14 +1789,19 @@ class DeepseekV4Model(nn.Module):
|
||||
# for fp8_einsum. Only layer that needs FP8 conversion.
|
||||
fp8_proj_names = {"wo_a"}
|
||||
# Attention layers called via .forward() — need bf16
|
||||
bf16_proj_names = {"fused_wqa_wkv", "wq_b", "wo_b"}
|
||||
# Shared expert layers called via .forward() — need bf16
|
||||
bf16_shared_names = {"gate_up_proj", "down_proj"}
|
||||
# cuBLAS BF16 is broken on Blackwell — nothing gets dequantized to BF16.
|
||||
# Everything stays native NVFP4/FP8 via FlashInfer CUTLASS.
|
||||
bf16_proj_names = set()
|
||||
bf16_shared_names = set()
|
||||
|
||||
fp8_converted = 0
|
||||
fp8_from_bf16 = 0
|
||||
bf16_converted = 0
|
||||
compressor_converted = 0
|
||||
|
||||
# Build shard index once for compressor reconstruction (avoids N×M full-shard loads)
|
||||
_shard_index = self._build_shard_index("/model") if os.path.isdir("/model") else None
|
||||
|
||||
for layer_idx, layer in enumerate(self.layers):
|
||||
attn = layer.attn
|
||||
|
||||
@@ -1677,13 +1812,11 @@ class DeepseekV4Model(nn.Module):
|
||||
mod = getattr(attn, proj_name)
|
||||
if not hasattr(mod, "weight"):
|
||||
continue
|
||||
if mod.weight.dtype == torch.uint8:
|
||||
if mod.weight.dtype in (torch.uint8, torch.int8):
|
||||
# NVFP4 -> dequant to bf16 -> requant to FP8
|
||||
self._convert_nvfp4_to_fp8(mod, E2M1_LUT, FP8_MAX)
|
||||
fp8_converted += 1
|
||||
elif mod.weight.dtype == torch.bfloat16:
|
||||
# modelopt did NOT quantize o_a_proj — it's bf16 already.
|
||||
# Convert bf16 -> FP8 directly for fp8_einsum path.
|
||||
self._convert_bf16_to_fp8(mod, FP8_MAX)
|
||||
fp8_from_bf16 += 1
|
||||
|
||||
@@ -1692,7 +1825,7 @@ class DeepseekV4Model(nn.Module):
|
||||
if not hasattr(attn, proj_name):
|
||||
continue
|
||||
mod = getattr(attn, proj_name)
|
||||
if not hasattr(mod, "weight") or mod.weight.dtype != torch.uint8:
|
||||
if not hasattr(mod, "weight") or mod.weight.dtype not in (torch.uint8, torch.int8):
|
||||
continue
|
||||
self._dequant_nvfp4_to_bf16(mod, E2M1_LUT)
|
||||
bf16_converted += 1
|
||||
@@ -1710,23 +1843,23 @@ class DeepseekV4Model(nn.Module):
|
||||
compressor = getattr(mla_attn, "compressor", None)
|
||||
if compressor is not None and hasattr(compressor, "fused_wkv_wgate"):
|
||||
compressor_converted += self._reconstruct_compressor_weight(
|
||||
compressor.fused_wkv_wgate, attn, layer_idx, E2M1_LUT)
|
||||
compressor.fused_wkv_wgate, attn, layer_idx, E2M1_LUT, _shard_index=_shard_index)
|
||||
# Indexer compressor (C4A layers only)
|
||||
indexer = getattr(mla_attn, "indexer", None)
|
||||
if indexer is not None:
|
||||
idx_compressor = getattr(indexer, "compressor", None)
|
||||
if idx_compressor is not None and hasattr(idx_compressor, "fused_wkv_wgate"):
|
||||
compressor_converted += self._reconstruct_compressor_weight(
|
||||
idx_compressor.fused_wkv_wgate, indexer, layer_idx, E2M1_LUT, sub_path=".indexer")
|
||||
idx_compressor.fused_wkv_wgate, indexer, layer_idx, E2M1_LUT, sub_path=".indexer", _shard_index=_shard_index)
|
||||
|
||||
# Shared experts
|
||||
# Shared experts: dequantize NVFP4 → BF16
|
||||
ffn = layer.ffn
|
||||
if hasattr(ffn, "shared_experts") and ffn.shared_experts is not None:
|
||||
for proj_name in bf16_shared_names:
|
||||
if not hasattr(ffn.shared_experts, proj_name):
|
||||
continue
|
||||
mod = getattr(ffn.shared_experts, proj_name)
|
||||
if not hasattr(mod, "weight") or mod.weight.dtype != torch.uint8:
|
||||
if not hasattr(mod, "weight") or mod.weight.dtype not in (torch.uint8, torch.int8):
|
||||
continue
|
||||
self._dequant_nvfp4_to_bf16(mod, E2M1_LUT)
|
||||
bf16_converted += 1
|
||||
@@ -1737,8 +1870,7 @@ class DeepseekV4Model(nn.Module):
|
||||
print(f"NVFP4 post-load: {fp8_converted} NVFP4->FP8, "
|
||||
f"{fp8_from_bf16} BF16->FP8, "
|
||||
f"{bf16_converted} attn/shared->BF16, "
|
||||
f"{compressor_converted} compressor->BF16, "
|
||||
f"MoE experts stay NVFP4")
|
||||
f"{compressor_converted} compressor->BF16")
|
||||
|
||||
|
||||
def _dequant_nvfp4_to_bf16(self, mod, e2m1_lut):
|
||||
@@ -1749,9 +1881,8 @@ class DeepseekV4Model(nn.Module):
|
||||
|
||||
# Dequantize with scales
|
||||
if hasattr(mod, "weight_scale") and hasattr(mod, "weight_scale_2"):
|
||||
# scale_fmt=ue8m0: weight_scale bytes are E8M0 format (power-of-2 only).
|
||||
# A simple .to(float32) misinterprets them as E4M3. Must reinterpret
|
||||
# the raw uint8 bits as IEEE 754 exponent field.
|
||||
# NVFP4 block scales are float8_e4m3fn (UE4M3) — standard spec.
|
||||
# .to(float32) is correct (Bug #7: shift-by-23 was wrong, reverted)
|
||||
block_scale = self._ue8m0_to_float32(mod.weight_scale.data)
|
||||
if block_scale.dim() == 2 and w_bf16.dim() == 2:
|
||||
block_size = w_bf16.shape[1] // block_scale.shape[1]
|
||||
@@ -1773,8 +1904,10 @@ class DeepseekV4Model(nn.Module):
|
||||
else:
|
||||
w_dequant = w_bf16
|
||||
|
||||
# Replace weight with bf16 version
|
||||
# Free source tensors eagerly to avoid holding uint8+bf16+fp32 simultaneously
|
||||
del w_uint8, w_bf16
|
||||
mod.weight = torch.nn.Parameter(w_dequant, requires_grad=False)
|
||||
del w_dequant
|
||||
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
||||
mod.quant_method = UnquantizedLinearMethod()
|
||||
for attr in ("weight_scale", "weight_scale_2", "input_scale",
|
||||
@@ -1794,7 +1927,7 @@ class DeepseekV4Model(nn.Module):
|
||||
|
||||
# Dequantize with scales
|
||||
if hasattr(mod, "weight_scale") and hasattr(mod, "weight_scale_2"):
|
||||
# scale_fmt=ue8m0: reinterpret E8M0 bytes as float32
|
||||
# NVFP4 block scales: float8_e4m3fn → .to(float32) (Bug #7 reverted)
|
||||
block_scale = self._ue8m0_to_float32(mod.weight_scale.data)
|
||||
if block_scale.dim() == 2 and w_bf16.dim() == 2:
|
||||
block_size = w_bf16.shape[1] // block_scale.shape[1]
|
||||
@@ -1857,21 +1990,38 @@ class DeepseekV4Model(nn.Module):
|
||||
bmm_batch_size=bmm_batch_size,
|
||||
)
|
||||
|
||||
# Free source tensors eagerly
|
||||
del w_uint8, w_bf16, w_dequant
|
||||
mod.weight = torch.nn.Parameter(w_fp8, requires_grad=False)
|
||||
del w_fp8
|
||||
# weight_scale_inv is what the attention runtime reads as b_scale
|
||||
# for deepseek_v4_fp8_einsum -> DeepGEMM fp8_einsum.
|
||||
# It must be the DeepGEMM-formatted block scale (dg_ws), NOT the
|
||||
# per-tensor scalar. See: deepseek_v4_attention.py line 319.
|
||||
mod.weight_scale_inv = torch.nn.Parameter(ws, requires_grad=False)
|
||||
# weight_scale is not used at runtime for BMM layers; remove it
|
||||
# to avoid confusing other code paths.
|
||||
del ws
|
||||
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
||||
mod.quant_method = UnquantizedLinearMethod()
|
||||
for attr in ("weight_scale", "weight_scale_2", "input_scale"):
|
||||
if hasattr(mod, attr):
|
||||
delattr(mod, attr)
|
||||
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
||||
mod.quant_method = UnquantizedLinearMethod()
|
||||
|
||||
def _reconstruct_compressor_weight(self, fused_mod, parent_mod, layer_idx, e2m1_lut, sub_path=""):
|
||||
@staticmethod
|
||||
def _build_shard_index(ckpt_dir: str) -> dict[str, str]:
|
||||
"""Build key→shard_path index from safetensors metadata (no tensor I/O)."""
|
||||
import glob
|
||||
from safetensors import safe_open
|
||||
index = {}
|
||||
for shard_file in sorted(glob.glob(os.path.join(ckpt_dir, "model-*.safetensors"))):
|
||||
try:
|
||||
with safe_open(shard_file, framework="pt") as f:
|
||||
for key in f.keys():
|
||||
index[key] = shard_file
|
||||
except Exception:
|
||||
continue
|
||||
return index
|
||||
|
||||
def _reconstruct_compressor_weight(self, fused_mod, parent_mod, layer_idx, e2m1_lut, sub_path="", _shard_index=None):
|
||||
"""Reconstruct compressor fused_wkv_wgate from checkpoint.
|
||||
|
||||
Compressor weights are SKIPPED during loading because NVFP4 uint8 data
|
||||
@@ -1879,8 +2029,7 @@ class DeepseekV4Model(nn.Module):
|
||||
We read the original uint8 data from the safetensors checkpoint, unpack
|
||||
E2M1, dequantize, and stack into the fused weight param.
|
||||
"""
|
||||
import glob
|
||||
from safetensors.torch import load_file
|
||||
from safetensors import safe_open
|
||||
|
||||
# Find the checkpoint directory
|
||||
# The model weights are mounted at /model in Docker
|
||||
@@ -1895,49 +2044,45 @@ class DeepseekV4Model(nn.Module):
|
||||
# We read from checkpoint (before mapper), so use original names
|
||||
layer_prefix = f"model.layers.{layer_idx}.self_attn.compressor{sub_path}"
|
||||
|
||||
# Find which shard contains this layer's compressor weights
|
||||
wkv_key = f"{layer_prefix}.kv_proj.weight"
|
||||
wgate_key = f"{layer_prefix}.gate_proj.weight"
|
||||
wkv_scale_key = f"{layer_prefix}.kv_proj.weight_scale"
|
||||
wgate_scale_key = f"{layer_prefix}.gate_proj.weight_scale"
|
||||
wkv_scale2_key = f"{layer_prefix}.kv_proj.weight_scale_2"
|
||||
wgate_scale2_key = f"{layer_prefix}.gate_proj.weight_scale_2"
|
||||
wkv_iscale_key = f"{layer_prefix}.kv_proj.input_scale"
|
||||
wgate_iscale_key = f"{layer_prefix}.gate_proj.input_scale"
|
||||
# All keys we need from the checkpoint
|
||||
keys = {
|
||||
'wkv_uint8': f"{layer_prefix}.kv_proj.weight",
|
||||
'wgate_uint8': f"{layer_prefix}.gate_proj.weight",
|
||||
'wkv_block_scale': f"{layer_prefix}.kv_proj.weight_scale",
|
||||
'wgate_block_scale': f"{layer_prefix}.gate_proj.weight_scale",
|
||||
'wkv_global_scale': f"{layer_prefix}.kv_proj.weight_scale_2",
|
||||
'wgate_global_scale': f"{layer_prefix}.gate_proj.weight_scale_2",
|
||||
'wkv_input_scale': f"{layer_prefix}.kv_proj.input_scale",
|
||||
'wgate_input_scale': f"{layer_prefix}.gate_proj.input_scale",
|
||||
}
|
||||
|
||||
# Load from safetensors
|
||||
wkv_uint8 = None
|
||||
wgate_uint8 = None
|
||||
wkv_block_scale = None
|
||||
wgate_block_scale = None
|
||||
wkv_global_scale = None
|
||||
wgate_global_scale = None
|
||||
wkv_input_scale = None
|
||||
wgate_input_scale = None
|
||||
|
||||
shard_files = sorted(glob.glob(os.path.join(ckpt_dir, "model-*.safetensors")))
|
||||
for shard_file in shard_files:
|
||||
# Read tensors using shard index for targeted access (no full-shard loads)
|
||||
tensors = {}
|
||||
for name, key in keys.items():
|
||||
shard_path = (_shard_index or {}).get(key)
|
||||
if shard_path is None:
|
||||
continue
|
||||
try:
|
||||
shard_data = load_file(shard_file)
|
||||
with safe_open(shard_path, framework="pt") as f:
|
||||
if key in f.keys():
|
||||
tensors[name] = f.get_tensor(key)
|
||||
except Exception:
|
||||
continue
|
||||
if wkv_key in shard_data:
|
||||
wkv_uint8 = shard_data[wkv_key]
|
||||
wkv_block_scale = shard_data.get(wkv_scale_key)
|
||||
wkv_global_scale = shard_data.get(wkv_scale2_key)
|
||||
wkv_input_scale = shard_data.get(wkv_iscale_key)
|
||||
if wgate_key in shard_data:
|
||||
wgate_uint8 = shard_data[wgate_key]
|
||||
wgate_block_scale = shard_data.get(wgate_scale_key)
|
||||
wgate_global_scale = shard_data.get(wgate_scale2_key)
|
||||
wgate_input_scale = shard_data.get(wgate_iscale_key)
|
||||
if wkv_uint8 is not None and wgate_uint8 is not None:
|
||||
break
|
||||
|
||||
wkv_uint8 = tensors.get('wkv_uint8')
|
||||
wgate_uint8 = tensors.get('wgate_uint8')
|
||||
|
||||
if wkv_uint8 is None or wgate_uint8 is None:
|
||||
# Layer might not have a compressor (compress_ratio=1 layers)
|
||||
return 0
|
||||
|
||||
wkv_block_scale = tensors.get('wkv_block_scale')
|
||||
wgate_block_scale = tensors.get('wgate_block_scale')
|
||||
wkv_global_scale = tensors.get('wkv_global_scale')
|
||||
wgate_global_scale = tensors.get('wgate_global_scale')
|
||||
wkv_input_scale = tensors.get('wkv_input_scale')
|
||||
wgate_input_scale = tensors.get('wgate_input_scale')
|
||||
|
||||
device = fused_mod.weight.device
|
||||
wkv_uint8 = wkv_uint8.to(device)
|
||||
wgate_uint8 = wgate_uint8.to(device)
|
||||
@@ -1949,7 +2094,7 @@ class DeepseekV4Model(nn.Module):
|
||||
# Dequantize with scales
|
||||
def _dequant(w_bf16, block_scale, global_scale, input_scale):
|
||||
if block_scale is not None and global_scale is not None:
|
||||
# scale_fmt=ue8m0: reinterpret E8M0 bytes as float32
|
||||
# NVFP4 block scales: float8_e4m3fn → .to(float32) (Bug #7 reverted)
|
||||
block_scale = self._ue8m0_to_float32(block_scale.to(device))
|
||||
if block_scale.dim() == 2 and w_bf16.dim() == 2:
|
||||
block_size = w_bf16.shape[1] // block_scale.shape[1]
|
||||
@@ -1972,8 +2117,6 @@ class DeepseekV4Model(nn.Module):
|
||||
# fused_wkv_wgate.weight = cat([wkv, wgate], dim=0) → (2*head_dim, hidden_size)
|
||||
w_fused = torch.cat([wkv_dequant, wgate_dequant], dim=0)
|
||||
|
||||
# DEBUG: log shapes to diagnose compressor weight mismatch
|
||||
print(f"NVFP4 compressor layer {layer_idx}: wkv={wkv_dequant.shape}, wgate={wgate_dequant.shape}, fused={w_fused.shape}, existing_param={fused_mod.weight.shape}")
|
||||
|
||||
# Replace the weight
|
||||
fused_mod.weight = torch.nn.Parameter(w_fused, requires_grad=False)
|
||||
@@ -2041,17 +2184,12 @@ class DeepseekV4Model(nn.Module):
|
||||
|
||||
@staticmethod
|
||||
def _ue8m0_to_float32(sf: torch.Tensor) -> torch.Tensor:
|
||||
"""Convert UE8M0 (E8M0 power-of-2) scale bytes to float32.
|
||||
"""Convert NVFP4 block scales (float8_e4m3fn / UE4M3) to float32.
|
||||
|
||||
NVFP4 checkpoints with scale_fmt=ue8m0 store per-block weight scales as
|
||||
E8M0 format (8-bit exponent, no mantissa). The value = 2^(raw_byte - 127).
|
||||
The bytes are loaded as float8_e4m3fn by safetensors, but a simple
|
||||
.to(float32) misinterprets them as E4M3 (which has mantissa bits).
|
||||
Correct conversion: place the raw uint8 bits into the exponent field
|
||||
of an IEEE 754 float32 (bits 23-30), yielding 2^(raw-127) * implicit_1.
|
||||
Checkpoint stores float8_e4m3fn (standard NVFP4 spec, NOT UE8M0).
|
||||
Simple .to(float32) is correct — shift-by-23 was wrong (Bug #7 fix).
|
||||
"""
|
||||
raw_uint8 = sf.view(torch.uint8)
|
||||
return (raw_uint8.to(torch.int32) << 23).view(torch.float32)
|
||||
return sf.to(torch.float32)
|
||||
|
||||
def _unpack_nvfp4_to_bf16(self, w_uint8, e2m1_lut, device):
|
||||
"""Unpack NVFP4 uint8 packed weights to bf16 using E2M1 format."""
|
||||
@@ -2269,6 +2407,9 @@ class DeepseekV4ForCausalLM(nn.Module):
|
||||
loaded_params = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||
self.model.finalize_mega_moe_weights()
|
||||
self.model._convert_nvfp4_post_load()
|
||||
if os.environ.get('NVFP4_DEBUG_SYNC', '') == '1':
|
||||
torch.cuda.synchronize()
|
||||
print("[NVFP4] post-load conversion done, CUDA OK")
|
||||
return loaded_params
|
||||
|
||||
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||
|
||||
1155
patches/deepseek_v4_attention.py
Normal file
1155
patches/deepseek_v4_attention.py
Normal file
File diff suppressed because it is too large
Load Diff
133
patches/nvfp4_linear.py
Normal file
133
patches/nvfp4_linear.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""
|
||||
NVFP4 Linear Method — runs BF16 input through DeepGEMM fp8_fp4_gemm natively.
|
||||
|
||||
Weight format: NVFP4 (E2M1 packed int8 + UE4M3 block16 scales + float32 global scale)
|
||||
Activation: BF16 → FP8 e4m3fn with UE8M0 per-token scales
|
||||
GEMM: deep_gemm.fp8_fp4_gemm_nn(a=(fp8, ue8m0_scale), b=(nvfp4_packed, float32_scale))
|
||||
Output: BF16
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from vllm.model_executor.layers.linear import LinearMethodBase
|
||||
|
||||
|
||||
class NVFP4LinearMethod(LinearMethodBase):
|
||||
"""Linear method that runs BF16 x NVFP4 via DeepGEMM fp8_fp4_gemm.
|
||||
|
||||
The layer must have:
|
||||
- weight: E2M1 packed int8 (2 values per byte), shape (N, K//2)
|
||||
- weight_scale: float8_e4m3fn UE4M3 block scales, shape (N, K//16)
|
||||
- weight_scale_2: float32 global scale, shape (num_logical_weights,)
|
||||
- input_scale: float32 activation scale (unused, dynamic quant)
|
||||
"""
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
pass
|
||||
|
||||
def process_weights_after_loading(self, layer: nn.Module) -> None:
|
||||
"""Fold global scale into block scales and prepare for DeepGEMM consumption."""
|
||||
w_data = layer.weight.data
|
||||
device = w_data.device
|
||||
|
||||
if w_data.dtype not in (torch.uint8, torch.int8):
|
||||
return
|
||||
|
||||
N = w_data.shape[0]
|
||||
K = w_data.shape[1] * 2 # unpacked K
|
||||
|
||||
# Get block scales
|
||||
sf_e4m3 = None
|
||||
for attr in ("weight_scale", "weight_scale_inv"):
|
||||
if hasattr(layer, attr):
|
||||
sf_e4m3 = getattr(layer, attr).data
|
||||
break
|
||||
assert sf_e4m3 is not None
|
||||
|
||||
# Get global scale
|
||||
if hasattr(layer, "weight_global_scale"):
|
||||
global_scale = layer.weight_global_scale.data.to(torch.float32)
|
||||
elif hasattr(layer, "weight_scale_2"):
|
||||
ws2 = layer.weight_scale_2.data
|
||||
if ws2.numel() > 1:
|
||||
logical_widths = getattr(layer, 'logical_widths', None)
|
||||
if logical_widths is not None and len(ws2) == len(logical_widths):
|
||||
expanded = []
|
||||
for i, w in enumerate(logical_widths):
|
||||
expanded.append(ws2[i:i+1].expand(w))
|
||||
global_scale = torch.cat(expanded).to(torch.float32).unsqueeze(1)
|
||||
else:
|
||||
global_scale = ws2.max().to(torch.float32)
|
||||
else:
|
||||
global_scale = ws2.max().to(torch.float32)
|
||||
else:
|
||||
global_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||
|
||||
# Fold global scale into block scales and store as float32
|
||||
# (DeepGEMM fp8_fp4_gemm_nn expects float32 scales, NOT float8_e4m3fn)
|
||||
sf_f32 = sf_e4m3.to(torch.float32) * global_scale
|
||||
# Pad to align with gran_k=16 for DeepGEM
|
||||
sf_k = sf_f32.shape[1] # K//16
|
||||
gran_k = 16
|
||||
aligned_k = (sf_k + gran_k - 1) // gran_k * gran_k
|
||||
if aligned_k > sf_k:
|
||||
# Pad the scale tensor to be aligned
|
||||
sf_padded = torch.zeros(N, aligned_k, dtype=torch.float32, device=device)
|
||||
sf_padded[:, :sf_k] = sf_f32
|
||||
sf_f32 = sf_padded
|
||||
|
||||
layer.weight_scale_inv = nn.Parameter(sf_f32.contiguous(), requires_grad=False)
|
||||
del sf_f32, sf_e4m3
|
||||
|
||||
# Ensure weight is contiguous int8, K-major (required by DeepGEMM)
|
||||
if w_data.dtype == torch.uint8:
|
||||
layer.weight.data = w_data.view(torch.int8).contiguous()
|
||||
else:
|
||||
layer.weight.data = w_data.contiguous()
|
||||
|
||||
# Free source attributes
|
||||
for attr in ("weight_scale", "weight_scale_2", "input_scale",
|
||||
"weight_global_scale", "input_global_scale",
|
||||
"alpha", "input_global_scale_inv"):
|
||||
if hasattr(layer, attr):
|
||||
delattr(layer, attr)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
import deep_gemm
|
||||
|
||||
M, K = x.shape
|
||||
|
||||
# Quantize activation to FP8 with UE8M0 per-token scales
|
||||
x_fp8, x_sf = deep_gemm.per_token_cast_to_fp8(
|
||||
x, use_ue8m0=True, use_packed_ue8m0=True)
|
||||
|
||||
# Weight: E2M1 packed int8 + folded float32 block scales
|
||||
b_weight = layer.weight.data # (N, K//2) int8
|
||||
b_sf = layer.weight_scale_inv.data # (N, K//16) float32
|
||||
|
||||
N = b_weight.shape[0]
|
||||
d = torch.empty((M, N), dtype=torch.bfloat16, device=x.device)
|
||||
|
||||
# DeepGEMM fp8_fp4_gemm: A is FP8 (M, K), B is FP4 (N, K//2 packed)
|
||||
# B scales are float32 with gran_k=16 (NVFP4 block size)
|
||||
deep_gemm.fp8_fp4_gemm_nn(
|
||||
a=(x_fp8, x_sf),
|
||||
b=(b_weight, b_sf),
|
||||
d=d,
|
||||
recipe_b=(1, 16), # NVFP4: gran_mn=1, gran_k=16
|
||||
)
|
||||
|
||||
return d
|
||||
270
patches/staging_kernel.py
Normal file
270
patches/staging_kernel.py
Normal file
@@ -0,0 +1,270 @@
|
||||
"""
|
||||
NVFP4 staging kernel — full FP4 (E2M1) activations + UE4M3 block16 scales.
|
||||
|
||||
The mxf4nvf4 PTX instruction requires BOTH A and B to be FP4 (E2M1 packed).
|
||||
This kernel quantizes BF16 activations → E2M1 packed uint8 with UE4M3 scales.
|
||||
"""
|
||||
import triton
|
||||
import triton.language as tl
|
||||
import torch
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _deepseek_v4_stage_mega_moe_inputs_kernel(
|
||||
hidden_states,
|
||||
x_fp4, # uint8, shape (M, K//2) — E2M1 packed, 2 values per byte
|
||||
x_sf, # int32, shape (M, K//64) — UE4M3 packed, 4 scales per int32
|
||||
topk_ids,
|
||||
topk_weights,
|
||||
topk_idx_out,
|
||||
topk_weights_out,
|
||||
hidden_stride_m: tl.constexpr,
|
||||
hidden_stride_k: tl.constexpr,
|
||||
x_stride_m: tl.constexpr,
|
||||
x_stride_k: tl.constexpr,
|
||||
x_sf_stride_m: tl.constexpr,
|
||||
x_sf_stride_k: tl.constexpr,
|
||||
topk_ids_stride_m: tl.constexpr,
|
||||
topk_ids_stride_k: tl.constexpr,
|
||||
topk_weights_stride_m: tl.constexpr,
|
||||
topk_weights_stride_k: tl.constexpr,
|
||||
topk_idx_stride_m: tl.constexpr,
|
||||
topk_idx_stride_k: tl.constexpr,
|
||||
topk_weights_out_stride_m: tl.constexpr,
|
||||
topk_weights_out_stride_k: tl.constexpr,
|
||||
hidden_size: tl.constexpr,
|
||||
top_k: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr, # 128 elements (loaded from hidden)
|
||||
GROUP_K: tl.constexpr, # 16 (NVFP4 group_size)
|
||||
BLOCK_TOPK: tl.constexpr,
|
||||
) -> None:
|
||||
token_id = tl.program_id(0)
|
||||
k_block_id = tl.program_id(1)
|
||||
|
||||
k_offsets = k_block_id * BLOCK_K + tl.arange(0, BLOCK_K)
|
||||
k_mask = k_offsets < hidden_size
|
||||
hidden = tl.load(
|
||||
hidden_states + token_id * hidden_stride_m + k_offsets * hidden_stride_k,
|
||||
mask=k_mask,
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
|
||||
num_groups: tl.constexpr = BLOCK_K // GROUP_K # 8
|
||||
hidden_groups = tl.reshape(hidden, [num_groups, GROUP_K])
|
||||
abs_groups = tl.reshape(tl.abs(hidden), [num_groups, GROUP_K])
|
||||
amax = tl.max(abs_groups, axis=1)
|
||||
amax = tl.maximum(amax, 1.0e-4)
|
||||
|
||||
# ---- UE4M3 scale computation ----
|
||||
# scale = amax / 6.0 (E2M1 max value = 6)
|
||||
# Then quantize scale to UE4M3 format
|
||||
scale = amax / 6.0
|
||||
scale_bits = scale.to(tl.uint32, bitcast=True)
|
||||
scale_exp = (scale_bits >> 23) & 0xFF
|
||||
scale_mant = scale_bits & 0x7FFFFF
|
||||
|
||||
# Convert FP32 → E4M3 manually (with subnormal support)
|
||||
# FP32 bias=127, E4M3 bias=7 → raw exp = scale_exp - 120
|
||||
e4m3_exp_raw = scale_exp - 120 # can be negative → subnormal
|
||||
|
||||
# Normal path: exp >= 1, just truncate mantissa top 3 bits
|
||||
# RNE rounding: need guard (bit 19), sticky (OR of bits 18:0), and LSB of result
|
||||
normal_mant = scale_mant >> 20
|
||||
guard_bit = (scale_mant >> 19) & 1
|
||||
sticky_bit = tl.where((scale_mant & 0x7FFFF) != 0, 1, 0) # OR of bits [18:0]
|
||||
result_lsb = normal_mant & 1
|
||||
# RNE: round up if (guard=1 and sticky=1) or (guard=1 and sticky=0 and lsb=1)
|
||||
round_up = guard_bit & (sticky_bit | result_lsb)
|
||||
normal_mant = normal_mant + round_up
|
||||
normal_exp = e4m3_exp_raw
|
||||
|
||||
# Subnormal path: exp_raw <= 0
|
||||
# Insert implicit leading 1 and right-shift by (1 - exp_raw)
|
||||
# E4M3 subnormal: value = (mant/8) * 2^(1-7) = (mant/8) * 2^-6
|
||||
# So we need: (1 + mant_fp32/2^23) * 2^(exp_raw - 7) = (shifted_mant/8) * 2^-6
|
||||
# shifted_mant = (implicit_1 | mant_fp32) >> (1 - exp_raw - 1) then take top 3 bits
|
||||
shift = 1 - e4m3_exp_raw # positive when subnormal
|
||||
mant_with_leading = (0x800000 | scale_mant) # insert implicit 1
|
||||
# Right-shift to get into the 3-bit E4M3 mantissa window
|
||||
# We want bits [shift+19 : shift+23) of mant_with_leading for 3 mantissa bits + 1 round bit
|
||||
subnormal_mant = (mant_with_leading >> (shift.to(tl.int32) + 20)) & 0x7
|
||||
sub_guard_bit = (mant_with_leading >> (shift.to(tl.int32) + 19)) & 1
|
||||
# Sticky: OR of all bits below the guard bit in the shifted result
|
||||
# shift ≤ 8 in practice (amax floor = 1e-4 → scale ≈ 2^-15 → exp_raw ≈ -7), so mask ≤ 2^27
|
||||
sub_sticky_mask = (1 << (shift.to(tl.int32) + 19)) - 1
|
||||
sub_sticky_bit = tl.where((mant_with_leading & sub_sticky_mask) != 0, 1, 0)
|
||||
sub_result_lsb = subnormal_mant & 1
|
||||
sub_round_up = sub_guard_bit & (sub_sticky_bit | sub_result_lsb)
|
||||
subnormal_mant = subnormal_mant + sub_round_up
|
||||
|
||||
is_normal = e4m3_exp_raw >= 1
|
||||
e4m3_mant = tl.where(is_normal, normal_mant, subnormal_mant)
|
||||
e4m3_exp = tl.where(is_normal, normal_exp, 0) # exp=0 for subnormals
|
||||
|
||||
# Handle mantissa overflow after rounding
|
||||
overflow = e4m3_mant >= 8
|
||||
e4m3_mant = tl.where(overflow, 0, e4m3_mant)
|
||||
e4m3_exp = tl.where(overflow, e4m3_exp + 1, e4m3_exp)
|
||||
e4m3_exp = tl.maximum(e4m3_exp, 0)
|
||||
e4m3_exp = tl.minimum(e4m3_exp, 15)
|
||||
scale_e4m3_bits = (e4m3_exp << 3) | e4m3_mant
|
||||
|
||||
# Reconstruct dequantized scale by decoding the STORED E4M3 bits.
|
||||
# This guarantees the E2M1 quantization divides by exactly the value
|
||||
# the CUDA kernel will multiply back — same bits, single decode, no
|
||||
# possibility of encode/decode disagreement.
|
||||
stored_exp = (scale_e4m3_bits >> 3) & 0xF
|
||||
stored_mant = scale_e4m3_bits & 0x7
|
||||
e4m3_exp_for_recon = tl.maximum(stored_exp.to(tl.int32) - 7, -126)
|
||||
two_pow_exp_bits = (e4m3_exp_for_recon + 127).to(tl.uint32) << 23
|
||||
two_pow_exp = two_pow_exp_bits.to(tl.float32, bitcast=True)
|
||||
normal_value = (1.0 + stored_mant.to(tl.float32) / 8.0) * two_pow_exp
|
||||
subnormal_value = (stored_mant.to(tl.float32) / 8.0) * 0.015625
|
||||
e4m3_value = tl.where(stored_exp == 0, subnormal_value, normal_value)
|
||||
|
||||
# ---- E2M1 FP4 quantization (unpacked, 1 byte/element) ----
|
||||
# E2M1 LUT (unsigned): [0, 0.5, 1, 1.5, 2, 3, 4, 6]
|
||||
# Nearest-neighbor using thresholds (midpoints between consecutive values)
|
||||
scaled = hidden_groups * (1.0 / tl.maximum(e4m3_value, 1e-6))[:, None]
|
||||
# Clamp to E2M1 range [-6, 6]
|
||||
scaled = tl.maximum(scaled, -6.0)
|
||||
scaled = tl.minimum(scaled, 6.0)
|
||||
|
||||
abs_s = tl.abs(scaled)
|
||||
# Thresholds: midpoints between [0, 0.5, 1, 1.5, 2, 3, 4, 6]
|
||||
# [0, 0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0, INF]
|
||||
e2m1_idx = tl.where(abs_s < 0.25, 0,
|
||||
tl.where(abs_s < 0.75, 1,
|
||||
tl.where(abs_s < 1.25, 2,
|
||||
tl.where(abs_s < 1.75, 3,
|
||||
tl.where(abs_s < 2.5, 4,
|
||||
tl.where(abs_s < 3.5, 5,
|
||||
tl.where(abs_s < 5.0, 6, 7)))))))
|
||||
sign_bit = (scaled < 0).to(tl.int32)
|
||||
e2m1_4bit = (sign_bit << 3) | e2m1_idx # 4-bit: (sign << 3) | index
|
||||
|
||||
# Pack E2M1 pairs into single bytes (2 per byte, low nibble first)
|
||||
# mxf4nvf4 reads FP4 packed from SMEM — must match kernel's TMA layout
|
||||
e2m1_flat = tl.reshape(e2m1_4bit, [BLOCK_K])
|
||||
e2m1_lo = e2m1_flat[0::2] # even indices → low nibble
|
||||
e2m1_hi = e2m1_flat[1::2] # odd indices → high nibble
|
||||
e2m1_packed = (e2m1_hi << 4 | e2m1_lo).to(tl.uint8) # [BLOCK_K // 2]
|
||||
|
||||
k_offsets_out = k_block_id * (BLOCK_K // 2) + tl.arange(0, BLOCK_K // 2)
|
||||
k_mask_out = k_offsets_out < (hidden_size // 2)
|
||||
tl.store(
|
||||
x_fp4 + token_id * x_stride_m + k_offsets_out * x_stride_k,
|
||||
e2m1_packed,
|
||||
mask=k_mask_out,
|
||||
)
|
||||
|
||||
# Pack UE4M3 bytes into int32 (NVFP4: group_size=16, 4 groups per 64 elements)
|
||||
# 8 groups per k_block of 128 → 2 int32s per k_block
|
||||
# int32 can only pack 4 bytes (shifts >= 32 are UB), so split into two packs
|
||||
scale_offsets = tl.arange(0, num_groups) # [0..7]
|
||||
first_half = scale_offsets < 4 # groups 0-3 → int32[0]
|
||||
second_half = scale_offsets >= 4 # groups 4-7 → int32[1]
|
||||
|
||||
packed_lo = tl.sum(
|
||||
tl.where(first_half, scale_e4m3_bits.to(tl.int32) << (scale_offsets * 8), 0),
|
||||
axis=0,
|
||||
).to(tl.int32)
|
||||
packed_hi = tl.sum(
|
||||
tl.where(second_half, scale_e4m3_bits.to(tl.int32) << ((scale_offsets - 4) * 8), 0),
|
||||
axis=0,
|
||||
).to(tl.int32)
|
||||
|
||||
# Write 2 int32s per k_block: x_sf shape is (M, K//64) = (M, num_k_blocks * 2)
|
||||
sf_base = token_id * x_sf_stride_m + k_block_id * 2 * x_sf_stride_k
|
||||
tl.store(x_sf + sf_base, packed_lo)
|
||||
tl.store(x_sf + sf_base + x_sf_stride_k, packed_hi)
|
||||
|
||||
if k_block_id == 0:
|
||||
topk_offsets = tl.arange(0, BLOCK_TOPK)
|
||||
topk_mask = topk_offsets < top_k
|
||||
|
||||
ids = tl.load(
|
||||
topk_ids + token_id * topk_ids_stride_m + topk_offsets * topk_ids_stride_k,
|
||||
mask=topk_mask,
|
||||
other=0,
|
||||
).to(tl.int64)
|
||||
tl.store(
|
||||
topk_idx_out
|
||||
+ token_id * topk_idx_stride_m
|
||||
+ topk_offsets * topk_idx_stride_k,
|
||||
ids,
|
||||
mask=topk_mask,
|
||||
)
|
||||
|
||||
weights = tl.load(
|
||||
topk_weights
|
||||
+ token_id * topk_weights_stride_m
|
||||
+ topk_offsets * topk_weights_stride_k,
|
||||
mask=topk_mask,
|
||||
other=0.0,
|
||||
)
|
||||
tl.store(
|
||||
topk_weights_out
|
||||
+ token_id * topk_weights_out_stride_m
|
||||
+ topk_offsets * topk_weights_out_stride_k,
|
||||
weights,
|
||||
mask=topk_mask,
|
||||
)
|
||||
|
||||
|
||||
def _stage_deepseek_v4_mega_moe_inputs(
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
x_fp4: torch.Tensor, # uint8, shape (M, K//2)
|
||||
x_sf: torch.Tensor, # int32, shape (M, K//64)
|
||||
topk_idx_out: torch.Tensor,
|
||||
topk_weights_out: torch.Tensor,
|
||||
) -> None:
|
||||
num_tokens, hidden_size = hidden_states.shape
|
||||
if num_tokens == 0:
|
||||
return
|
||||
if hidden_size % 128 != 0:
|
||||
raise ValueError(
|
||||
"DeepSeek V4 MegaMoE input staging requires hidden_size to be "
|
||||
"a multiple of 128."
|
||||
)
|
||||
top_k = topk_ids.shape[1]
|
||||
if topk_weights.shape != topk_ids.shape:
|
||||
raise ValueError(
|
||||
"DeepSeek V4 MegaMoE input staging requires topk_weights and "
|
||||
"topk_ids to have the same shape."
|
||||
)
|
||||
|
||||
block_k = 128
|
||||
grid = (num_tokens, triton.cdiv(hidden_size, block_k))
|
||||
block_topk = triton.next_power_of_2(top_k)
|
||||
_deepseek_v4_stage_mega_moe_inputs_kernel[grid](
|
||||
hidden_states,
|
||||
x_fp4,
|
||||
x_sf,
|
||||
topk_ids,
|
||||
topk_weights,
|
||||
topk_idx_out,
|
||||
topk_weights_out,
|
||||
hidden_states.stride(0),
|
||||
hidden_states.stride(1),
|
||||
x_fp4.stride(0),
|
||||
x_fp4.stride(1),
|
||||
x_sf.stride(0),
|
||||
x_sf.stride(1),
|
||||
topk_ids.stride(0),
|
||||
topk_ids.stride(1),
|
||||
topk_weights.stride(0),
|
||||
topk_weights.stride(1),
|
||||
topk_idx_out.stride(0),
|
||||
topk_idx_out.stride(1),
|
||||
topk_weights_out.stride(0),
|
||||
topk_weights_out.stride(1),
|
||||
hidden_size,
|
||||
top_k,
|
||||
BLOCK_K=block_k,
|
||||
GROUP_K=16, # NVFP4: group_size=16 (scale_vec::4X)
|
||||
BLOCK_TOPK=block_topk,
|
||||
num_warps=4,
|
||||
)
|
||||
Reference in New Issue
Block a user