fix: root-cause JIT memory corruption myth, add eager warmup, remove _needs_token_refill
Bug #1 fix: The _needs_token_refill workaround was a band-aid over a misdiagnosis. cute.compile does NOT corrupt GPU memory (verified on B200). The original corruption was from a different bug (likely OOB write or weight loading issue). Changes: - bridge.py: Add warmup_compilation() for eager JIT before runtime buffers exist. Pre-allocate workspace per cache entry (no torch.full in hot path). Cache stores {compiled, workspace, workspace_size} instead of just compiled. CuTe tensor wrappers re-created per call (cheap metadata, avoids stale refs). - runner.py: Remove _needs_token_refill hack. Add eager warmup call in _ensure_stacked() for both L1 and L2 GEMM shapes. - nvfp4_linear.py: Add eager warmup in finalize_weights() for single GEMM. The warmup approach ensures cute.compile runs exactly once per shape during model init, before any forward pass. This is deterministic and eliminates any possible interaction between JIT and runtime GPU memory.
This commit is contained in:
@@ -28,7 +28,18 @@ from cutedsl.kernel.moe.torch_scaled_grouped_mm import (
|
||||
|
||||
E2M1_MAGNITUDES = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0]
|
||||
|
||||
# Cache compiled kernels by (num_experts, device, K, N)
|
||||
# Cache compiled kernels + pre-allocated workspace by cache_key
|
||||
# Each entry: {'compiled': callable, 'workspace': Tensor, 'workspace_size': int}
|
||||
#
|
||||
# Key design decisions (Bug #1 fix):
|
||||
# - cute.compile does NOT corrupt GPU memory (verified 2026-05-20 on B200).
|
||||
# The original _needs_token_refill hack was a misdiagnosis. The real bug
|
||||
# was elsewhere (likely OOB write or weight loading).
|
||||
# - Workspace is pre-allocated per cache entry during warmup_compilation()
|
||||
# and reused on subsequent calls. No torch.full() in the hot path.
|
||||
# - CuTe tensor wrappers (from_dlpack + mark_layout_dynamic) are cheap
|
||||
# metadata wrappers. We re-create them per call from real tensors.
|
||||
# Caching them would hold stale references to tensors that get freed.
|
||||
_compiled_kernel_cache = {}
|
||||
|
||||
# Cached LUT for E2M1 quantization (created once per device, cudagraph-safe)
|
||||
@@ -282,6 +293,106 @@ def compute_expert_offsets(tokens_per_expert, num_experts, device="cuda"):
|
||||
|
||||
# ── Kernel Launch ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def warmup_compilation(num_experts, K_packed, N_packed, device,
|
||||
mma_tiler_mn=(128, 128), cluster_shape_mn=(1, 1)):
|
||||
"""Eagerly JIT-compile the GEMM kernel for a specific shape.
|
||||
|
||||
Call this during model initialization (before any runtime buffers
|
||||
are allocated) to ensure cute.compile runs exactly once per shape,
|
||||
eliminating any risk of JIT interacting with runtime GPU memory.
|
||||
|
||||
After warmup, run_nvfp4_grouped_gemm will hit the cache and skip
|
||||
compilation entirely on the forward path.
|
||||
|
||||
Args:
|
||||
num_experts: number of experts (local, after expert parallelism)
|
||||
K_packed: K dimension in float4 elements (i.e. K_original // 2)
|
||||
N_packed: N dimension in float4 elements (i.e. N_original // 2)
|
||||
device: 'cuda:X' or 'cuda'
|
||||
mma_tiler_mn: GEMM tiling (default (128,128))
|
||||
cluster_shape_mn: cluster shape (default (1,1))
|
||||
"""
|
||||
cache_key = (num_experts, str(device), mma_tiler_mn, cluster_shape_mn,
|
||||
K_packed, N_packed)
|
||||
|
||||
if cache_key in _compiled_kernel_cache:
|
||||
return # Already compiled
|
||||
|
||||
# Allocate minimal dummy tensors for compilation
|
||||
mat_a = torch.zeros(128, K_packed, dtype=torch.uint8, device=device).view(torch.float4_e2m1fn_x2)
|
||||
mat_b = torch.zeros(num_experts, K_packed, N_packed, dtype=torch.uint8, device=device).view(torch.float4_e2m1fn_x2)
|
||||
K_sf = ceil_div(K_packed, 16) # K in scale-factor blocks (K_packed is already //2, sf is //16 of original)
|
||||
N_sf = ceil_div(N_packed, 16)
|
||||
scale_a = torch.zeros(128, K_sf, dtype=torch.float8_e4m3fn, device=device)
|
||||
scale_b = torch.zeros(num_experts, N_sf, K_sf, dtype=torch.float8_e4m3fn, device=device)
|
||||
out = torch.zeros(128, N_packed, dtype=torch.bfloat16, device=device)
|
||||
expert_offsets = torch.full((num_experts,), max(128 // num_experts, 1), dtype=torch.int32, device=device)
|
||||
global_scale_a = torch.ones(num_experts, dtype=torch.float32, device=device)
|
||||
global_scale_b = torch.ones(num_experts, dtype=torch.float32, device=device)
|
||||
|
||||
kernel = ScaledGroupedGemmKernel(
|
||||
scenario="2Dx3D",
|
||||
sf_vec_size=SF_VEC_SIZE,
|
||||
accumulate_on_output=False,
|
||||
separate_tensormap_init=True,
|
||||
consistent_token_padding=False,
|
||||
mma_tiler_mnk=(*mma_tiler_mn, 256),
|
||||
cluster_shape_mnk=(*cluster_shape_mn, 1),
|
||||
)
|
||||
|
||||
def to_cute(t):
|
||||
ct = cutlass_torch.from_dlpack(t)
|
||||
return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t))
|
||||
|
||||
a_c = to_cute(mat_a)
|
||||
b_c = to_cute(mat_b)
|
||||
sfa_c = to_cute(scale_a)
|
||||
sfb_c = to_cute(scale_b)
|
||||
c_c = to_cute(out)
|
||||
offs_c = to_cute(expert_offsets)
|
||||
|
||||
workspace_size = kernel.get_workspace_size(num_experts)
|
||||
workspace = torch.full((workspace_size,), 255, dtype=torch.uint8, device=device)
|
||||
ws_c = to_cute(workspace)
|
||||
|
||||
gsa_c = to_cute(global_scale_a)
|
||||
gsb_c = to_cute(global_scale_b)
|
||||
|
||||
import cuda.bindings.driver as cuda
|
||||
cluster_size = cluster_shape_mn[0] * cluster_shape_mn[1]
|
||||
max_active_clusters = utils.HardwareInfo().get_max_active_clusters(cluster_size)
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
|
||||
compiled = cute.compile(
|
||||
kernel, a_c, b_c, sfa_c, sfb_c, c_c, offs_c, ws_c,
|
||||
max_active_clusters, stream,
|
||||
global_scale_a=gsa_c, global_scale_b=gsb_c,
|
||||
)
|
||||
|
||||
# Warmup run (required by CuTeDSL)
|
||||
compiled(
|
||||
a_c, b_c, sfa_c, sfb_c, c_c, offs_c, ws_c, stream,
|
||||
global_scale_a=gsa_c, global_scale_b=gsb_c,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Cache compiled kernel + pre-allocated workspace
|
||||
# (NOT CuTe wrappers — they hold refs to dummy tensors)
|
||||
_compiled_kernel_cache[cache_key] = {
|
||||
'compiled': compiled,
|
||||
'workspace': workspace, # pre-allocated, reused
|
||||
'workspace_size': workspace_size,
|
||||
}
|
||||
|
||||
# Free dummy data tensors (workspace is kept in cache)
|
||||
del mat_a, mat_b, scale_a, scale_b, out, expert_offsets
|
||||
del global_scale_a, global_scale_b
|
||||
del a_c, b_c, sfa_c, sfb_c, c_c, offs_c, ws_c, gsa_c, gsb_c
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
|
||||
def run_nvfp4_grouped_gemm(
|
||||
mat_a, # (tokens_sum, K_packed) float4_e2m1fn_x2
|
||||
mat_b, # (experts, K_packed, N_packed) float4_e2m1fn_x2, K-major
|
||||
@@ -297,75 +408,101 @@ def run_nvfp4_grouped_gemm(
|
||||
|
||||
2Dx3D: A(tokens, K) x B(experts, K, N) -> C(tokens, N)
|
||||
|
||||
CUDAGraph-compatible: kernel is compiled once on first call (warmup),
|
||||
then only the compiled kernel is invoked on subsequent calls.
|
||||
Kernel is compiled once (either via warmup_compilation() during init
|
||||
or lazily on first call), then cached. Workspace is pre-allocated
|
||||
during warmup and reused — no torch.full() in the hot path.
|
||||
No torch.cuda.synchronize() or .item() in the forward path.
|
||||
"""
|
||||
num_experts = mat_b.shape[0]
|
||||
n_dim = mat_b.shape[2] # packed N (in float4 elements)
|
||||
n_dim = mat_b.shape[2]
|
||||
tokens_sum = mat_a.shape[0]
|
||||
|
||||
out = torch.zeros(tokens_sum, n_dim, dtype=torch.bfloat16, device=mat_a.device)
|
||||
|
||||
kernel = ScaledGroupedGemmKernel(
|
||||
scenario="2Dx3D",
|
||||
sf_vec_size=SF_VEC_SIZE,
|
||||
accumulate_on_output=False,
|
||||
separate_tensormap_init=True,
|
||||
consistent_token_padding=False,
|
||||
mma_tiler_mnk=(*mma_tiler_mn, 256),
|
||||
cluster_shape_mnk=(*cluster_shape_mn, 1),
|
||||
)
|
||||
cache_key = (num_experts, str(mat_a.device), mma_tiler_mn, cluster_shape_mn,
|
||||
mat_a.shape[1], mat_b.shape[2])
|
||||
|
||||
if cache_key not in _compiled_kernel_cache:
|
||||
# Lazy compilation — safety net if warmup_compilation wasn't called.
|
||||
# This should NOT happen in production (warmup is called during init).
|
||||
kernel = ScaledGroupedGemmKernel(
|
||||
scenario="2Dx3D",
|
||||
sf_vec_size=SF_VEC_SIZE,
|
||||
accumulate_on_output=False,
|
||||
separate_tensormap_init=True,
|
||||
consistent_token_padding=False,
|
||||
mma_tiler_mnk=(*mma_tiler_mn, 256),
|
||||
cluster_shape_mnk=(*cluster_shape_mn, 1),
|
||||
)
|
||||
|
||||
# Convert to CuTe tensors with dynamic layout
|
||||
def to_cute(t):
|
||||
ct = cutlass_torch.from_dlpack(t)
|
||||
return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t))
|
||||
|
||||
a_c = to_cute(mat_a)
|
||||
b_c = to_cute(mat_b)
|
||||
sfa_c = to_cute(scale_a)
|
||||
sfb_c = to_cute(scale_b)
|
||||
c_c = to_cute(out)
|
||||
offs_c = to_cute(expert_offsets)
|
||||
|
||||
workspace_size = kernel.get_workspace_size(num_experts)
|
||||
workspace = torch.full((workspace_size,), 255, dtype=torch.uint8, device=mat_a.device)
|
||||
ws_c = to_cute(workspace)
|
||||
|
||||
gsa_c = to_cute(global_scale_a) if global_scale_a is not None else None
|
||||
gsb_c = to_cute(global_scale_b) if global_scale_b is not None else None
|
||||
|
||||
import cuda.bindings.driver as cuda
|
||||
cluster_size = cluster_shape_mn[0] * cluster_shape_mn[1]
|
||||
max_active_clusters = utils.HardwareInfo().get_max_active_clusters(cluster_size)
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
|
||||
compiled = cute.compile(
|
||||
kernel, a_c, b_c, sfa_c, sfb_c, c_c, offs_c, ws_c,
|
||||
max_active_clusters, stream,
|
||||
global_scale_a=gsa_c, global_scale_b=gsb_c,
|
||||
)
|
||||
compiled(
|
||||
a_c, b_c, sfa_c, sfb_c, c_c, offs_c, ws_c, stream,
|
||||
global_scale_a=gsa_c, global_scale_b=gsb_c,
|
||||
)
|
||||
|
||||
_compiled_kernel_cache[cache_key] = {
|
||||
'compiled': compiled,
|
||||
'workspace': workspace,
|
||||
'workspace_size': workspace_size,
|
||||
}
|
||||
|
||||
# --- Invoke cached kernel ---
|
||||
entry = _compiled_kernel_cache[cache_key]
|
||||
compiled = entry['compiled']
|
||||
workspace = entry['workspace']
|
||||
|
||||
# Re-create CuTe wrappers from current tensors each call.
|
||||
# This is cheap (metadata only, no GPU work) and avoids stale
|
||||
# references to tensors from previous calls that may have been freed.
|
||||
def to_cute(t):
|
||||
ct = cutlass_torch.from_dlpack(t)
|
||||
return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t))
|
||||
|
||||
|
||||
a_c = to_cute(mat_a)
|
||||
b_c = to_cute(mat_b)
|
||||
sfa_c = to_cute(scale_a)
|
||||
sfb_c = to_cute(scale_b)
|
||||
c_c = to_cute(out)
|
||||
offs_c = to_cute(expert_offsets)
|
||||
|
||||
workspace_size = kernel.get_workspace_size(num_experts)
|
||||
workspace = torch.full((workspace_size,), 255, dtype=torch.uint8, device=mat_a.device)
|
||||
ws_c = to_cute(workspace)
|
||||
|
||||
|
||||
gsa_c = to_cute(global_scale_a) if global_scale_a is not None else None
|
||||
gsb_c = to_cute(global_scale_b) if global_scale_b is not None else None
|
||||
|
||||
import cuda.bindings.driver as cuda
|
||||
cluster_size = cluster_shape_mn[0] * cluster_shape_mn[1]
|
||||
max_active_clusters = utils.HardwareInfo().get_max_active_clusters(cluster_size)
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
|
||||
# Cache compiled kernel by (num_experts, K, N)
|
||||
cache_key = (num_experts, str(mat_a.device), mma_tiler_mn, cluster_shape_mn,
|
||||
mat_a.shape[1], mat_b.shape[2])
|
||||
|
||||
if cache_key not in _compiled_kernel_cache:
|
||||
# First call: compile with real tensors
|
||||
compiled = cute.compile(
|
||||
kernel, a_c, b_c, sfa_c, sfb_c, c_c, offs_c, ws_c,
|
||||
max_active_clusters, stream,
|
||||
global_scale_a=gsa_c, global_scale_b=gsb_c,
|
||||
)
|
||||
# Warmup run
|
||||
compiled(
|
||||
a_c, b_c, sfa_c, sfb_c, c_c, offs_c, ws_c,
|
||||
stream,
|
||||
global_scale_a=gsa_c, global_scale_b=gsb_c,
|
||||
)
|
||||
_compiled_kernel_cache[cache_key] = compiled
|
||||
else:
|
||||
# Subsequent calls: just invoke the compiled kernel
|
||||
compiled = _compiled_kernel_cache[cache_key]
|
||||
compiled(
|
||||
a_c, b_c, sfa_c, sfb_c, c_c, offs_c, ws_c,
|
||||
stream,
|
||||
global_scale_a=gsa_c, global_scale_b=gsb_c,
|
||||
)
|
||||
import cuda.bindings.driver as cuda
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
|
||||
compiled(
|
||||
a_c, b_c, sfa_c, sfb_c, c_c, offs_c, ws_c, stream,
|
||||
global_scale_a=gsa_c, global_scale_b=gsb_c,
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
@@ -74,6 +74,13 @@ class CuTeDSLNvfp4Linear:
|
||||
self.sf = None
|
||||
self.gs = None
|
||||
|
||||
# Eagerly JIT-compile the GEMM kernel for this (K, N) shape.
|
||||
# Uses num_groups=1 since this is a single linear layer.
|
||||
from cutedsl.bridge import warmup_compilation
|
||||
K_packed = self.in_features // 2
|
||||
N_packed = self.out_features // 2
|
||||
warmup_compilation(1, K_packed, N_packed, self.device)
|
||||
|
||||
def _ensure_buffer_size(self, num_tokens: int):
|
||||
"""Ensure the padded buffer is large enough for num_tokens."""
|
||||
needed_rows = cutedsl_ceil_div(num_tokens, 128) * 128
|
||||
|
||||
@@ -239,20 +239,26 @@ class CuTeDSLMoERunner:
|
||||
self.l1_gs = None
|
||||
self.l2_gs = None
|
||||
|
||||
# Allocate buffers AFTER JIT compilation
|
||||
# (CuTeDSL's cute.compile corrupts GPU memory during JIT;
|
||||
# tensors allocated before/during compilation may be zeroed)
|
||||
#
|
||||
# _token_indices: GPU tensor for cudagraph compatibility.
|
||||
# CuTeDSL JIT may corrupt GPU memory, so we fill AFTER stacking
|
||||
# (which triggers the weight JIT). The GEMM JIT in run_nvfp4_grouped_gemm
|
||||
# is triggered on the first run() call; we refill _token_indices after
|
||||
# that first call via the _needs_token_refill flag.
|
||||
# Allocate buffers and eagerly warmup JIT compilation.
|
||||
# cute.compile does NOT corrupt GPU memory (verified 2026-05-20).
|
||||
# We warmup eagerly here to ensure compilation happens before
|
||||
# the model's first forward pass, not during it.
|
||||
self._token_indices = torch.zeros(
|
||||
self.max_num_tokens * self.top_k, dtype=torch.int32, device=self.device
|
||||
)
|
||||
self._fill_token_indices()
|
||||
self._needs_token_refill = True # GEMM JIT may corrupt; refill after first run
|
||||
# No _needs_token_refill: cute.compile does NOT corrupt GPU memory.
|
||||
# The original corruption was a misdiagnosis (see bridge.py cache docs).
|
||||
|
||||
# Eagerly JIT-compile GEMM kernels for L1 and L2 shapes.
|
||||
# This triggers cute.compile once per shape, caching the compiled
|
||||
# kernel + workspace. Subsequent run() calls hit the cache.
|
||||
from cutedsl.bridge import warmup_compilation, ceil_div as bridge_ceil_div
|
||||
K_packed = self.hidden_size // 2
|
||||
N_packed_l1 = (2 * self.intermediate_size) // 2 # gate+up combined
|
||||
N_packed_l2 = self.hidden_size // 2 # down
|
||||
warmup_compilation(self.num_experts, K_packed, N_packed_l1, self.device)
|
||||
warmup_compilation(self.num_experts, K_packed, N_packed_l2, self.device)
|
||||
|
||||
self._expert_id_range = torch.arange(
|
||||
self.num_experts, dtype=torch.int32
|
||||
@@ -578,10 +584,4 @@ class CuTeDSLMoERunner:
|
||||
weighted_out,
|
||||
)
|
||||
|
||||
# Refill _token_indices after GEMM JIT on first call
|
||||
# (CuTeDSL's cute.compile may corrupt GPU memory during first GEMM)
|
||||
if self._needs_token_refill:
|
||||
self._fill_token_indices()
|
||||
self._needs_token_refill = False
|
||||
|
||||
return y
|
||||
|
||||
Reference in New Issue
Block a user