[V1] TPU CI - Add basic perf regression test (#15414)
Signed-off-by: Alexander Matveev <amatveev@redhat.com>
This commit is contained in:
committed by
GitHub
parent
2de4118243
commit
9a2160fa55
@@ -77,9 +77,12 @@ class TPUModelRunner:
|
||||
parallel_config = self.parallel_config
|
||||
self.device = device
|
||||
self.check_recompilation = envs.VLLM_XLA_CHECK_RECOMPILATION
|
||||
if self.check_recompilation:
|
||||
self.num_xla_graphs = xr.get_num_cached_compilation_graph()
|
||||
|
||||
self.enforce_eager = model_config.enforce_eager
|
||||
|
||||
self.num_xla_graphs = 0
|
||||
self._update_num_xla_graphs("init")
|
||||
|
||||
self.pin_memory = is_pin_memory_available()
|
||||
self.dtype = self.model_config.dtype
|
||||
self._hidden_states_dtype = self.dtype
|
||||
@@ -180,6 +183,31 @@ class TPUModelRunner:
|
||||
max_token_size=self.max_num_tokens,
|
||||
padding_gap=envs.VLLM_TPU_BUCKET_PADDING_GAP)
|
||||
|
||||
def _update_num_xla_graphs(self, case_str):
|
||||
check_comp = self.check_recompilation and not self.enforce_eager
|
||||
if not check_comp:
|
||||
return
|
||||
|
||||
total_cached_graphs = xr.get_num_cached_compilation_graph()
|
||||
new_compiled_graphs = total_cached_graphs - self.num_xla_graphs
|
||||
if new_compiled_graphs == 0:
|
||||
return
|
||||
|
||||
logger.info("Add new %d compiled XLA graphs due to %s",
|
||||
new_compiled_graphs, case_str)
|
||||
self.num_xla_graphs += new_compiled_graphs
|
||||
|
||||
def _verify_num_xla_graphs(self, case_str):
|
||||
check_comp = self.check_recompilation and not self.enforce_eager
|
||||
if not check_comp:
|
||||
return
|
||||
|
||||
curr_cached_graph = xr.get_num_cached_compilation_graph()
|
||||
assert self.num_xla_graphs == curr_cached_graph, (
|
||||
"Recompilation after warm up is detected during {}."
|
||||
" num_xla_graphs = {} curr_cached_graph = {}".format(
|
||||
case_str, self.num_xla_graphs, curr_cached_graph))
|
||||
|
||||
def _update_states(self, scheduler_output: "SchedulerOutput") -> bool:
|
||||
"""Update the cached states and the persistent batch with the scheduler
|
||||
output.
|
||||
@@ -694,12 +722,11 @@ class TPUModelRunner:
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict=prompt_logprobs_dict,
|
||||
)
|
||||
# Check there is no new graph compilation, all the graphs should be
|
||||
# captured and compiled during warming up.
|
||||
if self.check_recompilation and not self.enforce_eager:
|
||||
curr_cached_graph = xr.get_num_cached_compilation_graph()
|
||||
assert self.num_xla_graphs == curr_cached_graph, (
|
||||
"Recompilation after warm up is detected.")
|
||||
|
||||
# Check there are no new graphs compiled - all the graphs should be
|
||||
# captured and compiled during warm up.
|
||||
self._verify_num_xla_graphs("execute_model")
|
||||
|
||||
return model_runner_output
|
||||
|
||||
def load_model(self) -> None:
|
||||
@@ -797,7 +824,9 @@ class TPUModelRunner:
|
||||
xm.mark_step()
|
||||
xm.wait_device_ops()
|
||||
end = time.perf_counter()
|
||||
|
||||
logger.info("Compilation finished in in %.2f [secs].", end - start)
|
||||
self._update_num_xla_graphs("model")
|
||||
|
||||
logger.info("Compiling sampling with different input shapes.")
|
||||
start = time.perf_counter()
|
||||
@@ -832,15 +861,9 @@ class TPUModelRunner:
|
||||
num_reqs_to_sample + 1, self.max_num_reqs)
|
||||
xm.wait_device_ops()
|
||||
end = time.perf_counter()
|
||||
logger.info("Compilation finished in %.2f [secs].", end - start)
|
||||
# Record the number cached XLA graph after warming up, this will be
|
||||
# used for checking there is no additional graph compilation during
|
||||
# runtime execution.
|
||||
if self.check_recompilation:
|
||||
total_cached_graphs = xr.get_num_cached_compilation_graph()
|
||||
num_compiled_graphs = total_cached_graphs - self.num_xla_graphs
|
||||
logger.info("Compiled %d XLA graphs.", num_compiled_graphs)
|
||||
self.num_xla_graphs += num_compiled_graphs
|
||||
|
||||
logger.info("Compilation finished in in %.2f [secs].", end - start)
|
||||
self._update_num_xla_graphs("sampling")
|
||||
|
||||
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user