Proper NVFP4 integration: use ModelOptNvFp4Config + FusedMoE framework
Major refactor to eliminate all post-load hacks: - deepseek_v4.py: use upstream model with NVFP4 weight mapper only (gate_proj→w1, up_proj→w3, down_proj→w2, .self_attn→.attn, .mlp→.ffn) - Add CuTeDSLMoEExperts as a FusedMoEExpertsModular subclass that wraps our CuTeDSL runner as a proper vLLM MoE backend - Register CUTEDSL backend in the NVFP4 oracle - Use ModelOptNvFp4Config for quantization dispatch (not DeepseekV4FP8Config) - ModelOptNvFp4LinearMethod handles NVFP4 attention/shared expert projections - Remove nvfp4_cutedsl.py, cutedsl_quant_method.py, utils.py from Dockerfile - CuTeDSL runner moved to cutedsl/runner.py for clean imports - cos_sin_cache float32 fix in deepseek_v4_attention.py No more monkey-patching, no _convert_nvfp4_post_load, no CuTeDSLNvfp4Method.
This commit is contained in:
16
Dockerfile
16
Dockerfile
@@ -30,17 +30,21 @@ ENV PYTHONPATH="/root/nvfp4-megamoe-kernel:${PYTHONPATH}"
|
||||
# Patch vLLM — overwrite model files and register architecture
|
||||
ARG VLLM_MODELS_DIR=/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models
|
||||
ARG VLLM_LAYERS_DIR=/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers
|
||||
ARG VLLM_QUANT_DIR=/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/quantization
|
||||
ARG VLLM_FUSED_MOE_DIR=/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe
|
||||
ARG VLLM_LOADER_DIR=/usr/local/lib/python3.12/dist-packages/vllm/model_executor/model_loader
|
||||
|
||||
# Core model patches
|
||||
COPY vllm/patches/deepseek_v4.py ${VLLM_MODELS_DIR}/deepseek_v4.py
|
||||
COPY vllm/patches/deepseek_v4_attention.py ${VLLM_LAYERS_DIR}/deepseek_v4_attention.py
|
||||
COPY vllm/nvfp4_cutedsl.py ${VLLM_MODELS_DIR}/nvfp4_cutedsl.py
|
||||
COPY vllm/cutedsl_quant_method.py ${VLLM_MODELS_DIR}/cutedsl_quant_method.py
|
||||
COPY cutedsl/nvfp4_linear.py /root/nvfp4-megamoe-kernel/cutedsl/nvfp4_linear.py
|
||||
COPY cutedsl/shared_expert_pipeline.py /root/nvfp4-megamoe-kernel/cutedsl/shared_expert_pipeline.py
|
||||
COPY vllm/patches/utils.py ${VLLM_LOADER_DIR}/utils.py
|
||||
|
||||
RUN sed -i 's/"DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),/"DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),\n "DeepseekV4ForCausalLM": ("deepseek_v4", "DeepseekV4ForCausalLM"),/' \
|
||||
# NVFP4 MoE backend registration
|
||||
COPY vllm/patches/fused_moe/oracle/nvfp4.py ${VLLM_FUSED_MOE_DIR}/oracle/nvfp4.py
|
||||
COPY vllm/patches/fused_moe/experts/cutedsl_moe.py ${VLLM_FUSED_MOE_DIR}/experts/cutedsl_moe.py
|
||||
|
||||
# Register DeepseekV4ForCausalLM model architecture (if not already in upstream)
|
||||
RUN grep -q '"DeepseekV4ForCausalLM"' ${VLLM_MODELS_DIR}/registry.py || \
|
||||
sed -i 's/"DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),/"DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),\n "DeepseekV4ForCausalLM": ("deepseek_v4", "DeepseekV4ForCausalLM"),/' \
|
||||
${VLLM_MODELS_DIR}/registry.py
|
||||
|
||||
# Verify
|
||||
|
||||
529
cutedsl/runner.py
Normal file
529
cutedsl/runner.py
Normal file
@@ -0,0 +1,529 @@
|
||||
"""
|
||||
vLLM integration for the CuTeDSL NVFP4 MoE kernel.
|
||||
|
||||
CUDA-graph-compatible design:
|
||||
- All intermediate buffers pre-allocated at max_num_tokens * top_k size
|
||||
- No .item(), .tolist(), .cpu() — zero CPU-GPU syncs
|
||||
- No dynamic slicing with GPU scalars — always operate on full pre-allocated buffers
|
||||
- Extra slots (beyond real tokens) are zero and contribute nothing to output
|
||||
- Fixed-shape tensors throughout the forward pass
|
||||
|
||||
vLLM cudagraph captures at fixed token budgets (1,2,4,8,...,8192).
|
||||
During capture, num_tokens equals the budget — all shapes are fixed.
|
||||
During replay, inputs are padded to the budget size. Our runner always
|
||||
processes max_slots = budget * top_k rows; padding rows are zeros.
|
||||
"""
|
||||
import torch
|
||||
|
||||
from cutedsl.bridge import (
|
||||
quantize_activation_nvfp4,
|
||||
quantize_weight_to_nvfp4,
|
||||
|
||||
|
||||
class _MoEApply(torch.autograd.Function):
|
||||
"""Custom autograd function to make CuTeDSL MoE runner opaque to torch.compile."""
|
||||
@staticmethod
|
||||
def forward(ctx, runner, hidden_states, topk_weights, topk_ids, expert_indices):
|
||||
return runner._run_impl(hidden_states, topk_weights, topk_ids, expert_indices)
|
||||
quantize_to_nvfp4,
|
||||
make_b_k_major,
|
||||
assemble_scales_3d_side,
|
||||
run_nvfp4_grouped_gemm,
|
||||
)
|
||||
from cutedsl.kernel.moe.torch_scaled_grouped_mm import (
|
||||
ceil_div as cutedsl_ceil_div,
|
||||
pad_and_swizzle_single,
|
||||
)
|
||||
|
||||
|
||||
class CuTeDSLMoERunner:
|
||||
"""Manages NVFP4 MoE execution via the CuTeDSL kernel.
|
||||
|
||||
CUDA-graph-compatible: all buffers pre-allocated, no CPU-GPU syncs,
|
||||
no dynamic shapes. Always computes at max_num_tokens * top_k capacity.
|
||||
"""
|
||||
|
||||
def __init__(self, num_experts, hidden_size, intermediate_size,
|
||||
max_num_tokens=8192, top_k=8, device="cuda",
|
||||
experts_start_idx=0):
|
||||
self.num_experts = num_experts
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.top_k = top_k
|
||||
self.device = device
|
||||
self.experts_start_idx = experts_start_idx
|
||||
self._swiglu_limit = None # Set via set_swiglu_limit()
|
||||
|
||||
# Weight storage (set before _ensure_stacked)
|
||||
self.l1_fp4 = None
|
||||
self.l1_sf = None
|
||||
self.l1_gs = None
|
||||
self.l2_fp4 = None
|
||||
self.l2_sf = None
|
||||
self.l2_gs = None
|
||||
|
||||
# Stacked weight tensors (set in _ensure_stacked)
|
||||
self._l1_mat_b = None
|
||||
self._l2_mat_b = None
|
||||
self._l1_scale_b = None
|
||||
self._l2_scale_b = None
|
||||
self._l1_gsb = None
|
||||
self._l2_gsb = None
|
||||
|
||||
# Default: 1/2688 ≈ 0.000372 (amax=1 → gs=1/2688)
|
||||
# Overridden in finalize_weights with checkpoint input_scale or warmup value
|
||||
self._l1_activation_global_scale = 1.0 / (6.0 * 448.0)
|
||||
self._l2_activation_global_scale = 1.0 / (6.0 * 448.0)
|
||||
|
||||
# Pre-allocated cudagraph buffers (set in _allocate_buffers)
|
||||
self._token_indices = None
|
||||
self._expert_id_range = None
|
||||
self._expert_offsets_buf = None
|
||||
self._per_expert_scale_bufs_l1 = None
|
||||
self._per_expert_scale_bufs_l2 = None
|
||||
self._padded_x_sf_buf_l1 = None
|
||||
self._padded_x_sf_buf_l2 = None
|
||||
self._l1_gsa_buf = None
|
||||
self._l2_gsa_buf = None
|
||||
self._output_buf = None
|
||||
self._row_indices_buf = None
|
||||
self._padded_hidden_buf = None
|
||||
self._padded_activated_buf = None # unused, using shared
|
||||
self._padded_expert_offsets_buf = None
|
||||
self._max_chunks_per_expert = cutedsl_ceil_div(
|
||||
self.max_num_tokens * self.top_k, self.num_experts * 128
|
||||
)
|
||||
self._buffers_allocated = False
|
||||
|
||||
def set_swiglu_limit(self, limit: float | None):
|
||||
"""Set the swiglu_limit for activation clamping."""
|
||||
self._swiglu_limit = limit
|
||||
|
||||
def _fill_token_indices(self):
|
||||
"""Fill _token_indices with [0,0,..0, 1,1,..1, ...] (each token repeated top_k times).
|
||||
|
||||
Builds on CPU first, then copies to GPU, to ensure correctness
|
||||
regardless of CuTeDSL JIT GPU memory corruption.
|
||||
"""
|
||||
src = torch.arange(self.max_num_tokens, dtype=torch.int32)
|
||||
cpu_indices = src.unsqueeze(1).expand(-1, self.top_k).contiguous().view(-1)
|
||||
self._token_indices.copy_(cpu_indices)
|
||||
|
||||
def _allocate_buffers(self):
|
||||
"""Pre-allocate scale buffers at max size for cudagraph compatibility."""
|
||||
# Per-expert scale buffers: separate L1/L2 since K_sf differs
|
||||
K_sf_l1 = cutedsl_ceil_div(self.hidden_size, 16)
|
||||
padded_cols_l1 = cutedsl_ceil_div(K_sf_l1, 4) * 4
|
||||
K_sf_l2 = cutedsl_ceil_div(self.intermediate_size, 16)
|
||||
padded_cols_l2 = cutedsl_ceil_div(K_sf_l2, 4) * 4
|
||||
|
||||
self._per_expert_scale_bufs_l1 = [
|
||||
torch.zeros(128, padded_cols_l1, dtype=torch.float16, device=self.device).to(torch.float8_e4m3fn)
|
||||
for _ in range(self.num_experts)
|
||||
]
|
||||
self._per_expert_scale_bufs_l2 = [
|
||||
torch.zeros(128, padded_cols_l2, dtype=torch.float16, device=self.device).to(torch.float8_e4m3fn)
|
||||
for _ in range(self.num_experts)
|
||||
]
|
||||
|
||||
# Initialize shared buffers dict (if not already)
|
||||
device_key = str(self.device)
|
||||
if not hasattr(CuTeDSLMoERunner, '_shared_padded_bufs'):
|
||||
CuTeDSLMoERunner._shared_padded_bufs = {}
|
||||
if device_key not in CuTeDSLMoERunner._shared_padded_bufs:
|
||||
CuTeDSLMoERunner._shared_padded_bufs[device_key] = {}
|
||||
|
||||
# Padded x_sf buffers: SHARED across all runners (not per-layer)
|
||||
max_sf_rows = self.num_experts * self._max_chunks_per_expert * 128
|
||||
if 'xsf_l1' not in CuTeDSLMoERunner._shared_padded_bufs[device_key]:
|
||||
CuTeDSLMoERunner._shared_padded_bufs[device_key].update({
|
||||
'xsf_l1': torch.zeros(
|
||||
max_sf_rows, padded_cols_l1, dtype=torch.float16, device=self.device
|
||||
).to(torch.float8_e4m3fn),
|
||||
'xsf_l2': torch.zeros(
|
||||
max_sf_rows, padded_cols_l2, dtype=torch.float16, device=self.device
|
||||
).to(torch.float8_e4m3fn),
|
||||
'output': torch.zeros(
|
||||
self.max_num_tokens, self.hidden_size, dtype=torch.bfloat16, device=self.device
|
||||
),
|
||||
})
|
||||
self._padded_x_sf_buf_l1 = CuTeDSLMoERunner._shared_padded_bufs[device_key]['xsf_l1']
|
||||
self._padded_x_sf_buf_l2 = CuTeDSLMoERunner._shared_padded_bufs[device_key]['xsf_l2']
|
||||
self._output_buf = CuTeDSLMoERunner._shared_padded_bufs[device_key]['output']
|
||||
|
||||
# Pre-allocated global_scale_a buffers (filled via .fill_(), no torch.full during capture)
|
||||
self._l1_gsa_buf = torch.zeros(self.num_experts, dtype=torch.float32, device=self.device)
|
||||
self._l2_gsa_buf = torch.zeros(self.num_experts, dtype=torch.float32, device=self.device)
|
||||
|
||||
# Row indices for scale assembly (max_num_tokens * top_k slots)
|
||||
self._row_indices_buf = torch.arange(
|
||||
self.max_num_tokens * self.top_k, device=self.device
|
||||
)
|
||||
|
||||
# Padded hidden/activated: SHARED across all runners (not per-layer)
|
||||
max_rows_per_expert = self._max_chunks_per_expert * 128
|
||||
padded_max_slots = self.num_experts * max_rows_per_expert
|
||||
if 'hidden' not in CuTeDSLMoERunner._shared_padded_bufs[device_key]:
|
||||
CuTeDSLMoERunner._shared_padded_bufs[device_key].update({
|
||||
'hidden': torch.zeros(
|
||||
padded_max_slots, self.hidden_size, dtype=torch.bfloat16, device=self.device
|
||||
),
|
||||
'hidden_fp4': torch.zeros(
|
||||
padded_max_slots, self.hidden_size // 2, dtype=torch.uint8, device=self.device
|
||||
).view(torch.float4_e2m1fn_x2),
|
||||
'activated': torch.zeros(
|
||||
padded_max_slots, self.intermediate_size, dtype=torch.bfloat16, device=self.device
|
||||
),
|
||||
'activated_fp4': torch.zeros(
|
||||
padded_max_slots, self.intermediate_size // 2, dtype=torch.uint8, device=self.device
|
||||
).view(torch.float4_e2m1fn_x2),
|
||||
})
|
||||
self._shared_bufs = CuTeDSLMoERunner._shared_padded_bufs[device_key]
|
||||
|
||||
# Padded expert offsets buffer: [0, max_rows, 2*max_rows, ...] (fixed)
|
||||
self._padded_expert_offsets_buf = torch.zeros(
|
||||
self.num_experts + 1, dtype=torch.int32, device=self.device
|
||||
)
|
||||
max_rows_per_expert = self._max_chunks_per_expert * 128
|
||||
self._padded_expert_offsets_buf[1:] = torch.arange(
|
||||
1, self.num_experts + 1, dtype=torch.int32, device=self.device
|
||||
) * max_rows_per_expert
|
||||
|
||||
self._buffers_allocated = True
|
||||
|
||||
def _ensure_stacked(self):
|
||||
if self._l1_mat_b is not None:
|
||||
return
|
||||
|
||||
# Stack and prepare weight tensors FIRST (triggers CuTeDSL JIT compilation)
|
||||
self._l1_mat_b = make_b_k_major(torch.stack(self.l1_fp4))
|
||||
self._l2_mat_b = make_b_k_major(torch.stack(self.l2_fp4))
|
||||
self._l1_scale_b = assemble_scales_3d_side(self.l1_sf)
|
||||
self._l2_scale_b = assemble_scales_3d_side(self.l2_sf)
|
||||
self._l1_gsb = torch.tensor(self.l1_gs, dtype=torch.float32, device=self.device)
|
||||
self._l2_gsb = torch.tensor(self.l2_gs, dtype=torch.float32, device=self.device)
|
||||
self.l1_fp4 = None
|
||||
self.l1_sf = None
|
||||
self.l1_gs = None
|
||||
self.l2_fp4 = None
|
||||
self.l2_sf = None
|
||||
self.l2_gs = None
|
||||
|
||||
# Allocate buffers AFTER JIT compilation
|
||||
# (CuTeDSL's cute.compile corrupts GPU memory during JIT;
|
||||
# tensors allocated before/during compilation may be zeroed)
|
||||
#
|
||||
# _token_indices: GPU tensor for cudagraph compatibility.
|
||||
# CuTeDSL JIT may corrupt GPU memory, so we fill AFTER stacking
|
||||
# (which triggers the weight JIT). The GEMM JIT in run_nvfp4_grouped_gemm
|
||||
# is triggered on the first run() call; we refill _token_indices after
|
||||
# that first call via the _needs_token_refill flag.
|
||||
self._token_indices = torch.zeros(
|
||||
self.max_num_tokens * self.top_k, dtype=torch.int32, device=self.device
|
||||
)
|
||||
self._fill_token_indices()
|
||||
self._needs_token_refill = True # GEMM JIT may corrupt; refill after first run
|
||||
|
||||
self._expert_id_range = torch.arange(
|
||||
self.num_experts, dtype=torch.int32
|
||||
).to(self.device)
|
||||
self._expert_offsets_buf = torch.zeros(
|
||||
self.num_experts + 1, dtype=torch.int32, device=self.device
|
||||
)
|
||||
self._allocate_buffers()
|
||||
|
||||
def prepare_weights_direct(self, l1_fp4, l1_sf, l1_gs, l2_fp4, l2_sf, l2_gs):
|
||||
self.l1_fp4 = l1_fp4
|
||||
self.l1_sf = l1_sf
|
||||
self.l1_gs = l1_gs
|
||||
self.l2_fp4 = l2_fp4
|
||||
self.l2_sf = l2_sf
|
||||
self.l2_gs = l2_gs
|
||||
self._l1_mat_b = None
|
||||
|
||||
def prepare_weights_from_dequantized(self, l1_weights_bf16, l2_weights_bf16):
|
||||
self.l1_fp4, self.l1_sf, self.l1_gs = [], [], []
|
||||
self.l2_fp4, self.l2_sf, self.l2_gs = [], [], []
|
||||
for l1_w, l2_w in zip(l1_weights_bf16, l2_weights_bf16):
|
||||
l1_w_t = l1_w.T
|
||||
w_fp4, w_sf, w_gs = quantize_weight_to_nvfp4(l1_w_t)
|
||||
self.l1_fp4.append(w_fp4)
|
||||
self.l1_sf.append(w_sf)
|
||||
self.l1_gs.append(w_gs)
|
||||
l2_w_t = l2_w.T
|
||||
w_fp4, w_sf, w_gs = quantize_weight_to_nvfp4(l2_w_t)
|
||||
self.l2_fp4.append(w_fp4)
|
||||
self.l2_sf.append(w_sf)
|
||||
self.l2_gs.append(w_gs)
|
||||
self._l1_mat_b = None
|
||||
|
||||
def _assemble_scales_cudagraph_safe(self, x_sf, expert_offsets,
|
||||
padded_expert_offsets,
|
||||
padded_x_sf_buf, per_expert_bufs):
|
||||
"""Assemble 2D-side activation scales (cudagraph-safe, NO CPU syncs).
|
||||
|
||||
Phase 1: Scatter x_sf into padded per-expert sections (GPU-only).
|
||||
Phase 2: Apply full-buffer Blackwell 32_4_4 swizzle (no Python loops).
|
||||
|
||||
The buffer is 128-row aligned per expert (from padded_expert_offsets),
|
||||
so the full-buffer swizzle produces the correct layout. The GEMM reads
|
||||
scale_a using padded_expert_offsets, matching the scatter layout.
|
||||
"""
|
||||
K_sf = x_sf.shape[1]
|
||||
padded_x_sf = padded_x_sf_buf
|
||||
padded_x_sf.zero_()
|
||||
|
||||
# Phase 1: Scatter x_sf into padded per-expert sections (GPU-only)
|
||||
total_rows = x_sf.shape[0]
|
||||
row_indices = self._row_indices_buf[:total_rows]
|
||||
expert_assign = torch.searchsorted(
|
||||
expert_offsets[1:], row_indices, right=True
|
||||
).clamp(max=self.num_experts - 1)
|
||||
local_row = row_indices - expert_offsets[expert_assign]
|
||||
dst_rows = padded_expert_offsets[expert_assign] + local_row
|
||||
padded_x_sf[dst_rows, :K_sf] = x_sf
|
||||
|
||||
# Phase 2: Full-buffer swizzle (no CPU sync, no Python loops)
|
||||
# padded_x_sf is 128-row aligned per expert and 4-col aligned.
|
||||
# to_blocked: (rows, cols) → view(R, 128, C, 4) → permute(0,2,1,3)
|
||||
# → reshape(-1, 4, 32, 4) → transpose(1,2) → reshape(-1, 32, 16) → flatten
|
||||
rows = padded_x_sf.shape[0]
|
||||
cols = padded_x_sf.shape[1]
|
||||
R = rows // 128
|
||||
C = cols // 4
|
||||
blocks = padded_x_sf.view(R, 128, C, 4).permute(0, 2, 1, 3)
|
||||
rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16)
|
||||
swizzled = rearranged.flatten().view(torch.float8_e4m3fn)
|
||||
return swizzled.reshape(rows, cols)
|
||||
|
||||
def compute_activation_global_scales(self, hidden_states_sample, topk_weights, topk_ids):
|
||||
"""Compute activation global scales from a warmup forward pass.
|
||||
|
||||
Called BEFORE cudagraph capture. Uses the SAME padded GEMM path as run()
|
||||
to ensure kernel JIT happens with the same layout, and L2 gs is computed
|
||||
from actual L1 output (not an approximation).
|
||||
"""
|
||||
self._ensure_stacked()
|
||||
device = hidden_states_sample.device
|
||||
num_tokens = hidden_states_sample.shape[0]
|
||||
top_k = topk_ids.shape[1]
|
||||
|
||||
with torch.no_grad():
|
||||
# Build slot mapping (same as run())
|
||||
flat_ids = topk_ids.reshape(-1)
|
||||
num_slots = num_tokens * top_k
|
||||
token_indices = self._token_indices[:num_slots]
|
||||
sort_idx = flat_ids.argsort(stable=True)
|
||||
sorted_ids = flat_ids[sort_idx]
|
||||
sorted_token_ids = token_indices[sort_idx]
|
||||
slot_hidden = hidden_states_sample[sorted_token_ids]
|
||||
|
||||
# L1: get exact gs from quantize_to_nvfp4
|
||||
_, _, l1_gs = quantize_to_nvfp4(slot_hidden)
|
||||
|
||||
# Quantize slot_hidden for GEMM
|
||||
slot_x_fp4, slot_x_sf = quantize_activation_nvfp4(slot_hidden, l1_gs)
|
||||
|
||||
expert_id_range = self._expert_id_range
|
||||
tokens_per_expert = (sorted_ids.unsqueeze(1) == expert_id_range.unsqueeze(0)).sum(dim=0).int()
|
||||
expert_offsets = self._expert_offsets_buf
|
||||
expert_offsets.zero_()
|
||||
expert_offsets[1:self.num_experts + 1] = tokens_per_expert.cumsum(0)
|
||||
|
||||
padded_tokens_per_expert = ((tokens_per_expert + 127) // 128) * 128
|
||||
padded_expert_offsets = self._padded_expert_offsets_buf
|
||||
padded_expert_offsets.zero_()
|
||||
padded_expert_offsets[1:self.num_experts + 1] = padded_tokens_per_expert.cumsum(0)
|
||||
|
||||
# Compute padded_dst (same as run())
|
||||
row_indices = self._row_indices_buf[:num_slots]
|
||||
expert_assign = torch.searchsorted(
|
||||
expert_offsets[1:], row_indices, right=True
|
||||
).clamp(max=self.num_experts - 1)
|
||||
local_row = row_indices - expert_offsets[expert_assign]
|
||||
padded_dst = padded_expert_offsets[expert_assign] + local_row
|
||||
|
||||
# Scatter x_fp4 into padded layout
|
||||
padded_x_fp4 = self._shared_bufs['hidden_fp4']
|
||||
padded_x_fp4.view(torch.uint8).zero_()
|
||||
padded_x_fp4.view(torch.uint8)[padded_dst] = slot_x_fp4.view(torch.uint8)
|
||||
|
||||
l1_scale_a = self._assemble_scales_cudagraph_safe(
|
||||
slot_x_sf, expert_offsets[:self.num_experts + 1],
|
||||
padded_expert_offsets,
|
||||
self._padded_x_sf_buf_l1, self._per_expert_scale_bufs_l1
|
||||
)
|
||||
l1_gsa = torch.full((self.num_experts,), l1_gs, dtype=torch.float32, device=device)
|
||||
|
||||
l1_out = run_nvfp4_grouped_gemm(
|
||||
mat_a=padded_x_fp4, mat_b=self._l1_mat_b,
|
||||
scale_a=l1_scale_a, scale_b=self._l1_scale_b,
|
||||
expert_offsets=padded_expert_offsets[1:self.num_experts + 1],
|
||||
global_scale_a=l1_gsa, global_scale_b=self._l1_gsb,
|
||||
)
|
||||
|
||||
# Extract real token outputs
|
||||
l1_out_real = l1_out[padded_dst]
|
||||
|
||||
# L2: get exact gs from SiLU(gate)*up
|
||||
gate = l1_out_real[:, :self.intermediate_size]
|
||||
up = l1_out_real[:, self.intermediate_size:]
|
||||
gate_silu = torch.nn.functional.silu(gate)
|
||||
if self._swiglu_limit is not None:
|
||||
gate_silu = gate_silu.clamp(max=self._swiglu_limit)
|
||||
up = up.clamp(min=-self._swiglu_limit, max=self._swiglu_limit)
|
||||
activated = gate_silu * up
|
||||
_, _, l2_gs = quantize_to_nvfp4(activated)
|
||||
|
||||
self._l1_activation_global_scale = l1_gs
|
||||
self._l2_activation_global_scale = l2_gs
|
||||
|
||||
|
||||
|
||||
def run(self, hidden_states, topk_weights, topk_ids, expert_indices=None):
|
||||
"""Forward: route tokens to experts, GEMM, combine."""
|
||||
return _MoEApply.apply(self, hidden_states, topk_weights, topk_ids, expert_indices)
|
||||
|
||||
def _run_impl(self, hidden_states, topk_weights, topk_ids, expert_indices=None):
|
||||
"""Run the NVFP4 MoE forward pass.
|
||||
|
||||
Handles global→local expert ID remapping for expert parallelism.
|
||||
Fully cudagraph-safe: no CPU-GPU syncs, no dynamic shapes.
|
||||
|
||||
Each expert's slots are padded to multiples of 128 for the GEMM.
|
||||
expert_offsets is [0, padded_e0, padded_e0+padded_e1, ...].
|
||||
scale_a is produced at those same offsets.
|
||||
"""
|
||||
num_tokens = hidden_states.shape[0]
|
||||
top_k = topk_ids.shape[1]
|
||||
device = hidden_states.device
|
||||
|
||||
self._ensure_stacked()
|
||||
|
||||
# -- Remap global expert IDs to local IDs --
|
||||
local_ids = topk_ids - self.experts_start_idx
|
||||
local_mask = (local_ids >= 0) & (local_ids < self.num_experts)
|
||||
safe_ids = local_ids.clamp(0, self.num_experts - 1)
|
||||
safe_weights = topk_weights * local_mask.float()
|
||||
|
||||
# -- Build slot mapping --
|
||||
flat_ids = safe_ids.reshape(-1)
|
||||
flat_weights = safe_weights.reshape(-1)
|
||||
num_slots = num_tokens * top_k
|
||||
token_indices = self._token_indices[:num_slots]
|
||||
|
||||
sort_idx = flat_ids.argsort(stable=True)
|
||||
sorted_ids = flat_ids[sort_idx]
|
||||
sorted_weights = flat_weights[sort_idx]
|
||||
sorted_token_ids = token_indices[sort_idx]
|
||||
|
||||
# Expert offsets (real token counts)
|
||||
expert_id_range = self._expert_id_range
|
||||
tokens_per_expert = (sorted_ids.unsqueeze(1) == expert_id_range.unsqueeze(0)).sum(dim=0).int()
|
||||
expert_offsets = self._expert_offsets_buf
|
||||
expert_offsets.zero_()
|
||||
expert_offsets[1:self.num_experts + 1] = tokens_per_expert.cumsum(0)
|
||||
|
||||
# Pad each expert to 128-row alignment (GPU-only computation)
|
||||
padded_tokens_per_expert = ((tokens_per_expert + 127) // 128) * 128
|
||||
padded_expert_offsets = self._padded_expert_offsets_buf
|
||||
padded_expert_offsets.zero_()
|
||||
padded_expert_offsets[1:self.num_experts + 1] = padded_tokens_per_expert.cumsum(0)
|
||||
total_padded_slots = padded_expert_offsets[self.num_experts]
|
||||
|
||||
# -- Gather hidden states into slot order, compute padded_dst --
|
||||
slot_hidden = hidden_states[sorted_token_ids]
|
||||
row_indices = self._row_indices_buf[:num_slots]
|
||||
expert_assign = torch.searchsorted(
|
||||
expert_offsets[1:], row_indices, right=True
|
||||
).clamp(max=self.num_experts - 1)
|
||||
local_row = row_indices - expert_offsets[expert_assign]
|
||||
padded_dst = padded_expert_offsets[expert_assign] + local_row
|
||||
|
||||
# === L1: gate + up ===
|
||||
# Quantize slot_hidden (sorted tokens), NOT padded_hidden.
|
||||
# padded_hidden is padded with zeros; quantizing it produces
|
||||
# x_sf rows at padded positions, but x_sf[:num_slots] would
|
||||
# only get scales for the first num_slots PADDED rows (expert 0),
|
||||
# not the scattered token positions. Quantizing slot_hidden
|
||||
# gives x_sf with num_slots rows (one per token), which the
|
||||
# scale assembly correctly scatters into padded layout.
|
||||
slot_x_fp4, slot_x_sf = quantize_activation_nvfp4(
|
||||
slot_hidden, self._l1_activation_global_scale
|
||||
)
|
||||
# Scatter x_fp4 into padded layout for the GEMM
|
||||
# Must scatter as uint8 (float4_e2m1fn_x2 doesn't support index_put)
|
||||
padded_x_fp4 = self._shared_bufs['hidden_fp4']
|
||||
padded_x_fp4.view(torch.uint8).zero_()
|
||||
padded_x_fp4.view(torch.uint8)[padded_dst] = slot_x_fp4.view(torch.uint8)
|
||||
|
||||
l1_scale_a = self._assemble_scales_cudagraph_safe(
|
||||
slot_x_sf, expert_offsets[:self.num_experts + 1],
|
||||
padded_expert_offsets,
|
||||
self._padded_x_sf_buf_l1, self._per_expert_scale_bufs_l1
|
||||
)
|
||||
l1_gsa = self._l1_gsa_buf.fill_(self._l1_activation_global_scale)
|
||||
|
||||
l1_out = run_nvfp4_grouped_gemm(
|
||||
mat_a=padded_x_fp4, mat_b=self._l1_mat_b,
|
||||
scale_a=l1_scale_a, scale_b=self._l1_scale_b,
|
||||
expert_offsets=padded_expert_offsets[1:self.num_experts + 1],
|
||||
global_scale_a=l1_gsa, global_scale_b=self._l1_gsb,
|
||||
)
|
||||
|
||||
# Extract real token outputs from padded GEMM output
|
||||
l1_out_real = l1_out[padded_dst]
|
||||
|
||||
# === SiLU(gate) * up (with swiglu_limit clamp) ===
|
||||
gate = l1_out_real[:, :self.intermediate_size]
|
||||
up = l1_out_real[:, self.intermediate_size:]
|
||||
gate_silu = torch.nn.functional.silu(gate)
|
||||
# Apply DeepSeek-V4 swiglu_limit: clamp both silu(gate) and up
|
||||
if self._swiglu_limit is not None:
|
||||
gate_silu = gate_silu.clamp(max=self._swiglu_limit)
|
||||
up = up.clamp(min=-self._swiglu_limit, max=self._swiglu_limit)
|
||||
activated = gate_silu * up
|
||||
|
||||
# === L2: down ===
|
||||
# Quantize activated (per-token), scatter into padded FP4 buffer
|
||||
slot_l2_x_fp4, slot_l2_x_sf = quantize_activation_nvfp4(
|
||||
activated, self._l2_activation_global_scale
|
||||
)
|
||||
padded_activated_fp4 = self._shared_bufs['activated_fp4']
|
||||
padded_activated_fp4.view(torch.uint8).zero_()
|
||||
padded_activated_fp4.view(torch.uint8)[padded_dst] = slot_l2_x_fp4.view(torch.uint8)
|
||||
|
||||
l2_scale_a = self._assemble_scales_cudagraph_safe(
|
||||
slot_l2_x_sf, expert_offsets[:self.num_experts + 1],
|
||||
padded_expert_offsets,
|
||||
self._padded_x_sf_buf_l2, self._per_expert_scale_bufs_l2
|
||||
)
|
||||
l2_gsa = self._l2_gsa_buf.fill_(self._l2_activation_global_scale)
|
||||
|
||||
l2_out = run_nvfp4_grouped_gemm(
|
||||
mat_a=padded_activated_fp4, mat_b=self._l2_mat_b,
|
||||
scale_a=l2_scale_a, scale_b=self._l2_scale_b,
|
||||
expert_offsets=padded_expert_offsets[1:self.num_experts + 1],
|
||||
global_scale_a=l2_gsa, global_scale_b=self._l2_gsb,
|
||||
)
|
||||
|
||||
l2_out_real = l2_out[padded_dst]
|
||||
|
||||
# === Scatter -> final output ===
|
||||
y = self._output_buf[:num_tokens]
|
||||
y.zero_()
|
||||
weighted_out = l2_out_real * sorted_weights.unsqueeze(1).to(l2_out_real.dtype)
|
||||
y.scatter_add_(
|
||||
0,
|
||||
sorted_token_ids.unsqueeze(1).expand(-1, self.hidden_size),
|
||||
weighted_out,
|
||||
)
|
||||
|
||||
# Refill _token_indices after GEMM JIT on first call
|
||||
# (CuTeDSL's cute.compile may corrupt GPU memory during first GEMM)
|
||||
if self._needs_token_refill:
|
||||
self._fill_token_indices()
|
||||
self._needs_token_refill = False
|
||||
|
||||
return y
|
||||
File diff suppressed because it is too large
Load Diff
@@ -14,6 +14,7 @@ import torch.nn.functional as F
|
||||
from transformers import DeepseekV2Config, DeepseekV3Config
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.breakable_cudagraph import eager_break_during_capture
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ReplicatedLinear,
|
||||
)
|
||||
@@ -28,6 +29,7 @@ from vllm.v1.attention.ops.deepseek_v4_ops import (
|
||||
fused_inv_rope_fp8_quant,
|
||||
fused_q_kv_rmsnorm,
|
||||
)
|
||||
from vllm.v1.attention.ops.rocm_aiter_mla_sparse import rocm_inv_rope_einsum
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.attention.backends.mla.sparse_swa import (
|
||||
@@ -45,7 +47,7 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.custom_op import PluggableLayer
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.layers.deepseek_compressor import DeepseekCompressor
|
||||
from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import (
|
||||
QuantFP8,
|
||||
@@ -53,6 +55,7 @@ from vllm.model_executor.layers.quantization.input_quant_fp8 import (
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.multi_stream_utils import (
|
||||
execute_in_parallel,
|
||||
maybe_execute_in_parallel,
|
||||
@@ -198,8 +201,6 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
|
||||
# Pick fp8_einsum recipe based on GPU arch:
|
||||
# SM90: FP32 block scales stay [g, r/128, d/128] → sfb_gran_mn=128
|
||||
# SM100: INT32 packed scales become [g, r, ...] → sfb_gran_mn=1
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
cap = current_platform.get_device_capability()
|
||||
assert cap is not None, "DeepseekV4 attention requires a CUDA device"
|
||||
self._einsum_recipe = (1, 128, 128) if cap.major <= 9 else (1, 1, 128)
|
||||
@@ -222,6 +223,7 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
|
||||
+ 1 # 1B pad
|
||||
)
|
||||
|
||||
# Will be None on ROCm for now.
|
||||
self.aux_stream_list = mla_modules.aux_stream_list
|
||||
# [0]: GEMM start / post-GEMM event0. [1..3]: GEMM done events;
|
||||
# [1] doubles as post-GEMM event1. Reuse is safe: GEMM fully joins
|
||||
@@ -303,6 +305,19 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
|
||||
)
|
||||
o = o_padded[:, : self.n_local_heads, :]
|
||||
|
||||
# Keep ROCm on the BF16 reference wo_a path util kernel ready.
|
||||
if current_platform.is_rocm():
|
||||
z = rocm_inv_rope_einsum(
|
||||
self.rotary_emb,
|
||||
o,
|
||||
positions,
|
||||
self.rope_head_dim,
|
||||
self.n_local_groups,
|
||||
self.o_lora_rank,
|
||||
self.wo_a,
|
||||
)
|
||||
return self.wo_b(z.flatten(1))
|
||||
|
||||
# O projection: inverse RoPE + FP8 quant + einsum + wo_b
|
||||
o_fp8, o_scale = fused_inv_rope_fp8_quant(
|
||||
o,
|
||||
@@ -336,12 +351,15 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
|
||||
return self.wo_b(z.flatten(1))
|
||||
|
||||
def attn_gemm_parallel_execute(self, hidden_states) -> tuple[Any, ...]:
|
||||
assert self.aux_stream_list is not None
|
||||
assert len(self.aux_stream_list) >= 3
|
||||
aux_streams = self.aux_stream_list
|
||||
if aux_streams is not None:
|
||||
assert len(aux_streams) >= 3
|
||||
aux_streams = aux_streams[:3]
|
||||
|
||||
# fused_wqa_wkv (heaviest) on default; the three lighter input GEMMs
|
||||
# on aux streams 0..2 when their owning module exists. ln_events[0]
|
||||
# is the fan-out start event; ln_events[1..3] are per-aux done events.
|
||||
# On ROCm, aux_streams is None and execute_in_parallel runs serially.
|
||||
aux_fns: list[Callable[[], Any] | None] = [None, None, None]
|
||||
|
||||
if self.compressor is not None:
|
||||
@@ -385,7 +403,7 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
|
||||
aux_fns,
|
||||
self.ln_events[0],
|
||||
self.ln_events[1:4],
|
||||
self.aux_stream_list[:3],
|
||||
aux_streams,
|
||||
enable=hidden_states.shape[0]
|
||||
<= envs.VLLM_MULTI_STREAM_GEMM_TOKEN_THRESHOLD,
|
||||
)
|
||||
@@ -419,8 +437,9 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
|
||||
# downstream reads q on default). Indexer/compressor go on aux for
|
||||
# overlap with default's GEMM + cache write.
|
||||
if self.indexer is not None:
|
||||
assert self.aux_stream_list is not None
|
||||
aux_stream = self.aux_stream_list[0]
|
||||
aux_stream = (
|
||||
self.aux_stream_list[0] if self.aux_stream_list is not None else None
|
||||
)
|
||||
indexer = self.indexer
|
||||
# Local ref so the closure keeps a non-None type for mypy.
|
||||
assert self.compressor is not None
|
||||
@@ -448,8 +467,9 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
|
||||
)
|
||||
elif self.compressor is not None:
|
||||
# wq_b + kv_insert on default, compressor on aux.
|
||||
assert self.aux_stream_list is not None
|
||||
aux_stream = self.aux_stream_list[0]
|
||||
aux_stream = (
|
||||
self.aux_stream_list[0] if self.aux_stream_list is not None else None
|
||||
)
|
||||
compressor = self.compressor
|
||||
|
||||
def wq_b_kv_insert() -> torch.Tensor:
|
||||
@@ -534,6 +554,7 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
|
||||
)
|
||||
|
||||
|
||||
@eager_break_during_capture
|
||||
def deepseek_v4_attention(
|
||||
hidden_states: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
@@ -668,7 +689,7 @@ class DeepseekV4MLAAttention(nn.Module, AttentionLayerBase):
|
||||
vllm_config.scheduler_config.max_num_batched_tokens
|
||||
)
|
||||
self.max_model_len = vllm_config.model_config.max_model_len
|
||||
# DeepseekV4 only supports fp8 kv-cache format for now
|
||||
# DeepseekV4 only supports fp8 kv-cache format for now.
|
||||
kv_cache_dtype = cache_config.cache_dtype if cache_config is not None else "fp8"
|
||||
|
||||
assert kv_cache_dtype.startswith("fp8"), (
|
||||
@@ -702,6 +723,12 @@ class DeepseekV4MLAAttention(nn.Module, AttentionLayerBase):
|
||||
self.kv_cache = torch.tensor([])
|
||||
|
||||
def get_attn_backend(self) -> type[AttentionBackend]:
|
||||
if current_platform.is_rocm():
|
||||
from vllm.v1.attention.backends.mla.rocm_aiter_mla_sparse_dsv4 import (
|
||||
DeepseekV4ROCMAiterMLASparseBackend,
|
||||
)
|
||||
|
||||
return DeepseekV4ROCMAiterMLASparseBackend
|
||||
return DeepseekV4FlashMLASparseBackend
|
||||
|
||||
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None:
|
||||
@@ -734,6 +761,14 @@ class DeepseekV4MLAAttention(nn.Module, AttentionLayerBase):
|
||||
f"output buffer dtype {output.dtype} must match q dtype {q.dtype}"
|
||||
)
|
||||
|
||||
if current_platform.is_rocm():
|
||||
from vllm.v1.attention.backends.mla.rocm_aiter_mla_sparse_dsv4 import (
|
||||
DeepseekV4ROCMAiterMLASparseImpl,
|
||||
)
|
||||
|
||||
DeepseekV4ROCMAiterMLASparseImpl.forward(self, q, kv, positions, output)
|
||||
return
|
||||
|
||||
# Get SWA and indexer metadata from forward context
|
||||
forward_context = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
@@ -979,8 +1014,7 @@ class DeepseekV4MLAAttention(nn.Module, AttentionLayerBase):
|
||||
M,
|
||||
N,
|
||||
)
|
||||
|
||||
output_chunk, _, _ = flash_mla_sparse_fwd(
|
||||
flash_mla_sparse_fwd(
|
||||
q=q[query_start:query_end],
|
||||
kv=kv.view(-1, 1, q.shape[-1]),
|
||||
indices=combined_indices.unsqueeze(1),
|
||||
@@ -1077,7 +1111,6 @@ class DeepseekV4Indexer(nn.Module):
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.weights_proj",
|
||||
)
|
||||
self.k_norm = LayerNorm(self.head_dim, eps=1e-6)
|
||||
self.softmax_scale = self.head_dim**-0.5
|
||||
|
||||
self.scale_fmt = "ue8m0"
|
||||
|
||||
303
vllm/patches/fused_moe/experts/cutedsl_moe.py
Normal file
303
vllm/patches/fused_moe/experts/cutedsl_moe.py
Normal file
@@ -0,0 +1,303 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""CuTeDSL NVFP4 MoE experts for DeepSeek-V4.
|
||||
|
||||
Integrates the CuTeDSL NVFP4 grouped GEMM kernel into vLLM's FusedMoE
|
||||
modular kernel framework. This is the proper integration path — no
|
||||
monkey-patching, no post-load surgery.
|
||||
|
||||
The CuTeDSL kernel is a Python-based CUTLASS kernel compiled via MLIR → PTX.
|
||||
It handles:
|
||||
- L1 GEMM (gate + up projections)
|
||||
- SiLU activation with optional swiglu_limit clamping
|
||||
- L2 GEMM (down projection)
|
||||
- All with NVFP4 (float8_e4m3fn block scales + float32 global scales)
|
||||
|
||||
CUDA-graph-safe: all intermediate buffers pre-allocated, no CPU-GPU syncs,
|
||||
no dynamic shapes.
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEParallelConfig,
|
||||
FusedMoEQuantConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceNoOP,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
kNvfp4Dynamic,
|
||||
kNvfp4Static,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from cutedsl.runner import CuTeDSLMoERunner
|
||||
|
||||
|
||||
class CuTeDSLMoEExperts(mk.FusedMoEExpertsModular):
|
||||
"""CuTeDSL NVFP4 MoE experts using the custom CuTeDSL grouped GEMM kernel.
|
||||
|
||||
Uses Standard activation format (non-batched). Handles input quantization
|
||||
internally (expects_unquantized_inputs=True).
|
||||
|
||||
Supports expert parallelism: remaps global→local expert IDs.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
moe_config: FusedMoEConfig,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
):
|
||||
super().__init__(
|
||||
moe_config=moe_config,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
assert quant_config.quant_dtype == "nvfp4", (
|
||||
"CuTeDSL MoE only supports nvfp4 quantization, "
|
||||
f"got {quant_config.quant_dtype}"
|
||||
)
|
||||
self.out_dtype = moe_config.in_dtype
|
||||
self.hidden_dim = moe_config.hidden_dim
|
||||
self.intermediate_size_per_partition = (
|
||||
moe_config.intermediate_size_per_partition
|
||||
)
|
||||
self.topk = moe_config.experts_per_token
|
||||
self.local_num_experts = moe_config.num_local_experts
|
||||
self.global_num_experts = moe_config.num_experts
|
||||
self.ep_rank = moe_config.moe_parallel_config.ep_rank
|
||||
self.local_expert_offset = self.ep_rank * self.local_num_experts
|
||||
# max_num_tokens from scheduler config (for buffer pre-allocation)
|
||||
self.max_num_tokens = getattr(moe_config, 'max_num_tokens', 8192)
|
||||
|
||||
# swiglu_limit: read from the FusedMoE layer in process_weights_after_loading
|
||||
self._swiglu_limit = None
|
||||
|
||||
# Runner is created in process_weights_after_loading
|
||||
self._runner: CuTeDSLMoERunner | None = None
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
"""Convert NVFP4 MoE weights into CuTeDSL kernel format.
|
||||
|
||||
Reads w13/w2 weight tensors from the FusedMoE layer, converts them
|
||||
to the CuTeDSL runner's expected format, and creates the runner.
|
||||
Also folds the activation global scale (input_scale) into the
|
||||
weight global scale (weight_scale_2) as the runner's alpha.
|
||||
"""
|
||||
num_experts = layer.w13_weight.shape[0]
|
||||
hidden_size = self.hidden_dim
|
||||
intermediate_size = self.intermediate_size_per_partition
|
||||
device = layer.w13_weight.device
|
||||
|
||||
# NOTE: For the CuTeDSL kernel, we do NOT fold input_scale into
|
||||
# weight_scale_2. The CuTeDSL runner uses weight global scale
|
||||
# (weight_scale_2) and activation global scale separately.
|
||||
# The activation global scale is computed via warmup before first inference.
|
||||
#
|
||||
# Also, convert_to_nvfp4_moe_kernel_format already inverted input_scale
|
||||
# (1.0 / a13_scale) for the quant config. We undo that inversion here
|
||||
# to get the original input_scale, then use it as initial activation gs.
|
||||
if layer.w13_input_scale is not None and not isinstance(layer.w13_input_scale, float):
|
||||
# input_scale was inverted in convert_to_nvfp4_moe_kernel_format
|
||||
# Original: input_scale = amax / (6.0 * 448.0)
|
||||
# Inverted: 1.0 / input_scale = 6.0 * 448.0 / amax
|
||||
# We need the original for activation gs
|
||||
w13_input_scale_orig = 1.0 / layer.w13_input_scale
|
||||
else:
|
||||
w13_input_scale_orig = None
|
||||
if layer.w2_input_scale is not None and not isinstance(layer.w2_input_scale, float):
|
||||
w2_input_scale_orig = 1.0 / layer.w2_input_scale
|
||||
else:
|
||||
w2_input_scale_orig = None
|
||||
|
||||
# Extract and convert weights for CuTeDSL runner
|
||||
# w13_weight: (E, 2*intermediate, hidden//2) uint8 — gate + up fused
|
||||
# w2_weight: (E, hidden, intermediate//2) uint8 — down
|
||||
l1_fp4_list = []
|
||||
l1_sf_list = []
|
||||
l1_gs_list = []
|
||||
l2_fp4_list = []
|
||||
l2_sf_list = []
|
||||
l2_gs_list = []
|
||||
|
||||
for expert_id in range(num_experts):
|
||||
# L1: gate + up (w13)
|
||||
w13_uint8 = layer.w13_weight.data[expert_id] # (2*inter, hidden//2)
|
||||
w13_sf = layer.w13_weight_scale.data[expert_id] # (2*inter, hidden//16) fp8
|
||||
w13_gs = layer.w13_weight_scale_2.data[expert_id].item() # float32
|
||||
|
||||
# uint8 → float4_e2m1fn_x2, permute to (K_packed, N) for CuTeDSL
|
||||
l1_w = w13_uint8.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous()
|
||||
# Block scales: (N, K_sf) → (K_sf, N) for CuTeDSL
|
||||
l1_s = w13_sf.permute(1, 0).contiguous()
|
||||
if l1_s.dtype != torch.float8_e4m3fn:
|
||||
l1_s = l1_s.to(torch.float8_e4m3fn)
|
||||
|
||||
l1_fp4_list.append(l1_w)
|
||||
l1_sf_list.append(l1_s)
|
||||
l1_gs_list.append(w13_gs)
|
||||
|
||||
# L2: down (w2)
|
||||
w2_uint8 = layer.w2_weight.data[expert_id] # (hidden, intermediate//2)
|
||||
w2_sf = layer.w2_weight_scale.data[expert_id] # (hidden, intermediate//16) fp8
|
||||
w2_gs = layer.w2_weight_scale_2.data[expert_id].item() # float32
|
||||
|
||||
l2_w = w2_uint8.view(torch.float4_e2m1fn_x2).permute(1, 0).contiguous()
|
||||
l2_s = w2_sf.permute(1, 0).contiguous()
|
||||
if l2_s.dtype != torch.float8_e4m3fn:
|
||||
l2_s = l2_s.to(torch.float8_e4m3fn)
|
||||
|
||||
l2_fp4_list.append(l2_w)
|
||||
l2_sf_list.append(l2_s)
|
||||
l2_gs_list.append(w2_gs)
|
||||
|
||||
# Create the CuTeDSL runner
|
||||
self._runner = CuTeDSLMoERunner(
|
||||
num_experts=num_experts,
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
max_num_tokens=self.max_num_tokens,
|
||||
top_k=self.topk,
|
||||
device=str(device),
|
||||
experts_start_idx=self.local_expert_offset,
|
||||
)
|
||||
self._runner.prepare_weights_direct(
|
||||
l1_fp4_list, l1_sf_list, l1_gs_list,
|
||||
l2_fp4_list, l2_sf_list, l2_gs_list,
|
||||
)
|
||||
if self._swiglu_limit is not None:
|
||||
self._runner.set_swiglu_limit(float(self._swiglu_limit))
|
||||
|
||||
# Read swiglu_limit from the FusedMoE layer (set by DeepseekV4MoE)
|
||||
swiglu_limit = getattr(layer, 'swiglu_limit', None)
|
||||
if swiglu_limit is not None:
|
||||
self._swiglu_limit = swiglu_limit
|
||||
self._runner.set_swiglu_limit(float(swiglu_limit))
|
||||
|
||||
# Set initial activation global scales from checkpoint input_scale.
|
||||
# The CuTeDSL runner uses activation_gs = 1.0 / input_scale from the
|
||||
# checkpoint as the starting value. The warmup step
|
||||
# (compute_activation_global_scales) will override this with an
|
||||
# empirically computed value before the first inference.
|
||||
if w13_input_scale_orig is not None:
|
||||
# input_scale = 448.0 / amax → activation_gs = 1.0 / input_scale = amax / 448.0
|
||||
# Mean across experts (they should be similar)
|
||||
mean_l1_gs = float(w13_input_scale_orig.mean().item())
|
||||
if mean_l1_gs > 0:
|
||||
self._runner._l1_activation_global_scale = 1.0 / mean_l1_gs
|
||||
if w2_input_scale_orig is not None:
|
||||
mean_l2_gs = float(w2_input_scale_orig.mean().item())
|
||||
if mean_l2_gs > 0:
|
||||
self._runner._l2_activation_global_scale = 1.0 / mean_l2_gs
|
||||
|
||||
# Note: activation global scale warmup must be done after
|
||||
# process_weights_after_loading, before the first inference.
|
||||
# This is handled by the model's load_weights or a separate warmup step.
|
||||
|
||||
@property
|
||||
def runner(self) -> CuTeDSLMoERunner | None:
|
||||
return self._runner
|
||||
|
||||
@staticmethod
|
||||
def activation_format() -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.Standard
|
||||
|
||||
@staticmethod
|
||||
def _supports_current_device() -> bool:
|
||||
# CuTeDSL requires CUDA SM100 (Blackwell)
|
||||
p = current_platform
|
||||
return p.is_cuda() and p.is_device_capability_family(100)
|
||||
|
||||
@staticmethod
|
||||
def _supports_no_act_and_mul() -> bool:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _supports_quant_scheme(
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
) -> bool:
|
||||
SUPPORTED_W_A = [
|
||||
(kNvfp4Static, kNvfp4Dynamic),
|
||||
]
|
||||
return (weight_key, activation_key) in SUPPORTED_W_A
|
||||
|
||||
@staticmethod
|
||||
def _supports_activation(activation: MoEActivation) -> bool:
|
||||
# We handle SiLU + swiglu_limit internally
|
||||
return activation == MoEActivation.SILU
|
||||
|
||||
@staticmethod
|
||||
def _supports_parallel_config(
|
||||
moe_parallel_config: FusedMoEParallelConfig,
|
||||
) -> bool:
|
||||
return True
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def expects_unquantized_inputs(self) -> bool:
|
||||
# Our runner handles activation quantization internally
|
||||
return True
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
return TopKWeightAndReduceNoOP()
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
topk: int,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
activation: MoEActivation,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
# Our runner manages its own workspace internally (pre-allocated buffers)
|
||||
workspace1 = (0,)
|
||||
workspace2 = (0,)
|
||||
# K is packed (K//2 for uint8), so output uses hidden_dim
|
||||
assert self.hidden_dim == K * 2
|
||||
output = (M, self.hidden_dim)
|
||||
return (workspace1, workspace2, output)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
a2_scale: torch.Tensor | None,
|
||||
workspace13: torch.Tensor | None,
|
||||
workspace2: torch.Tensor | None,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
apply_router_weight_on_input: bool | None,
|
||||
):
|
||||
assert self._runner is not None, (
|
||||
"CuTeDSL runner not initialized. "
|
||||
"Call process_weights_after_loading first."
|
||||
)
|
||||
|
||||
# Our runner expects topk_ids as global expert IDs.
|
||||
# The modular kernel framework may pass local IDs with expert_map.
|
||||
# We handle remapping internally via experts_start_idx.
|
||||
result = self._runner.run(
|
||||
hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
)
|
||||
|
||||
# Copy result into output tensor
|
||||
output.copy_(result)
|
||||
535
vllm/patches/fused_moe/oracle/nvfp4.py
Normal file
535
vllm/patches/fused_moe/oracle/nvfp4.py
Normal file
@@ -0,0 +1,535 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from enum import Enum
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.config.kernel import MoEBackend
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.all2all_utils import (
|
||||
maybe_make_prepare_finalize,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEQuantConfig,
|
||||
nvfp4_moe_quant_config,
|
||||
nvfp4_w4a16_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
|
||||
prepare_nvfp4_moe_layer_for_fi_or_cutlass,
|
||||
prepare_nvfp4_moe_layer_for_flashinfer_cutedsl,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
FlashinferMoeBackend,
|
||||
get_flashinfer_moe_backend,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
||||
prepare_nvfp4_moe_layer_for_marlin,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import (
|
||||
kE2M1ToFloat_handle,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class NvFp4MoeBackend(Enum):
|
||||
FLASHINFER_TRTLLM = "FLASHINFER_TRTLLM"
|
||||
FLASHINFER_CUTLASS = "FLASHINFER_CUTLASS"
|
||||
FLASHINFER_CUTEDSL = "FLASHINFER_CUTEDSL"
|
||||
FLASHINFER_CUTEDSL_BATCHED = "FLASHINFER_CUTEDSL_BATCHED"
|
||||
VLLM_CUTLASS = "VLLM_CUTLASS"
|
||||
MARLIN = "MARLIN"
|
||||
CUTEDSL = "CUTEDSL"
|
||||
EMULATION = "EMULATION"
|
||||
|
||||
|
||||
FLASHINFER_NVFP4_MOE_BACKENDS = [
|
||||
NvFp4MoeBackend.FLASHINFER_TRTLLM,
|
||||
NvFp4MoeBackend.FLASHINFER_CUTLASS,
|
||||
NvFp4MoeBackend.FLASHINFER_CUTEDSL,
|
||||
NvFp4MoeBackend.FLASHINFER_CUTEDSL_BATCHED,
|
||||
]
|
||||
|
||||
CUTEDSL_NVFP4_MOE_BACKENDS = [
|
||||
NvFp4MoeBackend.CUTEDSL,
|
||||
]
|
||||
|
||||
fi_2_vllm_backend_map: dict[FlashinferMoeBackend, NvFp4MoeBackend] = {
|
||||
FlashinferMoeBackend.CUTLASS: NvFp4MoeBackend.FLASHINFER_CUTLASS,
|
||||
FlashinferMoeBackend.TENSORRT_LLM: NvFp4MoeBackend.FLASHINFER_TRTLLM,
|
||||
FlashinferMoeBackend.CUTEDSL: NvFp4MoeBackend.FLASHINFER_CUTEDSL,
|
||||
}
|
||||
|
||||
|
||||
def is_global_sf_supported_for_nvfp4_backend(backend: NvFp4MoeBackend) -> bool:
|
||||
# Checks whether `backend` supports quantizing with scaling factors
|
||||
# of all experts in Expert Parallel Mode when all experts are not
|
||||
# on the same rank.
|
||||
|
||||
return backend in FLASHINFER_NVFP4_MOE_BACKENDS or backend in CUTEDSL_NVFP4_MOE_BACKENDS
|
||||
|
||||
|
||||
def backend_to_kernel_cls(
|
||||
backend: NvFp4MoeBackend,
|
||||
) -> list[type[mk.FusedMoEExperts]]:
|
||||
if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
|
||||
from vllm.model_executor.layers.fused_moe.experts.trtllm_nvfp4_moe import (
|
||||
TrtLlmNvFp4ExpertsModular,
|
||||
TrtLlmNvFp4ExpertsMonolithic,
|
||||
)
|
||||
|
||||
# NOTE: prefer Monolthic > Modular, so return Monolithic first.
|
||||
return [
|
||||
TrtLlmNvFp4ExpertsMonolithic,
|
||||
TrtLlmNvFp4ExpertsModular,
|
||||
]
|
||||
|
||||
elif backend == NvFp4MoeBackend.FLASHINFER_CUTLASS:
|
||||
from vllm.model_executor.layers.fused_moe.experts.flashinfer_cutlass_moe import ( # noqa: E501
|
||||
FlashInferExperts,
|
||||
)
|
||||
|
||||
return [FlashInferExperts]
|
||||
|
||||
elif backend == NvFp4MoeBackend.FLASHINFER_CUTEDSL:
|
||||
from vllm.model_executor.layers.fused_moe.experts.flashinfer_cutedsl_moe import ( # noqa: E501
|
||||
FlashInferCuteDSLExperts,
|
||||
)
|
||||
|
||||
return [FlashInferCuteDSLExperts]
|
||||
|
||||
elif backend == NvFp4MoeBackend.FLASHINFER_CUTEDSL_BATCHED:
|
||||
from vllm.model_executor.layers.fused_moe.experts.flashinfer_cutedsl_batched_moe import ( # noqa: E501
|
||||
FlashInferCuteDSLBatchedExperts,
|
||||
)
|
||||
|
||||
return [FlashInferCuteDSLBatchedExperts]
|
||||
|
||||
elif backend == NvFp4MoeBackend.CUTEDSL:
|
||||
from vllm.model_executor.layers.fused_moe.experts.cutedsl_moe import ( # noqa: E501
|
||||
CuTeDSLMoEExperts,
|
||||
)
|
||||
|
||||
return [CuTeDSLMoEExperts]
|
||||
|
||||
elif backend == NvFp4MoeBackend.VLLM_CUTLASS:
|
||||
from vllm.model_executor.layers.fused_moe.experts.cutlass_moe import (
|
||||
CutlassExpertsFp4,
|
||||
)
|
||||
|
||||
return [CutlassExpertsFp4]
|
||||
|
||||
elif backend == NvFp4MoeBackend.MARLIN:
|
||||
from vllm.model_executor.layers.fused_moe.experts.marlin_moe import (
|
||||
MarlinExperts,
|
||||
)
|
||||
|
||||
return [MarlinExperts]
|
||||
elif backend == NvFp4MoeBackend.EMULATION:
|
||||
from vllm.model_executor.layers.fused_moe.experts.nvfp4_emulation_moe import (
|
||||
Nvfp4QuantizationEmulationTritonExperts,
|
||||
)
|
||||
|
||||
return [Nvfp4QuantizationEmulationTritonExperts]
|
||||
else:
|
||||
raise ValueError(f"Unknown NvFP4 MoE backend: {backend.value}")
|
||||
|
||||
|
||||
def map_nvfp4_backend(runner_backend: MoEBackend) -> NvFp4MoeBackend:
|
||||
"""Map user's MoEBackend to NvFp4MoeBackend."""
|
||||
mapping = {
|
||||
"cutlass": NvFp4MoeBackend.VLLM_CUTLASS,
|
||||
"flashinfer_trtllm": NvFp4MoeBackend.FLASHINFER_TRTLLM,
|
||||
"flashinfer_cutlass": NvFp4MoeBackend.FLASHINFER_CUTLASS,
|
||||
"flashinfer_cutedsl": NvFp4MoeBackend.FLASHINFER_CUTEDSL,
|
||||
"cutedsl": NvFp4MoeBackend.CUTEDSL,
|
||||
"marlin": NvFp4MoeBackend.MARLIN,
|
||||
"emulation": NvFp4MoeBackend.EMULATION,
|
||||
}
|
||||
if backend := mapping.get(runner_backend):
|
||||
return backend
|
||||
raise ValueError(
|
||||
f"moe_backend='{runner_backend}' is not supported for NvFP4 MoE. "
|
||||
f"Expected one of {list(mapping.keys())}."
|
||||
)
|
||||
|
||||
|
||||
def select_nvfp4_moe_backend(
|
||||
config: FusedMoEConfig,
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
) -> tuple[NvFp4MoeBackend, type[mk.FusedMoEExperts]]:
|
||||
"""
|
||||
Select the primary NvFP4 MoE backend
|
||||
Note: Shape-specific fallbacks may still occur at runtime.
|
||||
"""
|
||||
|
||||
# NOTE: the kernels are selected in the following order.
|
||||
AVAILABLE_BACKENDS = [
|
||||
NvFp4MoeBackend.FLASHINFER_TRTLLM,
|
||||
NvFp4MoeBackend.FLASHINFER_CUTEDSL,
|
||||
NvFp4MoeBackend.FLASHINFER_CUTEDSL_BATCHED,
|
||||
NvFp4MoeBackend.CUTEDSL,
|
||||
NvFp4MoeBackend.FLASHINFER_CUTLASS,
|
||||
NvFp4MoeBackend.VLLM_CUTLASS,
|
||||
NvFp4MoeBackend.MARLIN,
|
||||
NvFp4MoeBackend.EMULATION,
|
||||
]
|
||||
|
||||
use_batched = config.moe_parallel_config.use_batched_activation_format
|
||||
activation_format = (
|
||||
mk.FusedMoEActivationFormat.BatchedExperts
|
||||
if use_batched
|
||||
else mk.FusedMoEActivationFormat.Standard
|
||||
)
|
||||
|
||||
def _make_log_backend(backend: NvFp4MoeBackend):
|
||||
available_backend_strs = [b.value for b in AVAILABLE_BACKENDS]
|
||||
return (
|
||||
f"Using '{backend.value}' NvFp4 MoE backend out "
|
||||
f"of potential backends: {available_backend_strs}."
|
||||
)
|
||||
|
||||
def _make_log_unsupported(backend: NvFp4MoeBackend, reason: str | None) -> str:
|
||||
if reason:
|
||||
return (
|
||||
f"NvFp4 MoE backend '{backend.value}' does not support the "
|
||||
f"deployment configuration since {reason}."
|
||||
)
|
||||
else:
|
||||
return (
|
||||
f"NvFp4 MoE backend '{backend.value}' does not support the "
|
||||
"deployment configuration."
|
||||
)
|
||||
|
||||
def _return_or_raise(
|
||||
backend: NvFp4MoeBackend,
|
||||
config: FusedMoEConfig,
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
activation_format: mk.FusedMoEActivationFormat,
|
||||
) -> tuple[NvFp4MoeBackend, type[mk.FusedMoEExperts]]:
|
||||
for k_cls in backend_to_kernel_cls(backend):
|
||||
supported, reason = k_cls.is_supported_config(
|
||||
k_cls, config, weight_key, activation_key, activation_format
|
||||
)
|
||||
if supported:
|
||||
logger.info_once(_make_log_backend(backend))
|
||||
return backend, k_cls
|
||||
|
||||
raise ValueError(_make_log_unsupported(backend, reason))
|
||||
|
||||
# Handle explicit moe_backend from user.
|
||||
runner_backend = config.moe_backend
|
||||
if runner_backend != "auto":
|
||||
requested_backend = map_nvfp4_backend(runner_backend)
|
||||
# For batched activation format, use batched variant if available.
|
||||
if (
|
||||
activation_format == mk.FusedMoEActivationFormat.BatchedExperts
|
||||
and requested_backend == NvFp4MoeBackend.FLASHINFER_CUTEDSL
|
||||
):
|
||||
requested_backend = NvFp4MoeBackend.FLASHINFER_CUTEDSL_BATCHED
|
||||
return _return_or_raise(
|
||||
requested_backend, config, weight_key, activation_key, activation_format
|
||||
)
|
||||
|
||||
if envs.is_set("VLLM_USE_FLASHINFER_MOE_FP4"):
|
||||
if not envs.VLLM_USE_FLASHINFER_MOE_FP4:
|
||||
# If the user rejects FlashInfer remove those backends.
|
||||
for b in FLASHINFER_NVFP4_MOE_BACKENDS:
|
||||
AVAILABLE_BACKENDS.remove(b)
|
||||
|
||||
elif envs.is_set("VLLM_FLASHINFER_MOE_BACKEND"):
|
||||
# If user is explicit about backend, validate it.
|
||||
backend = fi_2_vllm_backend_map[get_flashinfer_moe_backend()]
|
||||
return _return_or_raise(
|
||||
backend, config, weight_key, activation_key, activation_format
|
||||
)
|
||||
else:
|
||||
# If the user is not explicit about the backend, try each.
|
||||
for backend in FLASHINFER_NVFP4_MOE_BACKENDS:
|
||||
for k_cls in backend_to_kernel_cls(backend):
|
||||
supported, reason = k_cls.is_supported_config(
|
||||
k_cls,
|
||||
config,
|
||||
weight_key,
|
||||
activation_key,
|
||||
activation_format,
|
||||
)
|
||||
if supported:
|
||||
logger.info_once(_make_log_backend(backend))
|
||||
return backend, k_cls
|
||||
else:
|
||||
logger.debug_once(_make_log_unsupported(backend, reason))
|
||||
|
||||
raise NotImplementedError(
|
||||
"Found VLLM_USE_FLASHINFER_MOE_FP4=1, but no "
|
||||
"FlashInfer NVFP4 MoE backend supports the configuration."
|
||||
)
|
||||
|
||||
if envs.VLLM_TEST_FORCE_FP8_MARLIN:
|
||||
backend = NvFp4MoeBackend.MARLIN
|
||||
return _return_or_raise(
|
||||
backend, config, weight_key, activation_key, activation_format
|
||||
)
|
||||
|
||||
# Select kernels in order of backend.
|
||||
for backend in AVAILABLE_BACKENDS:
|
||||
for k_cls in backend_to_kernel_cls(backend):
|
||||
supported, reason = k_cls.is_supported_config(
|
||||
k_cls,
|
||||
config,
|
||||
weight_key,
|
||||
activation_key,
|
||||
activation_format,
|
||||
)
|
||||
if supported:
|
||||
logger.info_once(_make_log_backend(backend))
|
||||
return backend, k_cls
|
||||
else:
|
||||
logger.debug_once(_make_log_unsupported(backend, reason))
|
||||
|
||||
raise NotImplementedError(
|
||||
"No NvFp4 MoE backend supports the deployment configuration."
|
||||
)
|
||||
|
||||
|
||||
def convert_to_nvfp4_moe_kernel_format(
|
||||
nvfp4_backend: NvFp4MoeBackend,
|
||||
layer: torch.nn.Module,
|
||||
w13: torch.Tensor,
|
||||
w13_scale: torch.Tensor,
|
||||
w13_scale_2: torch.Tensor,
|
||||
a13_scale: torch.Tensor | None,
|
||||
w2: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
w2_scale_2: torch.Tensor,
|
||||
a2_scale: torch.Tensor | None,
|
||||
is_act_and_mul: bool,
|
||||
) -> tuple[
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
]:
|
||||
if nvfp4_backend == NvFp4MoeBackend.CUTEDSL:
|
||||
# CuTeDSL kernel handles weight conversion in its own
|
||||
# process_weights_after_loading. Pass through raw weights.
|
||||
# Compute inverse activation scales for the quant config.
|
||||
if a13_scale is not None:
|
||||
a13_scale = 1.0 / a13_scale
|
||||
if a2_scale is not None:
|
||||
a2_scale = 1.0 / a2_scale
|
||||
elif nvfp4_backend == NvFp4MoeBackend.FLASHINFER_CUTEDSL:
|
||||
(
|
||||
w13,
|
||||
w13_scale,
|
||||
w13_scale_2,
|
||||
a13_scale,
|
||||
w2,
|
||||
w2_scale,
|
||||
w2_scale_2,
|
||||
a2_scale,
|
||||
) = prepare_nvfp4_moe_layer_for_flashinfer_cutedsl(
|
||||
layer=layer,
|
||||
w13=w13,
|
||||
w13_scale=w13_scale,
|
||||
w13_scale_2=w13_scale_2,
|
||||
a13_scale=a13_scale,
|
||||
w2=w2,
|
||||
w2_scale=w2_scale,
|
||||
w2_scale_2=w2_scale_2,
|
||||
a2_scale=a2_scale,
|
||||
)
|
||||
elif (
|
||||
nvfp4_backend in FLASHINFER_NVFP4_MOE_BACKENDS
|
||||
or nvfp4_backend == NvFp4MoeBackend.VLLM_CUTLASS
|
||||
):
|
||||
(
|
||||
w13,
|
||||
w13_scale,
|
||||
w13_scale_2,
|
||||
a13_scale,
|
||||
w2,
|
||||
w2_scale,
|
||||
w2_scale_2,
|
||||
a2_scale,
|
||||
) = prepare_nvfp4_moe_layer_for_fi_or_cutlass(
|
||||
backend=nvfp4_backend,
|
||||
layer=layer,
|
||||
w13=w13,
|
||||
w13_scale=w13_scale,
|
||||
w13_scale_2=w13_scale_2,
|
||||
a13_scale=a13_scale,
|
||||
w2=w2,
|
||||
w2_scale=w2_scale,
|
||||
w2_scale_2=w2_scale_2,
|
||||
a2_scale=a2_scale,
|
||||
is_act_and_mul=is_act_and_mul,
|
||||
)
|
||||
elif nvfp4_backend == NvFp4MoeBackend.MARLIN:
|
||||
a13_scale = None
|
||||
a2_scale = None
|
||||
(
|
||||
w13,
|
||||
w13_scale,
|
||||
w13_scale_2,
|
||||
w2,
|
||||
w2_scale,
|
||||
w2_scale_2,
|
||||
) = prepare_nvfp4_moe_layer_for_marlin(
|
||||
layer=layer,
|
||||
w13=w13,
|
||||
w13_scale=w13_scale,
|
||||
w13_scale_2=w13_scale_2,
|
||||
w2=w2,
|
||||
w2_scale=w2_scale,
|
||||
w2_scale_2=w2_scale_2,
|
||||
is_act_and_mul=is_act_and_mul,
|
||||
)
|
||||
elif nvfp4_backend == NvFp4MoeBackend.EMULATION:
|
||||
# Move the E2M1 lookup table to the device now, because
|
||||
# `.to(device)` is not allowed during CUDA graph capture.
|
||||
kE2M1ToFloat_handle.val = kE2M1ToFloat_handle.val.to(w13.device)
|
||||
|
||||
if a13_scale is None or a2_scale is None:
|
||||
raise ValueError(
|
||||
"Activation global scales should not be None, got"
|
||||
f" a13_scale={a13_scale}, a2_scale={a2_scale}"
|
||||
)
|
||||
|
||||
if torch.unique(a13_scale).numel() != 1 or torch.unique(a2_scale).numel() != 1:
|
||||
logger.warning_once(
|
||||
"In NVFP4 linear, the activation global scale for inputs are different"
|
||||
" for MOE w13 (gate_up_proj) layer or MOE w2 (down_proj). Using"
|
||||
" a13_scale = a13_scale.max() and a2_scale = a2_scale.max()."
|
||||
)
|
||||
|
||||
# 1. We take the max following e.g. quantization/utils/flashinfer_fp4_moe.py.
|
||||
# 2. moe_kernel_quantize_input -> ref_nvfp4_quant_dequant
|
||||
# use the inverse scale directly (large global scale).
|
||||
# NOTE: Before this point, `a13_scale` and `a2_scale` are such that:
|
||||
# `FP8_MAX = activation[expert_id].abs().max() * global_scale[expert_id]`,
|
||||
# and `global_scale[expert_id]` are small (~1e-4).
|
||||
# Taking the largest global scale likely results in overflowing the FP8 range
|
||||
# for other experts - other selection strategies may be used.
|
||||
a13_scale = 1.0 / a13_scale.max().to(torch.float32)
|
||||
a2_scale = 1.0 / a2_scale.max().to(torch.float32)
|
||||
else:
|
||||
raise ValueError(f"Unknown NvFp4 backend for MoE: {nvfp4_backend}")
|
||||
|
||||
return (
|
||||
w13,
|
||||
w13_scale,
|
||||
w13_scale_2,
|
||||
a13_scale,
|
||||
w2,
|
||||
w2_scale,
|
||||
w2_scale_2,
|
||||
a2_scale,
|
||||
)
|
||||
|
||||
|
||||
def make_nvfp4_moe_quant_config(
|
||||
backend: NvFp4MoeBackend,
|
||||
w13_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
w13_scale_2: torch.Tensor,
|
||||
w2_scale_2: torch.Tensor,
|
||||
a13_scale: torch.Tensor,
|
||||
a2_scale: torch.Tensor,
|
||||
) -> FusedMoEQuantConfig:
|
||||
if backend == NvFp4MoeBackend.MARLIN:
|
||||
return nvfp4_w4a16_moe_quant_config(
|
||||
g1_alphas=w13_scale_2,
|
||||
g2_alphas=w2_scale_2,
|
||||
w1_scale=w13_scale,
|
||||
w2_scale=w2_scale,
|
||||
)
|
||||
elif backend == NvFp4MoeBackend.EMULATION:
|
||||
return nvfp4_moe_quant_config(
|
||||
g1_alphas=w13_scale_2,
|
||||
g2_alphas=w2_scale_2,
|
||||
a1_gscale=a13_scale,
|
||||
a2_gscale=a2_scale,
|
||||
w1_scale=w13_scale,
|
||||
w2_scale=w2_scale,
|
||||
)
|
||||
|
||||
# Pass w13_scale_2 / w2_scale_2 directly as g1/g2_alphas.
|
||||
# The expert's process_weights_after_loading will fuse activation
|
||||
# scales in-place. Since the quant config references the same tensor
|
||||
# as the registered parameter, EPLB rearrangement stays in sync.
|
||||
return nvfp4_moe_quant_config(
|
||||
g1_alphas=w13_scale_2,
|
||||
g2_alphas=w2_scale_2,
|
||||
a1_gscale=(1.0 / a13_scale),
|
||||
a2_gscale=(1.0 / a2_scale),
|
||||
w1_scale=w13_scale,
|
||||
w2_scale=w2_scale,
|
||||
# NOTE(rob): this is a hack until the MoE kernels
|
||||
# create their own quant configs. TRTLLM kernel
|
||||
# does not accept swizzled input quant scales.
|
||||
is_scale_swizzled=(
|
||||
backend
|
||||
not in (
|
||||
NvFp4MoeBackend.FLASHINFER_TRTLLM,
|
||||
NvFp4MoeBackend.FLASHINFER_CUTEDSL,
|
||||
NvFp4MoeBackend.CUTEDSL,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def make_nvfp4_moe_kernel(
|
||||
moe_quant_config: FusedMoEQuantConfig,
|
||||
moe_config: FusedMoEConfig,
|
||||
experts_cls: type[mk.FusedMoEExperts],
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
) -> mk.FusedMoEKernel:
|
||||
# Create Prepare/Finalize.
|
||||
prepare_finalize = maybe_make_prepare_finalize(
|
||||
moe=moe_config,
|
||||
quant_config=moe_quant_config,
|
||||
routing_tables=routing_tables,
|
||||
allow_new_interface=True,
|
||||
use_monolithic=issubclass(experts_cls, mk.FusedMoEExpertsMonolithic),
|
||||
)
|
||||
assert prepare_finalize is not None
|
||||
|
||||
logger.info_once("Using %s", prepare_finalize.__class__.__name__)
|
||||
|
||||
# Create Experts.
|
||||
if prepare_finalize.activation_format == mk.FusedMoEActivationFormat.BatchedExperts:
|
||||
max_num_tokens = prepare_finalize.max_num_tokens_per_rank()
|
||||
assert max_num_tokens is not None
|
||||
experts = experts_cls(
|
||||
moe_config=moe_config,
|
||||
quant_config=moe_quant_config,
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_dispatchers=prepare_finalize.num_dispatchers(),
|
||||
)
|
||||
else:
|
||||
experts = experts_cls(
|
||||
moe_config=moe_config,
|
||||
quant_config=moe_quant_config,
|
||||
)
|
||||
|
||||
kernel = mk.FusedMoEKernel(
|
||||
prepare_finalize,
|
||||
experts,
|
||||
inplace=False,
|
||||
)
|
||||
|
||||
# TODO(rob): update inplace logic to be part of the kernel.
|
||||
return kernel
|
||||
2378
vllm/patches/modelopt.py
Normal file
2378
vllm/patches/modelopt.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user