588 lines
27 KiB
Python
588 lines
27 KiB
Python
"""
|
|
vLLM integration for the CuTeDSL NVFP4 MoE kernel.
|
|
|
|
CUDA-graph-compatible design:
|
|
- All intermediate buffers pre-allocated at max_num_tokens * top_k size
|
|
- No .item(), .tolist(), .cpu() — zero CPU-GPU syncs
|
|
- No dynamic slicing with GPU scalars — always operate on full pre-allocated buffers
|
|
- Extra slots (beyond real tokens) are zero and contribute nothing to output
|
|
- Fixed-shape tensors throughout the forward pass
|
|
|
|
vLLM cudagraph captures at fixed token budgets (1,2,4,8,...,8192).
|
|
During capture, num_tokens equals the budget — all shapes are fixed.
|
|
During replay, inputs are padded to the budget size. Our runner always
|
|
processes max_slots = budget * top_k rows; padding rows are zeros.
|
|
"""
|
|
import torch
|
|
|
|
from cutedsl.bridge import (
|
|
quantize_activation_nvfp4,
|
|
quantize_weight_to_nvfp4,
|
|
quantize_to_nvfp4,
|
|
make_b_k_major,
|
|
assemble_scales_3d_side,
|
|
run_nvfp4_grouped_gemm,
|
|
)
|
|
from cutedsl.kernel.moe.torch_scaled_grouped_mm import (
|
|
ceil_div as cutedsl_ceil_div,
|
|
pad_and_swizzle_single,
|
|
)
|
|
from cutedsl.custom_ops import register_runner, nvfp4_moe_gemm
|
|
|
|
|
|
class CuTeDSLMoERunner:
|
|
"""Manages NVFP4 MoE execution via the CuTeDSL kernel.
|
|
|
|
CUDA-graph-compatible: all buffers pre-allocated, no CPU-GPU syncs,
|
|
no dynamic shapes. Always computes at max_num_tokens * top_k capacity.
|
|
"""
|
|
|
|
def __init__(self, num_experts, hidden_size, intermediate_size,
|
|
max_num_tokens=8192, top_k=8, device="cuda",
|
|
experts_start_idx=0):
|
|
self.num_experts = num_experts
|
|
self.hidden_size = hidden_size
|
|
self.intermediate_size = intermediate_size
|
|
self.max_num_tokens = max_num_tokens
|
|
self.top_k = top_k
|
|
self.device = device
|
|
self.experts_start_idx = experts_start_idx
|
|
self._swiglu_limit = None # Set via set_swiglu_limit()
|
|
|
|
# Weight storage (set before _ensure_stacked)
|
|
self.l1_fp4 = None
|
|
self.l1_sf = None
|
|
self.l1_gs = None
|
|
self.l2_fp4 = None
|
|
self.l2_sf = None
|
|
self.l2_gs = None
|
|
|
|
# Stacked weight tensors (set in _ensure_stacked)
|
|
self._l1_mat_b = None
|
|
self._l2_mat_b = None
|
|
self._l1_scale_b = None
|
|
self._l2_scale_b = None
|
|
self._l1_gsb = None
|
|
self._l2_gsb = None
|
|
|
|
# Default: 1/2688 ≈ 0.000372 (amax=1 → gs=1/2688)
|
|
# Overridden in finalize_weights with checkpoint input_scale or warmup value
|
|
self._l1_activation_global_scale = 1.0 / (6.0 * 448.0)
|
|
self._l2_activation_global_scale = 1.0 / (6.0 * 448.0)
|
|
|
|
# Pre-allocated cudagraph buffers (set in _allocate_buffers)
|
|
self._token_indices = None
|
|
self._expert_id_range = None
|
|
self._expert_offsets_buf = None
|
|
self._per_expert_scale_bufs_l1 = None
|
|
self._per_expert_scale_bufs_l2 = None
|
|
self._padded_x_sf_buf_l1 = None
|
|
self._padded_x_sf_buf_l2 = None
|
|
self._l1_gsa_buf = None
|
|
self._l2_gsa_buf = None
|
|
self._output_buf = None
|
|
self._row_indices_buf = None
|
|
self._padded_hidden_buf = None
|
|
self._padded_activated_buf = None # unused, using shared
|
|
self._padded_expert_offsets_buf = None
|
|
self._max_chunks_per_expert = cutedsl_ceil_div(
|
|
self.max_num_tokens * self.top_k, self.num_experts * 128
|
|
)
|
|
self._buffers_allocated = False
|
|
|
|
def set_swiglu_limit(self, limit: float | None):
|
|
"""Set the swiglu_limit for activation clamping."""
|
|
self._swiglu_limit = limit
|
|
|
|
def _fill_token_indices(self):
|
|
"""Fill _token_indices with [0,0,..0, 1,1,..1, ...] (each token repeated top_k times).
|
|
|
|
Builds on CPU first, then copies to GPU, to ensure correctness
|
|
regardless of CuTeDSL JIT GPU memory corruption.
|
|
"""
|
|
src = torch.arange(self.max_num_tokens, dtype=torch.int32)
|
|
cpu_indices = src.unsqueeze(1).expand(-1, self.top_k).contiguous().view(-1)
|
|
self._token_indices.copy_(cpu_indices)
|
|
|
|
def _allocate_buffers(self):
|
|
"""Pre-allocate scale buffers at max size for cudagraph compatibility."""
|
|
# Per-expert scale buffers: separate L1/L2 since K_sf differs
|
|
K_sf_l1 = cutedsl_ceil_div(self.hidden_size, 16)
|
|
padded_cols_l1 = cutedsl_ceil_div(K_sf_l1, 4) * 4
|
|
K_sf_l2 = cutedsl_ceil_div(self.intermediate_size, 16)
|
|
padded_cols_l2 = cutedsl_ceil_div(K_sf_l2, 4) * 4
|
|
|
|
self._per_expert_scale_bufs_l1 = [
|
|
torch.zeros(128, padded_cols_l1, dtype=torch.float16, device=self.device).to(torch.float8_e4m3fn)
|
|
for _ in range(self.num_experts)
|
|
]
|
|
self._per_expert_scale_bufs_l2 = [
|
|
torch.zeros(128, padded_cols_l2, dtype=torch.float16, device=self.device).to(torch.float8_e4m3fn)
|
|
for _ in range(self.num_experts)
|
|
]
|
|
|
|
# Initialize shared buffers dict (if not already)
|
|
device_key = str(self.device)
|
|
if not hasattr(CuTeDSLMoERunner, '_shared_padded_bufs'):
|
|
CuTeDSLMoERunner._shared_padded_bufs = {}
|
|
if device_key not in CuTeDSLMoERunner._shared_padded_bufs:
|
|
CuTeDSLMoERunner._shared_padded_bufs[device_key] = {}
|
|
|
|
# Padded x_sf buffers: SHARED across all runners (not per-layer)
|
|
max_sf_rows = self.num_experts * self._max_chunks_per_expert * 128
|
|
if 'xsf_l1' not in CuTeDSLMoERunner._shared_padded_bufs[device_key]:
|
|
CuTeDSLMoERunner._shared_padded_bufs[device_key].update({
|
|
'xsf_l1': torch.zeros(
|
|
max_sf_rows, padded_cols_l1, dtype=torch.float16, device=self.device
|
|
).to(torch.float8_e4m3fn),
|
|
'xsf_l2': torch.zeros(
|
|
max_sf_rows, padded_cols_l2, dtype=torch.float16, device=self.device
|
|
).to(torch.float8_e4m3fn),
|
|
'output': torch.zeros(
|
|
self.max_num_tokens, self.hidden_size, dtype=torch.bfloat16, device=self.device
|
|
),
|
|
})
|
|
self._padded_x_sf_buf_l1 = CuTeDSLMoERunner._shared_padded_bufs[device_key]['xsf_l1']
|
|
self._padded_x_sf_buf_l2 = CuTeDSLMoERunner._shared_padded_bufs[device_key]['xsf_l2']
|
|
self._output_buf = CuTeDSLMoERunner._shared_padded_bufs[device_key]['output']
|
|
|
|
# Pre-allocated global_scale_a buffers (filled via .fill_(), no torch.full during capture)
|
|
self._l1_gsa_buf = torch.zeros(self.num_experts, dtype=torch.float32, device=self.device)
|
|
self._l2_gsa_buf = torch.zeros(self.num_experts, dtype=torch.float32, device=self.device)
|
|
|
|
# Row indices for scale assembly (max_num_tokens * top_k slots)
|
|
self._row_indices_buf = torch.arange(
|
|
self.max_num_tokens * self.top_k, device=self.device
|
|
)
|
|
|
|
# Padded hidden/activated: SHARED across all runners (not per-layer)
|
|
max_rows_per_expert = self._max_chunks_per_expert * 128
|
|
padded_max_slots = self.num_experts * max_rows_per_expert
|
|
if 'hidden' not in CuTeDSLMoERunner._shared_padded_bufs[device_key]:
|
|
CuTeDSLMoERunner._shared_padded_bufs[device_key].update({
|
|
'hidden': torch.zeros(
|
|
padded_max_slots, self.hidden_size, dtype=torch.bfloat16, device=self.device
|
|
),
|
|
'hidden_fp4': torch.zeros(
|
|
padded_max_slots, self.hidden_size // 2, dtype=torch.uint8, device=self.device
|
|
).view(torch.float4_e2m1fn_x2),
|
|
'activated': torch.zeros(
|
|
padded_max_slots, self.intermediate_size, dtype=torch.bfloat16, device=self.device
|
|
),
|
|
'activated_fp4': torch.zeros(
|
|
padded_max_slots, self.intermediate_size // 2, dtype=torch.uint8, device=self.device
|
|
).view(torch.float4_e2m1fn_x2),
|
|
})
|
|
self._shared_bufs = CuTeDSLMoERunner._shared_padded_bufs[device_key]
|
|
|
|
# Padded expert offsets buffer: [0, max_rows, 2*max_rows, ...] (fixed)
|
|
self._padded_expert_offsets_buf = torch.zeros(
|
|
self.num_experts + 1, dtype=torch.int32, device=self.device
|
|
)
|
|
max_rows_per_expert = self._max_chunks_per_expert * 128
|
|
self._padded_expert_offsets_buf[1:] = torch.arange(
|
|
1, self.num_experts + 1, dtype=torch.int32, device=self.device
|
|
) * max_rows_per_expert
|
|
|
|
self._buffers_allocated = True
|
|
|
|
def _ensure_stacked(self):
|
|
if self._l1_mat_b is not None:
|
|
return
|
|
|
|
# Convert weights to kernel format
|
|
if hasattr(self, 'l1_fp4_stacked') and self.l1_fp4_stacked is not None:
|
|
# Fast path: pre-stacked 3D tensors in checkpoint format (E, N, K)
|
|
# Permute to (E, K, N) then make K-major
|
|
l1_fp4_ekn = self.l1_fp4_stacked.permute(0, 2, 1).contiguous()
|
|
l2_fp4_ekn = self.l2_fp4_stacked.permute(0, 2, 1).contiguous()
|
|
# Free stacked checkpoints before make_b_k_major (saves one copy)
|
|
self.l1_fp4_stacked = None
|
|
self.l2_fp4_stacked = None
|
|
torch.cuda.empty_cache()
|
|
|
|
self._l1_mat_b = make_b_k_major(l1_fp4_ekn)
|
|
self._l2_mat_b = make_b_k_major(l2_fp4_ekn)
|
|
del l1_fp4_ekn, l2_fp4_ekn
|
|
torch.cuda.empty_cache()
|
|
|
|
# Scales: checkpoint is (E, N, K_sf) — the kernel expects (N, K_sf)
|
|
# per expert for swizzle. Split into views (no copy), then assemble.
|
|
l1_sf_list = [self.l1_sf_stacked[i] for i in range(self.num_experts)]
|
|
l2_sf_list = [self.l2_sf_stacked[i] for i in range(self.num_experts)]
|
|
self.l1_sf_stacked = None
|
|
self.l2_sf_stacked = None
|
|
torch.cuda.empty_cache()
|
|
|
|
# assemble_scales_3d_side expects (K_sf, N) per expert and transposes
|
|
# to (N, K_sf) internally. But our scales are already (N, K_sf) from
|
|
# the checkpoint! Skip the transpose by calling the assembly directly.
|
|
from cutedsl.kernel.moe.torch_scaled_grouped_mm import (
|
|
assemble_raw_scales_2d3d_3d_side,
|
|
)
|
|
self._l1_scale_b = assemble_raw_scales_2d3d_3d_side(l1_sf_list)
|
|
self._l2_scale_b = assemble_raw_scales_2d3d_3d_side(l2_sf_list)
|
|
del l1_sf_list, l2_sf_list
|
|
else:
|
|
# Legacy path: per-expert lists
|
|
self._l1_mat_b = make_b_k_major(torch.stack(self.l1_fp4))
|
|
self._l2_mat_b = make_b_k_major(torch.stack(self.l2_fp4))
|
|
self._l1_scale_b = assemble_scales_3d_side(self.l1_sf)
|
|
self._l2_scale_b = assemble_scales_3d_side(self.l2_sf)
|
|
self.l1_fp4 = None
|
|
self.l1_sf = None
|
|
self.l2_fp4 = None
|
|
self.l2_sf = None
|
|
|
|
self._l1_gsb = torch.tensor(self.l1_gs, dtype=torch.float32, device=self.device)
|
|
self._l2_gsb = torch.tensor(self.l2_gs, dtype=torch.float32, device=self.device)
|
|
self.l1_gs = None
|
|
self.l2_gs = None
|
|
|
|
# Allocate buffers AFTER JIT compilation
|
|
# (CuTeDSL's cute.compile corrupts GPU memory during JIT;
|
|
# tensors allocated before/during compilation may be zeroed)
|
|
#
|
|
# _token_indices: GPU tensor for cudagraph compatibility.
|
|
# CuTeDSL JIT may corrupt GPU memory, so we fill AFTER stacking
|
|
# (which triggers the weight JIT). The GEMM JIT in run_nvfp4_grouped_gemm
|
|
# is triggered on the first run() call; we refill _token_indices after
|
|
# that first call via the _needs_token_refill flag.
|
|
self._token_indices = torch.zeros(
|
|
self.max_num_tokens * self.top_k, dtype=torch.int32, device=self.device
|
|
)
|
|
self._fill_token_indices()
|
|
self._needs_token_refill = True # GEMM JIT may corrupt; refill after first run
|
|
|
|
self._expert_id_range = torch.arange(
|
|
self.num_experts, dtype=torch.int32
|
|
).to(self.device)
|
|
self._expert_offsets_buf = torch.zeros(
|
|
self.num_experts + 1, dtype=torch.int32, device=self.device
|
|
)
|
|
self._allocate_buffers()
|
|
|
|
def prepare_weights_direct(self, l1_fp4, l1_sf, l1_gs, l2_fp4, l2_sf, l2_gs):
|
|
self.l1_fp4 = l1_fp4
|
|
self.l1_sf = l1_sf
|
|
self.l1_gs = l1_gs
|
|
self.l2_fp4 = l2_fp4
|
|
self.l2_sf = l2_sf
|
|
self.l2_gs = l2_gs
|
|
self._l1_mat_b = None
|
|
|
|
def prepare_weights_from_stacked(self, l1_fp4_stacked, l1_sf_stacked,
|
|
l1_gs, l2_fp4_stacked, l2_sf_stacked,
|
|
l2_gs):
|
|
"""Prepare weights from pre-stacked 3D tensors (checkpoint format).
|
|
|
|
Takes (E, N, K_packed) fp4 and (E, N, K_sf) scale tensors directly
|
|
from the checkpoint, avoiding the per-expert list→stack round-trip.
|
|
|
|
The conversion to K-major and swizzled layout happens in _ensure_stacked.
|
|
This just stores the tensors for deferred processing.
|
|
"""
|
|
# Store in checkpoint format (E, N, K) — _ensure_stacked will convert
|
|
self.l1_fp4_stacked = l1_fp4_stacked
|
|
self.l1_sf_stacked = l1_sf_stacked
|
|
self.l1_gs = l1_gs
|
|
self.l2_fp4_stacked = l2_fp4_stacked
|
|
self.l2_sf_stacked = l2_sf_stacked
|
|
self.l2_gs = l2_gs
|
|
self._l1_mat_b = None
|
|
|
|
def prepare_weights_from_dequantized(self, l1_weights_bf16, l2_weights_bf16):
|
|
self.l1_fp4, self.l1_sf, self.l1_gs = [], [], []
|
|
self.l2_fp4, self.l2_sf, self.l2_gs = [], [], []
|
|
for l1_w, l2_w in zip(l1_weights_bf16, l2_weights_bf16):
|
|
l1_w_t = l1_w.T
|
|
w_fp4, w_sf, w_gs = quantize_weight_to_nvfp4(l1_w_t)
|
|
self.l1_fp4.append(w_fp4)
|
|
self.l1_sf.append(w_sf)
|
|
self.l1_gs.append(w_gs)
|
|
l2_w_t = l2_w.T
|
|
w_fp4, w_sf, w_gs = quantize_weight_to_nvfp4(l2_w_t)
|
|
self.l2_fp4.append(w_fp4)
|
|
self.l2_sf.append(w_sf)
|
|
self.l2_gs.append(w_gs)
|
|
self._l1_mat_b = None
|
|
|
|
def _assemble_scales_cudagraph_safe(self, x_sf, expert_offsets,
|
|
padded_expert_offsets,
|
|
padded_x_sf_buf, per_expert_bufs):
|
|
"""Assemble 2D-side activation scales (cudagraph-safe, NO CPU syncs).
|
|
|
|
Phase 1: Scatter x_sf into padded per-expert sections (GPU-only).
|
|
Phase 2: Apply full-buffer Blackwell 32_4_4 swizzle (no Python loops).
|
|
|
|
The buffer is 128-row aligned per expert (from padded_expert_offsets),
|
|
so the full-buffer swizzle produces the correct layout. The GEMM reads
|
|
scale_a using padded_expert_offsets, matching the scatter layout.
|
|
"""
|
|
K_sf = x_sf.shape[1]
|
|
padded_x_sf = padded_x_sf_buf
|
|
padded_x_sf.zero_()
|
|
|
|
# Phase 1: Scatter x_sf into padded per-expert sections (GPU-only)
|
|
total_rows = x_sf.shape[0]
|
|
row_indices = self._row_indices_buf[:total_rows]
|
|
expert_assign = torch.searchsorted(
|
|
expert_offsets[1:], row_indices, right=True
|
|
).clamp(max=self.num_experts - 1)
|
|
local_row = row_indices - expert_offsets[expert_assign]
|
|
dst_rows = padded_expert_offsets[expert_assign] + local_row
|
|
padded_x_sf[dst_rows, :K_sf] = x_sf
|
|
|
|
# Phase 2: Full-buffer swizzle (no CPU sync, no Python loops)
|
|
# padded_x_sf is 128-row aligned per expert and 4-col aligned.
|
|
# to_blocked: (rows, cols) → view(R, 128, C, 4) → permute(0,2,1,3)
|
|
# → reshape(-1, 4, 32, 4) → transpose(1,2) → reshape(-1, 32, 16) → flatten
|
|
rows = padded_x_sf.shape[0]
|
|
cols = padded_x_sf.shape[1]
|
|
R = rows // 128
|
|
C = cols // 4
|
|
blocks = padded_x_sf.view(R, 128, C, 4).permute(0, 2, 1, 3)
|
|
rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16)
|
|
swizzled = rearranged.flatten().view(torch.float8_e4m3fn)
|
|
return swizzled.reshape(rows, cols)
|
|
|
|
def compute_activation_global_scales(self, hidden_states_sample, topk_weights, topk_ids):
|
|
"""Compute activation global scales from a warmup forward pass.
|
|
|
|
Called BEFORE cudagraph capture. Uses the SAME padded GEMM path as run()
|
|
to ensure kernel JIT happens with the same layout, and L2 gs is computed
|
|
from actual L1 output (not an approximation).
|
|
"""
|
|
self._ensure_stacked()
|
|
device = hidden_states_sample.device
|
|
num_tokens = hidden_states_sample.shape[0]
|
|
top_k = topk_ids.shape[1]
|
|
|
|
with torch.no_grad():
|
|
# Build slot mapping (same as run())
|
|
flat_ids = topk_ids.reshape(-1)
|
|
num_slots = num_tokens * top_k
|
|
token_indices = self._token_indices[:num_slots]
|
|
sort_idx = flat_ids.argsort(stable=True)
|
|
sorted_ids = flat_ids[sort_idx]
|
|
sorted_token_ids = token_indices[sort_idx]
|
|
slot_hidden = hidden_states_sample[sorted_token_ids]
|
|
|
|
# L1: get exact gs from quantize_to_nvfp4
|
|
_, _, l1_gs = quantize_to_nvfp4(slot_hidden)
|
|
|
|
# Quantize slot_hidden for GEMM
|
|
slot_x_fp4, slot_x_sf = quantize_activation_nvfp4(slot_hidden, l1_gs)
|
|
|
|
expert_id_range = self._expert_id_range
|
|
tokens_per_expert = (sorted_ids.unsqueeze(1) == expert_id_range.unsqueeze(0)).sum(dim=0).int()
|
|
expert_offsets = self._expert_offsets_buf
|
|
expert_offsets.zero_()
|
|
expert_offsets[1:self.num_experts + 1] = tokens_per_expert.cumsum(0)
|
|
|
|
padded_tokens_per_expert = ((tokens_per_expert + 127) // 128) * 128
|
|
padded_expert_offsets = self._padded_expert_offsets_buf
|
|
padded_expert_offsets.zero_()
|
|
padded_expert_offsets[1:self.num_experts + 1] = padded_tokens_per_expert.cumsum(0)
|
|
|
|
# Compute padded_dst (same as run())
|
|
row_indices = self._row_indices_buf[:num_slots]
|
|
expert_assign = torch.searchsorted(
|
|
expert_offsets[1:], row_indices, right=True
|
|
).clamp(max=self.num_experts - 1)
|
|
local_row = row_indices - expert_offsets[expert_assign]
|
|
padded_dst = padded_expert_offsets[expert_assign] + local_row
|
|
|
|
# Scatter x_fp4 into padded layout
|
|
padded_x_fp4 = self._shared_bufs['hidden_fp4']
|
|
padded_x_fp4.view(torch.uint8).zero_()
|
|
padded_x_fp4.view(torch.uint8)[padded_dst] = slot_x_fp4.view(torch.uint8)
|
|
|
|
l1_scale_a = self._assemble_scales_cudagraph_safe(
|
|
slot_x_sf, expert_offsets[:self.num_experts + 1],
|
|
padded_expert_offsets,
|
|
self._padded_x_sf_buf_l1, self._per_expert_scale_bufs_l1
|
|
)
|
|
l1_gsa = torch.full((self.num_experts,), l1_gs, dtype=torch.float32, device=device)
|
|
|
|
l1_out = run_nvfp4_grouped_gemm(
|
|
mat_a=padded_x_fp4, mat_b=self._l1_mat_b,
|
|
scale_a=l1_scale_a, scale_b=self._l1_scale_b,
|
|
expert_offsets=padded_expert_offsets[1:self.num_experts + 1],
|
|
global_scale_a=l1_gsa, global_scale_b=self._l1_gsb,
|
|
)
|
|
|
|
# Extract real token outputs
|
|
l1_out_real = l1_out[padded_dst]
|
|
|
|
# L2: get exact gs from SiLU(gate)*up
|
|
gate = l1_out_real[:, :self.intermediate_size]
|
|
up = l1_out_real[:, self.intermediate_size:]
|
|
gate_silu = torch.nn.functional.silu(gate)
|
|
if self._swiglu_limit is not None:
|
|
gate_silu = gate_silu.clamp(max=self._swiglu_limit)
|
|
up = up.clamp(min=-self._swiglu_limit, max=self._swiglu_limit)
|
|
activated = gate_silu * up
|
|
_, _, l2_gs = quantize_to_nvfp4(activated)
|
|
|
|
self._l1_activation_global_scale = l1_gs
|
|
self._l2_activation_global_scale = l2_gs
|
|
|
|
|
|
|
|
def run(self, hidden_states, topk_weights, topk_ids, expert_indices=None):
|
|
"""Forward: route tokens to experts, GEMM, combine.
|
|
|
|
Uses torch.library.custom_op (nvfp4::moe_gemm) so torch.compile
|
|
treats this as an opaque op. The custom op calls _run_impl internally.
|
|
"""
|
|
if not hasattr(self, '_runner_id'):
|
|
self._runner_id = register_runner(self)
|
|
return nvfp4_moe_gemm(
|
|
hidden_states, topk_weights, topk_ids,
|
|
self._runner_id, self.hidden_size,
|
|
)
|
|
|
|
def _run_impl(self, hidden_states, topk_weights, topk_ids, expert_indices=None):
|
|
"""Run the NVFP4 MoE forward pass.
|
|
|
|
Handles global→local expert ID remapping for expert parallelism.
|
|
Fully cudagraph-safe: no CPU-GPU syncs, no dynamic shapes.
|
|
|
|
Each expert's slots are padded to multiples of 128 for the GEMM.
|
|
expert_offsets is [0, padded_e0, padded_e0+padded_e1, ...].
|
|
scale_a is produced at those same offsets.
|
|
"""
|
|
num_tokens = hidden_states.shape[0]
|
|
top_k = topk_ids.shape[1]
|
|
device = hidden_states.device
|
|
|
|
self._ensure_stacked()
|
|
|
|
# -- Remap global expert IDs to local IDs --
|
|
local_ids = topk_ids - self.experts_start_idx
|
|
local_mask = (local_ids >= 0) & (local_ids < self.num_experts)
|
|
safe_ids = local_ids.clamp(0, self.num_experts - 1)
|
|
safe_weights = topk_weights * local_mask.float()
|
|
|
|
# -- Build slot mapping --
|
|
flat_ids = safe_ids.reshape(-1)
|
|
flat_weights = safe_weights.reshape(-1)
|
|
num_slots = num_tokens * top_k
|
|
token_indices = self._token_indices[:num_slots]
|
|
|
|
sort_idx = flat_ids.argsort(stable=True)
|
|
sorted_ids = flat_ids[sort_idx]
|
|
sorted_weights = flat_weights[sort_idx]
|
|
sorted_token_ids = token_indices[sort_idx]
|
|
|
|
# Expert offsets (real token counts)
|
|
expert_id_range = self._expert_id_range
|
|
tokens_per_expert = (sorted_ids.unsqueeze(1) == expert_id_range.unsqueeze(0)).sum(dim=0).int()
|
|
expert_offsets = self._expert_offsets_buf
|
|
expert_offsets.zero_()
|
|
expert_offsets[1:self.num_experts + 1] = tokens_per_expert.cumsum(0)
|
|
|
|
# Pad each expert to 128-row alignment (GPU-only computation)
|
|
padded_tokens_per_expert = ((tokens_per_expert + 127) // 128) * 128
|
|
padded_expert_offsets = self._padded_expert_offsets_buf
|
|
padded_expert_offsets.zero_()
|
|
padded_expert_offsets[1:self.num_experts + 1] = padded_tokens_per_expert.cumsum(0)
|
|
total_padded_slots = padded_expert_offsets[self.num_experts]
|
|
|
|
# -- Gather hidden states into slot order, compute padded_dst --
|
|
slot_hidden = hidden_states[sorted_token_ids]
|
|
row_indices = self._row_indices_buf[:num_slots]
|
|
expert_assign = torch.searchsorted(
|
|
expert_offsets[1:], row_indices, right=True
|
|
).clamp(max=self.num_experts - 1)
|
|
local_row = row_indices - expert_offsets[expert_assign]
|
|
padded_dst = padded_expert_offsets[expert_assign] + local_row
|
|
|
|
# === L1: gate + up ===
|
|
# Quantize slot_hidden (sorted tokens), NOT padded_hidden.
|
|
# padded_hidden is padded with zeros; quantizing it produces
|
|
# x_sf rows at padded positions, but x_sf[:num_slots] would
|
|
# only get scales for the first num_slots PADDED rows (expert 0),
|
|
# not the scattered token positions. Quantizing slot_hidden
|
|
# gives x_sf with num_slots rows (one per token), which the
|
|
# scale assembly correctly scatters into padded layout.
|
|
slot_x_fp4, slot_x_sf = quantize_activation_nvfp4(
|
|
slot_hidden, self._l1_activation_global_scale
|
|
)
|
|
# Scatter x_fp4 into padded layout for the GEMM
|
|
# Must scatter as uint8 (float4_e2m1fn_x2 doesn't support index_put)
|
|
padded_x_fp4 = self._shared_bufs['hidden_fp4']
|
|
padded_x_fp4.view(torch.uint8).zero_()
|
|
padded_x_fp4.view(torch.uint8)[padded_dst] = slot_x_fp4.view(torch.uint8)
|
|
|
|
l1_scale_a = self._assemble_scales_cudagraph_safe(
|
|
slot_x_sf, expert_offsets[:self.num_experts + 1],
|
|
padded_expert_offsets,
|
|
self._padded_x_sf_buf_l1, self._per_expert_scale_bufs_l1
|
|
)
|
|
l1_gsa = self._l1_gsa_buf.fill_(self._l1_activation_global_scale)
|
|
|
|
l1_out = run_nvfp4_grouped_gemm(
|
|
mat_a=padded_x_fp4, mat_b=self._l1_mat_b,
|
|
scale_a=l1_scale_a, scale_b=self._l1_scale_b,
|
|
expert_offsets=padded_expert_offsets[1:self.num_experts + 1],
|
|
global_scale_a=l1_gsa, global_scale_b=self._l1_gsb,
|
|
)
|
|
|
|
# Extract real token outputs from padded GEMM output
|
|
l1_out_real = l1_out[padded_dst]
|
|
|
|
# === SiLU(gate) * up (with swiglu_limit clamp) ===
|
|
gate = l1_out_real[:, :self.intermediate_size]
|
|
up = l1_out_real[:, self.intermediate_size:]
|
|
gate_silu = torch.nn.functional.silu(gate)
|
|
# Apply DeepSeek-V4 swiglu_limit: clamp both silu(gate) and up
|
|
if self._swiglu_limit is not None:
|
|
gate_silu = gate_silu.clamp(max=self._swiglu_limit)
|
|
up = up.clamp(min=-self._swiglu_limit, max=self._swiglu_limit)
|
|
activated = gate_silu * up
|
|
|
|
# === L2: down ===
|
|
# Quantize activated (per-token), scatter into padded FP4 buffer
|
|
slot_l2_x_fp4, slot_l2_x_sf = quantize_activation_nvfp4(
|
|
activated, self._l2_activation_global_scale
|
|
)
|
|
padded_activated_fp4 = self._shared_bufs['activated_fp4']
|
|
padded_activated_fp4.view(torch.uint8).zero_()
|
|
padded_activated_fp4.view(torch.uint8)[padded_dst] = slot_l2_x_fp4.view(torch.uint8)
|
|
|
|
l2_scale_a = self._assemble_scales_cudagraph_safe(
|
|
slot_l2_x_sf, expert_offsets[:self.num_experts + 1],
|
|
padded_expert_offsets,
|
|
self._padded_x_sf_buf_l2, self._per_expert_scale_bufs_l2
|
|
)
|
|
l2_gsa = self._l2_gsa_buf.fill_(self._l2_activation_global_scale)
|
|
|
|
l2_out = run_nvfp4_grouped_gemm(
|
|
mat_a=padded_activated_fp4, mat_b=self._l2_mat_b,
|
|
scale_a=l2_scale_a, scale_b=self._l2_scale_b,
|
|
expert_offsets=padded_expert_offsets[1:self.num_experts + 1],
|
|
global_scale_a=l2_gsa, global_scale_b=self._l2_gsb,
|
|
)
|
|
|
|
l2_out_real = l2_out[padded_dst]
|
|
|
|
# === Scatter -> final output ===
|
|
y = self._output_buf[:num_tokens]
|
|
y.zero_()
|
|
weighted_out = l2_out_real * sorted_weights.unsqueeze(1).to(l2_out_real.dtype)
|
|
y.scatter_add_(
|
|
0,
|
|
sorted_token_ids.unsqueeze(1).expand(-1, self.hidden_size),
|
|
weighted_out,
|
|
)
|
|
|
|
# Refill _token_indices after GEMM JIT on first call
|
|
# (CuTeDSL's cute.compile may corrupt GPU memory during first GEMM)
|
|
if self._needs_token_refill:
|
|
self._fill_token_indices()
|
|
self._needs_token_refill = False
|
|
|
|
return y
|