[V1] Adding min tokens/repetition/presence/frequence penalties to V1 sampler (#10681)

Signed-off-by: Sourashis Roy <sroy@roblox.com>
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
sroy745
2024-12-26 02:02:58 -08:00
committed by GitHub
parent 7492a36207
commit dcb1a944d4
11 changed files with 879 additions and 49 deletions

View File

@@ -139,3 +139,41 @@ def test_engine_core(monkeypatch):
engine_core.abort_requests([req2.request_id, req0.request_id])
assert len(engine_core.scheduler.waiting) == 0
assert len(engine_core.scheduler.running) == 0
def test_engine_core_advanced_sampling(monkeypatch):
"""
A basic end-to-end test to verify that the engine functions correctly
when additional sampling parameters, such as min_tokens and
presence_penalty, are set.
"""
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
"""Setup the EngineCore."""
engine_args = EngineArgs(model=MODEL_NAME)
vllm_config = engine_args.create_engine_config(
usage_context=UsageContext.UNKNOWN_CONTEXT)
executor_class = AsyncLLM._get_executor_cls(vllm_config)
engine_core = EngineCore(vllm_config=vllm_config,
executor_class=executor_class,
usage_context=UsageContext.UNKNOWN_CONTEXT)
"""Test basic request lifecycle."""
# First request.
request: EngineCoreRequest = make_request()
request.sampling_params = SamplingParams(
min_tokens=4,
presence_penalty=1.0,
frequency_penalty=1.0,
repetition_penalty=0.1,
stop_token_ids=[1001, 1002],
)
engine_core.add_request(request)
assert len(engine_core.scheduler.waiting) == 1
assert len(engine_core.scheduler.running) == 0
# Loop through until they are all done.
while len(engine_core.step()) > 0:
pass
assert len(engine_core.scheduler.waiting) == 0
assert len(engine_core.scheduler.running) == 0