[V1] Wrapper which plumbs request-level logits processors into vLLM batch-level logits processing (#23656)
Signed-off-by: Andrew Feldman <afeldman@redhat.com>
This commit is contained in:
@@ -15,6 +15,7 @@ from tests.v1.logits_processors.utils import (DUMMY_LOGITPROC_ARG,
|
||||
POOLING_MODEL_NAME, TEMP_GREEDY,
|
||||
CustomLogitprocSource,
|
||||
DummyLogitsProcessor,
|
||||
WrappedPerReqLogitsProcessor,
|
||||
dummy_module)
|
||||
from tests.v1.logits_processors.utils import entry_points as fake_entry_points
|
||||
from tests.v1.logits_processors.utils import prompts
|
||||
@@ -161,6 +162,38 @@ def test_custom_logitsprocs(monkeypatch,
|
||||
_run_test(kwargs, logitproc_loaded=True)
|
||||
|
||||
|
||||
@create_new_process_for_each_test()
|
||||
def test_custom_logitsprocs_req(monkeypatch):
|
||||
"""Test passing request-level logits processor to offline Python interface
|
||||
|
||||
Wrap a request-level logits processor to create a batch level logits
|
||||
processor that has a well-defined behavior (mask out all tokens except one
|
||||
`target_token`)
|
||||
|
||||
Construct an `LLM` instance which loads the wrapped logits processor. Pass
|
||||
the custom logitproc as a class object.
|
||||
|
||||
Construct a reference `LLM` instance with no custom logitproc
|
||||
|
||||
Pass in a batch of requests, 50% of which pass a `target_token` value
|
||||
in through `SamplingParams.extra_args`, 50% of which do not.
|
||||
|
||||
Validate that
|
||||
* Requests which do not activate the custom logitproc, yield the same
|
||||
results for both `LLM` instances
|
||||
* Requests which activate the custom logitproc, only output `target_token`
|
||||
|
||||
Args:
|
||||
monkeypatch: for setting env vars
|
||||
"""
|
||||
|
||||
# Test that logitproc info is passed to workers
|
||||
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "1")
|
||||
random.seed(40)
|
||||
_run_test({"logits_processors": [WrappedPerReqLogitsProcessor]},
|
||||
logitproc_loaded=True)
|
||||
|
||||
|
||||
@create_new_process_for_each_test()
|
||||
@pytest.mark.parametrize("logitproc_source", [
|
||||
CustomLogitprocSource.LOGITPROC_SOURCE_ENTRYPOINT,
|
||||
|
||||
Reference in New Issue
Block a user