[Hardware][TPU] Add check for no additional graph compilation during runtime (#14710)

Signed-off-by: Siyuan Liu <lsiyuan@google.com>
This commit is contained in:
Siyuan Liu
2025-03-20 20:05:28 -07:00
committed by GitHub
parent e588ac237c
commit b15fd2be2a
3 changed files with 32 additions and 6 deletions

View File

@@ -11,6 +11,7 @@ import torch.nn as nn
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention
from vllm.config import VllmConfig
@@ -73,6 +74,10 @@ class TPUModelRunner:
scheduler_config = self.scheduler_config
parallel_config = self.parallel_config
self.device = device
self.check_recompilation = envs.VLLM_XLA_CHECK_RECOMPILATION
if self.check_recompilation:
self.num_xla_graphs = xr.get_num_cached_compilation_graph()
self.enforce_eager = model_config.enforce_eager
self.pin_memory = is_pin_memory_available()
self.dtype = self.model_config.dtype
@@ -671,6 +676,12 @@ class TPUModelRunner:
logprobs=None,
prompt_logprobs_dict=prompt_logprobs_dict,
)
# Check there is no new graph compilation, all the graphs should be
# captured and compiled during warming up.
if self.check_recompilation and not self.enforce_eager:
curr_cached_graph = xr.get_num_cached_compilation_graph()
assert self.num_xla_graphs == curr_cached_graph, (
"Recompilation after warm up is detected.")
return model_runner_output
def load_model(self) -> None:
@@ -810,6 +821,14 @@ class TPUModelRunner:
xm.wait_device_ops()
end = time.perf_counter()
logger.info("Compilation finished in in %.2f [secs].", end - start)
# Record the number cached XLA graph after warming up, this will be
# used for checking there is no additional graph compilation during
# runtime execution.
if self.check_recompilation:
total_cached_graphs = xr.get_num_cached_compilation_graph()
num_compiled_graphs = total_cached_graphs - self.num_xla_graphs
logger.info("Compiled %d XLA graphs.", num_compiled_graphs)
self.num_xla_graphs += num_compiled_graphs
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
"""