diff --git a/vllm/patches/deepseek_v4.py b/vllm/patches/deepseek_v4.py index 2c1b5afc..056ec18a 100644 --- a/vllm/patches/deepseek_v4.py +++ b/vllm/patches/deepseek_v4.py @@ -406,7 +406,7 @@ class DeepseekV4MegaMoEExperts(nn.Module): def finalize_weights(self) -> None: if self._cutedsl_runner is not None and self._cutedsl_runner.l1_fp4 is not None: - return + return # Already finalized self._check_runtime_supported() @@ -508,15 +508,21 @@ class DeepseekV4MegaMoEExperts(nn.Module): # Warm up the CuTeDSL kernel (JIT compiles MLIR→PTX on first call). # This takes ~1-2 min but prevents the vLLM RPC timeout (5 min) from # killing the engine when the first inference request triggers compilation. + # Note: The MMA tiler is (128,128,256) — we need >= 128 tokens to fill + # a tile. Using 128 tokens, 1 expert for the warmup. print(" Compiling CuTeDSL NVFP4 MegaMoE kernel (one-time JIT, ~1-2 min)...", flush=True) - dummy_hidden = torch.randn(1, self.hidden_size, dtype=torch.bfloat16, device=self._cutedsl_runner.l1_fp4[0].device) - dummy_ids = torch.zeros(1, 1, dtype=torch.int32, device=dummy_hidden.device) - dummy_weights = torch.ones(1, 1, dtype=torch.float32, device=dummy_hidden.device) try: + device = self._cutedsl_runner.l1_fp4[0].device + dummy_hidden = torch.randn(128, self.hidden_size, dtype=torch.bfloat16, device=device) + dummy_ids = torch.zeros(128, 1, dtype=torch.int32, device=device) # all to expert 0 + dummy_weights = torch.ones(128, 1, dtype=torch.float32, device=device) self._cutedsl_runner.run(dummy_hidden, dummy_weights, dummy_ids, expert_indices=[0]) print(" CuTeDSL NVFP4 MegaMoE kernel compiled ✓", flush=True) - except Exception: - print(" CuTeDSL warmup failed (will compile on first inference)", flush=True) + except Exception as exc: + # CUDA illegal memory access corrupts the context — can't recover. + # Log the error clearly so the user knows to check the kernel. + print(f" CuTeDSL warmup FAILED: {exc}", flush=True) + print(" The CUDA context may be corrupted. Check kernel alignment/tiling.", flush=True) def forward( self, diff --git a/vllm/patches/lly be able to do atte b/vllm/patches/lly be able to do atte new file mode 100644 index 00000000..e69de29b