[Hardware][TPU] Add check for no additional graph compilation during runtime (#14710)
Signed-off-by: Siyuan Liu <lsiyuan@google.com>
This commit is contained in:
@@ -11,6 +11,7 @@ import torch.nn as nn
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.runtime as xr
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import VllmConfig
|
||||
@@ -73,6 +74,10 @@ class TPUModelRunner:
|
||||
scheduler_config = self.scheduler_config
|
||||
parallel_config = self.parallel_config
|
||||
self.device = device
|
||||
self.check_recompilation = envs.VLLM_XLA_CHECK_RECOMPILATION
|
||||
if self.check_recompilation:
|
||||
self.num_xla_graphs = xr.get_num_cached_compilation_graph()
|
||||
self.enforce_eager = model_config.enforce_eager
|
||||
self.pin_memory = is_pin_memory_available()
|
||||
self.dtype = self.model_config.dtype
|
||||
|
||||
@@ -671,6 +676,12 @@ class TPUModelRunner:
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict=prompt_logprobs_dict,
|
||||
)
|
||||
# Check there is no new graph compilation, all the graphs should be
|
||||
# captured and compiled during warming up.
|
||||
if self.check_recompilation and not self.enforce_eager:
|
||||
curr_cached_graph = xr.get_num_cached_compilation_graph()
|
||||
assert self.num_xla_graphs == curr_cached_graph, (
|
||||
"Recompilation after warm up is detected.")
|
||||
return model_runner_output
|
||||
|
||||
def load_model(self) -> None:
|
||||
@@ -810,6 +821,14 @@ class TPUModelRunner:
|
||||
xm.wait_device_ops()
|
||||
end = time.perf_counter()
|
||||
logger.info("Compilation finished in in %.2f [secs].", end - start)
|
||||
# Record the number cached XLA graph after warming up, this will be
|
||||
# used for checking there is no additional graph compilation during
|
||||
# runtime execution.
|
||||
if self.check_recompilation:
|
||||
total_cached_graphs = xr.get_num_cached_compilation_graph()
|
||||
num_compiled_graphs = total_cached_graphs - self.num_xla_graphs
|
||||
logger.info("Compiled %d XLA graphs.", num_compiled_graphs)
|
||||
self.num_xla_graphs += num_compiled_graphs
|
||||
|
||||
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user