"""CuTeDSL NVFP4 Linear (single GEMM) Generic NVFP4 GEMM runner for attention projections and any single linear layer. Uses ScaledGroupedGemmKernel with num_groups=1. CUDA-graph-compatible: all buffers pre-allocated, no CPU-GPU syncs. """ import torch from cutedsl.bridge import ( quantize_activation_nvfp4, quantize_to_nvfp4, make_b_k_major, assemble_scales_3d_side, run_nvfp4_grouped_gemm, ) from cutedsl.kernel.moe.torch_scaled_grouped_mm import ( ceil_div as cutedsl_ceil_div, pad_and_swizzle_single, ) from cutedsl.custom_ops import register_runner, nvfp4_linear_gemm class CuTeDSLNvfp4Linear: """Single NVFP4 GEMM using CuTeDSL (num_groups=1). Handles any (K, N) weight matrix in NVFP4 format. Simple: quantize activation → GEMM → BF16 output. No SiLU, no fusion, no routing. CUDA-graph-compatible: all buffers pre-allocated, no CPU-GPU syncs. """ def __init__( self, in_features: int, out_features: int, max_num_tokens: int = 8192, device: str = "cuda", ): self.in_features = in_features self.out_features = out_features self.max_num_tokens = max_num_tokens self.device = device # Weights (set after construction, then call finalize_weights) self.fp4 = None # list of 1 tensor self.sf = None # list of 1 tensor self.gs = None # list of 1 float # Processed weights self._mat_b = None self._scale_b = None self._gsb = None # Activation global scale self._activation_global_scale = 1.0 / (6.0 * 448.0) # Pre-allocated buffers self._padded_x_fp4_buf = None self._expert_offsets_buf = None self._gsa_buf = None self._buffers_allocated = False def finalize_weights(self): """Process weights for CuTeDSL GEMM.""" self._mat_b = make_b_k_major(torch.stack(self.fp4)) # (1, K_packed, N_packed) self._scale_b = assemble_scales_3d_side(self.sf) self._gsb = torch.tensor(self.gs, dtype=torch.float32, device=self.device) # Free raw weights self.fp4 = None self.sf = None self.gs = None def _allocate_buffers(self): """Pre-allocate buffers at max size for cudagraph compatibility.""" max_rows = cutedsl_ceil_div(self.max_num_tokens, 128) * 128 self._padded_x_fp4_buf = torch.zeros( max_rows, self.in_features // 2, dtype=torch.uint8, device=self.device ).view(torch.float4_e2m1fn_x2) self._expert_offsets_buf = torch.zeros(1, dtype=torch.int32, device=self.device) self._gsa_buf = torch.zeros(1, dtype=torch.float32, device=self.device) self._buffers_allocated = True def _ensure_initialized(self): if self._mat_b is None: self.finalize_weights() if not self._buffers_allocated: self._allocate_buffers() def _assemble_scales_single_group(self, x_sf): """Assemble 2D-side activation scales for num_groups=1.""" num_rows, num_cols = x_sf.shape padded_rows = cutedsl_ceil_div(num_rows, 128) * 128 padded_cols = cutedsl_ceil_div(num_cols, 4) * 4 buf = torch.zeros(padded_rows, padded_cols, dtype=torch.float16, device=x_sf.device).to(torch.float8_e4m3fn) buf[:num_rows, :num_cols] = x_sf swizzled_flat = pad_and_swizzle_single(buf) return swizzled_flat.reshape(padded_rows, padded_cols) def compute_activation_global_scale(self, hidden_states_sample): """Compute activation global scale from a warmup forward.""" self._ensure_initialized() with torch.no_grad(): _, _, gs = quantize_to_nvfp4(hidden_states_sample) self._activation_global_scale = gs def run(self, hidden_states: torch.Tensor) -> torch.Tensor: """Forward: BF16 input → NVFP4 GEMM → BF16 output. Uses torch.library.custom_op (nvfp4::linear_gemm) so torch.compile treats this as an opaque op. The custom op calls _run_impl internally. """ if not hasattr(self, '_runner_id'): self._runner_id = register_runner(self) return nvfp4_linear_gemm( hidden_states, self._runner_id, self.out_features, ) def _run_impl(self, hidden_states: torch.Tensor) -> torch.Tensor: """Actual implementation — called via custom autograd to be torch.compile-safe.""" self._ensure_initialized() num_tokens = hidden_states.shape[0] padded_rows = cutedsl_ceil_div(num_tokens, 128) * 128 # Quantize activation x_fp4, x_sf = quantize_activation_nvfp4( hidden_states, self._activation_global_scale ) # Scatter x_fp4 into padded buffer padded_x_fp4 = self._padded_x_fp4_buf padded_x_fp4.view(torch.uint8).zero_() padded_x_fp4.view(torch.uint8)[:num_tokens] = x_fp4.view(torch.uint8) # Assemble A-side scales scale_a = self._assemble_scales_single_group(x_sf) # Expert offsets: [padded_rows] for 1 group expert_offsets = self._expert_offsets_buf expert_offsets.fill_(padded_rows) # Global scales gsa = self._gsa_buf.fill_(self._activation_global_scale) # Run GEMM out = run_nvfp4_grouped_gemm( mat_a=padded_x_fp4, mat_b=self._mat_b, scale_a=scale_a, scale_b=self._scale_b, expert_offsets=expert_offsets, global_scale_a=gsa, global_scale_b=self._gsb, ) return out[:num_tokens] def __call__(self, hidden_states: torch.Tensor) -> torch.Tensor: return self.run(hidden_states)