[Model Runner V2] Minor fix for cudagraph_utils (#29256)

This commit is contained in:
Woosuk Kwon
2025-11-22 20:12:50 -08:00
committed by GitHub
parent 389aa1b2eb
commit 20ee418adc
2 changed files with 6 additions and 14 deletions

View File

@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import gc from unittest.mock import patch
from contextlib import contextmanager
import numpy as np import numpy as np
import torch import torch
@@ -140,6 +139,7 @@ class CudaGraphManager:
attn_metadata, attn_metadata,
self.vllm_config, self.vllm_config,
num_tokens=batch_size, num_tokens=batch_size,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
num_tokens_across_dp=num_tokens_across_dp, num_tokens_across_dp=num_tokens_across_dp,
): ):
hidden_states = model( hidden_states = model(
@@ -148,15 +148,16 @@ class CudaGraphManager:
) )
if self.hidden_states is None: if self.hidden_states is None:
self.hidden_states = torch.empty_like(hidden_states) self.hidden_states = torch.empty_like(hidden_states)
torch.cuda.synchronize()
# Capture the graph. # Capture the graph.
graph = torch.cuda.CUDAGraph() graph = torch.cuda.CUDAGraph()
with ( with (
patch("torch.cuda.empty_cache", lambda: None),
set_forward_context( set_forward_context(
attn_metadata, attn_metadata,
self.vllm_config, self.vllm_config,
num_tokens=batch_size, num_tokens=batch_size,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
num_tokens_across_dp=num_tokens_across_dp, num_tokens_across_dp=num_tokens_across_dp,
), ),
torch.cuda.graph(graph, self.pool), torch.cuda.graph(graph, self.pool),
@@ -183,7 +184,7 @@ class CudaGraphManager:
if is_global_first_rank(): if is_global_first_rank():
sizes_to_capture = tqdm(sizes_to_capture, desc="Capturing CUDA graphs") sizes_to_capture = tqdm(sizes_to_capture, desc="Capturing CUDA graphs")
with freeze_gc(), graph_capture(device=self.device): with graph_capture(device=self.device):
for batch_size in sizes_to_capture: for batch_size in sizes_to_capture:
self.capture_graph( self.capture_graph(
batch_size, batch_size,
@@ -199,13 +200,3 @@ class CudaGraphManager:
self.graphs[batch_size].replay() self.graphs[batch_size].replay()
assert self.hidden_states is not None assert self.hidden_states is not None
return self.hidden_states[:batch_size] return self.hidden_states[:batch_size]
@contextmanager
def freeze_gc():
gc.collect()
gc.freeze()
try:
yield
finally:
gc.unfreeze()

View File

@@ -298,6 +298,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
return 0 return 0
start_time = time.perf_counter() start_time = time.perf_counter()
torch.cuda.empty_cache()
start_free_gpu_memory = torch.cuda.mem_get_info()[0] start_free_gpu_memory = torch.cuda.mem_get_info()[0]
with self.maybe_setup_dummy_loras(self.lora_config): with self.maybe_setup_dummy_loras(self.lora_config):