Fix OOM: use 1-token warmup sample + free immediately
8 tokens * 7168 hidden * ~40 NVFP4 layers = ~2.3 MiB per layer * 40 = 92 MiB But the dummy weight param (out_features * in_features * 2 bytes BF16) was the real killer — each layer allocated a BF16 dummy of its full weight shape. With 1 token the warmup still gets a valid gs, and empty_cache frees the sample tensor before KV cache allocation.
This commit is contained in:
@@ -103,10 +103,13 @@ class CuTeDSLNvfp4Method(LinearMethodBase):
|
||||
# match what quantize_activation_nvfp4 expects at runtime. Using it
|
||||
# produces garbage output (empty EOS tokens). The correct approach is
|
||||
# a warmup forward pass that measures the actual activation distribution.
|
||||
# Use only 1 token to minimize GPU memory overhead during weight loading.
|
||||
with torch.no_grad():
|
||||
sample = torch.randn(min(8, 256), in_features,
|
||||
sample = torch.randn(1, in_features,
|
||||
dtype=torch.bfloat16, device=device) * 2.0
|
||||
runner.compute_activation_global_scale(sample)
|
||||
del sample
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Replace weight with dummy BF16 (needed by vLLM module introspection)
|
||||
layer.weight = torch.nn.Parameter(
|
||||
|
||||
@@ -85,12 +85,15 @@ class CuTeDSLNvFp4LinearKernel(NvFp4LinearKernel):
|
||||
# match what quantize_activation_nvfp4 expects at runtime. Using it
|
||||
# produces garbage output (empty EOS tokens). The correct approach is
|
||||
# a warmup forward pass that measures the actual activation distribution.
|
||||
# Use only 1 token to minimize GPU memory overhead during weight loading.
|
||||
with torch.no_grad():
|
||||
sample = torch.randn(
|
||||
min(8, 256), in_features,
|
||||
1, in_features,
|
||||
dtype=torch.bfloat16, device=str(device),
|
||||
) * 2.0
|
||||
runner.compute_activation_global_scale(sample)
|
||||
del sample
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Register the runner and store the ID (not the runner itself)
|
||||
layer._cutedsl_runner_id = register_runner(runner)
|
||||
|
||||
Reference in New Issue
Block a user