diff --git a/vllm/kernels/linear/nvfp4/cutedsl.py b/vllm/kernels/linear/nvfp4/cutedsl.py index 648aaad1..4c7ae95e 100644 --- a/vllm/kernels/linear/nvfp4/cutedsl.py +++ b/vllm/kernels/linear/nvfp4/cutedsl.py @@ -188,6 +188,21 @@ class CuTeDSLNvFp4LinearKernel(NvFp4LinearKernel): layer._cutedsl_global_scale_b = runner._gsb layer._cutedsl_activation_global_scale = activation_global_scale + # Warmup: CuTeDSL cute.compile corrupts GPU memory during JIT. + # Run a warmup forward to trigger compilation, then synchronize + # and verify GPU health. Matches cutedsl/runner.py MoE pattern. + with torch.no_grad(): + warmup_x = torch.randn(1, in_features, dtype=torch.bfloat16, + device=device) + _ = torch.ops.cutedsl.nvfp4_linear( + warmup_x, runner._mat_b, runner._scale_b, runner._gsb, + activation_global_scale, + ) + torch.cuda.synchronize() + # Verify GPU is still healthy after CuTeDSL JIT + test = torch.ones(1, device=device) + torch.ones(1, device=device) + assert test.item() == 2.0, "GPU corruption after CuTeDSL JIT" + # Replace weight with dummy BF16 (vLLM module introspection may need it) layer.weight = torch.nn.Parameter( torch.zeros(out_features, in_features, dtype=torch.bfloat16,