diff --git a/Dockerfile b/Dockerfile index 3622f0ae..04cbe608 100644 --- a/Dockerfile +++ b/Dockerfile @@ -30,17 +30,21 @@ ENV PYTHONPATH="/root/nvfp4-megamoe-kernel:${PYTHONPATH}" # Patch vLLM — overwrite model files and register architecture ARG VLLM_MODELS_DIR=/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models ARG VLLM_LAYERS_DIR=/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers +ARG VLLM_QUANT_DIR=/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/quantization +ARG VLLM_FUSED_MOE_DIR=/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe ARG VLLM_LOADER_DIR=/usr/local/lib/python3.12/dist-packages/vllm/model_executor/model_loader +# Core model patches COPY vllm/patches/deepseek_v4.py ${VLLM_MODELS_DIR}/deepseek_v4.py COPY vllm/patches/deepseek_v4_attention.py ${VLLM_LAYERS_DIR}/deepseek_v4_attention.py -COPY vllm/nvfp4_cutedsl.py ${VLLM_MODELS_DIR}/nvfp4_cutedsl.py -COPY vllm/cutedsl_quant_method.py ${VLLM_MODELS_DIR}/cutedsl_quant_method.py -COPY cutedsl/nvfp4_linear.py /root/nvfp4-megamoe-kernel/cutedsl/nvfp4_linear.py -COPY cutedsl/shared_expert_pipeline.py /root/nvfp4-megamoe-kernel/cutedsl/shared_expert_pipeline.py -COPY vllm/patches/utils.py ${VLLM_LOADER_DIR}/utils.py -RUN sed -i 's/"DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),/"DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),\n "DeepseekV4ForCausalLM": ("deepseek_v4", "DeepseekV4ForCausalLM"),/' \ +# NVFP4 MoE backend registration +COPY vllm/patches/fused_moe/oracle/nvfp4.py ${VLLM_FUSED_MOE_DIR}/oracle/nvfp4.py +COPY vllm/patches/fused_moe/experts/cutedsl_moe.py ${VLLM_FUSED_MOE_DIR}/experts/cutedsl_moe.py + +# Register DeepseekV4ForCausalLM model architecture (if not already in upstream) +RUN grep -q '"DeepseekV4ForCausalLM"' ${VLLM_MODELS_DIR}/registry.py || \ + sed -i 's/"DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),/"DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),\n "DeepseekV4ForCausalLM": ("deepseek_v4", "DeepseekV4ForCausalLM"),/' \ ${VLLM_MODELS_DIR}/registry.py # Verify diff --git a/cutedsl/runner.py b/cutedsl/runner.py new file mode 100644 index 00000000..e22e60fa --- /dev/null +++ b/cutedsl/runner.py @@ -0,0 +1,529 @@ +""" +vLLM integration for the CuTeDSL NVFP4 MoE kernel. + +CUDA-graph-compatible design: +- All intermediate buffers pre-allocated at max_num_tokens * top_k size +- No .item(), .tolist(), .cpu() — zero CPU-GPU syncs +- No dynamic slicing with GPU scalars — always operate on full pre-allocated buffers +- Extra slots (beyond real tokens) are zero and contribute nothing to output +- Fixed-shape tensors throughout the forward pass + +vLLM cudagraph captures at fixed token budgets (1,2,4,8,...,8192). +During capture, num_tokens equals the budget — all shapes are fixed. +During replay, inputs are padded to the budget size. Our runner always +processes max_slots = budget * top_k rows; padding rows are zeros. +""" +import torch + +from cutedsl.bridge import ( + quantize_activation_nvfp4, + quantize_weight_to_nvfp4, + + +class _MoEApply(torch.autograd.Function): + """Custom autograd function to make CuTeDSL MoE runner opaque to torch.compile.""" + @staticmethod + def forward(ctx, runner, hidden_states, topk_weights, topk_ids, expert_indices): + return runner._run_impl(hidden_states, topk_weights, topk_ids, expert_indices) + quantize_to_nvfp4, + make_b_k_major, + assemble_scales_3d_side, + run_nvfp4_grouped_gemm, +) +from cutedsl.kernel.moe.torch_scaled_grouped_mm import ( + ceil_div as cutedsl_ceil_div, + pad_and_swizzle_single, +) + + +class CuTeDSLMoERunner: + """Manages NVFP4 MoE execution via the CuTeDSL kernel. + + CUDA-graph-compatible: all buffers pre-allocated, no CPU-GPU syncs, + no dynamic shapes. Always computes at max_num_tokens * top_k capacity. + """ + + def __init__(self, num_experts, hidden_size, intermediate_size, + max_num_tokens=8192, top_k=8, device="cuda", + experts_start_idx=0): + self.num_experts = num_experts + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.max_num_tokens = max_num_tokens + self.top_k = top_k + self.device = device + self.experts_start_idx = experts_start_idx + self._swiglu_limit = None # Set via set_swiglu_limit() + + # Weight storage (set before _ensure_stacked) + self.l1_fp4 = None + self.l1_sf = None + self.l1_gs = None + self.l2_fp4 = None + self.l2_sf = None + self.l2_gs = None + + # Stacked weight tensors (set in _ensure_stacked) + self._l1_mat_b = None + self._l2_mat_b = None + self._l1_scale_b = None + self._l2_scale_b = None + self._l1_gsb = None + self._l2_gsb = None + + # Default: 1/2688 ≈ 0.000372 (amax=1 → gs=1/2688) + # Overridden in finalize_weights with checkpoint input_scale or warmup value + self._l1_activation_global_scale = 1.0 / (6.0 * 448.0) + self._l2_activation_global_scale = 1.0 / (6.0 * 448.0) + + # Pre-allocated cudagraph buffers (set in _allocate_buffers) + self._token_indices = None + self._expert_id_range = None + self._expert_offsets_buf = None + self._per_expert_scale_bufs_l1 = None + self._per_expert_scale_bufs_l2 = None + self._padded_x_sf_buf_l1 = None + self._padded_x_sf_buf_l2 = None + self._l1_gsa_buf = None + self._l2_gsa_buf = None + self._output_buf = None + self._row_indices_buf = None + self._padded_hidden_buf = None + self._padded_activated_buf = None # unused, using shared + self._padded_expert_offsets_buf = None + self._max_chunks_per_expert = cutedsl_ceil_div( + self.max_num_tokens * self.top_k, self.num_experts * 128 + ) + self._buffers_allocated = False + + def set_swiglu_limit(self, limit: float | None): + """Set the swiglu_limit for activation clamping.""" + self._swiglu_limit = limit + + def _fill_token_indices(self): + """Fill _token_indices with [0,0,..0, 1,1,..1, ...] (each token repeated top_k times). + + Builds on CPU first, then copies to GPU, to ensure correctness + regardless of CuTeDSL JIT GPU memory corruption. + """ + src = torch.arange(self.max_num_tokens, dtype=torch.int32) + cpu_indices = src.unsqueeze(1).expand(-1, self.top_k).contiguous().view(-1) + self._token_indices.copy_(cpu_indices) + + def _allocate_buffers(self): + """Pre-allocate scale buffers at max size for cudagraph compatibility.""" + # Per-expert scale buffers: separate L1/L2 since K_sf differs + K_sf_l1 = cutedsl_ceil_div(self.hidden_size, 16) + padded_cols_l1 = cutedsl_ceil_div(K_sf_l1, 4) * 4 + K_sf_l2 = cutedsl_ceil_div(self.intermediate_size, 16) + padded_cols_l2 = cutedsl_ceil_div(K_sf_l2, 4) * 4 + + self._per_expert_scale_bufs_l1 = [ + torch.zeros(128, padded_cols_l1, dtype=torch.float16, device=self.device).to(torch.float8_e4m3fn) + for _ in range(self.num_experts) + ] + self._per_expert_scale_bufs_l2 = [ + torch.zeros(128, padded_cols_l2, dtype=torch.float16, device=self.device).to(torch.float8_e4m3fn) + for _ in range(self.num_experts) + ] + + # Initialize shared buffers dict (if not already) + device_key = str(self.device) + if not hasattr(CuTeDSLMoERunner, '_shared_padded_bufs'): + CuTeDSLMoERunner._shared_padded_bufs = {} + if device_key not in CuTeDSLMoERunner._shared_padded_bufs: + CuTeDSLMoERunner._shared_padded_bufs[device_key] = {} + + # Padded x_sf buffers: SHARED across all runners (not per-layer) + max_sf_rows = self.num_experts * self._max_chunks_per_expert * 128 + if 'xsf_l1' not in CuTeDSLMoERunner._shared_padded_bufs[device_key]: + CuTeDSLMoERunner._shared_padded_bufs[device_key].update({ + 'xsf_l1': torch.zeros( + max_sf_rows, padded_cols_l1, dtype=torch.float16, device=self.device + ).to(torch.float8_e4m3fn), + 'xsf_l2': torch.zeros( + max_sf_rows, padded_cols_l2, dtype=torch.float16, device=self.device + ).to(torch.float8_e4m3fn), + 'output': torch.zeros( + self.max_num_tokens, self.hidden_size, dtype=torch.bfloat16, device=self.device + ), + }) + self._padded_x_sf_buf_l1 = CuTeDSLMoERunner._shared_padded_bufs[device_key]['xsf_l1'] + self._padded_x_sf_buf_l2 = CuTeDSLMoERunner._shared_padded_bufs[device_key]['xsf_l2'] + self._output_buf = CuTeDSLMoERunner._shared_padded_bufs[device_key]['output'] + + # Pre-allocated global_scale_a buffers (filled via .fill_(), no torch.full during capture) + self._l1_gsa_buf = torch.zeros(self.num_experts, dtype=torch.float32, device=self.device) + self._l2_gsa_buf = torch.zeros(self.num_experts, dtype=torch.float32, device=self.device) + + # Row indices for scale assembly (max_num_tokens * top_k slots) + self._row_indices_buf = torch.arange( + self.max_num_tokens * self.top_k, device=self.device + ) + + # Padded hidden/activated: SHARED across all runners (not per-layer) + max_rows_per_expert = self._max_chunks_per_expert * 128 + padded_max_slots = self.num_experts * max_rows_per_expert + if 'hidden' not in CuTeDSLMoERunner._shared_padded_bufs[device_key]: + CuTeDSLMoERunner._shared_padded_bufs[device_key].update({ + 'hidden': torch.zeros( + padded_max_slots, self.hidden_size, dtype=torch.bfloat16, device=self.device + ), + 'hidden_fp4': torch.zeros( + padded_max_slots, self.hidden_size // 2, dtype=torch.uint8, device=self.device + ).view(torch.float4_e2m1fn_x2), + 'activated': torch.zeros( + padded_max_slots, self.intermediate_size, dtype=torch.bfloat16, device=self.device + ), + 'activated_fp4': torch.zeros( + padded_max_slots, self.intermediate_size // 2, dtype=torch.uint8, device=self.device + ).view(torch.float4_e2m1fn_x2), + }) + self._shared_bufs = CuTeDSLMoERunner._shared_padded_bufs[device_key] + + # Padded expert offsets buffer: [0, max_rows, 2*max_rows, ...] (fixed) + self._padded_expert_offsets_buf = torch.zeros( + self.num_experts + 1, dtype=torch.int32, device=self.device + ) + max_rows_per_expert = self._max_chunks_per_expert * 128 + self._padded_expert_offsets_buf[1:] = torch.arange( + 1, self.num_experts + 1, dtype=torch.int32, device=self.device + ) * max_rows_per_expert + + self._buffers_allocated = True + + def _ensure_stacked(self): + if self._l1_mat_b is not None: + return + + # Stack and prepare weight tensors FIRST (triggers CuTeDSL JIT compilation) + self._l1_mat_b = make_b_k_major(torch.stack(self.l1_fp4)) + self._l2_mat_b = make_b_k_major(torch.stack(self.l2_fp4)) + self._l1_scale_b = assemble_scales_3d_side(self.l1_sf) + self._l2_scale_b = assemble_scales_3d_side(self.l2_sf) + self._l1_gsb = torch.tensor(self.l1_gs, dtype=torch.float32, device=self.device) + self._l2_gsb = torch.tensor(self.l2_gs, dtype=torch.float32, device=self.device) + self.l1_fp4 = None + self.l1_sf = None + self.l1_gs = None + self.l2_fp4 = None + self.l2_sf = None + self.l2_gs = None + + # Allocate buffers AFTER JIT compilation + # (CuTeDSL's cute.compile corrupts GPU memory during JIT; + # tensors allocated before/during compilation may be zeroed) + # + # _token_indices: GPU tensor for cudagraph compatibility. + # CuTeDSL JIT may corrupt GPU memory, so we fill AFTER stacking + # (which triggers the weight JIT). The GEMM JIT in run_nvfp4_grouped_gemm + # is triggered on the first run() call; we refill _token_indices after + # that first call via the _needs_token_refill flag. + self._token_indices = torch.zeros( + self.max_num_tokens * self.top_k, dtype=torch.int32, device=self.device + ) + self._fill_token_indices() + self._needs_token_refill = True # GEMM JIT may corrupt; refill after first run + + self._expert_id_range = torch.arange( + self.num_experts, dtype=torch.int32 + ).to(self.device) + self._expert_offsets_buf = torch.zeros( + self.num_experts + 1, dtype=torch.int32, device=self.device + ) + self._allocate_buffers() + + def prepare_weights_direct(self, l1_fp4, l1_sf, l1_gs, l2_fp4, l2_sf, l2_gs): + self.l1_fp4 = l1_fp4 + self.l1_sf = l1_sf + self.l1_gs = l1_gs + self.l2_fp4 = l2_fp4 + self.l2_sf = l2_sf + self.l2_gs = l2_gs + self._l1_mat_b = None + + def prepare_weights_from_dequantized(self, l1_weights_bf16, l2_weights_bf16): + self.l1_fp4, self.l1_sf, self.l1_gs = [], [], [] + self.l2_fp4, self.l2_sf, self.l2_gs = [], [], [] + for l1_w, l2_w in zip(l1_weights_bf16, l2_weights_bf16): + l1_w_t = l1_w.T + w_fp4, w_sf, w_gs = quantize_weight_to_nvfp4(l1_w_t) + self.l1_fp4.append(w_fp4) + self.l1_sf.append(w_sf) + self.l1_gs.append(w_gs) + l2_w_t = l2_w.T + w_fp4, w_sf, w_gs = quantize_weight_to_nvfp4(l2_w_t) + self.l2_fp4.append(w_fp4) + self.l2_sf.append(w_sf) + self.l2_gs.append(w_gs) + self._l1_mat_b = None + + def _assemble_scales_cudagraph_safe(self, x_sf, expert_offsets, + padded_expert_offsets, + padded_x_sf_buf, per_expert_bufs): + """Assemble 2D-side activation scales (cudagraph-safe, NO CPU syncs). + + Phase 1: Scatter x_sf into padded per-expert sections (GPU-only). + Phase 2: Apply full-buffer Blackwell 32_4_4 swizzle (no Python loops). + + The buffer is 128-row aligned per expert (from padded_expert_offsets), + so the full-buffer swizzle produces the correct layout. The GEMM reads + scale_a using padded_expert_offsets, matching the scatter layout. + """ + K_sf = x_sf.shape[1] + padded_x_sf = padded_x_sf_buf + padded_x_sf.zero_() + + # Phase 1: Scatter x_sf into padded per-expert sections (GPU-only) + total_rows = x_sf.shape[0] + row_indices = self._row_indices_buf[:total_rows] + expert_assign = torch.searchsorted( + expert_offsets[1:], row_indices, right=True + ).clamp(max=self.num_experts - 1) + local_row = row_indices - expert_offsets[expert_assign] + dst_rows = padded_expert_offsets[expert_assign] + local_row + padded_x_sf[dst_rows, :K_sf] = x_sf + + # Phase 2: Full-buffer swizzle (no CPU sync, no Python loops) + # padded_x_sf is 128-row aligned per expert and 4-col aligned. + # to_blocked: (rows, cols) → view(R, 128, C, 4) → permute(0,2,1,3) + # → reshape(-1, 4, 32, 4) → transpose(1,2) → reshape(-1, 32, 16) → flatten + rows = padded_x_sf.shape[0] + cols = padded_x_sf.shape[1] + R = rows // 128 + C = cols // 4 + blocks = padded_x_sf.view(R, 128, C, 4).permute(0, 2, 1, 3) + rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16) + swizzled = rearranged.flatten().view(torch.float8_e4m3fn) + return swizzled.reshape(rows, cols) + + def compute_activation_global_scales(self, hidden_states_sample, topk_weights, topk_ids): + """Compute activation global scales from a warmup forward pass. + + Called BEFORE cudagraph capture. Uses the SAME padded GEMM path as run() + to ensure kernel JIT happens with the same layout, and L2 gs is computed + from actual L1 output (not an approximation). + """ + self._ensure_stacked() + device = hidden_states_sample.device + num_tokens = hidden_states_sample.shape[0] + top_k = topk_ids.shape[1] + + with torch.no_grad(): + # Build slot mapping (same as run()) + flat_ids = topk_ids.reshape(-1) + num_slots = num_tokens * top_k + token_indices = self._token_indices[:num_slots] + sort_idx = flat_ids.argsort(stable=True) + sorted_ids = flat_ids[sort_idx] + sorted_token_ids = token_indices[sort_idx] + slot_hidden = hidden_states_sample[sorted_token_ids] + + # L1: get exact gs from quantize_to_nvfp4 + _, _, l1_gs = quantize_to_nvfp4(slot_hidden) + + # Quantize slot_hidden for GEMM + slot_x_fp4, slot_x_sf = quantize_activation_nvfp4(slot_hidden, l1_gs) + + expert_id_range = self._expert_id_range + tokens_per_expert = (sorted_ids.unsqueeze(1) == expert_id_range.unsqueeze(0)).sum(dim=0).int() + expert_offsets = self._expert_offsets_buf + expert_offsets.zero_() + expert_offsets[1:self.num_experts + 1] = tokens_per_expert.cumsum(0) + + padded_tokens_per_expert = ((tokens_per_expert + 127) // 128) * 128 + padded_expert_offsets = self._padded_expert_offsets_buf + padded_expert_offsets.zero_() + padded_expert_offsets[1:self.num_experts + 1] = padded_tokens_per_expert.cumsum(0) + + # Compute padded_dst (same as run()) + row_indices = self._row_indices_buf[:num_slots] + expert_assign = torch.searchsorted( + expert_offsets[1:], row_indices, right=True + ).clamp(max=self.num_experts - 1) + local_row = row_indices - expert_offsets[expert_assign] + padded_dst = padded_expert_offsets[expert_assign] + local_row + + # Scatter x_fp4 into padded layout + padded_x_fp4 = self._shared_bufs['hidden_fp4'] + padded_x_fp4.view(torch.uint8).zero_() + padded_x_fp4.view(torch.uint8)[padded_dst] = slot_x_fp4.view(torch.uint8) + + l1_scale_a = self._assemble_scales_cudagraph_safe( + slot_x_sf, expert_offsets[:self.num_experts + 1], + padded_expert_offsets, + self._padded_x_sf_buf_l1, self._per_expert_scale_bufs_l1 + ) + l1_gsa = torch.full((self.num_experts,), l1_gs, dtype=torch.float32, device=device) + + l1_out = run_nvfp4_grouped_gemm( + mat_a=padded_x_fp4, mat_b=self._l1_mat_b, + scale_a=l1_scale_a, scale_b=self._l1_scale_b, + expert_offsets=padded_expert_offsets[1:self.num_experts + 1], + global_scale_a=l1_gsa, global_scale_b=self._l1_gsb, + ) + + # Extract real token outputs + l1_out_real = l1_out[padded_dst] + + # L2: get exact gs from SiLU(gate)*up + gate = l1_out_real[:, :self.intermediate_size] + up = l1_out_real[:, self.intermediate_size:] + gate_silu = torch.nn.functional.silu(gate) + if self._swiglu_limit is not None: + gate_silu = gate_silu.clamp(max=self._swiglu_limit) + up = up.clamp(min=-self._swiglu_limit, max=self._swiglu_limit) + activated = gate_silu * up + _, _, l2_gs = quantize_to_nvfp4(activated) + + self._l1_activation_global_scale = l1_gs + self._l2_activation_global_scale = l2_gs + + + + def run(self, hidden_states, topk_weights, topk_ids, expert_indices=None): + """Forward: route tokens to experts, GEMM, combine.""" + return _MoEApply.apply(self, hidden_states, topk_weights, topk_ids, expert_indices) + + def _run_impl(self, hidden_states, topk_weights, topk_ids, expert_indices=None): + """Run the NVFP4 MoE forward pass. + + Handles global→local expert ID remapping for expert parallelism. + Fully cudagraph-safe: no CPU-GPU syncs, no dynamic shapes. + + Each expert's slots are padded to multiples of 128 for the GEMM. + expert_offsets is [0, padded_e0, padded_e0+padded_e1, ...]. + scale_a is produced at those same offsets. + """ + num_tokens = hidden_states.shape[0] + top_k = topk_ids.shape[1] + device = hidden_states.device + + self._ensure_stacked() + + # -- Remap global expert IDs to local IDs -- + local_ids = topk_ids - self.experts_start_idx + local_mask = (local_ids >= 0) & (local_ids < self.num_experts) + safe_ids = local_ids.clamp(0, self.num_experts - 1) + safe_weights = topk_weights * local_mask.float() + + # -- Build slot mapping -- + flat_ids = safe_ids.reshape(-1) + flat_weights = safe_weights.reshape(-1) + num_slots = num_tokens * top_k + token_indices = self._token_indices[:num_slots] + + sort_idx = flat_ids.argsort(stable=True) + sorted_ids = flat_ids[sort_idx] + sorted_weights = flat_weights[sort_idx] + sorted_token_ids = token_indices[sort_idx] + + # Expert offsets (real token counts) + expert_id_range = self._expert_id_range + tokens_per_expert = (sorted_ids.unsqueeze(1) == expert_id_range.unsqueeze(0)).sum(dim=0).int() + expert_offsets = self._expert_offsets_buf + expert_offsets.zero_() + expert_offsets[1:self.num_experts + 1] = tokens_per_expert.cumsum(0) + + # Pad each expert to 128-row alignment (GPU-only computation) + padded_tokens_per_expert = ((tokens_per_expert + 127) // 128) * 128 + padded_expert_offsets = self._padded_expert_offsets_buf + padded_expert_offsets.zero_() + padded_expert_offsets[1:self.num_experts + 1] = padded_tokens_per_expert.cumsum(0) + total_padded_slots = padded_expert_offsets[self.num_experts] + + # -- Gather hidden states into slot order, compute padded_dst -- + slot_hidden = hidden_states[sorted_token_ids] + row_indices = self._row_indices_buf[:num_slots] + expert_assign = torch.searchsorted( + expert_offsets[1:], row_indices, right=True + ).clamp(max=self.num_experts - 1) + local_row = row_indices - expert_offsets[expert_assign] + padded_dst = padded_expert_offsets[expert_assign] + local_row + + # === L1: gate + up === + # Quantize slot_hidden (sorted tokens), NOT padded_hidden. + # padded_hidden is padded with zeros; quantizing it produces + # x_sf rows at padded positions, but x_sf[:num_slots] would + # only get scales for the first num_slots PADDED rows (expert 0), + # not the scattered token positions. Quantizing slot_hidden + # gives x_sf with num_slots rows (one per token), which the + # scale assembly correctly scatters into padded layout. + slot_x_fp4, slot_x_sf = quantize_activation_nvfp4( + slot_hidden, self._l1_activation_global_scale + ) + # Scatter x_fp4 into padded layout for the GEMM + # Must scatter as uint8 (float4_e2m1fn_x2 doesn't support index_put) + padded_x_fp4 = self._shared_bufs['hidden_fp4'] + padded_x_fp4.view(torch.uint8).zero_() + padded_x_fp4.view(torch.uint8)[padded_dst] = slot_x_fp4.view(torch.uint8) + + l1_scale_a = self._assemble_scales_cudagraph_safe( + slot_x_sf, expert_offsets[:self.num_experts + 1], + padded_expert_offsets, + self._padded_x_sf_buf_l1, self._per_expert_scale_bufs_l1 + ) + l1_gsa = self._l1_gsa_buf.fill_(self._l1_activation_global_scale) + + l1_out = run_nvfp4_grouped_gemm( + mat_a=padded_x_fp4, mat_b=self._l1_mat_b, + scale_a=l1_scale_a, scale_b=self._l1_scale_b, + expert_offsets=padded_expert_offsets[1:self.num_experts + 1], + global_scale_a=l1_gsa, global_scale_b=self._l1_gsb, + ) + + # Extract real token outputs from padded GEMM output + l1_out_real = l1_out[padded_dst] + + # === SiLU(gate) * up (with swiglu_limit clamp) === + gate = l1_out_real[:, :self.intermediate_size] + up = l1_out_real[:, self.intermediate_size:] + gate_silu = torch.nn.functional.silu(gate) + # Apply DeepSeek-V4 swiglu_limit: clamp both silu(gate) and up + if self._swiglu_limit is not None: + gate_silu = gate_silu.clamp(max=self._swiglu_limit) + up = up.clamp(min=-self._swiglu_limit, max=self._swiglu_limit) + activated = gate_silu * up + + # === L2: down === + # Quantize activated (per-token), scatter into padded FP4 buffer + slot_l2_x_fp4, slot_l2_x_sf = quantize_activation_nvfp4( + activated, self._l2_activation_global_scale + ) + padded_activated_fp4 = self._shared_bufs['activated_fp4'] + padded_activated_fp4.view(torch.uint8).zero_() + padded_activated_fp4.view(torch.uint8)[padded_dst] = slot_l2_x_fp4.view(torch.uint8) + + l2_scale_a = self._assemble_scales_cudagraph_safe( + slot_l2_x_sf, expert_offsets[:self.num_experts + 1], + padded_expert_offsets, + self._padded_x_sf_buf_l2, self._per_expert_scale_bufs_l2 + ) + l2_gsa = self._l2_gsa_buf.fill_(self._l2_activation_global_scale) + + l2_out = run_nvfp4_grouped_gemm( + mat_a=padded_activated_fp4, mat_b=self._l2_mat_b, + scale_a=l2_scale_a, scale_b=self._l2_scale_b, + expert_offsets=padded_expert_offsets[1:self.num_experts + 1], + global_scale_a=l2_gsa, global_scale_b=self._l2_gsb, + ) + + l2_out_real = l2_out[padded_dst] + + # === Scatter -> final output === + y = self._output_buf[:num_tokens] + y.zero_() + weighted_out = l2_out_real * sorted_weights.unsqueeze(1).to(l2_out_real.dtype) + y.scatter_add_( + 0, + sorted_token_ids.unsqueeze(1).expand(-1, self.hidden_size), + weighted_out, + ) + + # Refill _token_indices after GEMM JIT on first call + # (CuTeDSL's cute.compile may corrupt GPU memory during first GEMM) + if self._needs_token_refill: + self._fill_token_indices() + self._needs_token_refill = False + + return y diff --git a/vllm/patches/deepseek_v4.py b/vllm/patches/deepseek_v4.py index e6ad2b4e..a9ab511a 100644 --- a/vllm/patches/deepseek_v4.py +++ b/vllm/patches/deepseek_v4.py @@ -1,13 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import os import typing from collections.abc import Callable, Iterable from itertools import islice import regex as re -import os import torch import torch.nn as nn @@ -15,6 +12,7 @@ from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig, get_current_vllm_config from vllm.distributed import ( get_ep_group, + get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, ) @@ -25,11 +23,14 @@ from vllm.model_executor.layers.deepseek_v4_attention import ( DeepseekV4MLAModules, DeepseekV4MultiHeadLatentAttentionWrapper, ) -from vllm.model_executor.layers.fused_moe import FusedMoE, GateLinear +from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod from vllm.model_executor.layers.fused_moe.router.fused_topk_bias_router import ( fused_topk_bias, ) +from vllm.model_executor.layers.fused_moe.router.norm_gate_linear import ( + NormGateLinear, +) from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( ColumnParallelLinear, @@ -37,6 +38,12 @@ from vllm.model_executor.layers.linear import ( RowParallelLinear, ) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mhc import ( + HCHeadOp, + MHCFusedPostPreOp, + MHCPostOp, + MHCPreOp, +) from vllm.model_executor.layers.quantization import ( QuantizationConfig, QuantizationMethods, @@ -52,6 +59,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ) from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.interfaces import SupportsPP from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors @@ -60,8 +68,10 @@ from vllm.utils.torch_utils import direct_register_custom_op from .utils import ( AutoWeightsLoader, + PPMissingLayer, WeightsMapper, extract_layer_index, + is_pp_missing_parameter, make_layers, maybe_prefix, ) @@ -123,17 +133,28 @@ class DeepseekV4MLP(nn.Module): class DeepseekV4FP8Config(Fp8Config): """FP8 config for DeepSeek V4 with expert-dtype-aware MoE dispatch. - DeepSeek V4 checkpoints use FP8 block quantization for attention - layers and NVFP4 (E2M1 + float8_e4m3fn block scales) for MoE experts. + DeepSeek V4 checkpoints always use FP8 block quantization for + linear/attention layers. The MoE expert weights vary by checkpoint: + - ``expert_dtype="fp4"`` (e.g. DeepSeek-V4-Flash): MXFP4 experts + with ue8m0 (e8m0fnu) FP8 linear scales. + - ``expert_dtype="fp8"`` (e.g. DeepSeek-V4-Flash-Base): FP8 block + experts with float32 FP8 linear scales. - ``expert_dtype`` from hf_config determines the MoE dispatch path. - For NVFP4 checkpoints (our case), expert_dtype="fp4" which routes - to DeepseekV4MegaMoEExperts (native NVFP4 CUTLASS kernel). + The dispatch and the linear scale dtype are both keyed off + ``expert_dtype`` from the model's hf_config; missing values default + to ``"fp4"`` so existing FP4 checkpoints stay unchanged. + + NOTE: ``expert_dtype`` is resolved lazily because this config is + constructed during VllmConfig setup, before ``set_current_vllm_config`` + is active. Reading hf_config eagerly in ``__init__`` would always see + the default ``"fp4"`` and silently misroute Flash-Base checkpoints. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._resolved_expert_dtype: str | None = None + # ``is_scale_e8m0`` is a property that resolves on first read, + # by which time the current vllm_config has been set. @property def expert_dtype(self) -> str: @@ -141,9 +162,8 @@ class DeepseekV4FP8Config(Fp8Config): try: hf_config = get_current_vllm_config().model_config.hf_config except Exception: - # vllm_config not yet set; return safe default but do NOT - # cache — a later call inside set_current_vllm_config may - # resolve differently. + # vllm_config not yet set; defer the decision until a + # later call lands inside set_current_vllm_config. return "fp4" expert_dtype = getattr(hf_config, "expert_dtype", "fp4") if expert_dtype not in _DEEPSEEK_V4_EXPERT_DTYPES: @@ -152,8 +172,19 @@ class DeepseekV4FP8Config(Fp8Config): f"expected one of {_DEEPSEEK_V4_EXPERT_DTYPES}." ) self._resolved_expert_dtype = expert_dtype + from vllm.logger import init_logger + + init_logger(__name__).info_once( + "DeepSeek V4 expert_dtype resolved to %r", expert_dtype + ) return self._resolved_expert_dtype + @property + def is_scale_e8m0(self) -> bool: + # FP4 checkpoints store FP8 linear scales as e8m0fnu; FP8 expert + # checkpoints (Flash-Base) store them as float32. + return self.expert_dtype == "fp4" + @classmethod def get_name(cls) -> QuantizationMethods: return "deepseek_v4_fp8" @@ -190,45 +221,188 @@ class DeepseekV4FP8Config(Fp8Config): return isinstance(layer, FusedMoE) and self.expert_dtype == "fp4" +@triton.jit +def _deepseek_v4_stage_mega_moe_inputs_kernel( + hidden_states, + x_fp8, + x_sf, + topk_ids, + topk_weights, + topk_idx_out, + topk_weights_out, + hidden_stride_m: tl.constexpr, + hidden_stride_k: tl.constexpr, + x_stride_m: tl.constexpr, + x_stride_k: tl.constexpr, + x_sf_stride_m: tl.constexpr, + x_sf_stride_k: tl.constexpr, + topk_ids_stride_m: tl.constexpr, + topk_ids_stride_k: tl.constexpr, + topk_weights_stride_m: tl.constexpr, + topk_weights_stride_k: tl.constexpr, + topk_idx_stride_m: tl.constexpr, + topk_idx_stride_k: tl.constexpr, + topk_weights_out_stride_m: tl.constexpr, + topk_weights_out_stride_k: tl.constexpr, + hidden_size: tl.constexpr, + top_k: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_K: tl.constexpr, + BLOCK_TOPK: tl.constexpr, +) -> None: + token_id = tl.program_id(0) + k_block_id = tl.program_id(1) + + k_offsets = k_block_id * BLOCK_K + tl.arange(0, BLOCK_K) + k_mask = k_offsets < hidden_size + hidden = tl.load( + hidden_states + token_id * hidden_stride_m + k_offsets * hidden_stride_k, + mask=k_mask, + other=0.0, + ).to(tl.float32) + + num_groups: tl.constexpr = BLOCK_K // GROUP_K + hidden_groups = tl.reshape(tl.abs(hidden), [num_groups, GROUP_K]) + amax = tl.max(hidden_groups, axis=1) + amax = tl.maximum(amax, 1.0e-4) + + scale = amax / 448.0 + scale_bits = scale.to(tl.uint32, bitcast=True) + scale_exp = ((scale_bits >> 23) & 0xFF) + ((scale_bits & 0x7FFFFF) != 0).to( + tl.uint32 + ) + scale_exp = tl.minimum(tl.maximum(scale_exp, 1), 254) + rounded_scale = (scale_exp << 23).to(tl.float32, bitcast=True) + + hidden_groups = tl.reshape(hidden, [num_groups, GROUP_K]) + scaled = hidden_groups * (1.0 / rounded_scale)[:, None] + scaled = tl.reshape(scaled, [BLOCK_K]) + fp8 = scaled.to(tl.float8e4nv) + tl.store( + x_fp8 + token_id * x_stride_m + k_offsets * x_stride_k, + fp8, + mask=k_mask, + ) + + scale_offsets = tl.arange(0, num_groups) + packed_scale = tl.sum(scale_exp << (scale_offsets * 8), axis=0).to(tl.int32) + tl.store( + x_sf + token_id * x_sf_stride_m + k_block_id * x_sf_stride_k, + packed_scale, + ) + + if k_block_id == 0: + topk_offsets = tl.arange(0, BLOCK_TOPK) + topk_mask = topk_offsets < top_k + + ids = tl.load( + topk_ids + token_id * topk_ids_stride_m + topk_offsets * topk_ids_stride_k, + mask=topk_mask, + other=0, + ).to(tl.int64) + tl.store( + topk_idx_out + + token_id * topk_idx_stride_m + + topk_offsets * topk_idx_stride_k, + ids, + mask=topk_mask, + ) + + weights = tl.load( + topk_weights + + token_id * topk_weights_stride_m + + topk_offsets * topk_weights_stride_k, + mask=topk_mask, + other=0.0, + ) + tl.store( + topk_weights_out + + token_id * topk_weights_out_stride_m + + topk_offsets * topk_weights_out_stride_k, + weights, + mask=topk_mask, + ) + + +def _stage_deepseek_v4_mega_moe_inputs( + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + x_fp8: torch.Tensor, + x_sf: torch.Tensor, + topk_idx_out: torch.Tensor, + topk_weights_out: torch.Tensor, +) -> None: + num_tokens, hidden_size = hidden_states.shape + if num_tokens == 0: + return + if hidden_size % 128 != 0: + raise ValueError( + "DeepSeek V4 MegaMoE input staging requires hidden_size to be " + "a multiple of 128." + ) + top_k = topk_ids.shape[1] + if topk_weights.shape != topk_ids.shape: + raise ValueError( + "DeepSeek V4 MegaMoE input staging requires topk_weights and " + "topk_ids to have the same shape." + ) + + block_k = 128 + grid = (num_tokens, triton.cdiv(hidden_size, block_k)) + block_topk = triton.next_power_of_2(top_k) + _deepseek_v4_stage_mega_moe_inputs_kernel[grid]( + hidden_states, + x_fp8, + x_sf, + topk_ids, + topk_weights, + topk_idx_out, + topk_weights_out, + hidden_states.stride(0), + hidden_states.stride(1), + x_fp8.stride(0), + x_fp8.stride(1), + x_sf.stride(0), + x_sf.stride(1), + topk_ids.stride(0), + topk_ids.stride(1), + topk_weights.stride(0), + topk_weights.stride(1), + topk_idx_out.stride(0), + topk_idx_out.stride(1), + topk_weights_out.stride(0), + topk_weights_out.stride(1), + hidden_size, + top_k, + BLOCK_K=block_k, + GROUP_K=32, + BLOCK_TOPK=block_topk, + num_warps=4, + ) def make_deepseek_v4_expert_params_mapping( num_experts: int, ) -> list[tuple[str, str, int, str]]: - # Checkpoint uses gate_proj/up_proj/down_proj, model params use w13_/w2_ return [ ( "experts.w13_" if shard_id in ("w1", "w3") else "experts.w2_", - f"experts.{expert_id}.{ckpt_name}.", + f"experts.{expert_id}.{weight_name}.", expert_id, shard_id, ) for expert_id in range(num_experts) - for shard_id, ckpt_name in [ - ("w1", "gate_proj"), - ("w2", "down_proj"), - ("w3", "up_proj"), + for shard_id, weight_name in [ + ("w1", "w1"), + ("w2", "w2"), + ("w3", "w3"), ] ] class DeepseekV4MegaMoEExperts(nn.Module): - """MegaMoE experts for DeepSeek V4 with NVFP4 quantization. - - Loads NVFP4 expert weights (E2M1 packed uint8 + float8_e4m3fn block scales - + float32 global scales) and runs them through the CuTeDSL NVFP4 kernel. - - The CuTeDSL kernel is a Python-based CUTLASS kernel compiled via MLIR → PTX. - It handles NVFP4 natively with full Blackwell pipeline overlap (TMA → MMA → Epilogue). - This replaces the broken C++ CUTLASS kernel (see README.md for the full story). - """ - _cutedsl_runner: 'CuTeDSLMoERunner | None' = None - _weight_load_count: int = 0 - _weight_load_tqdm: 'tqdm | None' = None - - # NVFP4 E2M1 lookup table (positive values, sign from bit 3) - E2M1_LUT = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0] - # MXFP4 E2M1 is the same format + _symm_buffer_cache: dict[tuple[int, int, int, int, int, int, int], object] = {} def __init__( self, @@ -252,88 +426,56 @@ class DeepseekV4MegaMoEExperts(nn.Module): self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens - self.swiglu_limit = vllm_config.model_config.hf_config.swiglu_limit weight_attrs = {"weight_loader": self.weight_loader} - - # NVFP4 weights: E2M1 packed as uint8, 2 values per byte self.w13_weight = nn.Parameter( torch.zeros( num_local_experts, 2 * intermediate_size, hidden_size // 2, - dtype=torch.int8, + dtype=torch.uint8, ), requires_grad=False, ) set_weight_attrs(self.w13_weight, weight_attrs) - # NVFP4 block scales: float8_e4m3fn, group_size=16 - # Shape: [num_local_experts, 2*intermediate_size, hidden_size // 16] self.w13_weight_scale = nn.Parameter( torch.zeros( num_local_experts, 2 * intermediate_size, - hidden_size // 16, - dtype=torch.float8_e4m3fn, + hidden_size // 32, + dtype=torch.uint8, ), requires_grad=False, ) set_weight_attrs(self.w13_weight_scale, weight_attrs) self.w13_weight_scale.quant_method = "block" - # NVFP4 global scales: float32, per-expert, per-projection (gate, up) - # shape (num_local_experts, 2) — one scale for gate_proj, one for up_proj - self.w13_weight_scale_2 = nn.Parameter( - torch.zeros(num_local_experts, 2, dtype=torch.float32), - requires_grad=False, - ) - set_weight_attrs(self.w13_weight_scale_2, weight_attrs) - - # NVFP4 activation scales: float32, per-expert - self.w13_input_scale = nn.Parameter( - torch.zeros(num_local_experts, dtype=torch.float32), - requires_grad=False, - ) - set_weight_attrs(self.w13_input_scale, weight_attrs) - self.w2_weight = nn.Parameter( torch.zeros( num_local_experts, hidden_size, intermediate_size // 2, - dtype=torch.int8, + dtype=torch.uint8, ), requires_grad=False, ) set_weight_attrs(self.w2_weight, weight_attrs) - # NVFP4 block scales for w2 self.w2_weight_scale = nn.Parameter( torch.zeros( num_local_experts, hidden_size, - intermediate_size // 16, - dtype=torch.float8_e4m3fn, + intermediate_size // 32, + dtype=torch.uint8, ), requires_grad=False, ) set_weight_attrs(self.w2_weight_scale, weight_attrs) self.w2_weight_scale.quant_method = "block" - self.w2_weight_scale_2 = nn.Parameter( - torch.zeros(num_local_experts, dtype=torch.float32), - requires_grad=False, - ) - set_weight_attrs(self.w2_weight_scale_2, weight_attrs) - - self.w2_input_scale = nn.Parameter( - torch.zeros(num_local_experts, dtype=torch.float32), - requires_grad=False, - ) - set_weight_attrs(self.w2_input_scale, weight_attrs) - - self._cutedsl_runner = None + self._transformed_l1_weights: tuple[torch.Tensor, torch.Tensor] | None = None + self._transformed_l2_weights: tuple[torch.Tensor, torch.Tensor] | None = None # Register in the static forward context so the custom-op wrapper # can look up this module by name from within a torch.compile graph. @@ -354,43 +496,21 @@ class DeepseekV4MegaMoEExperts(nn.Module): weight_name: str, shard_id: str, expert_id: int, - ) -> bool: - # Progress bar for k8s/docker liveness during GPU upload - if DeepseekV4MegaMoEExperts._weight_load_count == 0: - from tqdm import tqdm as _tqdm - DeepseekV4MegaMoEExperts._weight_load_tqdm = _tqdm( - total=self.num_local_experts * 20, # ~20 tensors per expert - desc=" Loading Native NVFP4 Expert Weights", - unit="tensor", - ) - DeepseekV4MegaMoEExperts._weight_load_count += 1 - DeepseekV4MegaMoEExperts._weight_load_tqdm.update(1) - + return_success: bool = False, + ) -> bool | None: local_expert_id = self._map_global_expert_id(expert_id) if local_expert_id == -1: - return False - - # Scalar params (weight_scale_2, input_scale): per-expert - if "weight_scale_2" in weight_name or "input_scale" in weight_name: - if "w13_" in weight_name and "weight_scale_2" in weight_name: - # w13 is fused gate+up — store gate and up scales separately - # shard_id tells us which projection: w1=gate, w3=up - proj_idx = 0 if shard_id == "w1" else 1 - param.data[local_expert_id, proj_idx].copy_(loaded_weight) - else: - # w2 or input_scale — single scalar per expert - param.data[local_expert_id].copy_(loaded_weight) - return True + return False if return_success else None expert_data = param.data[local_expert_id] if shard_id in ("w1", "w3"): if "w13_" not in weight_name: - return False + return False if return_success else None shard_offset = 0 if shard_id == "w1" else self.intermediate_size expert_data = expert_data.narrow(0, shard_offset, self.intermediate_size) elif shard_id == "w2": if "w2_" not in weight_name: - return False + return False if return_success else None else: raise ValueError(f"Unsupported expert shard id: {shard_id}") @@ -401,24 +521,21 @@ class DeepseekV4MegaMoEExperts(nn.Module): f"vs checkpoint {tuple(loaded_weight.shape)}" ) expert_data.copy_(loaded_weight) - return True + return True if return_success else None + + @staticmethod + def _ue8m0_uint8_to_float(sf: torch.Tensor) -> torch.Tensor: + return (sf.to(torch.int32) << 23).view(torch.float32) def _check_runtime_supported(self) -> None: if not torch.cuda.is_available(): raise NotImplementedError("DeepSeek V4 MegaMoE requires CUDA.") - # After finalize_weights, w13_weight is freed — get device from - # the stacked tensors in the cutedsl runner instead. - if self.w13_weight is not None: - device = self.w13_weight.device - elif self._cutedsl_runner is not None and self._cutedsl_runner._l1_mat_b is not None: - device = self._cutedsl_runner._l1_mat_b.device - else: - device = torch.device("cuda") # fallback + device = self.w13_weight.device if device.type != "cuda": raise NotImplementedError( "DeepSeek V4 MegaMoE expert weights must be loaded on CUDA." ) - if torch.cuda.get_device_capability(device)[0] < 10: + if torch.cuda.get_device_capability(device)[0] != 10: raise NotImplementedError("DeepGEMM MegaMoE requires SM100 GPUs.") if self.hidden_size % 128 != 0 or self.intermediate_size % 128 != 0: raise ValueError( @@ -427,169 +544,69 @@ class DeepseekV4MegaMoEExperts(nn.Module): ) def finalize_weights(self) -> None: - if self._cutedsl_runner is not None and (self._cutedsl_runner.l1_fp4 is not None or self._cutedsl_runner._l1_mat_b is not None): - return # Already finalized + if self._transformed_l1_weights is not None: + return self._check_runtime_supported() + import vllm.third_party.deep_gemm as deep_gemm - # ── Direct NVFP4 path (no BF16 round-trip) ── - # Checkpoint stores: - # weight: uint8 packed E2M1 (2 FP4 values/byte) → view as float4_e2m1fn_x2 - # weight_scale: float8_e4m3fn block scales → use directly - # weight_scale_2: float32 global scale → use directly - # The only conversion is uint8 → float4_e2m1fn_x2 (byte-preserving view cast). - # - # L1 complication: gate and up have different global scales, but the - # kernel takes one global_scale_b per expert. Solution: normalize to - # max(gate_gs, up_gs) and fold the ratio into block scales via float32 - # (one multiply + float8 round-trip on the *ratio only* — much better - # than dequantizing the entire weight matrix through BF16). - - from vllm.model_executor.models.nvfp4_cutedsl import CuTeDSLMoERunner - - l1_fp4, l1_sf, l1_gs = [], [], [] - l2_fp4, l2_sf, l2_gs = [], [], [] - - for e in range(self.num_local_experts): - # ── L1: gate + up (fused) ── - gate_w = self.w13_weight.data[e, :self.intermediate_size] # (intermediate, hidden//2) uint8 - up_w = self.w13_weight.data[e, self.intermediate_size:] # (intermediate, hidden//2) uint8 - gate_sf = self.w13_weight_scale.data[e, :self.intermediate_size] # (intermediate, hidden//16) float8 - up_sf = self.w13_weight_scale.data[e, self.intermediate_size:] - gate_gs = self.w13_weight_scale_2.data[e, 0].item() # float32 scalar - up_gs = self.w13_weight_scale_2.data[e, 1].item() - - # Fuse gate+up along N dim, then transpose to K-major (K_packed, N) - # Checkpoint is (N, K_packed) → permute to (K_packed, N) = (hidden//2, 2*intermediate) - fused_w = torch.cat([gate_w, up_w], dim=0) # (2*intermediate, hidden//2) - fused_w_fp4 = fused_w.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous() - # shape: (hidden//2, 2*intermediate) — K=hidden packed, N=2*intermediate - - # Fuse block scales: checkpoint is (N, K_sf), bridge expects (K_sf, N) - fused_sf = torch.cat([gate_sf, up_sf], dim=0) # (2*intermediate, hidden//16) = (N, K_sf) - # Transpose to (K_sf, N) for assemble_scales_3d_side - fused_sf = fused_sf.permute(1, 0).contiguous() - - # Handle dual global scales: normalize to max, fold ratio into block scales - l1_max_gs = max(gate_gs, up_gs) - if gate_gs != up_gs: - fused_sf_f32 = fused_sf.float() - # After transpose to (K_sf, N): gate is first intermediate cols, up is next - fused_sf_f32[:, :self.intermediate_size] *= (gate_gs / l1_max_gs) - fused_sf_f32[:, self.intermediate_size:] *= (up_gs / l1_max_gs) - fused_sf = fused_sf_f32.to(torch.float8_e4m3fn) - - l1_fp4.append(fused_w_fp4) - l1_sf.append(fused_sf) - l1_gs.append(l1_max_gs) - - # ── L2: down (single projection, straightforward) ── - down_w = self.w2_weight.data[e] # (hidden, intermediate//2) uint8 - down_sf = self.w2_weight_scale.data[e] # (hidden, intermediate//16) float8 - down_gs = self.w2_weight_scale_2.data[e].item() # float32 scalar - - # Checkpoint is (N, K_packed) → permute to (K_packed, N) - # K=intermediate (packed dim), N=hidden - down_w_fp4 = down_w.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous() - # shape: (intermediate//2, hidden) — K=intermediate packed, N=hidden - - # Block scales: checkpoint is (N, K_sf), bridge expects (K_sf, N) - down_sf = down_sf.permute(1, 0).contiguous() - - l2_fp4.append(down_w_fp4) - l2_sf.append(down_sf) - l2_gs.append(down_gs) - - # Create CuTeDSL runner with directly-cast weights - self._cutedsl_runner = CuTeDSLMoERunner( - num_experts=self.num_local_experts, - hidden_size=self.hidden_size, - intermediate_size=self.intermediate_size, - max_num_tokens=self.max_num_tokens, - top_k=self.top_k, - device=l1_fp4[0].device, - experts_start_idx=self.experts_start_idx, + w13_scale = deep_gemm.transform_sf_into_required_layout( + self._ue8m0_uint8_to_float(self.w13_weight_scale.data).contiguous(), + 2 * self.intermediate_size, + self.hidden_size, + (1, 32), + self.num_local_experts, ) - self._cutedsl_runner.l1_fp4 = l1_fp4 - self._cutedsl_runner.l1_sf = l1_sf - self._cutedsl_runner.l1_gs = l1_gs - self._cutedsl_runner.l2_fp4 = l2_fp4 - self._cutedsl_runner.l2_sf = l2_sf - self._cutedsl_runner.l2_gs = l2_gs - - # Set activation global scales from checkpoint input_scale - # The input_scale is the pre-computed activation normalization factor. - # w13_input_scale shape: (num_experts, 2) for gate+up, but may be (num_experts,) after EP split - # w2_input_scale shape: (num_experts, 1) or (num_experts,) - w13_igs = self.w13_input_scale.data - w2_igs = self.w2_input_scale.data - if w13_igs.dim() == 2: - l1_igs = w13_igs[:, 0] # gate input_scale - else: - l1_igs = w13_igs # already 1D per expert - if w2_igs.dim() == 2: - l2_igs = w2_igs[:, 0] - else: - l2_igs = w2_igs - # Use checkpoint input_scale as initial guess, then warmup will override - self._cutedsl_runner._l1_activation_global_scale = l1_igs.mean().item() - self._cutedsl_runner._l2_activation_global_scale = l2_igs.mean().item() - - # Drop the original loader-side parameters - self._w13_input_scale = self.w13_input_scale.data.clone() - self._w2_input_scale = self.w2_input_scale.data.clone() + w2_scale = deep_gemm.transform_sf_into_required_layout( + self._ue8m0_uint8_to_float(self.w2_weight_scale.data).contiguous(), + self.hidden_size, + self.intermediate_size, + (1, 32), + self.num_local_experts, + ) + self._transformed_l1_weights, self._transformed_l2_weights = ( + deep_gemm.transform_weights_for_mega_moe( + (self.w13_weight.data.view(torch.int8).contiguous(), w13_scale), + (self.w2_weight.data.view(torch.int8).contiguous(), w2_scale), + ) + ) + # Drop the original loader-side parameters: the MegaMoE kernels only + # consume the transformed views above. transform_weights_for_mega_moe + # allocates a fresh tensor for the L1 weight (see _interleave_l1_weights) + # and fresh SF tensors for L1/L2; the L2 weight is the only tensor that + # aliases the original storage, and _transformed_l2_weights still holds + # it, so the storage stays live after we drop the Parameter. self.w13_weight = None self.w13_weight_scale = None - self.w13_weight_scale_2 = None - self.w13_input_scale = None self.w2_weight = None self.w2_weight_scale = None - self.w2_weight_scale_2 = None - self.w2_input_scale = None - # Warmup: compute actual activation global scales from sample data. - # The checkpoint input_scale is a calibration value that doesn't match - # runtime activation magnitudes. We run a small forward pass to observe - # the actual amax and compute correct gs values. - self._warmup_activation_global_scales() + def get_symm_buffer(self): + import vllm.third_party.deep_gemm as deep_gemm - # Set swiglu_limit for activation clamping in the runner - if self.swiglu_limit is not None: - self._cutedsl_runner.set_swiglu_limit(float(self.swiglu_limit)) - - def _warmup_activation_global_scales(self) -> None: - """Run a warmup forward pass to compute correct activation global scales. - - Called once per layer during finalize_weights, before cudagraph capture. - Uses quantize_to_nvfp4 (which calls .max()) to get the exact gs - from real activation magnitudes, then stores them for use by - quantize_activation_nvfp4 (no .max(), cudagraph-safe). - """ - import torch - runner = self._cutedsl_runner - device = runner.device - num_tokens = min(8, runner.max_num_tokens) - top_k = runner.top_k - - with torch.no_grad(): - # Sample hidden states: typical BF16 activations have amax ~1-10 - hidden_states = torch.randn(num_tokens, runner.hidden_size, - dtype=torch.bfloat16, device=device) - # Assign all tokens to local experts (0..num_local_experts-1) - # compute_activation_global_scales expects local IDs, not global - topk_ids = torch.zeros(num_tokens, top_k, dtype=torch.int64, device=device) - for i in range(num_tokens): - for j in range(top_k): - topk_ids[i, j] = j % runner.num_experts - topk_weights = torch.ones(num_tokens, top_k, dtype=torch.float32, device=device) / top_k - - runner.compute_activation_global_scales( - hidden_states, topk_weights, topk_ids + group = get_ep_group().device_group + device = torch.accelerator.current_device_index() + key = ( + id(group), + device, + self.num_experts, + self.max_num_tokens, + self.top_k, + self.hidden_size, + self.intermediate_size, + ) + symm_buffer = self._symm_buffer_cache.get(key) + if symm_buffer is None: + symm_buffer = deep_gemm.get_symm_buffer_for_mega_moe( + group, + self.num_experts, + self.max_num_tokens, + self.top_k, + self.hidden_size, + self.intermediate_size, ) - - # Note: No explicit CuTeDSL warmup here. With FULL_AND_PIECEWISE - # CUDA graph mode, the kernel compiles during graph capture (startup). - # In eager mode, the first inference triggers JIT compilation. + self._symm_buffer_cache[key] = symm_buffer + return symm_buffer def forward( self, @@ -626,31 +643,34 @@ class DeepseekV4MegaMoEExperts(nn.Module): activation_clamp: float | None, fast_math: bool, ) -> None: - import os + import vllm.third_party.deep_gemm as deep_gemm + + symm_buffer = self.get_symm_buffer() + num_tokens = hidden_states.shape[0] + _stage_deepseek_v4_mega_moe_inputs( + hidden_states, + topk_weights, + topk_ids, + symm_buffer.x[:num_tokens], + symm_buffer.x_sf[:num_tokens], + symm_buffer.topk_idx[:num_tokens], + symm_buffer.topk_weights[:num_tokens], + ) # This method must have been already called during the weight loading phase. # We call it again here to cover the dummy weight loading case. self.finalize_weights() - assert self._cutedsl_runner is not None - # After _ensure_stacked, per-expert lists are freed and stacked - # tensors live in _l1_mat_b / _l2_mat_b instead. - assert (self._cutedsl_runner.l1_fp4 is not None - or self._cutedsl_runner._l1_mat_b is not None) - - # Build expert indices list for this rank - expert_indices = list(range(self.num_local_experts)) - - try: - result = self._cutedsl_runner.run( - hidden_states, topk_weights, topk_ids, - expert_indices=expert_indices, - ) - y.copy_(result) - except Exception as exc: - import traceback - traceback.print_exc() - raise + assert self._transformed_l1_weights is not None + assert self._transformed_l2_weights is not None + deep_gemm.fp8_fp4_mega_moe( + y, + self._transformed_l1_weights, + self._transformed_l2_weights, + symm_buffer, + activation_clamp=activation_clamp, + fast_math=fast_math, + ) DeepseekV4MegaMoEExperts.weight_loader.supports_moe_loading = True # type: ignore[attr-defined] @@ -708,7 +728,9 @@ class DeepseekV4MoE(nn.Module): config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config self.prefix = prefix - self.use_mega_moe = True # Force mega_moe for NVFP4 pipeline + self.use_mega_moe = ( + vllm_config.kernel_config.moe_backend == "deep_gemm_mega_moe" + ) if self.use_mega_moe and not vllm_config.parallel_config.enable_expert_parallel: raise NotImplementedError( "DeepSeek V4 MegaMoE currently requires expert parallel. " @@ -729,25 +751,30 @@ class DeepseekV4MoE(nn.Module): raise NotImplementedError( "DeepSeek V4 MegaMoE currently supports sqrtsoftplus routing only." ) - # NVFP4 experts work with mega_moe via NVFP4 weight transformation in finalize_weights + if self.use_mega_moe and getattr(config, "expert_dtype", "fp4") != "fp4": + raise NotImplementedError( + "DeepSeek V4 MegaMoE only supports fp4 experts; got expert_dtype=" + f"{config.expert_dtype!r}. Drop --kernel-config moe_backend=" + "deep_gemm_mega_moe for this checkpoint." + ) - self.gate = GateLinear( - config.hidden_size, - config.n_routed_experts, - out_dtype=torch.float32, - bias=False, - prefix=f"{prefix}.gate", + # Fused RMSNorm + gate: owns both ffn_norm and the gate matmul. + self.norm_gate = NormGateLinear( + hidden_size=config.hidden_size, + num_experts=config.n_routed_experts, + rms_eps=config.rms_norm_eps, + prefix=f"{prefix}.norm_gate", ) - self.gate.e_score_correction_bias = None - self.gate.tid2eid = None + # Routing-side tensors live on ``norm_gate`` directly (not on the + # inner gate); they are initialized to None in NormGatedLinear and + # populated below depending on the MoE variant. is_hash_moe = extract_layer_index(prefix) < config.num_hash_layers self.hash_indices_dtype = torch.int64 if self.use_mega_moe else torch.int32 - if is_hash_moe: # hash MoE doesn't use e_score_correction_bias # Use randint instead of empty to avoid garbage values causing # invalid memory access in dummy mode (--load-format="dummy") - self.gate.tid2eid = nn.Parameter( + self.norm_gate.tid2eid = nn.Parameter( torch.randint( 0, config.n_routed_experts, @@ -757,7 +784,7 @@ class DeepseekV4MoE(nn.Module): requires_grad=False, ) elif getattr(config, "topk_method", None) == "noaux_tc": - self.gate.e_score_correction_bias = nn.Parameter( + self.norm_gate.e_score_correction_bias = nn.Parameter( torch.empty(config.n_routed_experts, dtype=torch.float32), requires_grad=False, ) @@ -820,10 +847,9 @@ class DeepseekV4MoE(nn.Module): self.n_local_experts = config.n_routed_experts // self.tp_size self.experts_start_idx = self.tp_rank * self.n_local_experts self.experts_end_idx = self.experts_start_idx + self.n_local_experts - + # We don't pass `gate` into FusedMoE self.experts = FusedMoE( shared_experts=self.shared_experts, - gate=self.gate, num_experts=config.n_routed_experts, top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, @@ -833,8 +859,8 @@ class DeepseekV4MoE(nn.Module): prefix=f"{prefix}.experts", scoring_func=self.scoring_func, routed_scaling_factor=self.routed_scaling_factor, - e_score_correction_bias=self.gate.e_score_correction_bias, - hash_indices_table=self.gate.tid2eid, + e_score_correction_bias=self.norm_gate.e_score_correction_bias, + hash_indices_table=self.norm_gate.tid2eid, swiglu_limit=self.swiglu_limit, router_logits_dtype=torch.float32, ) @@ -842,46 +868,40 @@ class DeepseekV4MoE(nn.Module): def forward( self, hidden_states: torch.Tensor, input_ids: torch.Tensor | None = None ) -> torch.Tensor: - if self.gate.tid2eid is not None and input_ids is None: + if self.norm_gate.tid2eid is not None and input_ids is None: raise ValueError("DeepSeek V4 hash MoE routing requires input_ids.") if not self.use_mega_moe: return self._forward_fused_moe(hidden_states, input_ids) org_shape = hidden_states.shape - router_logits, _ = self.gate(hidden_states) + normed_x, router_logits = self.norm_gate(hidden_states) topk_weights, topk_ids = fused_topk_bias( - hidden_states=hidden_states, + hidden_states=normed_x, gating_output=router_logits, scoring_func=self.scoring_func, - e_score_correction_bias=self.gate.e_score_correction_bias.data - if self.gate.e_score_correction_bias is not None + e_score_correction_bias=self.norm_gate.e_score_correction_bias.data + if self.norm_gate.e_score_correction_bias is not None else None, topk=self.n_activated_experts, renormalize=self.renormalize, indices_type=self.hash_indices_dtype, input_tokens=input_ids, - hash_indices_table=self.gate.tid2eid, + hash_indices_table=self.norm_gate.tid2eid, routed_scaling_factor=self.routed_scaling_factor, ) activation_clamp = ( float(self.swiglu_limit) if self.swiglu_limit is not None else None ) final_hidden_states = self.experts( - hidden_states, + normed_x, topk_weights, topk_ids, activation_clamp=activation_clamp, ) - # EP all-reduce: each rank only computes its local experts, - # so we must sum across EP ranks to get the full routed output. - torch.distributed.all_reduce( - final_hidden_states, group=self.ep_group.device_group - ) - if self.shared_experts is not None: - shared_output = self.shared_experts(hidden_states) + shared_output = self.shared_experts(normed_x) final_hidden_states += shared_output return final_hidden_states.view(org_shape) @@ -889,21 +909,14 @@ class DeepseekV4MoE(nn.Module): def _forward_fused_moe( self, hidden_states: torch.Tensor, input_ids: torch.Tensor | None = None ) -> torch.Tensor: + assert not self.experts.is_internal_router org_shape = hidden_states.shape - if self.experts.is_internal_router: - # In this case, the gate/router runs inside the FusedMoE class - final_hidden_states = self.experts( - hidden_states=hidden_states, - router_logits=hidden_states, - input_ids=input_ids, - ) - else: - router_logits, _ = self.gate(hidden_states) - final_hidden_states = self.experts( - hidden_states=hidden_states, - router_logits=router_logits, - input_ids=input_ids, - ) + normed_x, router_logits = self.norm_gate(hidden_states) + final_hidden_states = self.experts( + hidden_states=normed_x, + router_logits=router_logits, + input_ids=input_ids, + ) return final_hidden_states.view(org_shape) @@ -1001,7 +1014,7 @@ class DeepseekV4Attention(nn.Module): self.rope_parameters = config.rope_scaling # Initialize rotary embedding BEFORE DeepseekV4MLAModules (which needs it) - rope_parameters = dict(config.rope_parameters) + rope_parameters = config.rope_parameters rope_parameters["rope_theta"] = ( config.compress_rope_theta if self.compress_ratio > 1 else config.rope_theta ) @@ -1107,7 +1120,8 @@ class DeepseekV4DecoderLayer(nn.Module): self.ffn = DeepseekV4MoE(vllm_config, prefix=f"{prefix}.ffn") self.attn_norm = RMSNorm(self.hidden_size, self.rms_norm_eps) - self.ffn_norm = RMSNorm(self.hidden_size, self.rms_norm_eps) + # ``ffn_norm`` is owned by ``self.ffn.norm_gate`` (fused with the + # router gate matmul); see ``NormGatedLinear``. self.hc_mult = config.hc_mult self.hc_sinkhorn_iters = config.hc_sinkhorn_iters self.hc_eps = config.hc_eps @@ -1156,6 +1170,9 @@ class DeepseekV4DecoderLayer(nn.Module): ), requires_grad=False, ) + self.mhc_pre = MHCPreOp() + self.mhc_post = MHCPostOp() + self.mhc_fused_post_pre = MHCFusedPostPreOp() def hc_pre( self, @@ -1164,7 +1181,7 @@ class DeepseekV4DecoderLayer(nn.Module): hc_scale: torch.Tensor, hc_base: torch.Tensor, ): - post_mix, res_mix, layer_input = torch.ops.vllm.mhc_pre( + post_mix, res_mix, layer_input = self.mhc_pre( residual=x, fn=hc_fn, hc_scale=hc_scale, @@ -1184,14 +1201,72 @@ class DeepseekV4DecoderLayer(nn.Module): post: torch.Tensor, comb: torch.Tensor, ): - return torch.ops.vllm.mhc_post(x, residual, post, comb) + return self.mhc_post(x, residual, post, comb) - def forward( + def _forward_cuda( self, x: torch.Tensor, positions: torch.Tensor, input_ids: torch.Tensor | None, - ) -> torch.Tensor: + post_mix: torch.Tensor | None = None, + res_mix: torch.Tensor | None = None, + residual: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + if residual is None: + # Run standalone hc_pre on first layer + residual = x + x, post_mix, res_mix = self.hc_pre( + x, self.hc_attn_fn, self.hc_attn_scale, self.hc_attn_base + ) + else: + residual, post_mix, res_mix, x = self.mhc_fused_post_pre( + x, + residual, + post_mix, + res_mix, + self.hc_attn_fn, + self.hc_attn_scale, + self.hc_attn_base, + self.rms_norm_eps, + self.hc_eps, + self.hc_eps, + self.hc_post_alpha, + self.hc_sinkhorn_iters, + ) + + x = self.attn_norm(x) + x = self.attn(positions, x, None) + + residual, post_mix, res_mix, x = self.mhc_fused_post_pre( + x, + residual, + post_mix, + res_mix, + self.hc_ffn_fn, + self.hc_ffn_scale, + self.hc_ffn_base, + self.rms_norm_eps, + self.hc_eps, + self.hc_eps, + self.hc_post_alpha, + self.hc_sinkhorn_iters, + ) + # ffn_norm is now folded into self.ffn.norm_gate; ffn() takes + # the pre-norm activation directly. + x = self.ffn(x, input_ids) + return x, residual, post_mix, res_mix + + def _forward_rocm( + self, + x: torch.Tensor, + positions: torch.Tensor, + input_ids: torch.Tensor | None, + post_mix: torch.Tensor | None = None, + res_mix: torch.Tensor | None = None, + residual: torch.Tensor | None = None, + ) -> tuple[ + torch.Tensor, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None + ]: residual = x x, post, comb = self.hc_pre( x, self.hc_attn_fn, self.hc_attn_scale, self.hc_attn_base @@ -1204,21 +1279,42 @@ class DeepseekV4DecoderLayer(nn.Module): x, post, comb = self.hc_pre( x, self.hc_ffn_fn, self.hc_ffn_scale, self.hc_ffn_base ) - x = self.ffn_norm(x) + # ffn_norm is now folded into self.ffn.norm_gate; ffn() takes + # the pre-norm activation directly. x = self.ffn(x, input_ids) x = self.hc_post(x, residual, post, comb) - return x + return x, None, None, None + + def forward( + self, + x: torch.Tensor, + positions: torch.Tensor, + input_ids: torch.Tensor | None, + post_mix: torch.Tensor | None = None, + res_mix: torch.Tensor | None = None, + residual: torch.Tensor | None = None, + ) -> tuple[ + torch.Tensor, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None + ]: + if current_platform.is_rocm(): + return self._forward_rocm( + x, positions, input_ids, post_mix, res_mix, residual + ) + + return self._forward_cuda(x, positions, input_ids, post_mix, res_mix, residual) @support_torch_compile class DeepseekV4Model(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + def __init__(self, *, vllm_config: Vllm_config, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config self.config = config - self.use_mega_moe = True # Force mega_moe for NVFP4 pipeline + self.use_mega_moe = ( + vllm_config.kernel_config.moe_backend == "deep_gemm_mega_moe" + ) if self.use_mega_moe and not vllm_config.parallel_config.enable_expert_parallel: raise NotImplementedError( "DeepSeek V4 MegaMoE currently requires expert parallel. " @@ -1235,7 +1331,12 @@ class DeepseekV4Model(nn.Module): # DeepseekV4MultiHeadLatentAttentionWrapper.attn_gemm_parallel_execute # (compressor kv_score, indexer.weights_proj, indexer.compressor # kv_score). fused_wqa_wkv stays on the default stream. - aux_stream_list = [torch.cuda.Stream() for _ in range(3)] + # Disable them on ROCm because of hang issues. + aux_stream_list = ( + None + if current_platform.is_rocm() + else [torch.cuda.Stream() for _ in range(3)] + ) self.device = current_platform.device_type # Reserved topk indices buffer for all Indexer layers to reuse. @@ -1246,12 +1347,15 @@ class DeepseekV4Model(nn.Module): device=self.device, ) - self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=f"{prefix}.embed_tokens", - ) + if get_pp_group().is_first_rank: + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens", + ) + else: + self.embed_tokens = PPMissingLayer() self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, @@ -1264,7 +1368,10 @@ class DeepseekV4Model(nn.Module): prefix=f"{prefix}.layers", ) - self.norm = RMSNorm(config.hidden_size, self.rms_norm_eps) + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, self.rms_norm_eps) + else: + self.norm = PPMissingLayer() self.hc_head_fn = nn.Parameter( torch.empty( @@ -1285,20 +1392,46 @@ class DeepseekV4Model(nn.Module): torch.empty(1, dtype=torch.float32), requires_grad=False, ) - + self.hc_head_op = HCHeadOp() # Pre-hc_head residual stream buffer for the MTP draft. Stable # address (outside the cudagraph pool) so the copy_ in forward() # refreshes it correctly across captured shapes. - self._mtp_hidden_buffer = torch.empty( - vllm_config.scheduler_config.max_num_batched_tokens, - self.hc_dim, - dtype=vllm_config.model_config.dtype, - device=self.device, - ) + # refreshes it correctly across captured shapes. Only allocated on + # the last PP rank — that's where MTP target hidden states are + # produced. + if get_pp_group().is_last_rank: + self._mtp_hidden_buffer = torch.empty( + vllm_config.scheduler_config.max_num_batched_tokens, + self.hc_dim, + dtype=vllm_config.model_config.dtype, + device=self.device, + ) + else: + self._mtp_hidden_buffer = None def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) + def make_empty_intermediate_tensors( + self, + batch_size: int, + dtype: torch.dtype, + device: torch.device, + ) -> IntermediateTensors: + # PP intermediate tensors carry the multi-stream hidden_states + # of shape (num_tokens, hc_mult, hidden_size) — V4 expands the + # token embedding to hc_mult streams before the first decoder + # layer and keeps that shape until hc_head() collapses it. + return IntermediateTensors( + { + "hidden_states": torch.zeros( + (batch_size, self.hc_mult, self.config.hidden_size), + dtype=dtype, + device=device, + ), + } + ) + def forward( self, input_ids: torch.Tensor, @@ -1306,23 +1439,40 @@ class DeepseekV4Model(nn.Module): intermediate_tensors: IntermediateTensors | None, inputs_embeds: torch.Tensor | None = None, ) -> torch.Tensor | IntermediateTensors: - hidden_states = self.embed_input_ids(input_ids) - hidden_states = hidden_states.unsqueeze(-2).repeat(1, self.hc_mult, 1) + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.embed_input_ids(input_ids) + hidden_states = hidden_states.unsqueeze(-2).repeat(1, self.hc_mult, 1) + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + if self.use_mega_moe: input_ids = input_ids.to(torch.int64) - for layer_idx, layer in enumerate(islice(self.layers, self.start_layer, self.end_layer)): - hidden_states = layer( + + residual, post_mix, res_mix = None, None, None + for layer in islice(self.layers, self.start_layer, self.end_layer): + hidden_states, residual, post_mix, res_mix = layer( hidden_states, positions, input_ids, + post_mix, + res_mix, + residual, ) + if layer is not None and current_platform.is_cuda(): + hidden_states = layer.hc_post(hidden_states, residual, post_mix, res_mix) + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) # Stash pre-hc_head residual for the MTP draft (captured copy_). num_tokens = hidden_states.shape[0] self._mtp_hidden_buffer[:num_tokens].copy_(hidden_states.flatten(1)) - hidden_states = hc_head( + hidden_states = self.hc_head_op( hidden_states, self.hc_head_fn, self.hc_head_scale, @@ -1343,59 +1493,6 @@ class DeepseekV4Model(nn.Module): ("compressor.fused_wkv_wgate", "compressor.wkv", 0), ("compressor.fused_wkv_wgate", "compressor.wgate", 1), ] - - # Checkpoint key → model param name substitutions. - # Applied to each (name, weight) pair before matching against - # params_dict. Order matters: longer/more-specific patterns first. - CKPT_KEY_SUBST = { - # self_attn projection names → vLLM attn attribute names - ".self_attn.q_a_proj.": ".attn.wq_a.", - ".self_attn.q_b_proj.": ".attn.wq_b.", - ".self_attn.q_a_norm.": ".attn.q_norm.", - ".self_attn.o_a_proj.": ".attn.wo_a.", - ".self_attn.o_b_proj.": ".attn.wo_b.", - ".self_attn.sinks": ".attn.attn_sink", - ".self_attn.kv_proj.": ".attn.wkv.", - ".self_attn.kv_norm.": ".attn.kv_norm.", - # Indexer: self_attn.compressor.indexer → attn.indexer - # MUST come before the generic .self_attn.compressor. rule - ".self_attn.compressor.indexer.q_b_proj.": ".attn.indexer.wq_b.", - ".self_attn.compressor.indexer.kv_norm.": ".attn.indexer.k_norm.", - ".self_attn.compressor.indexer.position_bias": ".attn.indexer.compressor.ape", - ".self_attn.compressor.indexer.gate_proj.": ".attn.indexer.compressor.wgate.", - ".self_attn.compressor.indexer.kv_proj.": ".attn.indexer.compressor.wkv.", - ".self_attn.compressor.indexer.": ".attn.indexer.", - # Compressor: self_attn.compressor → attn.mla_attn.compressor - # Compressor projections for stacking (fused_wkv_wgate) - ".self_attn.compressor.kv_proj.": ".attn.mla_attn.compressor.wkv.", - ".self_attn.compressor.gate_proj.": ".attn.mla_attn.compressor.gate.", - ".self_attn.compressor.kv_norm.": ".attn.kv_norm.", - ".self_attn.compressor.position_bias": ".attn.mla_attn.compressor.ape", - ".self_attn.compressor.": ".attn.mla_attn.compressor.", - # Shared expert projections (stacking into gate_up_proj) - # Must include .mlp. prefix since break prevents .mlp.→.ffn. from - # firing on the same key after these patterns match. - ".mlp.shared_experts.gate_proj.": ".ffn.shared_experts.w1.", - ".mlp.shared_experts.up_proj.": ".ffn.shared_experts.w3.", - ".mlp.shared_experts.down_proj.": ".ffn.shared_experts.down_proj.", - # Hadamard coding params: checkpoint has .attn_hc.base/fn/scale - # and .ffn_hc.base/fn/scale; model has hc_attn_base/fn/scale - # and hc_ffn_base/fn/scale (underscore not dot before base/fn/scale) - ".attn_hc.base": ".hc_attn_base", - ".attn_hc.fn": ".hc_attn_fn", - ".attn_hc.scale": ".hc_attn_scale", - ".ffn_hc.base": ".hc_ffn_base", - ".ffn_hc.fn": ".hc_ffn_fn", - ".ffn_hc.scale": ".hc_ffn_scale", - "hc_head.hc_base": "hc_head_base", - "hc_head.hc_fn": "hc_head_fn", - "hc_head.hc_scale": "hc_head_scale", - # compressor.position_bias → compressor.ape - ".compressor.position_bias": ".compressor.ape", - # modelopt uses mlp, vllm uses ffn internally - ".mlp.": ".ffn.", - } - params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() @@ -1411,97 +1508,38 @@ class DeepseekV4Model(nn.Module): expert_mapping = self.get_expert_mapping() for name, loaded_weight in weights: - # Strip 'model.' prefix from checkpoint keys. - # vLLM's weight iteration yields keys like 'model.layers.0...' - # but named_parameters() on DeepseekV4Model returns 'layers.0...' - if name.startswith("model."): - name = name[len("model."):] - - # Apply checkpoint → model name substitutions - for ckpt_pat, model_pat in CKPT_KEY_SUBST.items(): - if ckpt_pat in name: - name = name.replace(ckpt_pat, model_pat) - break # first match wins (order matters) - for param_name, weight_name, shard_id in stacked_params_mapping: - # Skip MoE routed experts (handled separately below). - # Use .ffn.experts. (not .experts.) to avoid skipping - # shared_experts which also contains ".experts.". - if ".ffn.experts." in name: + # Skip non-stacked layers and experts (experts handled below). + if ".experts." in name: continue if weight_name not in name: continue - name_mapped = name.replace(weight_name, param_name) - if name_mapped not in params_dict: - continue - name = name_mapped + name = name.replace(weight_name, param_name) + if is_pp_missing_parameter(name, self): + break param = params_dict[name] weight_loader = param.weight_loader - - # ModelOpt NVFP4 packed weight fix for MergedColumnParallelLinear. - # - # modelopt exports NVFP4 packed weights as uint8 (2 values/byte - # along the column dim). But MergedColumnParallelLinear creates - # the weight param as bfloat16 (ModelWeightParameter), because - # ModelOptNvFp4Config only patches Linear, not - # MergedColumnParallelLinear. - # - # When loading uint8 packed weights into a bf16 param, we need to - # unpack them. Each uint8 byte contains 2 E2M1 FP4 values. - # We unpack using the LUT and return bf16. - # - # The weight_scale is loaded separately and process_weights_after_loading - # will handle the actual NVFP4 quantization. - if (loaded_weight.dtype == torch.uint8 - and param.data.dtype != torch.uint8 - and loaded_weight.shape[-1] * 2 == param.data.shape[-1]): - # Unpack NVFP4 (E2M1) → BF16 - # E2M1 LUT: 0→0, 1→0.5, 2→1, 3→1.5, 4→2, 5→3, 6→4, 7→6 - # Sign bit in bit 3 (indices 8-15 are negatives) - FP4_LUT = torch.tensor([ - 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, - -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0, - ], dtype=torch.float32, device=loaded_weight.device) - lower = FP4_LUT[(loaded_weight & 0x0F).long()] # (..., in_packed, ) - upper = FP4_LUT[((loaded_weight >> 4) & 0x0F).long()] - # Interleave: [lower_0, upper_0, lower_1, upper_1, ...] - out = torch.empty( - *loaded_weight.shape[:-1], loaded_weight.shape[-1] * 2, - dtype=torch.float32, device=loaded_weight.device, - ) - out[..., 0::2] = lower - out[..., 1::2] = upper - loaded_weight = out.to(torch.bfloat16) - - try: - weight_loader(param, loaded_weight, shard_id) - except (AssertionError, ValueError, RuntimeError) as e: - raise RuntimeError( - f'Weight load failed: name={name} shard_id={shard_id} ' - f'param.shape={param.shape} param.dtype={param.data.dtype} ' - f'loaded.shape={loaded_weight.shape} loaded.dtype={loaded_weight.dtype}' - ) from e + weight_loader(param, loaded_weight, shard_id) loaded_params.add(name) break else: - if ".ffn.experts." in name: - # NVFP4 checkpoint stores float8_e4m3fn scales, not E8M0. - # E8M0 would indicate an MXFP4 checkpoint — wrong format. + if ".experts." in name: + # E8M0 scales are stored as float8_e8m0fnu in + # checkpoints but the MoE param is uint8. copy_() + # would do a numeric conversion (e.g. 2^-7 → 0), + # destroying the raw exponent bytes. if ( "weight_scale" in name and loaded_weight.dtype == torch.float8_e8m0fnu ): - raise ValueError( - f"E8M0 weight_scale in NVFP4 checkpoint ({name}) — " - f"checkpoint format mismatch" - ) + loaded_weight = loaded_weight.view(torch.uint8) for mapping in expert_mapping: param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: continue name_mapped = name.replace(weight_name, param_name) - if name_mapped not in params_dict: + if is_pp_missing_parameter(name_mapped, self): continue param = params_dict[name_mapped] # We should ask the weight loader to return success or not @@ -1516,16 +1554,15 @@ class DeepseekV4Model(nn.Module): name_mapped, shard_id=shard_id, expert_id=expert_id, + return_success=True, ) if success: name = name_mapped - loaded_params.add(name_mapped) break - else: - continue + loaded_params.add(name_mapped) continue elif "attn_sink" in name: - if name not in params_dict: + if is_pp_missing_parameter(name, self): continue narrow_weight = loaded_weight[head_rank_start:head_rank_end] n = narrow_weight.shape[0] @@ -1533,126 +1570,9 @@ class DeepseekV4Model(nn.Module): loaded_params.add(name) continue else: - if name not in params_dict: - # ModelOpt NVFP4 export includes params not in the - # vllm model (e.g., compressor.position_bias). - # Skip them silently. + if is_pp_missing_parameter(name, self): continue param = params_dict[name] - - # Handle bf16 → uint8 mismatch for o_a_proj: - # modelopt didn't quantize o_a_proj (bf16, no scales), - # but ModelOptNvFp4Config creates wo_a with NVFP4 quant - # (uint8 weight + scales). We quantize the bf16 weight - # to NVFP4 at load time so the layer runs in NVFP4 path. - if (name.endswith(".weight") - and loaded_weight.dtype != torch.uint8 - and param.data.dtype == torch.uint8): - # Quantize bf16 → NVFP4 (E2M1 packed uint8 + scales) - w_bf16 = loaded_weight - out_dim, in_dim = w_bf16.shape - block_size = 16 - assert in_dim % block_size == 0 - n_blocks = in_dim // block_size - - # Reshape into blocks - w_blocks = w_bf16.reshape(out_dim, n_blocks, block_size) - - # Compute per-block amax - amax = w_blocks.abs().amax(dim=-1) # [out, n_blocks] - - # Global scale (weight_scale_2): max amax / (6.0 * 448.0) - global_amax = amax.max() - # Use 448.0 as the max e4m3 value for scale computation - weight_scale_2_val = global_amax / (6.0 * 448.0) - weight_scale_2 = weight_scale_2_val.to(torch.float32) - - # Per-block scale (weight_scale): float8_e4m3fn - # block_scale = amax / (6.0 * weight_scale_2) - block_scale = amax / (6.0 * weight_scale_2_val) - weight_scale = block_scale.clamp(0.0, 448.0).to(torch.float8_e4m3fn) - - # Quantize to FP4 (E2M1) - # E2M1 LUT: 0, 0.5, 1, 1.5, 2, 3, 4, 6 (positive) - FP4_POS = torch.tensor( - [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], - dtype=torch.float32, device=w_bf16.device, - ) - # Scale the weight values: normalized = w / (block_scale * weight_scale_2) - block_scale_f32 = block_scale.clamp(0.0, 448.0) - scaled = w_blocks / (block_scale_f32.unsqueeze(-1) * weight_scale_2_val) - # Find nearest FP4 index (0-7 for magnitude) - # Use absolute value for matching, then apply sign - scaled_abs = scaled.abs() - # Find closest FP4 value - diff = (scaled_abs.unsqueeze(-1) - FP4_POS).abs() - fp4_idx = diff.argmin(dim=-1) # [out, n_blocks, block_size] - # Apply sign: negative values get bit 3 set - sign = (scaled < 0).int() - fp4_val = (sign << 3) | fp4_idx.int() - # Pack: 2 FP4 values per uint8 byte - # Even positions → lower nibble, Odd → upper nibble - fp4_flat = fp4_val.reshape(out_dim, -1) # [out, in_dim] - assert fp4_flat.shape[1] % 2 == 0 - even = fp4_flat[:, 0::2] # lower nibble - odd = fp4_flat[:, 1::2] # upper nibble - packed = (odd << 4) | even - weight_packed = packed.to(torch.uint8).view(torch.int8) - - # Reshape weight_scale to [out, n_blocks] - weight_scale_2d = weight_scale.reshape(out_dim, n_blocks) - - # Load the quantized weight into the uint8 param - weight_loader = param.weight_loader - weight_loader(param, weight_packed) - loaded_params.add(name) - - # Load scales into sibling params - base = name.rsplit(".", 1)[0] - # weight_scale - ws_name = f"{base}.weight_scale" - if ws_name in params_dict: - ws_param = params_dict[ws_name] - ws_loader = getattr(ws_param, "weight_loader", default_weight_loader) - ws_loader(ws_param, weight_scale_2d) - loaded_params.add(ws_name) - # weight_scale_2 - ws2_name = f"{base}.weight_scale_2" - if ws2_name in params_dict: - ws2_param = params_dict[ws2_name] - ws2_loader = getattr(ws2_param, "weight_loader", default_weight_loader) - ws2_loader(ws2_param, weight_scale_2.reshape(1)) - loaded_params.add(ws2_name) - # input_scale: use 1.0 default (dynamic quant) - is_name = f"{base}.input_scale" - if is_name in params_dict: - is_param = params_dict[is_name] - is_loader = getattr(is_param, "weight_loader", default_weight_loader) - is_loader(is_param, torch.tensor(1.0, dtype=torch.float32)) - loaded_params.add(is_name) - continue - - # Handle uint8 NVFP4 packed → bf16 unpack for non-stacked - # params (e.g. indexer.weights_proj). Checkpoint stores - # NVFP4 as uint8 (2 values/byte), but model param is bf16. - if (loaded_weight.dtype == torch.uint8 - and param.data.dtype != torch.uint8 - and loaded_weight.shape[-1] * 2 == param.data.shape[-1]): - FP4_LUT = torch.tensor([ - 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, - -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0, - ], dtype=torch.float32, device=loaded_weight.device) - lower = FP4_LUT[(loaded_weight & 0x0F).long()] - upper = FP4_LUT[((loaded_weight >> 4) & 0x0F).long()] - out = torch.empty( - *loaded_weight.shape[:-1], - loaded_weight.shape[-1] * 2, - dtype=torch.float32, device=loaded_weight.device, - ) - out[..., 0::2] = lower - out[..., 1::2] = upper - loaded_weight = out.to(torch.bfloat16) - weight_loader = getattr( param, "weight_loader", default_weight_loader ) @@ -1677,532 +1597,9 @@ class DeepseekV4Model(nn.Module): ) def finalize_mega_moe_weights(self) -> None: - from tqdm import tqdm - layers = list(islice(self.layers, self.start_layer, self.end_layer)) - for layer in tqdm(layers, desc=" (JIT compile)NVFP4 MoE layers", unit="layer"): + for layer in islice(self.layers, self.start_layer, self.end_layer): layer.ffn.finalize_mega_moe_weights() - def _convert_nvfp4_post_load(self): - """Post-load setup of CuTeDSL NVFP4 runners for attention and shared experts. - - Replaces the broken FlashInferCutlassNvFp4LinearKernel with CuTeDSL GEMM. - For attention projections (fused_wqa_wkv, wq_b, wo_b), installs - CuTeDSLNvfp4Method which creates CuTeDSL runners during - process_weights_after_loading. - - For shared experts, creates CuTeDSLSharedExpertRunner which handles - the full L1 (gate_up) + SiLU + L2 (down) pipeline. - - wo_a is converted to FP8 for fp8_einsum (unchanged). - Compressor weights are reconstructed from checkpoint sub-weights (unchanged). - """ - from vllm.model_executor.models.cutedsl_quant_method import CuTeDSLNvfp4Method - - fp8_converted = 0 - compressor_converted = 0 - cutedsl_installed = 0 - shared_expert_installed = 0 - - _shard_index = self._build_shard_index("/model") if os.path.isdir("/model") else None - - from tqdm import tqdm - for layer_idx, layer in tqdm(enumerate(self.layers), total=len(self.layers), desc=" NVFP4→CuTeDSL setup", unit="layer"): - attn = layer.attn - - # FP8 conversion: wo_a (used by fp8_einsum, no input_scale) - FP8_MAX = torch.finfo(torch.float8_e4m3fn).max - if hasattr(attn, "wo_a") and hasattr(attn.wo_a, "weight"): - if attn.wo_a.weight.dtype in (torch.uint8, torch.int8): - E2M1_LUT = torch.tensor([0, 0.5, 1, 1.5, 2, 3, 4, 6], dtype=torch.bfloat16) - self._convert_nvfp4_to_fp8(attn.wo_a, E2M1_LUT, FP8_MAX) - fp8_converted += 1 - - # Install CuTeDSL quant method on attention NVFP4 projections. - # When vLLM calls process_weights_after_loading, CuTeDSLNvfp4Method - # will read the NVFP4 weights, create CuTeDSL runners, and swap - # the quant method to CuTeDSLNvfp4LinearMethod. - for proj_name in ["fused_wqa_wkv", "wq_b", "wo_b"]: - if not hasattr(attn, proj_name): - continue - mod = getattr(attn, proj_name) - if not hasattr(mod, "weight") or mod.weight.dtype not in (torch.uint8, torch.int8): - continue - is_fused = (proj_name == "fused_wqa_wkv") - mod.quant_method = CuTeDSLNvfp4Method(is_fused=is_fused) - cutedsl_installed += 1 - - # Compressor: BF16 reconstruction (unchanged) - mla_attn = getattr(attn, "mla_attn", None) - if mla_attn is not None: - E2M1_LUT = torch.tensor([0, 0.5, 1, 1.5, 2, 3, 4, 6], dtype=torch.bfloat16) - compressor = getattr(mla_attn, "compressor", None) - if compressor is not None and hasattr(compressor, "fused_wkv_wgate"): - compressor_converted += self._reconstruct_compressor_weight( - compressor.fused_wkv_wgate, attn, layer_idx, E2M1_LUT, _shard_index=_shard_index) - indexer = getattr(mla_attn, "indexer", None) - if indexer is not None: - idx_compressor = getattr(indexer, "compressor", None) - if idx_compressor is not None and hasattr(idx_compressor, "fused_wkv_wgate"): - compressor_converted += self._reconstruct_compressor_weight( - idx_compressor.fused_wkv_wgate, indexer, layer_idx, E2M1_LUT, sub_path=".indexer", _shard_index=_shard_index) - - # Shared expert: install CuTeDSL shared expert runner - ffn = layer.ffn - if hasattr(ffn, 'shared_experts') and ffn.shared_experts is not None: - swiglu_limit = ffn.swiglu_limit if hasattr(ffn, 'swiglu_limit') else None - se = ffn.shared_experts - if self._install_shared_expert_runner(se, swiglu_limit, layer_idx): - shared_expert_installed += 1 - - - - def _install_shared_expert_runner(self, se_mlp, swiglu_limit: float | None, layer_idx: int) -> bool: - """Install CuTeDSL shared expert runner on a DeepseekV4MLP. - - Extracts gate_up and down NVFP4 weights, creates - CuTeDSLSharedExpertRunner, and replaces the MLP's forward - with the fused L1+SiLU+L2 pipeline. - """ - from cutedsl.shared_expert_pipeline import CuTeDSLSharedExpertRunner - - gate_up = se_mlp.gate_up_proj - down = se_mlp.down_proj - - # Check that both projections have NVFP4 weights - if not (hasattr(gate_up, "weight") and hasattr(down, "weight")): - return False - if gate_up.weight.dtype not in (torch.uint8, torch.int8): - return False - if down.weight.dtype not in (torch.uint8, torch.int8): - return False - - device = gate_up.weight.device - hidden_size = gate_up.weight.shape[1] * 2 # 2 FP4 per uint8 - intermediate_size_2x = gate_up.weight.shape[0] # gate + up stacked - intermediate_size = intermediate_size_2x // 2 - - # ── L1: gate_up (MergedColumnParallelLinear, gate + up fused) ── - l1_w_uint8 = gate_up.weight.data # (2*intermediate, hidden//2) uint8 - l1_sf = gate_up.weight_scale.data # (2*intermediate, hidden//16) float8 - l1_gs_data = gate_up.weight_scale_2.data # float32 [2] (gate, up) - - # uint8 → float4_e2m1fn_x2, permute to (K_packed, N) - l1_w_fp4 = l1_w_uint8.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous() - - # Block scales: (N, K_sf) → (K_sf, N) - if l1_sf.dtype != torch.float8_e4m3fn: - l1_sf = l1_sf.to(torch.float8_e4m3fn) - l1_sf = l1_sf.permute(1, 0).contiguous() - - # Dual global scales: normalize to max, fold ratio into block scales - l1_gs1 = l1_gs_data[0].item() - l1_gs2 = l1_gs_data[1].item() - l1_gs = max(l1_gs1, l1_gs2) - if l1_gs1 != l1_gs2: - l1_sf_f32 = l1_sf.float() - # After permute to (K_sf, N): first intermediate rows are gate, then up - l1_sf_f32[:, :intermediate_size] *= (l1_gs1 / l1_gs) - l1_sf_f32[:, intermediate_size:] *= (l1_gs2 / l1_gs) - l1_sf = l1_sf_f32.to(torch.float8_e4m3fn) - - # ── L2: down (RowParallelLinear, single projection) ── - l2_w_uint8 = down.weight.data # (hidden, intermediate//2) uint8 - l2_sf = down.weight_scale.data # (hidden, intermediate//16) float8 - l2_gs = down.weight_scale_2.data.max().item() # float32 scalar - - # uint8 → float4_e2m1fn_x2, permute to (K_packed, N) - l2_w_fp4 = l2_w_uint8.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous() - - # Block scales: (N, K_sf) → (K_sf, N) - if l2_sf.dtype != torch.float8_e4m3fn: - l2_sf = l2_sf.to(torch.float8_e4m3fn) - l2_sf = l2_sf.permute(1, 0).contiguous() - - # Create runner, set weights, finalize - runner = CuTeDSLSharedExpertRunner( - hidden_size=hidden_size, - intermediate_size=intermediate_size, - device=device, - swiglu_limit=swiglu_limit if swiglu_limit is not None else 10.0, - ) - runner.l1_fp4 = [l1_w_fp4] - runner.l1_sf = [l1_sf] - runner.l1_gs = [l1_gs] - runner.l2_fp4 = [l2_w_fp4] - runner.l2_sf = [l2_sf] - runner.l2_gs = [l2_gs] - runner.finalize_weights() - - # Warmup: compute activation global scales - with torch.no_grad(): - sample = torch.randn(min(8, 256), hidden_size, - dtype=torch.bfloat16, device=device) * 2.0 - runner.compute_activation_global_scales(sample) - - # Replace the MLP's forward with the runner - se_mlp._cutedsl_runner = runner - - # Monkey-patch forward to use the CuTeDSL runner - original_cls = type(se_mlp) - - def _cutedsl_forward(self, x): - output = self._cutedsl_runner.run(x) - # Down_proj with reduce_results may need all-reduce handled - # by RowParallelLinear. Since we bypassed it, check if we need - # to all-reduce manually. - if hasattr(self, '_needs_tp_reduce') and self._needs_tp_reduce: - from vllm.distributed import tensor_model_parallel_all_reduce - output = tensor_model_parallel_all_reduce(output) - return output - - import types - se_mlp.forward = types.MethodType(_cutedsl_forward, se_mlp) - - # Check if down_proj needs TP all-reduce - # reduce_results=True means the original RowParallelLinear would all-reduce - if hasattr(down, 'reduce_results') and down.reduce_results and getattr(down, 'tp_size', 1) > 1: - se_mlp._needs_tp_reduce = True - else: - se_mlp._needs_tp_reduce = False - - # Free NVFP4 params from gate_up and down (replace with dummy BF16) - for mod in [gate_up, down]: - out_dim = mod.weight.shape[0] - in_dim = mod.weight.shape[1] * 2 - mod.weight = torch.nn.Parameter( - torch.zeros(out_dim, in_dim, dtype=torch.bfloat16, - device=device), - requires_grad=False, - ) - from vllm.model_executor.layers.linear import UnquantizedLinearMethod - mod.quant_method = UnquantizedLinearMethod() - for attr in ("weight_scale", "weight_scale_2", "input_scale", - "input_global_scale", "input_global_scale_inv", - "weight_global_scale", "alpha", "weight_scale_inv"): - if hasattr(mod, attr): - try: - delattr(mod, attr) - except Exception: - pass - - return True - - def _convert_nvfp4_to_fp8(self, mod, e2m1_lut, fp8_max): - """Convert NVFP4 weight to FP8 for fp8_einsum path (wo_a only). - - Uses DeepGEMM's deepgemm_post_process_fp8_weight_block to ensure - correct weight and scale format for fp8_einsum with BMM. - """ - w_uint8 = mod.weight.data - device = w_uint8.device - w_bf16 = self._unpack_nvfp4_to_bf16(w_uint8, e2m1_lut, device) - - # Dequantize with scales - if hasattr(mod, "weight_scale") and hasattr(mod, "weight_scale_2"): - - block_scale = self._block_scale_to_float32(mod.weight_scale.data) - if block_scale.dim() == 2 and w_bf16.dim() == 2: - block_size = w_bf16.shape[1] // block_scale.shape[1] - block_scale_expanded = block_scale.unsqueeze(-1).expand( - -1, -1, block_size - ).reshape(w_bf16.shape) - else: - block_scale_expanded = block_scale - global_scale = mod.weight_scale_2.data.max().item() - input_scale = ( - mod.input_scale.data.max().item() - if hasattr(mod, "input_scale") - else 1.0 - ) - # NOTE: input_scale is for ACTIVATIONS, not weights. - # Weight dequant = e2m1 * block_scale * global_scale (NO input_scale) - w_dequant = w_bf16.float() * block_scale_expanded * global_scale - w_dequant = w_dequant.to(torch.bfloat16) - else: - w_dequant = w_bf16 - - # Re-quantize bf16 -> FP8 e4m3 with block quantization - # DeepGEMM expects block-scale format: weight_scale (FP8 e4m3 block scale) - # and weight_scale_inv (per-tensor scale). - # We do per-tensor quantization, so block_scale is all-ones. - w_amax = w_dequant.abs().amax() - if w_amax == 0: - w_amax = torch.tensor(1.0, device=device) - fp8_scale = w_amax / fp8_max - w_fp8 = (w_dequant / fp8_scale).to(torch.float8_e4m3fn) - - # Create block scale filled with the per-tensor fp8_scale value. - # DeepGEMM divides by the block scale, so each block gets fp8_scale. - BLOCK_SIZE = 128 - is_bmm = getattr(mod, "is_bmm", False) - bmm_batch_size = getattr(mod, "bmm_batch_size", 0) - - # Weight is 2D (output_size, input_size) before BMM reshape - # Block scale shape: (output_size / BLOCK_SIZE, input_size / BLOCK_SIZE) - rows = w_fp8.size(0) - cols = w_fp8.size(1) - block_rows = rows // BLOCK_SIZE - block_cols = cols // BLOCK_SIZE - - # Fill block scale with the per-tensor fp8_scale (NOT all-ones!) - # This is correct because we requantized with a single per-tensor scale, - # so every 128x128 block has the same scale = fp8_scale. - ws = torch.full((block_rows, block_cols), fp8_scale.item(), dtype=torch.float32, device=device) - - # Use DeepGEMM's post-processing for proper layout transformation - from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - deepgemm_post_process_fp8_weight_block, - ) - w_fp8, ws = deepgemm_post_process_fp8_weight_block( - wq=w_fp8, - ws=ws, - quant_block_shape=(BLOCK_SIZE, BLOCK_SIZE), - use_e8m0=True, # scale_fmt=ue8m0 - is_bmm=is_bmm, - bmm_batch_size=bmm_batch_size, - ) - - # Free source tensors eagerly - del w_uint8, w_bf16, w_dequant - mod.weight = torch.nn.Parameter(w_fp8, requires_grad=False) - del w_fp8 - # weight_scale_inv is what the attention runtime reads as b_scale - # for deepseek_v4_fp8_einsum -> DeepGEMM fp8_einsum. - # It must be the DeepGEMM-formatted block scale (dg_ws), NOT the - # per-tensor scalar. See: deepseek_v4_attention.py line 319. - mod.weight_scale_inv = torch.nn.Parameter(ws, requires_grad=False) - del ws - from vllm.model_executor.layers.linear import UnquantizedLinearMethod - mod.quant_method = UnquantizedLinearMethod() - for attr in ("weight_scale", "weight_scale_2", "input_scale"): - if hasattr(mod, attr): - delattr(mod, attr) - - @staticmethod - def _build_shard_index(ckpt_dir: str) -> dict[str, str]: - """Build key→shard_path index from safetensors metadata (no tensor I/O).""" - import glob - from safetensors import safe_open - index = {} - for shard_file in sorted(glob.glob(os.path.join(ckpt_dir, "model-*.safetensors"))): - try: - with safe_open(shard_file, framework="pt") as f: - for key in f.keys(): - index[key] = shard_file - except Exception: - continue - return index - - def _reconstruct_compressor_weight(self, fused_mod, parent_mod, layer_idx, e2m1_lut, sub_path="", _shard_index=None): - """Reconstruct compressor fused_wkv_wgate from checkpoint. - - Compressor weights are SKIPPED during loading because NVFP4 uint8 data - can't be loaded into bf16 MergedColumnParallelLinear params (shape mismatch). - We read the original uint8 data from the safetensors checkpoint, unpack - E2M1, dequantize, and stack into the fused weight param. - """ - from safetensors import safe_open - - # Find the checkpoint directory - # The model weights are mounted at /model in Docker - ckpt_dir = "/model" - if not os.path.isdir(ckpt_dir): - print(f"WARNING: layer {layer_idx} compressor: checkpoint dir {ckpt_dir} not found") - return 0 - - # Determine the layer's compressor key prefix in the checkpoint - # Before mapper: model.layers.N.self_attn.compressor.{kv_proj,gate_proj} - # After mapper: model.layers.N.attn.mla_attn.compressor.{wkv,wgate} - # We read from checkpoint (before mapper), so use original names - layer_prefix = f"model.layers.{layer_idx}.self_attn.compressor{sub_path}" - - # All keys we need from the checkpoint - keys = { - 'wkv_uint8': f"{layer_prefix}.kv_proj.weight", - 'wgate_uint8': f"{layer_prefix}.gate_proj.weight", - 'wkv_block_scale': f"{layer_prefix}.kv_proj.weight_scale", - 'wgate_block_scale': f"{layer_prefix}.gate_proj.weight_scale", - 'wkv_global_scale': f"{layer_prefix}.kv_proj.weight_scale_2", - 'wgate_global_scale': f"{layer_prefix}.gate_proj.weight_scale_2", - 'wkv_input_scale': f"{layer_prefix}.kv_proj.input_scale", - 'wgate_input_scale': f"{layer_prefix}.gate_proj.input_scale", - } - - # Read tensors using shard index for targeted access (no full-shard loads) - tensors = {} - for name, key in keys.items(): - shard_path = (_shard_index or {}).get(key) - if shard_path is None: - continue - try: - with safe_open(shard_path, framework="pt") as f: - if key in f.keys(): - tensors[name] = f.get_tensor(key) - except Exception: - continue - - wkv_uint8 = tensors.get('wkv_uint8') - wgate_uint8 = tensors.get('wgate_uint8') - - if wkv_uint8 is None or wgate_uint8 is None: - # Layer might not have a compressor (compress_ratio=1 layers) - return 0 - - wkv_block_scale = tensors.get('wkv_block_scale') - wgate_block_scale = tensors.get('wgate_block_scale') - wkv_global_scale = tensors.get('wkv_global_scale') - wgate_global_scale = tensors.get('wgate_global_scale') - wkv_input_scale = tensors.get('wkv_input_scale') - wgate_input_scale = tensors.get('wgate_input_scale') - - device = fused_mod.weight.device - wkv_uint8 = wkv_uint8.to(device) - wgate_uint8 = wgate_uint8.to(device) - - # Unpack E2M1 FP4→bf16 - wkv_bf16 = self._unpack_nvfp4_to_bf16(wkv_uint8, e2m1_lut, device) - wgate_bf16 = self._unpack_nvfp4_to_bf16(wgate_uint8, e2m1_lut, device) - - # Dequantize with scales - def _dequant(w_bf16, block_scale, global_scale, input_scale): - if block_scale is not None and global_scale is not None: - - block_scale = self._block_scale_to_float32(block_scale.to(device)) - if block_scale.dim() == 2 and w_bf16.dim() == 2: - block_size = w_bf16.shape[1] // block_scale.shape[1] - block_scale_exp = block_scale.unsqueeze(-1).expand( - -1, -1, block_size - ).reshape(w_bf16.shape) - else: - block_scale_exp = block_scale - gs = global_scale.to(device).max().item() - # NOTE: input_scale is for activations, not weights. - # Weight dequant = e2m1 * block_scale * global_scale (NO input_scale) - w = w_bf16.float() * block_scale_exp * gs - return w.to(torch.bfloat16) - return w_bf16 - - wkv_dequant = _dequant(wkv_bf16, wkv_block_scale, wkv_global_scale, wkv_input_scale) - wgate_dequant = _dequant(wgate_bf16, wgate_block_scale, wgate_global_scale, wgate_input_scale) - - # Stack: concatenate along output dim (dim 0) - # fused_wkv_wgate.weight = cat([wkv, wgate], dim=0) → (2*head_dim, hidden_size) - w_fused = torch.cat([wkv_dequant, wgate_dequant], dim=0) - - - # Replace the weight - fused_mod.weight = torch.nn.Parameter(w_fused, requires_grad=False) - from vllm.model_executor.layers.linear import UnquantizedLinearMethod - fused_mod.quant_method = UnquantizedLinearMethod() - for attr in ("weight_scale", "weight_scale_2", "input_scale", "weight_scale_inv"): - if hasattr(fused_mod, attr): - delattr(fused_mod, attr) - return 1 - - def _convert_bf16_to_fp8(self, mod, fp8_max): - """Convert BF16 weight to FP8 for fp8_einsum path. - - Used for wo_a which modelopt did NOT quantize (bf16 in checkpoint) - but which the attention forward reads as FP8 for deepseek_v4_fp8_einsum. - Uses DeepGEMM's post-processing for proper BMM + scale format. - """ - w_bf16 = mod.weight.data - device = w_bf16.device - - # Re-quantize bf16 -> FP8 e4m3 with block quantization - w_amax = w_bf16.abs().amax() - if w_amax == 0: - w_amax = torch.tensor(1.0, device=device) - fp8_scale = w_amax / fp8_max - w_fp8 = (w_bf16 / fp8_scale).to(torch.float8_e4m3fn) - - BLOCK_SIZE = 128 - is_bmm = getattr(mod, "is_bmm", False) - bmm_batch_size = getattr(mod, "bmm_batch_size", 0) - - rows = w_fp8.size(0) - cols = w_fp8.size(1) - block_rows = rows // BLOCK_SIZE - block_cols = cols // BLOCK_SIZE - # Fill block scale with per-tensor fp8_scale (NOT all-ones!) - ws = torch.full((block_rows, block_cols), fp8_scale.item(), dtype=torch.float32, device=device) - - from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - deepgemm_post_process_fp8_weight_block, - ) - w_fp8, ws = deepgemm_post_process_fp8_weight_block( - wq=w_fp8, - ws=ws, - quant_block_shape=(BLOCK_SIZE, BLOCK_SIZE), - use_e8m0=True, # scale_fmt=ue8m0 - is_bmm=is_bmm, - bmm_batch_size=bmm_batch_size, - ) - - mod.weight = torch.nn.Parameter(w_fp8, requires_grad=False) - # weight_scale_inv is what the attention runtime reads as b_scale - # for deepseek_v4_fp8_einsum -> DeepGEMM fp8_einsum. - # It must be the DeepGEMM-formatted block scale (dg_ws), NOT the - # per-tensor scalar. See: deepseek_v4_attention.py line 319. - mod.weight_scale_inv = torch.nn.Parameter(ws, requires_grad=False) - # weight_scale is not used at runtime for BMM layers; remove it - # to avoid confusing other code paths. - for attr in ("weight_scale", "weight_scale_2", "input_scale"): - if hasattr(mod, attr): - delattr(mod, attr) - from vllm.model_executor.layers.linear import UnquantizedLinearMethod - mod.quant_method = UnquantizedLinearMethod() - - @staticmethod - @staticmethod - def _block_scale_to_float32(sf: torch.Tensor) -> torch.Tensor: - """Convert NVFP4 block scales (float8_e4m3fn) to float32.""" - return sf.to(torch.float32) - - def _unpack_nvfp4_to_bf16(self, w_uint8, e2m1_lut, device): - """Unpack NVFP4 uint8 packed weights to bf16 using E2M1 format.""" - # Extract 4-bit FP4 values (0-15, bit 3 = sign) - even_raw = (w_uint8 & 0x0F).int() - odd_raw = ((w_uint8 >> 4) & 0x0F).int() - # Sign: 0-7 = positive, 8-15 = negative - even_sign = torch.where(even_raw >= 8, -1.0, 1.0).to(torch.bfloat16) - odd_sign = torch.where(odd_raw >= 8, -1.0, 1.0).to(torch.bfloat16) - # Magnitude index: lower 3 bits (0-7) - even_vals = even_sign * e2m1_lut.to(device)[even_raw & 0x07] - odd_vals = odd_sign * e2m1_lut.to(device)[odd_raw & 0x07] - # Interleave and flatten - w_bf16 = torch.stack([even_vals, odd_vals], dim=-1) - w_bf16 = w_bf16.reshape(w_uint8.shape[0], -1).to(torch.bfloat16) - return w_bf16 -@torch.compile(backend=current_platform.simple_compile_backend) -def hc_head( - hidden_states: torch.Tensor, - hc_fn: torch.Tensor, - hc_scale: torch.Tensor, - hc_base: torch.Tensor, - rms_norm_eps: float, - hc_eps: float, -) -> torch.Tensor: - hc_mult, hidden_size = hidden_states.shape[-2:] - outer_shape = hidden_states.shape[:-2] - hs_flat = hidden_states.view(-1, hc_mult, hidden_size) - num_tokens = hs_flat.shape[0] - out = torch.empty( - num_tokens, hidden_size, dtype=torch.bfloat16, device=hidden_states.device - ) - torch.ops.vllm.hc_head_fused_kernel( - hs_flat, - hc_fn, - hc_scale, - hc_base, - out, - hidden_size, - rms_norm_eps, - hc_eps, - hc_mult, - ) - return out.view(*outer_shape, hidden_size) - def _make_deepseek_v4_weights_mapper(expert_dtype: str) -> WeightsMapper: if expert_dtype == "fp4": @@ -2221,56 +1618,78 @@ def _make_deepseek_v4_weights_mapper(expert_dtype: str) -> WeightsMapper: scale_regex = { re.compile(r"\.scale$"): ".weight_scale_inv", } + return WeightsMapper( + orig_to_new_prefix={ + "layers.": "model.layers.", + "embed.": "model.embed.", + "norm.": "model.norm.", + "hc_head": "model.hc_head", + "mtp.": "model.mtp.", + }, + orig_to_new_regex=scale_regex, + orig_to_new_suffix={ + "head.weight": "lm_head.weight", + "embed.weight": "embed_tokens.weight", + # Pre-MoE norm + gate are now owned by ``DeepseekV4MoE.norm_gate`` + # (see NormGatedLinear). + ".ffn_norm.weight": ".ffn.norm_gate.norm.weight", + ".ffn.gate.weight": ".ffn.norm_gate.gate.weight", + ".ffn.gate.bias": ".ffn.norm_gate.e_score_correction_bias", + # Hash MoE table also moved off the inner gate. + ".ffn.gate.tid2eid": ".ffn.norm_gate.tid2eid", + }, + orig_to_new_substr={ + ".attn.compressor.": ".attn.mla_attn.compressor.", + ".shared_experts.w2": ".shared_experts.down_proj", + }, + ) - # ── ModelOpt NVFP4 export patches ──────────────────────────────── - # modelopt exports with different naming than the original HF ckpt: - # - Expert projections: gate_proj/up_proj/down_proj → w1/w3/w2 - # - Shared expert projections: gate_proj/up_proj → w1/w3 (stacking) - # - Compressor: kv_proj → wkv, gate_proj → wgate (stacking) - # - Attention: self_attn prefix, kv_proj → wkv (stacking) - # - modelopt uses mlp, vllm uses ffn - # Order matters for regex: skip patterns MUST come before renames. - # Skip NVFP4 scales for compressor+attention fused params. - # After substr renaming, these map to stacked params (fused_wkv_wgate, - # fused_wqa_wkv, gate_up_proj) which don't register NVFP4 scale params - # because ModelOptNvFp4Config only handles Linear, not - # MergedColumnParallelLinear. We unpack weights as bf16 and let - # process_weights_after_loading re-quantize them. - # Must match ORIGINAL checkpoint key names (before substr renaming). - fused_skip_regex = { - # Compressor: SKIP ALL tensors. The compressor uses quant_config=None, - # so MergedColumnParallelLinear creates bf16 weight params. NVFP4 uint8 - # checkpoint data can't be loaded into these params (shape mismatch: - # uint8 (head_dim, hidden_size//2) vs bf16 (head_dim, hidden_size)). - # The stacking weight_loader silently skips the sub-weights, leaving - # random bf16 initialization. We reconstruct the compressor weights - # manually in post-load conversion by reading from the checkpoint. - re.compile(r"\.compressor\.kv_proj\.weight$"): None, - re.compile(r"\.compressor\.gate_proj\.weight$"): None, - re.compile(r"\.compressor\.kv_proj\.weight_scale$"): None, - re.compile(r"\.compressor\.gate_proj\.weight_scale$"): None, - re.compile(r"\.compressor\.kv_proj\.weight_scale_2$"): None, - re.compile(r"\.compressor\.gate_proj\.weight_scale_2$"): None, - re.compile(r"\.compressor\.kv_proj\.input_scale$"): None, - re.compile(r"\.compressor\.gate_proj\.input_scale$"): None, - # Note: attention and shared expert scale tensors are NO LONGER - # skipped. After fixing substr mappings, they correctly map to the - # model's NVFP4 scale parameters (fused_wqa_wkv, wq_b, wo_a, - # wo_b, gate_up_proj). They load via the stacking logic. - } - # Routed expert projections: gate_proj→w1, up_proj→w3, down_proj→w2 - # Regex (not substr) to match ONLY .experts.N. — not .shared_experts. + + +def _make_deepseek_v4_nvfp4_weights_mapper() -> WeightsMapper: + """Weight mapper for NVFP4 (ModelOpt) DeepSeek-V4 checkpoints. + + NVFP4 checkpoints use different key naming than the upstream MXFP4 format: + - Expert weights: gate_proj/up_proj/down_proj (not w1/w3/w2) + - Scales already have .weight_scale / .weight_scale_2 / .input_scale suffixes + - Shared expert uses down_proj (not w2) + - Self-attention uses .self_attn. prefix (same as checkpoint, renamed to .attn.) + - Hadamard coding uses .attn_hc. and .ffn_hc. prefixes + + This is the mapper that should be used when quantization is modelopt_fp4. + """ + # Expert weight renames: gate_proj→w1, up_proj→w3, down_proj→w2 + # Must match BEFORE the general suffix renames expert_rename_regex = { re.compile(r"(\.experts\.\d+\.)gate_proj\."): r"\1w1.", re.compile(r"(\.experts\.\d+\.)up_proj\."): r"\1w3.", re.compile(r"(\.experts\.\d+\.)down_proj\."): r"\1w2.", } - # Merge: skip patterns first, then renames, then original scale_regex - merged_regex = {} - merged_regex.update(fused_skip_regex) - merged_regex.update(expert_rename_regex) - merged_regex.update(scale_regex) + + # Suffix renames for non-expert keys + # NVFP4 checkpoints already use .weight_scale (not .scale), so no scale→weight_scale mapping needed + # But .self_attn. → .attn. and .mlp. → .ffn. renames are needed + suffix_renames = { + "head.weight": "lm_head.weight", + "embed.weight": "embed_tokens.weight", + ".ffn_norm.weight": ".ffn.norm_gate.norm.weight", + ".ffn.gate.weight": ".ffn.norm_gate.gate.weight", + ".ffn.gate.bias": ".ffn.norm_gate.e_score_correction_bias", + ".ffn.gate.tid2eid": ".ffn.norm_gate.tid2eid", + } + + # Substr renames + substr_renames = { + ".attn.compressor.": ".attn.mla_attn.compressor.", + ".mlp.shared_experts.gate_proj.": ".ffn.shared_experts.w1.", + ".mlp.shared_experts.up_proj.": ".ffn.shared_experts.w3.", + ".mlp.shared_experts.down_proj.": ".ffn.shared_experts.down_proj.", + ".mlp.": ".ffn.", + ".self_attn.": ".attn.", + ".attn_hc.": ".attn.hc_op.", + ".ffn_hc.": ".ffn.hc_op.", + } return WeightsMapper( orig_to_new_prefix={ @@ -2280,66 +1699,54 @@ def _make_deepseek_v4_weights_mapper(expert_dtype: str) -> WeightsMapper: "hc_head": "model.hc_head", "mtp.": "model.mtp.", }, - orig_to_new_regex=merged_regex, - orig_to_new_suffix={ - "embed.weight": "embed_tokens.weight", - ".ffn.gate.bias": ".ffn.gate.e_score_correction_bias", - }, - orig_to_new_substr={ - ".attn.compressor.": ".attn.mla_attn.compressor.", - ".shared_experts.w2": ".shared_experts.down_proj", - # ── ModelOpt NVFP4 substr patches ── - # Attention: self_attn → attn (projections at attn level, not mla_attn) - ".self_attn.q_a_proj.": ".attn.wq_a.", - ".self_attn.q_b_proj.": ".attn.wq_b.", - ".self_attn.q_a_norm.": ".attn.q_norm.", - ".self_attn.o_a_proj.": ".attn.wo_a.", - ".self_attn.o_b_proj.": ".attn.wo_b.", - ".self_attn.sinks": ".attn.attn_sink", - # kv_proj → wkv (for stacking into fused_wqa_wkv) - ".self_attn.kv_proj.": ".attn.wkv.", - ".self_attn.kv_norm.": ".attn.kv_norm.", - # kv_norm is at attention level, not compressor/mla_attn level in vllm - # Must come before the general compressor mapping - ".self_attn.compressor.kv_norm.": ".attn.kv_norm.", - # Compressor: self_attn.compressor → attn.mla_attn.compressor - ".self_attn.compressor.": ".attn.mla_attn.compressor.", - # Compressor projections for stacking (fused_wkv_wgate) - ".compressor.kv_proj.": ".compressor.wkv.", - ".compressor.gate_proj.": ".compressor.wgate.", - # Shared expert projections (stacking into gate_up_proj) - # Checkpoint has .shared_experts. but model has .ffn.shared_experts. - ".shared_experts.gate_proj.": ".ffn.shared_experts.w1.", - ".shared_experts.up_proj.": ".ffn.shared_experts.w3.", - # modelopt uses mlp, vllm uses ffn internally - ".mlp.": ".ffn.", - }, + orig_to_new_regex=expert_rename_regex, + orig_to_new_suffix=suffix_renames, + orig_to_new_substr=substr_renames, ) -class DeepseekV4ForCausalLM(nn.Module): +class DeepseekV4ForCausalLM(nn.Module, SupportsPP): model_cls = DeepseekV4Model - # NOTE: We do NOT set hf_to_vllm_mapper here because our custom - # load_weights handles all checkpoint→model name remapping inline. - # If hf_to_vllm_mapper is set, vLLM's AutoWeightsLoader may be invoked - # INSTEAD of our load_weights, silently dropping NVFP4 weight loading. + # Default mapper assumes the original FP4-expert checkpoint layout. + # Overridden per-instance in __init__ when expert_dtype != "fp4". + hf_to_vllm_mapper = _make_deepseek_v4_weights_mapper("fp4") def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config + + # Select weight mapper based on quantization method. + # NVFP4 (modelopt_fp4) checkpoints use different key naming + # than the default MXFP4 format. + quant_config = vllm_config.quant_config + if quant_config is not None and getattr(quant_config, "get_name", lambda: None)() == "modelopt_fp4": + self.hf_to_vllm_mapper = _make_deepseek_v4_nvfp4_weights_mapper() + elif getattr(config, "expert_dtype", "fp4") != "fp4": + self.hf_to_vllm_mapper = _make_deepseek_v4_weights_mapper("fp8") + else: + self.hf_to_vllm_mapper = _make_deepseek_v4_weights_mapper("fp4") self.config = config + expert_dtype = getattr(config, "expert_dtype", "fp4") + if expert_dtype != "fp4": + self.hf_to_vllm_mapper = _make_deepseek_v4_weights_mapper(expert_dtype) self.model = self.model_cls( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") ) - self.lm_head = ParallelLMHead( - config.vocab_size, - config.hidden_size, - prefix=maybe_prefix(prefix, "lm_head"), - ) + if get_pp_group().is_last_rank: + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + prefix=maybe_prefix(prefix, "lm_head"), + ) + else: + self.lm_head = PPMissingLayer() self.logits_processor = LogitsProcessor(config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors + ) def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.embed_input_ids(input_ids) @@ -2370,92 +1777,10 @@ class DeepseekV4ForCausalLM(nn.Module): return getattr(self.model, "_mtp_hidden_buffer", None) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - # lm_head lives on this outer model, not on the inner DeepseekV4Model. - # The inner load_weights silently drops lm_head.weight via - # "if name not in params_dict: continue". Extract it here. - rest = [] - for name, loaded_weight in weights: - if name == "lm_head.weight" or name.endswith(".lm_head.weight"): - param = self.lm_head.weight - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - else: - rest.append((name, loaded_weight)) - # Use the model-level loader which handles NVFP4 expert mapping, - # uint8→bf16 unpacking for MergedColumnParallelLinear, and - # bf16→NVFP4 quantization for unquantized layers. - # AutoWeightsLoader bypasses this logic and would break NVFP4 loading. - loaded_params = self.model.load_weights(rest) - loaded_params.add("lm_head.weight") - print(" Checkpoint loaded. Preparing NVFP4...", flush=True) + loader = AutoWeightsLoader(self, skip_substrs=["mtp."]) + loaded_params = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) self.model.finalize_mega_moe_weights() - self.model._convert_nvfp4_post_load() - print(" Warming up tilelang kernels...", flush=True) - self._warmup_tilelang() - print(" NVFP4 model ready ✓", flush=True) - return loaded_params - def _warmup_tilelang(self) -> None: - """Force-compile all tilelang JIT kernels with dummy data. - - tilelang's @jit decorator compiles lazily on first call. In eager mode - (no cudagraphs), the HTTP server comes up before the first inference - triggers compilation — and any request hitting the model during - compilation crashes vLLM. This warmup ensures all kernels are compiled - before the server accepts traffic. - - We call the custom ops directly with 1-token dummy tensors to populate - the tilelang kernel cache. - """ - import torch - config = self.model.config - hc_mult = config.hc_mult - hidden_size = config.hidden_size - device = next(self.model.parameters()).device - hc_mult3 = hc_mult * (2 + hc_mult) - - # Warmup mhc_pre - residual = torch.randn(1, hc_mult, hidden_size, dtype=torch.bfloat16, - device=device) - fn = torch.randn(hc_mult3, hc_mult * hidden_size, dtype=torch.float32, - device=device) - hc_scale = torch.randn(3, dtype=torch.float32, device=device) - hc_base = torch.randn(hc_mult3, dtype=torch.float32, device=device) - - try: - torch.ops.vllm.mhc_pre( - residual=residual, - fn=fn, - hc_scale=hc_scale, - hc_base=hc_base, - rms_eps=config.rms_norm_eps, - hc_pre_eps=config.hc_eps, - hc_sinkhorn_eps=config.hc_eps, - hc_post_mult_value=2.0, - sinkhorn_repeat=config.hc_sinkhorn_iters, - ) - print(" mhc_pre ✓", flush=True) - except Exception as e: - print(f" mhc_pre warmup failed (non-fatal): {e}", flush=True) - - # Warmup mhc_post - x = torch.randn(1, hidden_size, dtype=torch.bfloat16, device=device) - post_mix = torch.randn(1, hc_mult, 1, dtype=torch.float32, device=device) - comb_mix = torch.randn(1, hc_mult, hc_mult, dtype=torch.float32, - device=device) - - try: - torch.ops.vllm.mhc_post(x, residual, post_mix, comb_mix) - print(" mhc_post ✓", flush=True) - except Exception as e: - print(f" mhc_post warmup failed (non-fatal): {e}", flush=True) - - # Free dummy tensors - del residual, fn, hc_scale, hc_base, x, post_mix, comb_mix - torch.cuda.empty_cache() - def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: return self.model.get_expert_mapping() - diff --git a/vllm/patches/deepseek_v4_attention.py b/vllm/patches/deepseek_v4_attention.py index 2bde4a53..4bb2a64c 100644 --- a/vllm/patches/deepseek_v4_attention.py +++ b/vllm/patches/deepseek_v4_attention.py @@ -14,6 +14,7 @@ import torch.nn.functional as F from transformers import DeepseekV2Config, DeepseekV3Config import vllm.envs as envs +from vllm.compilation.breakable_cudagraph import eager_break_during_capture from vllm.model_executor.layers.linear import ( ReplicatedLinear, ) @@ -28,6 +29,7 @@ from vllm.v1.attention.ops.deepseek_v4_ops import ( fused_inv_rope_fp8_quant, fused_q_kv_rmsnorm, ) +from vllm.v1.attention.ops.rocm_aiter_mla_sparse import rocm_inv_rope_einsum if TYPE_CHECKING: from vllm.v1.attention.backends.mla.sparse_swa import ( @@ -45,7 +47,7 @@ from vllm.logger import init_logger from vllm.model_executor.custom_op import PluggableLayer from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.deepseek_compressor import DeepseekCompressor -from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm +from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.input_quant_fp8 import ( QuantFP8, @@ -53,6 +55,7 @@ from vllm.model_executor.layers.quantization.input_quant_fp8 import ( from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, ) +from vllm.platforms import current_platform from vllm.utils.multi_stream_utils import ( execute_in_parallel, maybe_execute_in_parallel, @@ -198,8 +201,6 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer): # Pick fp8_einsum recipe based on GPU arch: # SM90: FP32 block scales stay [g, r/128, d/128] → sfb_gran_mn=128 # SM100: INT32 packed scales become [g, r, ...] → sfb_gran_mn=1 - from vllm.platforms import current_platform - cap = current_platform.get_device_capability() assert cap is not None, "DeepseekV4 attention requires a CUDA device" self._einsum_recipe = (1, 128, 128) if cap.major <= 9 else (1, 1, 128) @@ -222,6 +223,7 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer): + 1 # 1B pad ) + # Will be None on ROCm for now. self.aux_stream_list = mla_modules.aux_stream_list # [0]: GEMM start / post-GEMM event0. [1..3]: GEMM done events; # [1] doubles as post-GEMM event1. Reuse is safe: GEMM fully joins @@ -303,6 +305,19 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer): ) o = o_padded[:, : self.n_local_heads, :] + # Keep ROCm on the BF16 reference wo_a path util kernel ready. + if current_platform.is_rocm(): + z = rocm_inv_rope_einsum( + self.rotary_emb, + o, + positions, + self.rope_head_dim, + self.n_local_groups, + self.o_lora_rank, + self.wo_a, + ) + return self.wo_b(z.flatten(1)) + # O projection: inverse RoPE + FP8 quant + einsum + wo_b o_fp8, o_scale = fused_inv_rope_fp8_quant( o, @@ -336,12 +351,15 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer): return self.wo_b(z.flatten(1)) def attn_gemm_parallel_execute(self, hidden_states) -> tuple[Any, ...]: - assert self.aux_stream_list is not None - assert len(self.aux_stream_list) >= 3 + aux_streams = self.aux_stream_list + if aux_streams is not None: + assert len(aux_streams) >= 3 + aux_streams = aux_streams[:3] # fused_wqa_wkv (heaviest) on default; the three lighter input GEMMs # on aux streams 0..2 when their owning module exists. ln_events[0] # is the fan-out start event; ln_events[1..3] are per-aux done events. + # On ROCm, aux_streams is None and execute_in_parallel runs serially. aux_fns: list[Callable[[], Any] | None] = [None, None, None] if self.compressor is not None: @@ -385,7 +403,7 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer): aux_fns, self.ln_events[0], self.ln_events[1:4], - self.aux_stream_list[:3], + aux_streams, enable=hidden_states.shape[0] <= envs.VLLM_MULTI_STREAM_GEMM_TOKEN_THRESHOLD, ) @@ -419,8 +437,9 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer): # downstream reads q on default). Indexer/compressor go on aux for # overlap with default's GEMM + cache write. if self.indexer is not None: - assert self.aux_stream_list is not None - aux_stream = self.aux_stream_list[0] + aux_stream = ( + self.aux_stream_list[0] if self.aux_stream_list is not None else None + ) indexer = self.indexer # Local ref so the closure keeps a non-None type for mypy. assert self.compressor is not None @@ -448,8 +467,9 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer): ) elif self.compressor is not None: # wq_b + kv_insert on default, compressor on aux. - assert self.aux_stream_list is not None - aux_stream = self.aux_stream_list[0] + aux_stream = ( + self.aux_stream_list[0] if self.aux_stream_list is not None else None + ) compressor = self.compressor def wq_b_kv_insert() -> torch.Tensor: @@ -534,6 +554,7 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer): ) +@eager_break_during_capture def deepseek_v4_attention( hidden_states: torch.Tensor, positions: torch.Tensor, @@ -668,7 +689,7 @@ class DeepseekV4MLAAttention(nn.Module, AttentionLayerBase): vllm_config.scheduler_config.max_num_batched_tokens ) self.max_model_len = vllm_config.model_config.max_model_len - # DeepseekV4 only supports fp8 kv-cache format for now + # DeepseekV4 only supports fp8 kv-cache format for now. kv_cache_dtype = cache_config.cache_dtype if cache_config is not None else "fp8" assert kv_cache_dtype.startswith("fp8"), ( @@ -702,6 +723,12 @@ class DeepseekV4MLAAttention(nn.Module, AttentionLayerBase): self.kv_cache = torch.tensor([]) def get_attn_backend(self) -> type[AttentionBackend]: + if current_platform.is_rocm(): + from vllm.v1.attention.backends.mla.rocm_aiter_mla_sparse_dsv4 import ( + DeepseekV4ROCMAiterMLASparseBackend, + ) + + return DeepseekV4ROCMAiterMLASparseBackend return DeepseekV4FlashMLASparseBackend def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None: @@ -734,6 +761,14 @@ class DeepseekV4MLAAttention(nn.Module, AttentionLayerBase): f"output buffer dtype {output.dtype} must match q dtype {q.dtype}" ) + if current_platform.is_rocm(): + from vllm.v1.attention.backends.mla.rocm_aiter_mla_sparse_dsv4 import ( + DeepseekV4ROCMAiterMLASparseImpl, + ) + + DeepseekV4ROCMAiterMLASparseImpl.forward(self, q, kv, positions, output) + return + # Get SWA and indexer metadata from forward context forward_context = get_forward_context() attn_metadata = forward_context.attn_metadata @@ -979,8 +1014,7 @@ class DeepseekV4MLAAttention(nn.Module, AttentionLayerBase): M, N, ) - - output_chunk, _, _ = flash_mla_sparse_fwd( + flash_mla_sparse_fwd( q=q[query_start:query_end], kv=kv.view(-1, 1, q.shape[-1]), indices=combined_indices.unsqueeze(1), @@ -1077,7 +1111,6 @@ class DeepseekV4Indexer(nn.Module): quant_config=None, prefix=f"{prefix}.weights_proj", ) - self.k_norm = LayerNorm(self.head_dim, eps=1e-6) self.softmax_scale = self.head_dim**-0.5 self.scale_fmt = "ue8m0" diff --git a/vllm/patches/fused_moe/experts/cutedsl_moe.py b/vllm/patches/fused_moe/experts/cutedsl_moe.py new file mode 100644 index 00000000..c4c540af --- /dev/null +++ b/vllm/patches/fused_moe/experts/cutedsl_moe.py @@ -0,0 +1,303 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""CuTeDSL NVFP4 MoE experts for DeepSeek-V4. + +Integrates the CuTeDSL NVFP4 grouped GEMM kernel into vLLM's FusedMoE +modular kernel framework. This is the proper integration path — no +monkey-patching, no post-load surgery. + +The CuTeDSL kernel is a Python-based CUTLASS kernel compiled via MLIR → PTX. +It handles: + - L1 GEMM (gate + up projections) + - SiLU activation with optional swiglu_limit clamping + - L2 GEMM (down projection) + - All with NVFP4 (float8_e4m3fn block scales + float32 global scales) + +CUDA-graph-safe: all intermediate buffers pre-allocated, no CPU-GPU syncs, +no dynamic shapes. +""" + +import torch + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.activation import MoEActivation +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEParallelConfig, + FusedMoEQuantConfig, +) +from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( + TopKWeightAndReduceNoOP, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, + kNvfp4Dynamic, + kNvfp4Static, +) +from vllm.platforms import current_platform + +from cutedsl.runner import CuTeDSLMoERunner + + +class CuTeDSLMoEExperts(mk.FusedMoEExpertsModular): + """CuTeDSL NVFP4 MoE experts using the custom CuTeDSL grouped GEMM kernel. + + Uses Standard activation format (non-batched). Handles input quantization + internally (expects_unquantized_inputs=True). + + Supports expert parallelism: remaps global→local expert IDs. + """ + + def __init__( + self, + moe_config: FusedMoEConfig, + quant_config: FusedMoEQuantConfig, + ): + super().__init__( + moe_config=moe_config, + quant_config=quant_config, + ) + assert quant_config.quant_dtype == "nvfp4", ( + "CuTeDSL MoE only supports nvfp4 quantization, " + f"got {quant_config.quant_dtype}" + ) + self.out_dtype = moe_config.in_dtype + self.hidden_dim = moe_config.hidden_dim + self.intermediate_size_per_partition = ( + moe_config.intermediate_size_per_partition + ) + self.topk = moe_config.experts_per_token + self.local_num_experts = moe_config.num_local_experts + self.global_num_experts = moe_config.num_experts + self.ep_rank = moe_config.moe_parallel_config.ep_rank + self.local_expert_offset = self.ep_rank * self.local_num_experts + # max_num_tokens from scheduler config (for buffer pre-allocation) + self.max_num_tokens = getattr(moe_config, 'max_num_tokens', 8192) + + # swiglu_limit: read from the FusedMoE layer in process_weights_after_loading + self._swiglu_limit = None + + # Runner is created in process_weights_after_loading + self._runner: CuTeDSLMoERunner | None = None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + """Convert NVFP4 MoE weights into CuTeDSL kernel format. + + Reads w13/w2 weight tensors from the FusedMoE layer, converts them + to the CuTeDSL runner's expected format, and creates the runner. + Also folds the activation global scale (input_scale) into the + weight global scale (weight_scale_2) as the runner's alpha. + """ + num_experts = layer.w13_weight.shape[0] + hidden_size = self.hidden_dim + intermediate_size = self.intermediate_size_per_partition + device = layer.w13_weight.device + + # NOTE: For the CuTeDSL kernel, we do NOT fold input_scale into + # weight_scale_2. The CuTeDSL runner uses weight global scale + # (weight_scale_2) and activation global scale separately. + # The activation global scale is computed via warmup before first inference. + # + # Also, convert_to_nvfp4_moe_kernel_format already inverted input_scale + # (1.0 / a13_scale) for the quant config. We undo that inversion here + # to get the original input_scale, then use it as initial activation gs. + if layer.w13_input_scale is not None and not isinstance(layer.w13_input_scale, float): + # input_scale was inverted in convert_to_nvfp4_moe_kernel_format + # Original: input_scale = amax / (6.0 * 448.0) + # Inverted: 1.0 / input_scale = 6.0 * 448.0 / amax + # We need the original for activation gs + w13_input_scale_orig = 1.0 / layer.w13_input_scale + else: + w13_input_scale_orig = None + if layer.w2_input_scale is not None and not isinstance(layer.w2_input_scale, float): + w2_input_scale_orig = 1.0 / layer.w2_input_scale + else: + w2_input_scale_orig = None + + # Extract and convert weights for CuTeDSL runner + # w13_weight: (E, 2*intermediate, hidden//2) uint8 — gate + up fused + # w2_weight: (E, hidden, intermediate//2) uint8 — down + l1_fp4_list = [] + l1_sf_list = [] + l1_gs_list = [] + l2_fp4_list = [] + l2_sf_list = [] + l2_gs_list = [] + + for expert_id in range(num_experts): + # L1: gate + up (w13) + w13_uint8 = layer.w13_weight.data[expert_id] # (2*inter, hidden//2) + w13_sf = layer.w13_weight_scale.data[expert_id] # (2*inter, hidden//16) fp8 + w13_gs = layer.w13_weight_scale_2.data[expert_id].item() # float32 + + # uint8 → float4_e2m1fn_x2, permute to (K_packed, N) for CuTeDSL + l1_w = w13_uint8.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous() + # Block scales: (N, K_sf) → (K_sf, N) for CuTeDSL + l1_s = w13_sf.permute(1, 0).contiguous() + if l1_s.dtype != torch.float8_e4m3fn: + l1_s = l1_s.to(torch.float8_e4m3fn) + + l1_fp4_list.append(l1_w) + l1_sf_list.append(l1_s) + l1_gs_list.append(w13_gs) + + # L2: down (w2) + w2_uint8 = layer.w2_weight.data[expert_id] # (hidden, intermediate//2) + w2_sf = layer.w2_weight_scale.data[expert_id] # (hidden, intermediate//16) fp8 + w2_gs = layer.w2_weight_scale_2.data[expert_id].item() # float32 + + l2_w = w2_uint8.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous() + l2_s = w2_sf.permute(1, 0).contiguous() + if l2_s.dtype != torch.float8_e4m3fn: + l2_s = l2_s.to(torch.float8_e4m3fn) + + l2_fp4_list.append(l2_w) + l2_sf_list.append(l2_s) + l2_gs_list.append(w2_gs) + + # Create the CuTeDSL runner + self._runner = CuTeDSLMoERunner( + num_experts=num_experts, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + max_num_tokens=self.max_num_tokens, + top_k=self.topk, + device=str(device), + experts_start_idx=self.local_expert_offset, + ) + self._runner.prepare_weights_direct( + l1_fp4_list, l1_sf_list, l1_gs_list, + l2_fp4_list, l2_sf_list, l2_gs_list, + ) + if self._swiglu_limit is not None: + self._runner.set_swiglu_limit(float(self._swiglu_limit)) + + # Read swiglu_limit from the FusedMoE layer (set by DeepseekV4MoE) + swiglu_limit = getattr(layer, 'swiglu_limit', None) + if swiglu_limit is not None: + self._swiglu_limit = swiglu_limit + self._runner.set_swiglu_limit(float(swiglu_limit)) + + # Set initial activation global scales from checkpoint input_scale. + # The CuTeDSL runner uses activation_gs = 1.0 / input_scale from the + # checkpoint as the starting value. The warmup step + # (compute_activation_global_scales) will override this with an + # empirically computed value before the first inference. + if w13_input_scale_orig is not None: + # input_scale = 448.0 / amax → activation_gs = 1.0 / input_scale = amax / 448.0 + # Mean across experts (they should be similar) + mean_l1_gs = float(w13_input_scale_orig.mean().item()) + if mean_l1_gs > 0: + self._runner._l1_activation_global_scale = 1.0 / mean_l1_gs + if w2_input_scale_orig is not None: + mean_l2_gs = float(w2_input_scale_orig.mean().item()) + if mean_l2_gs > 0: + self._runner._l2_activation_global_scale = 1.0 / mean_l2_gs + + # Note: activation global scale warmup must be done after + # process_weights_after_loading, before the first inference. + # This is handled by the model's load_weights or a separate warmup step. + + @property + def runner(self) -> CuTeDSLMoERunner | None: + return self._runner + + @staticmethod + def activation_format() -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.Standard + + @staticmethod + def _supports_current_device() -> bool: + # CuTeDSL requires CUDA SM100 (Blackwell) + p = current_platform + return p.is_cuda() and p.is_device_capability_family(100) + + @staticmethod + def _supports_no_act_and_mul() -> bool: + return False + + @staticmethod + def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + SUPPORTED_W_A = [ + (kNvfp4Static, kNvfp4Dynamic), + ] + return (weight_key, activation_key) in SUPPORTED_W_A + + @staticmethod + def _supports_activation(activation: MoEActivation) -> bool: + # We handle SiLU + swiglu_limit internally + return activation == MoEActivation.SILU + + @staticmethod + def _supports_parallel_config( + moe_parallel_config: FusedMoEParallelConfig, + ) -> bool: + return True + + def supports_expert_map(self) -> bool: + return False + + @property + def expects_unquantized_inputs(self) -> bool: + # Our runner handles activation quantization internally + return True + + def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: + return TopKWeightAndReduceNoOP() + + def workspace_shapes( + self, + M: int, + N: int, + K: int, + topk: int, + global_num_experts: int, + local_num_experts: int, + expert_tokens_meta: mk.ExpertTokensMetadata | None, + activation: MoEActivation, + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: + # Our runner manages its own workspace internally (pre-allocated buffers) + workspace1 = (0,) + workspace2 = (0,) + # K is packed (K//2 for uint8), so output uses hidden_dim + assert self.hidden_dim == K * 2 + output = (M, self.hidden_dim) + return (workspace1, workspace2, output) + + def apply( + self, + output: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: MoEActivation, + global_num_experts: int, + expert_map: torch.Tensor | None, + a1q_scale: torch.Tensor | None, + a2_scale: torch.Tensor | None, + workspace13: torch.Tensor | None, + workspace2: torch.Tensor | None, + expert_tokens_meta: mk.ExpertTokensMetadata | None, + apply_router_weight_on_input: bool | None, + ): + assert self._runner is not None, ( + "CuTeDSL runner not initialized. " + "Call process_weights_after_loading first." + ) + + # Our runner expects topk_ids as global expert IDs. + # The modular kernel framework may pass local IDs with expert_map. + # We handle remapping internally via experts_start_idx. + result = self._runner.run( + hidden_states=hidden_states, + topk_weights=topk_weights, + topk_ids=topk_ids, + ) + + # Copy result into output tensor + output.copy_(result) diff --git a/vllm/patches/fused_moe/oracle/nvfp4.py b/vllm/patches/fused_moe/oracle/nvfp4.py new file mode 100644 index 00000000..e29fc746 --- /dev/null +++ b/vllm/patches/fused_moe/oracle/nvfp4.py @@ -0,0 +1,535 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from enum import Enum + +import torch + +import vllm.envs as envs +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.config.kernel import MoEBackend +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.all2all_utils import ( + maybe_make_prepare_finalize, +) +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEQuantConfig, + nvfp4_moe_quant_config, + nvfp4_w4a16_moe_quant_config, +) +from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( + prepare_nvfp4_moe_layer_for_fi_or_cutlass, + prepare_nvfp4_moe_layer_for_flashinfer_cutedsl, +) +from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( + FlashinferMoeBackend, + get_flashinfer_moe_backend, +) +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( + prepare_nvfp4_moe_layer_for_marlin, +) +from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( + kE2M1ToFloat_handle, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, +) + +logger = init_logger(__name__) + + +class NvFp4MoeBackend(Enum): + FLASHINFER_TRTLLM = "FLASHINFER_TRTLLM" + FLASHINFER_CUTLASS = "FLASHINFER_CUTLASS" + FLASHINFER_CUTEDSL = "FLASHINFER_CUTEDSL" + FLASHINFER_CUTEDSL_BATCHED = "FLASHINFER_CUTEDSL_BATCHED" + VLLM_CUTLASS = "VLLM_CUTLASS" + MARLIN = "MARLIN" + CUTEDSL = "CUTEDSL" + EMULATION = "EMULATION" + + +FLASHINFER_NVFP4_MOE_BACKENDS = [ + NvFp4MoeBackend.FLASHINFER_TRTLLM, + NvFp4MoeBackend.FLASHINFER_CUTLASS, + NvFp4MoeBackend.FLASHINFER_CUTEDSL, + NvFp4MoeBackend.FLASHINFER_CUTEDSL_BATCHED, +] + +CUTEDSL_NVFP4_MOE_BACKENDS = [ + NvFp4MoeBackend.CUTEDSL, +] + +fi_2_vllm_backend_map: dict[FlashinferMoeBackend, NvFp4MoeBackend] = { + FlashinferMoeBackend.CUTLASS: NvFp4MoeBackend.FLASHINFER_CUTLASS, + FlashinferMoeBackend.TENSORRT_LLM: NvFp4MoeBackend.FLASHINFER_TRTLLM, + FlashinferMoeBackend.CUTEDSL: NvFp4MoeBackend.FLASHINFER_CUTEDSL, +} + + +def is_global_sf_supported_for_nvfp4_backend(backend: NvFp4MoeBackend) -> bool: + # Checks whether `backend` supports quantizing with scaling factors + # of all experts in Expert Parallel Mode when all experts are not + # on the same rank. + + return backend in FLASHINFER_NVFP4_MOE_BACKENDS or backend in CUTEDSL_NVFP4_MOE_BACKENDS + + +def backend_to_kernel_cls( + backend: NvFp4MoeBackend, +) -> list[type[mk.FusedMoEExperts]]: + if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM: + from vllm.model_executor.layers.fused_moe.experts.trtllm_nvfp4_moe import ( + TrtLlmNvFp4ExpertsModular, + TrtLlmNvFp4ExpertsMonolithic, + ) + + # NOTE: prefer Monolthic > Modular, so return Monolithic first. + return [ + TrtLlmNvFp4ExpertsMonolithic, + TrtLlmNvFp4ExpertsModular, + ] + + elif backend == NvFp4MoeBackend.FLASHINFER_CUTLASS: + from vllm.model_executor.layers.fused_moe.experts.flashinfer_cutlass_moe import ( # noqa: E501 + FlashInferExperts, + ) + + return [FlashInferExperts] + + elif backend == NvFp4MoeBackend.FLASHINFER_CUTEDSL: + from vllm.model_executor.layers.fused_moe.experts.flashinfer_cutedsl_moe import ( # noqa: E501 + FlashInferCuteDSLExperts, + ) + + return [FlashInferCuteDSLExperts] + + elif backend == NvFp4MoeBackend.FLASHINFER_CUTEDSL_BATCHED: + from vllm.model_executor.layers.fused_moe.experts.flashinfer_cutedsl_batched_moe import ( # noqa: E501 + FlashInferCuteDSLBatchedExperts, + ) + + return [FlashInferCuteDSLBatchedExperts] + + elif backend == NvFp4MoeBackend.CUTEDSL: + from vllm.model_executor.layers.fused_moe.experts.cutedsl_moe import ( # noqa: E501 + CuTeDSLMoEExperts, + ) + + return [CuTeDSLMoEExperts] + + elif backend == NvFp4MoeBackend.VLLM_CUTLASS: + from vllm.model_executor.layers.fused_moe.experts.cutlass_moe import ( + CutlassExpertsFp4, + ) + + return [CutlassExpertsFp4] + + elif backend == NvFp4MoeBackend.MARLIN: + from vllm.model_executor.layers.fused_moe.experts.marlin_moe import ( + MarlinExperts, + ) + + return [MarlinExperts] + elif backend == NvFp4MoeBackend.EMULATION: + from vllm.model_executor.layers.fused_moe.experts.nvfp4_emulation_moe import ( + Nvfp4QuantizationEmulationTritonExperts, + ) + + return [Nvfp4QuantizationEmulationTritonExperts] + else: + raise ValueError(f"Unknown NvFP4 MoE backend: {backend.value}") + + +def map_nvfp4_backend(runner_backend: MoEBackend) -> NvFp4MoeBackend: + """Map user's MoEBackend to NvFp4MoeBackend.""" + mapping = { + "cutlass": NvFp4MoeBackend.VLLM_CUTLASS, + "flashinfer_trtllm": NvFp4MoeBackend.FLASHINFER_TRTLLM, + "flashinfer_cutlass": NvFp4MoeBackend.FLASHINFER_CUTLASS, + "flashinfer_cutedsl": NvFp4MoeBackend.FLASHINFER_CUTEDSL, + "cutedsl": NvFp4MoeBackend.CUTEDSL, + "marlin": NvFp4MoeBackend.MARLIN, + "emulation": NvFp4MoeBackend.EMULATION, + } + if backend := mapping.get(runner_backend): + return backend + raise ValueError( + f"moe_backend='{runner_backend}' is not supported for NvFP4 MoE. " + f"Expected one of {list(mapping.keys())}." + ) + + +def select_nvfp4_moe_backend( + config: FusedMoEConfig, + weight_key: QuantKey | None, + activation_key: QuantKey | None, +) -> tuple[NvFp4MoeBackend, type[mk.FusedMoEExperts]]: + """ + Select the primary NvFP4 MoE backend + Note: Shape-specific fallbacks may still occur at runtime. + """ + + # NOTE: the kernels are selected in the following order. + AVAILABLE_BACKENDS = [ + NvFp4MoeBackend.FLASHINFER_TRTLLM, + NvFp4MoeBackend.FLASHINFER_CUTEDSL, + NvFp4MoeBackend.FLASHINFER_CUTEDSL_BATCHED, + NvFp4MoeBackend.CUTEDSL, + NvFp4MoeBackend.FLASHINFER_CUTLASS, + NvFp4MoeBackend.VLLM_CUTLASS, + NvFp4MoeBackend.MARLIN, + NvFp4MoeBackend.EMULATION, + ] + + use_batched = config.moe_parallel_config.use_batched_activation_format + activation_format = ( + mk.FusedMoEActivationFormat.BatchedExperts + if use_batched + else mk.FusedMoEActivationFormat.Standard + ) + + def _make_log_backend(backend: NvFp4MoeBackend): + available_backend_strs = [b.value for b in AVAILABLE_BACKENDS] + return ( + f"Using '{backend.value}' NvFp4 MoE backend out " + f"of potential backends: {available_backend_strs}." + ) + + def _make_log_unsupported(backend: NvFp4MoeBackend, reason: str | None) -> str: + if reason: + return ( + f"NvFp4 MoE backend '{backend.value}' does not support the " + f"deployment configuration since {reason}." + ) + else: + return ( + f"NvFp4 MoE backend '{backend.value}' does not support the " + "deployment configuration." + ) + + def _return_or_raise( + backend: NvFp4MoeBackend, + config: FusedMoEConfig, + weight_key: QuantKey | None, + activation_key: QuantKey | None, + activation_format: mk.FusedMoEActivationFormat, + ) -> tuple[NvFp4MoeBackend, type[mk.FusedMoEExperts]]: + for k_cls in backend_to_kernel_cls(backend): + supported, reason = k_cls.is_supported_config( + k_cls, config, weight_key, activation_key, activation_format + ) + if supported: + logger.info_once(_make_log_backend(backend)) + return backend, k_cls + + raise ValueError(_make_log_unsupported(backend, reason)) + + # Handle explicit moe_backend from user. + runner_backend = config.moe_backend + if runner_backend != "auto": + requested_backend = map_nvfp4_backend(runner_backend) + # For batched activation format, use batched variant if available. + if ( + activation_format == mk.FusedMoEActivationFormat.BatchedExperts + and requested_backend == NvFp4MoeBackend.FLASHINFER_CUTEDSL + ): + requested_backend = NvFp4MoeBackend.FLASHINFER_CUTEDSL_BATCHED + return _return_or_raise( + requested_backend, config, weight_key, activation_key, activation_format + ) + + if envs.is_set("VLLM_USE_FLASHINFER_MOE_FP4"): + if not envs.VLLM_USE_FLASHINFER_MOE_FP4: + # If the user rejects FlashInfer remove those backends. + for b in FLASHINFER_NVFP4_MOE_BACKENDS: + AVAILABLE_BACKENDS.remove(b) + + elif envs.is_set("VLLM_FLASHINFER_MOE_BACKEND"): + # If user is explicit about backend, validate it. + backend = fi_2_vllm_backend_map[get_flashinfer_moe_backend()] + return _return_or_raise( + backend, config, weight_key, activation_key, activation_format + ) + else: + # If the user is not explicit about the backend, try each. + for backend in FLASHINFER_NVFP4_MOE_BACKENDS: + for k_cls in backend_to_kernel_cls(backend): + supported, reason = k_cls.is_supported_config( + k_cls, + config, + weight_key, + activation_key, + activation_format, + ) + if supported: + logger.info_once(_make_log_backend(backend)) + return backend, k_cls + else: + logger.debug_once(_make_log_unsupported(backend, reason)) + + raise NotImplementedError( + "Found VLLM_USE_FLASHINFER_MOE_FP4=1, but no " + "FlashInfer NVFP4 MoE backend supports the configuration." + ) + + if envs.VLLM_TEST_FORCE_FP8_MARLIN: + backend = NvFp4MoeBackend.MARLIN + return _return_or_raise( + backend, config, weight_key, activation_key, activation_format + ) + + # Select kernels in order of backend. + for backend in AVAILABLE_BACKENDS: + for k_cls in backend_to_kernel_cls(backend): + supported, reason = k_cls.is_supported_config( + k_cls, + config, + weight_key, + activation_key, + activation_format, + ) + if supported: + logger.info_once(_make_log_backend(backend)) + return backend, k_cls + else: + logger.debug_once(_make_log_unsupported(backend, reason)) + + raise NotImplementedError( + "No NvFp4 MoE backend supports the deployment configuration." + ) + + +def convert_to_nvfp4_moe_kernel_format( + nvfp4_backend: NvFp4MoeBackend, + layer: torch.nn.Module, + w13: torch.Tensor, + w13_scale: torch.Tensor, + w13_scale_2: torch.Tensor, + a13_scale: torch.Tensor | None, + w2: torch.Tensor, + w2_scale: torch.Tensor, + w2_scale_2: torch.Tensor, + a2_scale: torch.Tensor | None, + is_act_and_mul: bool, +) -> tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, +]: + if nvfp4_backend == NvFp4MoeBackend.CUTEDSL: + # CuTeDSL kernel handles weight conversion in its own + # process_weights_after_loading. Pass through raw weights. + # Compute inverse activation scales for the quant config. + if a13_scale is not None: + a13_scale = 1.0 / a13_scale + if a2_scale is not None: + a2_scale = 1.0 / a2_scale + elif nvfp4_backend == NvFp4MoeBackend.FLASHINFER_CUTEDSL: + ( + w13, + w13_scale, + w13_scale_2, + a13_scale, + w2, + w2_scale, + w2_scale_2, + a2_scale, + ) = prepare_nvfp4_moe_layer_for_flashinfer_cutedsl( + layer=layer, + w13=w13, + w13_scale=w13_scale, + w13_scale_2=w13_scale_2, + a13_scale=a13_scale, + w2=w2, + w2_scale=w2_scale, + w2_scale_2=w2_scale_2, + a2_scale=a2_scale, + ) + elif ( + nvfp4_backend in FLASHINFER_NVFP4_MOE_BACKENDS + or nvfp4_backend == NvFp4MoeBackend.VLLM_CUTLASS + ): + ( + w13, + w13_scale, + w13_scale_2, + a13_scale, + w2, + w2_scale, + w2_scale_2, + a2_scale, + ) = prepare_nvfp4_moe_layer_for_fi_or_cutlass( + backend=nvfp4_backend, + layer=layer, + w13=w13, + w13_scale=w13_scale, + w13_scale_2=w13_scale_2, + a13_scale=a13_scale, + w2=w2, + w2_scale=w2_scale, + w2_scale_2=w2_scale_2, + a2_scale=a2_scale, + is_act_and_mul=is_act_and_mul, + ) + elif nvfp4_backend == NvFp4MoeBackend.MARLIN: + a13_scale = None + a2_scale = None + ( + w13, + w13_scale, + w13_scale_2, + w2, + w2_scale, + w2_scale_2, + ) = prepare_nvfp4_moe_layer_for_marlin( + layer=layer, + w13=w13, + w13_scale=w13_scale, + w13_scale_2=w13_scale_2, + w2=w2, + w2_scale=w2_scale, + w2_scale_2=w2_scale_2, + is_act_and_mul=is_act_and_mul, + ) + elif nvfp4_backend == NvFp4MoeBackend.EMULATION: + # Move the E2M1 lookup table to the device now, because + # `.to(device)` is not allowed during CUDA graph capture. + kE2M1ToFloat_handle.val = kE2M1ToFloat_handle.val.to(w13.device) + + if a13_scale is None or a2_scale is None: + raise ValueError( + "Activation global scales should not be None, got" + f" a13_scale={a13_scale}, a2_scale={a2_scale}" + ) + + if torch.unique(a13_scale).numel() != 1 or torch.unique(a2_scale).numel() != 1: + logger.warning_once( + "In NVFP4 linear, the activation global scale for inputs are different" + " for MOE w13 (gate_up_proj) layer or MOE w2 (down_proj). Using" + " a13_scale = a13_scale.max() and a2_scale = a2_scale.max()." + ) + + # 1. We take the max following e.g. quantization/utils/flashinfer_fp4_moe.py. + # 2. moe_kernel_quantize_input -> ref_nvfp4_quant_dequant + # use the inverse scale directly (large global scale). + # NOTE: Before this point, `a13_scale` and `a2_scale` are such that: + # `FP8_MAX = activation[expert_id].abs().max() * global_scale[expert_id]`, + # and `global_scale[expert_id]` are small (~1e-4). + # Taking the largest global scale likely results in overflowing the FP8 range + # for other experts - other selection strategies may be used. + a13_scale = 1.0 / a13_scale.max().to(torch.float32) + a2_scale = 1.0 / a2_scale.max().to(torch.float32) + else: + raise ValueError(f"Unknown NvFp4 backend for MoE: {nvfp4_backend}") + + return ( + w13, + w13_scale, + w13_scale_2, + a13_scale, + w2, + w2_scale, + w2_scale_2, + a2_scale, + ) + + +def make_nvfp4_moe_quant_config( + backend: NvFp4MoeBackend, + w13_scale: torch.Tensor, + w2_scale: torch.Tensor, + w13_scale_2: torch.Tensor, + w2_scale_2: torch.Tensor, + a13_scale: torch.Tensor, + a2_scale: torch.Tensor, +) -> FusedMoEQuantConfig: + if backend == NvFp4MoeBackend.MARLIN: + return nvfp4_w4a16_moe_quant_config( + g1_alphas=w13_scale_2, + g2_alphas=w2_scale_2, + w1_scale=w13_scale, + w2_scale=w2_scale, + ) + elif backend == NvFp4MoeBackend.EMULATION: + return nvfp4_moe_quant_config( + g1_alphas=w13_scale_2, + g2_alphas=w2_scale_2, + a1_gscale=a13_scale, + a2_gscale=a2_scale, + w1_scale=w13_scale, + w2_scale=w2_scale, + ) + + # Pass w13_scale_2 / w2_scale_2 directly as g1/g2_alphas. + # The expert's process_weights_after_loading will fuse activation + # scales in-place. Since the quant config references the same tensor + # as the registered parameter, EPLB rearrangement stays in sync. + return nvfp4_moe_quant_config( + g1_alphas=w13_scale_2, + g2_alphas=w2_scale_2, + a1_gscale=(1.0 / a13_scale), + a2_gscale=(1.0 / a2_scale), + w1_scale=w13_scale, + w2_scale=w2_scale, + # NOTE(rob): this is a hack until the MoE kernels + # create their own quant configs. TRTLLM kernel + # does not accept swizzled input quant scales. + is_scale_swizzled=( + backend + not in ( + NvFp4MoeBackend.FLASHINFER_TRTLLM, + NvFp4MoeBackend.FLASHINFER_CUTEDSL, + NvFp4MoeBackend.CUTEDSL, + ) + ), + ) + + +def make_nvfp4_moe_kernel( + moe_quant_config: FusedMoEQuantConfig, + moe_config: FusedMoEConfig, + experts_cls: type[mk.FusedMoEExperts], + routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, +) -> mk.FusedMoEKernel: + # Create Prepare/Finalize. + prepare_finalize = maybe_make_prepare_finalize( + moe=moe_config, + quant_config=moe_quant_config, + routing_tables=routing_tables, + allow_new_interface=True, + use_monolithic=issubclass(experts_cls, mk.FusedMoEExpertsMonolithic), + ) + assert prepare_finalize is not None + + logger.info_once("Using %s", prepare_finalize.__class__.__name__) + + # Create Experts. + if prepare_finalize.activation_format == mk.FusedMoEActivationFormat.BatchedExperts: + max_num_tokens = prepare_finalize.max_num_tokens_per_rank() + assert max_num_tokens is not None + experts = experts_cls( + moe_config=moe_config, + quant_config=moe_quant_config, + max_num_tokens=max_num_tokens, + num_dispatchers=prepare_finalize.num_dispatchers(), + ) + else: + experts = experts_cls( + moe_config=moe_config, + quant_config=moe_quant_config, + ) + + kernel = mk.FusedMoEKernel( + prepare_finalize, + experts, + inplace=False, + ) + + # TODO(rob): update inplace logic to be part of the kernel. + return kernel diff --git a/vllm/patches/modelopt.py b/vllm/patches/modelopt.py new file mode 100644 index 00000000..4a3b7619 --- /dev/null +++ b/vllm/patches/modelopt.py @@ -0,0 +1,2378 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from fnmatch import fnmatch +from typing import TYPE_CHECKING, Any + +import torch +from torch.nn.parameter import Parameter + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.config import get_current_vllm_config +from vllm.logger import init_logger +from vllm.model_executor.kernels.linear import ( + MarlinNvFp4LinearKernel, + NvFp4LinearLayerConfig, + init_fp8_linear_kernel, + init_mxfp8_linear_kernel, + init_nvfp4_linear_kernel, +) +from vllm.model_executor.layers.attention import Attention, MLAAttention +from vllm.model_executor.layers.fused_moe import ( + FusedMoEConfig, + FusedMoEMethodBase, + FusedMoEQuantConfig, + FusedMoeWeightScaleSupported, + MoEActivation, + RoutedExperts, + RoutingMethodType, + SharedExperts, +) +from vllm.model_executor.layers.fused_moe.oracle.fp8 import ( + Fp8MoeBackend, + convert_to_fp8_moe_kernel_format, + make_fp8_moe_kernel, + make_fp8_moe_quant_config, + select_fp8_moe_backend, +) +from vllm.model_executor.layers.fused_moe.oracle.mxfp8 import ( + select_mxfp8_moe_backend, +) +from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import ( + convert_to_nvfp4_moe_kernel_format, + is_global_sf_supported_for_nvfp4_backend, + make_nvfp4_moe_kernel, + make_nvfp4_moe_quant_config, + select_nvfp4_moe_backend, +) +from vllm.model_executor.layers.linear import ( + LinearBase, + LinearMethodBase, + UnquantizedLinearMethod, +) +from vllm.model_executor.layers.quantization import QuantizationMethods +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, + QuantizeMethodBase, +) +from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod +from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( + swap_w13_to_w31, +) +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + process_fp8_input_tensor_strategy_moe, + process_fp8_weight_tensor_strategy_moe, +) +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + get_marlin_input_dtype, +) +from vllm.model_executor.layers.quantization.utils.mxfp8_utils import ( + MXFP8_BLOCK_SIZE, + MXFP8_SCALE_DTYPE, + MXFP8_VALUE_DTYPE, + mxfp8_e4m3_quantize, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape, + create_fp8_quant_key, + is_layer_skipped, + kFp8DynamicTokenSym, + kFp8StaticTensorSym, + kFp8StaticTokenSym, + kNvfp4Dynamic, + kNvfp4Static, +) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + requantize_with_max_scale, +) +from vllm.model_executor.parameter import ( + BlockQuantScaleParameter, + ChannelQuantScaleParameter, + GroupQuantScaleParameter, + ModelWeightParameter, + PerTensorScaleParameter, +) +from vllm.model_executor.utils import replace_parameter, set_weight_attrs +from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe + +if TYPE_CHECKING: + from vllm.model_executor.models.utils import WeightsMapper + +logger = init_logger(__name__) + +QUANT_ALGOS = [ + # FP8 (per-tensor weight + optional static activation scale). + "FP8", + # FP8 per-channel weight scale + per-token activation scale. + "FP8_PER_CHANNEL_PER_TOKEN", + # FP8 per-block weight-only (ModelOpt may emit this as lowercase). + "FP8_PB_WO", + # NVFP4 W4A4 (4-bit float weights AND 4-bit float activations). + "NVFP4", + # W4A16 NVFP4 (4-bit float weights, fp16/bf16 activations). + "W4A16_NVFP4", + # MXFP8 + "MXFP8", + # MIXED_PRECISION, + "MIXED_PRECISION", +] +KV_CACHE_QUANT_ALGOS = ["FP8", "NVFP4"] + + +class ModelOptKVCacheMethod(BaseKVCacheMethod): + """ + Supports loading kv-cache scaling factors from FP8 or NVFP4 checkpoints. + """ + + def __init__(self, quant_config: "ModelOptQuantConfigBase"): + super().__init__(quant_config) + + +class ModelOptQuantConfigBase(QuantizationConfig): + LinearMethodCls: type = LinearMethodBase + FusedMoEMethodCls: type = FusedMoEMethodBase + KVCacheMethodCls: type = BaseKVCacheMethod + + def __init__( + self, + exclude_modules: list[str], + ): + super().__init__() + self.exclude_modules: list[str] = exclude_modules + + def is_layer_excluded(self, prefix: str) -> bool: + """ + Check if a layer should be excluded from quantization. + + Handles both exact matching (for fused layers) and ModelOpt wildcard matching. + + The ModelOpt exclude_modules list is a list of wildcards. + """ + if len(self.exclude_modules) == 0: + return False + + # First check exact matching with fused layer support + if is_layer_skipped(prefix, self.exclude_modules, self.packed_modules_mapping): + return True + + # TODO: This special hard coded logic is not needed for quantized checkpoints + # generated by ModelOpt >= 0.39.0 where they are handled natually by the + # exclude_modules config. But need to keep them for loading quantized + # checkpoints generated by older versions. Then check substring matching + # for patterns not caught by exact match + for exclude_module in self.exclude_modules: + # Skip exact matches already handled above + if exclude_module != prefix and ( + exclude_module in prefix + or ( + prefix.startswith("language_model.") + and exclude_module in prefix.removeprefix("language_model.") + ) + ): + return True + + # modelopt exclude modules are not simple strings, they are wildcards + for wildcard_pattern in self.exclude_modules: + if fnmatch(prefix, wildcard_pattern): + return True + + return False + + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> "QuantizeMethodBase | None": + # handle kv-cache first so we can focus only on weight quantization thereafter + if isinstance(layer, (Attention, MLAAttention)): + return self.KVCacheMethodCls(self) + + # handle exclusion + if self.is_layer_excluded(prefix): + if isinstance(layer, LinearBase): + return UnquantizedLinearMethod() + return None + + # TODO: This special hard coded logic is not needed for quantized checkpoints + # generated by ModelOpt >= 0.39.0 where they are handled natually by the + # exclude_modules config. But need to keep them for loading quantized + # checkpoints generated by older versions. Then check substring matching + # for patterns not caught by exact match + if "vision_tower" in prefix or "vision_model" in prefix: + return UnquantizedLinearMethod() + + # now, the layer is quantized, handle it here + if isinstance(layer, LinearBase): + quant_method = self.LinearMethodCls(self) + if getattr(quant_method, "backend", "") == "marlin": + quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix) + return quant_method + elif isinstance(layer, RoutedExperts): + quant_method = self.FusedMoEMethodCls( + quant_config=self, moe_config=layer.moe_config + ) + if getattr(quant_method, "backend", "") == "marlin": + quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix) + return quant_method + + return None + + def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): + if len(self.exclude_modules) > 0: + # This is a workaround for the weights remapping issue: + # https://github.com/vllm-project/vllm/issues/28072 + # Right now, the Nvidia ModelOpt library use just one wildcard pattern: + # module_path* + # It gets applied if the whole tree of modules rooted at module_path + # is not quantized. Here we replace such pattern by 2 patterns that are + # collectively equivalent to the original pattern: + # module_path + # module_path.* + new_exclude_modules = [] + for exclude in self.exclude_modules: + if len(exclude) >= 2 and exclude[-1] == "*" and exclude[-2] != ".": + new_exclude_modules.append(exclude[:-1]) + new_exclude_modules.append(exclude[:-1] + ".*") + else: + new_exclude_modules.append(exclude) + + self.exclude_modules = hf_to_vllm_mapper.apply_list(new_exclude_modules) + + @staticmethod + def _extract_modelopt_quant_algo( + hf_quant_cfg: dict[str, Any] | None, + ) -> str | None: + """Extract upper-cased quant_algo from a modelopt config. + + Returns the quant_algo string (upper-cased), or None if the config + is not a modelopt config. + """ + if hf_quant_cfg is None: + return None + if not hf_quant_cfg.get("quant_method", "").lower().startswith("modelopt"): + return None + if "quantization" in hf_quant_cfg: + quant_config = hf_quant_cfg["quantization"] + if isinstance(quant_config, dict): + return str(quant_config.get("quant_algo", "")).upper() + return None + return str(hf_quant_cfg.get("quant_algo", "")).upper() + + @staticmethod + def get_config_filenames() -> list[str]: + return ["hf_quant_config.json"] + + @classmethod + def _from_config( + cls, + *, + quant_method: str, + kv_cache_quant_method: str | None, + exclude_modules: list[str], + original_config: dict[str, Any], + group_size: int | None, + ) -> "ModelOptQuantConfigBase": + raise NotImplementedError("Please implement this function in sub classes") + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "ModelOptQuantConfigBase": + # Handle both ModelOpt format and compressed-tensors style format + if "quantization" in config: + # Traditional ModelOpt format: + # {"quantization": {"quant_algo": "..."}} + quant_config = cls.get_from_keys(config, ["quantization"]) + if not isinstance(quant_config, dict): + raise ValueError("Expected 'quantization' to be a dictionary in config") + + quant_method = quant_config.get("quant_algo") + + # Handle kv_cache_quant_algo with proper type validation + kv_cache_quant_method = quant_config.get("kv_cache_quant_algo") + + # Handle group_size with proper type validation + group_size_raw = quant_config.get("group_size") + + # "exclude_modules" is the key in the legacy hf_quant_config.json + exclude_modules = quant_config.get("exclude_modules", []) + else: + # Compressed-tensors style format (config.json quantization_config): + # {"quant_algo": "...", "quant_method": "modelopt"} + quant_method = config.get("quant_algo") + + # "kv_cache_scheme" (a dict) instead of "kv_cache_quant_algo" (a string). + kv_cache_scheme = config.get("kv_cache_scheme") + if isinstance(kv_cache_scheme, dict) and ( + kv_cache_scheme.get("type") == "float" + and kv_cache_scheme.get("num_bits") == 8 + ): + kv_cache_quant_method = "FP8" + else: + kv_cache_quant_method = None + + # "ignore" is the key in config.json + exclude_modules = config.get("ignore", []) + group_size_raw = config.get("group_size") + + if not quant_method: + raise ValueError("Missing 'quant_algo' in quantization config") + + # Normalize quant_algo for robust matching (ModelOpt may emit lowercase). + quant_method = str(quant_method).upper() + + if kv_cache_quant_method is None: + # No KV cache quantization, keep this branch just to have this comment + pass + elif not isinstance(kv_cache_quant_method, str): + raise ValueError( + f"kv_cache_quant_algo must be a string, got " + f"{type(kv_cache_quant_method)}" + ) + else: + kv_cache_quant_method = kv_cache_quant_method.upper() + + if not isinstance(exclude_modules, list): + raise ValueError( + f"exclude_modules must be a list, got {type(exclude_modules)}" + ) + + if group_size_raw is None: + group_size = None + elif isinstance(group_size_raw, int): + group_size = group_size_raw + else: + try: + group_size = int(group_size_raw) + except (ValueError, TypeError): + raise ValueError( + f"group_size must be an integer, got {type(group_size_raw)}" + ) from None + + if quant_method not in QUANT_ALGOS: + raise ValueError( + f"ModelOpt currently only supports: {QUANT_ALGOS} " + "quantizations in vLLM. Please check the " + "`hf_quant_config.json` file for your model's " + "quant configuration." + ) + return cls._from_config( + quant_method=quant_method, + kv_cache_quant_method=kv_cache_quant_method, + exclude_modules=exclude_modules, + group_size=group_size, + original_config=config, + ) + + +class ModelOptFp8Config(ModelOptQuantConfigBase): + """Config class for ModelOpt FP8.""" + + def __init__( + self, + quant_method: str, + is_checkpoint_fp8_serialized: bool, + kv_cache_quant_method: str | None, + exclude_modules: list[str], + ) -> None: + super().__init__(exclude_modules) + self.quant_method = quant_method + self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized + self.kv_cache_quant_method = kv_cache_quant_method + if is_checkpoint_fp8_serialized: + logger.warning( + "Detected ModelOpt fp8 checkpoint (quant_algo=%s). Please note " + "that the format is experimental and could change.", + quant_method, + ) + + # Select LinearMethod implementation based on quant_algo. + if self.quant_method == "FP8": + self.LinearMethodCls = ModelOptFp8LinearMethod + elif self.quant_method == "FP8_PER_CHANNEL_PER_TOKEN": + self.LinearMethodCls = ModelOptFp8PcPtLinearMethod + elif self.quant_method == "FP8_PB_WO": + self.LinearMethodCls = ModelOptFp8PbWoLinearMethod + else: + raise ValueError( + "Unsupported ModelOpt FP8 quant_algo for vLLM: " + f"{self.quant_method}. Supported: FP8 / " + "FP8_PER_CHANNEL_PER_TOKEN / FP8_PB_WO." + ) + + def get_name(self) -> QuantizationMethods: + return "modelopt" + + def get_supported_act_dtypes(self) -> list[torch.dtype]: + return [torch.bfloat16, torch.half] + + @classmethod + def get_min_capability(cls) -> int: + return 89 + + @classmethod + def override_quantization_method( + cls, hf_quant_cfg, user_quant, hf_config=None + ) -> QuantizationMethods | None: + algo = cls._extract_modelopt_quant_algo(hf_quant_cfg) + if algo is not None and algo == "FP8": + return "modelopt" + return None + + @classmethod + def _from_config( + cls, + *, + quant_method: str, + kv_cache_quant_method: str | None, + exclude_modules: list[str], + original_config: dict[str, Any], + **kwargs: Any, + ) -> "ModelOptFp8Config": + is_checkpoint_fp8_serialized = "FP8" in quant_method + + return cls( + quant_method, + is_checkpoint_fp8_serialized, + kv_cache_quant_method, + exclude_modules, + ) + + +class ModelOptFp8LinearMethod(LinearMethodBase): + """Linear method for Model Optimizer static quantization. + Supports loading FP8 checkpoints with static weight scale and + activation scale. Future support might be added for dynamic + scales. + + Limitations: + 1. Only support per-tensor quantization due to torch._scaled_mm support. + 2. Only support float8_e4m3fn datatype + Args: quant_config: The ModelOpt quantization config. + """ + + def __init__(self, quant_config: ModelOptFp8Config) -> None: + self.quant_config = quant_config + self.out_dtype = torch.get_default_dtype() + self.input_dtype = get_current_vllm_config().model_config.dtype + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + del input_size, output_size + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + weight_dtype = ( + torch.float8_e4m3fn + if self.quant_config.is_checkpoint_fp8_serialized + else params_dtype + ) + weight = ModelWeightParameter( + data=torch.empty( + output_size_per_partition, input_size_per_partition, dtype=weight_dtype + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + if self.quant_config.is_checkpoint_fp8_serialized: + # WEIGHT SCALE + weight_scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) + weight_scale[:] = torch.finfo(torch.float32).min + layer.register_parameter("weight_scale", weight_scale) + # INPUT SCALE + scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) + + scale[:] = torch.finfo(torch.float32).min + layer.register_parameter("input_scale", scale) + + self.fp8_linear = init_fp8_linear_kernel( + activation_quant_key=kFp8StaticTensorSym, + weight_quant_key=kFp8StaticTensorSym, + weight_shape=layer.weight.shape, + input_dtype=self.input_dtype, + out_dtype=self.out_dtype, + module_name=self.__class__.__name__, + ) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + weight = layer.weight + max_w_scale = layer.weight_scale.max() + if not (layer.weight_scale == layer.weight_scale[0]).all(): + max_w_scale, weight = requantize_with_max_scale( + layer.weight, layer.weight_scale, layer.logical_widths + ) + layer.weight = Parameter(weight.t(), requires_grad=False) + layer.weight_scale = Parameter(max_w_scale, requires_grad=False) + layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False) + self.fp8_linear.process_weights_after_loading(layer) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + return self.fp8_linear.apply_weights(layer, x, bias) + + +class ModelOptFp8PcPtLinearMethod(LinearMethodBase): + """Linear method for ModelOpt FP8_PER_CHANNEL_PER_TOKEN checkpoints. + + Expected checkpoint structure (per Linear): + - weight: fp8-e4m3fn, shape [out, in] + - weight_scale: fp32, shape [out] (per-output-channel) + - no input_scale (activations are dynamically quantized per-token) + """ + + def __init__(self, quant_config: ModelOptFp8Config) -> None: + self.quant_config = quant_config + self.out_dtype = torch.get_default_dtype() + self.input_dtype = get_current_vllm_config().model_config.dtype + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + del input_size, output_size + + if not self.quant_config.is_checkpoint_fp8_serialized: + raise ValueError( + "FP8_PER_CHANNEL_PER_TOKEN currently only supports " + "FP8-serialized checkpoints." + ) + + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + + weight = ModelWeightParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=torch.float8_e4m3fn, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + weight_scale = ChannelQuantScaleParameter( + data=torch.empty(output_size_per_partition, dtype=torch.float32), + output_dim=0, + weight_loader=weight_loader, + ) + weight_scale[:] = torch.finfo(torch.float32).min + layer.register_parameter("weight_scale", weight_scale) + + self.fp8_linear = init_fp8_linear_kernel( + activation_quant_key=kFp8DynamicTokenSym, + weight_quant_key=kFp8StaticTokenSym, + weight_shape=layer.weight.shape, + input_dtype=self.input_dtype, + out_dtype=self.out_dtype, + module_name=self.__class__.__name__, + ) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + layer.weight = Parameter(layer.weight.t(), requires_grad=False) + layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False) + self.fp8_linear.process_weights_after_loading(layer) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + return self.fp8_linear.apply_weights(layer, x, bias) + + +class ModelOptFp8PbWoLinearMethod(LinearMethodBase): + """Linear method for ModelOpt FP8_PB_WO checkpoints. + + ModelOpt exports `weight_scale` as a 4D tensor: + [out_blk, 1, in_blk, 1] + where block size is typically 128 for both dims. + + vLLM executes it as FP8 GEMM with *dynamic per-token* activation quant. + """ + + _WEIGHT_BLOCK_SIZE: tuple[int, int] = (128, 128) + + def __init__(self, quant_config: ModelOptFp8Config) -> None: + self.quant_config = quant_config + block_n, block_k = self._WEIGHT_BLOCK_SIZE + self.weight_block_size = list(self._WEIGHT_BLOCK_SIZE) + + self.activation_quant_key = create_fp8_quant_key( + static=False, group_shape=GroupShape(1, block_k) + ) + self.weight_quant_key = create_fp8_quant_key( + static=True, group_shape=GroupShape(block_n, block_k) + ) + + self.out_dtype = torch.get_default_dtype() + self.input_dtype = get_current_vllm_config().model_config.dtype + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + del input_size, output_size + + if not self.quant_config.is_checkpoint_fp8_serialized: + raise ValueError( + "FP8_PB_WO currently only supports FP8-serialized checkpoints." + ) + + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + + # Expose block size so the v2 weight loaders can translate offsets from + # element-space -> block-space for BlockQuantScaleParameter. + layer.weight_block_size = self.weight_block_size + + weight = ModelWeightParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=torch.float8_e4m3fn, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + block_n, block_k = self._WEIGHT_BLOCK_SIZE + if output_size_per_partition % block_n != 0: + raise ValueError( + "ModelOpt FP8_PB_WO requires out_features divisible by " + f"{block_n}, got {output_size_per_partition}." + ) + if input_size_per_partition % block_k != 0: + raise ValueError( + "ModelOpt FP8_PB_WO requires in_features divisible by " + f"{block_k}, got {input_size_per_partition}." + ) + + out_blks = output_size_per_partition // block_n + in_blks = input_size_per_partition // block_k + + # Match ModelOpt's exported shape so weight loading works without a + # custom loader: [out_blk, 1, in_blk, 1] + weight_scale = BlockQuantScaleParameter( + data=torch.empty((out_blks, 1, in_blks, 1), dtype=torch.float32), + input_dim=2, + output_dim=0, + weight_loader=weight_loader, + ) + weight_scale[:] = torch.finfo(torch.float32).min + layer.register_parameter("weight_scale", weight_scale) + + self.w8a8_block_fp8_linear = init_fp8_linear_kernel( + activation_quant_key=self.activation_quant_key, + weight_quant_key=self.weight_quant_key, + weight_shape=layer.weight.shape, + input_dtype=self.input_dtype, + out_dtype=self.out_dtype, + module_name=self.__class__.__name__, + ) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # Keep weight in [out, in] layout for Fp8BlockScaledMMLinearKernel. + layer.weight = Parameter(layer.weight.data, requires_grad=False) + + scale = layer.weight_scale + if scale.dim() == 4: + # [out_blk, 1, in_blk, 1] -> [out_blk, in_blk] + scale = scale.squeeze(1).squeeze(-1) + elif scale.dim() != 2: + raise ValueError( + "Unexpected ModelOpt FP8_PB_WO weight_scale shape: " + f"{tuple(scale.shape)}." + ) + + layer.weight_scale = Parameter(scale.contiguous(), requires_grad=False) + + if hasattr(self, "fp8_linear"): + self.fp8_linear.process_weights_after_loading(layer) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + return self.w8a8_block_fp8_linear.apply_weights(layer, x, bias) + + +class ModelOptFp8MoEMethod(FusedMoEMethodBase): + """MoE method for ModelOpt FP8. + Supports loading FP8 checkpoints with static weight scale and + activation scale. + Args: + quant_config: The ModelOpt quantization config. + """ + + def __init__( + self, + quant_config: ModelOptFp8Config, + moe_config: FusedMoEConfig, + ) -> None: + super().__init__(moe_config) + self.quant_config = quant_config + assert self.quant_config.is_checkpoint_fp8_serialized + + # Select Fp8 MoE backend + self.fp8_backend, self.experts_cls = select_fp8_moe_backend( + config=self.moe, + weight_key=kFp8StaticTensorSym, + activation_key=kFp8StaticTensorSym, + ) + + def maybe_make_prepare_finalize( + self, + routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, + ) -> mk.FusedMoEPrepareAndFinalizeModular | None: + raise ValueError( + f"{self.__class__.__name__} uses the new modular kernel initialization " + "logic. This function should not be called." + ) + + def select_gemm_impl( + self, + prepare_finalize: mk.FusedMoEPrepareAndFinalizeModular, + layer: RoutedExperts, + ) -> mk.FusedMoEExpertsModular: + raise ValueError( + f"{self.__class__.__name__} uses the new modular kernel initialization " + "logic. This function should not be called." + ) + + def create_weights( + self, + layer: RoutedExperts, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + layer.orig_dtype = params_dtype + layer.num_experts = num_experts + + # Use FP8 dtype if checkpoint is serialized + weight_dtype = ( + torch.float8_e4m3fn + if self.quant_config.is_checkpoint_fp8_serialized + else params_dtype + ) + weight_loader = extra_weight_attrs.get("weight_loader") + + w13_num_shards = 2 if self.moe.is_act_and_mul else 1 + + w13_weight = ModelWeightParameter( + data=torch.empty( + num_experts, + w13_num_shards * intermediate_size_per_partition, + hidden_size, + dtype=weight_dtype, + ), + input_dim=2, + output_dim=1, + weight_loader=weight_loader, + ) + layer.register_parameter("w13_weight", w13_weight) + + w2_weight = ModelWeightParameter( + data=torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=weight_dtype, + ), + input_dim=2, + output_dim=1, + weight_loader=weight_loader, + ) + layer.register_parameter("w2_weight", w2_weight) + + # WEIGHT SCALES - Per-tensor scaling for ModelOpts + # For gated MoE, allocate 2 scales for w1 and w3 respectively. + # They will be combined to a single scale after weight loading. + # For non-gated MoE, allocate 1 scale for w13. + w13_weight_scale = PerTensorScaleParameter( + data=torch.full( + (num_experts, w13_num_shards), + 1.0, + dtype=torch.float32, + ), + weight_loader=weight_loader, + ) + w2_weight_scale = PerTensorScaleParameter( + data=torch.full((num_experts,), 1.0, dtype=torch.float32), + weight_loader=weight_loader, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + + # INPUT SCALES - Per-tensor scaling for ModelOpt + w13_input_scale = PerTensorScaleParameter( + data=torch.full((num_experts,), 1.0, dtype=torch.float32), + weight_loader=weight_loader, + ) + w2_input_scale = PerTensorScaleParameter( + data=torch.full((num_experts,), 1.0, dtype=torch.float32), + weight_loader=weight_loader, + ) + layer.register_parameter("w13_input_scale", w13_input_scale) + layer.register_parameter("w2_input_scale", w2_input_scale) + + def _setup_kernel( + self, + layer: RoutedExperts, + w13: torch.Tensor, + w2: torch.Tensor, + w13_scale: torch.Tensor, + w2_scale: torch.Tensor, + w13_input_scale: torch.Tensor, + w2_input_scale: torch.Tensor, + ): + w13, w2, w13_scale, w2_scale = convert_to_fp8_moe_kernel_format( + fp8_backend=self.fp8_backend, + layer=layer, + w13=w13, + w2=w2, + w13_scale=w13_scale, + w2_scale=w2_scale, + w13_input_scale=w13_input_scale, + w2_input_scale=w2_input_scale, + ) + + # Replace parameters with updated versions. Note that this helper + # function ensures the replacement is compatible with RL weight reloads. + replace_parameter(layer, "w13_weight", w13) + replace_parameter(layer, "w2_weight", w2) + replace_parameter(layer, "w13_weight_scale", w13_scale) + replace_parameter(layer, "w2_weight_scale", w2_scale) + + # Setup modular kernel. + self.moe_quant_config = self.get_fused_moe_quant_config(layer) + assert self.experts_cls is not None + self.moe_kernel = make_fp8_moe_kernel( + moe_quant_config=self.moe_quant_config, + moe_config=self.moe, + fp8_backend=self.fp8_backend, + experts_cls=self.experts_cls, + routing_tables=layer._expert_routing_tables(), + ) + + def process_weights_after_loading(self, layer: RoutedExperts) -> None: + w13 = layer.w13_weight + w2 = layer.w2_weight + w13_scale = layer.w13_weight_scale + w2_scale = layer.w2_weight_scale + w13_input_scale = layer.w13_input_scale + w2_input_scale = layer.w2_input_scale + + # Per tensor kernels require single activation scale. Use the max. + w13_input_scale, w2_input_scale = process_fp8_input_tensor_strategy_moe( + w13_input_scale, w2_input_scale + ) + replace_parameter(layer, "w13_input_scale", w13_input_scale) + replace_parameter(layer, "w2_input_scale", w2_input_scale) + + # Per tensor kernels require single weight scale for w13 per expert, but + # on disk there is a scale for w1 and w3. Use the max to requantize. + shard_size = layer.intermediate_size_per_partition + w13, w13_scale = process_fp8_weight_tensor_strategy_moe( + w13, + w13_scale, + shard_size, + num_experts=layer.w13_weight.shape[0], + is_act_and_mul=self.moe.is_act_and_mul, + ) + + # Shuffle weights to runtime format and setup kernel. + self._setup_kernel( + layer, w13, w2, w13_scale, w2_scale, w13_input_scale, w2_input_scale + ) + + def get_fused_moe_quant_config(self, layer: RoutedExperts) -> FusedMoEQuantConfig: + w1_scale = layer.w13_weight_scale + w2_scale = layer.w2_weight_scale + a1_scale = layer.w13_input_scale + a2_scale = layer.w2_input_scale + + return make_fp8_moe_quant_config( + fp8_backend=self.fp8_backend, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + swiglu_limit=getattr(layer, "swiglu_limit", None), + ) + + def apply_monolithic( + self, + layer: RoutedExperts, + x: torch.Tensor, + router_logits: torch.Tensor, + input_ids: torch.Tensor | None = None, + ) -> torch.Tensor: + assert self.is_monolithic + assert self.moe_kernel is not None + return self.moe_kernel.apply_monolithic( + x, + layer.w13_weight, + layer.w2_weight, + router_logits, + activation=layer.activation, + global_num_experts=layer.global_num_experts, + expert_map=layer.expert_map, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + num_expert_group=layer.num_expert_group, + topk_group=layer.topk_group, + e_score_correction_bias=layer.e_score_correction_bias, + routed_scaling_factor=layer.routed_scaling_factor, + ) + + def apply( + self, + layer: RoutedExperts, + x: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + shared_experts: SharedExperts | None, + shared_experts_input: torch.Tensor | None, + ) -> torch.Tensor: + assert not self.is_monolithic + assert self.moe_kernel is not None + return self.moe_kernel.apply( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + activation=layer.activation, + global_num_experts=layer.global_num_experts, + expert_map=layer.expert_map, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + shared_experts=shared_experts, + shared_experts_input=shared_experts_input, + ) + + +ModelOptFp8Config.LinearMethodCls = ModelOptFp8LinearMethod +ModelOptFp8Config.FusedMoEMethodCls = ModelOptFp8MoEMethod +ModelOptFp8Config.KVCacheMethodCls = ModelOptKVCacheMethod + + +class ModelOptNvFp4Config(ModelOptQuantConfigBase): + """Config class for ModelOpt FP4.""" + + def __init__( + self, + quant_method: str = "NVFP4", + is_checkpoint_nvfp4_serialized: bool = False, + kv_cache_quant_algo: str | None = None, + exclude_modules: list[str] | None = None, + group_size: int = 16, + ) -> None: + if exclude_modules is None: + exclude_modules = [] + super().__init__(exclude_modules) + self.quant_method = quant_method + self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized + if is_checkpoint_nvfp4_serialized: + logger.warning( + "Detected ModelOpt NVFP4 checkpoint (quant_algo=%s). Please " + "note that the format is experimental and could change in " + "future.", + quant_method, + ) + + self.group_size = group_size + self.kv_cache_quant_algo = kv_cache_quant_algo + + # Select LinearMethod implementation based on quant_algo (FP8 pattern). + # NVFP4 -> W4A4: cutlass NVFP4 GEMM with input quantization + # W4A16_NVFP4 -> W4A16: FP4 Marlin GEMM with bf16/fp16 activations + if quant_method == "NVFP4": + self.LinearMethodCls = ModelOptNvFp4LinearMethod + elif quant_method == "W4A16_NVFP4": + self.LinearMethodCls = ModelOptNvFp4W4A16LinearMethod + else: + raise ValueError( + f"Unsupported ModelOpt NVFP4 quant_algo: {quant_method}. " + "Supported: NVFP4 / W4A16_NVFP4." + ) + + def get_name(self) -> QuantizationMethods: + return "modelopt_fp4" + + def get_supported_act_dtypes(self) -> list[torch.dtype]: + return [torch.bfloat16, torch.half, torch.float8_e4m3fn] + + @classmethod + def get_min_capability(cls) -> int: + return 75 + + @classmethod + def override_quantization_method( + cls, hf_quant_cfg, user_quant, hf_config=None + ) -> QuantizationMethods | None: + algo = cls._extract_modelopt_quant_algo(hf_quant_cfg) + if algo is not None and ("NVFP4" in algo or "FP4" in algo): + return "modelopt_fp4" + return None + + @classmethod + def _from_config( + cls, + *, + quant_method: str, + kv_cache_quant_method: str | None, + exclude_modules: list[str], + original_config: dict[str, Any], + group_size: int | None, + **kwargs: Any, + ) -> "ModelOptNvFp4Config": + is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method + + if group_size is None: + group_size = 16 # Default value + + # For FP4, these fields are required + if is_checkpoint_nvfp4_serialized and "quantization" in original_config: + # Check if required fields are present in the quantization config + quant_config = original_config["quantization"] + required_fields = ["group_size", "kv_cache_quant_algo", "exclude_modules"] + missing_fields = [ + field for field in required_fields if field not in quant_config + ] + if missing_fields: + raise ValueError( + f"NVFP4 quantization requires the following fields in " + f"hf_quant_config.json: {missing_fields}" + ) + + return cls( + quant_method, + is_checkpoint_nvfp4_serialized, + kv_cache_quant_method, + exclude_modules, + group_size, + ) + + +class ModelOptNvFp4LinearMethod(LinearMethodBase): + """Linear method for Model Optimizer NVFP4. + Supports loading NVFP4 checkpoints with the following structure: + + input_scale: torch.float32, scalar , + weight: NVFP4(represented as byte) Shape: [1, X, y/2] + weight_scale: FP8-E4M3, Shape: [X, Y], aka per block scale, + weight_scale_2: torch.float32, scalar, + Args: quant_config: The ModelOpt quantization config. + """ + + def __init__(self, quant_config: ModelOptNvFp4Config) -> None: + self.quant_config = quant_config + self.marlin_input_dtype = None + self.kernel = init_nvfp4_linear_kernel() + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + del input_size, output_size + if not self.quant_config.is_checkpoint_nvfp4_serialized: + raise ValueError( + "NVFP4 quantization was selected, " + " dynamic quantization is not supported." + ) + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + + if input_size_per_partition % 16 != 0: + raise ValueError( + "Unsupported model when in features size is not multiple of 16" + ) + # The nvfp4 weight is still represented as + weight_dtype = ( + torch.float8_e4m3fn + if self.quant_config.is_checkpoint_nvfp4_serialized + else params_dtype + ) + # Weight + weight = ModelWeightParameter( + data=torch.empty( + # 2 fp4 items are packed in the input dimension + layer.output_size_per_partition, + layer.input_size_per_partition // 2, + dtype=torch.uint8, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + # Input Global Scale + input_global_scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) + layer.register_parameter("input_scale", input_global_scale) + + # Weight Global Scale + weight_global_scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) + layer.register_parameter("weight_scale_2", weight_global_scale) + + # Per Block Weight Scale + weight_scale = ModelWeightParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition // self.quant_config.group_size, + dtype=weight_dtype, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + + layer.register_parameter("weight_scale", weight_scale) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + if ( + torch.unique(layer.input_scale).numel() != 1 + or torch.unique(layer.weight_scale_2).numel() != 1 + ): + logger.warning_once( + "In NVFP4 linear, the global scale for input or weight are different" + " for parallel layers (e.g. q_proj, k_proj, v_proj). This " + " will likely results in reduce accuracy. Please verify the model" + " accuracy. Consider using a checkpoint with a shared global NVFP4" + " scale for parallel layers." + ) + + # Rename ModelOpt checkpoint names to standardized names + input_global_scale = layer.input_scale.max().to(torch.float32) + layer.input_global_scale = Parameter(input_global_scale, requires_grad=False) + del layer.input_scale + + weight_global_scale = layer.weight_scale_2.max().to(torch.float32) + layer.weight_global_scale = Parameter(weight_global_scale, requires_grad=False) + del layer.weight_scale_2 + + # Pre-compute alpha and inverse for runtime quantization + layer.alpha = Parameter( + layer.input_global_scale * layer.weight_global_scale, requires_grad=False + ) + layer.input_global_scale_inv = Parameter( + (1.0 / layer.input_global_scale).to(torch.float32), requires_grad=False + ) + + # Convert layer to NVFP4 linear kernel format + self.kernel.process_weights_after_loading(layer) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + return self.kernel.apply_weights(layer=layer, x=x, bias=bias) + + +class ModelOptNvFp4W4A16LinearMethod(LinearMethodBase): + """Linear method for ModelOpt NVFP4 W4A16. + + 4-bit NVFP4 weights, fp16/bf16 activations. Loads ModelOpt-style names + directly (no on-disk conversion) and dispatches to the FP4 Marlin GEMM: + + weight uint8 packed NVFP4 (2 nibbles/byte along input dim) + weight_scale fp8-e4m3 per 16-elem group along input dim + weight_scale_2 fp32 per-tensor global scale = amax / (6.0 * 448.0) + + No activation quantization. Marlin expects the global scale in the same + form ModelOpt stores (amax/2688), so we rename weight_scale_2 -> + weight_global_scale **without reciprocation** -- the CT W4A16 path + reciprocates only because CT stores the inverse on disk. + + We also register a placeholder input_scale parameter so that W4A4-shaped + checkpoints (which contain *_proj.input_scale tensors) can be loaded + under this method without the per-shard loader hitting a KeyError on + the merged-name lookup. The placeholder is discarded in + process_weights_after_loading -- its value is never used. + """ + + def __init__(self, quant_config: ModelOptNvFp4Config) -> None: + self.quant_config = quant_config + # Vestigial slot mirrored from ModelOptNvFp4LinearMethod: the parent + # config's get_quant_method only fills marlin_input_dtype when + # backend == "marlin"; we don't set that since we pin the kernel + # below, but we keep the attribute for shape parity. + self.marlin_input_dtype = None + # Direct-instantiate the Marlin NVFP4 adapter rather than going through + # init_nvfp4_linear_kernel(): the latter's priority list returns a + # cutlass W4A4 kernel as first-pick on this hardware, which would + # silently try to quantize activations (we have no input_scale). For + # W4A16 there is exactly one valid kernel, so we pin it. + self.kernel = MarlinNvFp4LinearKernel(NvFp4LinearLayerConfig()) + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + del input_size, output_size + if not self.quant_config.is_checkpoint_nvfp4_serialized: + raise ValueError( + "W4A16_NVFP4 quantization was selected; " + "dynamic quantization is not supported." + ) + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + + if input_size_per_partition % 16 != 0: + raise ValueError( + "Unsupported model: input feature size is not a multiple of 16." + ) + + # Packed NVFP4 weights: uint8, 2 nibbles per byte along the input dim. + weight = ModelWeightParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition // 2, + dtype=torch.uint8, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + # Per-tensor global weight scale (fp32). ModelOpt stores + # amax / (NVFP4_max * fp8_e4m3_max) = amax / 2688. PerTensorScaleParameter + # holds one entry per fused output partition (e.g. q/k/v in a fused QKV). + weight_scale_2 = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) + layer.register_parameter("weight_scale_2", weight_scale_2) + + # Per-group fp8 weight scale. + weight_scale = GroupQuantScaleParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition // self.quant_config.group_size, + dtype=torch.float8_e4m3fn, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight_scale", weight_scale) + + # Placeholder input_scale param so W4A4-shaped checkpoints can be + # loaded under this method without KeyError on the merged-name + # lookup (qwen2-style stacked-loader path renames *_proj.input_scale + # to e.g. qkv_proj.input_scale and looks it up unconditionally). + # Discarded in process_weights_after_loading; never read by the kernel. + # For native W4A16 checkpoints (no input_scale on disk) the param + # stays uninitialized and is simply deleted. + input_scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) + layer.register_parameter("input_scale", input_scale) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # Discard the input_scale placeholder. Whether it carries values + # (W4A4 ckpt loaded as W4A16) or is uninitialized (native W4A16 + # ckpt), W4A16 mode does not quantize activations, so this is unused. + if hasattr(layer, "input_scale"): + del layer.input_scale + + if torch.unique(layer.weight_scale_2).numel() != 1: + logger.warning_once( + "In W4A16_NVFP4 linear, the global weight scale " + "(weight_scale_2) differs across fused parallel layers " + "(e.g. q/k/v_proj). This will likely reduce accuracy. " + "Consider a checkpoint with a shared global scale." + ) + + # Rename weight_scale_2 -> weight_global_scale. NO reciprocation: + # ModelOpt already stores amax/2688, which is exactly what Marlin + # consumes via nvfp4_marlin_process_global_scale (called inside the + # Marlin adapter's process_weights_after_loading). + layer.weight_global_scale = Parameter( + layer.weight_scale_2.max().to(torch.float32), requires_grad=False + ) + del layer.weight_scale_2 + + self.kernel.process_weights_after_loading(layer) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + return self.kernel.apply_weights(layer=layer, x=x, bias=bias) + + +class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): + """ + MoE Method for FP4 Quantization. + Args: + quant_config: NVFP4 Quant Config + """ + + def __init__( + self, + quant_config: ModelOptNvFp4Config, + moe_config: FusedMoEConfig, + ) -> None: + super().__init__(moe_config) + self.quant_config = quant_config + # Select experts implementation. + self.nvfp4_backend, self.experts_cls = select_nvfp4_moe_backend( + config=self.moe, + weight_key=kNvfp4Static, + activation_key=kNvfp4Dynamic, + ) + + self.use_global_sf = is_global_sf_supported_for_nvfp4_backend( + self.nvfp4_backend + ) + + def maybe_make_prepare_finalize( + self, + routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, + ) -> mk.FusedMoEPrepareAndFinalizeModular | None: + raise ValueError( + f"{self.__class__.__name__} uses the new modular kernel initialization " + "logic. This function should not be called." + ) + + def uses_weight_scale_2_pattern(self) -> bool: + """ + FP4 variants use 'weight_scale_2' pattern for per-tensor weight scales. + """ + return True + + def create_weights( + self, + layer: RoutedExperts, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + assert self.quant_config.is_checkpoint_nvfp4_serialized + + layer.num_experts = num_experts + layer.params_dtype = params_dtype + layer.quant_config = self.quant_config + weight_dtype = torch.uint8 + weight_scale_dtype = torch.float8_e4m3fn + weight_loader = extra_weight_attrs.get("weight_loader") + global_num_experts = extra_weight_attrs.get("global_num_experts") + w13_num_shards = 2 if self.moe.is_act_and_mul else 1 + # GEMM 1 + w13_weight = ModelWeightParameter( + data=torch.empty( + num_experts, + w13_num_shards * intermediate_size_per_partition, + # 2 fp4 items are packed in the input dimension + hidden_size // 2, + dtype=weight_dtype, + ), + input_dim=1, + output_dim=2, + weight_loader=weight_loader, + ) + layer.register_parameter("w13_weight", w13_weight) + + # GEMM 2 + w2_weight = ModelWeightParameter( + data=torch.empty( + num_experts, + hidden_size, + # 2 fp4 items are packed in the input dimension + intermediate_size_per_partition // 2, + dtype=weight_dtype, + ), + input_dim=1, + output_dim=2, + weight_loader=weight_loader, + ) + layer.register_parameter("w2_weight", w2_weight) + + w13_weight_scale = ModelWeightParameter( + data=torch.empty( + num_experts, + w13_num_shards * intermediate_size_per_partition, + # 2 fp4 items are packed in the input dimension + hidden_size // self.quant_config.group_size, + dtype=weight_scale_dtype, + ), + input_dim=1, + output_dim=2, + weight_loader=weight_loader, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + + w2_weight_scale = ModelWeightParameter( + data=torch.empty( + num_experts, + hidden_size, + # 2 fp4 items are packed in the input dimension + intermediate_size_per_partition // self.quant_config.group_size, + dtype=weight_scale_dtype, + ), + input_dim=1, + output_dim=2, + weight_loader=weight_loader, + ) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value} + ) + + w13_weight_scale_2 = PerTensorScaleParameter( + data=torch.empty(num_experts, w13_num_shards, dtype=torch.float32), + weight_loader=weight_loader, + ) + layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2) + + w2_weight_scale_2 = PerTensorScaleParameter( + data=torch.empty(num_experts, dtype=torch.float32), + weight_loader=weight_loader, + ) + layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2) + + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} + ) + + global_sf_num_experts = ( + global_num_experts if self.use_global_sf else num_experts + ) + w13_input_scale = PerTensorScaleParameter( + data=torch.empty( + global_sf_num_experts, + w13_num_shards, + dtype=torch.float32, + ), + weight_loader=weight_loader, + ) + layer.register_parameter("w13_input_scale", w13_input_scale) + + w2_input_scale = PerTensorScaleParameter( + data=torch.empty(global_sf_num_experts, dtype=torch.float32), + weight_loader=weight_loader, + ) + layer.register_parameter("w2_input_scale", w2_input_scale) + + def process_weights_after_loading(self, layer: RoutedExperts) -> None: + """ + Convert NVFP4 MoE weights into kernel format and setup the kernel. + """ + + # Use a single gscale for w13. + if self.moe.is_act_and_mul and not torch.allclose( + layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1] + ): + logger.warning_once( + "w1_weight_scale_2 must match w3_weight_scale_2. " + "Accuracy may be affected." + ) + w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0].contiguous() + + ( + w13, + w13_scale, + w13_scale_2, + a13_scale, + w2, + w2_scale, + w2_scale_2, + a2_scale, + ) = convert_to_nvfp4_moe_kernel_format( + nvfp4_backend=self.nvfp4_backend, + layer=layer, + w13=layer.w13_weight, + w13_scale=layer.w13_weight_scale, + w13_scale_2=w13_weight_scale_2, + a13_scale=layer.w13_input_scale, + w2=layer.w2_weight, + w2_scale=layer.w2_weight_scale, + w2_scale_2=layer.w2_weight_scale_2, + a2_scale=layer.w2_input_scale, + is_act_and_mul=self.moe.is_act_and_mul, + ) + + replace_parameter(layer, "w13_weight", w13) + replace_parameter(layer, "w13_weight_scale", w13_scale) + replace_parameter(layer, "w13_weight_scale_2", w13_scale_2) + replace_parameter(layer, "w13_input_scale", a13_scale) + replace_parameter(layer, "w2_weight", w2) + replace_parameter(layer, "w2_weight_scale", w2_scale) + replace_parameter(layer, "w2_weight_scale_2", w2_scale_2) + replace_parameter(layer, "w2_input_scale", a2_scale) + + # Setup modular kernel. + self.moe_quant_config = self.get_fused_moe_quant_config(layer) + assert self.experts_cls is not None + self.moe_kernel = make_nvfp4_moe_kernel( + moe_quant_config=self.moe_quant_config, + moe_config=self.moe, + experts_cls=self.experts_cls, + routing_tables=layer._expert_routing_tables(), + ) + self.moe_kernel.fused_experts.process_weights_after_loading(layer) + + def get_fused_moe_quant_config(self, layer: RoutedExperts) -> FusedMoEQuantConfig: + return make_nvfp4_moe_quant_config( + backend=self.nvfp4_backend, + w13_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + w13_scale_2=layer.w13_weight_scale_2, + w2_scale_2=layer.w2_weight_scale_2, + a13_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + ) + + @property + def supports_eplb(self) -> bool: + return True + + def apply_monolithic( + self, + layer: RoutedExperts, + x: torch.Tensor, + router_logits: torch.Tensor, + input_ids: torch.Tensor | None = None, + ) -> torch.Tensor: + assert self.is_monolithic + assert self.moe_kernel is not None + return self.moe_kernel.apply_monolithic( + x, + layer.w13_weight, + layer.w2_weight, + router_logits, + activation=layer.activation, + global_num_experts=layer.global_num_experts, + expert_map=layer.expert_map, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + num_expert_group=layer.num_expert_group, + topk_group=layer.topk_group, + e_score_correction_bias=layer.e_score_correction_bias, + routed_scaling_factor=layer.routed_scaling_factor, + ) + + def apply( + self, + layer: RoutedExperts, + x: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + shared_experts: SharedExperts | None, + shared_experts_input: torch.Tensor | None, + ) -> torch.Tensor: + assert not self.is_monolithic + assert self.moe_kernel is not None + return self.moe_kernel.apply( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + activation=layer.activation, + global_num_experts=layer.global_num_experts, + expert_map=layer.expert_map, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + shared_experts=shared_experts, + shared_experts_input=shared_experts_input, + ) + + +ModelOptNvFp4Config.LinearMethodCls = ModelOptNvFp4LinearMethod +ModelOptNvFp4Config.FusedMoEMethodCls = ModelOptNvFp4FusedMoE +ModelOptNvFp4Config.KVCacheMethodCls = ModelOptKVCacheMethod + + +class ModelOptMxFp8Config(ModelOptQuantConfigBase): + """Config class for ModelOpt MXFP8.""" + + def __init__( + self, + is_checkpoint_mxfp8_serialized: bool, + kv_cache_quant_algo: str | None, + exclude_modules: list[str], + ) -> None: + super().__init__(exclude_modules) + self.is_checkpoint_mxfp8_serialized = is_checkpoint_mxfp8_serialized + + if not is_checkpoint_mxfp8_serialized: + raise ValueError( + "MXFP8 quantization requires a serialized checkpoint. " + "Dynamic quantization is not supported." + ) + + logger.warning( + "Detected ModelOpt MXFP8 checkpoint. Please note that " + "the format is experimental and could change in future." + ) + + self.kv_cache_quant_algo = kv_cache_quant_algo + + def get_name(self) -> QuantizationMethods: + return "modelopt_mxfp8" + + def get_supported_act_dtypes(self) -> list[torch.dtype]: + return [torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + # Marlin kernel supports MXFP8 on SM80+ + return 80 + + @classmethod + def override_quantization_method( + cls, hf_quant_cfg, user_quant, hf_config=None + ) -> QuantizationMethods | None: + algo = cls._extract_modelopt_quant_algo(hf_quant_cfg) + if algo is not None and "MXFP8" in algo: + return "modelopt_mxfp8" + return None + + @classmethod + def _from_config( + cls, + *, + quant_method: str, + kv_cache_quant_method: str | None, + exclude_modules: list[str], + original_config: dict[str, Any], + **kwargs: Any, + ) -> "ModelOptMxFp8Config": + is_checkpoint_mxfp8_serialized = "MXFP8" in quant_method.upper() + + # For MXFP8, validate required fields in the config + if is_checkpoint_mxfp8_serialized and "quantization" in original_config: + quant_config = original_config["quantization"] + required_fields = ["kv_cache_quant_algo", "exclude_modules"] + missing_fields = [ + field for field in required_fields if field not in quant_config + ] + if missing_fields: + raise ValueError( + f"MXFP8 quantization requires the following fields in " + f"hf_quant_config.json: {missing_fields}" + ) + + return cls( + is_checkpoint_mxfp8_serialized, + kv_cache_quant_method, + exclude_modules, + ) + + +class ModelOptMxFp8LinearMethod(LinearMethodBase): + """Linear method for ModelOpt MXFP8 quantization.""" + + def __init__(self, quant_config: ModelOptMxFp8Config) -> None: + self.quant_config = quant_config + + if not self.quant_config.is_checkpoint_mxfp8_serialized: + raise ValueError( + "MXFP8 currently only supports serialized checkpoints. " + "Dynamic quantization is not supported." + ) + + self.kernel = init_mxfp8_linear_kernel() + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + del input_size, output_size + + if not self.quant_config.is_checkpoint_mxfp8_serialized: + raise ValueError( + "MXFP8 quantization was selected, but checkpoint is not " + "MXFP8 serialized. Dynamic quantization is not supported." + ) + + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + + if input_size_per_partition % MXFP8_BLOCK_SIZE != 0: + raise ValueError( + f"MXFP8 requires input dimension to be divisible by " + f"{MXFP8_BLOCK_SIZE}, got {input_size_per_partition}" + ) + + # Weight tensor: FP8 E4M3 format + weight = ModelWeightParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=MXFP8_VALUE_DTYPE, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + # Weight scale tensor (E8M0 encoded as uint8), one scale per block of 32 along K + weight_scale = ModelWeightParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition // MXFP8_BLOCK_SIZE, + dtype=MXFP8_SCALE_DTYPE, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight_scale", weight_scale) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # Validate weight tensor + if layer.weight.ndim != 2: + raise ValueError( + f"MXFP8 weight must be 2D tensor [N, K], got {layer.weight.ndim}D " + f"with shape {tuple(layer.weight.shape)}" + ) + + if layer.weight.dtype != MXFP8_VALUE_DTYPE: + raise ValueError( + f"MXFP8 weight must be {MXFP8_VALUE_DTYPE} (FP8 E4M3), " + f"got {layer.weight.dtype}. The checkpoint may not be properly " + f"quantized with MXFP8." + ) + + # Validate weight scale tensor (should be 2D, not swizzled) + assert layer.weight_scale.ndim == 2, ( + f"MXFP8 weight scale must be 2D, got {layer.weight_scale.ndim}D" + ) + assert layer.weight_scale.dtype == MXFP8_SCALE_DTYPE, ( + f"MXFP8 weight scale must be {MXFP8_SCALE_DTYPE}," + f" got {layer.weight_scale.dtype}" + ) + + self.kernel.process_weights_after_loading(layer) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + return self.kernel.apply_weights(layer, x, bias) + + +class ModelOptMxFp8FusedMoE(FusedMoEMethodBase): + """FlashInfer TRTLLM MXFP8 block-scale MoE for ModelOpt checkpoints.""" + + def __init__( + self, + quant_config: ModelOptMxFp8Config, + moe_config: FusedMoEConfig, + ) -> None: + super().__init__(moe_config) + self.quant_config = quant_config + assert self.quant_config.is_checkpoint_mxfp8_serialized + + self.mxfp8_backend, _ = select_mxfp8_moe_backend(self.moe) + + def create_weights( + self, + layer: RoutedExperts, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + assert layer.intermediate_size_per_partition == intermediate_size_per_partition + assert layer.hidden_size == hidden_size + layer.orig_dtype = params_dtype + + if hidden_size % MXFP8_BLOCK_SIZE != 0: + raise ValueError( + f"MXFP8 MoE requires hidden_size divisible by {MXFP8_BLOCK_SIZE}, " + f"got {hidden_size}." + ) + if intermediate_size_per_partition % MXFP8_BLOCK_SIZE != 0: + raise ValueError( + "MXFP8 MoE requires intermediate_size_per_partition divisible by " + f"{MXFP8_BLOCK_SIZE}, got {intermediate_size_per_partition}." + ) + + layer.num_experts = num_experts + weight_loader = extra_weight_attrs.get("weight_loader") + w13_num_shards = 2 if self.moe.is_act_and_mul else 1 + + # GEMM 1 weights: [E, (2I or I), H] + w13_weight = ModelWeightParameter( + data=torch.empty( + num_experts, + w13_num_shards * intermediate_size_per_partition, + hidden_size, + dtype=MXFP8_VALUE_DTYPE, + ), + input_dim=2, + output_dim=1, + weight_loader=weight_loader, + ) + layer.register_parameter("w13_weight", w13_weight) + + # GEMM 2 weights: [E, H, I] + w2_weight = ModelWeightParameter( + data=torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=MXFP8_VALUE_DTYPE, + ), + input_dim=2, + output_dim=1, + weight_loader=weight_loader, + ) + layer.register_parameter("w2_weight", w2_weight) + + # Per-block (K=32) E8M0 scales. + w13_weight_scale = ModelWeightParameter( + data=torch.empty( + num_experts, + w13_num_shards * intermediate_size_per_partition, + hidden_size // MXFP8_BLOCK_SIZE, + dtype=MXFP8_SCALE_DTYPE, + ), + input_dim=2, + output_dim=1, + weight_loader=weight_loader, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + + w2_weight_scale = ModelWeightParameter( + data=torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition // MXFP8_BLOCK_SIZE, + dtype=MXFP8_SCALE_DTYPE, + ), + input_dim=2, + output_dim=1, + weight_loader=weight_loader, + ) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + + # Ensure the generic MoE weight-loader treats these as block scales. + set_weight_attrs( + layer.w13_weight_scale, + {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}, + ) + set_weight_attrs( + layer.w2_weight_scale, + {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}, + ) + + @staticmethod + def _check_weight_dtypes(layer: torch.nn.Module) -> None: + """Validate weight and scale dtypes before processing.""" + expected = { + "w13_weight": MXFP8_VALUE_DTYPE, + "w2_weight": MXFP8_VALUE_DTYPE, + "w13_weight_scale": MXFP8_SCALE_DTYPE, + "w2_weight_scale": MXFP8_SCALE_DTYPE, + } + for name, expected_dtype in expected.items(): + actual = getattr(layer, name).dtype + if actual != expected_dtype: + raise ValueError( + f"Expected {name} dtype {expected_dtype}, got {actual}." + ) + + def _shuffle_weights_for_trtllm(self, layer: torch.nn.Module) -> None: + """Shuffle weights and scales into FlashInfer TRTLLM MXFP8 layout.""" + from flashinfer import ( + reorder_rows_for_gated_act_gemm, + shuffle_matrix_a, + shuffle_matrix_sf_a, + ) + + epilogue_tile_m = 128 + num_experts = layer.w13_weight.shape[0] + is_gated = self.moe.is_act_and_mul + intermediate_size_factor = 2 if is_gated else 1 + + w13_weight = layer.w13_weight.data + w13_scale = layer.w13_weight_scale.data + if is_gated: + # FI TRTLLM gated kernels use W31 ordering. Model checkpoints store + # gated projection as W13, so convert once before shuffling. + w13_weight = swap_w13_to_w31(w13_weight) + w13_scale = swap_w13_to_w31(w13_scale) + + w13_weight_shuffled = [] + w2_weight_shuffled = [] + w13_scale_shuffled = [] + w2_scale_shuffled = [] + for i in range(num_experts): + w13_i = w13_weight[i].reshape( + intermediate_size_factor * layer.intermediate_size_per_partition, -1 + ) + w13_sf_i = w13_scale[i].reshape( + intermediate_size_factor * layer.intermediate_size_per_partition, -1 + ) + if is_gated: + # Reorder rows for gated activation layout expected by TRTLLM. + w13_i = reorder_rows_for_gated_act_gemm(w13_i.clone()) + w13_sf_i = reorder_rows_for_gated_act_gemm(w13_sf_i.clone()) + + w13_shuffled_i = shuffle_matrix_a(w13_i.view(torch.uint8), epilogue_tile_m) + w2_shuffled_i = shuffle_matrix_a( + layer.w2_weight.data[i].view(torch.uint8), epilogue_tile_m + ) + w13_weight_shuffled.append( + w13_shuffled_i.contiguous().view(MXFP8_VALUE_DTYPE) + ) + w2_weight_shuffled.append( + w2_shuffled_i.contiguous().view(MXFP8_VALUE_DTYPE) + ) + w13_sf_shuffled_i = shuffle_matrix_sf_a( + w13_sf_i.view(torch.uint8).reshape( + intermediate_size_factor * layer.intermediate_size_per_partition, + -1, + ), + epilogue_tile_m, + ) + w2_sf_shuffled_i = shuffle_matrix_sf_a( + layer.w2_weight_scale.data[i] + .view(torch.uint8) + .reshape(layer.hidden_size, -1), + epilogue_tile_m, + ) + w13_scale_shuffled.append( + w13_sf_shuffled_i.contiguous().view(MXFP8_SCALE_DTYPE) + ) + w2_scale_shuffled.append( + w2_sf_shuffled_i.contiguous().view(MXFP8_SCALE_DTYPE) + ) + + replace_parameter( + layer, "w13_weight", torch.stack(w13_weight_shuffled).contiguous() + ) + replace_parameter( + layer, "w2_weight", torch.stack(w2_weight_shuffled).contiguous() + ) + replace_parameter( + layer, + "w13_weight_scale", + torch.stack(w13_scale_shuffled).contiguous(), + ) + replace_parameter( + layer, + "w2_weight_scale", + torch.stack(w2_scale_shuffled).contiguous(), + ) + + def process_weights_after_loading(self, layer: RoutedExperts) -> None: + if getattr(layer, "_already_called_process_weights_after_loading", False): + return + + self._check_weight_dtypes(layer) + self._shuffle_weights_for_trtllm(layer) + layer._already_called_process_weights_after_loading = True + + def maybe_make_prepare_finalize( + self, + routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, + ) -> mk.FusedMoEPrepareAndFinalizeModular | None: + raise ValueError( + f"{self.__class__.__name__} uses the new modular kernel initialization " + "logic. This function should not be called." + ) + + def select_gemm_impl( + self, + prepare_finalize: mk.FusedMoEPrepareAndFinalizeModular, + layer: RoutedExperts, + ) -> mk.FusedMoEExpertsModular: + raise ValueError( + f"{self.__class__.__name__} uses the new modular kernel initialization " + "logic. This function should not be called." + ) + + def get_fused_moe_quant_config( + self, layer: RoutedExperts + ) -> FusedMoEQuantConfig | None: + # TRTLLM MXFP8 path is monolithic and does not use modular kernel config. + return None + + @property + def is_monolithic(self) -> bool: + return self.mxfp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM + + def apply_monolithic( + self, + layer: RoutedExperts, + x: torch.Tensor, + router_logits: torch.Tensor, + input_ids: torch.Tensor | None = None, + ) -> torch.Tensor: + from flashinfer.fused_moe.core import ( + ActivationType, + Fp8QuantizationType, + ) + + assert self.mxfp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM + + if layer.eplb_state is not None: + raise NotImplementedError( + "EPLB is not supported for FlashInfer TRTLLM MXFP8 MoE backend." + ) + + supported_activations = [MoEActivation.SILU] + if layer.activation not in supported_activations: + raise NotImplementedError( + "FlashInfer TRTLLM MXFP8 MoE supports only " + f"{supported_activations}, got {layer.activation}." + ) + + # Map vLLM MoEActivation to FlashInfer ActivationType. + activation_map = { + MoEActivation.SILU: ActivationType.Swiglu, + MoEActivation.RELU2_NO_MUL: ActivationType.Relu2, + } + fi_activation_type: ActivationType = activation_map[layer.activation] + + # DeepSeekV3 routing requires float32 logits; others expect bfloat16. + if layer.routing_method_type == RoutingMethodType.DeepSeekV3: + assert router_logits.dtype == torch.float32, ( + "DeepSeekV3 routing requires float32 router_logits, " + f"got {router_logits.dtype}." + ) + else: + router_logits = router_logits.to(torch.bfloat16) + + # Treat 0 as "unset" for compatibility with ungrouped routing configs. + n_group = layer.num_expert_group or None + topk_group = layer.topk_group or None + + hidden_states_mxfp8, hidden_states_scale = mxfp8_e4m3_quantize( + x, + is_sf_swizzled_layout=False, + ) + + kwargs: dict = dict( + routing_logits=router_logits, + routing_bias=layer.e_score_correction_bias, + hidden_states=hidden_states_mxfp8, + hidden_states_scale=hidden_states_scale, + gemm1_weights=layer.w13_weight, + gemm1_weights_scale=layer.w13_weight_scale, + gemm2_weights=layer.w2_weight, + gemm2_weights_scale=layer.w2_weight_scale, + num_experts=layer.global_num_experts, + top_k=layer.top_k, + # Keep Optional semantics: FlashInfer expects None for non-grouped + # routing (e.g. Qwen3 Renormalize), not 0. + n_group=n_group, + topk_group=topk_group, + intermediate_size=layer.intermediate_size_per_partition, + local_expert_offset=layer.ep_rank * layer.local_num_experts, + local_num_experts=layer.local_num_experts, + routed_scaling_factor=layer.routed_scaling_factor, + routing_method_type=layer.routing_method_type, + use_shuffled_weight=True, + weight_layout=0, + fp8_quantization_type=Fp8QuantizationType.MxFp8, + ) + + if fi_activation_type != ActivationType.Swiglu: + raise NotImplementedError( + "FlashInfer TRTLLM MXFP8 MoE supports only Swiglu activation, " + f"got {fi_activation_type}." + ) + + return flashinfer_trtllm_fp8_block_scale_moe(**kwargs) + + def apply( + self, + layer: RoutedExperts, + x: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + shared_experts: SharedExperts | None, + shared_experts_input: torch.Tensor | None, + ) -> torch.Tensor: + assert not self.is_monolithic + raise NotImplementedError( + "Non-monolithic MXFP8 MoE path is not yet implemented." + ) + + +# Register the method classes for ModelOptMxFp8Config +ModelOptMxFp8Config.LinearMethodCls = ModelOptMxFp8LinearMethod +ModelOptMxFp8Config.FusedMoEMethodCls = ModelOptMxFp8FusedMoE +ModelOptMxFp8Config.KVCacheMethodCls = ModelOptKVCacheMethod + + +class ModelOptMixedPrecisionConfig(ModelOptQuantConfigBase): + """Config class for ModelOpt MIXED_PRECISION. + + Supports checkpoints where different layers use different quantization + algorithms (e.g., FP8 for dense layers and NVFP4 for MoE experts). + The per-layer algorithm is specified in the ``quantized_layers`` dict + inside ``config.json``'s ``quantization_config`` (preferred) or the + legacy ``hf_quant_config.json``. + """ + + def __init__( + self, + kv_cache_quant_method: str | None, + exclude_modules: list[str], + quantized_layers: dict[str, dict[str, Any]], + fp8_config: ModelOptFp8Config, + nvfp4_config: ModelOptNvFp4Config, + ) -> None: + super().__init__(exclude_modules) + self.kv_cache_quant_method = kv_cache_quant_method + self.quantized_layers = quantized_layers + self.fp8_config = fp8_config + self.nvfp4_config = nvfp4_config + + def get_name(self) -> QuantizationMethods: + return "modelopt_mixed" + + def get_supported_act_dtypes(self) -> list[torch.dtype]: + return [torch.bfloat16, torch.half] + + @classmethod + def get_min_capability(cls) -> int: + return 89 + + @classmethod + def override_quantization_method( + cls, hf_quant_cfg, user_quant, hf_config=None + ) -> QuantizationMethods | None: + algo = cls._extract_modelopt_quant_algo(hf_quant_cfg) + if algo is not None and algo == "MIXED_PRECISION": + return "modelopt_mixed" + return None + + @classmethod + def _from_config( + cls, + *, + quant_method: str, + kv_cache_quant_method: str | None, + exclude_modules: list[str], + original_config: dict[str, Any], + group_size: int | None, + **kwargs: Any, + ) -> "ModelOptMixedPrecisionConfig": + if "quantization" in original_config: + quantized_layers = original_config["quantization"].get( + "quantized_layers", {} + ) + else: + quantized_layers = original_config.get("quantized_layers", {}) + + if not quantized_layers: + raise ValueError( + "MIXED_PRECISION quant_algo requires a non-empty " + "'quantized_layers' mapping in the quantization config." + ) + + # Determine group_size from the first NVFP4 entry if not provided. + if group_size is None: + for layer_info in quantized_layers.values(): + if layer_info.get("quant_algo", "").upper() == "NVFP4": + group_size = layer_info.get("group_size", 16) + break + if group_size is None: + group_size = 16 + + fp8_config = ModelOptFp8Config( + quant_method="FP8", + is_checkpoint_fp8_serialized=True, + kv_cache_quant_method=kv_cache_quant_method, + exclude_modules=[], + ) + nvfp4_config = ModelOptNvFp4Config( + is_checkpoint_nvfp4_serialized=True, + kv_cache_quant_algo=kv_cache_quant_method, + exclude_modules=[], + group_size=group_size, + ) + + return cls( + kv_cache_quant_method=kv_cache_quant_method, + exclude_modules=exclude_modules, + quantized_layers=quantized_layers, + fp8_config=fp8_config, + nvfp4_config=nvfp4_config, + ) + + def _resolve_quant_algo(self, prefix: str) -> str | None: + """Look up the quant_algo for a vLLM-side layer prefix. + + Tries three strategies in order: + 1. Direct lookup in ``quantized_layers``. + 2. Packed/fused-layer lookup (unfuse via ``packed_modules_mapping``). + 3. Prefix-based lookup for RoutedExperts (any child key starts with + ``prefix + "."``). + + Returns the upper-cased quant_algo string, or *None* if the prefix + is not found. + """ + # 1. Direct lookup + if prefix in self.quantized_layers: + return self.quantized_layers[prefix]["quant_algo"].upper() + + # 2. Packed / fused layer lookup + proj_name = prefix.rsplit(".", 1)[-1] + if self.packed_modules_mapping and proj_name in self.packed_modules_mapping: + algos: set[str] = set() + base = prefix.rsplit(".", 1)[0] + for shard_name in self.packed_modules_mapping[proj_name]: + shard_prefix = f"{base}.{shard_name}" + if shard_prefix in self.quantized_layers: + algos.add(self.quantized_layers[shard_prefix]["quant_algo"].upper()) + if len(algos) == 1: + return algos.pop() + if len(algos) > 1: + raise ValueError( + f"Mixed quant_algo within fused layer {prefix}: " + f"{algos}. All shards must use the same quantization." + ) + + # 3. Prefix-based lookup (for RoutedExperts / parent modules) + prefix_dot = prefix + "." + for key, info in self.quantized_layers.items(): + if key.startswith(prefix_dot): + return info["quant_algo"].upper() + + return None + + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> "QuantizeMethodBase | None": + """Return quantize-method based on layer.""" + # KV-cache quantization + if isinstance(layer, Attention): + if self.kv_cache_quant_method: + return ModelOptKVCacheMethod(self) + return None + + # Excluded layers + if self.is_layer_excluded(prefix): + if isinstance(layer, LinearBase): + return UnquantizedLinearMethod() + return None + + quant_algo = self._resolve_quant_algo(prefix) + + if isinstance(layer, LinearBase): + if quant_algo == "FP8": + return ModelOptFp8LinearMethod(self.fp8_config) + if quant_algo == "NVFP4": + return ModelOptNvFp4LinearMethod(self.nvfp4_config) + # Layer not in quantized_layers — leave unquantized + return UnquantizedLinearMethod() + + if isinstance(layer, RoutedExperts): + if quant_algo == "FP8": + return ModelOptFp8MoEMethod( + quant_config=self.fp8_config, + moe_config=layer.moe_config, + ) + if quant_algo == "NVFP4": + return ModelOptNvFp4FusedMoE( + quant_config=self.nvfp4_config, + moe_config=layer.moe_config, + ) + return None + + return None + + def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): + super().apply_vllm_mapper(hf_to_vllm_mapper) + if self.quantized_layers: + self.quantized_layers = hf_to_vllm_mapper.apply_dict(self.quantized_layers)