Files
nvfp4-megamoe-kernel/cutedsl/shared_expert_pipeline.py
biondizzle 48386e34ad Fix torch.compile: use custom autograd Function instead of @torch.compiler.disable
torch.compile fullgraph mode can't handle @torch.compiler.disable (skips
the function and refuses to compile). Custom autograd Functions are treated
as opaque ops by torch.compile — they execute eagerly without the compiler
trying to trace into CuTeDSL internals (JIT, Path.cwd, etc).
2026-05-18 21:38:28 +00:00

304 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""CuTeDSL Shared Expert Pipeline
NVFP4 inference for DeepSeek V4 shared experts.
Uses ScaledGroupedGemmKernel with num_groups=1.
Pipeline:
1. Quantize activation: BF16 → NVFP4 (using warmup gs)
2. L1 GEMM: NVFP4_act × NVFP4_weight(gate_up) → BF16
3. SiLU(gate) * up → BF16
4. Re-quantize: BF16 → NVFP4 (using warmup gs)
5. L2 GEMM: NVFP4_act × NVFP4_weight(down) → BF16
Unlike MoE, there's no routing, no scatter, no expert offsets.
All tokens go through the same expert (the shared expert).
Scale assembly is just: quantize activation → pad to 128-row alignment → Blackwell swizzle.
CUDA-graph-compatible: all buffers pre-allocated, no CPU-GPU syncs,
no dynamic shapes. Padding rows are zeros that contribute nothing to GEMM output.
"""
import torch
from cutedsl.bridge import (
quantize_activation_nvfp4,
class _SharedExpertApply(torch.autograd.Function):
"""Custom autograd function to make CuTeDSL runner opaque to torch.compile."""
@staticmethod
def forward(ctx, runner, hidden_states):
return runner._run_impl(hidden_states)
quantize_to_nvfp4,
make_b_k_major,
assemble_scales_3d_side,
run_nvfp4_grouped_gemm,
)
from cutedsl.kernel.moe.torch_scaled_grouped_mm import (
ceil_div as cutedsl_ceil_div,
pad_and_swizzle_single,
)
class CuTeDSLSharedExpertRunner:
"""NVFP4 shared expert runner using CuTeDSL GEMM (num_groups=1).
CUDA-graph-compatible: all buffers pre-allocated, no CPU-GPU syncs.
"""
def __init__(
self,
hidden_size: int,
intermediate_size: int,
max_num_tokens: int = 8192,
device: str = "cuda",
swiglu_limit: float = 10.0,
):
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.max_num_tokens = max_num_tokens
self.device = device
self.swiglu_limit = swiglu_limit
# Weights (set after construction, then call finalize_weights)
self.l1_fp4 = None
self.l1_sf = None
self.l1_gs = None
self.l2_fp4 = None
self.l2_sf = None
self.l2_gs = None
# Processed weights (set by finalize_weights)
self._l1_mat_b = None
self._l2_mat_b = None
self._l1_scale_b = None
self._l2_scale_b = None
self._l1_gsb = None
self._l2_gsb = None
# Activation global scales (set by compute_activation_global_scales)
self._l1_activation_global_scale = 1.0 / (6.0 * 448.0)
self._l2_activation_global_scale = 1.0 / (6.0 * 448.0)
# Pre-allocated cudagraph buffers (set in _allocate_buffers)
self._padded_x_fp4_buf_l1 = None
self._padded_x_sf_buf_l1 = None
self._padded_x_fp4_buf_l2 = None
self._padded_x_sf_buf_l2 = None
self._l1_gsa_buf = None
self._l2_gsa_buf = None
self._expert_offsets_buf = None
self._buffers_allocated = False
def set_swiglu_limit(self, limit: float):
self.swiglu_limit = limit
def finalize_weights(self):
"""Process weights for CuTeDSL GEMM. Must be called after setting l1/l2 weights."""
# Stack weights and convert to K-major
# l1_fp4/l2_fp4 are lists with 1 element (the shared expert)
self._l1_mat_b = make_b_k_major(torch.stack(self.l1_fp4)) # (1, K_packed, N_packed)
self._l2_mat_b = make_b_k_major(torch.stack(self.l2_fp4))
self._l1_scale_b = assemble_scales_3d_side(self.l1_sf) # (1, N, K_sf_padded)
self._l2_scale_b = assemble_scales_3d_side(self.l2_sf)
self._l1_gsb = torch.tensor(self.l1_gs, dtype=torch.float32, device=self.device)
self._l2_gsb = torch.tensor(self.l2_gs, dtype=torch.float32, device=self.device)
# Free raw weights
self.l1_fp4 = None
self.l1_sf = None
self.l1_gs = None
self.l2_fp4 = None
self.l2_sf = None
self.l2_gs = None
def _allocate_buffers(self):
"""Pre-allocate all buffers at max size for cudagraph compatibility."""
max_rows = cutedsl_ceil_div(self.max_num_tokens, 128) * 128 # pad to 128
# L1: hidden_size packed, L2: intermediate_size packed
self._padded_x_fp4_buf_l1 = torch.zeros(
max_rows, self.hidden_size // 2, dtype=torch.uint8, device=self.device
).view(torch.float4_e2m1fn_x2)
self._padded_x_fp4_buf_l2 = torch.zeros(
max_rows, self.intermediate_size // 2, dtype=torch.uint8, device=self.device
).view(torch.float4_e2m1fn_x2)
# Padded scale buffers (need same padded dimensions as pad_and_swizzle_single produces)
K_sf_l1 = cutedsl_ceil_div(self.hidden_size, 16)
padded_cols_l1 = cutedsl_ceil_div(K_sf_l1, 4) * 4
K_sf_l2 = cutedsl_ceil_div(self.intermediate_size, 16)
padded_cols_l2 = cutedsl_ceil_div(K_sf_l2, 4) * 4
self._padded_x_sf_buf_l1 = torch.zeros(
max_rows, padded_cols_l1, dtype=torch.float16, device=self.device
).to(torch.float8_e4m3fn)
self._padded_x_sf_buf_l2 = torch.zeros(
max_rows, padded_cols_l2, dtype=torch.float16, device=self.device
).to(torch.float8_e4m3fn)
# Global scale buffers
self._l1_gsa_buf = torch.zeros(1, dtype=torch.float32, device=self.device)
self._l2_gsa_buf = torch.zeros(1, dtype=torch.float32, device=self.device)
# Expert offsets for num_groups=1: just [num_tokens_padded]
# The GEMM expects expert_offsets as (num_experts,) cumulative offsets
# For 1 expert: offsets = [num_tokens] (just one element)
self._expert_offsets_buf = torch.zeros(1, dtype=torch.int32, device=self.device)
self._buffers_allocated = True
def _ensure_initialized(self):
"""Lazily initialize stacked weights and buffers."""
if self._l1_mat_b is None:
self.finalize_weights()
if not self._buffers_allocated:
self._allocate_buffers()
def _assemble_scales_single_group(self, x_sf, num_tokens, padded_x_sf_buf):
"""Assemble 2D-side activation scales for num_groups=1.
For a single group, scale assembly is just:
1. Copy x_sf into a correctly-sized buffer (padded to 128 rows, 4 cols)
2. Apply pad_and_swizzle_single (Blackwell swizzle)
3. Reshape back to 2D (kernel expects 2D scale_a)
The padded buffer must be sized exactly for 128-aligned num_tokens,
NOT the max_num_tokens buffer (which would be way too large).
"""
num_rows, num_cols = x_sf.shape
padded_rows = cutedsl_ceil_div(num_rows, 128) * 128
padded_cols = cutedsl_ceil_div(num_cols, 4) * 4
# Use a temp buffer sized for this exact token count
buf = torch.zeros(padded_rows, padded_cols, dtype=torch.float16, device=x_sf.device).to(torch.float8_e4m3fn)
buf[:num_rows, :num_cols] = x_sf
swizzled_flat = pad_and_swizzle_single(buf)
return swizzled_flat.reshape(padded_rows, padded_cols)
def compute_activation_global_scales(self, hidden_states_sample):
"""Compute activation global scales from a warmup forward pass.
Called BEFORE cudagraph capture. Uses quantize_to_nvfp4 to get
the exact global_scale from the data, then runs L1 to compute
L2 gs from actual SiLU(gate)*up output.
"""
self._ensure_initialized()
with torch.no_grad():
# L1: exact gs from quantize_to_nvfp4
_, _, l1_gs = quantize_to_nvfp4(hidden_states_sample)
self._l1_activation_global_scale = l1_gs
# Run L1 GEMM to get intermediate for L2 gs
num_tokens = hidden_states_sample.shape[0]
l1_out = self._run_l1(hidden_states_sample)
if l1_out is not None and not torch.isnan(l1_out).any():
gate = l1_out[:, :self.intermediate_size]
up = l1_out[:, self.intermediate_size:]
if self.swiglu_limit is not None:
gate = gate.clamp(max=self.swiglu_limit)
up = up.clamp(min=-self.swiglu_limit, max=self.swiglu_limit)
activated = torch.nn.functional.silu(gate) * up
_, _, l2_gs = quantize_to_nvfp4(activated)
self._l2_activation_global_scale = l2_gs
def _run_l1(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""L1 GEMM: activation × gate_up_weight → BF16."""
num_tokens = hidden_states.shape[0]
padded_rows = cutedsl_ceil_div(num_tokens, 128) * 128
# Quantize activation
x_fp4, x_sf = quantize_activation_nvfp4(
hidden_states, self._l1_activation_global_scale
)
# Scatter x_fp4 into padded buffer
padded_x_fp4 = self._padded_x_fp4_buf_l1
padded_x_fp4.view(torch.uint8).zero_()
padded_x_fp4.view(torch.uint8)[:num_tokens] = x_fp4.view(torch.uint8)
# Assemble A-side scales
scale_a = self._assemble_scales_single_group(x_sf, num_tokens, self._padded_x_sf_buf_l1)
# Expert offsets: [padded_rows] for 1 group
expert_offsets = self._expert_offsets_buf
expert_offsets.fill_(padded_rows)
# Global scales
gsa = self._l1_gsa_buf.fill_(self._l1_activation_global_scale)
# Run GEMM
out = run_nvfp4_grouped_gemm(
mat_a=padded_x_fp4,
mat_b=self._l1_mat_b,
scale_a=scale_a,
scale_b=self._l1_scale_b,
expert_offsets=expert_offsets,
global_scale_a=gsa,
global_scale_b=self._l1_gsb,
)
# Extract real token outputs
return out[:num_tokens]
def _run_l2(self, intermediate: torch.Tensor) -> torch.Tensor:
"""L2 GEMM: intermediate × down_weight → BF16."""
num_tokens = intermediate.shape[0]
padded_rows = cutedsl_ceil_div(num_tokens, 128) * 128
# Quantize activation
x_fp4, x_sf = quantize_activation_nvfp4(
intermediate, self._l2_activation_global_scale
)
# Scatter into padded buffer
padded_x_fp4 = self._padded_x_fp4_buf_l2
padded_x_fp4.view(torch.uint8).zero_()
padded_x_fp4.view(torch.uint8)[:num_tokens] = x_fp4.view(torch.uint8)
# Assemble A-side scales
scale_a = self._assemble_scales_single_group(x_sf, num_tokens, self._padded_x_sf_buf_l2)
# Expert offsets
expert_offsets = self._expert_offsets_buf
expert_offsets.fill_(padded_rows)
# Global scales
gsa = self._l2_gsa_buf.fill_(self._l2_activation_global_scale)
# Run GEMM
out = run_nvfp4_grouped_gemm(
mat_a=padded_x_fp4,
mat_b=self._l2_mat_b,
scale_a=scale_a,
scale_b=self._l2_scale_b,
expert_offsets=expert_offsets,
global_scale_a=gsa,
global_scale_b=self._l2_gsb,
)
return out[:num_tokens]
def run(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""Full shared expert forward: L1 → SiLU → L2 → output."""
return _SharedExpertApply.apply(self, hidden_states)
def _run_impl(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""Actual implementation — called via custom autograd to be torch.compile-safe."""
self._ensure_initialized()
l1_out = self._run_l1(hidden_states)
gate = l1_out[:, :self.intermediate_size]
up = l1_out[:, self.intermediate_size:]
if self.swiglu_limit is not None:
# Match SiluAndMulWithClamp: clamp gate BEFORE silu, clamp up to [-limit, limit]
gate = gate.clamp(max=self.swiglu_limit)
up = up.clamp(min=-self.swiglu_limit, max=self.swiglu_limit)
intermediate = torch.nn.functional.silu(gate) * up
output = self._run_l2(intermediate)
return output