[Model Runner V2] Minor fix for cudagraph_utils (#29256)
This commit is contained in:
@@ -1,7 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import gc
|
from unittest.mock import patch
|
||||||
from contextlib import contextmanager
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -140,6 +139,7 @@ class CudaGraphManager:
|
|||||||
attn_metadata,
|
attn_metadata,
|
||||||
self.vllm_config,
|
self.vllm_config,
|
||||||
num_tokens=batch_size,
|
num_tokens=batch_size,
|
||||||
|
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
||||||
num_tokens_across_dp=num_tokens_across_dp,
|
num_tokens_across_dp=num_tokens_across_dp,
|
||||||
):
|
):
|
||||||
hidden_states = model(
|
hidden_states = model(
|
||||||
@@ -148,15 +148,16 @@ class CudaGraphManager:
|
|||||||
)
|
)
|
||||||
if self.hidden_states is None:
|
if self.hidden_states is None:
|
||||||
self.hidden_states = torch.empty_like(hidden_states)
|
self.hidden_states = torch.empty_like(hidden_states)
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
# Capture the graph.
|
# Capture the graph.
|
||||||
graph = torch.cuda.CUDAGraph()
|
graph = torch.cuda.CUDAGraph()
|
||||||
with (
|
with (
|
||||||
|
patch("torch.cuda.empty_cache", lambda: None),
|
||||||
set_forward_context(
|
set_forward_context(
|
||||||
attn_metadata,
|
attn_metadata,
|
||||||
self.vllm_config,
|
self.vllm_config,
|
||||||
num_tokens=batch_size,
|
num_tokens=batch_size,
|
||||||
|
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
||||||
num_tokens_across_dp=num_tokens_across_dp,
|
num_tokens_across_dp=num_tokens_across_dp,
|
||||||
),
|
),
|
||||||
torch.cuda.graph(graph, self.pool),
|
torch.cuda.graph(graph, self.pool),
|
||||||
@@ -183,7 +184,7 @@ class CudaGraphManager:
|
|||||||
if is_global_first_rank():
|
if is_global_first_rank():
|
||||||
sizes_to_capture = tqdm(sizes_to_capture, desc="Capturing CUDA graphs")
|
sizes_to_capture = tqdm(sizes_to_capture, desc="Capturing CUDA graphs")
|
||||||
|
|
||||||
with freeze_gc(), graph_capture(device=self.device):
|
with graph_capture(device=self.device):
|
||||||
for batch_size in sizes_to_capture:
|
for batch_size in sizes_to_capture:
|
||||||
self.capture_graph(
|
self.capture_graph(
|
||||||
batch_size,
|
batch_size,
|
||||||
@@ -199,13 +200,3 @@ class CudaGraphManager:
|
|||||||
self.graphs[batch_size].replay()
|
self.graphs[batch_size].replay()
|
||||||
assert self.hidden_states is not None
|
assert self.hidden_states is not None
|
||||||
return self.hidden_states[:batch_size]
|
return self.hidden_states[:batch_size]
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def freeze_gc():
|
|
||||||
gc.collect()
|
|
||||||
gc.freeze()
|
|
||||||
try:
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
gc.unfreeze()
|
|
||||||
|
|||||||
@@ -298,6 +298,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
return 0
|
return 0
|
||||||
|
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
start_free_gpu_memory = torch.cuda.mem_get_info()[0]
|
start_free_gpu_memory = torch.cuda.mem_get_info()[0]
|
||||||
|
|
||||||
with self.maybe_setup_dummy_loras(self.lora_config):
|
with self.maybe_setup_dummy_loras(self.lora_config):
|
||||||
|
|||||||
Reference in New Issue
Block a user