fix: only suppress compile message, still warmup all layers
CuTeDSL caches kernels by (M, N, K) shape. Different layer shapes (L1 vs L2, different expert counts) trigger new compiles. We can't skip the warmup call — only suppress the print spam. Flag now gates the message, not the warmup.
This commit is contained in:
@@ -506,26 +506,22 @@ class DeepseekV4MegaMoEExperts(nn.Module):
|
||||
self.w2_input_scale = None
|
||||
|
||||
# Warm up the CuTeDSL kernel (JIT compiles MLIR→PTX on first call).
|
||||
# Only compile once per process — the kernel is cached after that.
|
||||
# This takes ~1-2 min but prevents the vLLM RPC timeout (5 min) from
|
||||
# killing the engine when the first inference request triggers compilation.
|
||||
# Note: The MMA tiler is (128,128,256) — we need >= 128 tokens to fill
|
||||
# a tile. Using 128 tokens, 1 expert for the warmup.
|
||||
# CuTeDSL caches compiled kernels by (M, N, K) shape, so different
|
||||
# layer shapes may trigger additional compiles. We only print the
|
||||
# compile message once to avoid spam (61 layers × 8 ranks).
|
||||
# The MMA tiler needs >= 128 tokens; using 128 for the warmup.
|
||||
if not DeepseekV4MegaMoEExperts._cutedsl_compiled:
|
||||
DeepseekV4MegaMoEExperts._cutedsl_compiled = True
|
||||
print(" Compiling CuTeDSL NVFP4 MegaMoE kernel (one-time JIT, ~1-2 min)...", flush=True)
|
||||
try:
|
||||
device = self._cutedsl_runner.l1_fp4[0].device
|
||||
dummy_hidden = torch.randn(128, self.hidden_size, dtype=torch.bfloat16, device=device)
|
||||
dummy_ids = torch.zeros(128, 1, dtype=torch.int32, device=device) # all to expert 0
|
||||
dummy_weights = torch.ones(128, 1, dtype=torch.float32, device=device)
|
||||
self._cutedsl_runner.run(dummy_hidden, dummy_weights, dummy_ids, expert_indices=[0])
|
||||
print(" CuTeDSL NVFP4 MegaMoE kernel compiled ✓", flush=True)
|
||||
except Exception as exc:
|
||||
# CUDA illegal memory access corrupts the context — can't recover.
|
||||
# Log the error clearly so the user knows to check the kernel.
|
||||
print(f" CuTeDSL warmup FAILED: {exc}", flush=True)
|
||||
print(" The CUDA context may be corrupted. Check kernel alignment/tiling.", flush=True)
|
||||
print(" Compiling CuTeDSL NVFP4 MegaMoE kernels (one-time JIT, ~1-2 min)...", flush=True)
|
||||
try:
|
||||
device = self._cutedsl_runner.l1_fp4[0].device
|
||||
dummy_hidden = torch.randn(128, self.hidden_size, dtype=torch.bfloat16, device=device)
|
||||
dummy_ids = torch.zeros(128, 1, dtype=torch.int32, device=device)
|
||||
dummy_weights = torch.ones(128, 1, dtype=torch.float32, device=device)
|
||||
self._cutedsl_runner.run(dummy_hidden, dummy_weights, dummy_ids, expert_indices=[0])
|
||||
except Exception as exc:
|
||||
print(f" CuTeDSL warmup FAILED: {exc}", flush=True)
|
||||
print(" The CUDA context may be corrupted. Check kernel alignment/tiling.", flush=True)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user