rewrite: cudagraph-safe runner - no dynamic slicing, no GPU scalar indices

- Removed all [:total_slots] dynamic slicing with GPU scalars
- slot_hidden gathers from hidden_states directly using sorted_token_ids
- scatter_add uses full sorted_token_ids (padding slots have zero weight)
- _assemble_scales_cudagraph_safe returns 2D via padded_scales.shape[0]
- Fixed padded_scales_buf allocation via float16->float8 cast
- GEMM output size: n_dim * 2 for float4_e2m1fn_x2 packed format
This commit is contained in:
2026-05-16 18:44:25 +00:00
parent 4300775bfe
commit 53c25bee0b

View File

@@ -1,9 +1,17 @@
"""
vLLM integration for the CuTeDSL NVFP4 MoE kernel.
CUDA-graph-compatible: no .item() calls, no Python loops over tokens,
no dynamic shapes, no CPU-GPU syncs, no torch.cuda.synchronize().
All buffers pre-allocated at max_num_tokens size.
CUDA-graph-compatible design:
- All intermediate buffers pre-allocated at max_num_tokens * top_k size
- No .item(), .tolist(), .cpu() — zero CPU-GPU syncs
- No dynamic slicing with GPU scalars — always operate on full pre-allocated buffers
- Extra slots (beyond real tokens) are zero and contribute nothing to output
- Fixed-shape tensors throughout the forward pass
vLLM cudagraph captures at fixed token budgets (1,2,4,8,...,8192).
During capture, num_tokens equals the budget — all shapes are fixed.
During replay, inputs are padded to the budget size. Our runner always
processes max_slots = budget * top_k rows; padding rows are zeros.
"""
import torch
@@ -16,7 +24,6 @@ from cutedsl.bridge import (
)
from cutedsl.kernel.moe.torch_scaled_grouped_mm import (
ceil_div as cutedsl_ceil_div,
round_up as cutedsl_round_up,
pad_and_swizzle_single,
)
@@ -24,7 +31,8 @@ from cutedsl.kernel.moe.torch_scaled_grouped_mm import (
class CuTeDSLMoERunner:
"""Manages NVFP4 MoE execution via the CuTeDSL kernel.
CUDA-graph-compatible: all buffers pre-allocated, no CPU-GPU syncs.
CUDA-graph-compatible: all buffers pre-allocated, no CPU-GPU syncs,
no dynamic shapes. Always computes at max_num_tokens * top_k capacity.
"""
def __init__(self, num_experts, hidden_size, intermediate_size,
@@ -36,6 +44,7 @@ class CuTeDSLMoERunner:
self.top_k = top_k
self.device = device
# Weight storage (set before _ensure_stacked)
self.l1_fp4 = None
self.l1_sf = None
self.l1_gs = None
@@ -43,6 +52,7 @@ class CuTeDSLMoERunner:
self.l2_sf = None
self.l2_gs = None
# Stacked weight tensors (set in _ensure_stacked)
self._l1_mat_b = None
self._l2_mat_b = None
self._l1_scale_b = None
@@ -53,12 +63,11 @@ class CuTeDSLMoERunner:
self._l1_activation_global_scale = 1.0 / 2688.0
self._l2_activation_global_scale = 1.0 / 2688.0
# Pre-allocated buffers (set in _allocate_buffers)
# Pre-allocated cudagraph buffers (set in _allocate_buffers)
self._token_indices = None
self._expert_id_range = None
self._output_buf = None
self._padded_scales_buf = None
self._expert_offsets_buf = None
self._padded_scales_buf = None
self._padded_expert_offsets_buf = None
self._buffers_allocated = False
@@ -67,8 +76,10 @@ class CuTeDSLMoERunner:
max_slots = self.max_num_tokens * self.top_k
K_sf = cutedsl_ceil_div(self.hidden_size, 16)
padded_cols = cutedsl_ceil_div(K_sf, 4) * 4
max_padded_rows = self.num_experts * 128 # worst case: 1 token per expert, each padded to 128
# Worst case: 1 token per expert, each padded to 128 rows
max_padded_rows = self.num_experts * 128
# Slot -> token mapping: [0,0,...,0, 1,1,...,1, ...] (top_k repeats)
self._token_indices = torch.arange(
self.max_num_tokens, device=self.device
).unsqueeze(1).expand(-1, self.top_k).reshape(-1)
@@ -81,14 +92,9 @@ class CuTeDSLMoERunner:
self._padded_expert_offsets_buf = torch.zeros(
self.num_experts + 1, dtype=torch.int32, device=self.device
)
self._output_buf = torch.zeros(
max_slots, self.hidden_size, dtype=torch.bfloat16, device=self.device
)
self._padded_scales_buf = torch.zeros(
max_padded_rows, padded_cols, dtype=torch.float8_e4m3fn, device=self.device
)
max_padded_rows, padded_cols, dtype=torch.float16, device=self.device
).to(torch.float8_e4m3fn)
self._buffers_allocated = True
@@ -137,10 +143,8 @@ class CuTeDSLMoERunner:
def _assemble_scales_cudagraph_safe(self, x_sf, expert_offsets):
"""Assemble 2D-side activation scales (cudagraph-safe, no CPU sync).
Pre-allocates a padded buffer at max size. Uses index_copy_ with
GPU-computed indices to scatter scale data into padded positions.
Then applies the swizzle to the whole buffer.
Uses GPU-computed indices to scatter scale data into padded positions,
then applies the swizzle. Returns 2D tensor.
No .item(), no .tolist(), no Python control flow on GPU data.
"""
num_experts = self.num_experts
@@ -164,16 +168,12 @@ class CuTeDSLMoERunner:
padded_scales = self._padded_scales_buf[:total_padded_rows, :padded_cols]
padded_scales.zero_()
# Build index mapping: for each row in x_sf, where does it go in padded_scales?
# Row i in x_sf belongs to expert e where expert_offsets[e] <= i < expert_offsets[e+1]
# Its destination is padded_expert_offsets[e] + (i - expert_offsets[e])
# Use searchsorted to find which expert each row belongs to
# Build index mapping: for each row in x_sf, which expert does it belong to?
total_rows = x_sf.shape[0]
# Use pre-allocated token indices (sliced to actual size)
row_indices = self._token_indices[:total_rows]
# expert_assign[i] = which expert row i belongs to
expert_assign = torch.searchsorted(expert_offsets[1:], row_indices, right=False).clamp(max=num_experts - 1)
expert_assign = torch.searchsorted(
expert_offsets[1:], row_indices, right=False
).clamp(max=num_experts - 1)
# Destination row in padded buffer
local_row = row_indices - expert_offsets[expert_assign]
@@ -182,91 +182,111 @@ class CuTeDSLMoERunner:
# Scatter x_sf into padded_scales
padded_scales[dst_rows, :K_sf] = x_sf
# Apply swizzle to the whole padded tensor, return 2D for 2D-side scale_a
# to_blocked preserves element count, so reshape to match padded shape
# Apply swizzle, reshape to 2D (element count preserved by swizzle)
swizzled = pad_and_swizzle_single(padded_scales)
return swizzled.reshape(padded_scales.shape[0], -1)
def run(self, hidden_states, topk_weights, topk_ids, expert_indices=None):
num_tokens, hidden_size = hidden_states.shape
"""Run the NVFP4 MoE forward pass.
Fully cudagraph-safe: no CPU-GPU syncs, no dynamic shapes.
expert_offsets are computed from the actual token distribution
via GPU-only ops (argsort, broadcast ==, cumsum). These offsets
are passed to the GEMM as a GPU tensor, never converted to Python.
The GEMM and quantize functions see the full slot buffer.
Padding rows are zeros that produce zero output, contributing
nothing to the final scatter_add.
Args:
hidden_states: (num_tokens, hidden_size) bf16
topk_weights: (num_tokens, top_k) float32
topk_ids: (num_tokens, top_k) int
expert_indices: ignored (uses all experts)
Returns:
(num_tokens, hidden_size) bf16 - MoE output
"""
num_tokens = hidden_states.shape[0]
top_k = topk_ids.shape[1]
device = hidden_states.device
if expert_indices is None:
expert_indices = list(range(self.num_experts))
num_experts = len(expert_indices)
self._ensure_stacked()
# ── Build slot mapping ──
# -- Build slot mapping --
flat_ids = topk_ids.reshape(-1)
flat_weights = topk_weights.reshape(-1)
token_indices = self._token_indices[:num_tokens * top_k]
num_slots = num_tokens * top_k
token_indices = self._token_indices[:num_slots]
sort_idx = flat_ids.argsort(stable=True)
sorted_ids = flat_ids[sort_idx]
sorted_weights = flat_weights[sort_idx]
sorted_token_ids = token_indices[sort_idx]
# Expert offsets (GPU-only)
expert_id_range = self._expert_id_range[:num_experts]
# Expert offsets (GPU-only, never touches CPU)
expert_id_range = self._expert_id_range
tokens_per_expert = (sorted_ids.unsqueeze(1) == expert_id_range.unsqueeze(0)).sum(dim=0).int()
expert_offsets = self._expert_offsets_buf
expert_offsets.zero_()
expert_offsets[1:num_experts + 1] = tokens_per_expert.cumsum(0)
expert_offsets[1:self.num_experts + 1] = tokens_per_expert.cumsum(0)
total_slots = expert_offsets[num_experts]
# -- Gather hidden states into slot order --
slot_hidden = hidden_states[sorted_token_ids]
slot_hidden = hidden_states[sorted_token_ids[:total_slots]]
# ════════════════════════════════════════════════════════════
# L1: gate + up
# ════════════════════════════════════════════════════════════
# === L1: gate + up ===
x_fp4, x_sf = quantize_activation_nvfp4(
slot_hidden, self._l1_activation_global_scale
)
l1_scale_a = self._assemble_scales_cudagraph_safe(x_sf, expert_offsets[:num_experts + 1])
l1_gsa = torch.full((num_experts,), self._l1_activation_global_scale,
dtype=torch.float32, device=device)
l1_scale_a = self._assemble_scales_cudagraph_safe(
x_sf, expert_offsets[:self.num_experts + 1]
)
l1_gsa = torch.full(
(self.num_experts,), self._l1_activation_global_scale,
dtype=torch.float32, device=device
)
l1_out = run_nvfp4_grouped_gemm(
mat_a=x_fp4, mat_b=self._l1_mat_b,
scale_a=l1_scale_a, scale_b=self._l1_scale_b,
expert_offsets=expert_offsets[:num_experts + 1],
expert_offsets=expert_offsets[:self.num_experts + 1],
global_scale_a=l1_gsa, global_scale_b=self._l1_gsb,
)
# ════════════════════════════════════════════════════════════
# SiLU(gate) * up
# ════════════════════════════════════════════════════════════
# === SiLU(gate) * up ===
gate = l1_out[:, :self.intermediate_size]
up = l1_out[:, self.intermediate_size:]
activated = torch.nn.functional.silu(gate) * up
# ════════════════════════════════════════════════════════════
# L2: down
# ════════════════════════════════════════════════════════════
# === L2: down ===
l2_x_fp4, l2_x_sf = quantize_activation_nvfp4(
activated, self._l2_activation_global_scale
)
l2_scale_a = self._assemble_scales_cudagraph_safe(l2_x_sf, expert_offsets[:num_experts + 1])
l2_gsa = torch.full((num_experts,), self._l2_activation_global_scale,
dtype=torch.float32, device=device)
l2_scale_a = self._assemble_scales_cudagraph_safe(
l2_x_sf, expert_offsets[:self.num_experts + 1]
)
l2_gsa = torch.full(
(self.num_experts,), self._l2_activation_global_scale,
dtype=torch.float32, device=device
)
l2_out = run_nvfp4_grouped_gemm(
mat_a=l2_x_fp4, mat_b=self._l2_mat_b,
scale_a=l2_scale_a, scale_b=self._l2_scale_b,
expert_offsets=expert_offsets[:num_experts + 1],
expert_offsets=expert_offsets[:self.num_experts + 1],
global_scale_a=l2_gsa, global_scale_b=self._l2_gsb,
)
# ════════════════════════════════════════════════════════════
# Scatter → final output
# ════════════════════════════════════════════════════════════
y = torch.zeros(num_tokens, hidden_size, dtype=torch.bfloat16, device=device)
weighted_out = l2_out * sorted_weights[:total_slots].unsqueeze(1).to(l2_out.dtype)
y.scatter_add_(0, sorted_token_ids[:total_slots].unsqueeze(1).expand(-1, hidden_size), weighted_out)
# === Scatter -> final output ===
y = torch.zeros(num_tokens, self.hidden_size, dtype=torch.bfloat16, device=device)
weighted_out = l2_out * sorted_weights.unsqueeze(1).to(l2_out.dtype)
y.scatter_add_(
0,
sorted_token_ids.unsqueeze(1).expand(-1, self.hidden_size),
weighted_out,
)
return y