760 lines
36 KiB
Python
760 lines
36 KiB
Python
"""
|
|
vLLM integration for the CuTeDSL NVFP4 MoE kernel.
|
|
|
|
CUDA-graph-compatible design:
|
|
- All intermediate buffers pre-allocated at max_num_tokens * top_k size
|
|
- No .item(), .tolist(), .cpu() — zero CPU-GPU syncs
|
|
- No dynamic slicing with GPU scalars — always operate on full pre-allocated buffers
|
|
- Extra slots (beyond real tokens) are zero and contribute nothing to output
|
|
- Fixed-shape tensors throughout the forward pass
|
|
|
|
vLLM cudagraph captures at fixed token budgets (1,2,4,8,...,8192).
|
|
During capture, num_tokens equals the budget — all shapes are fixed.
|
|
During replay, inputs are padded to the budget size. Our runner always
|
|
processes max_slots = budget * top_k rows; padding rows are zeros.
|
|
"""
|
|
import torch
|
|
|
|
from dsv4.ops.quantize import (
|
|
quantize_activation_nvfp4,
|
|
quantize_weight_to_nvfp4,
|
|
quantize_to_nvfp4,
|
|
quantize_nvfp4_gpu,
|
|
deinterleave_quantize_nvfp4_cuda,
|
|
)
|
|
from dsv4.ops.layouts import (
|
|
make_b_k_major,
|
|
assemble_scales_3d_side,
|
|
interleave_l1_weights,
|
|
deinterleave_l1_weights,
|
|
)
|
|
from dsv4.ops.gemm_runner import (
|
|
run_nvfp4_grouped_gemm,
|
|
run_fused_swiglu_grouped_gemm,
|
|
warmup_fused_swiglu_compilation,
|
|
)
|
|
from dsv4.ops.layouts import (
|
|
ceil_div as cutedsl_ceil_div,
|
|
pad_and_swizzle_single,
|
|
)
|
|
from dsv4.ops.custom_ops import register_runner, nvfp4_moe_gemm
|
|
|
|
|
|
class Nvfp4MoE:
|
|
"""Manages NVFP4 MoE execution via the CuTeDSL kernel.
|
|
|
|
CUDA-graph-compatible: all buffers pre-allocated, no CPU-GPU syncs,
|
|
no dynamic shapes. Always computes at max_num_tokens * top_k capacity.
|
|
"""
|
|
|
|
def __init__(self, num_experts, hidden_size, intermediate_size,
|
|
max_num_tokens=8192, top_k=8, device="cuda",
|
|
experts_start_idx=0):
|
|
self.num_experts = num_experts
|
|
self.hidden_size = hidden_size
|
|
self.intermediate_size = intermediate_size
|
|
self.max_num_tokens = max_num_tokens
|
|
self.top_k = top_k
|
|
self.device = device
|
|
self.experts_start_idx = experts_start_idx
|
|
self._swiglu_limit = None # Set via set_swiglu_limit()
|
|
self._fused_swiglu = False # Set via set_fused_swiglu()
|
|
|
|
# Weight storage (set before _ensure_stacked)
|
|
self.l1_fp4 = None
|
|
self.l1_sf = None
|
|
self.l1_gs = None
|
|
self.l2_fp4 = None
|
|
self.l2_sf = None
|
|
self.l2_gs = None
|
|
|
|
# Stacked weight tensors (set in _ensure_stacked)
|
|
self._l1_mat_b = None
|
|
self._l2_mat_b = None
|
|
self._l1_scale_b = None
|
|
self._l2_scale_b = None
|
|
self._l1_gsb = None
|
|
self._l2_gsb = None
|
|
|
|
# Default: 1/2688 ≈ 0.000372 (amax=1 → gs=1/2688)
|
|
# Overridden in finalize_weights with checkpoint input_scale or warmup value
|
|
self._l1_activation_global_scale = 1.0 / (6.0 * 448.0)
|
|
self._l2_activation_global_scale = 1.0 / (6.0 * 448.0)
|
|
|
|
# Pre-allocated cudagraph buffers (set in _allocate_buffers)
|
|
self._token_indices = None
|
|
self._expert_offsets_buf = None
|
|
self._per_expert_scale_bufs_l1 = None
|
|
self._per_expert_scale_bufs_l2 = None
|
|
self._padded_x_sf_buf_l1 = None
|
|
self._padded_x_sf_buf_l2 = None
|
|
self._l1_gsa_buf = None
|
|
self._l2_gsa_buf = None
|
|
self._l1_out_buf = None # pre-allocated L1 GEMM output for graph capture
|
|
self._output_buf = None
|
|
self._row_indices_buf = None
|
|
self._padded_hidden_buf = None
|
|
self._padded_activated_buf = None # unused, using shared
|
|
self._padded_expert_offsets_buf = None
|
|
self._max_chunks_per_expert = cutedsl_ceil_div(
|
|
self.max_num_tokens * self.top_k, self.num_experts * 128
|
|
)
|
|
self._buffers_allocated = False
|
|
|
|
def set_swiglu_limit(self, limit: float | None):
|
|
"""Set the swiglu_limit for activation clamping."""
|
|
self._swiglu_limit = limit
|
|
|
|
def set_fused_swiglu(self, enabled: bool):
|
|
"""Enable fused L1 GEMM + SwiGLU kernel (saves 240+ BF16 kernel launches per token)."""
|
|
self._fused_swiglu = enabled
|
|
|
|
def _fill_token_indices(self):
|
|
"""Fill _token_indices with [0,0,..0, 1,1,..1, ...] (each token repeated top_k times).
|
|
|
|
Builds on CPU first, then copies to GPU, to ensure correctness
|
|
regardless of CuTeDSL JIT GPU memory corruption.
|
|
"""
|
|
src = torch.arange(self.max_num_tokens, dtype=torch.int32)
|
|
cpu_indices = src.unsqueeze(1).expand(-1, self.top_k).contiguous().view(-1)
|
|
self._token_indices.copy_(cpu_indices)
|
|
|
|
def _allocate_buffers(self):
|
|
"""Pre-allocate scale buffers at max size for cudagraph compatibility."""
|
|
# Per-expert scale buffers: separate L1/L2 since K_sf differs
|
|
K_sf_l1 = cutedsl_ceil_div(self.hidden_size, 16)
|
|
padded_cols_l1 = cutedsl_ceil_div(K_sf_l1, 4) * 4
|
|
K_sf_l2 = cutedsl_ceil_div(self.intermediate_size, 16)
|
|
padded_cols_l2 = cutedsl_ceil_div(K_sf_l2, 4) * 4
|
|
|
|
self._per_expert_scale_bufs_l1 = [
|
|
torch.zeros(128, padded_cols_l1, dtype=torch.float16, device=self.device).to(torch.float8_e4m3fn)
|
|
for _ in range(self.num_experts)
|
|
]
|
|
self._per_expert_scale_bufs_l2 = [
|
|
torch.zeros(128, padded_cols_l2, dtype=torch.float16, device=self.device).to(torch.float8_e4m3fn)
|
|
for _ in range(self.num_experts)
|
|
]
|
|
|
|
# Initialize shared buffers dict (if not already)
|
|
device_key = str(self.device)
|
|
if not hasattr(Nvfp4MoE, '_shared_padded_bufs'):
|
|
Nvfp4MoE._shared_padded_bufs = {}
|
|
if device_key not in Nvfp4MoE._shared_padded_bufs:
|
|
Nvfp4MoE._shared_padded_bufs[device_key] = {}
|
|
|
|
# Padded x_sf buffers: SHARED across all runners (not per-layer)
|
|
max_sf_rows = self.num_experts * self._max_chunks_per_expert * 128
|
|
if 'xsf_l1' not in Nvfp4MoE._shared_padded_bufs[device_key]:
|
|
Nvfp4MoE._shared_padded_bufs[device_key].update({
|
|
'xsf_l1': torch.zeros(
|
|
max_sf_rows, padded_cols_l1, dtype=torch.float16, device=self.device
|
|
).to(torch.float8_e4m3fn),
|
|
'xsf_l2': torch.zeros(
|
|
max_sf_rows, padded_cols_l2, dtype=torch.float16, device=self.device
|
|
).to(torch.float8_e4m3fn),
|
|
'output': torch.zeros(
|
|
self.max_num_tokens, self.hidden_size, dtype=torch.bfloat16, device=self.device
|
|
),
|
|
})
|
|
self._padded_x_sf_buf_l1 = Nvfp4MoE._shared_padded_bufs[device_key]['xsf_l1']
|
|
self._padded_x_sf_buf_l2 = Nvfp4MoE._shared_padded_bufs[device_key]['xsf_l2']
|
|
self._output_buf = Nvfp4MoE._shared_padded_bufs[device_key]['output']
|
|
|
|
# Pre-allocated swizzled scale output buffers (same size as padded_x_sf)
|
|
# Required for CUDA graph capture — Python view ops (reshape, transpose) not capturable
|
|
if 'xsf_swizzled_l1' not in Nvfp4MoE._shared_padded_bufs[device_key]:
|
|
Nvfp4MoE._shared_padded_bufs[device_key].update({
|
|
'xsf_swizzled_l1': torch.zeros_like(Nvfp4MoE._shared_padded_bufs[device_key]['xsf_l1']),
|
|
'xsf_swizzled_l2': torch.zeros_like(Nvfp4MoE._shared_padded_bufs[device_key]['xsf_l2']),
|
|
})
|
|
self._padded_x_sf_swizzled_buf_l1 = Nvfp4MoE._shared_padded_bufs[device_key]['xsf_swizzled_l1']
|
|
self._padded_x_sf_swizzled_buf_l2 = Nvfp4MoE._shared_padded_bufs[device_key]['xsf_swizzled_l2']
|
|
|
|
# Pre-allocated global_scale_a buffers (filled via .fill_(), no torch.full during capture)
|
|
self._l1_gsa_buf = torch.zeros(self.num_experts, dtype=torch.float32, device=self.device)
|
|
self._l2_gsa_buf = torch.zeros(self.num_experts, dtype=torch.float32, device=self.device)
|
|
|
|
# Pre-allocated L1 GEMM output — avoids torch.zeros() in run_fused_swiglu_grouped_gemm
|
|
# Shape: (max_tokens * top_k, 2*intermediate_size) — gate+up combined
|
|
self._l1_out_buf = torch.zeros(
|
|
self.max_num_tokens * self.top_k, 2 * self.intermediate_size,
|
|
dtype=torch.bfloat16, device=self.device
|
|
)
|
|
# Pre-allocated L2 GEMM output — avoids torch.zeros() in run_nvfp4_grouped_gemm
|
|
# Shape: (max_tokens * top_k, hidden_size) — down projection
|
|
self._l2_out_buf = torch.zeros(
|
|
self.max_num_tokens * self.top_k, self.hidden_size,
|
|
dtype=torch.bfloat16, device=self.device
|
|
)
|
|
|
|
# Pre-allocated tokens-per-expert buffer — replaces torch.bincount
|
|
# (bincount produces data-dependent shapes, breaks CUDA graph capture)
|
|
self._tokens_per_expert_buf = torch.zeros(self.num_experts, dtype=torch.int32, device=self.device)
|
|
|
|
# Row indices for scale assembly (max_num_tokens * top_k slots)
|
|
self._row_indices_buf = torch.arange(
|
|
self.max_num_tokens * self.top_k, device=self.device
|
|
)
|
|
|
|
# Padded hidden/activated: SHARED across all runners (not per-layer)
|
|
max_rows_per_expert = self._max_chunks_per_expert * 128
|
|
padded_max_slots = self.num_experts * max_rows_per_expert
|
|
if 'hidden' not in Nvfp4MoE._shared_padded_bufs[device_key]:
|
|
Nvfp4MoE._shared_padded_bufs[device_key].update({
|
|
'hidden': torch.zeros(
|
|
padded_max_slots, self.hidden_size, dtype=torch.bfloat16, device=self.device
|
|
),
|
|
'hidden_fp4': torch.zeros(
|
|
padded_max_slots, self.hidden_size // 2, dtype=torch.uint8, device=self.device
|
|
).view(torch.float4_e2m1fn_x2),
|
|
'activated': torch.zeros(
|
|
padded_max_slots, self.intermediate_size, dtype=torch.bfloat16, device=self.device
|
|
),
|
|
'activated_fp4': torch.zeros(
|
|
padded_max_slots, self.intermediate_size // 2, dtype=torch.uint8, device=self.device
|
|
).view(torch.float4_e2m1fn_x2),
|
|
})
|
|
self._shared_bufs = Nvfp4MoE._shared_padded_bufs[device_key]
|
|
|
|
# Padded expert offsets buffer: [0, max_rows, 2*max_rows, ...] (fixed)
|
|
self._padded_expert_offsets_buf = torch.zeros(
|
|
self.num_experts + 1, dtype=torch.int32, device=self.device
|
|
)
|
|
max_rows_per_expert = self._max_chunks_per_expert * 128
|
|
self._padded_expert_offsets_buf[1:] = torch.arange(
|
|
1, self.num_experts + 1, dtype=torch.int32, device=self.device
|
|
) * max_rows_per_expert
|
|
|
|
self._buffers_allocated = True
|
|
|
|
def _ensure_stacked(self):
|
|
if self._l1_mat_b is not None:
|
|
return
|
|
|
|
# Convert weights to kernel format
|
|
if hasattr(self, 'l1_fp4_stacked') and self.l1_fp4_stacked is not None:
|
|
# Fast path: pre-stacked 3D tensors in checkpoint format (E, N, K)
|
|
# Permute to (E, K, N) then make K-major
|
|
l1_fp4_ekn = self.l1_fp4_stacked.permute(0, 2, 1).contiguous()
|
|
l2_fp4_ekn = self.l2_fp4_stacked.permute(0, 2, 1).contiguous()
|
|
# Interleave L1 gate/up weights at granularity 4 BF16.
|
|
# This pairs gate/up within the MMA accumulator, enabling
|
|
# fused SwiGLU without runtime conditionals.
|
|
l1_fp4_ekn = interleave_l1_weights(l1_fp4_ekn)
|
|
# Convert uint8 checkpoint weights to float4_e2m1fn_x2 view
|
|
if l1_fp4_ekn.dtype == torch.uint8:
|
|
l1_fp4_ekn = l1_fp4_ekn.view(torch.float4_e2m1fn_x2)
|
|
if l2_fp4_ekn.dtype == torch.uint8:
|
|
l2_fp4_ekn = l2_fp4_ekn.view(torch.float4_e2m1fn_x2)
|
|
# Free stacked checkpoints before make_b_k_major (saves one copy)
|
|
self.l1_fp4_stacked = None
|
|
self.l2_fp4_stacked = None
|
|
torch.cuda.empty_cache()
|
|
|
|
self._l1_mat_b = make_b_k_major(l1_fp4_ekn)
|
|
self._l2_mat_b = make_b_k_major(l2_fp4_ekn)
|
|
del l1_fp4_ekn, l2_fp4_ekn
|
|
torch.cuda.empty_cache()
|
|
|
|
# Scales: checkpoint is (E, N, K_sf) — the kernel expects (N, K_sf)
|
|
# per expert for swizzle. Split into views (no copy), then assemble.
|
|
l1_sf_list = [self.l1_sf_stacked[i] for i in range(self.num_experts)]
|
|
l2_sf_list = [self.l2_sf_stacked[i] for i in range(self.num_experts)]
|
|
self.l1_sf_stacked = None
|
|
self.l2_sf_stacked = None
|
|
torch.cuda.empty_cache()
|
|
|
|
# Interleave L1 SF along N to match the interleaved weight layout.
|
|
# SF per expert from checkpoint is (N, K_sf). Interleave along N.
|
|
# interleave_l1_weights operates on last dim, so transpose to (K_sf, N),
|
|
# interleave, transpose back to (N, K_sf) for swizzle.
|
|
l1_sf_il = []
|
|
for sf_nk in l1_sf_list:
|
|
sf_kn = sf_nk.T.contiguous().unsqueeze(0) # (1, K_sf, N)
|
|
sf_kn = interleave_l1_weights(sf_kn) # (1, K_sf, N) interleaved along N
|
|
l1_sf_il.append(sf_kn[0].T.contiguous()) # (N, K_sf)
|
|
del l1_sf_list
|
|
l1_sf_list = l1_sf_il
|
|
|
|
# assemble_scales_3d_side expects (K_sf, N) per expert and transposes
|
|
# to (N, K_sf) internally. But our scales are already (N, K_sf) from
|
|
# the checkpoint! Skip the transpose by calling the assembly directly.
|
|
from dsv4.ops.layouts import (
|
|
assemble_raw_scales_2d3d_3d_side,
|
|
)
|
|
self._l1_scale_b = assemble_raw_scales_2d3d_3d_side(l1_sf_list)
|
|
self._l2_scale_b = assemble_raw_scales_2d3d_3d_side(l2_sf_list)
|
|
del l1_sf_list, l2_sf_list
|
|
else:
|
|
# Legacy path: per-expert lists
|
|
l1_stacked = torch.stack(self.l1_fp4) # (E, K, N)
|
|
l1_stacked = interleave_l1_weights(l1_stacked) # interleave gate/up
|
|
if l1_stacked.dtype == torch.uint8:
|
|
l1_stacked = l1_stacked.view(torch.float4_e2m1fn_x2)
|
|
l2_stacked = torch.stack(self.l2_fp4)
|
|
if l2_stacked.dtype == torch.uint8:
|
|
l2_stacked = l2_stacked.view(torch.float4_e2m1fn_x2)
|
|
self._l1_mat_b = make_b_k_major(l1_stacked)
|
|
self._l2_mat_b = make_b_k_major(l2_stacked)
|
|
# Interleave L1 SF to match weight interleave
|
|
# SF from quantize_weight_to_nvfp4 is (K_sf, N). Interleave along N,
|
|
# then transpose to (N, K_sf) for swizzle via assemble_scales_3d_side.
|
|
l1_sf_il = []
|
|
for sf in self.l1_sf:
|
|
sf_ekn = sf.unsqueeze(0) # (1, K_sf, N)
|
|
sf_ekn = interleave_l1_weights(sf_ekn) # interleaved along N
|
|
l1_sf_il.append(sf_ekn[0]) # (K_sf, N)
|
|
self._l1_scale_b = assemble_scales_3d_side(l1_sf_il)
|
|
self._l2_scale_b = assemble_scales_3d_side(self.l2_sf)
|
|
del l1_stacked, l1_sf_il
|
|
self.l1_fp4 = None
|
|
self.l1_sf = None
|
|
self.l2_fp4 = None
|
|
self.l2_sf = None
|
|
|
|
self._l1_gsb = torch.tensor(self.l1_gs, dtype=torch.float32, device=self.device)
|
|
self._l2_gsb = torch.tensor(self.l2_gs, dtype=torch.float32, device=self.device)
|
|
|
|
# Fold weight_scale_2 into global_scale_b
|
|
# gsb = input_scale * weight_scale_2
|
|
if self.l1_ws2 is not None:
|
|
for i, ws2 in enumerate(self.l1_ws2):
|
|
if ws2 is not None:
|
|
self._l1_gsb[i] *= ws2.float().item()
|
|
if self.l2_ws2 is not None:
|
|
for i, ws2 in enumerate(self.l2_ws2):
|
|
if ws2 is not None:
|
|
self._l2_gsb[i] *= ws2.float().item()
|
|
|
|
self.l1_gs = None
|
|
self.l2_gs = None
|
|
self.l1_ws2 = None
|
|
self.l2_ws2 = None
|
|
|
|
# Allocate buffers and eagerly warmup JIT compilation.
|
|
# cute.compile does NOT corrupt GPU memory (verified 2026-05-20).
|
|
# We warmup eagerly here to ensure compilation happens before
|
|
# the model's first forward pass, not during it.
|
|
self._token_indices = torch.zeros(
|
|
self.max_num_tokens * self.top_k, dtype=torch.int32, device=self.device
|
|
)
|
|
self._fill_token_indices()
|
|
# No _needs_token_refill: cute.compile does NOT corrupt GPU memory.
|
|
# The original corruption was a misdiagnosis (see bridge.py cache docs).
|
|
|
|
# Eagerly JIT-compile GEMM kernels for L1 and L2 shapes.
|
|
# This triggers cute.compile once per shape, caching the compiled
|
|
# kernel + workspace. Subsequent run() calls hit the cache.
|
|
# MUST happen before model forward pass to avoid OOM from lazy JIT.
|
|
from dsv4.ops.layouts import (
|
|
ceil_div as bridge_ceil_div,
|
|
)
|
|
from dsv4.ops.gemm_runner import (
|
|
warmup_compilation,
|
|
warmup_fused_swiglu_compilation,
|
|
)
|
|
K_packed = self.hidden_size // 2
|
|
N_packed_l1 = (2 * self.intermediate_size) // 2 # gate+up combined
|
|
N_packed_l2 = self.hidden_size // 2 # down
|
|
warmup_compilation(self.num_experts, K_packed, N_packed_l1, self.device) # L1
|
|
warmup_compilation(self.num_experts, K_packed, N_packed_l2, self.device) # L2
|
|
if self._fused_swiglu:
|
|
warmup_fused_swiglu_compilation(
|
|
self.num_experts, K_packed, N_packed_l1, self.device,
|
|
swiglu_limit=self._swiglu_limit if self._swiglu_limit is not None else 0.0,
|
|
) # Fused L1
|
|
|
|
self._expert_offsets_buf = torch.zeros(
|
|
self.num_experts + 1, dtype=torch.int32, device=self.device
|
|
)
|
|
self._allocate_buffers()
|
|
|
|
def prepare_weights_direct(self, l1_fp4, l1_sf, l1_gs, l2_fp4, l2_sf, l2_gs):
|
|
"""DEPRECATED: Use prepare_weights_from_stacked() for checkpoint weights.
|
|
|
|
This path takes pre-quantized per-expert lists. The stacked path is
|
|
more memory-efficient and avoids per-expert list overhead.
|
|
"""
|
|
self.l1_fp4 = l1_fp4
|
|
self.l1_sf = l1_sf
|
|
self.l1_gs = l1_gs
|
|
self.l2_fp4 = l2_fp4
|
|
self.l2_sf = l2_sf
|
|
self.l2_gs = l2_gs
|
|
self._l1_mat_b = None
|
|
|
|
def prepare_weights_from_stacked(self, l1_fp4_stacked, l1_sf_stacked,
|
|
l1_gs, l2_fp4_stacked, l2_sf_stacked,
|
|
l2_gs):
|
|
"""Prepare weights from pre-stacked 3D tensors (checkpoint format).
|
|
|
|
Takes (E, N, K_packed) fp4 and (E, N, K_sf) scale tensors directly
|
|
from the checkpoint, avoiding the per-expert list→stack round-trip.
|
|
|
|
The conversion to K-major and swizzled layout happens in _ensure_stacked.
|
|
This just stores the tensors for deferred processing.
|
|
"""
|
|
# Store in checkpoint format (E, N, K) — _ensure_stacked will convert
|
|
self.l1_fp4_stacked = l1_fp4_stacked
|
|
self.l1_sf_stacked = l1_sf_stacked
|
|
self.l1_gs = l1_gs
|
|
self.l2_fp4_stacked = l2_fp4_stacked
|
|
self.l2_sf_stacked = l2_sf_stacked
|
|
self.l2_gs = l2_gs
|
|
self._l1_mat_b = None
|
|
|
|
def prepare_weights_from_dequantized(self, l1_weights_bf16, l2_weights_bf16):
|
|
"""DEPRECATED: Use prepare_weights_from_stacked() instead.
|
|
|
|
This path dequantizes checkpoint NVFP4 to BF16 then re-quantizes to our FP4.
|
|
While the round-trip is lossless for DeepSeek-V4 (our packing matches
|
|
the checkpoint convention exactly), it wastes memory and compute.
|
|
The direct byte path (prepare_weights_from_stacked) is preferred.
|
|
"""
|
|
self.l1_fp4, self.l1_sf, self.l1_gs = [], [], []
|
|
self.l2_fp4, self.l2_sf, self.l2_gs = [], [], []
|
|
for l1_w, l2_w in zip(l1_weights_bf16, l2_weights_bf16):
|
|
l1_w_t = l1_w.T
|
|
w_fp4, w_sf, w_gs = quantize_weight_to_nvfp4(l1_w_t)
|
|
self.l1_fp4.append(w_fp4)
|
|
self.l1_sf.append(w_sf)
|
|
self.l1_gs.append(w_gs)
|
|
l2_w_t = l2_w.T
|
|
w_fp4, w_sf, w_gs = quantize_weight_to_nvfp4(l2_w_t)
|
|
self.l2_fp4.append(w_fp4)
|
|
self.l2_sf.append(w_sf)
|
|
self.l2_gs.append(w_gs)
|
|
self._l1_mat_b = None
|
|
|
|
def _assemble_scales_cudagraph_safe(self, x_sf, expert_offsets,
|
|
padded_expert_offsets,
|
|
padded_x_sf_buf, per_expert_bufs):
|
|
"""Assemble 2D-side activation scales (cudagraph-safe, NO CPU syncs).
|
|
|
|
Phase 1: Scatter x_sf into padded per-expert sections (GPU-only).
|
|
Phase 2: Apply full-buffer Blackwell 32_4_4 swizzle (no Python loops).
|
|
|
|
The buffer is 128-row aligned per expert (from padded_expert_offsets),
|
|
so the full-buffer swizzle produces the correct layout. The GEMM reads
|
|
scale_a using padded_expert_offsets, matching the scatter layout.
|
|
"""
|
|
K_sf = x_sf.shape[1]
|
|
padded_x_sf = padded_x_sf_buf
|
|
padded_x_sf.zero_()
|
|
|
|
# Phase 1: Scatter x_sf into padded per-expert sections (GPU-only)
|
|
total_rows = x_sf.shape[0]
|
|
row_indices = self._row_indices_buf[:total_rows]
|
|
expert_assign = torch.searchsorted(
|
|
expert_offsets[1:], row_indices, right=True
|
|
).clamp(max=self.num_experts - 1)
|
|
local_row = row_indices - expert_offsets[expert_assign]
|
|
dst_rows = padded_expert_offsets[expert_assign] + local_row
|
|
padded_x_sf[dst_rows, :K_sf] = x_sf
|
|
|
|
# Phase 2: Full-buffer swizzle (no CPU sync, no Python loops)
|
|
# During graph capture, Python view ops (reshape, transpose) are not allowed.
|
|
# Use CUDA swizzle kernel instead.
|
|
rows = padded_x_sf.shape[0]
|
|
cols = padded_x_sf.shape[1]
|
|
if torch.cuda.is_current_stream_capturing():
|
|
from dsv4.kernels.cuda.loader import get_cuda_module
|
|
mod = get_cuda_module("blackwell_swizzle", ["blackwell_swizzle.cu"])
|
|
out_buf = self._padded_x_sf_swizzled_buf_l1 if padded_x_sf is self._padded_x_sf_buf_l1 else self._padded_x_sf_swizzled_buf_l2
|
|
mod.blackwell_swizzle_32_4_4(
|
|
padded_x_sf.view(torch.uint8), out_buf.view(torch.uint8),
|
|
rows, cols
|
|
)
|
|
return out_buf.view(torch.float8_e4m3fn).reshape(rows, cols)
|
|
# Eager path: Python view operations
|
|
R = rows // 128
|
|
C = cols // 4
|
|
blocks = padded_x_sf.view(R, 128, C, 4).permute(0, 2, 1, 3)
|
|
rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16)
|
|
swizzled = rearranged.flatten().view(torch.float8_e4m3fn)
|
|
return swizzled.reshape(rows, cols)
|
|
|
|
def compute_activation_global_scales(self, hidden_states_sample, topk_weights, topk_ids):
|
|
"""Compute activation global scales from a warmup forward pass.
|
|
|
|
Called BEFORE cudagraph capture. Uses the SAME padded GEMM path as run()
|
|
to ensure kernel JIT happens with the same layout, and L2 gs is computed
|
|
from actual L1 output (not an approximation).
|
|
"""
|
|
self._ensure_stacked()
|
|
device = hidden_states_sample.device
|
|
num_tokens = hidden_states_sample.shape[0]
|
|
top_k = topk_ids.shape[1]
|
|
|
|
with torch.no_grad():
|
|
# Build slot mapping (same as run())
|
|
flat_ids = topk_ids.reshape(-1)
|
|
num_slots = num_tokens * top_k
|
|
token_indices = self._token_indices[:num_slots]
|
|
sort_idx = flat_ids.argsort(stable=True)
|
|
sorted_ids = flat_ids[sort_idx]
|
|
sorted_token_ids = token_indices[sort_idx]
|
|
slot_hidden = hidden_states_sample[sorted_token_ids]
|
|
|
|
# L1: get exact gs from quantize_to_nvfp4
|
|
_, _, l1_gs = quantize_to_nvfp4(slot_hidden)
|
|
|
|
# Quantize slot_hidden for GEMM
|
|
slot_x_fp4, slot_x_sf = quantize_activation_nvfp4(slot_hidden, l1_gs)
|
|
|
|
# Compute tokens_per_expert — CUDA-graph-safe alternative to torch.bincount.
|
|
# torch.bincount produces data-dependent shapes (violates graph capture).
|
|
# Instead, use scatter_add_ into a pre-allocated buffer (fixed shape, GPU-only).
|
|
self._tokens_per_expert_buf.zero_()
|
|
# scatter_add_ requires int64 indices — ensure sorted_ids is int64
|
|
sorted_ids_i64 = sorted_ids.long()
|
|
n_slots = sorted_ids_i64.shape[0]
|
|
if not hasattr(self, '_ones_buf') or self._ones_buf.shape[0] < n_slots:
|
|
self._ones_buf = torch.ones(self.max_num_tokens * self.top_k, dtype=self._tokens_per_expert_buf.dtype, device=sorted_ids_i64.device)
|
|
self._tokens_per_expert_buf.scatter_add_(0, sorted_ids_i64, self._ones_buf[:n_slots])
|
|
tokens_per_expert = self._tokens_per_expert_buf[:self.num_experts]
|
|
expert_offsets = self._expert_offsets_buf
|
|
expert_offsets.zero_()
|
|
expert_offsets[1:self.num_experts + 1] = tokens_per_expert.cumsum(0)
|
|
|
|
padded_tokens_per_expert = ((tokens_per_expert + 127) // 128) * 128
|
|
padded_expert_offsets = self._padded_expert_offsets_buf
|
|
padded_expert_offsets.zero_()
|
|
padded_expert_offsets[1:self.num_experts + 1] = padded_tokens_per_expert.cumsum(0)
|
|
|
|
# Compute padded_dst (same as run())
|
|
row_indices = self._row_indices_buf[:num_slots]
|
|
expert_assign = torch.searchsorted(
|
|
expert_offsets[1:], row_indices, right=True
|
|
).clamp(max=self.num_experts - 1)
|
|
local_row = row_indices - expert_offsets[expert_assign]
|
|
padded_dst = padded_expert_offsets[expert_assign] + local_row
|
|
|
|
# Scatter x_fp4 into padded layout
|
|
padded_x_fp4 = self._shared_bufs['hidden_fp4']
|
|
padded_x_fp4.view(torch.uint8).zero_()
|
|
padded_x_fp4.view(torch.uint8)[padded_dst] = slot_x_fp4.view(torch.uint8)
|
|
|
|
l1_scale_a = self._assemble_scales_cudagraph_safe(
|
|
slot_x_sf, expert_offsets[:self.num_experts + 1],
|
|
padded_expert_offsets,
|
|
self._padded_x_sf_buf_l1, self._per_expert_scale_bufs_l1
|
|
)
|
|
# l1_gsa: pre-allocated buffer, no per-call allocation
|
|
self._l1_gsa_buf.fill_(l1_gs)
|
|
l1_gsa = self._l1_gsa_buf
|
|
|
|
l1_out = run_nvfp4_grouped_gemm(
|
|
mat_a=padded_x_fp4, mat_b=self._l1_mat_b,
|
|
scale_a=l1_scale_a, scale_b=self._l1_scale_b,
|
|
expert_offsets=padded_expert_offsets[1:self.num_experts + 1],
|
|
global_scale_a=l1_gsa, global_scale_b=self._l1_gsb,
|
|
)
|
|
|
|
# Extract real token outputs
|
|
l1_out_real = l1_out[padded_dst]
|
|
|
|
# L2: get exact gs from SiLU(gate)*up
|
|
# De-interleave L1 output: with interleaved weights, L1 GEMM
|
|
# output has [gate]*4, [up]*4 pattern. De-interleave before splitting.
|
|
l1_deil = deinterleave_l1_weights(l1_out_real.unsqueeze(0).contiguous())[0]
|
|
gate = l1_deil[:, :self.intermediate_size]
|
|
up = l1_deil[:, self.intermediate_size:]
|
|
gate_silu = torch.nn.functional.silu(gate)
|
|
if self._swiglu_limit is not None:
|
|
gate_silu = gate_silu.clamp(max=self._swiglu_limit)
|
|
up = up.clamp(min=-self._swiglu_limit, max=self._swiglu_limit)
|
|
activated = gate_silu * up
|
|
_, _, l2_gs = quantize_to_nvfp4(activated)
|
|
|
|
self._l1_activation_global_scale = l1_gs
|
|
self._l2_activation_global_scale = l2_gs
|
|
|
|
|
|
|
|
def run(self, hidden_states, topk_weights, topk_ids, expert_indices=None):
|
|
"""Forward: route tokens to experts, GEMM, combine.
|
|
|
|
Uses torch.library.custom_op (nvfp4::moe_gemm) so torch.compile
|
|
treats this as an opaque op. The custom op calls _run_impl internally.
|
|
"""
|
|
if not hasattr(self, '_runner_id'):
|
|
self._runner_id = register_runner(self)
|
|
return nvfp4_moe_gemm(
|
|
hidden_states, topk_weights, topk_ids,
|
|
self._runner_id, self.hidden_size,
|
|
)
|
|
|
|
def _run_impl(self, hidden_states, topk_weights, topk_ids, expert_indices=None):
|
|
"""Run the NVFP4 MoE forward pass.
|
|
|
|
Handles global→local expert ID remapping for expert parallelism.
|
|
Fully cudagraph-safe: no CPU-GPU syncs, no dynamic shapes.
|
|
|
|
Each expert's slots are padded to multiples of 128 for the GEMM.
|
|
expert_offsets is [0, padded_e0, padded_e0+padded_e1, ...].
|
|
scale_a is produced at those same offsets.
|
|
"""
|
|
num_tokens = hidden_states.shape[0]
|
|
top_k = topk_ids.shape[1]
|
|
device = hidden_states.device
|
|
|
|
self._ensure_stacked()
|
|
|
|
# -- Remap global expert IDs to local IDs --
|
|
local_ids = topk_ids - self.experts_start_idx
|
|
local_mask = (local_ids >= 0) & (local_ids < self.num_experts)
|
|
safe_ids = local_ids.clamp(0, self.num_experts - 1)
|
|
safe_weights = topk_weights * local_mask.float()
|
|
|
|
# -- Build slot mapping --
|
|
flat_ids = safe_ids.reshape(-1)
|
|
flat_weights = safe_weights.reshape(-1)
|
|
num_slots = num_tokens * top_k
|
|
token_indices = self._token_indices[:num_slots]
|
|
|
|
sort_idx = flat_ids.argsort(stable=True)
|
|
sorted_ids = flat_ids[sort_idx]
|
|
sorted_weights = flat_weights[sort_idx]
|
|
sorted_token_ids = token_indices[sort_idx]
|
|
|
|
# Expert offsets (real token counts)
|
|
# CUDA-graph-safe: scatter_add_ instead of bincount (fixed shape, GPU-only)
|
|
self._tokens_per_expert_buf.zero_()
|
|
sorted_ids_i64 = sorted_ids.long()
|
|
n_slots = sorted_ids_i64.shape[0]
|
|
if not hasattr(self, '_ones_buf') or self._ones_buf.shape[0] < n_slots:
|
|
self._ones_buf = torch.ones(self.max_num_tokens * self.top_k, dtype=self._tokens_per_expert_buf.dtype, device=sorted_ids_i64.device)
|
|
self._tokens_per_expert_buf.scatter_add_(0, sorted_ids_i64, self._ones_buf[:n_slots])
|
|
tokens_per_expert = self._tokens_per_expert_buf[:self.num_experts]
|
|
expert_offsets = self._expert_offsets_buf
|
|
expert_offsets.zero_()
|
|
expert_offsets[1:self.num_experts + 1] = tokens_per_expert.cumsum(0)
|
|
|
|
# Pad each expert to 128-row alignment (GPU-only computation)
|
|
padded_tokens_per_expert = ((tokens_per_expert + 127) // 128) * 128
|
|
padded_expert_offsets = self._padded_expert_offsets_buf
|
|
padded_expert_offsets.zero_()
|
|
padded_expert_offsets[1:self.num_experts + 1] = padded_tokens_per_expert.cumsum(0)
|
|
total_padded_slots = padded_expert_offsets[self.num_experts]
|
|
|
|
# -- Gather hidden states into slot order, compute padded_dst --
|
|
slot_hidden = hidden_states[sorted_token_ids]
|
|
row_indices = self._row_indices_buf[:num_slots]
|
|
expert_assign = torch.searchsorted(
|
|
expert_offsets[1:], row_indices, right=True
|
|
).clamp(max=self.num_experts - 1)
|
|
local_row = row_indices - expert_offsets[expert_assign]
|
|
padded_dst = padded_expert_offsets[expert_assign] + local_row
|
|
|
|
# === L1: gate + up ===
|
|
# Fused amax + quantize: single kernel, zero CPU-GPU syncs.
|
|
# Computes amax on GPU → derives gsa → quantizes to NVFP4.
|
|
# gsa written to GPU buffer for GEMM global_scale_a.
|
|
if getattr(self, '_use_runtime_gsa', False):
|
|
from dsv4.ops.quantize import quantize_nvfp4_gpu_fused
|
|
slot_x_fp4, slot_x_sf, gsa_l1_gpu = quantize_nvfp4_gpu_fused(slot_hidden)
|
|
self._l1_gsa_buf[0] = gsa_l1_gpu[0] # scalar GPU→GPU, no sync, graph-capturable
|
|
else:
|
|
slot_x_fp4, slot_x_sf = quantize_nvfp4_gpu(
|
|
slot_hidden, self._l1_activation_global_scale
|
|
)
|
|
# Scatter x_fp4 into padded layout for the GEMM
|
|
# Must scatter as uint8 (float4_e2m1fn_x2 doesn't support index_put)
|
|
padded_x_fp4 = self._shared_bufs['hidden_fp4']
|
|
padded_x_fp4.view(torch.uint8).zero_()
|
|
padded_x_fp4.view(torch.uint8)[padded_dst] = slot_x_fp4.view(torch.uint8)
|
|
|
|
l1_scale_a = self._assemble_scales_cudagraph_safe(
|
|
slot_x_sf, expert_offsets[:self.num_experts + 1],
|
|
padded_expert_offsets,
|
|
self._padded_x_sf_buf_l1, self._per_expert_scale_bufs_l1
|
|
)
|
|
l1_gsa = self._l1_gsa_buf # already filled by GPU compute (no .fill_ needed)
|
|
|
|
if self._fused_swiglu:
|
|
# === Fused L1 GEMM + SwiGLU in kernel registers ===
|
|
l1_out = run_fused_swiglu_grouped_gemm(
|
|
mat_a=padded_x_fp4, mat_b=self._l1_mat_b,
|
|
scale_a=l1_scale_a, scale_b=self._l1_scale_b,
|
|
expert_offsets=padded_expert_offsets[1:self.num_experts + 1],
|
|
global_scale_a=l1_gsa, global_scale_b=self._l1_gsb,
|
|
swiglu_limit=self._swiglu_limit if self._swiglu_limit is not None else 0.0,
|
|
out=self._l1_out_buf,
|
|
)
|
|
l1_out_real = l1_out[padded_dst]
|
|
# Fused deinterleave + amax + quantize: zero CPU syncs.
|
|
# Computes gsa from de-interleaved SwiGLU output on GPU,
|
|
# quantizes in the same kernel. Writes gsa to GPU buffer.
|
|
if getattr(self, '_use_runtime_gsa', False):
|
|
from dsv4.ops.quantize import deinterleave_amax_quantize_nvfp4_fused
|
|
slot_l2_x_fp4, slot_l2_x_sf, gsa_l2_gpu = deinterleave_amax_quantize_nvfp4_fused(
|
|
l1_out_real, self.intermediate_size)
|
|
self._l2_gsa_buf[0] = gsa_l2_gpu[0] # scalar GPU→GPU, no sync, graph-capturable
|
|
else:
|
|
slot_l2_x_fp4, slot_l2_x_sf = deinterleave_quantize_nvfp4_cuda(
|
|
l1_out_real, self.intermediate_size, self._l2_activation_global_scale
|
|
)
|
|
else:
|
|
# === Non-fused L1 GEMM + PyTorch SiLU(gate)*up ===
|
|
l1_out = run_nvfp4_grouped_gemm(
|
|
mat_a=padded_x_fp4, mat_b=self._l1_mat_b,
|
|
scale_a=l1_scale_a, scale_b=self._l1_scale_b,
|
|
expert_offsets=padded_expert_offsets[1:self.num_experts + 1],
|
|
global_scale_a=l1_gsa, global_scale_b=self._l1_gsb,
|
|
out=self._l1_out_buf,
|
|
)
|
|
l1_out_real = l1_out[padded_dst]
|
|
l1_deil = deinterleave_l1_weights(l1_out_real.unsqueeze(0).contiguous())[0]
|
|
gate = l1_deil[:, :self.intermediate_size]
|
|
up = l1_deil[:, self.intermediate_size:]
|
|
gate_silu = torch.nn.functional.silu(gate)
|
|
if self._swiglu_limit is not None:
|
|
gate_silu = gate_silu.clamp(max=self._swiglu_limit)
|
|
up = up.clamp(min=-self._swiglu_limit, max=self._swiglu_limit)
|
|
activated = gate_silu * up
|
|
|
|
# Compute runtime gsa for L2 from activated output (non-fused path)
|
|
# Fused amax + quantize: zero CPU syncs.
|
|
if not self._fused_swiglu and getattr(self, '_use_runtime_gsa', False):
|
|
from dsv4.ops.quantize import quantize_nvfp4_gpu_fused
|
|
slot_l2_x_fp4, slot_l2_x_sf, gsa_l2_gpu = quantize_nvfp4_gpu_fused(activated)
|
|
self._l2_gsa_buf[0] = gsa_l2_gpu[0] # scalar GPU→GPU, no sync, graph-capturable
|
|
elif not self._fused_swiglu:
|
|
slot_l2_x_fp4, slot_l2_x_sf = quantize_nvfp4_gpu(
|
|
activated, self._l2_activation_global_scale
|
|
)
|
|
padded_activated_fp4 = self._shared_bufs['activated_fp4']
|
|
padded_activated_fp4.view(torch.uint8).zero_()
|
|
padded_activated_fp4.view(torch.uint8)[padded_dst] = slot_l2_x_fp4.view(torch.uint8)
|
|
|
|
l2_scale_a = self._assemble_scales_cudagraph_safe(
|
|
slot_l2_x_sf, expert_offsets[:self.num_experts + 1],
|
|
padded_expert_offsets,
|
|
self._padded_x_sf_buf_l2, self._per_expert_scale_bufs_l2
|
|
)
|
|
l2_gsa = self._l2_gsa_buf # already filled by GPU compute (no .fill_ needed)
|
|
|
|
l2_out = run_nvfp4_grouped_gemm(
|
|
mat_a=padded_activated_fp4, mat_b=self._l2_mat_b,
|
|
scale_a=l2_scale_a, scale_b=self._l2_scale_b,
|
|
expert_offsets=padded_expert_offsets[1:self.num_experts + 1],
|
|
global_scale_a=l2_gsa, global_scale_b=self._l2_gsb,
|
|
out=self._l2_out_buf,
|
|
)
|
|
|
|
l2_out_real = l2_out[padded_dst]
|
|
|
|
# === Scatter -> final output ===
|
|
y = self._output_buf[:num_tokens]
|
|
y.zero_()
|
|
weighted_out = l2_out_real * sorted_weights.unsqueeze(1).to(l2_out_real.dtype)
|
|
y.scatter_add_(
|
|
0,
|
|
sorted_token_ids.unsqueeze(1).expand(-1, self.hidden_size),
|
|
weighted_out,
|
|
)
|
|
|
|
return y
|