From 3838561c19faa482ada59e2bcf2121672a4ad05a Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 16 May 2026 05:18:10 +0000 Subject: [PATCH] fix: only suppress compile message, still warmup all layers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit CuTeDSL caches kernels by (M, N, K) shape. Different layer shapes (L1 vs L2, different expert counts) trigger new compiles. We can't skip the warmup call — only suppress the print spam. Flag now gates the message, not the warmup. --- vllm/patches/deepseek_v4.py | 32 ++++++++++++++------------------ 1 file changed, 14 insertions(+), 18 deletions(-) diff --git a/vllm/patches/deepseek_v4.py b/vllm/patches/deepseek_v4.py index 939a4737..0ad96432 100644 --- a/vllm/patches/deepseek_v4.py +++ b/vllm/patches/deepseek_v4.py @@ -506,26 +506,22 @@ class DeepseekV4MegaMoEExperts(nn.Module): self.w2_input_scale = None # Warm up the CuTeDSL kernel (JIT compiles MLIR→PTX on first call). - # Only compile once per process — the kernel is cached after that. - # This takes ~1-2 min but prevents the vLLM RPC timeout (5 min) from - # killing the engine when the first inference request triggers compilation. - # Note: The MMA tiler is (128,128,256) — we need >= 128 tokens to fill - # a tile. Using 128 tokens, 1 expert for the warmup. + # CuTeDSL caches compiled kernels by (M, N, K) shape, so different + # layer shapes may trigger additional compiles. We only print the + # compile message once to avoid spam (61 layers × 8 ranks). + # The MMA tiler needs >= 128 tokens; using 128 for the warmup. if not DeepseekV4MegaMoEExperts._cutedsl_compiled: DeepseekV4MegaMoEExperts._cutedsl_compiled = True - print(" Compiling CuTeDSL NVFP4 MegaMoE kernel (one-time JIT, ~1-2 min)...", flush=True) - try: - device = self._cutedsl_runner.l1_fp4[0].device - dummy_hidden = torch.randn(128, self.hidden_size, dtype=torch.bfloat16, device=device) - dummy_ids = torch.zeros(128, 1, dtype=torch.int32, device=device) # all to expert 0 - dummy_weights = torch.ones(128, 1, dtype=torch.float32, device=device) - self._cutedsl_runner.run(dummy_hidden, dummy_weights, dummy_ids, expert_indices=[0]) - print(" CuTeDSL NVFP4 MegaMoE kernel compiled ✓", flush=True) - except Exception as exc: - # CUDA illegal memory access corrupts the context — can't recover. - # Log the error clearly so the user knows to check the kernel. - print(f" CuTeDSL warmup FAILED: {exc}", flush=True) - print(" The CUDA context may be corrupted. Check kernel alignment/tiling.", flush=True) + print(" Compiling CuTeDSL NVFP4 MegaMoE kernels (one-time JIT, ~1-2 min)...", flush=True) + try: + device = self._cutedsl_runner.l1_fp4[0].device + dummy_hidden = torch.randn(128, self.hidden_size, dtype=torch.bfloat16, device=device) + dummy_ids = torch.zeros(128, 1, dtype=torch.int32, device=device) + dummy_weights = torch.ones(128, 1, dtype=torch.float32, device=device) + self._cutedsl_runner.run(dummy_hidden, dummy_weights, dummy_ids, expert_indices=[0]) + except Exception as exc: + print(f" CuTeDSL warmup FAILED: {exc}", flush=True) + print(" The CUDA context may be corrupted. Check kernel alignment/tiling.", flush=True) def forward( self,