From f2656dcf6d4013adf8ee7e05ab5be36dc3ef4604 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Thu, 14 May 2026 14:12:52 +0000 Subject: [PATCH] sync B200 deployment files: Dockerfile, docker-compose, patches --- Dockerfile | 54 ++ docker-compose.yml | 36 +- patches/deepseek_v4.py | 585 +++++++++------ patches/deepseek_v4_attention.py | 1155 ++++++++++++++++++++++++++++++ patches/staging_kernel.py | 270 +++++++ 5 files changed, 1852 insertions(+), 248 deletions(-) create mode 100644 Dockerfile create mode 100644 patches/deepseek_v4_attention.py create mode 100644 patches/staging_kernel.py diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..783a55c --- /dev/null +++ b/Dockerfile @@ -0,0 +1,54 @@ +# DeepSeek V4 NVFP4 vLLM + CUTLASS NVFP4 Mega MoE Kernel +FROM vllm/vllm-openai:nightly-x86_64 + +# Remove broken nixl_ep (built against CUDA 12, image is CUDA 13) +RUN pip uninstall -y nixl-ep; rm -rf /usr/local/lib/python3.12/dist-packages/nixl_ep + +RUN apt-get update && apt-get install -y git screen cmake libcusolver-dev-13-0 libcusparse-dev-13-0 libcublas-dev-13-0 libcurand-dev-13-0 libcufft-dev-13-0 libnvjitlink-dev-13-0 && rm -rf /var/lib/apt/lists/* + +# Remove the broken symlink if it exists +RUN rm -f /usr/local/cuda/lib64/libcudart.so.12 + +ENV CUDA_HOME=/usr/local/cuda +ENV TORCH_CUDA_ARCH_LIST="10.0" + +# Clone latest CUTLASS (has NVFP4 block-scaled MMA support) +ARG CUTLASS_CACHE_BUSTER=1 +RUN git clone --depth 1 https://github.com/NVIDIA/cutlass.git /root/cutlass + +# Clone our NVFP4 mega_moe kernel +ARG KERNEL_CACHE_BUSTER=24 +RUN git clone https://sweetapi.com/biondizzle/nvfp4-megamoe-kernel.git /root/nvfp4-megamoe-kernel && \ + cd /root/nvfp4-megamoe-kernel && \ + pip install -e . + +# Build the CUTLASS NVFP4 block-scaled GEMM extension +RUN cd /root/nvfp4-megamoe-kernel/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm && \ + mkdir -p cutlass_nvfp4_gemm && \ + CUTLASS_INCLUDE_DIR=/root/cutlass/include \ + TORCH_CUDA_ARCH_LIST=10.0 \ + python3 setup.py build_ext --inplace + +# Install TileLang (for potential future use) +RUN pip install tilelang + +ENV PYTHONPATH="/root/nvfp4-megamoe-kernel/src/nvfp4_megamoe_kernel/cutlass_nvfp4_gemm:/root/nvfp4-megamoe-kernel:${PYTHONPATH}" + +# Copy patches +ARG PATCH_CACHE_BUSTER=82 +COPY patches/deepseek_v4.py /tmp/patches/deepseek_v4.py +COPY patches/staging_kernel.py /tmp/patches/staging_kernel.py +COPY patches/deepseek_v4_attention.py /tmp/patches/deepseek_v4_attention.py + +# Apply patches +RUN VLLM_MODELS_DIR=$(python3 -c "import vllm.model_executor.models; import os; print(os.path.dirname(vllm.model_executor.models.__file__))") && \ + VLLM_LAYERS_DIR=$(python3 -c "import vllm.model_executor.layers; import os; print(os.path.dirname(vllm.model_executor.layers.__file__))") && \ + cp /tmp/patches/deepseek_v4.py "$VLLM_MODELS_DIR/deepseek_v4.py" && \ + cp /tmp/patches/staging_kernel.py "$VLLM_MODELS_DIR/staging_kernel.py" && \ + cp /tmp/patches/deepseek_v4_attention.py "$VLLM_LAYERS_DIR/deepseek_v4_attention.py" && \ + rm -rf /tmp/patches + +# Verify +RUN python3 -c "import torch; import cutlass_nvfp4_gemm._C; print('CUTLASS NVFP4 OK')" && \ + python3 -c "import vllm; print('vLLM OK')" && \ + python3 -c "import nvfp4_megamoe_kernel; print('NVFP4 kernel OK')" diff --git a/docker-compose.yml b/docker-compose.yml index 886c56f..cd5cb0c 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,30 +1,22 @@ services: vllm: - image: atl.vultrcr.com/vllm/vllm-with-lmcache:dream-build - pull_policy: always - entrypoint: - - bash - - -c - - | - cp /patches/deepseek_v4.py /usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/deepseek_v4.py - exec vllm serve "$$@" - - -- + build: + context: . + ports: + - "8000:8000" environment: - - HF_TOKEN=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 + - OMP_NUM_THREADS=128 + - MEGA_MOE_DEBUG=1 + - MEGA_MOE_STATIC=0 + - MEGA_MOE_USE_CUTLASS=1 + - DG_JIT_DEBUG=1 command: - /model - --trust-remote-code - - --kv-cache-dtype=fp8 - - --block-size=256 - --enable-expert-parallel - --tensor-parallel-size=8 - - --compilation-config={"cudagraph_mode":"FULL_AND_PIECEWISE","custom_ops":["all"]} - - --attention_config.use_fp4_indexer_cache=True + - --enforce-eager - --tokenizer-mode=deepseek_v4 - - --tool-call-parser=deepseek_v4 - - --enable-auto-tool-choice - - --reasoning-parser=deepseek_v4 - - --speculative_config={"method":"mtp","num_speculative_tokens":2} - --host=0.0.0.0 - --port=8000 deploy: @@ -34,13 +26,5 @@ services: - driver: nvidia count: all capabilities: [gpu] - ipc: host - security_opt: - - seccomp:unconfined - tty: true - stdin_open: true volumes: - /root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4:/model:ro - - /root/nvidia-meeting/deepseek-v4-quant/patches/deepseek_v4.py:/patches/deepseek_v4.py:ro - - /root/nvidia-meeting/deepseek-v4-quant/patches:/patches:ro - network_mode: host diff --git a/patches/deepseek_v4.py b/patches/deepseek_v4.py index b4c7126..f477ff7 100644 --- a/patches/deepseek_v4.py +++ b/patches/deepseek_v4.py @@ -1,40 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# ============================================================================== -# DeepSeek V4 NVFP4 Patch — Version Banner (printed at import time) -# ============================================================================== -import datetime as _dt -import os as _os -_git_commit = _os.popen("git -C /root/nvidia-meeting/deepseek-v4-quant rev-parse --short HEAD 2>/dev/null || echo 'unknown'").read().strip() -print(f""" -{'='*70} - DeepSeek V4 NVFP4 Patch - {'='*70} - Commit: {_git_commit} - Loaded: {_dt.datetime.now().strftime('%Y-%m-%d %H:%M:%S UTC')} - Node: {_os.uname().nodename} - - Architecture: - wo_a → FP8 + DeepGEMM block scale (BMM einsum) - wq_b/wo_b → BF16 (UnquantizedLinearMethod) - fused_wqa → BF16 (stacked q_a + kv, dequantized from NVFP4) - compressor → BF16 (reconstructed from separate kv_proj+gate_proj) - shared_exp → FP8 (Fp8LinearMethod, DeepGEMM) - MoE experts → NVFP4 (FusedMoE, FLASHINFER_TRTLLM) — NOT converted - - Bugs fixed: - #1 DeepGEMM sf.dim() — block scale format (deepgemm_post_process) - #2 fused_skip_regex — q_b/o_a/o_b scales no longer skipped - #3 input_scale — removed from weight dequant (activations only) - #4 compressor indexer — sub_path for .indexer keys - #5 block scale dtype — must be float32, not float8_e4m3fn - #6 block scale values — torch.full(fp8_scale) not torch.ones - #7 UE8M0 block scale — .to(float32) misinterprets E8M0 as E4M3 -{'='*70} -""") -# ============================================================================== - import typing from collections.abc import Callable, Iterable from itertools import islice @@ -185,8 +151,9 @@ class DeepseekV4FP8Config(Fp8Config): try: hf_config = get_current_vllm_config().model_config.hf_config except Exception: - # vllm_config not yet set; defer the decision until a - # later call lands inside set_current_vllm_config. + # vllm_config not yet set; return safe default but do NOT + # cache — a later call inside set_current_vllm_config may + # resolve differently. return "fp4" expert_dtype = getattr(hf_config, "expert_dtype", "fp4") if expert_dtype not in _DEEPSEEK_V4_EXPERT_DTYPES: @@ -195,11 +162,6 @@ class DeepseekV4FP8Config(Fp8Config): f"expected one of {_DEEPSEEK_V4_EXPERT_DTYPES}." ) self._resolved_expert_dtype = expert_dtype - from vllm.logger import init_logger - - init_logger(__name__).info_once( - "DeepSeek V4 expert_dtype resolved to %r", expert_dtype - ) return self._resolved_expert_dtype @property @@ -244,11 +206,23 @@ class DeepseekV4FP8Config(Fp8Config): return isinstance(layer, FusedMoE) and self.expert_dtype == "fp4" +import triton +import triton.language as tl +import torch + +""" +NVFP4 staging kernel — full FP4 (E2M1) activations + UE4M3 block16 scales. + +The mxf4nvf4 PTX instruction requires BOTH A and B to be FP4 (E2M1 packed). +This kernel quantizes BF16 activations → E2M1 packed uint8 with UE4M3 scales. +""" + + @triton.jit def _deepseek_v4_stage_mega_moe_inputs_kernel( hidden_states, - x_fp8, - x_sf, + x_fp4, # uint8, shape (M, K//2) — E2M1 packed, 2 values per byte + x_sf, # int32, shape (M, K//64) — UE4M3 packed, 4 scales per int32 topk_ids, topk_weights, topk_idx_out, @@ -269,8 +243,8 @@ def _deepseek_v4_stage_mega_moe_inputs_kernel( topk_weights_out_stride_k: tl.constexpr, hidden_size: tl.constexpr, top_k: tl.constexpr, - BLOCK_K: tl.constexpr, - GROUP_K: tl.constexpr, + BLOCK_K: tl.constexpr, # 128 elements (loaded from hidden) + GROUP_K: tl.constexpr, # 16 (NVFP4 group_size) BLOCK_TOPK: tl.constexpr, ) -> None: token_id = tl.program_id(0) @@ -284,35 +258,94 @@ def _deepseek_v4_stage_mega_moe_inputs_kernel( other=0.0, ).to(tl.float32) - num_groups: tl.constexpr = BLOCK_K // GROUP_K - hidden_groups = tl.reshape(tl.abs(hidden), [num_groups, GROUP_K]) - amax = tl.max(hidden_groups, axis=1) + num_groups: tl.constexpr = BLOCK_K // GROUP_K # 8 + hidden_groups = tl.reshape(hidden, [num_groups, GROUP_K]) + abs_groups = tl.reshape(tl.abs(hidden), [num_groups, GROUP_K]) + amax = tl.max(abs_groups, axis=1) amax = tl.maximum(amax, 1.0e-4) - scale = amax / 448.0 + # ---- UE4M3 scale computation ---- + # scale = amax / 6.0 (E2M1 max value = 6) + # Then quantize scale to UE4M3 format + scale = amax / 6.0 scale_bits = scale.to(tl.uint32, bitcast=True) - scale_exp = ((scale_bits >> 23) & 0xFF) + ((scale_bits & 0x7FFFFF) != 0).to( - tl.uint32 - ) - scale_exp = tl.minimum(tl.maximum(scale_exp, 1), 254) - rounded_scale = (scale_exp << 23).to(tl.float32, bitcast=True) + scale_exp = (scale_bits >> 23) & 0xFF + scale_mant = scale_bits & 0x7FFFFF - hidden_groups = tl.reshape(hidden, [num_groups, GROUP_K]) - scaled = hidden_groups * (1.0 / rounded_scale)[:, None] - scaled = tl.reshape(scaled, [BLOCK_K]) - fp8 = scaled.to(tl.float8e4nv) + # Convert FP32 → E4M3 manually + e4m3_exp = scale_exp - 120 # FP32 bias=127, E4M3 bias=7 + e4m3_exp = tl.maximum(e4m3_exp, 0) + e4m3_exp = tl.minimum(e4m3_exp, 15) + e4m3_mant = scale_mant >> 20 + round_bit = (scale_mant >> 19) & 1 + e4m3_mant = e4m3_mant + round_bit + overflow = e4m3_mant >= 8 + e4m3_mant = tl.where(overflow, 0, e4m3_mant) + e4m3_exp = tl.where(overflow, e4m3_exp + 1, e4m3_exp) + e4m3_exp = tl.minimum(e4m3_exp, 15) + scale_e4m3_bits = (e4m3_exp << 3) | e4m3_mant + + # Reconstruct dequantized scale for E2M1 quantization + e4m3_exp_for_recon = tl.maximum(e4m3_exp.to(tl.int32) - 7, -126) + two_pow_exp_bits = (e4m3_exp_for_recon + 127).to(tl.uint32) << 23 + two_pow_exp = two_pow_exp_bits.to(tl.float32, bitcast=True) + normal_value = (1.0 + e4m3_mant.to(tl.float32) / 8.0) * two_pow_exp + subnormal_value = (e4m3_mant.to(tl.float32) / 8.0) * 0.015625 + e4m3_value = tl.where(e4m3_exp == 0, subnormal_value, normal_value) + + # ---- E2M1 FP4 quantization ---- + # E2M1 LUT (unsigned): [0, 0.5, 1, 1.5, 2, 3, 4, 6] + # Nearest-neighbor using thresholds (midpoints between consecutive values) + scaled = hidden_groups * (1.0 / tl.maximum(e4m3_value, 1e-6))[:, None] + # Clamp to E2M1 range [-6, 6] + scaled = tl.maximum(scaled, -6.0) + scaled = tl.minimum(scaled, 6.0) + + abs_s = tl.abs(scaled) + # E2M1 quantization using arithmetic instead of nested tl.where (Triton compile error) + # LUT: [0, 0.5, 1, 1.5, 2, 3, 4, 6] → thresholds at midpoints + # idx = sum(abs_s >= threshold_i) for thresholds [0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0] + e2m1_idx = ((abs_s >= 0.25).to(tl.int32) + (abs_s >= 0.75).to(tl.int32) + + (abs_s >= 1.25).to(tl.int32) + (abs_s >= 1.75).to(tl.int32) + + (abs_s >= 2.5).to(tl.int32) + (abs_s >= 3.5).to(tl.int32) + + (abs_s >= 5.0).to(tl.int32)) + sign_bit = (scaled < 0).to(tl.int32) + e2m1_4bit = (sign_bit << 3) | e2m1_idx # 4-bit: (sign << 3) | index + + # Pack 2 E2M1 values per byte: even→low nibble, odd→high nibble + PACKED_K: tl.constexpr = BLOCK_K // 2 # 64 + e2m1_pairs = tl.reshape(e2m1_4bit, [PACKED_K, 2]) + even, odd = tl.split(e2m1_pairs) # splits last axis (size 2) into two [PACKED_K] tensors + packed_byte = (odd.to(tl.uint8) << 4) | even.to(tl.uint8) + + packed_k_offsets = k_block_id * PACKED_K + tl.arange(0, PACKED_K) + packed_k_mask = packed_k_offsets < (hidden_size // 2) tl.store( - x_fp8 + token_id * x_stride_m + k_offsets * x_stride_k, - fp8, - mask=k_mask, + x_fp4 + token_id * x_stride_m + packed_k_offsets * x_stride_k, + packed_byte, + mask=packed_k_mask, ) - scale_offsets = tl.arange(0, num_groups) - packed_scale = tl.sum(scale_exp << (scale_offsets * 8), axis=0).to(tl.int32) - tl.store( - x_sf + token_id * x_sf_stride_m + k_block_id * x_sf_stride_k, - packed_scale, - ) + # Pack UE4M3 bytes into int32 (NVFP4: group_size=16, 4 groups per 64 elements) + # 8 groups per k_block of 128 → 2 int32s per k_block + # int32 can only pack 4 bytes (shifts >= 32 are UB on GPU), so split into two packs + scale_offsets = tl.arange(0, num_groups) # [0..7] + first_half = scale_offsets < 4 # groups 0-3 → int32[0] + second_half = scale_offsets >= 4 # groups 4-7 → int32[1] + + packed_lo = tl.sum( + tl.where(first_half, scale_e4m3_bits.to(tl.int32) << (scale_offsets * 8), 0), + axis=0, + ).to(tl.int32) + packed_hi = tl.sum( + tl.where(second_half, scale_e4m3_bits.to(tl.int32) << ((scale_offsets - 4) * 8), 0), + axis=0, + ).to(tl.int32) + + # Write 2 int32s per k_block: x_sf shape is (M, K//64) = (M, num_k_blocks * 2) + sf_base = token_id * x_sf_stride_m + k_block_id * 2 * x_sf_stride_k + tl.store(x_sf + sf_base, packed_lo) + tl.store(x_sf + sf_base + x_sf_stride_k, packed_hi) if k_block_id == 0: topk_offsets = tl.arange(0, BLOCK_TOPK) @@ -351,8 +384,8 @@ def _stage_deepseek_v4_mega_moe_inputs( hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - x_fp8: torch.Tensor, - x_sf: torch.Tensor, + x_fp4: torch.Tensor, # uint8, shape (M, K//2) + x_sf: torch.Tensor, # int32, shape (M, K//64) topk_idx_out: torch.Tensor, topk_weights_out: torch.Tensor, ) -> None: @@ -376,7 +409,7 @@ def _stage_deepseek_v4_mega_moe_inputs( block_topk = triton.next_power_of_2(top_k) _deepseek_v4_stage_mega_moe_inputs_kernel[grid]( hidden_states, - x_fp8, + x_fp4, x_sf, topk_ids, topk_weights, @@ -384,8 +417,8 @@ def _stage_deepseek_v4_mega_moe_inputs( topk_weights_out, hidden_states.stride(0), hidden_states.stride(1), - x_fp8.stride(0), - x_fp8.stride(1), + x_fp4.stride(0), + x_fp4.stride(1), x_sf.stride(0), x_sf.stride(1), topk_ids.stride(0), @@ -399,7 +432,7 @@ def _stage_deepseek_v4_mega_moe_inputs( hidden_size, top_k, BLOCK_K=block_k, - GROUP_K=32, + GROUP_K=16, # NVFP4: group_size=16 (scale_vec::4X) BLOCK_TOPK=block_topk, num_warps=4, ) @@ -425,8 +458,21 @@ def make_deepseek_v4_expert_params_mapping( class DeepseekV4MegaMoEExperts(nn.Module): + """MegaMoE experts for DeepSeek V4 with NVFP4 quantization. + + Loads NVFP4 expert weights (E2M1 packed uint8 + float8_e4m3fn block scales + + float32 global scales) and feeds them natively to the DeepGEMM + fp8_nvfp4_mega_moe kernel (kind::mxf4nvf4.scale_vec::4X). + + No conversion to MXFP4. Experts stay NVFP4. The global scale (weight_scale_2) + is folded into the block scales before kernel consumption. + """ _symm_buffer_cache: dict[tuple[int, int, int, int, int, int, int], object] = {} + # NVFP4 E2M1 lookup table (positive values, sign from bit 3) + E2M1_LUT = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0] + # MXFP4 E2M1 is the same format + def __init__( self, vllm_config: VllmConfig, @@ -451,52 +497,83 @@ class DeepseekV4MegaMoEExperts(nn.Module): self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens weight_attrs = {"weight_loader": self.weight_loader} + + # NVFP4 weights: E2M1 packed as uint8, 2 values per byte self.w13_weight = nn.Parameter( torch.zeros( num_local_experts, 2 * intermediate_size, hidden_size // 2, - dtype=torch.uint8, + dtype=torch.int8, ), requires_grad=False, ) set_weight_attrs(self.w13_weight, weight_attrs) + # NVFP4 block scales: float8_e4m3fn, group_size=16 + # Shape: [num_local_experts, 2*intermediate_size, hidden_size // 16] self.w13_weight_scale = nn.Parameter( torch.zeros( num_local_experts, 2 * intermediate_size, - hidden_size // 32, - dtype=torch.uint8, + hidden_size // 16, + dtype=torch.float8_e4m3fn, ), requires_grad=False, ) set_weight_attrs(self.w13_weight_scale, weight_attrs) self.w13_weight_scale.quant_method = "block" + # NVFP4 global scales: float32, per-expert + self.w13_weight_scale_2 = nn.Parameter( + torch.zeros(num_local_experts, dtype=torch.float32), + requires_grad=False, + ) + set_weight_attrs(self.w13_weight_scale_2, weight_attrs) + + # NVFP4 activation scales: float32, per-expert + self.w13_input_scale = nn.Parameter( + torch.zeros(num_local_experts, dtype=torch.float32), + requires_grad=False, + ) + set_weight_attrs(self.w13_input_scale, weight_attrs) + self.w2_weight = nn.Parameter( torch.zeros( num_local_experts, hidden_size, intermediate_size // 2, - dtype=torch.uint8, + dtype=torch.int8, ), requires_grad=False, ) set_weight_attrs(self.w2_weight, weight_attrs) + # NVFP4 block scales for w2 self.w2_weight_scale = nn.Parameter( torch.zeros( num_local_experts, hidden_size, - intermediate_size // 32, - dtype=torch.uint8, + intermediate_size // 16, + dtype=torch.float8_e4m3fn, ), requires_grad=False, ) set_weight_attrs(self.w2_weight_scale, weight_attrs) self.w2_weight_scale.quant_method = "block" + self.w2_weight_scale_2 = nn.Parameter( + torch.zeros(num_local_experts, dtype=torch.float32), + requires_grad=False, + ) + set_weight_attrs(self.w2_weight_scale_2, weight_attrs) + + self.w2_input_scale = nn.Parameter( + torch.zeros(num_local_experts, dtype=torch.float32), + requires_grad=False, + ) + set_weight_attrs(self.w2_input_scale, weight_attrs) + self._transformed_l1_weights: tuple[torch.Tensor, torch.Tensor] | None = None self._transformed_l2_weights: tuple[torch.Tensor, torch.Tensor] | None = None @@ -519,21 +596,25 @@ class DeepseekV4MegaMoEExperts(nn.Module): weight_name: str, shard_id: str, expert_id: int, - return_success: bool = False, - ) -> bool | None: + ) -> bool: local_expert_id = self._map_global_expert_id(expert_id) if local_expert_id == -1: - return False if return_success else None + return False + + # Scalar params (weight_scale_2, input_scale): 1D per-expert + if "weight_scale_2" in weight_name or "input_scale" in weight_name: + param.data[local_expert_id].copy_(loaded_weight) + return True expert_data = param.data[local_expert_id] if shard_id in ("w1", "w3"): if "w13_" not in weight_name: - return False if return_success else None + return False shard_offset = 0 if shard_id == "w1" else self.intermediate_size expert_data = expert_data.narrow(0, shard_offset, self.intermediate_size) elif shard_id == "w2": if "w2_" not in weight_name: - return False if return_success else None + return False else: raise ValueError(f"Unsupported expert shard id: {shard_id}") @@ -544,11 +625,7 @@ class DeepseekV4MegaMoEExperts(nn.Module): f"vs checkpoint {tuple(loaded_weight.shape)}" ) expert_data.copy_(loaded_weight) - return True if return_success else None - - @staticmethod - def _ue8m0_uint8_to_float(sf: torch.Tensor) -> torch.Tensor: - return (sf.to(torch.int32) << 23).view(torch.float32) + return True def _check_runtime_supported(self) -> None: if not torch.cuda.is_available(): @@ -558,7 +635,7 @@ class DeepseekV4MegaMoEExperts(nn.Module): raise NotImplementedError( "DeepSeek V4 MegaMoE expert weights must be loaded on CUDA." ) - if torch.cuda.get_device_capability(device)[0] != 10: + if torch.cuda.get_device_capability(device)[0] < 10: raise NotImplementedError("DeepGEMM MegaMoE requires SM100 GPUs.") if self.hidden_size % 128 != 0 or self.intermediate_size % 128 != 0: raise ValueError( @@ -571,41 +648,51 @@ class DeepseekV4MegaMoEExperts(nn.Module): return self._check_runtime_supported() - import vllm.third_party.deep_gemm as deep_gemm + from nvfp4_megamoe_kernel import transform_nvfp4_weights_for_mega_moe - w13_scale = deep_gemm.transform_sf_into_required_layout( - self._ue8m0_uint8_to_float(self.w13_weight_scale.data).contiguous(), - 2 * self.intermediate_size, - self.hidden_size, - (1, 32), - self.num_local_experts, - ) - w2_scale = deep_gemm.transform_sf_into_required_layout( - self._ue8m0_uint8_to_float(self.w2_weight_scale.data).contiguous(), - self.hidden_size, - self.intermediate_size, - (1, 32), - self.num_local_experts, - ) + # === Native NVFP4 path === + # The DeepGEMM nvfp4 mega_moe kernel consumes NVFP4 directly: + # - E2M1 packed uint8 (same as checkpoint) + # - UE4M3 block scales (float8_e4m3fn), group_size=16 + # - float32 global scale folded into block scales + # No conversion to MXFP4. Experts stay NVFP4. + + # Fold global scales into block scales and transform for the kernel self._transformed_l1_weights, self._transformed_l2_weights = ( - deep_gemm.transform_weights_for_mega_moe( - (self.w13_weight.data.view(torch.int8).contiguous(), w13_scale), - (self.w2_weight.data.view(torch.int8).contiguous(), w2_scale), + transform_nvfp4_weights_for_mega_moe( + (self.w13_weight.data.contiguous(), + self.w13_weight_scale.data.contiguous()), + (self.w2_weight.data.contiguous(), + self.w2_weight_scale.data.contiguous()), + l1_weight_scale_2=self.w13_weight_scale_2.data.contiguous(), + l2_weight_scale_2=self.w2_weight_scale_2.data.contiguous(), ) ) - # Drop the original loader-side parameters: the MegaMoE kernels only - # consume the transformed views above. transform_weights_for_mega_moe - # allocates a fresh tensor for the L1 weight (see _interleave_l1_weights) - # and fresh SF tensors for L1/L2; the L2 weight is the only tensor that - # aliases the original storage, and _transformed_l2_weights still holds - # it, so the storage stays live after we drop the Parameter. + + # Drop the original loader-side parameters self.w13_weight = None self.w13_weight_scale = None + self.w13_weight_scale_2 = None + self.w13_input_scale = None self.w2_weight = None self.w2_weight_scale = None + self.w2_weight_scale_2 = None + self.w2_input_scale = None + + @staticmethod + def _ue8m0_to_float32(sf: torch.Tensor) -> torch.Tensor: + """Convert NVFP4 block scales (float8_e4m3fn / UE4M3) to float32. + + Checkpoint stores float8_e4m3fn (standard NVFP4 spec, NOT UE8M0). + Simple .to(float32) is correct — shift-by-23 was wrong (Bug #7 fix). + """ + return sf.to(torch.float32) + + def get_symm_buffer(self): - import vllm.third_party.deep_gemm as deep_gemm + import nvfp4_megamoe_kernel as deep_gemm + from nvfp4_megamoe_kernel import SymmBuffer, get_symm_buffer_for_nvfp4_mega_moe group = get_ep_group().device_group device = torch.accelerator.current_device_index() @@ -620,7 +707,8 @@ class DeepseekV4MegaMoEExperts(nn.Module): ) symm_buffer = self._symm_buffer_cache.get(key) if symm_buffer is None: - symm_buffer = deep_gemm.get_symm_buffer_for_mega_moe( + # NVFP4 SymmBuffer: 2x SF size due to group_size=16 + symm_buffer = get_symm_buffer_for_nvfp4_mega_moe( group, self.num_experts, self.max_num_tokens, @@ -666,7 +754,7 @@ class DeepseekV4MegaMoEExperts(nn.Module): activation_clamp: float | None, fast_math: bool, ) -> None: - import vllm.third_party.deep_gemm as deep_gemm + import nvfp4_megamoe_kernel as deep_gemm symm_buffer = self.get_symm_buffer() num_tokens = hidden_states.shape[0] @@ -680,13 +768,57 @@ class DeepseekV4MegaMoEExperts(nn.Module): symm_buffer.topk_weights[:num_tokens], ) + # Debug: check staging output + import os + if int(os.environ.get('MEGA_MOE_DEBUG', '0')): + print(f"[MEGA_MOE_DEBUG] After staging: x dtype={symm_buffer.x.dtype} shape={symm_buffer.x.shape}") + print(f"[MEGA_MOE_DEBUG] x_sf dtype={symm_buffer.x_sf.dtype} shape={symm_buffer.x_sf.shape}") + print(f"[MEGA_MOE_DEBUG] topk_idx dtype={symm_buffer.topk_idx.dtype} shape={symm_buffer.topk_idx.shape}") + print(f"[MEGA_MOE_DEBUG] topk_weights dtype={symm_buffer.topk_weights.dtype} shape={symm_buffer.topk_weights.shape}") + # Check for NaN/Inf in the staging output + x_sample = symm_buffer.x[:num_tokens] + sf_sample = symm_buffer.x_sf[:num_tokens] + print(f"[MEGA_MOE_DEBUG] x range: min={x_sample.min().item()} max={x_sample.max().item()}") + if sf_sample.numel() > 0: + print(f"[MEGA_MOE_DEBUG] x_sf range: min={sf_sample.to(torch.float32).min().item()} max={sf_sample.to(torch.float32).max().item}") + topk_sample = symm_buffer.topk_idx[:num_tokens] + print(f"[MEGA_MOE_DEBUG] topk_idx range: min={topk_sample.min().item()} max={topk_sample.max().item()}") + torch.cuda.synchronize() + print("[MEGA_MOE_DEBUG] Staging CUDA sync OK") + # This method must have been already called during the weight loading phase. # We call it again here to cover the dummy weight loading case. self.finalize_weights() assert self._transformed_l1_weights is not None assert self._transformed_l2_weights is not None - deep_gemm.fp8_fp4_mega_moe( + from nvfp4_megamoe_kernel import nvfp4_mega_moe_full as fp8_nvfp4_mega_moe + + # Debug: dump shapes before mega_moe + import os + if int(os.environ.get('MEGA_MOE_DEBUG', '0')): + l1_w, l1_sf = self._transformed_l1_weights + l2_w, l2_sf = self._transformed_l2_weights + print(f"[MEGA_MOE_DEBUG] num_tokens={num_tokens}, hidden={hidden_states.shape[1]}") + print(f"[MEGA_MOE_DEBUG] l1_w: dtype={l1_w.dtype} shape={l1_w.shape} stride={l1_w.stride()}") + print(f"[MEGA_MOE_DEBUG] l1_sf: dtype={l1_sf.dtype} shape={l1_sf.shape} stride={l1_sf.stride()}") + print(f"[MEGA_MOE_DEBUG] l2_w: dtype={l2_w.dtype} shape={l2_w.shape} stride={l2_w.stride()}") + print(f"[MEGA_MOE_DEBUG] l2_sf: dtype={l2_sf.dtype} shape={l2_sf.shape} stride={l2_sf.stride()}") + print(f"[MEGA_MOE_DEBUG] symm_buffer nbytes={symm_buffer.buffer.nbytes} rank={symm_buffer.group.rank()}") + print(f"[MEGA_MOE_DEBUG] num_experts={self.num_experts} topk={topk_ids.shape[1]} max_tokens={self.max_num_tokens}") + print(f"[MEGA_MOE_DEBUG] y: dtype={y.dtype} shape={y.shape}") + # Force CUDA sync to catch any prior async errors + torch.cuda.synchronize() + print("[MEGA_MOE_DEBUG] CUDA sync OK — prior ops clean") + + # MEGA_MOE_STATIC: skip the kernel entirely, return zeros + # Tests whether the crash is in the kernel launch or in prior data prep + if int(os.environ.get('MEGA_MOE_STATIC', '0')): + print(f"[MEGA_MOE_STATIC] Skipping fp8_nvfp4_mega_moe, returning zeros") + y.zero_() + return + + fp8_nvfp4_mega_moe( y, self._transformed_l1_weights, self._transformed_l2_weights, @@ -694,6 +826,8 @@ class DeepseekV4MegaMoEExperts(nn.Module): activation_clamp=activation_clamp, fast_math=fast_math, ) + if os.environ.get('NVFP4_DEBUG_SYNC', '') == '1': + torch.cuda.synchronize() DeepseekV4MegaMoEExperts.weight_loader.supports_moe_loading = True # type: ignore[attr-defined] @@ -751,9 +885,7 @@ class DeepseekV4MoE(nn.Module): config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config self.prefix = prefix - self.use_mega_moe = ( - vllm_config.kernel_config.moe_backend == "deep_gemm_mega_moe" - ) + self.use_mega_moe = True # Force mega_moe for NVFP4 pipeline if self.use_mega_moe and not vllm_config.parallel_config.enable_expert_parallel: raise NotImplementedError( "DeepSeek V4 MegaMoE currently requires expert parallel. " @@ -774,12 +906,7 @@ class DeepseekV4MoE(nn.Module): raise NotImplementedError( "DeepSeek V4 MegaMoE currently supports sqrtsoftplus routing only." ) - if self.use_mega_moe and getattr(config, "expert_dtype", "fp4") != "fp4": - raise NotImplementedError( - "DeepSeek V4 MegaMoE only supports fp4 experts; got expert_dtype=" - f"{config.expert_dtype!r}. Drop --kernel-config moe_backend=" - "deep_gemm_mega_moe for this checkpoint." - ) + # NVFP4 experts work with mega_moe via NVFP4→MXFP4 conversion in finalize_weights self.gate = GateLinear( config.hidden_size, @@ -1045,7 +1172,7 @@ class DeepseekV4Attention(nn.Module): self.rope_parameters = config.rope_scaling # Initialize rotary embedding BEFORE DeepseekV4MLAModules (which needs it) - rope_parameters = config.rope_parameters + rope_parameters = dict(config.rope_parameters) rope_parameters["rope_theta"] = ( config.compress_rope_theta if self.compress_ratio > 1 else config.rope_theta ) @@ -1236,6 +1363,15 @@ class DeepseekV4DecoderLayer(nn.Module): positions: torch.Tensor, input_ids: torch.Tensor | None, ) -> torch.Tensor: + # DEBUG: skip attention entirely, just run FFN on raw input + if int(os.environ.get('SKIP_ATTENTION', '0')): + # Flatten to 2D for ffn, then restore + org_shape = x.shape + x_2d = x.view(-1, x.shape[-1]) + x_2d = self.ffn_norm(x_2d) + x_2d = self.ffn(x_2d, input_ids) + return x_2d.view(org_shape) + residual = x x, post, comb = self.hc_pre( x, self.hc_attn_fn, self.hc_attn_scale, self.hc_attn_base @@ -1262,9 +1398,7 @@ class DeepseekV4Model(nn.Module): config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config self.config = config - self.use_mega_moe = ( - vllm_config.kernel_config.moe_backend == "deep_gemm_mega_moe" - ) + self.use_mega_moe = True # Force mega_moe for NVFP4 pipeline if self.use_mega_moe and not vllm_config.parallel_config.enable_expert_parallel: raise NotImplementedError( "DeepSeek V4 MegaMoE currently requires expert parallel. " @@ -1461,14 +1595,19 @@ class DeepseekV4Model(nn.Module): else: if ".experts." in name: # E8M0 scales are stored as float8_e8m0fnu in - # checkpoints but the MoE param is uint8. copy_() - # would do a numeric conversion (e.g. 2^-7 → 0), - # destroying the raw exponent bytes. + # MXFP4 checkpoints but NVFP4 uses float8_e4m3fn. + # The uint8 view+copy path is only valid for MXFP4; + # for NVFP4 it would paste raw E8M0 bytes into an + # E4M3 buffer, producing garbage. if ( "weight_scale" in name and loaded_weight.dtype == torch.float8_e8m0fnu ): - loaded_weight = loaded_weight.view(torch.uint8) + assert False, ( + f"E8M0 weight_scale encountered for NVFP4 experts " + f"({name}) — this is only valid for MXFP4. " + f"Check checkpoint dtype." + ) for mapping in expert_mapping: param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: @@ -1489,7 +1628,6 @@ class DeepseekV4Model(nn.Module): name_mapped, shard_id=shard_id, expert_id=expert_id, - return_success=True, ) if success: name = name_mapped @@ -1537,16 +1675,10 @@ class DeepseekV4Model(nn.Module): weight_scale_2_val = global_amax / (6.0 * 448.0) weight_scale_2 = weight_scale_2_val.to(torch.float32) - # Per-block scale (weight_scale): UE8M0 format - # scale_fmt=ue8m0: block_scale = 2^(exp-127), stored as - # uint8 exponent byte viewed as float8_e4m3fn + # Per-block scale (weight_scale): UE4M3 format (standard NVFP4) + # block_scale = amax / (6.0 * weight_scale_2) block_scale = amax / (6.0 * weight_scale_2_val) - # Convert to UE8M0: floor to nearest power of 2 - # UE8M0 exponent = floor(log2(block_scale)) + 127 - block_scale_clamped = block_scale.clamp(min=2**-127) - block_scale_exp = torch.floor(torch.log2(block_scale_clamped)).to(torch.int32) + 127 - block_scale_exp = block_scale_exp.clamp(0, 254).to(torch.uint8) - weight_scale = block_scale_exp.view(torch.float8_e4m3fn) + weight_scale = block_scale.clamp(0.0, 448.0).to(torch.float8_e4m3fn) # Quantize to FP4 (E2M1) # E2M1 LUT: 0, 0.5, 1, 1.5, 2, 3, 4, 6 (positive) @@ -1554,10 +1686,8 @@ class DeepseekV4Model(nn.Module): [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32, device=w_bf16.device, ) - # For each block, dequantize the block scale from UE8M0 - block_scale_f32 = (block_scale_exp.to(torch.int32) << 23).view(torch.float32) # Scale the weight values: normalized = w / (block_scale * weight_scale_2) - # We need to find the nearest FP4 value + block_scale_f32 = block_scale.clamp(0.0, 448.0) scaled = w_blocks / (block_scale_f32.unsqueeze(-1) * weight_scale_2_val) # Find nearest FP4 index (0-7 for magnitude) # Use absolute value for matching, then apply sign @@ -1575,7 +1705,7 @@ class DeepseekV4Model(nn.Module): even = fp4_flat[:, 0::2] # lower nibble odd = fp4_flat[:, 1::2] # upper nibble packed = (odd << 4) | even - weight_packed = packed.to(torch.uint8) + weight_packed = packed.to(torch.uint8).view(torch.int8) # Reshape weight_scale to [out, n_blocks] weight_scale_2d = weight_scale.reshape(out_dim, n_blocks) @@ -1647,8 +1777,8 @@ class DeepseekV4Model(nn.Module): .forward() which goes through quant_method; FP8 would dtype-mismatch) - compressor.fused_wkv_wgate: Dequant NVFP4->bf16 (used via direct torch.mm in attention parallel stream) - - shared_experts (gate_up_proj, down_proj): Dequant NVFP4->bf16 - - MoE experts: Stay in native NVFP4 (ModelOptNvFp4FusedMoE) + - shared_experts (gate_up_proj, down_proj): Stay native NVFP4 via DeepGEMM fp8_fp4_gemm + - MoE experts: Handled by DeepseekV4MegaMoEExperts (NVFP4→MXFP4) """ E2M1_LUT = torch.tensor( [0, 0.5, 1, 1.5, 2, 3, 4, 6], dtype=torch.bfloat16 @@ -1659,14 +1789,19 @@ class DeepseekV4Model(nn.Module): # for fp8_einsum. Only layer that needs FP8 conversion. fp8_proj_names = {"wo_a"} # Attention layers called via .forward() — need bf16 - bf16_proj_names = {"fused_wqa_wkv", "wq_b", "wo_b"} - # Shared expert layers called via .forward() — need bf16 - bf16_shared_names = {"gate_up_proj", "down_proj"} + # cuBLAS BF16 is broken on Blackwell — nothing gets dequantized to BF16. + # Everything stays native NVFP4/FP8 via FlashInfer CUTLASS. + bf16_proj_names = set() + bf16_shared_names = set() fp8_converted = 0 fp8_from_bf16 = 0 bf16_converted = 0 compressor_converted = 0 + + # Build shard index once for compressor reconstruction (avoids N×M full-shard loads) + _shard_index = self._build_shard_index("/model") if os.path.isdir("/model") else None + for layer_idx, layer in enumerate(self.layers): attn = layer.attn @@ -1677,13 +1812,11 @@ class DeepseekV4Model(nn.Module): mod = getattr(attn, proj_name) if not hasattr(mod, "weight"): continue - if mod.weight.dtype == torch.uint8: + if mod.weight.dtype in (torch.uint8, torch.int8): # NVFP4 -> dequant to bf16 -> requant to FP8 self._convert_nvfp4_to_fp8(mod, E2M1_LUT, FP8_MAX) fp8_converted += 1 elif mod.weight.dtype == torch.bfloat16: - # modelopt did NOT quantize o_a_proj — it's bf16 already. - # Convert bf16 -> FP8 directly for fp8_einsum path. self._convert_bf16_to_fp8(mod, FP8_MAX) fp8_from_bf16 += 1 @@ -1692,7 +1825,7 @@ class DeepseekV4Model(nn.Module): if not hasattr(attn, proj_name): continue mod = getattr(attn, proj_name) - if not hasattr(mod, "weight") or mod.weight.dtype != torch.uint8: + if not hasattr(mod, "weight") or mod.weight.dtype not in (torch.uint8, torch.int8): continue self._dequant_nvfp4_to_bf16(mod, E2M1_LUT) bf16_converted += 1 @@ -1710,23 +1843,23 @@ class DeepseekV4Model(nn.Module): compressor = getattr(mla_attn, "compressor", None) if compressor is not None and hasattr(compressor, "fused_wkv_wgate"): compressor_converted += self._reconstruct_compressor_weight( - compressor.fused_wkv_wgate, attn, layer_idx, E2M1_LUT) + compressor.fused_wkv_wgate, attn, layer_idx, E2M1_LUT, _shard_index=_shard_index) # Indexer compressor (C4A layers only) indexer = getattr(mla_attn, "indexer", None) if indexer is not None: idx_compressor = getattr(indexer, "compressor", None) if idx_compressor is not None and hasattr(idx_compressor, "fused_wkv_wgate"): compressor_converted += self._reconstruct_compressor_weight( - idx_compressor.fused_wkv_wgate, indexer, layer_idx, E2M1_LUT, sub_path=".indexer") + idx_compressor.fused_wkv_wgate, indexer, layer_idx, E2M1_LUT, sub_path=".indexer", _shard_index=_shard_index) - # Shared experts + # Shared experts: dequantize NVFP4 → BF16 ffn = layer.ffn if hasattr(ffn, "shared_experts") and ffn.shared_experts is not None: for proj_name in bf16_shared_names: if not hasattr(ffn.shared_experts, proj_name): continue mod = getattr(ffn.shared_experts, proj_name) - if not hasattr(mod, "weight") or mod.weight.dtype != torch.uint8: + if not hasattr(mod, "weight") or mod.weight.dtype not in (torch.uint8, torch.int8): continue self._dequant_nvfp4_to_bf16(mod, E2M1_LUT) bf16_converted += 1 @@ -1737,8 +1870,7 @@ class DeepseekV4Model(nn.Module): print(f"NVFP4 post-load: {fp8_converted} NVFP4->FP8, " f"{fp8_from_bf16} BF16->FP8, " f"{bf16_converted} attn/shared->BF16, " - f"{compressor_converted} compressor->BF16, " - f"MoE experts stay NVFP4") + f"{compressor_converted} compressor->BF16") def _dequant_nvfp4_to_bf16(self, mod, e2m1_lut): @@ -1749,9 +1881,8 @@ class DeepseekV4Model(nn.Module): # Dequantize with scales if hasattr(mod, "weight_scale") and hasattr(mod, "weight_scale_2"): - # scale_fmt=ue8m0: weight_scale bytes are E8M0 format (power-of-2 only). - # A simple .to(float32) misinterprets them as E4M3. Must reinterpret - # the raw uint8 bits as IEEE 754 exponent field. + # NVFP4 block scales are float8_e4m3fn (UE4M3) — standard spec. + # .to(float32) is correct (Bug #7: shift-by-23 was wrong, reverted) block_scale = self._ue8m0_to_float32(mod.weight_scale.data) if block_scale.dim() == 2 and w_bf16.dim() == 2: block_size = w_bf16.shape[1] // block_scale.shape[1] @@ -1773,8 +1904,10 @@ class DeepseekV4Model(nn.Module): else: w_dequant = w_bf16 - # Replace weight with bf16 version + # Free source tensors eagerly to avoid holding uint8+bf16+fp32 simultaneously + del w_uint8, w_bf16 mod.weight = torch.nn.Parameter(w_dequant, requires_grad=False) + del w_dequant from vllm.model_executor.layers.linear import UnquantizedLinearMethod mod.quant_method = UnquantizedLinearMethod() for attr in ("weight_scale", "weight_scale_2", "input_scale", @@ -1794,7 +1927,7 @@ class DeepseekV4Model(nn.Module): # Dequantize with scales if hasattr(mod, "weight_scale") and hasattr(mod, "weight_scale_2"): - # scale_fmt=ue8m0: reinterpret E8M0 bytes as float32 + # NVFP4 block scales: float8_e4m3fn → .to(float32) (Bug #7 reverted) block_scale = self._ue8m0_to_float32(mod.weight_scale.data) if block_scale.dim() == 2 and w_bf16.dim() == 2: block_size = w_bf16.shape[1] // block_scale.shape[1] @@ -1857,21 +1990,38 @@ class DeepseekV4Model(nn.Module): bmm_batch_size=bmm_batch_size, ) + # Free source tensors eagerly + del w_uint8, w_bf16, w_dequant mod.weight = torch.nn.Parameter(w_fp8, requires_grad=False) + del w_fp8 # weight_scale_inv is what the attention runtime reads as b_scale # for deepseek_v4_fp8_einsum -> DeepGEMM fp8_einsum. # It must be the DeepGEMM-formatted block scale (dg_ws), NOT the # per-tensor scalar. See: deepseek_v4_attention.py line 319. mod.weight_scale_inv = torch.nn.Parameter(ws, requires_grad=False) - # weight_scale is not used at runtime for BMM layers; remove it - # to avoid confusing other code paths. + del ws + from vllm.model_executor.layers.linear import UnquantizedLinearMethod + mod.quant_method = UnquantizedLinearMethod() for attr in ("weight_scale", "weight_scale_2", "input_scale"): if hasattr(mod, attr): delattr(mod, attr) - from vllm.model_executor.layers.linear import UnquantizedLinearMethod - mod.quant_method = UnquantizedLinearMethod() - def _reconstruct_compressor_weight(self, fused_mod, parent_mod, layer_idx, e2m1_lut, sub_path=""): + @staticmethod + def _build_shard_index(ckpt_dir: str) -> dict[str, str]: + """Build key→shard_path index from safetensors metadata (no tensor I/O).""" + import glob + from safetensors import safe_open + index = {} + for shard_file in sorted(glob.glob(os.path.join(ckpt_dir, "model-*.safetensors"))): + try: + with safe_open(shard_file, framework="pt") as f: + for key in f.keys(): + index[key] = shard_file + except Exception: + continue + return index + + def _reconstruct_compressor_weight(self, fused_mod, parent_mod, layer_idx, e2m1_lut, sub_path="", _shard_index=None): """Reconstruct compressor fused_wkv_wgate from checkpoint. Compressor weights are SKIPPED during loading because NVFP4 uint8 data @@ -1879,8 +2029,7 @@ class DeepseekV4Model(nn.Module): We read the original uint8 data from the safetensors checkpoint, unpack E2M1, dequantize, and stack into the fused weight param. """ - import glob - from safetensors.torch import load_file + from safetensors import safe_open # Find the checkpoint directory # The model weights are mounted at /model in Docker @@ -1895,49 +2044,45 @@ class DeepseekV4Model(nn.Module): # We read from checkpoint (before mapper), so use original names layer_prefix = f"model.layers.{layer_idx}.self_attn.compressor{sub_path}" - # Find which shard contains this layer's compressor weights - wkv_key = f"{layer_prefix}.kv_proj.weight" - wgate_key = f"{layer_prefix}.gate_proj.weight" - wkv_scale_key = f"{layer_prefix}.kv_proj.weight_scale" - wgate_scale_key = f"{layer_prefix}.gate_proj.weight_scale" - wkv_scale2_key = f"{layer_prefix}.kv_proj.weight_scale_2" - wgate_scale2_key = f"{layer_prefix}.gate_proj.weight_scale_2" - wkv_iscale_key = f"{layer_prefix}.kv_proj.input_scale" - wgate_iscale_key = f"{layer_prefix}.gate_proj.input_scale" + # All keys we need from the checkpoint + keys = { + 'wkv_uint8': f"{layer_prefix}.kv_proj.weight", + 'wgate_uint8': f"{layer_prefix}.gate_proj.weight", + 'wkv_block_scale': f"{layer_prefix}.kv_proj.weight_scale", + 'wgate_block_scale': f"{layer_prefix}.gate_proj.weight_scale", + 'wkv_global_scale': f"{layer_prefix}.kv_proj.weight_scale_2", + 'wgate_global_scale': f"{layer_prefix}.gate_proj.weight_scale_2", + 'wkv_input_scale': f"{layer_prefix}.kv_proj.input_scale", + 'wgate_input_scale': f"{layer_prefix}.gate_proj.input_scale", + } - # Load from safetensors - wkv_uint8 = None - wgate_uint8 = None - wkv_block_scale = None - wgate_block_scale = None - wkv_global_scale = None - wgate_global_scale = None - wkv_input_scale = None - wgate_input_scale = None - - shard_files = sorted(glob.glob(os.path.join(ckpt_dir, "model-*.safetensors"))) - for shard_file in shard_files: + # Read tensors using shard index for targeted access (no full-shard loads) + tensors = {} + for name, key in keys.items(): + shard_path = (_shard_index or {}).get(key) + if shard_path is None: + continue try: - shard_data = load_file(shard_file) + with safe_open(shard_path, framework="pt") as f: + if key in f.keys(): + tensors[name] = f.get_tensor(key) except Exception: continue - if wkv_key in shard_data: - wkv_uint8 = shard_data[wkv_key] - wkv_block_scale = shard_data.get(wkv_scale_key) - wkv_global_scale = shard_data.get(wkv_scale2_key) - wkv_input_scale = shard_data.get(wkv_iscale_key) - if wgate_key in shard_data: - wgate_uint8 = shard_data[wgate_key] - wgate_block_scale = shard_data.get(wgate_scale_key) - wgate_global_scale = shard_data.get(wgate_scale2_key) - wgate_input_scale = shard_data.get(wgate_iscale_key) - if wkv_uint8 is not None and wgate_uint8 is not None: - break + + wkv_uint8 = tensors.get('wkv_uint8') + wgate_uint8 = tensors.get('wgate_uint8') if wkv_uint8 is None or wgate_uint8 is None: # Layer might not have a compressor (compress_ratio=1 layers) return 0 + wkv_block_scale = tensors.get('wkv_block_scale') + wgate_block_scale = tensors.get('wgate_block_scale') + wkv_global_scale = tensors.get('wkv_global_scale') + wgate_global_scale = tensors.get('wgate_global_scale') + wkv_input_scale = tensors.get('wkv_input_scale') + wgate_input_scale = tensors.get('wgate_input_scale') + device = fused_mod.weight.device wkv_uint8 = wkv_uint8.to(device) wgate_uint8 = wgate_uint8.to(device) @@ -1949,7 +2094,7 @@ class DeepseekV4Model(nn.Module): # Dequantize with scales def _dequant(w_bf16, block_scale, global_scale, input_scale): if block_scale is not None and global_scale is not None: - # scale_fmt=ue8m0: reinterpret E8M0 bytes as float32 + # NVFP4 block scales: float8_e4m3fn → .to(float32) (Bug #7 reverted) block_scale = self._ue8m0_to_float32(block_scale.to(device)) if block_scale.dim() == 2 and w_bf16.dim() == 2: block_size = w_bf16.shape[1] // block_scale.shape[1] @@ -1972,8 +2117,6 @@ class DeepseekV4Model(nn.Module): # fused_wkv_wgate.weight = cat([wkv, wgate], dim=0) → (2*head_dim, hidden_size) w_fused = torch.cat([wkv_dequant, wgate_dequant], dim=0) - # DEBUG: log shapes to diagnose compressor weight mismatch - print(f"NVFP4 compressor layer {layer_idx}: wkv={wkv_dequant.shape}, wgate={wgate_dequant.shape}, fused={w_fused.shape}, existing_param={fused_mod.weight.shape}") # Replace the weight fused_mod.weight = torch.nn.Parameter(w_fused, requires_grad=False) @@ -2041,17 +2184,12 @@ class DeepseekV4Model(nn.Module): @staticmethod def _ue8m0_to_float32(sf: torch.Tensor) -> torch.Tensor: - """Convert UE8M0 (E8M0 power-of-2) scale bytes to float32. + """Convert NVFP4 block scales (float8_e4m3fn / UE4M3) to float32. - NVFP4 checkpoints with scale_fmt=ue8m0 store per-block weight scales as - E8M0 format (8-bit exponent, no mantissa). The value = 2^(raw_byte - 127). - The bytes are loaded as float8_e4m3fn by safetensors, but a simple - .to(float32) misinterprets them as E4M3 (which has mantissa bits). - Correct conversion: place the raw uint8 bits into the exponent field - of an IEEE 754 float32 (bits 23-30), yielding 2^(raw-127) * implicit_1. + Checkpoint stores float8_e4m3fn (standard NVFP4 spec, NOT UE8M0). + Simple .to(float32) is correct — shift-by-23 was wrong (Bug #7 fix). """ - raw_uint8 = sf.view(torch.uint8) - return (raw_uint8.to(torch.int32) << 23).view(torch.float32) + return sf.to(torch.float32) def _unpack_nvfp4_to_bf16(self, w_uint8, e2m1_lut, device): """Unpack NVFP4 uint8 packed weights to bf16 using E2M1 format.""" @@ -2269,6 +2407,9 @@ class DeepseekV4ForCausalLM(nn.Module): loaded_params = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) self.model.finalize_mega_moe_weights() self.model._convert_nvfp4_post_load() + if os.environ.get('NVFP4_DEBUG_SYNC', '') == '1': + torch.cuda.synchronize() + print("[NVFP4] post-load conversion done, CUDA OK") return loaded_params def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: diff --git a/patches/deepseek_v4_attention.py b/patches/deepseek_v4_attention.py new file mode 100644 index 0000000..2bde4a5 --- /dev/null +++ b/patches/deepseek_v4_attention.py @@ -0,0 +1,1155 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +DeepseekV4 MLA Attention Layer +""" + +from collections.abc import Callable +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, cast + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import DeepseekV2Config, DeepseekV3Config + +import vllm.envs as envs +from vllm.model_executor.layers.linear import ( + ReplicatedLinear, +) +from vllm.model_executor.layers.sparse_attn_indexer import SparseAttnIndexer +from vllm.utils.deep_gemm import fp8_einsum +from vllm.utils.torch_utils import direct_register_custom_op +from vllm.v1.attention.ops.deepseek_v4_ops import ( + combine_topk_swa_indices, + compute_global_topk_indices_and_lens, + dequantize_and_gather_k_cache, + fused_indexer_q_rope_quant, + fused_inv_rope_fp8_quant, + fused_q_kv_rmsnorm, +) + +if TYPE_CHECKING: + from vllm.v1.attention.backends.mla.sparse_swa import ( + DeepseekSparseSWAMetadata, + ) + +from vllm.config import ( + CacheConfig, + VllmConfig, + get_current_vllm_config, +) +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.forward_context import ForwardContext, get_forward_context +from vllm.logger import init_logger +from vllm.model_executor.custom_op import PluggableLayer +from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase +from vllm.model_executor.layers.deepseek_compressor import DeepseekCompressor +from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.quantization.input_quant_fp8 import ( + QuantFP8, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape, +) +from vllm.utils.multi_stream_utils import ( + execute_in_parallel, + maybe_execute_in_parallel, +) +from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata +from vllm.v1.attention.backends.mla.flashmla_sparse import ( + DeepseekV4FlashMLASparseBackend, + FlashMLASparseBackend, + FlashMLASparseMetadata, +) +from vllm.v1.attention.backends.mla.indexer import ( + DeepseekV4IndexerBackend, + get_max_prefill_buffer_size, +) +from vllm.v1.attention.backends.mla.sparse_swa import DeepseekV4SWACache +from vllm.v1.attention.ops.flashmla import ( + flash_mla_sparse_fwd, + flash_mla_with_kvcache, +) +from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec +from vllm.v1.worker.workspace import current_workspace_manager + +logger = init_logger(__name__) + +# Prefill is processed in fixed-size chunks; this bounds the bf16 kv-gather +# workspace allocated at _forward_prefill (and the matching profile-time +# reservation in attention_impl's dummy-run branch). +PREFILL_CHUNK_SIZE = 4 + + +@dataclass +class DeepseekV4MLAModules: + """Modules used in DeepseekV4 MLA.""" + + vllm_config: VllmConfig + fused_wqa_wkv: torch.nn.Module + q_norm: torch.nn.Module + wq_b: torch.nn.Module + kv_norm: torch.nn.Module + wo_a: torch.nn.Module + wo_b: torch.nn.Module + attn_sink: torch.nn.Module + rotary_emb: torch.nn.Module + indexer: torch.nn.Module | None + indexer_rotary_emb: torch.nn.Module + topk_indices_buffer: torch.Tensor | None + aux_stream_list: list[torch.cuda.Stream] | None = None + + +# --8<-- [start:multi_head_latent_attention] +@PluggableLayer.register("deepseek_v4_multi_head_latent_attention") +class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer): + """Pluggable MLA layer which allows OOT backends to add + custom implementations of the outer MLA layer (including rope & o_proj). + Note that currently oot platforms can still use CustomOp.register_oot to + replace MLA layer entirely, although we use PluggableLayer to register + this layer now. + + This class takes positions and hidden_states as input. + The input tensors can either contain prefill tokens or decode tokens. + The class does the following: + + 1. MLA Preprocess. + 2. Perform multi-head attention to prefill tokens and + multi-query attention to decode tokens separately. + 3. Return the output tensor. + """ + + # --8<-- [end:multi_head_latent_attention] + + def __init__( + self, + hidden_size: int, + num_heads: int, + head_dim: int, + scale: float, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + v_head_dim: int, + q_lora_rank: int | None, + kv_lora_rank: int, + o_lora_rank: int | None, + mla_modules: DeepseekV4MLAModules, + window_size: int, + compress_ratio: int | None, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + self.n_local_heads = num_heads + self.head_dim = head_dim + self.scale = scale + + # FlashMLA sparse kernel only supports 64 or 128 heads; pad up to the + # next supported size. Must match DeepseekV4MLAAttention.padded_heads. + if num_heads <= 64: + self.padded_heads = 64 + elif num_heads <= 128: + self.padded_heads = 128 + else: + raise ValueError( + f"DeepseekV4 attention does not support {num_heads} heads " + "(must be <= 128)." + ) + + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + self.window_size = window_size + self.compress_ratio = compress_ratio if compress_ratio is not None else 1 + self.prefix = prefix + + # Extract config from vllm_config + config = mla_modules.vllm_config.model_config.hf_config + tp_size = get_tensor_model_parallel_world_size() + + # DeepseekV4-specific attributes (num_heads is already TP-adjusted) + self.eps = config.rms_norm_eps + self.rope_head_dim = config.qk_rope_head_dim + self.nope_head_dim = head_dim - self.rope_head_dim + self.n_local_groups = config.o_groups // tp_size + self.o_lora_rank = config.o_lora_rank + + # Store projection modules + self.fused_wqa_wkv = mla_modules.fused_wqa_wkv + self.q_norm = mla_modules.q_norm + self.wq_b = mla_modules.wq_b + + self.kv_norm = mla_modules.kv_norm + self.wo_a = mla_modules.wo_a + + self._wo_a_act_quant = QuantFP8( + static=False, + group_shape=GroupShape(1, 128), + use_ue8m0=True, + ) + # Bypass packed-for-deepgemm path — we need FP32 scales (not packed + # INT32) so fp8_einsum can handle layout transform internally. + self._wo_a_act_quant.use_deep_gemm_supported = False + self.wo_b = mla_modules.wo_b + + # Pick fp8_einsum recipe based on GPU arch: + # SM90: FP32 block scales stay [g, r/128, d/128] → sfb_gran_mn=128 + # SM100: INT32 packed scales become [g, r, ...] → sfb_gran_mn=1 + from vllm.platforms import current_platform + + cap = current_platform.get_device_capability() + assert cap is not None, "DeepseekV4 attention requires a CUDA device" + self._einsum_recipe = (1, 128, 128) if cap.major <= 9 else (1, 1, 128) + self._tma_aligned_scales = cap.major >= 10 + + self.rotary_emb = mla_modules.rotary_emb + self.indexer_rotary_emb = mla_modules.indexer_rotary_emb + self.topk_indices_buffer = mla_modules.topk_indices_buffer + + self.indexer = mla_modules.indexer + + # Per-head RMS normalization for Q (no learnable weights) + self.q_head_norm = RMSNorm(head_dim, eps=self.eps, has_weight=False) + + # TODO(yifan): currently hardcoded for FP8 sparse, make it more generic + head_bytes = ( + self.nope_head_dim # 448 fp8 NoPE + + self.rope_head_dim * 2 # 64 bf16 RoPE + + self.nope_head_dim // 64 # 7B scale factors + + 1 # 1B pad + ) + + self.aux_stream_list = mla_modules.aux_stream_list + # [0]: GEMM start / post-GEMM event0. [1..3]: GEMM done events; + # [1] doubles as post-GEMM event1. Reuse is safe: GEMM fully joins + # before post-GEMM starts. + self.ln_events = [torch.cuda.Event() for _ in range(4)] + + assert cache_config is not None, "DeepseekV4 attention requires cache_config" + self.swa_cache_layer = DeepseekV4SWACache( + head_dim=self.head_dim, + window_size=self.window_size, + dtype=torch.uint8, + prefix=f"{prefix}.swa_cache", + cache_config=cache_config, + ) + + self.mla_attn = DeepseekV4MLAAttention( + num_heads=self.n_local_heads, + head_dim=self.head_dim, + scale=self.scale, + qk_nope_head_dim=self.nope_head_dim, + qk_rope_head_dim=self.rope_head_dim, + q_lora_rank=self.q_lora_rank, + kv_lora_rank=self.kv_lora_rank, + compress_ratio=self.compress_ratio, + window_size=self.window_size, + head_bytes=head_bytes, + swa_cache_layer=self.swa_cache_layer, + attn_sink=mla_modules.attn_sink, # already padded with -inf + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + indexer=self.indexer, + topk_indices_buffer=self.topk_indices_buffer, + ) + # Register this layer in the compilation config's static forward context + # This allows the custom op to retrieve the layer during execution + compilation_config = mla_modules.vllm_config.compilation_config + # HACK + self.layer_name = prefix + ".deepseek_v4_multi_head_latent_attention" + if self.layer_name in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {self.layer_name}") + compilation_config.static_forward_context[self.layer_name] = self + + # Create the compressor for layers with compress_ratio > 1; after + # creating the DeepseekV4MLAAttention layer to get its cache. + self.compressor = None + if self.compress_ratio > 1: + self.compressor = DeepseekCompressor( + vllm_config=mla_modules.vllm_config, + compress_ratio=self.compress_ratio, + hidden_size=self.hidden_size, + head_dim=self.head_dim, + rotate=True, + prefix=f"{prefix}.compressor", + k_cache_prefix=self.mla_attn.prefix, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + llama_4_scaling: torch.Tensor | None = None, + ) -> torch.Tensor: + # Pre-allocate attention output with FlashMLA-padded head count. + # The op writes into `o_padded`; we slice to n_local_heads after. + num_tokens = hidden_states.shape[0] + o_padded = torch.empty( + (num_tokens, self.padded_heads, self.head_dim), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + + # Attention (inside custom op for torch.compile boundary) + torch.ops.vllm.deepseek_v4_attention( + hidden_states, + positions, + o_padded, + self.layer_name, + ) + o = o_padded[:, : self.n_local_heads, :] + + # O projection: inverse RoPE + FP8 quant + einsum + wo_b + o_fp8, o_scale = fused_inv_rope_fp8_quant( + o, + positions, + self.rotary_emb.cos_sin_cache.to(torch.float32), + n_groups=self.n_local_groups, + heads_per_group=self.n_local_heads // self.n_local_groups, + nope_dim=self.nope_head_dim, + rope_dim=self.rope_head_dim, + tma_aligned_scales=self._tma_aligned_scales, + ) + + wo_a_fp8 = self.wo_a.weight + wo_a_scale = self.wo_a.weight_scale_inv + + z = torch.empty( + (num_tokens, self.n_local_groups, self.o_lora_rank), + device=o.device, + dtype=torch.bfloat16, + ) + torch.ops.vllm.deepseek_v4_fp8_einsum( + o_fp8, + o_scale, + wo_a_fp8, + wo_a_scale, + z, + "bhr,hdr->bhd", + list(self._einsum_recipe), + ) + + return self.wo_b(z.flatten(1)) + + def attn_gemm_parallel_execute(self, hidden_states) -> tuple[Any, ...]: + assert self.aux_stream_list is not None + assert len(self.aux_stream_list) >= 3 + + # fused_wqa_wkv (heaviest) on default; the three lighter input GEMMs + # on aux streams 0..2 when their owning module exists. ln_events[0] + # is the fan-out start event; ln_events[1..3] are per-aux done events. + aux_fns: list[Callable[[], Any] | None] = [None, None, None] + + if self.compressor is not None: + # Local ref so the closure keeps a non-None type for mypy. + compressor = self.compressor + + def compressor_kv_score() -> torch.Tensor: + return torch.mm( + hidden_states, + compressor.fused_wkv_wgate.weight.T, + out_dtype=torch.float32, + ) + + aux_fns[0] = compressor_kv_score + + if self.indexer is not None: + indexer = self.indexer + + def indexer_weights_proj() -> torch.Tensor: + # ReplicatedLinear returns (output, bias); bias is None. + weights, _ = indexer.weights_proj(hidden_states) + return weights + + def indexer_compressor_kv_score() -> torch.Tensor: + return torch.mm( + hidden_states, + indexer.compressor.fused_wkv_wgate.weight.T, + out_dtype=torch.float32, + ) + + aux_fns[1] = indexer_weights_proj + aux_fns[2] = indexer_compressor_kv_score + + def fused_wqa_wkv() -> torch.Tensor: + # MergedColumnParallelLinear returns (output, bias); bias is None. + qr_kv, _ = self.fused_wqa_wkv(hidden_states) + return qr_kv + + qr_kv, (kv_score, indexer_weights, indexer_kv_score) = execute_in_parallel( + fused_wqa_wkv, + aux_fns, + self.ln_events[0], + self.ln_events[1:4], + self.aux_stream_list[:3], + enable=hidden_states.shape[0] + <= envs.VLLM_MULTI_STREAM_GEMM_TOKEN_THRESHOLD, + ) + + return qr_kv, kv_score, indexer_kv_score, indexer_weights + + def attention_impl( + self, + hidden_states: torch.Tensor, + positions: torch.Tensor, + out: torch.Tensor, # [num_tokens, padded_heads, head_dim], written in place + ) -> None: + forward_context = get_forward_context() + attn_metadata = forward_context.attn_metadata + + qr_kv, kv_score, indexer_kv_score, indexer_weights = ( + self.attn_gemm_parallel_execute(hidden_states) + ) + + qr, kv = qr_kv.split([self.q_lora_rank, self.head_dim], dim=-1) + qr, kv = fused_q_kv_rmsnorm( + qr, + kv, + self.q_norm.weight.data, + self.kv_norm.weight.data, + self.eps, + ) + + # wq_b + kv_insert (+ MLA compressor when an indexer is present) ride + # on the default stream so q stays on its consumer stream (mla_attn + # downstream reads q on default). Indexer/compressor go on aux for + # overlap with default's GEMM + cache write. + if self.indexer is not None: + assert self.aux_stream_list is not None + aux_stream = self.aux_stream_list[0] + indexer = self.indexer + # Local ref so the closure keeps a non-None type for mypy. + assert self.compressor is not None + compressor = self.compressor + + def wq_b_kv_insert_and_compress() -> torch.Tensor: + q = self.wq_b(qr).view(-1, self.n_local_heads, self.head_dim) + self._fused_qnorm_rope_kv_insert(q, kv, positions, attn_metadata) + compressor(kv_score, positions, self.rotary_emb) + return q + + q, _ = maybe_execute_in_parallel( + wq_b_kv_insert_and_compress, + lambda: indexer( + hidden_states, + qr, + indexer_kv_score, + indexer_weights, + positions, + self.indexer_rotary_emb, + ), + self.ln_events[0], + self.ln_events[1], + aux_stream, + ) + elif self.compressor is not None: + # wq_b + kv_insert on default, compressor on aux. + assert self.aux_stream_list is not None + aux_stream = self.aux_stream_list[0] + compressor = self.compressor + + def wq_b_kv_insert() -> torch.Tensor: + q = self.wq_b(qr).view(-1, self.n_local_heads, self.head_dim) + self._fused_qnorm_rope_kv_insert(q, kv, positions, attn_metadata) + return q + + q, _ = maybe_execute_in_parallel( + wq_b_kv_insert, + lambda: compressor(kv_score, positions, self.rotary_emb), + self.ln_events[0], + self.ln_events[1], + aux_stream, + ) + else: + # SWA-only layer: no compressor, no overlap. + q = self.wq_b(qr).view(-1, self.n_local_heads, self.head_dim) + self._fused_qnorm_rope_kv_insert(q, kv, positions, attn_metadata) + + # Handle dummy run (no metadata). + if not isinstance(attn_metadata, dict): + # Reserve _forward_prefill's bf16-gather workspace; the dummy + # run returns before mla_attn runs, so without this the shared + # workspace locks below the real prefill size. + sub = self.mla_attn + swa_only = sub.compress_ratio <= 1 + N = ( + 0 + if swa_only + else (sub.max_model_len + sub.compress_ratio - 1) // sub.compress_ratio + ) + M = N + sub.window_size + sub.max_num_batched_tokens + current_workspace_manager().get_simultaneous( + ((PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16), + ) + out.zero_() + return + + # Pad q to FlashMLA-required head count (64 or 128) + if self.n_local_heads < self.padded_heads: + pad_size = self.padded_heads - self.n_local_heads + q = F.pad(q, (0, 0, 0, pad_size), value=0.0) + + # MLA attention writes into the pre-allocated `out` buffer + # ([num_tokens, padded_heads, head_dim]). + self.mla_attn(q, kv, positions, output=out) + + def _fused_qnorm_rope_kv_insert( + self, + q: torch.Tensor, + kv: torch.Tensor, + positions: torch.Tensor, + attn_metadata: ( + dict[str, AttentionMetadata] | list[dict[str, AttentionMetadata]] | None + ), + ) -> None: + if not isinstance(attn_metadata, dict): + return + + swa_metadata = cast( + "DeepseekSparseSWAMetadata | None", + attn_metadata.get(self.swa_cache_layer.prefix), + ) + assert swa_metadata is not None + + swa_kv_cache = self.swa_cache_layer.kv_cache + swa_kv_cache_2d = swa_kv_cache.view(swa_kv_cache.shape[0], -1) + + # Horizontally fused: + # Q side: q_head_norm (per-head RMSNorm, no weight) + GPT-J RoPE + # KV side: GPT-J RoPE + UE8M0 FP8 quant + paged cache insert + # kv is unchanged; mla_attn reads kv solely via swa_kv_cache. + torch.ops._C.fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert( + q, + kv, + swa_kv_cache_2d, + swa_metadata.slot_mapping, + positions.to(torch.int64), + self.rotary_emb.cos_sin_cache.to(torch.float32), + self.eps, + swa_metadata.block_size, + ) + + +def deepseek_v4_attention( + hidden_states: torch.Tensor, + positions: torch.Tensor, + out: torch.Tensor, + layer_name: str, +) -> None: + forward_context: ForwardContext = get_forward_context() + self = forward_context.no_compile_layers[layer_name] + self.attention_impl(hidden_states, positions, out) + + +def deepseek_v4_attention_fake( + hidden_states: torch.Tensor, + positions: torch.Tensor, + out: torch.Tensor, + layer_name: str, +) -> None: + return None + + +direct_register_custom_op( + op_name="deepseek_v4_attention", + op_func=deepseek_v4_attention, + mutates_args=["out"], + fake_impl=deepseek_v4_attention_fake, +) + + +def deepseek_v4_fp8_einsum( + a: torch.Tensor, + a_scale: torch.Tensor, + b: torch.Tensor, + b_scale: torch.Tensor, + out: torch.Tensor, + equation: str, + recipe: list[int], +) -> None: + fp8_einsum(equation, (a, a_scale), (b, b_scale), out, recipe=tuple(recipe)) + + +def deepseek_v4_fp8_einsum_fake( + a: torch.Tensor, + a_scale: torch.Tensor, + b: torch.Tensor, + b_scale: torch.Tensor, + out: torch.Tensor, + equation: str, + recipe: list[int], +) -> None: + return None + + +direct_register_custom_op( + op_name="deepseek_v4_fp8_einsum", + op_func=deepseek_v4_fp8_einsum, + mutates_args=["out"], + fake_impl=deepseek_v4_fp8_einsum_fake, +) + + +class DeepseekV4MLAAttention(nn.Module, AttentionLayerBase): + # FlashMLA FP8 sparse only supports 64 or 128 heads + SUPPORTED_HEAD_COUNTS = (64, 128) + + def __init__( + self, + num_heads: int, + head_dim: int, + scale: float, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + q_lora_rank: int | None, + kv_lora_rank: int, + compress_ratio: int, + window_size: int, + head_bytes: int, + swa_cache_layer: DeepseekV4SWACache, + attn_sink: torch.Tensor, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + # Sparse MLA Args + indexer: object | None = None, + topk_indices_buffer: torch.Tensor | None = None, + aux_stream: torch.cuda.Stream | None = None, + **extra_impl_args, + ) -> None: + super().__init__() + self.num_heads = num_heads + self.num_kv_heads = 1 + self.head_dim = head_dim + self.scale = scale + self.window_size = window_size + self.head_bytes = head_bytes + self.compress_ratio = compress_ratio + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + self.nope_head_dim = qk_nope_head_dim + self.rope_head_dim = qk_rope_head_dim + self.indexer = indexer + self.topk_indices_buffer = topk_indices_buffer + + self.prefix = prefix # Alias for compatibility with compressor + + self.aux_stream = aux_stream + self.ln_events = [torch.cuda.Event(), torch.cuda.Event()] + + # Determine padded head count for FlashMLA + if num_heads not in self.SUPPORTED_HEAD_COUNTS: + if num_heads < 64: + self.padded_heads = 64 + elif num_heads < 128: + self.padded_heads = 128 + else: + raise ValueError( + f"DeepseekV4MLAAttention does not support {num_heads} heads. " + f"Supported: <= 128 (will be padded to 64 or 128)" + ) + else: + self.padded_heads = num_heads + + # Store attention sink + assert attn_sink is not None + self.attn_sink: torch.Tensor = attn_sink + # Store SWA cache + assert swa_cache_layer is not None + self.swa_cache_layer: DeepseekV4SWACache = swa_cache_layer + + # Get vllm config for cache setup + vllm_config = get_current_vllm_config() + self.max_num_batched_tokens = ( + vllm_config.scheduler_config.max_num_batched_tokens + ) + self.max_model_len = vllm_config.model_config.max_model_len + # DeepseekV4 only supports fp8 kv-cache format for now + kv_cache_dtype = cache_config.cache_dtype if cache_config is not None else "fp8" + + assert kv_cache_dtype.startswith("fp8"), ( + f"DeepseekV4 only supports fp8 kv-cache format for now, " + f"got {kv_cache_dtype}" + ) + assert issubclass(self.get_attn_backend(), FlashMLASparseBackend), ( + "Only FlashMLA Sparse Attention backend is supported for DeepseekV4 for now" + ) + # FlashMLA Sparse Attention fp8 backend uses "fp8_ds_mla" kv-cache format + # Automatically convert fp8 kv-cache format to "fp8_ds_mla" + if ( + issubclass(self.get_attn_backend(), FlashMLASparseBackend) + and kv_cache_dtype.startswith("fp8") + and kv_cache_dtype != "fp8_ds_mla" + ): + assert cache_config is not None + cache_config.cache_dtype = "fp8_ds_mla" + kv_cache_dtype = "fp8_ds_mla" + logger.info_once("Using DeepSeek's fp8_ds_mla KV cache format.") + + self.kv_cache_dtype = kv_cache_dtype + + # Register with compilation context for metadata lookup + compilation_config = vllm_config.compilation_config + if prefix and prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + if prefix: + compilation_config.static_forward_context[prefix] = self + + self.kv_cache = torch.tensor([]) + + def get_attn_backend(self) -> type[AttentionBackend]: + return DeepseekV4FlashMLASparseBackend + + def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None: + if ( + self.compress_ratio <= 1 + ): # SWA part. Allocated separately as DeepseekV4SWACache. + return None + return MLAAttentionSpec( + block_size=vllm_config.cache_config.block_size, + num_kv_heads=1, + head_size=self.head_dim, + dtype=torch.uint8, + compress_ratio=self.compress_ratio, + cache_dtype_str=self.kv_cache_dtype, + alignment=576, # NOTE: FlashMLA requires 576B alignment + model_version="deepseek_v4", + ) + + def forward( + self, + q: torch.Tensor, + kv: torch.Tensor, + positions: torch.Tensor, + output: torch.Tensor, + ) -> None: + assert output.shape == q.shape, ( + f"output buffer shape {output.shape} must match q shape {q.shape}" + ) + assert output.dtype == q.dtype, ( + f"output buffer dtype {output.dtype} must match q dtype {q.dtype}" + ) + + # Get SWA and indexer metadata from forward context + forward_context = get_forward_context() + attn_metadata = forward_context.attn_metadata + assert isinstance(attn_metadata, dict) + flashmla_metadata = cast( + FlashMLASparseMetadata | None, attn_metadata.get(self.prefix) + ) + swa_metadata = cast( + "DeepseekSparseSWAMetadata | None", + attn_metadata.get(self.swa_cache_layer.prefix), + ) + assert swa_metadata is not None + + swa_only = self.compress_ratio <= 1 + # SWA-only layers (compress_ratio <= 1) don't have their own KV cache + # allocation, so self.kv_cache may be empty after profiling cleanup. + self_kv_cache = self.kv_cache if not swa_only else None + swa_kv_cache = self.swa_cache_layer.kv_cache + + # Split prefill and decode + num_decodes = swa_metadata.num_decodes + num_prefills = swa_metadata.num_prefills + num_decode_tokens = swa_metadata.num_decode_tokens + + if num_prefills > 0: + self._forward_prefill( + q=q[num_decode_tokens:], + positions=positions[num_decode_tokens:], + compressed_k_cache=self_kv_cache, + swa_k_cache=swa_kv_cache, + output=output[num_decode_tokens:], + attn_metadata=flashmla_metadata, + swa_metadata=swa_metadata, + ) + if num_decodes > 0: + self._forward_decode( + q=q[:num_decode_tokens], + kv_cache=self_kv_cache, + swa_metadata=swa_metadata, + attn_metadata=flashmla_metadata, + swa_only=swa_only, + output=output[:num_decode_tokens], + ) + + def _forward_decode( + self, + q: torch.Tensor, + kv_cache: torch.Tensor | None, # Only used when compress_ratio > 1 + swa_metadata: "DeepseekSparseSWAMetadata", + attn_metadata: FlashMLASparseMetadata | None, + swa_only: bool, + output: torch.Tensor, + ) -> None: + num_decodes = swa_metadata.num_decodes + num_decode_tokens = swa_metadata.num_decode_tokens + + topk_indices = None + topk_lens = None + if not swa_only: + assert attn_metadata is not None + assert swa_metadata.is_valid_token is not None + block_size = attn_metadata.block_size // self.compress_ratio + is_valid = swa_metadata.is_valid_token[:num_decode_tokens] + if self.compress_ratio == 4: + # C4A: local indices differ per layer (filled by Indexer). + assert self.topk_indices_buffer is not None + global_indices, topk_lens = compute_global_topk_indices_and_lens( + self.topk_indices_buffer[:num_decode_tokens], + swa_metadata.token_to_req_indices, + attn_metadata.block_table[:num_decodes], + block_size, + is_valid, + ) + topk_indices = global_indices.view(num_decode_tokens, 1, -1) + else: + # C128A: pre-computed during metadata build. + topk_indices = attn_metadata.c128a_global_decode_topk_indices + topk_lens = attn_metadata.c128a_decode_topk_lens + + swa_indices = swa_metadata.decode_swa_indices + swa_lens = swa_metadata.decode_swa_lens + + # We treat queries in the same seq as different queries + # and later we only attend by generated indices. + # q arrives pre-padded to self.padded_heads by the outer wrapper. + q = q.unsqueeze(1) + + # Prepare SWA cache (num_blocks, swa_block_size, 1, head_bytes) + # Use unsqueeze to preserve strides (handles padded blocks correctly) + swa_cache = self.swa_cache_layer.kv_cache.unsqueeze(-2) + # Reshape KV cache to (num_blocks, block_size, 1, head_bytes) + if kv_cache is not None: + kv_cache = kv_cache.unsqueeze(-2) + + # One FlashMLASchedMeta per layer type, shared across all same-type + # layers within this decode step. The first forward call per type + # triggers the in-kernel planner (allocating tile_scheduler_metadata + # and num_splits via PyTorch's graph-aware allocator so CUDA graph + # capture reuses the same addresses on replay); subsequent same-type + # layers see have_initialized=True and skip the planner. + if self.compress_ratio <= 1: + tile_metadata = swa_metadata.tile_sched_swaonly + elif self.compress_ratio == 4: + tile_metadata = swa_metadata.tile_sched_c4a + elif self.compress_ratio == 128: + tile_metadata = swa_metadata.tile_sched_c128a + else: + raise ValueError( + f"Unsupported compress_ratio={self.compress_ratio}; " + "expected 1, 4, or 128." + ) + assert tile_metadata is not None, ( + "swa_metadata missing tile_sched entry for " + f"compress_ratio={self.compress_ratio}; " + "DeepseekSparseSWAMetadataBuilder.build_tile_scheduler did not " + "allocate one for this layer type." + ) + + out, _ = flash_mla_with_kvcache( + q=q, + k_cache=swa_cache, + block_table=None, + head_dim_v=512, + tile_scheduler_metadata=tile_metadata, + cache_seqlens=None, + is_fp8_kvcache=True, + indices=swa_indices, + topk_length=swa_lens, + softmax_scale=self.scale, + attn_sink=self.attn_sink, + extra_k_cache=kv_cache if not swa_only else None, + extra_indices_in_kvcache=topk_indices, + extra_topk_length=topk_lens, + out=output.unsqueeze(1), + ) + + def _forward_prefill( + self, + q: torch.Tensor, + positions: torch.Tensor, + compressed_k_cache: torch.Tensor | None, # Only used when compress_ratio > 1 + swa_k_cache: torch.Tensor, + output: torch.Tensor, + attn_metadata: FlashMLASparseMetadata | None, + swa_metadata: "DeepseekSparseSWAMetadata", + ) -> None: + swa_only = attn_metadata is None + + num_prefills = swa_metadata.num_prefills + num_prefill_tokens = swa_metadata.num_prefill_tokens + num_decodes = swa_metadata.num_decodes + num_decode_tokens = swa_metadata.num_decode_tokens + + # Use pre-computed prefill metadata. + seq_lens = swa_metadata.prefill_seq_lens + gather_lens = swa_metadata.prefill_gather_lens + assert seq_lens is not None + assert gather_lens is not None + + # Derive prefill-local token offsets from the full query_start_loc_cpu. + query_start_loc_cpu = swa_metadata.query_start_loc_cpu + query_start_loc = swa_metadata.query_start_loc + assert query_start_loc_cpu is not None + assert query_start_loc is not None + prefill_token_base = query_start_loc_cpu[num_decodes] + + if not swa_only: + if self.compress_ratio == 4: + assert self.topk_indices_buffer is not None + topk_indices = self.topk_indices_buffer[num_decode_tokens:] + topk_indices = topk_indices[:num_prefill_tokens] + else: + # C128A: pre-computed during metadata build. + assert attn_metadata is not None + topk_indices = attn_metadata.c128a_prefill_topk_indices + top_k = topk_indices.shape[-1] + # Compressed region must fit the full compressed pool (seq_len // + # compress_ratio), not just top_k. top_k bounds how many indices + # the indexer selects, not the pool size it indexes into. + N = (self.max_model_len + self.compress_ratio - 1) // self.compress_ratio + else: + # NOTE(woosuk): topk_indices will not be used for SWA-only layers. + assert self.topk_indices_buffer is not None + topk_indices = self.topk_indices_buffer[num_decode_tokens:] + top_k = 0 + N = 0 + + M = N + self.window_size + self.max_num_batched_tokens + num_chunks = (num_prefills + PREFILL_CHUNK_SIZE - 1) // PREFILL_CHUNK_SIZE + + workspace_manager = current_workspace_manager() + kv = workspace_manager.get_simultaneous( + ((PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16), + )[0] + for chunk_idx in range(num_chunks): + chunk_start = chunk_idx * PREFILL_CHUNK_SIZE + chunk_end = min(chunk_start + PREFILL_CHUNK_SIZE, num_prefills) + chunk_size = chunk_end - chunk_start + if not swa_only: + # Gather compressed KV + assert attn_metadata is not None + block_table = attn_metadata.block_table[num_decodes:] + dequantize_and_gather_k_cache( + kv[:chunk_size], + compressed_k_cache, + seq_lens=seq_lens[chunk_start:chunk_end] // self.compress_ratio, + gather_lens=None, + block_table=block_table[chunk_start:chunk_end], + block_size=attn_metadata.block_size // self.compress_ratio, + offset=0, + ) + + # Gather SWA KV + swa_block_table = swa_metadata.block_table[num_decodes:] + dequantize_and_gather_k_cache( + kv[:chunk_size], + swa_k_cache, + seq_lens=seq_lens[chunk_start:chunk_end], + gather_lens=gather_lens[chunk_start:chunk_end], + block_table=swa_block_table[chunk_start:chunk_end], + block_size=swa_metadata.block_size, + offset=N, + ) + + # Combine the topk indices and SWA indices for gathered KV cache + query_start = ( + query_start_loc_cpu[num_decodes + chunk_start] - prefill_token_base + ) + query_end = ( + query_start_loc_cpu[num_decodes + chunk_end] - prefill_token_base + ) + + combined_indices, combined_lens = combine_topk_swa_indices( + topk_indices[query_start:query_end], + query_start_loc[ + num_decodes + chunk_start : num_decodes + chunk_end + 1 + ], + seq_lens[chunk_start:chunk_end], + gather_lens[chunk_start:chunk_end], + self.window_size, + self.compress_ratio, + top_k, + M, + N, + ) + + output_chunk, _, _ = flash_mla_sparse_fwd( + q=q[query_start:query_end], + kv=kv.view(-1, 1, q.shape[-1]), + indices=combined_indices.unsqueeze(1), + sm_scale=self.scale, + attn_sink=self.attn_sink, + topk_length=combined_lens, + out=output[query_start:query_end], + ) + + +class DeepseekV4IndexerCache(torch.nn.Module, AttentionLayerBase): + def __init__( + self, + head_dim: int, + dtype: torch.dtype, + prefix: str, + cache_config: CacheConfig, + compress_ratio: int = 1, + ): + super().__init__() + self.kv_cache = torch.tensor([]) + self.head_dim = head_dim + self.prefix = prefix + self.cache_config = cache_config + self.dtype = dtype + self.compress_ratio = compress_ratio + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + + def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: + # head_dim already carries the fp8 scale padding + # compress_ratio=1 for V3.2, >1 for DeepseekV4; both use the same cache layout. + return MLAAttentionSpec( + block_size=self.cache_config.block_size, + num_kv_heads=1, + head_size=self.head_dim, + dtype=self.dtype, + compress_ratio=self.compress_ratio, + # DeepseekV4 aligns indexer pages to FlashMLA's 576B so they can pack with + # the indexer's compressor state cache. V3.2 keeps the legacy layout. + alignment=576, + ) + + def forward(self): ... + + def get_attn_backend(self) -> type[AttentionBackend]: + return DeepseekV4IndexerBackend + + +class DeepseekV4Indexer(nn.Module): + def __init__( + self, + vllm_config: VllmConfig, + config: DeepseekV2Config | DeepseekV3Config, + hidden_size: int, + q_lora_rank: int, + quant_config: QuantizationConfig | None, + cache_config: CacheConfig | None, + topk_indices_buffer: torch.Tensor | None, + compress_ratio: int = 1, + prefix: str = "", + ): + super().__init__() + self.vllm_config = vllm_config + self.config = config + self.quant_config = quant_config + # self.indexer_cfg = config.attn_module_list_cfg[0]["attn_index"] + self.topk_tokens = config.index_topk + self.n_head = config.index_n_heads # 64 + self.head_dim = config.index_head_dim # 128 + self.rope_dim = config.qk_rope_head_dim # 64 + self.q_lora_rank = q_lora_rank # 1536 + self.compress_ratio = compress_ratio + self.use_fp4_kv = self.vllm_config.attention_config.use_fp4_indexer_cache + logger.info_once( + "Using %s indexer cache for Lightning Indexer.", + "MXFP4" if self.use_fp4_kv else "FP8", + ) + + # no tensor parallel, just replicated + self.wq_b = ReplicatedLinear( + self.q_lora_rank, + self.head_dim * self.n_head, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.wq_b", + ) + self.weights_proj = ReplicatedLinear( + hidden_size, + self.n_head, + bias=False, + quant_config=None, + prefix=f"{prefix}.weights_proj", + ) + self.k_norm = LayerNorm(self.head_dim, eps=1e-6) + self.softmax_scale = self.head_dim**-0.5 + + self.scale_fmt = "ue8m0" + self.quant_block_size = 128 # TODO: get from config + self.topk_indices_buffer = topk_indices_buffer + + self.max_model_len = ( + vllm_config.model_config.max_model_len // self.compress_ratio + ) + self.prefix = prefix + + self.max_total_seq_len = ( + get_max_prefill_buffer_size(vllm_config) // self.compress_ratio + ) + + assert cache_config is not None, "Deepseek V4 indexer requires cache_config" + # NOTE(yifan): FP8 indxer cache use the same layout as V3.2: + # head_dim bytes = 128 fp8 + 4 fp32 scale = 132. + # For FP4 indexer cache, we still allocate the same amount of memory as FP8, + # but only use the first half of the memory. + k_cache_head_dim = self.head_dim + self.head_dim // self.quant_block_size * 4 + self.k_cache = DeepseekV4IndexerCache( + head_dim=k_cache_head_dim, + dtype=torch.uint8, + prefix=f"{prefix}.k_cache", + cache_config=cache_config, + compress_ratio=self.compress_ratio, + ) + self.compressor = DeepseekCompressor( + vllm_config=vllm_config, + compress_ratio=self.compress_ratio, + hidden_size=hidden_size, + head_dim=self.head_dim, + rotate=True, + prefix=f"{prefix}.compressor", + k_cache_prefix=self.k_cache.prefix, + use_fp4_cache=self.use_fp4_kv, + ) + + self.indexer_op = SparseAttnIndexer( + self.k_cache, + self.quant_block_size, + self.scale_fmt, + self.topk_tokens, + self.head_dim, + self.max_model_len, + self.max_total_seq_len, + self.topk_indices_buffer, + skip_k_cache_insert=True, + use_fp4_cache=self.use_fp4_kv, + ) + + def forward( + self, + hidden_states: torch.Tensor, + qr: torch.Tensor, + compressed_kv_score: torch.Tensor, + indexer_weights: torch.Tensor, + positions: torch.Tensor, + rotary_emb: nn.Module, + ) -> torch.Tensor: + # ReplicatedLinear returns (output, bias); bias is None. + q, _ = self.wq_b(qr) + q = q.view(-1, self.n_head, self.head_dim) + k = self.compressor(compressed_kv_score, positions, rotary_emb) + q_quant, weights = fused_indexer_q_rope_quant( + positions, + q, + rotary_emb.cos_sin_cache, + indexer_weights, + self.softmax_scale, + self.n_head**-0.5, + use_fp4=self.use_fp4_kv, + ) + return self.indexer_op(hidden_states, q_quant, k, weights) diff --git a/patches/staging_kernel.py b/patches/staging_kernel.py new file mode 100644 index 0000000..3bc0fb0 --- /dev/null +++ b/patches/staging_kernel.py @@ -0,0 +1,270 @@ +""" +NVFP4 staging kernel — full FP4 (E2M1) activations + UE4M3 block16 scales. + +The mxf4nvf4 PTX instruction requires BOTH A and B to be FP4 (E2M1 packed). +This kernel quantizes BF16 activations → E2M1 packed uint8 with UE4M3 scales. +""" +import triton +import triton.language as tl +import torch + + +@triton.jit +def _deepseek_v4_stage_mega_moe_inputs_kernel( + hidden_states, + x_fp4, # uint8, shape (M, K//2) — E2M1 packed, 2 values per byte + x_sf, # int32, shape (M, K//64) — UE4M3 packed, 4 scales per int32 + topk_ids, + topk_weights, + topk_idx_out, + topk_weights_out, + hidden_stride_m: tl.constexpr, + hidden_stride_k: tl.constexpr, + x_stride_m: tl.constexpr, + x_stride_k: tl.constexpr, + x_sf_stride_m: tl.constexpr, + x_sf_stride_k: tl.constexpr, + topk_ids_stride_m: tl.constexpr, + topk_ids_stride_k: tl.constexpr, + topk_weights_stride_m: tl.constexpr, + topk_weights_stride_k: tl.constexpr, + topk_idx_stride_m: tl.constexpr, + topk_idx_stride_k: tl.constexpr, + topk_weights_out_stride_m: tl.constexpr, + topk_weights_out_stride_k: tl.constexpr, + hidden_size: tl.constexpr, + top_k: tl.constexpr, + BLOCK_K: tl.constexpr, # 128 elements (loaded from hidden) + GROUP_K: tl.constexpr, # 16 (NVFP4 group_size) + BLOCK_TOPK: tl.constexpr, +) -> None: + token_id = tl.program_id(0) + k_block_id = tl.program_id(1) + + k_offsets = k_block_id * BLOCK_K + tl.arange(0, BLOCK_K) + k_mask = k_offsets < hidden_size + hidden = tl.load( + hidden_states + token_id * hidden_stride_m + k_offsets * hidden_stride_k, + mask=k_mask, + other=0.0, + ).to(tl.float32) + + num_groups: tl.constexpr = BLOCK_K // GROUP_K # 8 + hidden_groups = tl.reshape(hidden, [num_groups, GROUP_K]) + abs_groups = tl.reshape(tl.abs(hidden), [num_groups, GROUP_K]) + amax = tl.max(abs_groups, axis=1) + amax = tl.maximum(amax, 1.0e-4) + + # ---- UE4M3 scale computation ---- + # scale = amax / 6.0 (E2M1 max value = 6) + # Then quantize scale to UE4M3 format + scale = amax / 6.0 + scale_bits = scale.to(tl.uint32, bitcast=True) + scale_exp = (scale_bits >> 23) & 0xFF + scale_mant = scale_bits & 0x7FFFFF + + # Convert FP32 → E4M3 manually (with subnormal support) + # FP32 bias=127, E4M3 bias=7 → raw exp = scale_exp - 120 + e4m3_exp_raw = scale_exp - 120 # can be negative → subnormal + + # Normal path: exp >= 1, just truncate mantissa top 3 bits + # RNE rounding: need guard (bit 19), sticky (OR of bits 18:0), and LSB of result + normal_mant = scale_mant >> 20 + guard_bit = (scale_mant >> 19) & 1 + sticky_bit = tl.where((scale_mant & 0x7FFFF) != 0, 1, 0) # OR of bits [18:0] + result_lsb = normal_mant & 1 + # RNE: round up if (guard=1 and sticky=1) or (guard=1 and sticky=0 and lsb=1) + round_up = guard_bit & (sticky_bit | result_lsb) + normal_mant = normal_mant + round_up + normal_exp = e4m3_exp_raw + + # Subnormal path: exp_raw <= 0 + # Insert implicit leading 1 and right-shift by (1 - exp_raw) + # E4M3 subnormal: value = (mant/8) * 2^(1-7) = (mant/8) * 2^-6 + # So we need: (1 + mant_fp32/2^23) * 2^(exp_raw - 7) = (shifted_mant/8) * 2^-6 + # shifted_mant = (implicit_1 | mant_fp32) >> (1 - exp_raw - 1) then take top 3 bits + shift = 1 - e4m3_exp_raw # positive when subnormal + mant_with_leading = (0x800000 | scale_mant) # insert implicit 1 + # Right-shift to get into the 3-bit E4M3 mantissa window + # We want bits [shift+19 : shift+23) of mant_with_leading for 3 mantissa bits + 1 round bit + subnormal_mant = (mant_with_leading >> (shift.to(tl.int32) + 20)) & 0x7 + sub_guard_bit = (mant_with_leading >> (shift.to(tl.int32) + 19)) & 1 + # Sticky: OR of all bits below the guard bit in the shifted result + # shift ≤ 8 in practice (amax floor = 1e-4 → scale ≈ 2^-15 → exp_raw ≈ -7), so mask ≤ 2^27 + sub_sticky_mask = (1 << (shift.to(tl.int32) + 19)) - 1 + sub_sticky_bit = tl.where((mant_with_leading & sub_sticky_mask) != 0, 1, 0) + sub_result_lsb = subnormal_mant & 1 + sub_round_up = sub_guard_bit & (sub_sticky_bit | sub_result_lsb) + subnormal_mant = subnormal_mant + sub_round_up + + is_normal = e4m3_exp_raw >= 1 + e4m3_mant = tl.where(is_normal, normal_mant, subnormal_mant) + e4m3_exp = tl.where(is_normal, normal_exp, 0) # exp=0 for subnormals + + # Handle mantissa overflow after rounding + overflow = e4m3_mant >= 8 + e4m3_mant = tl.where(overflow, 0, e4m3_mant) + e4m3_exp = tl.where(overflow, e4m3_exp + 1, e4m3_exp) + e4m3_exp = tl.maximum(e4m3_exp, 0) + e4m3_exp = tl.minimum(e4m3_exp, 15) + scale_e4m3_bits = (e4m3_exp << 3) | e4m3_mant + + # Reconstruct dequantized scale by decoding the STORED E4M3 bits. + # This guarantees the E2M1 quantization divides by exactly the value + # the CUDA kernel will multiply back — same bits, single decode, no + # possibility of encode/decode disagreement. + stored_exp = (scale_e4m3_bits >> 3) & 0xF + stored_mant = scale_e4m3_bits & 0x7 + e4m3_exp_for_recon = tl.maximum(stored_exp.to(tl.int32) - 7, -126) + two_pow_exp_bits = (e4m3_exp_for_recon + 127).to(tl.uint32) << 23 + two_pow_exp = two_pow_exp_bits.to(tl.float32, bitcast=True) + normal_value = (1.0 + stored_mant.to(tl.float32) / 8.0) * two_pow_exp + subnormal_value = (stored_mant.to(tl.float32) / 8.0) * 0.015625 + e4m3_value = tl.where(stored_exp == 0, subnormal_value, normal_value) + + # ---- E2M1 FP4 quantization (unpacked, 1 byte/element) ---- + # E2M1 LUT (unsigned): [0, 0.5, 1, 1.5, 2, 3, 4, 6] + # Nearest-neighbor using thresholds (midpoints between consecutive values) + scaled = hidden_groups * (1.0 / tl.maximum(e4m3_value, 1e-6))[:, None] + # Clamp to E2M1 range [-6, 6] + scaled = tl.maximum(scaled, -6.0) + scaled = tl.minimum(scaled, 6.0) + + abs_s = tl.abs(scaled) + # Thresholds: midpoints between [0, 0.5, 1, 1.5, 2, 3, 4, 6] + # [0, 0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0, INF] + e2m1_idx = tl.where(abs_s < 0.25, 0, + tl.where(abs_s < 0.75, 1, + tl.where(abs_s < 1.25, 2, + tl.where(abs_s < 1.75, 3, + tl.where(abs_s < 2.5, 4, + tl.where(abs_s < 3.5, 5, + tl.where(abs_s < 5.0, 6, 7))))))) + sign_bit = (scaled < 0).to(tl.int32) + e2m1_4bit = (sign_bit << 3) | e2m1_idx # 4-bit: (sign << 3) | index + + # Pack E2M1 pairs into single bytes (2 per byte, low nibble first) + # mxf4nvf4 reads FP4 packed from SMEM — must match kernel's TMA layout + e2m1_flat = tl.reshape(e2m1_4bit, [BLOCK_K]) + e2m1_lo = e2m1_flat[0::2] # even indices → low nibble + e2m1_hi = e2m1_flat[1::2] # odd indices → high nibble + e2m1_packed = (e2m1_hi << 4 | e2m1_lo).to(tl.uint8) # [BLOCK_K // 2] + + k_offsets_out = k_block_id * (BLOCK_K // 2) + tl.arange(0, BLOCK_K // 2) + k_mask_out = k_offsets_out < (hidden_size // 2) + tl.store( + x_fp4 + token_id * x_stride_m + k_offsets_out * x_stride_k, + e2m1_packed, + mask=k_mask_out, + ) + + # Pack UE4M3 bytes into int32 (NVFP4: group_size=16, 4 groups per 64 elements) + # 8 groups per k_block of 128 → 2 int32s per k_block + # int32 can only pack 4 bytes (shifts >= 32 are UB), so split into two packs + scale_offsets = tl.arange(0, num_groups) # [0..7] + first_half = scale_offsets < 4 # groups 0-3 → int32[0] + second_half = scale_offsets >= 4 # groups 4-7 → int32[1] + + packed_lo = tl.sum( + tl.where(first_half, scale_e4m3_bits.to(tl.int32) << (scale_offsets * 8), 0), + axis=0, + ).to(tl.int32) + packed_hi = tl.sum( + tl.where(second_half, scale_e4m3_bits.to(tl.int32) << ((scale_offsets - 4) * 8), 0), + axis=0, + ).to(tl.int32) + + # Write 2 int32s per k_block: x_sf shape is (M, K//64) = (M, num_k_blocks * 2) + sf_base = token_id * x_sf_stride_m + k_block_id * 2 * x_sf_stride_k + tl.store(x_sf + sf_base, packed_lo) + tl.store(x_sf + sf_base + x_sf_stride_k, packed_hi) + + if k_block_id == 0: + topk_offsets = tl.arange(0, BLOCK_TOPK) + topk_mask = topk_offsets < top_k + + ids = tl.load( + topk_ids + token_id * topk_ids_stride_m + topk_offsets * topk_ids_stride_k, + mask=topk_mask, + other=0, + ).to(tl.int64) + tl.store( + topk_idx_out + + token_id * topk_idx_stride_m + + topk_offsets * topk_idx_stride_k, + ids, + mask=topk_mask, + ) + + weights = tl.load( + topk_weights + + token_id * topk_weights_stride_m + + topk_offsets * topk_weights_stride_k, + mask=topk_mask, + other=0.0, + ) + tl.store( + topk_weights_out + + token_id * topk_weights_out_stride_m + + topk_offsets * topk_weights_out_stride_k, + weights, + mask=topk_mask, + ) + + +def _stage_deepseek_v4_mega_moe_inputs( + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + x_fp4: torch.Tensor, # uint8, shape (M, K//2) + x_sf: torch.Tensor, # int32, shape (M, K//64) + topk_idx_out: torch.Tensor, + topk_weights_out: torch.Tensor, +) -> None: + num_tokens, hidden_size = hidden_states.shape + if num_tokens == 0: + return + if hidden_size % 128 != 0: + raise ValueError( + "DeepSeek V4 MegaMoE input staging requires hidden_size to be " + "a multiple of 128." + ) + top_k = topk_ids.shape[1] + if topk_weights.shape != topk_ids.shape: + raise ValueError( + "DeepSeek V4 MegaMoE input staging requires topk_weights and " + "topk_ids to have the same shape." + ) + + block_k = 128 + grid = (num_tokens, triton.cdiv(hidden_size, block_k)) + block_topk = triton.next_power_of_2(top_k) + _deepseek_v4_stage_mega_moe_inputs_kernel[grid]( + hidden_states, + x_fp4, + x_sf, + topk_ids, + topk_weights, + topk_idx_out, + topk_weights_out, + hidden_states.stride(0), + hidden_states.stride(1), + x_fp4.stride(0), + x_fp4.stride(1), + x_sf.stride(0), + x_sf.stride(1), + topk_ids.stride(0), + topk_ids.stride(1), + topk_weights.stride(0), + topk_weights.stride(1), + topk_idx_out.stride(0), + topk_idx_out.stride(1), + topk_weights_out.stride(0), + topk_weights_out.stride(1), + hidden_size, + top_k, + BLOCK_K=block_k, + GROUP_K=16, # NVFP4: group_size=16 (scale_vec::4X) + BLOCK_TOPK=block_topk, + num_warps=4, + )