[Bugfix] fix when skip tokenizer init (#21922)
Signed-off-by: rongfu.leng <rongfu.leng@daocloud.io>
This commit is contained in:
@@ -213,3 +213,29 @@ def test_engine_metrics(vllm_runner, monkeypatch, example_prompts):
|
||||
assert len(num_accepted_tokens_per_pos) == 1
|
||||
assert isinstance(num_accepted_tokens_per_pos[0], Vector)
|
||||
assert len(num_accepted_tokens_per_pos[0].values) == 5
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["meta-llama/Llama-3.2-1B-Instruct"])
|
||||
def test_skip_tokenizer_initialization(model: str,
|
||||
monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||
# This test checks if the flag skip_tokenizer_init skips the initialization
|
||||
# of tokenizer and detokenizer. The generated output is expected to contain
|
||||
# token ids.
|
||||
llm = LLM(
|
||||
model=model,
|
||||
skip_tokenizer_init=True,
|
||||
enforce_eager=True,
|
||||
)
|
||||
sampling_params = SamplingParams(prompt_logprobs=True, detokenize=True)
|
||||
|
||||
with pytest.raises(ValueError, match="cannot pass text prompts when"):
|
||||
llm.generate("abc", sampling_params)
|
||||
|
||||
outputs = llm.generate({"prompt_token_ids": [1, 2, 3]},
|
||||
sampling_params=sampling_params)
|
||||
assert len(outputs) > 0
|
||||
completions = outputs[0].outputs
|
||||
assert len(completions) > 0
|
||||
assert completions[0].text == ""
|
||||
assert completions[0].token_ids
|
||||
|
||||
Reference in New Issue
Block a user