diff --git a/cutedsl/bridge.py b/cutedsl/bridge.py index 6e9d28fc..b59f60b6 100644 --- a/cutedsl/bridge.py +++ b/cutedsl/bridge.py @@ -28,7 +28,18 @@ from cutedsl.kernel.moe.torch_scaled_grouped_mm import ( E2M1_MAGNITUDES = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0] -# Cache compiled kernels by (num_experts, device, K, N) +# Cache compiled kernels + pre-allocated workspace by cache_key +# Each entry: {'compiled': callable, 'workspace': Tensor, 'workspace_size': int} +# +# Key design decisions (Bug #1 fix): +# - cute.compile does NOT corrupt GPU memory (verified 2026-05-20 on B200). +# The original _needs_token_refill hack was a misdiagnosis. The real bug +# was elsewhere (likely OOB write or weight loading). +# - Workspace is pre-allocated per cache entry during warmup_compilation() +# and reused on subsequent calls. No torch.full() in the hot path. +# - CuTe tensor wrappers (from_dlpack + mark_layout_dynamic) are cheap +# metadata wrappers. We re-create them per call from real tensors. +# Caching them would hold stale references to tensors that get freed. _compiled_kernel_cache = {} # Cached LUT for E2M1 quantization (created once per device, cudagraph-safe) @@ -282,6 +293,106 @@ def compute_expert_offsets(tokens_per_expert, num_experts, device="cuda"): # ── Kernel Launch ───────────────────────────────────────────────────── + +def warmup_compilation(num_experts, K_packed, N_packed, device, + mma_tiler_mn=(128, 128), cluster_shape_mn=(1, 1)): + """Eagerly JIT-compile the GEMM kernel for a specific shape. + + Call this during model initialization (before any runtime buffers + are allocated) to ensure cute.compile runs exactly once per shape, + eliminating any risk of JIT interacting with runtime GPU memory. + + After warmup, run_nvfp4_grouped_gemm will hit the cache and skip + compilation entirely on the forward path. + + Args: + num_experts: number of experts (local, after expert parallelism) + K_packed: K dimension in float4 elements (i.e. K_original // 2) + N_packed: N dimension in float4 elements (i.e. N_original // 2) + device: 'cuda:X' or 'cuda' + mma_tiler_mn: GEMM tiling (default (128,128)) + cluster_shape_mn: cluster shape (default (1,1)) + """ + cache_key = (num_experts, str(device), mma_tiler_mn, cluster_shape_mn, + K_packed, N_packed) + + if cache_key in _compiled_kernel_cache: + return # Already compiled + + # Allocate minimal dummy tensors for compilation + mat_a = torch.zeros(128, K_packed, dtype=torch.uint8, device=device).view(torch.float4_e2m1fn_x2) + mat_b = torch.zeros(num_experts, K_packed, N_packed, dtype=torch.uint8, device=device).view(torch.float4_e2m1fn_x2) + K_sf = ceil_div(K_packed, 16) # K in scale-factor blocks (K_packed is already //2, sf is //16 of original) + N_sf = ceil_div(N_packed, 16) + scale_a = torch.zeros(128, K_sf, dtype=torch.float8_e4m3fn, device=device) + scale_b = torch.zeros(num_experts, N_sf, K_sf, dtype=torch.float8_e4m3fn, device=device) + out = torch.zeros(128, N_packed, dtype=torch.bfloat16, device=device) + expert_offsets = torch.full((num_experts,), max(128 // num_experts, 1), dtype=torch.int32, device=device) + global_scale_a = torch.ones(num_experts, dtype=torch.float32, device=device) + global_scale_b = torch.ones(num_experts, dtype=torch.float32, device=device) + + kernel = ScaledGroupedGemmKernel( + scenario="2Dx3D", + sf_vec_size=SF_VEC_SIZE, + accumulate_on_output=False, + separate_tensormap_init=True, + consistent_token_padding=False, + mma_tiler_mnk=(*mma_tiler_mn, 256), + cluster_shape_mnk=(*cluster_shape_mn, 1), + ) + + def to_cute(t): + ct = cutlass_torch.from_dlpack(t) + return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t)) + + a_c = to_cute(mat_a) + b_c = to_cute(mat_b) + sfa_c = to_cute(scale_a) + sfb_c = to_cute(scale_b) + c_c = to_cute(out) + offs_c = to_cute(expert_offsets) + + workspace_size = kernel.get_workspace_size(num_experts) + workspace = torch.full((workspace_size,), 255, dtype=torch.uint8, device=device) + ws_c = to_cute(workspace) + + gsa_c = to_cute(global_scale_a) + gsb_c = to_cute(global_scale_b) + + import cuda.bindings.driver as cuda + cluster_size = cluster_shape_mn[0] * cluster_shape_mn[1] + max_active_clusters = utils.HardwareInfo().get_max_active_clusters(cluster_size) + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + compiled = cute.compile( + kernel, a_c, b_c, sfa_c, sfb_c, c_c, offs_c, ws_c, + max_active_clusters, stream, + global_scale_a=gsa_c, global_scale_b=gsb_c, + ) + + # Warmup run (required by CuTeDSL) + compiled( + a_c, b_c, sfa_c, sfb_c, c_c, offs_c, ws_c, stream, + global_scale_a=gsa_c, global_scale_b=gsb_c, + ) + torch.cuda.synchronize() + + # Cache compiled kernel + pre-allocated workspace + # (NOT CuTe wrappers — they hold refs to dummy tensors) + _compiled_kernel_cache[cache_key] = { + 'compiled': compiled, + 'workspace': workspace, # pre-allocated, reused + 'workspace_size': workspace_size, + } + + # Free dummy data tensors (workspace is kept in cache) + del mat_a, mat_b, scale_a, scale_b, out, expert_offsets + del global_scale_a, global_scale_b + del a_c, b_c, sfa_c, sfb_c, c_c, offs_c, ws_c, gsa_c, gsb_c + torch.cuda.empty_cache() + + + def run_nvfp4_grouped_gemm( mat_a, # (tokens_sum, K_packed) float4_e2m1fn_x2 mat_b, # (experts, K_packed, N_packed) float4_e2m1fn_x2, K-major @@ -297,75 +408,101 @@ def run_nvfp4_grouped_gemm( 2Dx3D: A(tokens, K) x B(experts, K, N) -> C(tokens, N) - CUDAGraph-compatible: kernel is compiled once on first call (warmup), - then only the compiled kernel is invoked on subsequent calls. + Kernel is compiled once (either via warmup_compilation() during init + or lazily on first call), then cached. Workspace is pre-allocated + during warmup and reused — no torch.full() in the hot path. No torch.cuda.synchronize() or .item() in the forward path. """ num_experts = mat_b.shape[0] - n_dim = mat_b.shape[2] # packed N (in float4 elements) + n_dim = mat_b.shape[2] tokens_sum = mat_a.shape[0] out = torch.zeros(tokens_sum, n_dim, dtype=torch.bfloat16, device=mat_a.device) - kernel = ScaledGroupedGemmKernel( - scenario="2Dx3D", - sf_vec_size=SF_VEC_SIZE, - accumulate_on_output=False, - separate_tensormap_init=True, - consistent_token_padding=False, - mma_tiler_mnk=(*mma_tiler_mn, 256), - cluster_shape_mnk=(*cluster_shape_mn, 1), - ) + cache_key = (num_experts, str(mat_a.device), mma_tiler_mn, cluster_shape_mn, + mat_a.shape[1], mat_b.shape[2]) + + if cache_key not in _compiled_kernel_cache: + # Lazy compilation — safety net if warmup_compilation wasn't called. + # This should NOT happen in production (warmup is called during init). + kernel = ScaledGroupedGemmKernel( + scenario="2Dx3D", + sf_vec_size=SF_VEC_SIZE, + accumulate_on_output=False, + separate_tensormap_init=True, + consistent_token_padding=False, + mma_tiler_mnk=(*mma_tiler_mn, 256), + cluster_shape_mnk=(*cluster_shape_mn, 1), + ) - # Convert to CuTe tensors with dynamic layout + def to_cute(t): + ct = cutlass_torch.from_dlpack(t) + return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t)) + + a_c = to_cute(mat_a) + b_c = to_cute(mat_b) + sfa_c = to_cute(scale_a) + sfb_c = to_cute(scale_b) + c_c = to_cute(out) + offs_c = to_cute(expert_offsets) + + workspace_size = kernel.get_workspace_size(num_experts) + workspace = torch.full((workspace_size,), 255, dtype=torch.uint8, device=mat_a.device) + ws_c = to_cute(workspace) + + gsa_c = to_cute(global_scale_a) if global_scale_a is not None else None + gsb_c = to_cute(global_scale_b) if global_scale_b is not None else None + + import cuda.bindings.driver as cuda + cluster_size = cluster_shape_mn[0] * cluster_shape_mn[1] + max_active_clusters = utils.HardwareInfo().get_max_active_clusters(cluster_size) + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + compiled = cute.compile( + kernel, a_c, b_c, sfa_c, sfb_c, c_c, offs_c, ws_c, + max_active_clusters, stream, + global_scale_a=gsa_c, global_scale_b=gsb_c, + ) + compiled( + a_c, b_c, sfa_c, sfb_c, c_c, offs_c, ws_c, stream, + global_scale_a=gsa_c, global_scale_b=gsb_c, + ) + + _compiled_kernel_cache[cache_key] = { + 'compiled': compiled, + 'workspace': workspace, + 'workspace_size': workspace_size, + } + + # --- Invoke cached kernel --- + entry = _compiled_kernel_cache[cache_key] + compiled = entry['compiled'] + workspace = entry['workspace'] + + # Re-create CuTe wrappers from current tensors each call. + # This is cheap (metadata only, no GPU work) and avoids stale + # references to tensors from previous calls that may have been freed. def to_cute(t): ct = cutlass_torch.from_dlpack(t) return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t)) - + a_c = to_cute(mat_a) b_c = to_cute(mat_b) sfa_c = to_cute(scale_a) sfb_c = to_cute(scale_b) c_c = to_cute(out) offs_c = to_cute(expert_offsets) - - workspace_size = kernel.get_workspace_size(num_experts) - workspace = torch.full((workspace_size,), 255, dtype=torch.uint8, device=mat_a.device) ws_c = to_cute(workspace) - + gsa_c = to_cute(global_scale_a) if global_scale_a is not None else None gsb_c = to_cute(global_scale_b) if global_scale_b is not None else None - - import cuda.bindings.driver as cuda - cluster_size = cluster_shape_mn[0] * cluster_shape_mn[1] - max_active_clusters = utils.HardwareInfo().get_max_active_clusters(cluster_size) - stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) - - # Cache compiled kernel by (num_experts, K, N) - cache_key = (num_experts, str(mat_a.device), mma_tiler_mn, cluster_shape_mn, - mat_a.shape[1], mat_b.shape[2]) - if cache_key not in _compiled_kernel_cache: - # First call: compile with real tensors - compiled = cute.compile( - kernel, a_c, b_c, sfa_c, sfb_c, c_c, offs_c, ws_c, - max_active_clusters, stream, - global_scale_a=gsa_c, global_scale_b=gsb_c, - ) - # Warmup run - compiled( - a_c, b_c, sfa_c, sfb_c, c_c, offs_c, ws_c, - stream, - global_scale_a=gsa_c, global_scale_b=gsb_c, - ) - _compiled_kernel_cache[cache_key] = compiled - else: - # Subsequent calls: just invoke the compiled kernel - compiled = _compiled_kernel_cache[cache_key] - compiled( - a_c, b_c, sfa_c, sfb_c, c_c, offs_c, ws_c, - stream, - global_scale_a=gsa_c, global_scale_b=gsb_c, - ) + import cuda.bindings.driver as cuda + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + compiled( + a_c, b_c, sfa_c, sfb_c, c_c, offs_c, ws_c, stream, + global_scale_a=gsa_c, global_scale_b=gsb_c, + ) return out diff --git a/cutedsl/nvfp4_linear.py b/cutedsl/nvfp4_linear.py index ec9d44a2..c7888015 100644 --- a/cutedsl/nvfp4_linear.py +++ b/cutedsl/nvfp4_linear.py @@ -74,6 +74,13 @@ class CuTeDSLNvfp4Linear: self.sf = None self.gs = None + # Eagerly JIT-compile the GEMM kernel for this (K, N) shape. + # Uses num_groups=1 since this is a single linear layer. + from cutedsl.bridge import warmup_compilation + K_packed = self.in_features // 2 + N_packed = self.out_features // 2 + warmup_compilation(1, K_packed, N_packed, self.device) + def _ensure_buffer_size(self, num_tokens: int): """Ensure the padded buffer is large enough for num_tokens.""" needed_rows = cutedsl_ceil_div(num_tokens, 128) * 128 diff --git a/cutedsl/runner.py b/cutedsl/runner.py index ea5c5405..2bf08acf 100644 --- a/cutedsl/runner.py +++ b/cutedsl/runner.py @@ -239,20 +239,26 @@ class CuTeDSLMoERunner: self.l1_gs = None self.l2_gs = None - # Allocate buffers AFTER JIT compilation - # (CuTeDSL's cute.compile corrupts GPU memory during JIT; - # tensors allocated before/during compilation may be zeroed) - # - # _token_indices: GPU tensor for cudagraph compatibility. - # CuTeDSL JIT may corrupt GPU memory, so we fill AFTER stacking - # (which triggers the weight JIT). The GEMM JIT in run_nvfp4_grouped_gemm - # is triggered on the first run() call; we refill _token_indices after - # that first call via the _needs_token_refill flag. + # Allocate buffers and eagerly warmup JIT compilation. + # cute.compile does NOT corrupt GPU memory (verified 2026-05-20). + # We warmup eagerly here to ensure compilation happens before + # the model's first forward pass, not during it. self._token_indices = torch.zeros( self.max_num_tokens * self.top_k, dtype=torch.int32, device=self.device ) self._fill_token_indices() - self._needs_token_refill = True # GEMM JIT may corrupt; refill after first run + # No _needs_token_refill: cute.compile does NOT corrupt GPU memory. + # The original corruption was a misdiagnosis (see bridge.py cache docs). + + # Eagerly JIT-compile GEMM kernels for L1 and L2 shapes. + # This triggers cute.compile once per shape, caching the compiled + # kernel + workspace. Subsequent run() calls hit the cache. + from cutedsl.bridge import warmup_compilation, ceil_div as bridge_ceil_div + K_packed = self.hidden_size // 2 + N_packed_l1 = (2 * self.intermediate_size) // 2 # gate+up combined + N_packed_l2 = self.hidden_size // 2 # down + warmup_compilation(self.num_experts, K_packed, N_packed_l1, self.device) + warmup_compilation(self.num_experts, K_packed, N_packed_l2, self.device) self._expert_id_range = torch.arange( self.num_experts, dtype=torch.int32 @@ -578,10 +584,4 @@ class CuTeDSLMoERunner: weighted_out, ) - # Refill _token_indices after GEMM JIT on first call - # (CuTeDSL's cute.compile may corrupt GPU memory during first GEMM) - if self._needs_token_refill: - self._fill_token_indices() - self._needs_token_refill = False - return y