rewrite: cudagraph-safe runner - no dynamic slicing, no GPU scalar indices
- Removed all [:total_slots] dynamic slicing with GPU scalars - slot_hidden gathers from hidden_states directly using sorted_token_ids - scatter_add uses full sorted_token_ids (padding slots have zero weight) - _assemble_scales_cudagraph_safe returns 2D via padded_scales.shape[0] - Fixed padded_scales_buf allocation via float16->float8 cast - GEMM output size: n_dim * 2 for float4_e2m1fn_x2 packed format
This commit is contained in:
@@ -1,9 +1,17 @@
|
||||
"""
|
||||
vLLM integration for the CuTeDSL NVFP4 MoE kernel.
|
||||
|
||||
CUDA-graph-compatible: no .item() calls, no Python loops over tokens,
|
||||
no dynamic shapes, no CPU-GPU syncs, no torch.cuda.synchronize().
|
||||
All buffers pre-allocated at max_num_tokens size.
|
||||
CUDA-graph-compatible design:
|
||||
- All intermediate buffers pre-allocated at max_num_tokens * top_k size
|
||||
- No .item(), .tolist(), .cpu() — zero CPU-GPU syncs
|
||||
- No dynamic slicing with GPU scalars — always operate on full pre-allocated buffers
|
||||
- Extra slots (beyond real tokens) are zero and contribute nothing to output
|
||||
- Fixed-shape tensors throughout the forward pass
|
||||
|
||||
vLLM cudagraph captures at fixed token budgets (1,2,4,8,...,8192).
|
||||
During capture, num_tokens equals the budget — all shapes are fixed.
|
||||
During replay, inputs are padded to the budget size. Our runner always
|
||||
processes max_slots = budget * top_k rows; padding rows are zeros.
|
||||
"""
|
||||
import torch
|
||||
|
||||
@@ -16,7 +24,6 @@ from cutedsl.bridge import (
|
||||
)
|
||||
from cutedsl.kernel.moe.torch_scaled_grouped_mm import (
|
||||
ceil_div as cutedsl_ceil_div,
|
||||
round_up as cutedsl_round_up,
|
||||
pad_and_swizzle_single,
|
||||
)
|
||||
|
||||
@@ -24,7 +31,8 @@ from cutedsl.kernel.moe.torch_scaled_grouped_mm import (
|
||||
class CuTeDSLMoERunner:
|
||||
"""Manages NVFP4 MoE execution via the CuTeDSL kernel.
|
||||
|
||||
CUDA-graph-compatible: all buffers pre-allocated, no CPU-GPU syncs.
|
||||
CUDA-graph-compatible: all buffers pre-allocated, no CPU-GPU syncs,
|
||||
no dynamic shapes. Always computes at max_num_tokens * top_k capacity.
|
||||
"""
|
||||
|
||||
def __init__(self, num_experts, hidden_size, intermediate_size,
|
||||
@@ -36,6 +44,7 @@ class CuTeDSLMoERunner:
|
||||
self.top_k = top_k
|
||||
self.device = device
|
||||
|
||||
# Weight storage (set before _ensure_stacked)
|
||||
self.l1_fp4 = None
|
||||
self.l1_sf = None
|
||||
self.l1_gs = None
|
||||
@@ -43,6 +52,7 @@ class CuTeDSLMoERunner:
|
||||
self.l2_sf = None
|
||||
self.l2_gs = None
|
||||
|
||||
# Stacked weight tensors (set in _ensure_stacked)
|
||||
self._l1_mat_b = None
|
||||
self._l2_mat_b = None
|
||||
self._l1_scale_b = None
|
||||
@@ -53,12 +63,11 @@ class CuTeDSLMoERunner:
|
||||
self._l1_activation_global_scale = 1.0 / 2688.0
|
||||
self._l2_activation_global_scale = 1.0 / 2688.0
|
||||
|
||||
# Pre-allocated buffers (set in _allocate_buffers)
|
||||
# Pre-allocated cudagraph buffers (set in _allocate_buffers)
|
||||
self._token_indices = None
|
||||
self._expert_id_range = None
|
||||
self._output_buf = None
|
||||
self._padded_scales_buf = None
|
||||
self._expert_offsets_buf = None
|
||||
self._padded_scales_buf = None
|
||||
self._padded_expert_offsets_buf = None
|
||||
self._buffers_allocated = False
|
||||
|
||||
@@ -67,8 +76,10 @@ class CuTeDSLMoERunner:
|
||||
max_slots = self.max_num_tokens * self.top_k
|
||||
K_sf = cutedsl_ceil_div(self.hidden_size, 16)
|
||||
padded_cols = cutedsl_ceil_div(K_sf, 4) * 4
|
||||
max_padded_rows = self.num_experts * 128 # worst case: 1 token per expert, each padded to 128
|
||||
# Worst case: 1 token per expert, each padded to 128 rows
|
||||
max_padded_rows = self.num_experts * 128
|
||||
|
||||
# Slot -> token mapping: [0,0,...,0, 1,1,...,1, ...] (top_k repeats)
|
||||
self._token_indices = torch.arange(
|
||||
self.max_num_tokens, device=self.device
|
||||
).unsqueeze(1).expand(-1, self.top_k).reshape(-1)
|
||||
@@ -81,14 +92,9 @@ class CuTeDSLMoERunner:
|
||||
self._padded_expert_offsets_buf = torch.zeros(
|
||||
self.num_experts + 1, dtype=torch.int32, device=self.device
|
||||
)
|
||||
|
||||
self._output_buf = torch.zeros(
|
||||
max_slots, self.hidden_size, dtype=torch.bfloat16, device=self.device
|
||||
)
|
||||
|
||||
self._padded_scales_buf = torch.zeros(
|
||||
max_padded_rows, padded_cols, dtype=torch.float8_e4m3fn, device=self.device
|
||||
)
|
||||
max_padded_rows, padded_cols, dtype=torch.float16, device=self.device
|
||||
).to(torch.float8_e4m3fn)
|
||||
|
||||
self._buffers_allocated = True
|
||||
|
||||
@@ -137,10 +143,8 @@ class CuTeDSLMoERunner:
|
||||
def _assemble_scales_cudagraph_safe(self, x_sf, expert_offsets):
|
||||
"""Assemble 2D-side activation scales (cudagraph-safe, no CPU sync).
|
||||
|
||||
Pre-allocates a padded buffer at max size. Uses index_copy_ with
|
||||
GPU-computed indices to scatter scale data into padded positions.
|
||||
Then applies the swizzle to the whole buffer.
|
||||
|
||||
Uses GPU-computed indices to scatter scale data into padded positions,
|
||||
then applies the swizzle. Returns 2D tensor.
|
||||
No .item(), no .tolist(), no Python control flow on GPU data.
|
||||
"""
|
||||
num_experts = self.num_experts
|
||||
@@ -164,16 +168,12 @@ class CuTeDSLMoERunner:
|
||||
padded_scales = self._padded_scales_buf[:total_padded_rows, :padded_cols]
|
||||
padded_scales.zero_()
|
||||
|
||||
# Build index mapping: for each row in x_sf, where does it go in padded_scales?
|
||||
# Row i in x_sf belongs to expert e where expert_offsets[e] <= i < expert_offsets[e+1]
|
||||
# Its destination is padded_expert_offsets[e] + (i - expert_offsets[e])
|
||||
|
||||
# Use searchsorted to find which expert each row belongs to
|
||||
# Build index mapping: for each row in x_sf, which expert does it belong to?
|
||||
total_rows = x_sf.shape[0]
|
||||
# Use pre-allocated token indices (sliced to actual size)
|
||||
row_indices = self._token_indices[:total_rows]
|
||||
# expert_assign[i] = which expert row i belongs to
|
||||
expert_assign = torch.searchsorted(expert_offsets[1:], row_indices, right=False).clamp(max=num_experts - 1)
|
||||
expert_assign = torch.searchsorted(
|
||||
expert_offsets[1:], row_indices, right=False
|
||||
).clamp(max=num_experts - 1)
|
||||
|
||||
# Destination row in padded buffer
|
||||
local_row = row_indices - expert_offsets[expert_assign]
|
||||
@@ -182,91 +182,111 @@ class CuTeDSLMoERunner:
|
||||
# Scatter x_sf into padded_scales
|
||||
padded_scales[dst_rows, :K_sf] = x_sf
|
||||
|
||||
# Apply swizzle to the whole padded tensor, return 2D for 2D-side scale_a
|
||||
# to_blocked preserves element count, so reshape to match padded shape
|
||||
# Apply swizzle, reshape to 2D (element count preserved by swizzle)
|
||||
swizzled = pad_and_swizzle_single(padded_scales)
|
||||
return swizzled.reshape(padded_scales.shape[0], -1)
|
||||
|
||||
def run(self, hidden_states, topk_weights, topk_ids, expert_indices=None):
|
||||
num_tokens, hidden_size = hidden_states.shape
|
||||
"""Run the NVFP4 MoE forward pass.
|
||||
|
||||
Fully cudagraph-safe: no CPU-GPU syncs, no dynamic shapes.
|
||||
|
||||
expert_offsets are computed from the actual token distribution
|
||||
via GPU-only ops (argsort, broadcast ==, cumsum). These offsets
|
||||
are passed to the GEMM as a GPU tensor, never converted to Python.
|
||||
|
||||
The GEMM and quantize functions see the full slot buffer.
|
||||
Padding rows are zeros that produce zero output, contributing
|
||||
nothing to the final scatter_add.
|
||||
|
||||
Args:
|
||||
hidden_states: (num_tokens, hidden_size) bf16
|
||||
topk_weights: (num_tokens, top_k) float32
|
||||
topk_ids: (num_tokens, top_k) int
|
||||
expert_indices: ignored (uses all experts)
|
||||
|
||||
Returns:
|
||||
(num_tokens, hidden_size) bf16 - MoE output
|
||||
"""
|
||||
num_tokens = hidden_states.shape[0]
|
||||
top_k = topk_ids.shape[1]
|
||||
device = hidden_states.device
|
||||
|
||||
if expert_indices is None:
|
||||
expert_indices = list(range(self.num_experts))
|
||||
|
||||
num_experts = len(expert_indices)
|
||||
self._ensure_stacked()
|
||||
|
||||
# ── Build slot mapping ──
|
||||
# -- Build slot mapping --
|
||||
flat_ids = topk_ids.reshape(-1)
|
||||
flat_weights = topk_weights.reshape(-1)
|
||||
token_indices = self._token_indices[:num_tokens * top_k]
|
||||
num_slots = num_tokens * top_k
|
||||
token_indices = self._token_indices[:num_slots]
|
||||
|
||||
sort_idx = flat_ids.argsort(stable=True)
|
||||
sorted_ids = flat_ids[sort_idx]
|
||||
sorted_weights = flat_weights[sort_idx]
|
||||
sorted_token_ids = token_indices[sort_idx]
|
||||
|
||||
# Expert offsets (GPU-only)
|
||||
expert_id_range = self._expert_id_range[:num_experts]
|
||||
# Expert offsets (GPU-only, never touches CPU)
|
||||
expert_id_range = self._expert_id_range
|
||||
tokens_per_expert = (sorted_ids.unsqueeze(1) == expert_id_range.unsqueeze(0)).sum(dim=0).int()
|
||||
expert_offsets = self._expert_offsets_buf
|
||||
expert_offsets.zero_()
|
||||
expert_offsets[1:num_experts + 1] = tokens_per_expert.cumsum(0)
|
||||
expert_offsets[1:self.num_experts + 1] = tokens_per_expert.cumsum(0)
|
||||
|
||||
total_slots = expert_offsets[num_experts]
|
||||
# -- Gather hidden states into slot order --
|
||||
slot_hidden = hidden_states[sorted_token_ids]
|
||||
|
||||
slot_hidden = hidden_states[sorted_token_ids[:total_slots]]
|
||||
|
||||
# ════════════════════════════════════════════════════════════
|
||||
# L1: gate + up
|
||||
# ════════════════════════════════════════════════════════════
|
||||
# === L1: gate + up ===
|
||||
x_fp4, x_sf = quantize_activation_nvfp4(
|
||||
slot_hidden, self._l1_activation_global_scale
|
||||
)
|
||||
|
||||
l1_scale_a = self._assemble_scales_cudagraph_safe(x_sf, expert_offsets[:num_experts + 1])
|
||||
l1_gsa = torch.full((num_experts,), self._l1_activation_global_scale,
|
||||
dtype=torch.float32, device=device)
|
||||
l1_scale_a = self._assemble_scales_cudagraph_safe(
|
||||
x_sf, expert_offsets[:self.num_experts + 1]
|
||||
)
|
||||
l1_gsa = torch.full(
|
||||
(self.num_experts,), self._l1_activation_global_scale,
|
||||
dtype=torch.float32, device=device
|
||||
)
|
||||
|
||||
l1_out = run_nvfp4_grouped_gemm(
|
||||
mat_a=x_fp4, mat_b=self._l1_mat_b,
|
||||
scale_a=l1_scale_a, scale_b=self._l1_scale_b,
|
||||
expert_offsets=expert_offsets[:num_experts + 1],
|
||||
expert_offsets=expert_offsets[:self.num_experts + 1],
|
||||
global_scale_a=l1_gsa, global_scale_b=self._l1_gsb,
|
||||
)
|
||||
|
||||
# ════════════════════════════════════════════════════════════
|
||||
# SiLU(gate) * up
|
||||
# ════════════════════════════════════════════════════════════
|
||||
# === SiLU(gate) * up ===
|
||||
gate = l1_out[:, :self.intermediate_size]
|
||||
up = l1_out[:, self.intermediate_size:]
|
||||
activated = torch.nn.functional.silu(gate) * up
|
||||
|
||||
# ════════════════════════════════════════════════════════════
|
||||
# L2: down
|
||||
# ════════════════════════════════════════════════════════════
|
||||
# === L2: down ===
|
||||
l2_x_fp4, l2_x_sf = quantize_activation_nvfp4(
|
||||
activated, self._l2_activation_global_scale
|
||||
)
|
||||
|
||||
l2_scale_a = self._assemble_scales_cudagraph_safe(l2_x_sf, expert_offsets[:num_experts + 1])
|
||||
l2_gsa = torch.full((num_experts,), self._l2_activation_global_scale,
|
||||
dtype=torch.float32, device=device)
|
||||
l2_scale_a = self._assemble_scales_cudagraph_safe(
|
||||
l2_x_sf, expert_offsets[:self.num_experts + 1]
|
||||
)
|
||||
l2_gsa = torch.full(
|
||||
(self.num_experts,), self._l2_activation_global_scale,
|
||||
dtype=torch.float32, device=device
|
||||
)
|
||||
|
||||
l2_out = run_nvfp4_grouped_gemm(
|
||||
mat_a=l2_x_fp4, mat_b=self._l2_mat_b,
|
||||
scale_a=l2_scale_a, scale_b=self._l2_scale_b,
|
||||
expert_offsets=expert_offsets[:num_experts + 1],
|
||||
expert_offsets=expert_offsets[:self.num_experts + 1],
|
||||
global_scale_a=l2_gsa, global_scale_b=self._l2_gsb,
|
||||
)
|
||||
|
||||
# ════════════════════════════════════════════════════════════
|
||||
# Scatter → final output
|
||||
# ════════════════════════════════════════════════════════════
|
||||
y = torch.zeros(num_tokens, hidden_size, dtype=torch.bfloat16, device=device)
|
||||
weighted_out = l2_out * sorted_weights[:total_slots].unsqueeze(1).to(l2_out.dtype)
|
||||
y.scatter_add_(0, sorted_token_ids[:total_slots].unsqueeze(1).expand(-1, hidden_size), weighted_out)
|
||||
# === Scatter -> final output ===
|
||||
y = torch.zeros(num_tokens, self.hidden_size, dtype=torch.bfloat16, device=device)
|
||||
weighted_out = l2_out * sorted_weights.unsqueeze(1).to(l2_out.dtype)
|
||||
y.scatter_add_(
|
||||
0,
|
||||
sorted_token_ids.unsqueeze(1).expand(-1, self.hidden_size),
|
||||
weighted_out,
|
||||
)
|
||||
|
||||
return y
|
||||
|
||||
Reference in New Issue
Block a user