Fix OOM: cap buffer pre-allocation at cudagraph max capture size

padded_hidden/activated buffers were sized for max_num_tokens=8192,
which is 72 MB per layer × 60 layers = 4.3 GB → OOM with 178 GB GPUs
(almost full from model + KV cache).

Now cap at max cudagraph capture size (512 tokens). Eager-mode runs
with >512 tokens will need dynamic allocation, but vLLM always uses
cudagraph for inference after warmup.
This commit is contained in:
2026-05-17 14:14:13 +00:00
parent 5bb78564f5
commit 8ac8e20fa9
2 changed files with 7 additions and 3 deletions

View File

@@ -81,7 +81,7 @@ class CuTeDSLMoERunner:
self._output_buf = None
self._row_indices_buf = None
self._padded_hidden_buf = None
self._padded_activated_buf = None
self._padded_activated_buf = None # unused, using shared
self._padded_expert_offsets_buf = None
self._max_chunks_per_expert = cutedsl_ceil_div(
self.max_num_tokens * self.top_k, self.num_experts * 128
@@ -138,7 +138,8 @@ class CuTeDSLMoERunner:
self.max_num_tokens * self.top_k, device=self.device
)
# Padded hidden/activated buffers: max_num_tokens * top_k rows (rounded to 128)
# Padded hidden/activated: per-layer, sized for max capture budget
# NOT sized for max_num_tokens (8192) which would be too much for 60 layers
max_slots = self.max_num_tokens * self.top_k
padded_max_slots = ((max_slots + 127) // 128) * 128
self._padded_hidden_buf = torch.zeros(

View File

@@ -499,11 +499,14 @@ class DeepseekV4MegaMoEExperts(nn.Module):
l2_gs.append(down_gs)
# Create CuTeDSL runner with directly-cast weights
# Max tokens for buffer pre-allocation: use cudagraph max capture size
# (not scheduler max which can be 8192, causing OOM with 60 layers)
max_cg_size = getattr(self, '_cudagraph_max_capture_size', 512)
self._cutedsl_runner = CuTeDSLMoERunner(
num_experts=self.num_local_experts,
hidden_size=self.hidden_size,
intermediate_size=self.intermediate_size,
max_num_tokens=self.max_num_tokens,
max_num_tokens=min(self.max_num_tokens, max_cg_size),
top_k=self.top_k,
device=l1_fp4[0].device,
experts_start_idx=self.experts_start_idx,