Add CuTeDSL warmup + CUDA sync after JIT compilation
CuTeDSL cute.compile corrupts GPU memory. Add warmup forward + torch.cuda.synchronize() + health check after finalize_weights, matching the MoE runner pattern.
This commit is contained in:
@@ -188,6 +188,21 @@ class CuTeDSLNvFp4LinearKernel(NvFp4LinearKernel):
|
||||
layer._cutedsl_global_scale_b = runner._gsb
|
||||
layer._cutedsl_activation_global_scale = activation_global_scale
|
||||
|
||||
# Warmup: CuTeDSL cute.compile corrupts GPU memory during JIT.
|
||||
# Run a warmup forward to trigger compilation, then synchronize
|
||||
# and verify GPU health. Matches cutedsl/runner.py MoE pattern.
|
||||
with torch.no_grad():
|
||||
warmup_x = torch.randn(1, in_features, dtype=torch.bfloat16,
|
||||
device=device)
|
||||
_ = torch.ops.cutedsl.nvfp4_linear(
|
||||
warmup_x, runner._mat_b, runner._scale_b, runner._gsb,
|
||||
activation_global_scale,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
# Verify GPU is still healthy after CuTeDSL JIT
|
||||
test = torch.ones(1, device=device) + torch.ones(1, device=device)
|
||||
assert test.item() == 2.0, "GPU corruption after CuTeDSL JIT"
|
||||
|
||||
# Replace weight with dummy BF16 (vLLM module introspection may need it)
|
||||
layer.weight = torch.nn.Parameter(
|
||||
torch.zeros(out_features, in_features, dtype=torch.bfloat16,
|
||||
|
||||
Reference in New Issue
Block a user