Add CuTeDSL warmup + CUDA sync after JIT compilation

CuTeDSL cute.compile corrupts GPU memory. Add warmup forward +
torch.cuda.synchronize() + health check after finalize_weights,
matching the MoE runner pattern.
This commit is contained in:
2026-05-19 01:11:44 +00:00
parent 1d9c0f996c
commit e1fcfc4f01

View File

@@ -188,6 +188,21 @@ class CuTeDSLNvFp4LinearKernel(NvFp4LinearKernel):
layer._cutedsl_global_scale_b = runner._gsb
layer._cutedsl_activation_global_scale = activation_global_scale
# Warmup: CuTeDSL cute.compile corrupts GPU memory during JIT.
# Run a warmup forward to trigger compilation, then synchronize
# and verify GPU health. Matches cutedsl/runner.py MoE pattern.
with torch.no_grad():
warmup_x = torch.randn(1, in_features, dtype=torch.bfloat16,
device=device)
_ = torch.ops.cutedsl.nvfp4_linear(
warmup_x, runner._mat_b, runner._scale_b, runner._gsb,
activation_global_scale,
)
torch.cuda.synchronize()
# Verify GPU is still healthy after CuTeDSL JIT
test = torch.ones(1, device=device) + torch.ones(1, device=device)
assert test.item() == 2.0, "GPU corruption after CuTeDSL JIT"
# Replace weight with dummy BF16 (vLLM module introspection may need it)
layer.weight = torch.nn.Parameter(
torch.zeros(out_features, in_features, dtype=torch.bfloat16,