[V1] TPU CI - Add basic perf regression test (#15414)

Signed-off-by: Alexander Matveev <amatveev@redhat.com>
This commit is contained in:
Alexander Matveev
2025-03-31 13:25:20 -04:00
committed by GitHub
parent 2de4118243
commit 9a2160fa55
5 changed files with 192 additions and 20 deletions

View File

@@ -77,9 +77,12 @@ class TPUModelRunner:
parallel_config = self.parallel_config
self.device = device
self.check_recompilation = envs.VLLM_XLA_CHECK_RECOMPILATION
if self.check_recompilation:
self.num_xla_graphs = xr.get_num_cached_compilation_graph()
self.enforce_eager = model_config.enforce_eager
self.num_xla_graphs = 0
self._update_num_xla_graphs("init")
self.pin_memory = is_pin_memory_available()
self.dtype = self.model_config.dtype
self._hidden_states_dtype = self.dtype
@@ -180,6 +183,31 @@ class TPUModelRunner:
max_token_size=self.max_num_tokens,
padding_gap=envs.VLLM_TPU_BUCKET_PADDING_GAP)
def _update_num_xla_graphs(self, case_str):
check_comp = self.check_recompilation and not self.enforce_eager
if not check_comp:
return
total_cached_graphs = xr.get_num_cached_compilation_graph()
new_compiled_graphs = total_cached_graphs - self.num_xla_graphs
if new_compiled_graphs == 0:
return
logger.info("Add new %d compiled XLA graphs due to %s",
new_compiled_graphs, case_str)
self.num_xla_graphs += new_compiled_graphs
def _verify_num_xla_graphs(self, case_str):
check_comp = self.check_recompilation and not self.enforce_eager
if not check_comp:
return
curr_cached_graph = xr.get_num_cached_compilation_graph()
assert self.num_xla_graphs == curr_cached_graph, (
"Recompilation after warm up is detected during {}."
" num_xla_graphs = {} curr_cached_graph = {}".format(
case_str, self.num_xla_graphs, curr_cached_graph))
def _update_states(self, scheduler_output: "SchedulerOutput") -> bool:
"""Update the cached states and the persistent batch with the scheduler
output.
@@ -694,12 +722,11 @@ class TPUModelRunner:
logprobs=None,
prompt_logprobs_dict=prompt_logprobs_dict,
)
# Check there is no new graph compilation, all the graphs should be
# captured and compiled during warming up.
if self.check_recompilation and not self.enforce_eager:
curr_cached_graph = xr.get_num_cached_compilation_graph()
assert self.num_xla_graphs == curr_cached_graph, (
"Recompilation after warm up is detected.")
# Check there are no new graphs compiled - all the graphs should be
# captured and compiled during warm up.
self._verify_num_xla_graphs("execute_model")
return model_runner_output
def load_model(self) -> None:
@@ -797,7 +824,9 @@ class TPUModelRunner:
xm.mark_step()
xm.wait_device_ops()
end = time.perf_counter()
logger.info("Compilation finished in in %.2f [secs].", end - start)
self._update_num_xla_graphs("model")
logger.info("Compiling sampling with different input shapes.")
start = time.perf_counter()
@@ -832,15 +861,9 @@ class TPUModelRunner:
num_reqs_to_sample + 1, self.max_num_reqs)
xm.wait_device_ops()
end = time.perf_counter()
logger.info("Compilation finished in %.2f [secs].", end - start)
# Record the number cached XLA graph after warming up, this will be
# used for checking there is no additional graph compilation during
# runtime execution.
if self.check_recompilation:
total_cached_graphs = xr.get_num_cached_compilation_graph()
num_compiled_graphs = total_cached_graphs - self.num_xla_graphs
logger.info("Compiled %d XLA graphs.", num_compiled_graphs)
self.num_xla_graphs += num_compiled_graphs
logger.info("Compilation finished in in %.2f [secs].", end - start)
self._update_num_xla_graphs("sampling")
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
"""