fix: root-cause JIT memory corruption myth, add eager warmup, remove _needs_token_refill

Bug #1 fix: The _needs_token_refill workaround was a band-aid over a
misdiagnosis. cute.compile does NOT corrupt GPU memory (verified on B200).
The original corruption was from a different bug (likely OOB write or
weight loading issue).

Changes:
- bridge.py: Add warmup_compilation() for eager JIT before runtime buffers
  exist. Pre-allocate workspace per cache entry (no torch.full in hot path).
  Cache stores {compiled, workspace, workspace_size} instead of just compiled.
  CuTe tensor wrappers re-created per call (cheap metadata, avoids stale refs).
- runner.py: Remove _needs_token_refill hack. Add eager warmup call in
  _ensure_stacked() for both L1 and L2 GEMM shapes.
- nvfp4_linear.py: Add eager warmup in finalize_weights() for single GEMM.

The warmup approach ensures cute.compile runs exactly once per shape during
model init, before any forward pass. This is deterministic and eliminates
any possible interaction between JIT and runtime GPU memory.
This commit is contained in:
2026-05-20 02:08:01 +00:00
parent 039a9e27d6
commit cc6b094450
3 changed files with 210 additions and 66 deletions

View File

@@ -28,7 +28,18 @@ from cutedsl.kernel.moe.torch_scaled_grouped_mm import (
E2M1_MAGNITUDES = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0]
# Cache compiled kernels by (num_experts, device, K, N)
# Cache compiled kernels + pre-allocated workspace by cache_key
# Each entry: {'compiled': callable, 'workspace': Tensor, 'workspace_size': int}
#
# Key design decisions (Bug #1 fix):
# - cute.compile does NOT corrupt GPU memory (verified 2026-05-20 on B200).
# The original _needs_token_refill hack was a misdiagnosis. The real bug
# was elsewhere (likely OOB write or weight loading).
# - Workspace is pre-allocated per cache entry during warmup_compilation()
# and reused on subsequent calls. No torch.full() in the hot path.
# - CuTe tensor wrappers (from_dlpack + mark_layout_dynamic) are cheap
# metadata wrappers. We re-create them per call from real tensors.
# Caching them would hold stale references to tensors that get freed.
_compiled_kernel_cache = {}
# Cached LUT for E2M1 quantization (created once per device, cudagraph-safe)
@@ -282,6 +293,106 @@ def compute_expert_offsets(tokens_per_expert, num_experts, device="cuda"):
# ── Kernel Launch ─────────────────────────────────────────────────────
def warmup_compilation(num_experts, K_packed, N_packed, device,
mma_tiler_mn=(128, 128), cluster_shape_mn=(1, 1)):
"""Eagerly JIT-compile the GEMM kernel for a specific shape.
Call this during model initialization (before any runtime buffers
are allocated) to ensure cute.compile runs exactly once per shape,
eliminating any risk of JIT interacting with runtime GPU memory.
After warmup, run_nvfp4_grouped_gemm will hit the cache and skip
compilation entirely on the forward path.
Args:
num_experts: number of experts (local, after expert parallelism)
K_packed: K dimension in float4 elements (i.e. K_original // 2)
N_packed: N dimension in float4 elements (i.e. N_original // 2)
device: 'cuda:X' or 'cuda'
mma_tiler_mn: GEMM tiling (default (128,128))
cluster_shape_mn: cluster shape (default (1,1))
"""
cache_key = (num_experts, str(device), mma_tiler_mn, cluster_shape_mn,
K_packed, N_packed)
if cache_key in _compiled_kernel_cache:
return # Already compiled
# Allocate minimal dummy tensors for compilation
mat_a = torch.zeros(128, K_packed, dtype=torch.uint8, device=device).view(torch.float4_e2m1fn_x2)
mat_b = torch.zeros(num_experts, K_packed, N_packed, dtype=torch.uint8, device=device).view(torch.float4_e2m1fn_x2)
K_sf = ceil_div(K_packed, 16) # K in scale-factor blocks (K_packed is already //2, sf is //16 of original)
N_sf = ceil_div(N_packed, 16)
scale_a = torch.zeros(128, K_sf, dtype=torch.float8_e4m3fn, device=device)
scale_b = torch.zeros(num_experts, N_sf, K_sf, dtype=torch.float8_e4m3fn, device=device)
out = torch.zeros(128, N_packed, dtype=torch.bfloat16, device=device)
expert_offsets = torch.full((num_experts,), max(128 // num_experts, 1), dtype=torch.int32, device=device)
global_scale_a = torch.ones(num_experts, dtype=torch.float32, device=device)
global_scale_b = torch.ones(num_experts, dtype=torch.float32, device=device)
kernel = ScaledGroupedGemmKernel(
scenario="2Dx3D",
sf_vec_size=SF_VEC_SIZE,
accumulate_on_output=False,
separate_tensormap_init=True,
consistent_token_padding=False,
mma_tiler_mnk=(*mma_tiler_mn, 256),
cluster_shape_mnk=(*cluster_shape_mn, 1),
)
def to_cute(t):
ct = cutlass_torch.from_dlpack(t)
return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t))
a_c = to_cute(mat_a)
b_c = to_cute(mat_b)
sfa_c = to_cute(scale_a)
sfb_c = to_cute(scale_b)
c_c = to_cute(out)
offs_c = to_cute(expert_offsets)
workspace_size = kernel.get_workspace_size(num_experts)
workspace = torch.full((workspace_size,), 255, dtype=torch.uint8, device=device)
ws_c = to_cute(workspace)
gsa_c = to_cute(global_scale_a)
gsb_c = to_cute(global_scale_b)
import cuda.bindings.driver as cuda
cluster_size = cluster_shape_mn[0] * cluster_shape_mn[1]
max_active_clusters = utils.HardwareInfo().get_max_active_clusters(cluster_size)
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
compiled = cute.compile(
kernel, a_c, b_c, sfa_c, sfb_c, c_c, offs_c, ws_c,
max_active_clusters, stream,
global_scale_a=gsa_c, global_scale_b=gsb_c,
)
# Warmup run (required by CuTeDSL)
compiled(
a_c, b_c, sfa_c, sfb_c, c_c, offs_c, ws_c, stream,
global_scale_a=gsa_c, global_scale_b=gsb_c,
)
torch.cuda.synchronize()
# Cache compiled kernel + pre-allocated workspace
# (NOT CuTe wrappers — they hold refs to dummy tensors)
_compiled_kernel_cache[cache_key] = {
'compiled': compiled,
'workspace': workspace, # pre-allocated, reused
'workspace_size': workspace_size,
}
# Free dummy data tensors (workspace is kept in cache)
del mat_a, mat_b, scale_a, scale_b, out, expert_offsets
del global_scale_a, global_scale_b
del a_c, b_c, sfa_c, sfb_c, c_c, offs_c, ws_c, gsa_c, gsb_c
torch.cuda.empty_cache()
def run_nvfp4_grouped_gemm(
mat_a, # (tokens_sum, K_packed) float4_e2m1fn_x2
mat_b, # (experts, K_packed, N_packed) float4_e2m1fn_x2, K-major
@@ -297,75 +408,101 @@ def run_nvfp4_grouped_gemm(
2Dx3D: A(tokens, K) x B(experts, K, N) -> C(tokens, N)
CUDAGraph-compatible: kernel is compiled once on first call (warmup),
then only the compiled kernel is invoked on subsequent calls.
Kernel is compiled once (either via warmup_compilation() during init
or lazily on first call), then cached. Workspace is pre-allocated
during warmup and reused — no torch.full() in the hot path.
No torch.cuda.synchronize() or .item() in the forward path.
"""
num_experts = mat_b.shape[0]
n_dim = mat_b.shape[2] # packed N (in float4 elements)
n_dim = mat_b.shape[2]
tokens_sum = mat_a.shape[0]
out = torch.zeros(tokens_sum, n_dim, dtype=torch.bfloat16, device=mat_a.device)
kernel = ScaledGroupedGemmKernel(
scenario="2Dx3D",
sf_vec_size=SF_VEC_SIZE,
accumulate_on_output=False,
separate_tensormap_init=True,
consistent_token_padding=False,
mma_tiler_mnk=(*mma_tiler_mn, 256),
cluster_shape_mnk=(*cluster_shape_mn, 1),
)
cache_key = (num_experts, str(mat_a.device), mma_tiler_mn, cluster_shape_mn,
mat_a.shape[1], mat_b.shape[2])
if cache_key not in _compiled_kernel_cache:
# Lazy compilation — safety net if warmup_compilation wasn't called.
# This should NOT happen in production (warmup is called during init).
kernel = ScaledGroupedGemmKernel(
scenario="2Dx3D",
sf_vec_size=SF_VEC_SIZE,
accumulate_on_output=False,
separate_tensormap_init=True,
consistent_token_padding=False,
mma_tiler_mnk=(*mma_tiler_mn, 256),
cluster_shape_mnk=(*cluster_shape_mn, 1),
)
# Convert to CuTe tensors with dynamic layout
def to_cute(t):
ct = cutlass_torch.from_dlpack(t)
return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t))
a_c = to_cute(mat_a)
b_c = to_cute(mat_b)
sfa_c = to_cute(scale_a)
sfb_c = to_cute(scale_b)
c_c = to_cute(out)
offs_c = to_cute(expert_offsets)
workspace_size = kernel.get_workspace_size(num_experts)
workspace = torch.full((workspace_size,), 255, dtype=torch.uint8, device=mat_a.device)
ws_c = to_cute(workspace)
gsa_c = to_cute(global_scale_a) if global_scale_a is not None else None
gsb_c = to_cute(global_scale_b) if global_scale_b is not None else None
import cuda.bindings.driver as cuda
cluster_size = cluster_shape_mn[0] * cluster_shape_mn[1]
max_active_clusters = utils.HardwareInfo().get_max_active_clusters(cluster_size)
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
compiled = cute.compile(
kernel, a_c, b_c, sfa_c, sfb_c, c_c, offs_c, ws_c,
max_active_clusters, stream,
global_scale_a=gsa_c, global_scale_b=gsb_c,
)
compiled(
a_c, b_c, sfa_c, sfb_c, c_c, offs_c, ws_c, stream,
global_scale_a=gsa_c, global_scale_b=gsb_c,
)
_compiled_kernel_cache[cache_key] = {
'compiled': compiled,
'workspace': workspace,
'workspace_size': workspace_size,
}
# --- Invoke cached kernel ---
entry = _compiled_kernel_cache[cache_key]
compiled = entry['compiled']
workspace = entry['workspace']
# Re-create CuTe wrappers from current tensors each call.
# This is cheap (metadata only, no GPU work) and avoids stale
# references to tensors from previous calls that may have been freed.
def to_cute(t):
ct = cutlass_torch.from_dlpack(t)
return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t))
a_c = to_cute(mat_a)
b_c = to_cute(mat_b)
sfa_c = to_cute(scale_a)
sfb_c = to_cute(scale_b)
c_c = to_cute(out)
offs_c = to_cute(expert_offsets)
workspace_size = kernel.get_workspace_size(num_experts)
workspace = torch.full((workspace_size,), 255, dtype=torch.uint8, device=mat_a.device)
ws_c = to_cute(workspace)
gsa_c = to_cute(global_scale_a) if global_scale_a is not None else None
gsb_c = to_cute(global_scale_b) if global_scale_b is not None else None
import cuda.bindings.driver as cuda
cluster_size = cluster_shape_mn[0] * cluster_shape_mn[1]
max_active_clusters = utils.HardwareInfo().get_max_active_clusters(cluster_size)
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
# Cache compiled kernel by (num_experts, K, N)
cache_key = (num_experts, str(mat_a.device), mma_tiler_mn, cluster_shape_mn,
mat_a.shape[1], mat_b.shape[2])
if cache_key not in _compiled_kernel_cache:
# First call: compile with real tensors
compiled = cute.compile(
kernel, a_c, b_c, sfa_c, sfb_c, c_c, offs_c, ws_c,
max_active_clusters, stream,
global_scale_a=gsa_c, global_scale_b=gsb_c,
)
# Warmup run
compiled(
a_c, b_c, sfa_c, sfb_c, c_c, offs_c, ws_c,
stream,
global_scale_a=gsa_c, global_scale_b=gsb_c,
)
_compiled_kernel_cache[cache_key] = compiled
else:
# Subsequent calls: just invoke the compiled kernel
compiled = _compiled_kernel_cache[cache_key]
compiled(
a_c, b_c, sfa_c, sfb_c, c_c, offs_c, ws_c,
stream,
global_scale_a=gsa_c, global_scale_b=gsb_c,
)
import cuda.bindings.driver as cuda
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
compiled(
a_c, b_c, sfa_c, sfb_c, c_c, offs_c, ws_c, stream,
global_scale_a=gsa_c, global_scale_b=gsb_c,
)
return out

View File

@@ -74,6 +74,13 @@ class CuTeDSLNvfp4Linear:
self.sf = None
self.gs = None
# Eagerly JIT-compile the GEMM kernel for this (K, N) shape.
# Uses num_groups=1 since this is a single linear layer.
from cutedsl.bridge import warmup_compilation
K_packed = self.in_features // 2
N_packed = self.out_features // 2
warmup_compilation(1, K_packed, N_packed, self.device)
def _ensure_buffer_size(self, num_tokens: int):
"""Ensure the padded buffer is large enough for num_tokens."""
needed_rows = cutedsl_ceil_div(num_tokens, 128) * 128

View File

@@ -239,20 +239,26 @@ class CuTeDSLMoERunner:
self.l1_gs = None
self.l2_gs = None
# Allocate buffers AFTER JIT compilation
# (CuTeDSL's cute.compile corrupts GPU memory during JIT;
# tensors allocated before/during compilation may be zeroed)
#
# _token_indices: GPU tensor for cudagraph compatibility.
# CuTeDSL JIT may corrupt GPU memory, so we fill AFTER stacking
# (which triggers the weight JIT). The GEMM JIT in run_nvfp4_grouped_gemm
# is triggered on the first run() call; we refill _token_indices after
# that first call via the _needs_token_refill flag.
# Allocate buffers and eagerly warmup JIT compilation.
# cute.compile does NOT corrupt GPU memory (verified 2026-05-20).
# We warmup eagerly here to ensure compilation happens before
# the model's first forward pass, not during it.
self._token_indices = torch.zeros(
self.max_num_tokens * self.top_k, dtype=torch.int32, device=self.device
)
self._fill_token_indices()
self._needs_token_refill = True # GEMM JIT may corrupt; refill after first run
# No _needs_token_refill: cute.compile does NOT corrupt GPU memory.
# The original corruption was a misdiagnosis (see bridge.py cache docs).
# Eagerly JIT-compile GEMM kernels for L1 and L2 shapes.
# This triggers cute.compile once per shape, caching the compiled
# kernel + workspace. Subsequent run() calls hit the cache.
from cutedsl.bridge import warmup_compilation, ceil_div as bridge_ceil_div
K_packed = self.hidden_size // 2
N_packed_l1 = (2 * self.intermediate_size) // 2 # gate+up combined
N_packed_l2 = self.hidden_size // 2 # down
warmup_compilation(self.num_experts, K_packed, N_packed_l1, self.device)
warmup_compilation(self.num_experts, K_packed, N_packed_l2, self.device)
self._expert_id_range = torch.arange(
self.num_experts, dtype=torch.int32
@@ -578,10 +584,4 @@ class CuTeDSLMoERunner:
weighted_out,
)
# Refill _token_indices after GEMM JIT on first call
# (CuTeDSL's cute.compile may corrupt GPU memory during first GEMM)
if self._needs_token_refill:
self._fill_token_indices()
self._needs_token_refill = False
return y