"""CuTeDSL Shared Expert Pipeline NVFP4 inference for DeepSeek V4 shared experts. Uses ScaledGroupedGemmKernel with num_groups=1. Pipeline: 1. Quantize activation: BF16 → NVFP4 (using warmup gs) 2. L1 GEMM: NVFP4_act × NVFP4_weight(gate_up) → BF16 3. SiLU(gate) * up → BF16 4. Re-quantize: BF16 → NVFP4 (using warmup gs) 5. L2 GEMM: NVFP4_act × NVFP4_weight(down) → BF16 Unlike MoE, there's no routing, no scatter, no expert offsets. All tokens go through the same expert (the shared expert). Scale assembly is just: quantize activation → pad to 128-row alignment → Blackwell swizzle. CUDA-graph-compatible: all buffers pre-allocated, no CPU-GPU syncs, no dynamic shapes. Padding rows are zeros that contribute nothing to GEMM output. """ import torch from dsv4.ops.quantize import ( quantize_activation_nvfp4, quantize_to_nvfp4, ) from dsv4.ops.layouts import ( make_b_k_major, assemble_scales_3d_side, ) from dsv4.ops.gemm_runner import ( run_nvfp4_grouped_gemm, ) from dsv4.kernels.gemm.grouped import ( ceil_div as cutedsl_ceil_div, pad_and_swizzle_single, ) class _SharedExpertApply(torch.autograd.Function): """Custom autograd function to make CuTeDSL runner opaque to torch.compile.""" @staticmethod def forward(ctx, runner, hidden_states): return runner._run_impl(hidden_states) class Nvfp4SharedExpert: """NVFP4 shared expert runner using CuTeDSL GEMM (num_groups=1). CUDA-graph-compatible: all buffers pre-allocated, no CPU-GPU syncs. """ def __init__( self, hidden_size: int, intermediate_size: int, max_num_tokens: int = 8192, device: str = "cuda", swiglu_limit: float = 10.0, ): self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.max_num_tokens = max_num_tokens self.device = device self.swiglu_limit = swiglu_limit # Weights (set after construction, then call finalize_weights) self.l1_fp4 = None self.l1_sf = None self.l1_gs = None self.l2_fp4 = None self.l2_sf = None self.l2_gs = None # Processed weights (set by finalize_weights) self._l1_mat_b = None self._l2_mat_b = None self._l1_scale_b = None self._l2_scale_b = None self._l1_gsb = None self._l2_gsb = None # Activation global scales (set by compute_activation_global_scales) self._l1_activation_global_scale = 1.0 / (6.0 * 448.0) self._l2_activation_global_scale = 1.0 / (6.0 * 448.0) # Pre-allocated cudagraph buffers (set in _allocate_buffers) self._padded_x_fp4_buf_l1 = None self._padded_x_sf_buf_l1 = None self._padded_x_fp4_buf_l2 = None self._padded_x_sf_buf_l2 = None self._l1_gsa_buf = None self._l2_gsa_buf = None self._expert_offsets_buf = None self._buffers_allocated = False def set_swiglu_limit(self, limit: float): self.swiglu_limit = limit def finalize_weights(self): """Process weights for CuTeDSL GEMM. Must be called after setting l1/l2 weights.""" # Stack weights and convert to K-major # l1_fp4/l2_fp4 are lists with 1 element (the shared expert) self._l1_mat_b = make_b_k_major(torch.stack(self.l1_fp4)) # (1, K_packed, N_packed) self._l2_mat_b = make_b_k_major(torch.stack(self.l2_fp4)) self._l1_scale_b = assemble_scales_3d_side(self.l1_sf) # (1, N, K_sf_padded) self._l2_scale_b = assemble_scales_3d_side(self.l2_sf) self._l1_gsb = torch.tensor(self.l1_gs, dtype=torch.float32, device=self.device) self._l2_gsb = torch.tensor(self.l2_gs, dtype=torch.float32, device=self.device) # Free raw weights self.l1_fp4 = None self.l1_sf = None self.l1_gs = None self.l2_fp4 = None self.l2_sf = None self.l2_gs = None def _allocate_buffers(self): """Pre-allocate all buffers at max size for cudagraph compatibility.""" max_rows = cutedsl_ceil_div(self.max_num_tokens, 128) * 128 # pad to 128 # L1: hidden_size packed, L2: intermediate_size packed self._padded_x_fp4_buf_l1 = torch.zeros( max_rows, self.hidden_size // 2, dtype=torch.uint8, device=self.device ).view(torch.float4_e2m1fn_x2) self._padded_x_fp4_buf_l2 = torch.zeros( max_rows, self.intermediate_size // 2, dtype=torch.uint8, device=self.device ).view(torch.float4_e2m1fn_x2) # Padded scale buffers (need same padded dimensions as pad_and_swizzle_single produces) K_sf_l1 = cutedsl_ceil_div(self.hidden_size, 16) padded_cols_l1 = cutedsl_ceil_div(K_sf_l1, 4) * 4 K_sf_l2 = cutedsl_ceil_div(self.intermediate_size, 16) padded_cols_l2 = cutedsl_ceil_div(K_sf_l2, 4) * 4 self._padded_x_sf_buf_l1 = torch.zeros( max_rows, padded_cols_l1, dtype=torch.float16, device=self.device ).to(torch.float8_e4m3fn) self._padded_x_sf_buf_l2 = torch.zeros( max_rows, padded_cols_l2, dtype=torch.float16, device=self.device ).to(torch.float8_e4m3fn) # Global scale buffers self._l1_gsa_buf = torch.zeros(1, dtype=torch.float32, device=self.device) self._l2_gsa_buf = torch.zeros(1, dtype=torch.float32, device=self.device) # Expert offsets for num_groups=1: just [num_tokens_padded] # The GEMM expects expert_offsets as (num_experts,) cumulative offsets # For 1 expert: offsets = [num_tokens] (just one element) self._expert_offsets_buf = torch.zeros(1, dtype=torch.int32, device=self.device) self._buffers_allocated = True def _ensure_initialized(self): """Lazily initialize stacked weights and buffers.""" if self._l1_mat_b is None: self.finalize_weights() if not self._buffers_allocated: self._allocate_buffers() def _assemble_scales_single_group(self, x_sf, num_tokens, padded_x_sf_buf): """Assemble 2D-side activation scales for num_groups=1. For a single group, scale assembly is just: 1. Copy x_sf into a correctly-sized buffer (padded to 128 rows, 4 cols) 2. Apply pad_and_swizzle_single (Blackwell swizzle) 3. Reshape back to 2D (kernel expects 2D scale_a) The padded buffer must be sized exactly for 128-aligned num_tokens, NOT the max_num_tokens buffer (which would be way too large). """ num_rows, num_cols = x_sf.shape padded_rows = cutedsl_ceil_div(num_rows, 128) * 128 padded_cols = cutedsl_ceil_div(num_cols, 4) * 4 # Use a temp buffer sized for this exact token count buf = torch.zeros(padded_rows, padded_cols, dtype=torch.float16, device=x_sf.device).to(torch.float8_e4m3fn) buf[:num_rows, :num_cols] = x_sf swizzled_flat = pad_and_swizzle_single(buf) return swizzled_flat.reshape(padded_rows, padded_cols) def compute_activation_global_scales(self, hidden_states_sample): """Compute activation global scales from a warmup forward pass. Called BEFORE cudagraph capture. Uses quantize_to_nvfp4 to get the exact global_scale from the data, then runs L1 to compute L2 gs from actual SiLU(gate)*up output. """ self._ensure_initialized() with torch.no_grad(): # L1: exact gs from quantize_to_nvfp4 _, _, l1_gs = quantize_to_nvfp4(hidden_states_sample) self._l1_activation_global_scale = l1_gs # Run L1 GEMM to get intermediate for L2 gs num_tokens = hidden_states_sample.shape[0] l1_out = self._run_l1(hidden_states_sample) if l1_out is not None and not torch.isnan(l1_out).any(): gate = l1_out[:, :self.intermediate_size] up = l1_out[:, self.intermediate_size:] if self.swiglu_limit is not None: gate = gate.clamp(max=self.swiglu_limit) up = up.clamp(min=-self.swiglu_limit, max=self.swiglu_limit) activated = torch.nn.functional.silu(gate) * up _, _, l2_gs = quantize_to_nvfp4(activated) self._l2_activation_global_scale = l2_gs def _run_l1(self, hidden_states: torch.Tensor) -> torch.Tensor: """L1 GEMM: activation × gate_up_weight → BF16.""" num_tokens = hidden_states.shape[0] padded_rows = cutedsl_ceil_div(num_tokens, 128) * 128 # Quantize activation x_fp4, x_sf = quantize_activation_nvfp4( hidden_states, self._l1_activation_global_scale ) # Scatter x_fp4 into padded buffer padded_x_fp4 = self._padded_x_fp4_buf_l1 padded_x_fp4.view(torch.uint8).zero_() padded_x_fp4.view(torch.uint8)[:num_tokens] = x_fp4.view(torch.uint8) # Assemble A-side scales scale_a = self._assemble_scales_single_group(x_sf, num_tokens, self._padded_x_sf_buf_l1) # Expert offsets: [padded_rows] for 1 group expert_offsets = self._expert_offsets_buf expert_offsets.fill_(padded_rows) # Global scales gsa = self._l1_gsa_buf.fill_(self._l1_activation_global_scale) # Run GEMM out = run_nvfp4_grouped_gemm( mat_a=padded_x_fp4, mat_b=self._l1_mat_b, scale_a=scale_a, scale_b=self._l1_scale_b, expert_offsets=expert_offsets, global_scale_a=gsa, global_scale_b=self._l1_gsb, ) # Extract real token outputs return out[:num_tokens] def _run_l2(self, intermediate: torch.Tensor) -> torch.Tensor: """L2 GEMM: intermediate × down_weight → BF16.""" num_tokens = intermediate.shape[0] padded_rows = cutedsl_ceil_div(num_tokens, 128) * 128 # Quantize activation x_fp4, x_sf = quantize_activation_nvfp4( intermediate, self._l2_activation_global_scale ) # Scatter into padded buffer padded_x_fp4 = self._padded_x_fp4_buf_l2 padded_x_fp4.view(torch.uint8).zero_() padded_x_fp4.view(torch.uint8)[:num_tokens] = x_fp4.view(torch.uint8) # Assemble A-side scales scale_a = self._assemble_scales_single_group(x_sf, num_tokens, self._padded_x_sf_buf_l2) # Expert offsets expert_offsets = self._expert_offsets_buf expert_offsets.fill_(padded_rows) # Global scales gsa = self._l2_gsa_buf.fill_(self._l2_activation_global_scale) # Run GEMM out = run_nvfp4_grouped_gemm( mat_a=padded_x_fp4, mat_b=self._l2_mat_b, scale_a=scale_a, scale_b=self._l2_scale_b, expert_offsets=expert_offsets, global_scale_a=gsa, global_scale_b=self._l2_gsb, ) return out[:num_tokens] def run(self, hidden_states: torch.Tensor) -> torch.Tensor: """Full shared expert forward: L1 → SiLU → L2 → output.""" return _SharedExpertApply.apply(self, hidden_states) def _run_impl(self, hidden_states: torch.Tensor) -> torch.Tensor: """Actual implementation — called via custom autograd to be torch.compile-safe.""" self._ensure_initialized() l1_out = self._run_l1(hidden_states) gate = l1_out[:, :self.intermediate_size] up = l1_out[:, self.intermediate_size:] if self.swiglu_limit is not None: # Match SiluAndMulWithClamp: clamp gate BEFORE silu, clamp up to [-limit, limit] gate = gate.clamp(max=self.swiglu_limit) up = up.clamp(min=-self.swiglu_limit, max=self.swiglu_limit) intermediate = torch.nn.functional.silu(gate) * up output = self._run_l2(intermediate) return output