[Misc] Log spec decode metrics (#6454)
This commit is contained in:
@@ -162,6 +162,11 @@ def create_llm_generator(baseline_or_test, request, common_llm_kwargs,
|
||||
}
|
||||
test_name = request.node.name
|
||||
|
||||
model = kwargs["model"]
|
||||
draft_model = kwargs.get("speculative_model", None)
|
||||
same_draft_target_model = (draft_model is not None
|
||||
and draft_model == model)
|
||||
|
||||
def generator_inner():
|
||||
|
||||
wait_for_gpu_memory_to_clear(
|
||||
@@ -177,6 +182,13 @@ def create_llm_generator(baseline_or_test, request, common_llm_kwargs,
|
||||
|
||||
print(f'Creating {baseline_or_test=} LLM for {test_name=}. {kwargs=}')
|
||||
llm = AsyncLLM(**kwargs) if use_async else LLM(**kwargs)
|
||||
|
||||
# Override logging interval to 0 for spec decode test run to
|
||||
# log all metrics in time.
|
||||
if (baseline_or_test == "test" and not use_async
|
||||
and llm.llm_engine.log_stats):
|
||||
for sate_logger in llm.llm_engine.stat_loggers.values():
|
||||
sate_logger.local_interval = 0
|
||||
set_random_seed(seed)
|
||||
|
||||
yield llm
|
||||
@@ -188,6 +200,9 @@ def create_llm_generator(baseline_or_test, request, common_llm_kwargs,
|
||||
yield llm
|
||||
del llm
|
||||
|
||||
# Set an attribute to the generator_outer function to allow us to
|
||||
# determine whether to further check the acceptance rate in tests.
|
||||
generator_outer.same_draft_target_model = same_draft_target_model # type: ignore
|
||||
return generator_outer
|
||||
|
||||
|
||||
@@ -204,18 +219,26 @@ def maybe_assert_ngram_worker(llm):
|
||||
|
||||
def get_output_from_llm_generator(
|
||||
llm_generator, prompts,
|
||||
sampling_params) -> Tuple[List[str], List[List[int]]]:
|
||||
sampling_params) -> Tuple[List[str], List[List[int]], float]:
|
||||
tokens: List[str] = []
|
||||
token_ids: List[List[int]] = []
|
||||
acceptance_rate: float = -1.0
|
||||
for llm in llm_generator():
|
||||
maybe_assert_ngram_worker(llm)
|
||||
|
||||
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
|
||||
token_ids = [output.outputs[0].token_ids for output in outputs]
|
||||
tokens = [output.outputs[0].text for output in outputs]
|
||||
|
||||
# Fetch acceptance rate if logging is enabled.
|
||||
if stat_loggers := getattr(llm.llm_engine, "stat_loggers", None):
|
||||
stat_logger = stat_loggers["prometheus"]
|
||||
acceptance_rate = (stat_logger.metrics.
|
||||
gauge_spec_decode_draft_acceptance_rate.labels(
|
||||
**stat_logger.labels)._value.get())
|
||||
del llm
|
||||
|
||||
return tokens, token_ids
|
||||
return tokens, token_ids, acceptance_rate
|
||||
|
||||
|
||||
def get_logprobs_from_llm_generator(
|
||||
@@ -237,7 +260,8 @@ def run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
batch_size,
|
||||
max_output_len,
|
||||
force_output_len: bool,
|
||||
print_tokens: bool = False):
|
||||
print_tokens: bool = False,
|
||||
ensure_all_accepted: bool = False):
|
||||
"""Helper method that compares the outputs of both the baseline LLM and
|
||||
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
|
||||
the same when temperature is zero.
|
||||
@@ -267,12 +291,13 @@ def run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
spec_batch_tokens, spec_batch_token_ids = get_output_from_llm_generator(
|
||||
test_llm_generator, prompts, sampling_params)
|
||||
(spec_batch_tokens, spec_batch_token_ids,
|
||||
acceptance_rate) = get_output_from_llm_generator(test_llm_generator,
|
||||
prompts, sampling_params)
|
||||
|
||||
(baseline_batch_tokens,
|
||||
baseline_batch_token_ids) = get_output_from_llm_generator(
|
||||
baseline_llm_generator, prompts, sampling_params)
|
||||
(baseline_batch_tokens, baseline_batch_token_ids,
|
||||
_) = get_output_from_llm_generator(baseline_llm_generator, prompts,
|
||||
sampling_params)
|
||||
|
||||
assert len(baseline_batch_token_ids) == len(prompts)
|
||||
assert len(spec_batch_token_ids) == len(prompts)
|
||||
@@ -287,3 +312,6 @@ def run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
print(f'{i=} {baseline_token_ids=}')
|
||||
print(f'{i=} {spec_token_ids=}')
|
||||
assert baseline_token_ids == spec_token_ids
|
||||
|
||||
if ensure_all_accepted:
|
||||
assert acceptance_rate == 1.0
|
||||
|
||||
Reference in New Issue
Block a user