Proper NVFP4 integration: use ModelOptNvFp4Config + FusedMoE framework

Major refactor to eliminate all post-load hacks:
- deepseek_v4.py: use upstream model with NVFP4 weight mapper only
  (gate_proj→w1, up_proj→w3, down_proj→w2, .self_attn→.attn, .mlp→.ffn)
- Add CuTeDSLMoEExperts as a FusedMoEExpertsModular subclass
  that wraps our CuTeDSL runner as a proper vLLM MoE backend
- Register CUTEDSL backend in the NVFP4 oracle
- Use ModelOptNvFp4Config for quantization dispatch (not DeepseekV4FP8Config)
- ModelOptNvFp4LinearMethod handles NVFP4 attention/shared expert projections
- Remove nvfp4_cutedsl.py, cutedsl_quant_method.py, utils.py from Dockerfile
- CuTeDSL runner moved to cutedsl/runner.py for clean imports
- cos_sin_cache float32 fix in deepseek_v4_attention.py

No more monkey-patching, no _convert_nvfp4_post_load, no CuTeDSLNvfp4Method.
This commit is contained in:
2026-05-18 22:19:23 +00:00
parent 48386e34ad
commit a7ed8faec6
7 changed files with 4427 additions and 1320 deletions

View File

@@ -30,17 +30,21 @@ ENV PYTHONPATH="/root/nvfp4-megamoe-kernel:${PYTHONPATH}"
# Patch vLLM — overwrite model files and register architecture
ARG VLLM_MODELS_DIR=/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models
ARG VLLM_LAYERS_DIR=/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers
ARG VLLM_QUANT_DIR=/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/quantization
ARG VLLM_FUSED_MOE_DIR=/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe
ARG VLLM_LOADER_DIR=/usr/local/lib/python3.12/dist-packages/vllm/model_executor/model_loader
# Core model patches
COPY vllm/patches/deepseek_v4.py ${VLLM_MODELS_DIR}/deepseek_v4.py
COPY vllm/patches/deepseek_v4_attention.py ${VLLM_LAYERS_DIR}/deepseek_v4_attention.py
COPY vllm/nvfp4_cutedsl.py ${VLLM_MODELS_DIR}/nvfp4_cutedsl.py
COPY vllm/cutedsl_quant_method.py ${VLLM_MODELS_DIR}/cutedsl_quant_method.py
COPY cutedsl/nvfp4_linear.py /root/nvfp4-megamoe-kernel/cutedsl/nvfp4_linear.py
COPY cutedsl/shared_expert_pipeline.py /root/nvfp4-megamoe-kernel/cutedsl/shared_expert_pipeline.py
COPY vllm/patches/utils.py ${VLLM_LOADER_DIR}/utils.py
RUN sed -i 's/"DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),/"DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),\n "DeepseekV4ForCausalLM": ("deepseek_v4", "DeepseekV4ForCausalLM"),/' \
# NVFP4 MoE backend registration
COPY vllm/patches/fused_moe/oracle/nvfp4.py ${VLLM_FUSED_MOE_DIR}/oracle/nvfp4.py
COPY vllm/patches/fused_moe/experts/cutedsl_moe.py ${VLLM_FUSED_MOE_DIR}/experts/cutedsl_moe.py
# Register DeepseekV4ForCausalLM model architecture (if not already in upstream)
RUN grep -q '"DeepseekV4ForCausalLM"' ${VLLM_MODELS_DIR}/registry.py || \
sed -i 's/"DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),/"DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),\n "DeepseekV4ForCausalLM": ("deepseek_v4", "DeepseekV4ForCausalLM"),/' \
${VLLM_MODELS_DIR}/registry.py
# Verify

529
cutedsl/runner.py Normal file
View File

@@ -0,0 +1,529 @@
"""
vLLM integration for the CuTeDSL NVFP4 MoE kernel.
CUDA-graph-compatible design:
- All intermediate buffers pre-allocated at max_num_tokens * top_k size
- No .item(), .tolist(), .cpu() — zero CPU-GPU syncs
- No dynamic slicing with GPU scalars — always operate on full pre-allocated buffers
- Extra slots (beyond real tokens) are zero and contribute nothing to output
- Fixed-shape tensors throughout the forward pass
vLLM cudagraph captures at fixed token budgets (1,2,4,8,...,8192).
During capture, num_tokens equals the budget — all shapes are fixed.
During replay, inputs are padded to the budget size. Our runner always
processes max_slots = budget * top_k rows; padding rows are zeros.
"""
import torch
from cutedsl.bridge import (
quantize_activation_nvfp4,
quantize_weight_to_nvfp4,
class _MoEApply(torch.autograd.Function):
"""Custom autograd function to make CuTeDSL MoE runner opaque to torch.compile."""
@staticmethod
def forward(ctx, runner, hidden_states, topk_weights, topk_ids, expert_indices):
return runner._run_impl(hidden_states, topk_weights, topk_ids, expert_indices)
quantize_to_nvfp4,
make_b_k_major,
assemble_scales_3d_side,
run_nvfp4_grouped_gemm,
)
from cutedsl.kernel.moe.torch_scaled_grouped_mm import (
ceil_div as cutedsl_ceil_div,
pad_and_swizzle_single,
)
class CuTeDSLMoERunner:
"""Manages NVFP4 MoE execution via the CuTeDSL kernel.
CUDA-graph-compatible: all buffers pre-allocated, no CPU-GPU syncs,
no dynamic shapes. Always computes at max_num_tokens * top_k capacity.
"""
def __init__(self, num_experts, hidden_size, intermediate_size,
max_num_tokens=8192, top_k=8, device="cuda",
experts_start_idx=0):
self.num_experts = num_experts
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.max_num_tokens = max_num_tokens
self.top_k = top_k
self.device = device
self.experts_start_idx = experts_start_idx
self._swiglu_limit = None # Set via set_swiglu_limit()
# Weight storage (set before _ensure_stacked)
self.l1_fp4 = None
self.l1_sf = None
self.l1_gs = None
self.l2_fp4 = None
self.l2_sf = None
self.l2_gs = None
# Stacked weight tensors (set in _ensure_stacked)
self._l1_mat_b = None
self._l2_mat_b = None
self._l1_scale_b = None
self._l2_scale_b = None
self._l1_gsb = None
self._l2_gsb = None
# Default: 1/2688 ≈ 0.000372 (amax=1 → gs=1/2688)
# Overridden in finalize_weights with checkpoint input_scale or warmup value
self._l1_activation_global_scale = 1.0 / (6.0 * 448.0)
self._l2_activation_global_scale = 1.0 / (6.0 * 448.0)
# Pre-allocated cudagraph buffers (set in _allocate_buffers)
self._token_indices = None
self._expert_id_range = None
self._expert_offsets_buf = None
self._per_expert_scale_bufs_l1 = None
self._per_expert_scale_bufs_l2 = None
self._padded_x_sf_buf_l1 = None
self._padded_x_sf_buf_l2 = None
self._l1_gsa_buf = None
self._l2_gsa_buf = None
self._output_buf = None
self._row_indices_buf = None
self._padded_hidden_buf = None
self._padded_activated_buf = None # unused, using shared
self._padded_expert_offsets_buf = None
self._max_chunks_per_expert = cutedsl_ceil_div(
self.max_num_tokens * self.top_k, self.num_experts * 128
)
self._buffers_allocated = False
def set_swiglu_limit(self, limit: float | None):
"""Set the swiglu_limit for activation clamping."""
self._swiglu_limit = limit
def _fill_token_indices(self):
"""Fill _token_indices with [0,0,..0, 1,1,..1, ...] (each token repeated top_k times).
Builds on CPU first, then copies to GPU, to ensure correctness
regardless of CuTeDSL JIT GPU memory corruption.
"""
src = torch.arange(self.max_num_tokens, dtype=torch.int32)
cpu_indices = src.unsqueeze(1).expand(-1, self.top_k).contiguous().view(-1)
self._token_indices.copy_(cpu_indices)
def _allocate_buffers(self):
"""Pre-allocate scale buffers at max size for cudagraph compatibility."""
# Per-expert scale buffers: separate L1/L2 since K_sf differs
K_sf_l1 = cutedsl_ceil_div(self.hidden_size, 16)
padded_cols_l1 = cutedsl_ceil_div(K_sf_l1, 4) * 4
K_sf_l2 = cutedsl_ceil_div(self.intermediate_size, 16)
padded_cols_l2 = cutedsl_ceil_div(K_sf_l2, 4) * 4
self._per_expert_scale_bufs_l1 = [
torch.zeros(128, padded_cols_l1, dtype=torch.float16, device=self.device).to(torch.float8_e4m3fn)
for _ in range(self.num_experts)
]
self._per_expert_scale_bufs_l2 = [
torch.zeros(128, padded_cols_l2, dtype=torch.float16, device=self.device).to(torch.float8_e4m3fn)
for _ in range(self.num_experts)
]
# Initialize shared buffers dict (if not already)
device_key = str(self.device)
if not hasattr(CuTeDSLMoERunner, '_shared_padded_bufs'):
CuTeDSLMoERunner._shared_padded_bufs = {}
if device_key not in CuTeDSLMoERunner._shared_padded_bufs:
CuTeDSLMoERunner._shared_padded_bufs[device_key] = {}
# Padded x_sf buffers: SHARED across all runners (not per-layer)
max_sf_rows = self.num_experts * self._max_chunks_per_expert * 128
if 'xsf_l1' not in CuTeDSLMoERunner._shared_padded_bufs[device_key]:
CuTeDSLMoERunner._shared_padded_bufs[device_key].update({
'xsf_l1': torch.zeros(
max_sf_rows, padded_cols_l1, dtype=torch.float16, device=self.device
).to(torch.float8_e4m3fn),
'xsf_l2': torch.zeros(
max_sf_rows, padded_cols_l2, dtype=torch.float16, device=self.device
).to(torch.float8_e4m3fn),
'output': torch.zeros(
self.max_num_tokens, self.hidden_size, dtype=torch.bfloat16, device=self.device
),
})
self._padded_x_sf_buf_l1 = CuTeDSLMoERunner._shared_padded_bufs[device_key]['xsf_l1']
self._padded_x_sf_buf_l2 = CuTeDSLMoERunner._shared_padded_bufs[device_key]['xsf_l2']
self._output_buf = CuTeDSLMoERunner._shared_padded_bufs[device_key]['output']
# Pre-allocated global_scale_a buffers (filled via .fill_(), no torch.full during capture)
self._l1_gsa_buf = torch.zeros(self.num_experts, dtype=torch.float32, device=self.device)
self._l2_gsa_buf = torch.zeros(self.num_experts, dtype=torch.float32, device=self.device)
# Row indices for scale assembly (max_num_tokens * top_k slots)
self._row_indices_buf = torch.arange(
self.max_num_tokens * self.top_k, device=self.device
)
# Padded hidden/activated: SHARED across all runners (not per-layer)
max_rows_per_expert = self._max_chunks_per_expert * 128
padded_max_slots = self.num_experts * max_rows_per_expert
if 'hidden' not in CuTeDSLMoERunner._shared_padded_bufs[device_key]:
CuTeDSLMoERunner._shared_padded_bufs[device_key].update({
'hidden': torch.zeros(
padded_max_slots, self.hidden_size, dtype=torch.bfloat16, device=self.device
),
'hidden_fp4': torch.zeros(
padded_max_slots, self.hidden_size // 2, dtype=torch.uint8, device=self.device
).view(torch.float4_e2m1fn_x2),
'activated': torch.zeros(
padded_max_slots, self.intermediate_size, dtype=torch.bfloat16, device=self.device
),
'activated_fp4': torch.zeros(
padded_max_slots, self.intermediate_size // 2, dtype=torch.uint8, device=self.device
).view(torch.float4_e2m1fn_x2),
})
self._shared_bufs = CuTeDSLMoERunner._shared_padded_bufs[device_key]
# Padded expert offsets buffer: [0, max_rows, 2*max_rows, ...] (fixed)
self._padded_expert_offsets_buf = torch.zeros(
self.num_experts + 1, dtype=torch.int32, device=self.device
)
max_rows_per_expert = self._max_chunks_per_expert * 128
self._padded_expert_offsets_buf[1:] = torch.arange(
1, self.num_experts + 1, dtype=torch.int32, device=self.device
) * max_rows_per_expert
self._buffers_allocated = True
def _ensure_stacked(self):
if self._l1_mat_b is not None:
return
# Stack and prepare weight tensors FIRST (triggers CuTeDSL JIT compilation)
self._l1_mat_b = make_b_k_major(torch.stack(self.l1_fp4))
self._l2_mat_b = make_b_k_major(torch.stack(self.l2_fp4))
self._l1_scale_b = assemble_scales_3d_side(self.l1_sf)
self._l2_scale_b = assemble_scales_3d_side(self.l2_sf)
self._l1_gsb = torch.tensor(self.l1_gs, dtype=torch.float32, device=self.device)
self._l2_gsb = torch.tensor(self.l2_gs, dtype=torch.float32, device=self.device)
self.l1_fp4 = None
self.l1_sf = None
self.l1_gs = None
self.l2_fp4 = None
self.l2_sf = None
self.l2_gs = None
# Allocate buffers AFTER JIT compilation
# (CuTeDSL's cute.compile corrupts GPU memory during JIT;
# tensors allocated before/during compilation may be zeroed)
#
# _token_indices: GPU tensor for cudagraph compatibility.
# CuTeDSL JIT may corrupt GPU memory, so we fill AFTER stacking
# (which triggers the weight JIT). The GEMM JIT in run_nvfp4_grouped_gemm
# is triggered on the first run() call; we refill _token_indices after
# that first call via the _needs_token_refill flag.
self._token_indices = torch.zeros(
self.max_num_tokens * self.top_k, dtype=torch.int32, device=self.device
)
self._fill_token_indices()
self._needs_token_refill = True # GEMM JIT may corrupt; refill after first run
self._expert_id_range = torch.arange(
self.num_experts, dtype=torch.int32
).to(self.device)
self._expert_offsets_buf = torch.zeros(
self.num_experts + 1, dtype=torch.int32, device=self.device
)
self._allocate_buffers()
def prepare_weights_direct(self, l1_fp4, l1_sf, l1_gs, l2_fp4, l2_sf, l2_gs):
self.l1_fp4 = l1_fp4
self.l1_sf = l1_sf
self.l1_gs = l1_gs
self.l2_fp4 = l2_fp4
self.l2_sf = l2_sf
self.l2_gs = l2_gs
self._l1_mat_b = None
def prepare_weights_from_dequantized(self, l1_weights_bf16, l2_weights_bf16):
self.l1_fp4, self.l1_sf, self.l1_gs = [], [], []
self.l2_fp4, self.l2_sf, self.l2_gs = [], [], []
for l1_w, l2_w in zip(l1_weights_bf16, l2_weights_bf16):
l1_w_t = l1_w.T
w_fp4, w_sf, w_gs = quantize_weight_to_nvfp4(l1_w_t)
self.l1_fp4.append(w_fp4)
self.l1_sf.append(w_sf)
self.l1_gs.append(w_gs)
l2_w_t = l2_w.T
w_fp4, w_sf, w_gs = quantize_weight_to_nvfp4(l2_w_t)
self.l2_fp4.append(w_fp4)
self.l2_sf.append(w_sf)
self.l2_gs.append(w_gs)
self._l1_mat_b = None
def _assemble_scales_cudagraph_safe(self, x_sf, expert_offsets,
padded_expert_offsets,
padded_x_sf_buf, per_expert_bufs):
"""Assemble 2D-side activation scales (cudagraph-safe, NO CPU syncs).
Phase 1: Scatter x_sf into padded per-expert sections (GPU-only).
Phase 2: Apply full-buffer Blackwell 32_4_4 swizzle (no Python loops).
The buffer is 128-row aligned per expert (from padded_expert_offsets),
so the full-buffer swizzle produces the correct layout. The GEMM reads
scale_a using padded_expert_offsets, matching the scatter layout.
"""
K_sf = x_sf.shape[1]
padded_x_sf = padded_x_sf_buf
padded_x_sf.zero_()
# Phase 1: Scatter x_sf into padded per-expert sections (GPU-only)
total_rows = x_sf.shape[0]
row_indices = self._row_indices_buf[:total_rows]
expert_assign = torch.searchsorted(
expert_offsets[1:], row_indices, right=True
).clamp(max=self.num_experts - 1)
local_row = row_indices - expert_offsets[expert_assign]
dst_rows = padded_expert_offsets[expert_assign] + local_row
padded_x_sf[dst_rows, :K_sf] = x_sf
# Phase 2: Full-buffer swizzle (no CPU sync, no Python loops)
# padded_x_sf is 128-row aligned per expert and 4-col aligned.
# to_blocked: (rows, cols) → view(R, 128, C, 4) → permute(0,2,1,3)
# → reshape(-1, 4, 32, 4) → transpose(1,2) → reshape(-1, 32, 16) → flatten
rows = padded_x_sf.shape[0]
cols = padded_x_sf.shape[1]
R = rows // 128
C = cols // 4
blocks = padded_x_sf.view(R, 128, C, 4).permute(0, 2, 1, 3)
rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16)
swizzled = rearranged.flatten().view(torch.float8_e4m3fn)
return swizzled.reshape(rows, cols)
def compute_activation_global_scales(self, hidden_states_sample, topk_weights, topk_ids):
"""Compute activation global scales from a warmup forward pass.
Called BEFORE cudagraph capture. Uses the SAME padded GEMM path as run()
to ensure kernel JIT happens with the same layout, and L2 gs is computed
from actual L1 output (not an approximation).
"""
self._ensure_stacked()
device = hidden_states_sample.device
num_tokens = hidden_states_sample.shape[0]
top_k = topk_ids.shape[1]
with torch.no_grad():
# Build slot mapping (same as run())
flat_ids = topk_ids.reshape(-1)
num_slots = num_tokens * top_k
token_indices = self._token_indices[:num_slots]
sort_idx = flat_ids.argsort(stable=True)
sorted_ids = flat_ids[sort_idx]
sorted_token_ids = token_indices[sort_idx]
slot_hidden = hidden_states_sample[sorted_token_ids]
# L1: get exact gs from quantize_to_nvfp4
_, _, l1_gs = quantize_to_nvfp4(slot_hidden)
# Quantize slot_hidden for GEMM
slot_x_fp4, slot_x_sf = quantize_activation_nvfp4(slot_hidden, l1_gs)
expert_id_range = self._expert_id_range
tokens_per_expert = (sorted_ids.unsqueeze(1) == expert_id_range.unsqueeze(0)).sum(dim=0).int()
expert_offsets = self._expert_offsets_buf
expert_offsets.zero_()
expert_offsets[1:self.num_experts + 1] = tokens_per_expert.cumsum(0)
padded_tokens_per_expert = ((tokens_per_expert + 127) // 128) * 128
padded_expert_offsets = self._padded_expert_offsets_buf
padded_expert_offsets.zero_()
padded_expert_offsets[1:self.num_experts + 1] = padded_tokens_per_expert.cumsum(0)
# Compute padded_dst (same as run())
row_indices = self._row_indices_buf[:num_slots]
expert_assign = torch.searchsorted(
expert_offsets[1:], row_indices, right=True
).clamp(max=self.num_experts - 1)
local_row = row_indices - expert_offsets[expert_assign]
padded_dst = padded_expert_offsets[expert_assign] + local_row
# Scatter x_fp4 into padded layout
padded_x_fp4 = self._shared_bufs['hidden_fp4']
padded_x_fp4.view(torch.uint8).zero_()
padded_x_fp4.view(torch.uint8)[padded_dst] = slot_x_fp4.view(torch.uint8)
l1_scale_a = self._assemble_scales_cudagraph_safe(
slot_x_sf, expert_offsets[:self.num_experts + 1],
padded_expert_offsets,
self._padded_x_sf_buf_l1, self._per_expert_scale_bufs_l1
)
l1_gsa = torch.full((self.num_experts,), l1_gs, dtype=torch.float32, device=device)
l1_out = run_nvfp4_grouped_gemm(
mat_a=padded_x_fp4, mat_b=self._l1_mat_b,
scale_a=l1_scale_a, scale_b=self._l1_scale_b,
expert_offsets=padded_expert_offsets[1:self.num_experts + 1],
global_scale_a=l1_gsa, global_scale_b=self._l1_gsb,
)
# Extract real token outputs
l1_out_real = l1_out[padded_dst]
# L2: get exact gs from SiLU(gate)*up
gate = l1_out_real[:, :self.intermediate_size]
up = l1_out_real[:, self.intermediate_size:]
gate_silu = torch.nn.functional.silu(gate)
if self._swiglu_limit is not None:
gate_silu = gate_silu.clamp(max=self._swiglu_limit)
up = up.clamp(min=-self._swiglu_limit, max=self._swiglu_limit)
activated = gate_silu * up
_, _, l2_gs = quantize_to_nvfp4(activated)
self._l1_activation_global_scale = l1_gs
self._l2_activation_global_scale = l2_gs
def run(self, hidden_states, topk_weights, topk_ids, expert_indices=None):
"""Forward: route tokens to experts, GEMM, combine."""
return _MoEApply.apply(self, hidden_states, topk_weights, topk_ids, expert_indices)
def _run_impl(self, hidden_states, topk_weights, topk_ids, expert_indices=None):
"""Run the NVFP4 MoE forward pass.
Handles global→local expert ID remapping for expert parallelism.
Fully cudagraph-safe: no CPU-GPU syncs, no dynamic shapes.
Each expert's slots are padded to multiples of 128 for the GEMM.
expert_offsets is [0, padded_e0, padded_e0+padded_e1, ...].
scale_a is produced at those same offsets.
"""
num_tokens = hidden_states.shape[0]
top_k = topk_ids.shape[1]
device = hidden_states.device
self._ensure_stacked()
# -- Remap global expert IDs to local IDs --
local_ids = topk_ids - self.experts_start_idx
local_mask = (local_ids >= 0) & (local_ids < self.num_experts)
safe_ids = local_ids.clamp(0, self.num_experts - 1)
safe_weights = topk_weights * local_mask.float()
# -- Build slot mapping --
flat_ids = safe_ids.reshape(-1)
flat_weights = safe_weights.reshape(-1)
num_slots = num_tokens * top_k
token_indices = self._token_indices[:num_slots]
sort_idx = flat_ids.argsort(stable=True)
sorted_ids = flat_ids[sort_idx]
sorted_weights = flat_weights[sort_idx]
sorted_token_ids = token_indices[sort_idx]
# Expert offsets (real token counts)
expert_id_range = self._expert_id_range
tokens_per_expert = (sorted_ids.unsqueeze(1) == expert_id_range.unsqueeze(0)).sum(dim=0).int()
expert_offsets = self._expert_offsets_buf
expert_offsets.zero_()
expert_offsets[1:self.num_experts + 1] = tokens_per_expert.cumsum(0)
# Pad each expert to 128-row alignment (GPU-only computation)
padded_tokens_per_expert = ((tokens_per_expert + 127) // 128) * 128
padded_expert_offsets = self._padded_expert_offsets_buf
padded_expert_offsets.zero_()
padded_expert_offsets[1:self.num_experts + 1] = padded_tokens_per_expert.cumsum(0)
total_padded_slots = padded_expert_offsets[self.num_experts]
# -- Gather hidden states into slot order, compute padded_dst --
slot_hidden = hidden_states[sorted_token_ids]
row_indices = self._row_indices_buf[:num_slots]
expert_assign = torch.searchsorted(
expert_offsets[1:], row_indices, right=True
).clamp(max=self.num_experts - 1)
local_row = row_indices - expert_offsets[expert_assign]
padded_dst = padded_expert_offsets[expert_assign] + local_row
# === L1: gate + up ===
# Quantize slot_hidden (sorted tokens), NOT padded_hidden.
# padded_hidden is padded with zeros; quantizing it produces
# x_sf rows at padded positions, but x_sf[:num_slots] would
# only get scales for the first num_slots PADDED rows (expert 0),
# not the scattered token positions. Quantizing slot_hidden
# gives x_sf with num_slots rows (one per token), which the
# scale assembly correctly scatters into padded layout.
slot_x_fp4, slot_x_sf = quantize_activation_nvfp4(
slot_hidden, self._l1_activation_global_scale
)
# Scatter x_fp4 into padded layout for the GEMM
# Must scatter as uint8 (float4_e2m1fn_x2 doesn't support index_put)
padded_x_fp4 = self._shared_bufs['hidden_fp4']
padded_x_fp4.view(torch.uint8).zero_()
padded_x_fp4.view(torch.uint8)[padded_dst] = slot_x_fp4.view(torch.uint8)
l1_scale_a = self._assemble_scales_cudagraph_safe(
slot_x_sf, expert_offsets[:self.num_experts + 1],
padded_expert_offsets,
self._padded_x_sf_buf_l1, self._per_expert_scale_bufs_l1
)
l1_gsa = self._l1_gsa_buf.fill_(self._l1_activation_global_scale)
l1_out = run_nvfp4_grouped_gemm(
mat_a=padded_x_fp4, mat_b=self._l1_mat_b,
scale_a=l1_scale_a, scale_b=self._l1_scale_b,
expert_offsets=padded_expert_offsets[1:self.num_experts + 1],
global_scale_a=l1_gsa, global_scale_b=self._l1_gsb,
)
# Extract real token outputs from padded GEMM output
l1_out_real = l1_out[padded_dst]
# === SiLU(gate) * up (with swiglu_limit clamp) ===
gate = l1_out_real[:, :self.intermediate_size]
up = l1_out_real[:, self.intermediate_size:]
gate_silu = torch.nn.functional.silu(gate)
# Apply DeepSeek-V4 swiglu_limit: clamp both silu(gate) and up
if self._swiglu_limit is not None:
gate_silu = gate_silu.clamp(max=self._swiglu_limit)
up = up.clamp(min=-self._swiglu_limit, max=self._swiglu_limit)
activated = gate_silu * up
# === L2: down ===
# Quantize activated (per-token), scatter into padded FP4 buffer
slot_l2_x_fp4, slot_l2_x_sf = quantize_activation_nvfp4(
activated, self._l2_activation_global_scale
)
padded_activated_fp4 = self._shared_bufs['activated_fp4']
padded_activated_fp4.view(torch.uint8).zero_()
padded_activated_fp4.view(torch.uint8)[padded_dst] = slot_l2_x_fp4.view(torch.uint8)
l2_scale_a = self._assemble_scales_cudagraph_safe(
slot_l2_x_sf, expert_offsets[:self.num_experts + 1],
padded_expert_offsets,
self._padded_x_sf_buf_l2, self._per_expert_scale_bufs_l2
)
l2_gsa = self._l2_gsa_buf.fill_(self._l2_activation_global_scale)
l2_out = run_nvfp4_grouped_gemm(
mat_a=padded_activated_fp4, mat_b=self._l2_mat_b,
scale_a=l2_scale_a, scale_b=self._l2_scale_b,
expert_offsets=padded_expert_offsets[1:self.num_experts + 1],
global_scale_a=l2_gsa, global_scale_b=self._l2_gsb,
)
l2_out_real = l2_out[padded_dst]
# === Scatter -> final output ===
y = self._output_buf[:num_tokens]
y.zero_()
weighted_out = l2_out_real * sorted_weights.unsqueeze(1).to(l2_out_real.dtype)
y.scatter_add_(
0,
sorted_token_ids.unsqueeze(1).expand(-1, self.hidden_size),
weighted_out,
)
# Refill _token_indices after GEMM JIT on first call
# (CuTeDSL's cute.compile may corrupt GPU memory during first GEMM)
if self._needs_token_refill:
self._fill_token_indices()
self._needs_token_refill = False
return y

File diff suppressed because it is too large Load Diff

View File

@@ -14,6 +14,7 @@ import torch.nn.functional as F
from transformers import DeepseekV2Config, DeepseekV3Config
import vllm.envs as envs
from vllm.compilation.breakable_cudagraph import eager_break_during_capture
from vllm.model_executor.layers.linear import (
ReplicatedLinear,
)
@@ -28,6 +29,7 @@ from vllm.v1.attention.ops.deepseek_v4_ops import (
fused_inv_rope_fp8_quant,
fused_q_kv_rmsnorm,
)
from vllm.v1.attention.ops.rocm_aiter_mla_sparse import rocm_inv_rope_einsum
if TYPE_CHECKING:
from vllm.v1.attention.backends.mla.sparse_swa import (
@@ -45,7 +47,7 @@ from vllm.logger import init_logger
from vllm.model_executor.custom_op import PluggableLayer
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.deepseek_compressor import DeepseekCompressor
from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.input_quant_fp8 import (
QuantFP8,
@@ -53,6 +55,7 @@ from vllm.model_executor.layers.quantization.input_quant_fp8 import (
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
)
from vllm.platforms import current_platform
from vllm.utils.multi_stream_utils import (
execute_in_parallel,
maybe_execute_in_parallel,
@@ -198,8 +201,6 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
# Pick fp8_einsum recipe based on GPU arch:
# SM90: FP32 block scales stay [g, r/128, d/128] → sfb_gran_mn=128
# SM100: INT32 packed scales become [g, r, ...] → sfb_gran_mn=1
from vllm.platforms import current_platform
cap = current_platform.get_device_capability()
assert cap is not None, "DeepseekV4 attention requires a CUDA device"
self._einsum_recipe = (1, 128, 128) if cap.major <= 9 else (1, 1, 128)
@@ -222,6 +223,7 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
+ 1 # 1B pad
)
# Will be None on ROCm for now.
self.aux_stream_list = mla_modules.aux_stream_list
# [0]: GEMM start / post-GEMM event0. [1..3]: GEMM done events;
# [1] doubles as post-GEMM event1. Reuse is safe: GEMM fully joins
@@ -303,6 +305,19 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
)
o = o_padded[:, : self.n_local_heads, :]
# Keep ROCm on the BF16 reference wo_a path util kernel ready.
if current_platform.is_rocm():
z = rocm_inv_rope_einsum(
self.rotary_emb,
o,
positions,
self.rope_head_dim,
self.n_local_groups,
self.o_lora_rank,
self.wo_a,
)
return self.wo_b(z.flatten(1))
# O projection: inverse RoPE + FP8 quant + einsum + wo_b
o_fp8, o_scale = fused_inv_rope_fp8_quant(
o,
@@ -336,12 +351,15 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
return self.wo_b(z.flatten(1))
def attn_gemm_parallel_execute(self, hidden_states) -> tuple[Any, ...]:
assert self.aux_stream_list is not None
assert len(self.aux_stream_list) >= 3
aux_streams = self.aux_stream_list
if aux_streams is not None:
assert len(aux_streams) >= 3
aux_streams = aux_streams[:3]
# fused_wqa_wkv (heaviest) on default; the three lighter input GEMMs
# on aux streams 0..2 when their owning module exists. ln_events[0]
# is the fan-out start event; ln_events[1..3] are per-aux done events.
# On ROCm, aux_streams is None and execute_in_parallel runs serially.
aux_fns: list[Callable[[], Any] | None] = [None, None, None]
if self.compressor is not None:
@@ -385,7 +403,7 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
aux_fns,
self.ln_events[0],
self.ln_events[1:4],
self.aux_stream_list[:3],
aux_streams,
enable=hidden_states.shape[0]
<= envs.VLLM_MULTI_STREAM_GEMM_TOKEN_THRESHOLD,
)
@@ -419,8 +437,9 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
# downstream reads q on default). Indexer/compressor go on aux for
# overlap with default's GEMM + cache write.
if self.indexer is not None:
assert self.aux_stream_list is not None
aux_stream = self.aux_stream_list[0]
aux_stream = (
self.aux_stream_list[0] if self.aux_stream_list is not None else None
)
indexer = self.indexer
# Local ref so the closure keeps a non-None type for mypy.
assert self.compressor is not None
@@ -448,8 +467,9 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
)
elif self.compressor is not None:
# wq_b + kv_insert on default, compressor on aux.
assert self.aux_stream_list is not None
aux_stream = self.aux_stream_list[0]
aux_stream = (
self.aux_stream_list[0] if self.aux_stream_list is not None else None
)
compressor = self.compressor
def wq_b_kv_insert() -> torch.Tensor:
@@ -534,6 +554,7 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
)
@eager_break_during_capture
def deepseek_v4_attention(
hidden_states: torch.Tensor,
positions: torch.Tensor,
@@ -668,7 +689,7 @@ class DeepseekV4MLAAttention(nn.Module, AttentionLayerBase):
vllm_config.scheduler_config.max_num_batched_tokens
)
self.max_model_len = vllm_config.model_config.max_model_len
# DeepseekV4 only supports fp8 kv-cache format for now
# DeepseekV4 only supports fp8 kv-cache format for now.
kv_cache_dtype = cache_config.cache_dtype if cache_config is not None else "fp8"
assert kv_cache_dtype.startswith("fp8"), (
@@ -702,6 +723,12 @@ class DeepseekV4MLAAttention(nn.Module, AttentionLayerBase):
self.kv_cache = torch.tensor([])
def get_attn_backend(self) -> type[AttentionBackend]:
if current_platform.is_rocm():
from vllm.v1.attention.backends.mla.rocm_aiter_mla_sparse_dsv4 import (
DeepseekV4ROCMAiterMLASparseBackend,
)
return DeepseekV4ROCMAiterMLASparseBackend
return DeepseekV4FlashMLASparseBackend
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None:
@@ -734,6 +761,14 @@ class DeepseekV4MLAAttention(nn.Module, AttentionLayerBase):
f"output buffer dtype {output.dtype} must match q dtype {q.dtype}"
)
if current_platform.is_rocm():
from vllm.v1.attention.backends.mla.rocm_aiter_mla_sparse_dsv4 import (
DeepseekV4ROCMAiterMLASparseImpl,
)
DeepseekV4ROCMAiterMLASparseImpl.forward(self, q, kv, positions, output)
return
# Get SWA and indexer metadata from forward context
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
@@ -979,8 +1014,7 @@ class DeepseekV4MLAAttention(nn.Module, AttentionLayerBase):
M,
N,
)
output_chunk, _, _ = flash_mla_sparse_fwd(
flash_mla_sparse_fwd(
q=q[query_start:query_end],
kv=kv.view(-1, 1, q.shape[-1]),
indices=combined_indices.unsqueeze(1),
@@ -1077,7 +1111,6 @@ class DeepseekV4Indexer(nn.Module):
quant_config=None,
prefix=f"{prefix}.weights_proj",
)
self.k_norm = LayerNorm(self.head_dim, eps=1e-6)
self.softmax_scale = self.head_dim**-0.5
self.scale_fmt = "ue8m0"

View File

@@ -0,0 +1,303 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""CuTeDSL NVFP4 MoE experts for DeepSeek-V4.
Integrates the CuTeDSL NVFP4 grouped GEMM kernel into vLLM's FusedMoE
modular kernel framework. This is the proper integration path — no
monkey-patching, no post-load surgery.
The CuTeDSL kernel is a Python-based CUTLASS kernel compiled via MLIR → PTX.
It handles:
- L1 GEMM (gate + up projections)
- SiLU activation with optional swiglu_limit clamping
- L2 GEMM (down projection)
- All with NVFP4 (float8_e4m3fn block scales + float32 global scales)
CUDA-graph-safe: all intermediate buffers pre-allocated, no CPU-GPU syncs,
no dynamic shapes.
"""
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kNvfp4Dynamic,
kNvfp4Static,
)
from vllm.platforms import current_platform
from cutedsl.runner import CuTeDSLMoERunner
class CuTeDSLMoEExperts(mk.FusedMoEExpertsModular):
"""CuTeDSL NVFP4 MoE experts using the custom CuTeDSL grouped GEMM kernel.
Uses Standard activation format (non-batched). Handles input quantization
internally (expects_unquantized_inputs=True).
Supports expert parallelism: remaps global→local expert IDs.
"""
def __init__(
self,
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
):
super().__init__(
moe_config=moe_config,
quant_config=quant_config,
)
assert quant_config.quant_dtype == "nvfp4", (
"CuTeDSL MoE only supports nvfp4 quantization, "
f"got {quant_config.quant_dtype}"
)
self.out_dtype = moe_config.in_dtype
self.hidden_dim = moe_config.hidden_dim
self.intermediate_size_per_partition = (
moe_config.intermediate_size_per_partition
)
self.topk = moe_config.experts_per_token
self.local_num_experts = moe_config.num_local_experts
self.global_num_experts = moe_config.num_experts
self.ep_rank = moe_config.moe_parallel_config.ep_rank
self.local_expert_offset = self.ep_rank * self.local_num_experts
# max_num_tokens from scheduler config (for buffer pre-allocation)
self.max_num_tokens = getattr(moe_config, 'max_num_tokens', 8192)
# swiglu_limit: read from the FusedMoE layer in process_weights_after_loading
self._swiglu_limit = None
# Runner is created in process_weights_after_loading
self._runner: CuTeDSLMoERunner | None = None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
"""Convert NVFP4 MoE weights into CuTeDSL kernel format.
Reads w13/w2 weight tensors from the FusedMoE layer, converts them
to the CuTeDSL runner's expected format, and creates the runner.
Also folds the activation global scale (input_scale) into the
weight global scale (weight_scale_2) as the runner's alpha.
"""
num_experts = layer.w13_weight.shape[0]
hidden_size = self.hidden_dim
intermediate_size = self.intermediate_size_per_partition
device = layer.w13_weight.device
# NOTE: For the CuTeDSL kernel, we do NOT fold input_scale into
# weight_scale_2. The CuTeDSL runner uses weight global scale
# (weight_scale_2) and activation global scale separately.
# The activation global scale is computed via warmup before first inference.
#
# Also, convert_to_nvfp4_moe_kernel_format already inverted input_scale
# (1.0 / a13_scale) for the quant config. We undo that inversion here
# to get the original input_scale, then use it as initial activation gs.
if layer.w13_input_scale is not None and not isinstance(layer.w13_input_scale, float):
# input_scale was inverted in convert_to_nvfp4_moe_kernel_format
# Original: input_scale = amax / (6.0 * 448.0)
# Inverted: 1.0 / input_scale = 6.0 * 448.0 / amax
# We need the original for activation gs
w13_input_scale_orig = 1.0 / layer.w13_input_scale
else:
w13_input_scale_orig = None
if layer.w2_input_scale is not None and not isinstance(layer.w2_input_scale, float):
w2_input_scale_orig = 1.0 / layer.w2_input_scale
else:
w2_input_scale_orig = None
# Extract and convert weights for CuTeDSL runner
# w13_weight: (E, 2*intermediate, hidden//2) uint8 — gate + up fused
# w2_weight: (E, hidden, intermediate//2) uint8 — down
l1_fp4_list = []
l1_sf_list = []
l1_gs_list = []
l2_fp4_list = []
l2_sf_list = []
l2_gs_list = []
for expert_id in range(num_experts):
# L1: gate + up (w13)
w13_uint8 = layer.w13_weight.data[expert_id] # (2*inter, hidden//2)
w13_sf = layer.w13_weight_scale.data[expert_id] # (2*inter, hidden//16) fp8
w13_gs = layer.w13_weight_scale_2.data[expert_id].item() # float32
# uint8 → float4_e2m1fn_x2, permute to (K_packed, N) for CuTeDSL
l1_w = w13_uint8.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous()
# Block scales: (N, K_sf) → (K_sf, N) for CuTeDSL
l1_s = w13_sf.permute(1, 0).contiguous()
if l1_s.dtype != torch.float8_e4m3fn:
l1_s = l1_s.to(torch.float8_e4m3fn)
l1_fp4_list.append(l1_w)
l1_sf_list.append(l1_s)
l1_gs_list.append(w13_gs)
# L2: down (w2)
w2_uint8 = layer.w2_weight.data[expert_id] # (hidden, intermediate//2)
w2_sf = layer.w2_weight_scale.data[expert_id] # (hidden, intermediate//16) fp8
w2_gs = layer.w2_weight_scale_2.data[expert_id].item() # float32
l2_w = w2_uint8.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous()
l2_s = w2_sf.permute(1, 0).contiguous()
if l2_s.dtype != torch.float8_e4m3fn:
l2_s = l2_s.to(torch.float8_e4m3fn)
l2_fp4_list.append(l2_w)
l2_sf_list.append(l2_s)
l2_gs_list.append(w2_gs)
# Create the CuTeDSL runner
self._runner = CuTeDSLMoERunner(
num_experts=num_experts,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
max_num_tokens=self.max_num_tokens,
top_k=self.topk,
device=str(device),
experts_start_idx=self.local_expert_offset,
)
self._runner.prepare_weights_direct(
l1_fp4_list, l1_sf_list, l1_gs_list,
l2_fp4_list, l2_sf_list, l2_gs_list,
)
if self._swiglu_limit is not None:
self._runner.set_swiglu_limit(float(self._swiglu_limit))
# Read swiglu_limit from the FusedMoE layer (set by DeepseekV4MoE)
swiglu_limit = getattr(layer, 'swiglu_limit', None)
if swiglu_limit is not None:
self._swiglu_limit = swiglu_limit
self._runner.set_swiglu_limit(float(swiglu_limit))
# Set initial activation global scales from checkpoint input_scale.
# The CuTeDSL runner uses activation_gs = 1.0 / input_scale from the
# checkpoint as the starting value. The warmup step
# (compute_activation_global_scales) will override this with an
# empirically computed value before the first inference.
if w13_input_scale_orig is not None:
# input_scale = 448.0 / amax → activation_gs = 1.0 / input_scale = amax / 448.0
# Mean across experts (they should be similar)
mean_l1_gs = float(w13_input_scale_orig.mean().item())
if mean_l1_gs > 0:
self._runner._l1_activation_global_scale = 1.0 / mean_l1_gs
if w2_input_scale_orig is not None:
mean_l2_gs = float(w2_input_scale_orig.mean().item())
if mean_l2_gs > 0:
self._runner._l2_activation_global_scale = 1.0 / mean_l2_gs
# Note: activation global scale warmup must be done after
# process_weights_after_loading, before the first inference.
# This is handled by the model's load_weights or a separate warmup step.
@property
def runner(self) -> CuTeDSLMoERunner | None:
return self._runner
@staticmethod
def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
@staticmethod
def _supports_current_device() -> bool:
# CuTeDSL requires CUDA SM100 (Blackwell)
p = current_platform
return p.is_cuda() and p.is_device_capability_family(100)
@staticmethod
def _supports_no_act_and_mul() -> bool:
return False
@staticmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
SUPPORTED_W_A = [
(kNvfp4Static, kNvfp4Dynamic),
]
return (weight_key, activation_key) in SUPPORTED_W_A
@staticmethod
def _supports_activation(activation: MoEActivation) -> bool:
# We handle SiLU + swiglu_limit internally
return activation == MoEActivation.SILU
@staticmethod
def _supports_parallel_config(
moe_parallel_config: FusedMoEParallelConfig,
) -> bool:
return True
def supports_expert_map(self) -> bool:
return False
@property
def expects_unquantized_inputs(self) -> bool:
# Our runner handles activation quantization internally
return True
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
return TopKWeightAndReduceNoOP()
def workspace_shapes(
self,
M: int,
N: int,
K: int,
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: MoEActivation,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# Our runner manages its own workspace internally (pre-allocated buffers)
workspace1 = (0,)
workspace2 = (0,)
# K is packed (K//2 for uint8), so output uses hidden_dim
assert self.hidden_dim == K * 2
output = (M, self.hidden_dim)
return (workspace1, workspace2, output)
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: MoEActivation,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
a2_scale: torch.Tensor | None,
workspace13: torch.Tensor | None,
workspace2: torch.Tensor | None,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool | None,
):
assert self._runner is not None, (
"CuTeDSL runner not initialized. "
"Call process_weights_after_loading first."
)
# Our runner expects topk_ids as global expert IDs.
# The modular kernel framework may pass local IDs with expert_map.
# We handle remapping internally via experts_start_idx.
result = self._runner.run(
hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
)
# Copy result into output tensor
output.copy_(result)

View File

@@ -0,0 +1,535 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from enum import Enum
import torch
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.config.kernel import MoEBackend
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.all2all_utils import (
maybe_make_prepare_finalize,
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
nvfp4_moe_quant_config,
nvfp4_w4a16_moe_quant_config,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
prepare_nvfp4_moe_layer_for_fi_or_cutlass,
prepare_nvfp4_moe_layer_for_flashinfer_cutedsl,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
FlashinferMoeBackend,
get_flashinfer_moe_backend,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
prepare_nvfp4_moe_layer_for_marlin,
)
from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import (
kE2M1ToFloat_handle,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
)
logger = init_logger(__name__)
class NvFp4MoeBackend(Enum):
FLASHINFER_TRTLLM = "FLASHINFER_TRTLLM"
FLASHINFER_CUTLASS = "FLASHINFER_CUTLASS"
FLASHINFER_CUTEDSL = "FLASHINFER_CUTEDSL"
FLASHINFER_CUTEDSL_BATCHED = "FLASHINFER_CUTEDSL_BATCHED"
VLLM_CUTLASS = "VLLM_CUTLASS"
MARLIN = "MARLIN"
CUTEDSL = "CUTEDSL"
EMULATION = "EMULATION"
FLASHINFER_NVFP4_MOE_BACKENDS = [
NvFp4MoeBackend.FLASHINFER_TRTLLM,
NvFp4MoeBackend.FLASHINFER_CUTLASS,
NvFp4MoeBackend.FLASHINFER_CUTEDSL,
NvFp4MoeBackend.FLASHINFER_CUTEDSL_BATCHED,
]
CUTEDSL_NVFP4_MOE_BACKENDS = [
NvFp4MoeBackend.CUTEDSL,
]
fi_2_vllm_backend_map: dict[FlashinferMoeBackend, NvFp4MoeBackend] = {
FlashinferMoeBackend.CUTLASS: NvFp4MoeBackend.FLASHINFER_CUTLASS,
FlashinferMoeBackend.TENSORRT_LLM: NvFp4MoeBackend.FLASHINFER_TRTLLM,
FlashinferMoeBackend.CUTEDSL: NvFp4MoeBackend.FLASHINFER_CUTEDSL,
}
def is_global_sf_supported_for_nvfp4_backend(backend: NvFp4MoeBackend) -> bool:
# Checks whether `backend` supports quantizing with scaling factors
# of all experts in Expert Parallel Mode when all experts are not
# on the same rank.
return backend in FLASHINFER_NVFP4_MOE_BACKENDS or backend in CUTEDSL_NVFP4_MOE_BACKENDS
def backend_to_kernel_cls(
backend: NvFp4MoeBackend,
) -> list[type[mk.FusedMoEExperts]]:
if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
from vllm.model_executor.layers.fused_moe.experts.trtllm_nvfp4_moe import (
TrtLlmNvFp4ExpertsModular,
TrtLlmNvFp4ExpertsMonolithic,
)
# NOTE: prefer Monolthic > Modular, so return Monolithic first.
return [
TrtLlmNvFp4ExpertsMonolithic,
TrtLlmNvFp4ExpertsModular,
]
elif backend == NvFp4MoeBackend.FLASHINFER_CUTLASS:
from vllm.model_executor.layers.fused_moe.experts.flashinfer_cutlass_moe import ( # noqa: E501
FlashInferExperts,
)
return [FlashInferExperts]
elif backend == NvFp4MoeBackend.FLASHINFER_CUTEDSL:
from vllm.model_executor.layers.fused_moe.experts.flashinfer_cutedsl_moe import ( # noqa: E501
FlashInferCuteDSLExperts,
)
return [FlashInferCuteDSLExperts]
elif backend == NvFp4MoeBackend.FLASHINFER_CUTEDSL_BATCHED:
from vllm.model_executor.layers.fused_moe.experts.flashinfer_cutedsl_batched_moe import ( # noqa: E501
FlashInferCuteDSLBatchedExperts,
)
return [FlashInferCuteDSLBatchedExperts]
elif backend == NvFp4MoeBackend.CUTEDSL:
from vllm.model_executor.layers.fused_moe.experts.cutedsl_moe import ( # noqa: E501
CuTeDSLMoEExperts,
)
return [CuTeDSLMoEExperts]
elif backend == NvFp4MoeBackend.VLLM_CUTLASS:
from vllm.model_executor.layers.fused_moe.experts.cutlass_moe import (
CutlassExpertsFp4,
)
return [CutlassExpertsFp4]
elif backend == NvFp4MoeBackend.MARLIN:
from vllm.model_executor.layers.fused_moe.experts.marlin_moe import (
MarlinExperts,
)
return [MarlinExperts]
elif backend == NvFp4MoeBackend.EMULATION:
from vllm.model_executor.layers.fused_moe.experts.nvfp4_emulation_moe import (
Nvfp4QuantizationEmulationTritonExperts,
)
return [Nvfp4QuantizationEmulationTritonExperts]
else:
raise ValueError(f"Unknown NvFP4 MoE backend: {backend.value}")
def map_nvfp4_backend(runner_backend: MoEBackend) -> NvFp4MoeBackend:
"""Map user's MoEBackend to NvFp4MoeBackend."""
mapping = {
"cutlass": NvFp4MoeBackend.VLLM_CUTLASS,
"flashinfer_trtllm": NvFp4MoeBackend.FLASHINFER_TRTLLM,
"flashinfer_cutlass": NvFp4MoeBackend.FLASHINFER_CUTLASS,
"flashinfer_cutedsl": NvFp4MoeBackend.FLASHINFER_CUTEDSL,
"cutedsl": NvFp4MoeBackend.CUTEDSL,
"marlin": NvFp4MoeBackend.MARLIN,
"emulation": NvFp4MoeBackend.EMULATION,
}
if backend := mapping.get(runner_backend):
return backend
raise ValueError(
f"moe_backend='{runner_backend}' is not supported for NvFP4 MoE. "
f"Expected one of {list(mapping.keys())}."
)
def select_nvfp4_moe_backend(
config: FusedMoEConfig,
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> tuple[NvFp4MoeBackend, type[mk.FusedMoEExperts]]:
"""
Select the primary NvFP4 MoE backend
Note: Shape-specific fallbacks may still occur at runtime.
"""
# NOTE: the kernels are selected in the following order.
AVAILABLE_BACKENDS = [
NvFp4MoeBackend.FLASHINFER_TRTLLM,
NvFp4MoeBackend.FLASHINFER_CUTEDSL,
NvFp4MoeBackend.FLASHINFER_CUTEDSL_BATCHED,
NvFp4MoeBackend.CUTEDSL,
NvFp4MoeBackend.FLASHINFER_CUTLASS,
NvFp4MoeBackend.VLLM_CUTLASS,
NvFp4MoeBackend.MARLIN,
NvFp4MoeBackend.EMULATION,
]
use_batched = config.moe_parallel_config.use_batched_activation_format
activation_format = (
mk.FusedMoEActivationFormat.BatchedExperts
if use_batched
else mk.FusedMoEActivationFormat.Standard
)
def _make_log_backend(backend: NvFp4MoeBackend):
available_backend_strs = [b.value for b in AVAILABLE_BACKENDS]
return (
f"Using '{backend.value}' NvFp4 MoE backend out "
f"of potential backends: {available_backend_strs}."
)
def _make_log_unsupported(backend: NvFp4MoeBackend, reason: str | None) -> str:
if reason:
return (
f"NvFp4 MoE backend '{backend.value}' does not support the "
f"deployment configuration since {reason}."
)
else:
return (
f"NvFp4 MoE backend '{backend.value}' does not support the "
"deployment configuration."
)
def _return_or_raise(
backend: NvFp4MoeBackend,
config: FusedMoEConfig,
weight_key: QuantKey | None,
activation_key: QuantKey | None,
activation_format: mk.FusedMoEActivationFormat,
) -> tuple[NvFp4MoeBackend, type[mk.FusedMoEExperts]]:
for k_cls in backend_to_kernel_cls(backend):
supported, reason = k_cls.is_supported_config(
k_cls, config, weight_key, activation_key, activation_format
)
if supported:
logger.info_once(_make_log_backend(backend))
return backend, k_cls
raise ValueError(_make_log_unsupported(backend, reason))
# Handle explicit moe_backend from user.
runner_backend = config.moe_backend
if runner_backend != "auto":
requested_backend = map_nvfp4_backend(runner_backend)
# For batched activation format, use batched variant if available.
if (
activation_format == mk.FusedMoEActivationFormat.BatchedExperts
and requested_backend == NvFp4MoeBackend.FLASHINFER_CUTEDSL
):
requested_backend = NvFp4MoeBackend.FLASHINFER_CUTEDSL_BATCHED
return _return_or_raise(
requested_backend, config, weight_key, activation_key, activation_format
)
if envs.is_set("VLLM_USE_FLASHINFER_MOE_FP4"):
if not envs.VLLM_USE_FLASHINFER_MOE_FP4:
# If the user rejects FlashInfer remove those backends.
for b in FLASHINFER_NVFP4_MOE_BACKENDS:
AVAILABLE_BACKENDS.remove(b)
elif envs.is_set("VLLM_FLASHINFER_MOE_BACKEND"):
# If user is explicit about backend, validate it.
backend = fi_2_vllm_backend_map[get_flashinfer_moe_backend()]
return _return_or_raise(
backend, config, weight_key, activation_key, activation_format
)
else:
# If the user is not explicit about the backend, try each.
for backend in FLASHINFER_NVFP4_MOE_BACKENDS:
for k_cls in backend_to_kernel_cls(backend):
supported, reason = k_cls.is_supported_config(
k_cls,
config,
weight_key,
activation_key,
activation_format,
)
if supported:
logger.info_once(_make_log_backend(backend))
return backend, k_cls
else:
logger.debug_once(_make_log_unsupported(backend, reason))
raise NotImplementedError(
"Found VLLM_USE_FLASHINFER_MOE_FP4=1, but no "
"FlashInfer NVFP4 MoE backend supports the configuration."
)
if envs.VLLM_TEST_FORCE_FP8_MARLIN:
backend = NvFp4MoeBackend.MARLIN
return _return_or_raise(
backend, config, weight_key, activation_key, activation_format
)
# Select kernels in order of backend.
for backend in AVAILABLE_BACKENDS:
for k_cls in backend_to_kernel_cls(backend):
supported, reason = k_cls.is_supported_config(
k_cls,
config,
weight_key,
activation_key,
activation_format,
)
if supported:
logger.info_once(_make_log_backend(backend))
return backend, k_cls
else:
logger.debug_once(_make_log_unsupported(backend, reason))
raise NotImplementedError(
"No NvFp4 MoE backend supports the deployment configuration."
)
def convert_to_nvfp4_moe_kernel_format(
nvfp4_backend: NvFp4MoeBackend,
layer: torch.nn.Module,
w13: torch.Tensor,
w13_scale: torch.Tensor,
w13_scale_2: torch.Tensor,
a13_scale: torch.Tensor | None,
w2: torch.Tensor,
w2_scale: torch.Tensor,
w2_scale_2: torch.Tensor,
a2_scale: torch.Tensor | None,
is_act_and_mul: bool,
) -> tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
]:
if nvfp4_backend == NvFp4MoeBackend.CUTEDSL:
# CuTeDSL kernel handles weight conversion in its own
# process_weights_after_loading. Pass through raw weights.
# Compute inverse activation scales for the quant config.
if a13_scale is not None:
a13_scale = 1.0 / a13_scale
if a2_scale is not None:
a2_scale = 1.0 / a2_scale
elif nvfp4_backend == NvFp4MoeBackend.FLASHINFER_CUTEDSL:
(
w13,
w13_scale,
w13_scale_2,
a13_scale,
w2,
w2_scale,
w2_scale_2,
a2_scale,
) = prepare_nvfp4_moe_layer_for_flashinfer_cutedsl(
layer=layer,
w13=w13,
w13_scale=w13_scale,
w13_scale_2=w13_scale_2,
a13_scale=a13_scale,
w2=w2,
w2_scale=w2_scale,
w2_scale_2=w2_scale_2,
a2_scale=a2_scale,
)
elif (
nvfp4_backend in FLASHINFER_NVFP4_MOE_BACKENDS
or nvfp4_backend == NvFp4MoeBackend.VLLM_CUTLASS
):
(
w13,
w13_scale,
w13_scale_2,
a13_scale,
w2,
w2_scale,
w2_scale_2,
a2_scale,
) = prepare_nvfp4_moe_layer_for_fi_or_cutlass(
backend=nvfp4_backend,
layer=layer,
w13=w13,
w13_scale=w13_scale,
w13_scale_2=w13_scale_2,
a13_scale=a13_scale,
w2=w2,
w2_scale=w2_scale,
w2_scale_2=w2_scale_2,
a2_scale=a2_scale,
is_act_and_mul=is_act_and_mul,
)
elif nvfp4_backend == NvFp4MoeBackend.MARLIN:
a13_scale = None
a2_scale = None
(
w13,
w13_scale,
w13_scale_2,
w2,
w2_scale,
w2_scale_2,
) = prepare_nvfp4_moe_layer_for_marlin(
layer=layer,
w13=w13,
w13_scale=w13_scale,
w13_scale_2=w13_scale_2,
w2=w2,
w2_scale=w2_scale,
w2_scale_2=w2_scale_2,
is_act_and_mul=is_act_and_mul,
)
elif nvfp4_backend == NvFp4MoeBackend.EMULATION:
# Move the E2M1 lookup table to the device now, because
# `.to(device)` is not allowed during CUDA graph capture.
kE2M1ToFloat_handle.val = kE2M1ToFloat_handle.val.to(w13.device)
if a13_scale is None or a2_scale is None:
raise ValueError(
"Activation global scales should not be None, got"
f" a13_scale={a13_scale}, a2_scale={a2_scale}"
)
if torch.unique(a13_scale).numel() != 1 or torch.unique(a2_scale).numel() != 1:
logger.warning_once(
"In NVFP4 linear, the activation global scale for inputs are different"
" for MOE w13 (gate_up_proj) layer or MOE w2 (down_proj). Using"
" a13_scale = a13_scale.max() and a2_scale = a2_scale.max()."
)
# 1. We take the max following e.g. quantization/utils/flashinfer_fp4_moe.py.
# 2. moe_kernel_quantize_input -> ref_nvfp4_quant_dequant
# use the inverse scale directly (large global scale).
# NOTE: Before this point, `a13_scale` and `a2_scale` are such that:
# `FP8_MAX = activation[expert_id].abs().max() * global_scale[expert_id]`,
# and `global_scale[expert_id]` are small (~1e-4).
# Taking the largest global scale likely results in overflowing the FP8 range
# for other experts - other selection strategies may be used.
a13_scale = 1.0 / a13_scale.max().to(torch.float32)
a2_scale = 1.0 / a2_scale.max().to(torch.float32)
else:
raise ValueError(f"Unknown NvFp4 backend for MoE: {nvfp4_backend}")
return (
w13,
w13_scale,
w13_scale_2,
a13_scale,
w2,
w2_scale,
w2_scale_2,
a2_scale,
)
def make_nvfp4_moe_quant_config(
backend: NvFp4MoeBackend,
w13_scale: torch.Tensor,
w2_scale: torch.Tensor,
w13_scale_2: torch.Tensor,
w2_scale_2: torch.Tensor,
a13_scale: torch.Tensor,
a2_scale: torch.Tensor,
) -> FusedMoEQuantConfig:
if backend == NvFp4MoeBackend.MARLIN:
return nvfp4_w4a16_moe_quant_config(
g1_alphas=w13_scale_2,
g2_alphas=w2_scale_2,
w1_scale=w13_scale,
w2_scale=w2_scale,
)
elif backend == NvFp4MoeBackend.EMULATION:
return nvfp4_moe_quant_config(
g1_alphas=w13_scale_2,
g2_alphas=w2_scale_2,
a1_gscale=a13_scale,
a2_gscale=a2_scale,
w1_scale=w13_scale,
w2_scale=w2_scale,
)
# Pass w13_scale_2 / w2_scale_2 directly as g1/g2_alphas.
# The expert's process_weights_after_loading will fuse activation
# scales in-place. Since the quant config references the same tensor
# as the registered parameter, EPLB rearrangement stays in sync.
return nvfp4_moe_quant_config(
g1_alphas=w13_scale_2,
g2_alphas=w2_scale_2,
a1_gscale=(1.0 / a13_scale),
a2_gscale=(1.0 / a2_scale),
w1_scale=w13_scale,
w2_scale=w2_scale,
# NOTE(rob): this is a hack until the MoE kernels
# create their own quant configs. TRTLLM kernel
# does not accept swizzled input quant scales.
is_scale_swizzled=(
backend
not in (
NvFp4MoeBackend.FLASHINFER_TRTLLM,
NvFp4MoeBackend.FLASHINFER_CUTEDSL,
NvFp4MoeBackend.CUTEDSL,
)
),
)
def make_nvfp4_moe_kernel(
moe_quant_config: FusedMoEQuantConfig,
moe_config: FusedMoEConfig,
experts_cls: type[mk.FusedMoEExperts],
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEKernel:
# Create Prepare/Finalize.
prepare_finalize = maybe_make_prepare_finalize(
moe=moe_config,
quant_config=moe_quant_config,
routing_tables=routing_tables,
allow_new_interface=True,
use_monolithic=issubclass(experts_cls, mk.FusedMoEExpertsMonolithic),
)
assert prepare_finalize is not None
logger.info_once("Using %s", prepare_finalize.__class__.__name__)
# Create Experts.
if prepare_finalize.activation_format == mk.FusedMoEActivationFormat.BatchedExperts:
max_num_tokens = prepare_finalize.max_num_tokens_per_rank()
assert max_num_tokens is not None
experts = experts_cls(
moe_config=moe_config,
quant_config=moe_quant_config,
max_num_tokens=max_num_tokens,
num_dispatchers=prepare_finalize.num_dispatchers(),
)
else:
experts = experts_cls(
moe_config=moe_config,
quant_config=moe_quant_config,
)
kernel = mk.FusedMoEKernel(
prepare_finalize,
experts,
inplace=False,
)
# TODO(rob): update inplace logic to be part of the kernel.
return kernel

2378
vllm/patches/modelopt.py Normal file

File diff suppressed because it is too large Load Diff