fix: warmup with 128 tokens (fills MMA tile), better error handling
The CuTeDSL kernel uses MMA tiler (128,128,256). With only 1 token, the kernel can't fill a tile and may access illegal memory. Using 128 tokens for the warmup. Also improved error message — after CUDA illegal memory access, the context is corrupted and can't recover.
This commit is contained in:
@@ -406,7 +406,7 @@ class DeepseekV4MegaMoEExperts(nn.Module):
|
||||
|
||||
def finalize_weights(self) -> None:
|
||||
if self._cutedsl_runner is not None and self._cutedsl_runner.l1_fp4 is not None:
|
||||
return
|
||||
return # Already finalized
|
||||
|
||||
self._check_runtime_supported()
|
||||
|
||||
@@ -508,15 +508,21 @@ class DeepseekV4MegaMoEExperts(nn.Module):
|
||||
# Warm up the CuTeDSL kernel (JIT compiles MLIR→PTX on first call).
|
||||
# This takes ~1-2 min but prevents the vLLM RPC timeout (5 min) from
|
||||
# killing the engine when the first inference request triggers compilation.
|
||||
# Note: The MMA tiler is (128,128,256) — we need >= 128 tokens to fill
|
||||
# a tile. Using 128 tokens, 1 expert for the warmup.
|
||||
print(" Compiling CuTeDSL NVFP4 MegaMoE kernel (one-time JIT, ~1-2 min)...", flush=True)
|
||||
dummy_hidden = torch.randn(1, self.hidden_size, dtype=torch.bfloat16, device=self._cutedsl_runner.l1_fp4[0].device)
|
||||
dummy_ids = torch.zeros(1, 1, dtype=torch.int32, device=dummy_hidden.device)
|
||||
dummy_weights = torch.ones(1, 1, dtype=torch.float32, device=dummy_hidden.device)
|
||||
try:
|
||||
device = self._cutedsl_runner.l1_fp4[0].device
|
||||
dummy_hidden = torch.randn(128, self.hidden_size, dtype=torch.bfloat16, device=device)
|
||||
dummy_ids = torch.zeros(128, 1, dtype=torch.int32, device=device) # all to expert 0
|
||||
dummy_weights = torch.ones(128, 1, dtype=torch.float32, device=device)
|
||||
self._cutedsl_runner.run(dummy_hidden, dummy_weights, dummy_ids, expert_indices=[0])
|
||||
print(" CuTeDSL NVFP4 MegaMoE kernel compiled ✓", flush=True)
|
||||
except Exception:
|
||||
print(" CuTeDSL warmup failed (will compile on first inference)", flush=True)
|
||||
except Exception as exc:
|
||||
# CUDA illegal memory access corrupts the context — can't recover.
|
||||
# Log the error clearly so the user knows to check the kernel.
|
||||
print(f" CuTeDSL warmup FAILED: {exc}", flush=True)
|
||||
print(" The CUDA context may be corrupted. Check kernel alignment/tiling.", flush=True)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
0
vllm/patches/lly be able to do atte
Normal file
0
vllm/patches/lly be able to do atte
Normal file
Reference in New Issue
Block a user