[Misc] unify variable for LLM instance (#20996)
Signed-off-by: Andy Xie <andy.xning@gmail.com>
This commit is contained in:
@@ -112,7 +112,7 @@ def _run_and_validate(
|
||||
max_tokens: int,
|
||||
do_apc: bool,
|
||||
) -> None:
|
||||
vllm_results = vllm_model.model.generate(
|
||||
vllm_results = vllm_model.llm.generate(
|
||||
test_prompts, sampling_params=vllm_sampling_params)
|
||||
|
||||
for vllm_result, hf_logprob, hf_output, logprob_prompt_logprob in zip(
|
||||
@@ -288,7 +288,7 @@ def test_get_logprobs_and_prompt_logprobs(
|
||||
"""
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
do_apc = vllm_model.model.llm_engine.cache_config.enable_prefix_caching
|
||||
do_apc = vllm_model.llm.llm_engine.cache_config.enable_prefix_caching
|
||||
if do_apc and (temperature < 2.0
|
||||
or batch_logprobs_composition != SAMPLE_PROMPT):
|
||||
# Skip some test-cases to save time.
|
||||
@@ -378,7 +378,7 @@ def test_none_logprobs(vllm_model, example_prompts,
|
||||
prompt_logprobs=None,
|
||||
temperature=0.0,
|
||||
)
|
||||
results_logprobs_none = vllm_model.model.generate(
|
||||
results_logprobs_none = vllm_model.llm.generate(
|
||||
example_prompts,
|
||||
sampling_params=sampling_params_logprobs_none,
|
||||
)
|
||||
@@ -408,7 +408,7 @@ def test_zero_logprobs(vllm_model, example_prompts,
|
||||
logprobs=0,
|
||||
prompt_logprobs=0,
|
||||
temperature=0.0)
|
||||
results_logprobs_zero = vllm_model.model.generate(
|
||||
results_logprobs_zero = vllm_model.llm.generate(
|
||||
example_prompts, sampling_params=sampling_params_logprobs_zero)
|
||||
|
||||
for i in range(len(results_logprobs_zero)):
|
||||
|
||||
@@ -14,30 +14,30 @@ PROMPT = "Hello my name is Robert and I"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def model() -> LLM:
|
||||
def llm() -> LLM:
|
||||
# Disable prefix caching so that we can test prompt logprobs.
|
||||
# TODO remove this after https://github.com/vllm-project/vllm/pull/13949
|
||||
# is merged
|
||||
return LLM(MODEL, enforce_eager=True, enable_prefix_caching=False)
|
||||
|
||||
|
||||
def test_n_gt_1(model):
|
||||
def test_n_gt_1(llm):
|
||||
"""ParallelSampling is supported."""
|
||||
|
||||
params = SamplingParams(n=3)
|
||||
outputs = model.generate(PROMPT, params)
|
||||
outputs = llm.generate(PROMPT, params)
|
||||
assert len(outputs[0].outputs) == 3
|
||||
|
||||
|
||||
def test_best_of(model):
|
||||
def test_best_of(llm):
|
||||
"""Raise a ValueError since best_of is deprecated."""
|
||||
|
||||
params = SamplingParams(n=2, best_of=3)
|
||||
with pytest.raises(ValueError):
|
||||
_ = model.generate(PROMPT, params)
|
||||
_ = llm.generate(PROMPT, params)
|
||||
|
||||
|
||||
def test_penalties(model):
|
||||
def test_penalties(llm):
|
||||
"""Check that we do not get errors if applied."""
|
||||
|
||||
params = SamplingParams(
|
||||
@@ -49,18 +49,18 @@ def test_penalties(model):
|
||||
top_p=0.5,
|
||||
top_k=3,
|
||||
)
|
||||
_ = model.generate(PROMPT, params)
|
||||
_ = llm.generate(PROMPT, params)
|
||||
|
||||
|
||||
def test_stop(model):
|
||||
def test_stop(llm):
|
||||
"""Check that we respect the stop words."""
|
||||
|
||||
output = model.generate(PROMPT, SamplingParams(temperature=0))
|
||||
output = llm.generate(PROMPT, SamplingParams(temperature=0))
|
||||
split_text = output[0].outputs[0].text.split()
|
||||
|
||||
STOP_IDX = 5
|
||||
params = SamplingParams(temperature=0, stop=split_text[STOP_IDX])
|
||||
output = model.generate(PROMPT, params)
|
||||
output = llm.generate(PROMPT, params)
|
||||
new_split_text = output[0].outputs[0].text.split()
|
||||
|
||||
# Output should not contain the stop word.
|
||||
@@ -69,40 +69,40 @@ def test_stop(model):
|
||||
params = SamplingParams(temperature=0,
|
||||
stop=split_text[STOP_IDX],
|
||||
include_stop_str_in_output=True)
|
||||
output = model.generate(PROMPT, params)
|
||||
output = llm.generate(PROMPT, params)
|
||||
new_split_text = output[0].outputs[0].text.split()
|
||||
|
||||
# Output should contain the stop word.
|
||||
assert len(new_split_text) == STOP_IDX + 1
|
||||
|
||||
|
||||
def test_stop_token_ids(model):
|
||||
def test_stop_token_ids(llm):
|
||||
"""Check that we respect the stop token ids."""
|
||||
|
||||
output = model.generate(PROMPT, SamplingParams(temperature=0))
|
||||
output = llm.generate(PROMPT, SamplingParams(temperature=0))
|
||||
|
||||
stop_token_id_0 = output[0].outputs[0].token_ids[5]
|
||||
stop_token_id_1 = output[0].outputs[0].token_ids[6]
|
||||
|
||||
stop_token_ids = [stop_token_id_1, stop_token_id_0]
|
||||
params = SamplingParams(temperature=0, stop_token_ids=stop_token_ids)
|
||||
output = model.generate(PROMPT, params)
|
||||
output = llm.generate(PROMPT, params)
|
||||
assert output[0].outputs[0].token_ids[-1] == stop_token_id_0
|
||||
|
||||
stop_token_ids = [stop_token_id_0, stop_token_id_1]
|
||||
params = SamplingParams(temperature=0, stop_token_ids=stop_token_ids)
|
||||
output = model.generate(PROMPT, params)
|
||||
output = llm.generate(PROMPT, params)
|
||||
assert output[0].outputs[0].token_ids[-1] == stop_token_id_0
|
||||
|
||||
|
||||
def test_detokenize_false(model):
|
||||
def test_detokenize_false(llm):
|
||||
"""Check that detokenize=False option works."""
|
||||
|
||||
output = model.generate(PROMPT, SamplingParams(detokenize=False))
|
||||
output = llm.generate(PROMPT, SamplingParams(detokenize=False))
|
||||
assert len(output[0].outputs[0].token_ids) > 0
|
||||
assert len(output[0].outputs[0].text) == 0
|
||||
|
||||
output = model.generate(
|
||||
output = llm.generate(
|
||||
PROMPT, SamplingParams(detokenize=False, logprobs=3,
|
||||
prompt_logprobs=3))
|
||||
assert len(output[0].outputs[0].token_ids) > 0
|
||||
@@ -118,28 +118,28 @@ def test_detokenize_false(model):
|
||||
assert all(lp.decoded_token is None for lp in logprobs.values())
|
||||
|
||||
|
||||
def test_bad_words(model):
|
||||
def test_bad_words(llm):
|
||||
"""Check that we respect bad words."""
|
||||
|
||||
output = model.generate(PROMPT, SamplingParams(temperature=0))
|
||||
output = llm.generate(PROMPT, SamplingParams(temperature=0))
|
||||
split_text = output[0].outputs[0].text.split()
|
||||
|
||||
bad_words_1 = " ".join(split_text[:2])
|
||||
params = SamplingParams(temperature=0, bad_words=[bad_words_1])
|
||||
output = model.generate(PROMPT, params)
|
||||
output = llm.generate(PROMPT, params)
|
||||
new_text = output[0].outputs[0].text
|
||||
assert bad_words_1 not in new_text
|
||||
|
||||
bad_words_2 = new_text.split()[-1]
|
||||
params = SamplingParams(temperature=0,
|
||||
bad_words=[bad_words_1, bad_words_2])
|
||||
output = model.generate(PROMPT, params)
|
||||
output = llm.generate(PROMPT, params)
|
||||
new_text = output[0].outputs[0].text
|
||||
assert bad_words_1 not in new_text
|
||||
assert bad_words_2 not in new_text
|
||||
|
||||
|
||||
def test_logits_processor(model):
|
||||
def test_logits_processor(llm):
|
||||
"""Check that we reject logits processor."""
|
||||
|
||||
# This sample logits processor gives infinite score to the i-th token,
|
||||
@@ -150,47 +150,45 @@ def test_logits_processor(model):
|
||||
return logits
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
_ = model.generate(PROMPT,
|
||||
SamplingParams(logits_processors=[pick_ith]))
|
||||
_ = llm.generate(PROMPT, SamplingParams(logits_processors=[pick_ith]))
|
||||
|
||||
|
||||
def test_allowed_token_ids(model):
|
||||
def test_allowed_token_ids(llm):
|
||||
"""Check that we can use allowed_token_ids."""
|
||||
|
||||
TOKEN_ID = 10
|
||||
allowed_token_ids = [TOKEN_ID]
|
||||
output = model.generate(
|
||||
PROMPT, SamplingParams(allowed_token_ids=allowed_token_ids))
|
||||
output = llm.generate(PROMPT,
|
||||
SamplingParams(allowed_token_ids=allowed_token_ids))
|
||||
assert output[0].outputs[0].token_ids[-1] == TOKEN_ID
|
||||
|
||||
# Reject empty allowed_token_ids.
|
||||
with pytest.raises(ValueError):
|
||||
_ = model.generate(PROMPT, SamplingParams(allowed_token_ids=[]))
|
||||
_ = llm.generate(PROMPT, SamplingParams(allowed_token_ids=[]))
|
||||
|
||||
# Reject negative token id.
|
||||
with pytest.raises(ValueError):
|
||||
_ = model.generate(PROMPT, SamplingParams(allowed_token_ids=[-1]))
|
||||
_ = llm.generate(PROMPT, SamplingParams(allowed_token_ids=[-1]))
|
||||
|
||||
# Reject out of vocabulary.
|
||||
with pytest.raises(ValueError):
|
||||
_ = model.generate(PROMPT,
|
||||
SamplingParams(allowed_token_ids=[10000000]))
|
||||
_ = llm.generate(PROMPT, SamplingParams(allowed_token_ids=[10000000]))
|
||||
|
||||
|
||||
def test_priority(model):
|
||||
def test_priority(llm):
|
||||
"""Check that we reject requests with priority."""
|
||||
|
||||
# Reject all allowed token ids
|
||||
with pytest.raises(ValueError):
|
||||
_ = model.generate(PROMPT, priority=[1])
|
||||
_ = llm.generate(PROMPT, priority=[1])
|
||||
|
||||
|
||||
def test_seed(model):
|
||||
def test_seed(llm):
|
||||
"""Check that seed impacts randomness."""
|
||||
|
||||
out_1 = model.generate(PROMPT, SamplingParams(seed=42))
|
||||
out_2 = model.generate(PROMPT, SamplingParams(seed=42))
|
||||
out_3 = model.generate(PROMPT, SamplingParams(seed=43))
|
||||
out_1 = llm.generate(PROMPT, SamplingParams(seed=42))
|
||||
out_2 = llm.generate(PROMPT, SamplingParams(seed=42))
|
||||
out_3 = llm.generate(PROMPT, SamplingParams(seed=43))
|
||||
|
||||
assert out_1[0].outputs[0].text == out_2[0].outputs[0].text
|
||||
assert out_1[0].outputs[0].text != out_3[0].outputs[0].text
|
||||
|
||||
Reference in New Issue
Block a user