[ BugFix ] Prompt Logprobs Detokenization (#6223)
Co-authored-by: Zifei Tong <zifeitong@gmail.com>
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from typing import Dict, List
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
@@ -139,6 +139,15 @@ def create_dummy_logprobs(
|
||||
} for token_id in complete_sequence_token_ids]
|
||||
|
||||
|
||||
def create_dummy_prompt_logprobs(
|
||||
complete_sequence_token_ids: List[int]
|
||||
) -> List[Optional[Dict[int, Any]]]:
|
||||
# logprob for the first prompt token is None.
|
||||
logprobs: List[Optional[Dict[int, Any]]] = [None]
|
||||
logprobs.extend(create_dummy_logprobs(complete_sequence_token_ids)[1:])
|
||||
return logprobs
|
||||
|
||||
|
||||
@pytest.mark.parametrize("complete_sequence", TRUTH)
|
||||
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
|
||||
@pytest.mark.parametrize("skip_special_tokens", [True, False])
|
||||
@@ -177,13 +186,10 @@ def test_decode_sequence_logprobs(complete_sequence: str,
|
||||
|
||||
@pytest.mark.parametrize("complete_sequence", TRUTH)
|
||||
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
|
||||
@pytest.mark.parametrize("skip_special_tokens", [True])
|
||||
def test_decode_prompt_logprobs(complete_sequence: str,
|
||||
complete_sequence_token_ids: List[int],
|
||||
detokenizer: Detokenizer,
|
||||
skip_special_tokens: bool):
|
||||
def test_decode_prompt_logprobs(complete_sequence_token_ids: List[int],
|
||||
detokenizer: Detokenizer):
|
||||
"""Verify Detokenizer decodes prompt logprobs correctly."""
|
||||
sampling_params = SamplingParams(skip_special_tokens=skip_special_tokens,
|
||||
sampling_params = SamplingParams(skip_special_tokens=True,
|
||||
prompt_logprobs=1)
|
||||
|
||||
# Run sequentially.
|
||||
@@ -192,19 +198,78 @@ def test_decode_prompt_logprobs(complete_sequence: str,
|
||||
seqs=[seq],
|
||||
sampling_params=sampling_params,
|
||||
arrival_time=0.0)
|
||||
dummy_logprobs = create_dummy_logprobs(complete_sequence_token_ids)
|
||||
detokenizer.decode_prompt_logprobs_inplace(seq_group, dummy_logprobs)
|
||||
decoded_prompt_logprobs = dummy_logprobs
|
||||
dummy_logprobs = create_dummy_prompt_logprobs(complete_sequence_token_ids)
|
||||
detokenizer.decode_prompt_logprobs_inplace(seq_group,
|
||||
dummy_logprobs,
|
||||
position_offset=0)
|
||||
# First logprob is None.
|
||||
decoded_prompt_logprobs: List[Dict[int, Any]] = dummy_logprobs[
|
||||
1:] # type: ignore
|
||||
|
||||
if skip_special_tokens:
|
||||
# Text for logprobs for the chosen token should be the same as the
|
||||
# prompt text. Note that this will only be true if we skip
|
||||
# special tokens.
|
||||
assert complete_sequence == "".join([
|
||||
logprobs[token_id].decoded_token for token_id, logprobs in zip(
|
||||
complete_sequence_token_ids, decoded_prompt_logprobs)
|
||||
])
|
||||
assert complete_sequence != "".join([
|
||||
logprobs[token_id + 1].decoded_token for token_id, logprobs in zip(
|
||||
complete_sequence_token_ids, decoded_prompt_logprobs)
|
||||
])
|
||||
# decoded_prompt_logprobs doesn't contain the first token.
|
||||
token_ids = complete_sequence_token_ids
|
||||
tokenzier = detokenizer.get_tokenizer_for_seq(seq)
|
||||
text_full = tokenzier.decode(token_ids, skip_special_tokens=True)
|
||||
text_first = tokenzier.decode(token_ids[0], skip_special_tokens=True)
|
||||
text = text_full[len(text_first):]
|
||||
|
||||
# Text for logprobs for the chosen token should be the same as the
|
||||
# prompt text. Note that the first logprob is None.
|
||||
assert text == "".join([
|
||||
logprobs[token_id].decoded_token
|
||||
for token_id, logprobs in zip(token_ids[1:], decoded_prompt_logprobs)
|
||||
])
|
||||
assert text != "".join([
|
||||
logprobs[token_id + 1].decoded_token
|
||||
for token_id, logprobs in zip(token_ids[1:], decoded_prompt_logprobs)
|
||||
])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
|
||||
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 7, 16, -1])
|
||||
def test_decode_prompt_logprobs_chunked_prefill(
|
||||
vllm_runner,
|
||||
model,
|
||||
chunked_prefill_token_size: int,
|
||||
example_prompts,
|
||||
):
|
||||
max_num_seqs = 256
|
||||
enable_chunked_prefill = False
|
||||
max_num_batched_tokens = None
|
||||
if chunked_prefill_token_size != -1:
|
||||
enable_chunked_prefill = True
|
||||
max_num_seqs = min(chunked_prefill_token_size, max_num_seqs)
|
||||
max_num_batched_tokens = chunked_prefill_token_size
|
||||
|
||||
with vllm_runner(model,
|
||||
dtype="half",
|
||||
max_logprobs=5,
|
||||
gpu_memory_utilization=0.5,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
max_num_seqs=max_num_seqs) as vllm_model:
|
||||
|
||||
vllm_sampling_params = SamplingParams(max_tokens=10,
|
||||
logprobs=5,
|
||||
prompt_logprobs=5,
|
||||
temperature=0.0)
|
||||
vllm_results = vllm_model.model.generate(
|
||||
example_prompts, sampling_params=vllm_sampling_params)
|
||||
|
||||
for idx, result in enumerate(vllm_results):
|
||||
assert result.prompt_logprobs is not None
|
||||
assert result.prompt_logprobs[0] is None
|
||||
|
||||
# Compared detokenized prompts ids to original prompt.
|
||||
generated_string = ""
|
||||
for (prompt_token,
|
||||
prompt_logprobs) in zip(result.prompt_token_ids[1:],
|
||||
result.prompt_logprobs[1:]):
|
||||
# prompt_logprobs is a dict of the token_id: logprob
|
||||
# We select the token_id corresponding to the actual prompt
|
||||
# Decoded token in the detokenized string corresponding to this
|
||||
# prompt token.
|
||||
generated_string += prompt_logprobs[prompt_token].decoded_token
|
||||
|
||||
assert generated_string == example_prompts[idx], (
|
||||
"Detokenized prompt logprobs do not match original prompt")
|
||||
|
||||
Reference in New Issue
Block a user