fix: warmup with 128 tokens (fills MMA tile), better error handling

The CuTeDSL kernel uses MMA tiler (128,128,256). With only 1 token,
the kernel can't fill a tile and may access illegal memory. Using 128
tokens for the warmup.

Also improved error message — after CUDA illegal memory access, the
context is corrupted and can't recover.
This commit is contained in:
2026-05-16 04:56:45 +00:00
parent a70d2d3984
commit cf0731cf4b
2 changed files with 12 additions and 6 deletions

View File

@@ -406,7 +406,7 @@ class DeepseekV4MegaMoEExperts(nn.Module):
def finalize_weights(self) -> None:
if self._cutedsl_runner is not None and self._cutedsl_runner.l1_fp4 is not None:
return
return # Already finalized
self._check_runtime_supported()
@@ -508,15 +508,21 @@ class DeepseekV4MegaMoEExperts(nn.Module):
# Warm up the CuTeDSL kernel (JIT compiles MLIR→PTX on first call).
# This takes ~1-2 min but prevents the vLLM RPC timeout (5 min) from
# killing the engine when the first inference request triggers compilation.
# Note: The MMA tiler is (128,128,256) — we need >= 128 tokens to fill
# a tile. Using 128 tokens, 1 expert for the warmup.
print(" Compiling CuTeDSL NVFP4 MegaMoE kernel (one-time JIT, ~1-2 min)...", flush=True)
dummy_hidden = torch.randn(1, self.hidden_size, dtype=torch.bfloat16, device=self._cutedsl_runner.l1_fp4[0].device)
dummy_ids = torch.zeros(1, 1, dtype=torch.int32, device=dummy_hidden.device)
dummy_weights = torch.ones(1, 1, dtype=torch.float32, device=dummy_hidden.device)
try:
device = self._cutedsl_runner.l1_fp4[0].device
dummy_hidden = torch.randn(128, self.hidden_size, dtype=torch.bfloat16, device=device)
dummy_ids = torch.zeros(128, 1, dtype=torch.int32, device=device) # all to expert 0
dummy_weights = torch.ones(128, 1, dtype=torch.float32, device=device)
self._cutedsl_runner.run(dummy_hidden, dummy_weights, dummy_ids, expert_indices=[0])
print(" CuTeDSL NVFP4 MegaMoE kernel compiled ✓", flush=True)
except Exception:
print(" CuTeDSL warmup failed (will compile on first inference)", flush=True)
except Exception as exc:
# CUDA illegal memory access corrupts the context — can't recover.
# Log the error clearly so the user knows to check the kernel.
print(f" CuTeDSL warmup FAILED: {exc}", flush=True)
print(" The CUDA context may be corrupted. Check kernel alignment/tiling.", flush=True)
def forward(
self,

View File