[V0 Deprecation] Remove VLLM_USE_V1 from tests (#26341)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -280,7 +280,6 @@ def test_get_logprobs_and_prompt_logprobs(
|
||||
batch_logprobs_composition: BatchLogprobsComposition,
|
||||
temperature: float,
|
||||
example_prompts: list[str],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Test V1 Engine logprobs & prompt logprobs
|
||||
|
||||
@@ -308,220 +307,204 @@ def test_get_logprobs_and_prompt_logprobs(
|
||||
temperature: "temperature" sampling parameter
|
||||
example_prompts: example prompt fixture
|
||||
"""
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
do_apc = vllm_model.llm.llm_engine.cache_config.enable_prefix_caching
|
||||
if do_apc and (
|
||||
temperature < 2.0 or batch_logprobs_composition != SAMPLE_PROMPT
|
||||
):
|
||||
# Skip some test-cases to save time.
|
||||
pytest.skip()
|
||||
test_prompts = example_prompts
|
||||
do_apc = vllm_model.llm.llm_engine.cache_config.enable_prefix_caching
|
||||
if do_apc and (temperature < 2.0 or batch_logprobs_composition != SAMPLE_PROMPT):
|
||||
# Skip some test-cases to save time.
|
||||
pytest.skip()
|
||||
test_prompts = example_prompts
|
||||
|
||||
max_tokens = 5
|
||||
hf_outputs = hf_model.generate_greedy(
|
||||
test_prompts,
|
||||
max_tokens = 5
|
||||
hf_outputs = hf_model.generate_greedy(
|
||||
test_prompts,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
hf_logprobs = hf_model.generate_greedy_logprobs(
|
||||
test_prompts,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
# Batch has mixed sample params
|
||||
# (different logprobs/prompt logprobs combos)
|
||||
logprob_prompt_logprob_list = get_test_batch(batch_logprobs_composition)
|
||||
|
||||
# Ensure that each test prompt has a logprob config for testing
|
||||
logprob_prompt_logprob_list = _repeat_logprob_config(
|
||||
test_prompts, logprob_prompt_logprob_list
|
||||
)
|
||||
# Generate SamplingParams
|
||||
vllm_sampling_params = [
|
||||
SamplingParams(
|
||||
max_tokens=max_tokens,
|
||||
logprobs=num_lp,
|
||||
prompt_logprobs=num_plp,
|
||||
temperature=temperature,
|
||||
seed=1984,
|
||||
)
|
||||
hf_logprobs = hf_model.generate_greedy_logprobs(
|
||||
test_prompts,
|
||||
for num_lp, num_plp in logprob_prompt_logprob_list
|
||||
]
|
||||
for _ in range(2 if do_apc else 1):
|
||||
_run_and_validate(
|
||||
vllm_model=vllm_model,
|
||||
test_prompts=test_prompts,
|
||||
vllm_sampling_params=vllm_sampling_params,
|
||||
hf_logprobs=hf_logprobs,
|
||||
hf_outputs=hf_outputs,
|
||||
logprob_prompt_logprob_list=logprob_prompt_logprob_list,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
do_apc=do_apc,
|
||||
)
|
||||
|
||||
# Batch has mixed sample params
|
||||
# (different logprobs/prompt logprobs combos)
|
||||
logprob_prompt_logprob_list = get_test_batch(batch_logprobs_composition)
|
||||
|
||||
# Ensure that each test prompt has a logprob config for testing
|
||||
logprob_prompt_logprob_list = _repeat_logprob_config(
|
||||
test_prompts, logprob_prompt_logprob_list
|
||||
)
|
||||
# Generate SamplingParams
|
||||
vllm_sampling_params = [
|
||||
SamplingParams(
|
||||
max_tokens=max_tokens,
|
||||
logprobs=num_lp,
|
||||
prompt_logprobs=num_plp,
|
||||
temperature=temperature,
|
||||
seed=1984,
|
||||
)
|
||||
for num_lp, num_plp in logprob_prompt_logprob_list
|
||||
]
|
||||
for _ in range(2 if do_apc else 1):
|
||||
_run_and_validate(
|
||||
vllm_model=vllm_model,
|
||||
test_prompts=test_prompts,
|
||||
vllm_sampling_params=vllm_sampling_params,
|
||||
hf_logprobs=hf_logprobs,
|
||||
hf_outputs=hf_outputs,
|
||||
logprob_prompt_logprob_list=logprob_prompt_logprob_list,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
do_apc=do_apc,
|
||||
)
|
||||
|
||||
|
||||
def test_max_logprobs(monkeypatch: pytest.MonkeyPatch):
|
||||
def test_max_logprobs():
|
||||
"""vLLM v1 engine should fail a request with `logprobs > max_logprobs`
|
||||
Should also fail for `prompt_logprobs > max_logprobs`
|
||||
APC should not matter as this test checks basic request validation.
|
||||
"""
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
runner = VllmRunner(
|
||||
"facebook/opt-125m",
|
||||
max_logprobs=1,
|
||||
enable_prefix_caching=False,
|
||||
# 2 other llms alive during whole session
|
||||
gpu_memory_utilization=0.15,
|
||||
max_model_len=256,
|
||||
)
|
||||
vllm_sampling_params = SamplingParams(logprobs=1)
|
||||
# should pass
|
||||
runner.generate(["Hello world"], sampling_params=vllm_sampling_params)
|
||||
|
||||
runner = VllmRunner(
|
||||
"facebook/opt-125m",
|
||||
max_logprobs=1,
|
||||
enable_prefix_caching=False,
|
||||
# 2 other llms alive during whole session
|
||||
gpu_memory_utilization=0.15,
|
||||
max_model_len=256,
|
||||
)
|
||||
vllm_sampling_params = SamplingParams(logprobs=1)
|
||||
# should pass
|
||||
runner.generate(["Hello world"], sampling_params=vllm_sampling_params)
|
||||
|
||||
bad_sampling_params = SamplingParams(logprobs=2)
|
||||
with pytest.raises(ValueError):
|
||||
runner.generate(["Hello world"], sampling_params=bad_sampling_params)
|
||||
bad_sampling_params = SamplingParams(logprobs=2)
|
||||
with pytest.raises(ValueError):
|
||||
runner.generate(["Hello world"], sampling_params=bad_sampling_params)
|
||||
|
||||
|
||||
def test_none_logprobs(vllm_model, example_prompts, monkeypatch: pytest.MonkeyPatch):
|
||||
def test_none_logprobs(vllm_model, example_prompts):
|
||||
"""Engine should return `logprobs` and `prompt_logprobs` as `None`
|
||||
|
||||
Args:
|
||||
vllm_model: vLLM model fixture
|
||||
example_prompts: list of example prompts (test fixture)
|
||||
"""
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
max_tokens = 5
|
||||
max_tokens = 5
|
||||
|
||||
sampling_params_logprobs_none = SamplingParams(
|
||||
max_tokens=max_tokens,
|
||||
logprobs=None,
|
||||
prompt_logprobs=None,
|
||||
temperature=0.0,
|
||||
)
|
||||
results_logprobs_none = vllm_model.llm.generate(
|
||||
example_prompts,
|
||||
sampling_params=sampling_params_logprobs_none,
|
||||
)
|
||||
sampling_params_logprobs_none = SamplingParams(
|
||||
max_tokens=max_tokens,
|
||||
logprobs=None,
|
||||
prompt_logprobs=None,
|
||||
temperature=0.0,
|
||||
)
|
||||
results_logprobs_none = vllm_model.llm.generate(
|
||||
example_prompts,
|
||||
sampling_params=sampling_params_logprobs_none,
|
||||
)
|
||||
|
||||
for i in range(len(results_logprobs_none)):
|
||||
# Check sample logprobs are None
|
||||
assert results_logprobs_none[i].outputs[0].logprobs is None
|
||||
assert results_logprobs_none[i].outputs[0].cumulative_logprob is None
|
||||
# Check prompt logprobs are None
|
||||
assert results_logprobs_none[i].prompt_logprobs is None
|
||||
for i in range(len(results_logprobs_none)):
|
||||
# Check sample logprobs are None
|
||||
assert results_logprobs_none[i].outputs[0].logprobs is None
|
||||
assert results_logprobs_none[i].outputs[0].cumulative_logprob is None
|
||||
# Check prompt logprobs are None
|
||||
assert results_logprobs_none[i].prompt_logprobs is None
|
||||
|
||||
|
||||
def test_zero_logprobs(vllm_model, example_prompts, monkeypatch: pytest.MonkeyPatch):
|
||||
def test_zero_logprobs(vllm_model, example_prompts):
|
||||
"""Engine should return sampled token and prompt token logprobs
|
||||
|
||||
Args:
|
||||
vllm_model: vLLM model fixture
|
||||
example_prompts: list of example prompts (test fixture)
|
||||
"""
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
max_tokens = 5
|
||||
max_tokens = 5
|
||||
|
||||
sampling_params_logprobs_zero = SamplingParams(
|
||||
max_tokens=max_tokens, logprobs=0, prompt_logprobs=0, temperature=0.0
|
||||
)
|
||||
results_logprobs_zero = vllm_model.llm.generate(
|
||||
example_prompts, sampling_params=sampling_params_logprobs_zero
|
||||
)
|
||||
sampling_params_logprobs_zero = SamplingParams(
|
||||
max_tokens=max_tokens, logprobs=0, prompt_logprobs=0, temperature=0.0
|
||||
)
|
||||
results_logprobs_zero = vllm_model.llm.generate(
|
||||
example_prompts, sampling_params=sampling_params_logprobs_zero
|
||||
)
|
||||
|
||||
for i in range(len(results_logprobs_zero)):
|
||||
# Check that there is one sample logprob dict for each
|
||||
# sample token
|
||||
logprobs = results_logprobs_zero[i].outputs[0].logprobs
|
||||
prompt_logprobs = results_logprobs_zero[i].prompt_logprobs
|
||||
sampled_token_ids = results_logprobs_zero[i].outputs[0].token_ids
|
||||
prompt_token_ids = results_logprobs_zero[i].prompt_token_ids
|
||||
assert logprobs is not None
|
||||
assert len(sampled_token_ids) == len(logprobs)
|
||||
assert results_logprobs_zero[i].outputs[0].cumulative_logprob is not None
|
||||
# Check that there is one prompt logprob dict for each
|
||||
# prompt token
|
||||
assert prompt_logprobs is not None
|
||||
assert len(prompt_token_ids) == len(prompt_logprobs)
|
||||
for i in range(len(results_logprobs_zero)):
|
||||
# Check that there is one sample logprob dict for each
|
||||
# sample token
|
||||
logprobs = results_logprobs_zero[i].outputs[0].logprobs
|
||||
prompt_logprobs = results_logprobs_zero[i].prompt_logprobs
|
||||
sampled_token_ids = results_logprobs_zero[i].outputs[0].token_ids
|
||||
prompt_token_ids = results_logprobs_zero[i].prompt_token_ids
|
||||
assert logprobs is not None
|
||||
assert len(sampled_token_ids) == len(logprobs)
|
||||
assert results_logprobs_zero[i].outputs[0].cumulative_logprob is not None
|
||||
# Check that there is one prompt logprob dict for each
|
||||
# prompt token
|
||||
assert prompt_logprobs is not None
|
||||
assert len(prompt_token_ids) == len(prompt_logprobs)
|
||||
|
||||
|
||||
def test_all_logprobs(example_prompts, monkeypatch: pytest.MonkeyPatch):
|
||||
def test_all_logprobs(example_prompts):
|
||||
"""Engine should return all vocabulary logprobs and prompt logprobs
|
||||
|
||||
Args:
|
||||
example_prompts: list of example prompts (test fixture)
|
||||
"""
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
runner = VllmRunner(
|
||||
"facebook/opt-125m",
|
||||
max_logprobs=-1,
|
||||
enable_prefix_caching=False,
|
||||
# 2 other llms alive during whole session
|
||||
gpu_memory_utilization=0.15,
|
||||
max_model_len=256,
|
||||
)
|
||||
runner = VllmRunner(
|
||||
"facebook/opt-125m",
|
||||
max_logprobs=-1,
|
||||
enable_prefix_caching=False,
|
||||
# 2 other llms alive during whole session
|
||||
gpu_memory_utilization=0.15,
|
||||
max_model_len=256,
|
||||
)
|
||||
|
||||
sampling_params_logprobs_all = SamplingParams(
|
||||
max_tokens=5, logprobs=-1, prompt_logprobs=-1
|
||||
)
|
||||
results_logprobs_all = runner.llm.generate(
|
||||
example_prompts, sampling_params=sampling_params_logprobs_all
|
||||
)
|
||||
vocab_size = runner.llm.llm_engine.get_model_config().get_vocab_size()
|
||||
sampling_params_logprobs_all = SamplingParams(
|
||||
max_tokens=5, logprobs=-1, prompt_logprobs=-1
|
||||
)
|
||||
results_logprobs_all = runner.llm.generate(
|
||||
example_prompts, sampling_params=sampling_params_logprobs_all
|
||||
)
|
||||
vocab_size = runner.llm.llm_engine.get_model_config().get_vocab_size()
|
||||
|
||||
for i in range(len(results_logprobs_all)):
|
||||
logprobs = results_logprobs_all[i].outputs[0].logprobs
|
||||
prompt_logprobs = results_logprobs_all[i].prompt_logprobs
|
||||
assert logprobs is not None
|
||||
for logprob in logprobs:
|
||||
assert len(logprob) == vocab_size
|
||||
assert prompt_logprobs is not None
|
||||
assert prompt_logprobs[0] is None
|
||||
for prompt_logprob in prompt_logprobs[1:]:
|
||||
assert len(prompt_logprob) == vocab_size
|
||||
for i in range(len(results_logprobs_all)):
|
||||
logprobs = results_logprobs_all[i].outputs[0].logprobs
|
||||
prompt_logprobs = results_logprobs_all[i].prompt_logprobs
|
||||
assert logprobs is not None
|
||||
for logprob in logprobs:
|
||||
assert len(logprob) == vocab_size
|
||||
assert prompt_logprobs is not None
|
||||
assert prompt_logprobs[0] is None
|
||||
for prompt_logprob in prompt_logprobs[1:]:
|
||||
assert len(prompt_logprob) == vocab_size
|
||||
|
||||
|
||||
@pytest.mark.parametrize("logprobs_mode", get_args(LogprobsMode))
|
||||
def test_logprobs_mode(logprobs_mode: LogprobsMode, monkeypatch: pytest.MonkeyPatch):
|
||||
def test_logprobs_mode(logprobs_mode: LogprobsMode):
|
||||
"""Test with LLM engine with different logprobs_mode.
|
||||
For logprobs, we should have non-positive values.
|
||||
For logits, we should expect at least one positive values.
|
||||
"""
|
||||
from vllm import LLM
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
llm = LLM(
|
||||
"facebook/opt-125m",
|
||||
max_logprobs=5,
|
||||
enable_prefix_caching=False,
|
||||
# 2 other llms alive during whole session
|
||||
gpu_memory_utilization=0.05,
|
||||
max_model_len=16,
|
||||
logprobs_mode=logprobs_mode,
|
||||
)
|
||||
vllm_sampling_params = SamplingParams(logprobs=1)
|
||||
results = llm.generate(["Hello world"], sampling_params=vllm_sampling_params)
|
||||
|
||||
llm = LLM(
|
||||
"facebook/opt-125m",
|
||||
max_logprobs=5,
|
||||
enable_prefix_caching=False,
|
||||
# 2 other llms alive during whole session
|
||||
gpu_memory_utilization=0.05,
|
||||
max_model_len=16,
|
||||
logprobs_mode=logprobs_mode,
|
||||
)
|
||||
vllm_sampling_params = SamplingParams(logprobs=1)
|
||||
results = llm.generate(["Hello world"], sampling_params=vllm_sampling_params)
|
||||
|
||||
total_token_with_logprobs = 0
|
||||
positive_values = 0
|
||||
for output in results[0].outputs:
|
||||
for logprobs in output.logprobs:
|
||||
for token_id in logprobs:
|
||||
logprob = logprobs[token_id]
|
||||
if logprobs_mode in ("raw_logprobs", "processed_logprobs"):
|
||||
assert logprob.logprob <= 0
|
||||
if logprob.logprob > 0:
|
||||
positive_values = positive_values + 1
|
||||
total_token_with_logprobs = total_token_with_logprobs + 1
|
||||
assert total_token_with_logprobs >= len(results[0].outputs)
|
||||
if logprobs_mode in ("raw_logits", "processed_logits"):
|
||||
assert positive_values > 0
|
||||
del llm
|
||||
total_token_with_logprobs = 0
|
||||
positive_values = 0
|
||||
for output in results[0].outputs:
|
||||
for logprobs in output.logprobs:
|
||||
for token_id in logprobs:
|
||||
logprob = logprobs[token_id]
|
||||
if logprobs_mode in ("raw_logprobs", "processed_logprobs"):
|
||||
assert logprob.logprob <= 0
|
||||
if logprob.logprob > 0:
|
||||
positive_values = positive_values + 1
|
||||
total_token_with_logprobs = total_token_with_logprobs + 1
|
||||
assert total_token_with_logprobs >= len(results[0].outputs)
|
||||
if logprobs_mode in ("raw_logits", "processed_logits"):
|
||||
assert positive_values > 0
|
||||
del llm
|
||||
|
||||
Reference in New Issue
Block a user