[torch.compile] decouple compile sizes and cudagraph sizes (#12243)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2025-01-24 02:01:30 +08:00
committed by GitHub
parent 3f50c148fd
commit 6e650f56a1
7 changed files with 94 additions and 57 deletions

View File

@@ -1,6 +1,6 @@
import gc
import time
from typing import TYPE_CHECKING, Dict, List, Tuple, cast
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast
import numpy as np
import torch
@@ -128,7 +128,8 @@ class GPUModelRunner:
# self.cudagraph_batch_sizes sorts in ascending order.
# The batch sizes in the config are in descending order.
self.cudagraph_batch_sizes = list(
reversed(self.vllm_config.compilation_config.capture_sizes))
reversed(
self.vllm_config.compilation_config.cudagraph_capture_sizes))
# Cache the device properties.
self.device_properties = torch.cuda.get_device_properties(self.device)
@@ -834,10 +835,12 @@ class GPUModelRunner:
@torch.inference_mode()
def _dummy_run(
self,
model: nn.Module,
num_tokens: int,
kv_caches: List[torch.Tensor],
kv_caches: Optional[List[torch.Tensor]] = None,
) -> torch.Tensor:
model = self.model
if kv_caches is None:
kv_caches = self.kv_caches
if self.is_multimodal_model:
input_ids = None
inputs_embeds = self.inputs_embeds[:num_tokens]
@@ -963,8 +966,7 @@ class GPUModelRunner:
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
# Trigger compilation for general shape.
hidden_states = self._dummy_run(self.model, self.max_num_tokens,
dummy_kv_caches)
hidden_states = self._dummy_run(self.max_num_tokens, dummy_kv_caches)
logits = self.model.compute_logits(hidden_states, None)
logits = logits[:self.max_num_tokens]
# TODO(woosuk): Consider the memory usage of the sampler.
@@ -990,8 +992,8 @@ class GPUModelRunner:
for num_tokens in reversed(self.cudagraph_batch_sizes):
for _ in range(self.vllm_config.compilation_config.
cudagraph_num_of_warmups):
self._dummy_run(self.model, num_tokens, self.kv_caches)
self._dummy_run(self.model, num_tokens, self.kv_caches)
self._dummy_run(num_tokens)
self._dummy_run(num_tokens)
end_time = time.perf_counter()
end_free_gpu_memory = torch.cuda.mem_get_info()[0]

View File

@@ -206,6 +206,18 @@ class Worker:
self.model_runner.initialize_kv_cache(kv_cache_config)
def compile_or_warm_up_model(self) -> None:
# warm up sizes that are not in cudagraph capture sizes,
# but users still want to compile for better performance,
# e.g. for the max-num-batched token size in chunked prefill.
warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy()
if not self.model_config.enforce_eager:
warmup_sizes = [
x for x in warmup_sizes if x not in
self.vllm_config.compilation_config.cudagraph_capture_sizes
]
for size in sorted(warmup_sizes, reverse=True):
logger.info("Compile and warming up model for size %d", size)
self.model_runner._dummy_run(size)
if not self.model_config.enforce_eager:
self.model_runner.capture_model()
# Reset the seed to ensure that the random state is not affected by