Files
nvfp4-megamoe-kernel/cutedsl/shared_expert_pipeline.py

304 lines
12 KiB
Python
Raw Normal View History

"""CuTeDSL Shared Expert Pipeline
NVFP4 inference for DeepSeek V4 shared experts.
Uses ScaledGroupedGemmKernel with num_groups=1.
Pipeline:
1. Quantize activation: BF16 NVFP4 (using warmup gs)
2. L1 GEMM: NVFP4_act × NVFP4_weight(gate_up) BF16
3. SiLU(gate) * up BF16
4. Re-quantize: BF16 NVFP4 (using warmup gs)
5. L2 GEMM: NVFP4_act × NVFP4_weight(down) BF16
Unlike MoE, there's no routing, no scatter, no expert offsets.
All tokens go through the same expert (the shared expert).
Scale assembly is just: quantize activation pad to 128-row alignment Blackwell swizzle.
CUDA-graph-compatible: all buffers pre-allocated, no CPU-GPU syncs,
no dynamic shapes. Padding rows are zeros that contribute nothing to GEMM output.
"""
import torch
from cutedsl.bridge import (
quantize_activation_nvfp4,
class _SharedExpertApply(torch.autograd.Function):
"""Custom autograd function to make CuTeDSL runner opaque to torch.compile."""
@staticmethod
def forward(ctx, runner, hidden_states):
return runner._run_impl(hidden_states)
quantize_to_nvfp4,
make_b_k_major,
assemble_scales_3d_side,
run_nvfp4_grouped_gemm,
)
from cutedsl.kernel.moe.torch_scaled_grouped_mm import (
ceil_div as cutedsl_ceil_div,
pad_and_swizzle_single,
)
class CuTeDSLSharedExpertRunner:
"""NVFP4 shared expert runner using CuTeDSL GEMM (num_groups=1).
CUDA-graph-compatible: all buffers pre-allocated, no CPU-GPU syncs.
"""
def __init__(
self,
hidden_size: int,
intermediate_size: int,
max_num_tokens: int = 8192,
device: str = "cuda",
swiglu_limit: float = 10.0,
):
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.max_num_tokens = max_num_tokens
self.device = device
self.swiglu_limit = swiglu_limit
# Weights (set after construction, then call finalize_weights)
self.l1_fp4 = None
self.l1_sf = None
self.l1_gs = None
self.l2_fp4 = None
self.l2_sf = None
self.l2_gs = None
# Processed weights (set by finalize_weights)
self._l1_mat_b = None
self._l2_mat_b = None
self._l1_scale_b = None
self._l2_scale_b = None
self._l1_gsb = None
self._l2_gsb = None
# Activation global scales (set by compute_activation_global_scales)
self._l1_activation_global_scale = 1.0 / (6.0 * 448.0)
self._l2_activation_global_scale = 1.0 / (6.0 * 448.0)
# Pre-allocated cudagraph buffers (set in _allocate_buffers)
self._padded_x_fp4_buf_l1 = None
self._padded_x_sf_buf_l1 = None
self._padded_x_fp4_buf_l2 = None
self._padded_x_sf_buf_l2 = None
self._l1_gsa_buf = None
self._l2_gsa_buf = None
self._expert_offsets_buf = None
self._buffers_allocated = False
def set_swiglu_limit(self, limit: float):
self.swiglu_limit = limit
def finalize_weights(self):
"""Process weights for CuTeDSL GEMM. Must be called after setting l1/l2 weights."""
# Stack weights and convert to K-major
# l1_fp4/l2_fp4 are lists with 1 element (the shared expert)
self._l1_mat_b = make_b_k_major(torch.stack(self.l1_fp4)) # (1, K_packed, N_packed)
self._l2_mat_b = make_b_k_major(torch.stack(self.l2_fp4))
self._l1_scale_b = assemble_scales_3d_side(self.l1_sf) # (1, N, K_sf_padded)
self._l2_scale_b = assemble_scales_3d_side(self.l2_sf)
self._l1_gsb = torch.tensor(self.l1_gs, dtype=torch.float32, device=self.device)
self._l2_gsb = torch.tensor(self.l2_gs, dtype=torch.float32, device=self.device)
# Free raw weights
self.l1_fp4 = None
self.l1_sf = None
self.l1_gs = None
self.l2_fp4 = None
self.l2_sf = None
self.l2_gs = None
def _allocate_buffers(self):
"""Pre-allocate all buffers at max size for cudagraph compatibility."""
max_rows = cutedsl_ceil_div(self.max_num_tokens, 128) * 128 # pad to 128
# L1: hidden_size packed, L2: intermediate_size packed
self._padded_x_fp4_buf_l1 = torch.zeros(
max_rows, self.hidden_size // 2, dtype=torch.uint8, device=self.device
).view(torch.float4_e2m1fn_x2)
self._padded_x_fp4_buf_l2 = torch.zeros(
max_rows, self.intermediate_size // 2, dtype=torch.uint8, device=self.device
).view(torch.float4_e2m1fn_x2)
# Padded scale buffers (need same padded dimensions as pad_and_swizzle_single produces)
K_sf_l1 = cutedsl_ceil_div(self.hidden_size, 16)
padded_cols_l1 = cutedsl_ceil_div(K_sf_l1, 4) * 4
K_sf_l2 = cutedsl_ceil_div(self.intermediate_size, 16)
padded_cols_l2 = cutedsl_ceil_div(K_sf_l2, 4) * 4
self._padded_x_sf_buf_l1 = torch.zeros(
max_rows, padded_cols_l1, dtype=torch.float16, device=self.device
).to(torch.float8_e4m3fn)
self._padded_x_sf_buf_l2 = torch.zeros(
max_rows, padded_cols_l2, dtype=torch.float16, device=self.device
).to(torch.float8_e4m3fn)
# Global scale buffers
self._l1_gsa_buf = torch.zeros(1, dtype=torch.float32, device=self.device)
self._l2_gsa_buf = torch.zeros(1, dtype=torch.float32, device=self.device)
# Expert offsets for num_groups=1: just [num_tokens_padded]
# The GEMM expects expert_offsets as (num_experts,) cumulative offsets
# For 1 expert: offsets = [num_tokens] (just one element)
self._expert_offsets_buf = torch.zeros(1, dtype=torch.int32, device=self.device)
self._buffers_allocated = True
def _ensure_initialized(self):
"""Lazily initialize stacked weights and buffers."""
if self._l1_mat_b is None:
self.finalize_weights()
if not self._buffers_allocated:
self._allocate_buffers()
def _assemble_scales_single_group(self, x_sf, num_tokens, padded_x_sf_buf):
"""Assemble 2D-side activation scales for num_groups=1.
For a single group, scale assembly is just:
1. Copy x_sf into a correctly-sized buffer (padded to 128 rows, 4 cols)
2. Apply pad_and_swizzle_single (Blackwell swizzle)
3. Reshape back to 2D (kernel expects 2D scale_a)
The padded buffer must be sized exactly for 128-aligned num_tokens,
NOT the max_num_tokens buffer (which would be way too large).
"""
num_rows, num_cols = x_sf.shape
padded_rows = cutedsl_ceil_div(num_rows, 128) * 128
padded_cols = cutedsl_ceil_div(num_cols, 4) * 4
# Use a temp buffer sized for this exact token count
buf = torch.zeros(padded_rows, padded_cols, dtype=torch.float16, device=x_sf.device).to(torch.float8_e4m3fn)
buf[:num_rows, :num_cols] = x_sf
swizzled_flat = pad_and_swizzle_single(buf)
return swizzled_flat.reshape(padded_rows, padded_cols)
def compute_activation_global_scales(self, hidden_states_sample):
"""Compute activation global scales from a warmup forward pass.
Called BEFORE cudagraph capture. Uses quantize_to_nvfp4 to get
the exact global_scale from the data, then runs L1 to compute
L2 gs from actual SiLU(gate)*up output.
"""
self._ensure_initialized()
with torch.no_grad():
# L1: exact gs from quantize_to_nvfp4
_, _, l1_gs = quantize_to_nvfp4(hidden_states_sample)
self._l1_activation_global_scale = l1_gs
# Run L1 GEMM to get intermediate for L2 gs
num_tokens = hidden_states_sample.shape[0]
l1_out = self._run_l1(hidden_states_sample)
if l1_out is not None and not torch.isnan(l1_out).any():
gate = l1_out[:, :self.intermediate_size]
up = l1_out[:, self.intermediate_size:]
if self.swiglu_limit is not None:
gate = gate.clamp(max=self.swiglu_limit)
up = up.clamp(min=-self.swiglu_limit, max=self.swiglu_limit)
activated = torch.nn.functional.silu(gate) * up
_, _, l2_gs = quantize_to_nvfp4(activated)
self._l2_activation_global_scale = l2_gs
def _run_l1(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""L1 GEMM: activation × gate_up_weight → BF16."""
num_tokens = hidden_states.shape[0]
padded_rows = cutedsl_ceil_div(num_tokens, 128) * 128
# Quantize activation
x_fp4, x_sf = quantize_activation_nvfp4(
hidden_states, self._l1_activation_global_scale
)
# Scatter x_fp4 into padded buffer
padded_x_fp4 = self._padded_x_fp4_buf_l1
padded_x_fp4.view(torch.uint8).zero_()
padded_x_fp4.view(torch.uint8)[:num_tokens] = x_fp4.view(torch.uint8)
# Assemble A-side scales
scale_a = self._assemble_scales_single_group(x_sf, num_tokens, self._padded_x_sf_buf_l1)
# Expert offsets: [padded_rows] for 1 group
expert_offsets = self._expert_offsets_buf
expert_offsets.fill_(padded_rows)
# Global scales
gsa = self._l1_gsa_buf.fill_(self._l1_activation_global_scale)
# Run GEMM
out = run_nvfp4_grouped_gemm(
mat_a=padded_x_fp4,
mat_b=self._l1_mat_b,
scale_a=scale_a,
scale_b=self._l1_scale_b,
expert_offsets=expert_offsets,
global_scale_a=gsa,
global_scale_b=self._l1_gsb,
)
# Extract real token outputs
return out[:num_tokens]
def _run_l2(self, intermediate: torch.Tensor) -> torch.Tensor:
"""L2 GEMM: intermediate × down_weight → BF16."""
num_tokens = intermediate.shape[0]
padded_rows = cutedsl_ceil_div(num_tokens, 128) * 128
# Quantize activation
x_fp4, x_sf = quantize_activation_nvfp4(
intermediate, self._l2_activation_global_scale
)
# Scatter into padded buffer
padded_x_fp4 = self._padded_x_fp4_buf_l2
padded_x_fp4.view(torch.uint8).zero_()
padded_x_fp4.view(torch.uint8)[:num_tokens] = x_fp4.view(torch.uint8)
# Assemble A-side scales
scale_a = self._assemble_scales_single_group(x_sf, num_tokens, self._padded_x_sf_buf_l2)
# Expert offsets
expert_offsets = self._expert_offsets_buf
expert_offsets.fill_(padded_rows)
# Global scales
gsa = self._l2_gsa_buf.fill_(self._l2_activation_global_scale)
# Run GEMM
out = run_nvfp4_grouped_gemm(
mat_a=padded_x_fp4,
mat_b=self._l2_mat_b,
scale_a=scale_a,
scale_b=self._l2_scale_b,
expert_offsets=expert_offsets,
global_scale_a=gsa,
global_scale_b=self._l2_gsb,
)
return out[:num_tokens]
def run(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""Full shared expert forward: L1 → SiLU → L2 → output."""
return _SharedExpertApply.apply(self, hidden_states)
def _run_impl(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""Actual implementation — called via custom autograd to be torch.compile-safe."""
self._ensure_initialized()
l1_out = self._run_l1(hidden_states)
gate = l1_out[:, :self.intermediate_size]
up = l1_out[:, self.intermediate_size:]
if self.swiglu_limit is not None:
# Match SiluAndMulWithClamp: clamp gate BEFORE silu, clamp up to [-limit, limit]
gate = gate.clamp(max=self.swiglu_limit)
up = up.clamp(min=-self.swiglu_limit, max=self.swiglu_limit)
intermediate = torch.nn.functional.silu(gate) * up
output = self._run_l2(intermediate)
return output