fix: only suppress compile message, still warmup all layers

CuTeDSL caches kernels by (M, N, K) shape. Different layer shapes
(L1 vs L2, different expert counts) trigger new compiles. We can't
skip the warmup call — only suppress the print spam.

Flag now gates the message, not the warmup.
This commit is contained in:
2026-05-16 05:18:10 +00:00
parent f19932d8db
commit 3838561c19

View File

@@ -506,26 +506,22 @@ class DeepseekV4MegaMoEExperts(nn.Module):
self.w2_input_scale = None
# Warm up the CuTeDSL kernel (JIT compiles MLIR→PTX on first call).
# Only compile once per process — the kernel is cached after that.
# This takes ~1-2 min but prevents the vLLM RPC timeout (5 min) from
# killing the engine when the first inference request triggers compilation.
# Note: The MMA tiler is (128,128,256) — we need >= 128 tokens to fill
# a tile. Using 128 tokens, 1 expert for the warmup.
# CuTeDSL caches compiled kernels by (M, N, K) shape, so different
# layer shapes may trigger additional compiles. We only print the
# compile message once to avoid spam (61 layers × 8 ranks).
# The MMA tiler needs >= 128 tokens; using 128 for the warmup.
if not DeepseekV4MegaMoEExperts._cutedsl_compiled:
DeepseekV4MegaMoEExperts._cutedsl_compiled = True
print(" Compiling CuTeDSL NVFP4 MegaMoE kernel (one-time JIT, ~1-2 min)...", flush=True)
try:
device = self._cutedsl_runner.l1_fp4[0].device
dummy_hidden = torch.randn(128, self.hidden_size, dtype=torch.bfloat16, device=device)
dummy_ids = torch.zeros(128, 1, dtype=torch.int32, device=device) # all to expert 0
dummy_weights = torch.ones(128, 1, dtype=torch.float32, device=device)
self._cutedsl_runner.run(dummy_hidden, dummy_weights, dummy_ids, expert_indices=[0])
print(" CuTeDSL NVFP4 MegaMoE kernel compiled ✓", flush=True)
except Exception as exc:
# CUDA illegal memory access corrupts the context — can't recover.
# Log the error clearly so the user knows to check the kernel.
print(f" CuTeDSL warmup FAILED: {exc}", flush=True)
print(" The CUDA context may be corrupted. Check kernel alignment/tiling.", flush=True)
print(" Compiling CuTeDSL NVFP4 MegaMoE kernels (one-time JIT, ~1-2 min)...", flush=True)
try:
device = self._cutedsl_runner.l1_fp4[0].device
dummy_hidden = torch.randn(128, self.hidden_size, dtype=torch.bfloat16, device=device)
dummy_ids = torch.zeros(128, 1, dtype=torch.int32, device=device)
dummy_weights = torch.ones(128, 1, dtype=torch.float32, device=device)
self._cutedsl_runner.run(dummy_hidden, dummy_weights, dummy_ids, expert_indices=[0])
except Exception as exc:
print(f" CuTeDSL warmup FAILED: {exc}", flush=True)
print(" The CUDA context may be corrupted. Check kernel alignment/tiling.", flush=True)
def forward(
self,