[V1] Logit processors for rejection sampler (#19482)
Signed-off-by: southfreebird <yvorott@gmail.com> Signed-off-by: Sergei Skvortsov <sergeyskv@nebius.com> Signed-off-by: Sergei Skvortsov <yvorott@gmail.com> Co-authored-by: Sergei Skvortsov <sergeyskv@nebius.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import random
|
||||
import sys
|
||||
from typing import Union
|
||||
from typing import Any, Union
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -25,6 +25,7 @@ from tests.v1.logits_processors.utils import entry_points as fake_entry_points
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.v1.sample.logits_processor import (
|
||||
STR_POOLING_REJECTS_LOGITSPROCS,
|
||||
STR_SPEC_DEC_REJECTS_LOGITSPROCS,
|
||||
LogitsProcessor,
|
||||
)
|
||||
|
||||
@@ -205,6 +206,7 @@ def test_custom_logitsprocs_req(monkeypatch):
|
||||
|
||||
|
||||
@create_new_process_for_each_test()
|
||||
@pytest.mark.parametrize("model_scenario", ["pooling", "spec_dec"])
|
||||
@pytest.mark.parametrize(
|
||||
"logitproc_source",
|
||||
[
|
||||
@@ -213,11 +215,12 @@ def test_custom_logitsprocs_req(monkeypatch):
|
||||
CustomLogitprocSource.LOGITPROC_SOURCE_CLASS,
|
||||
],
|
||||
)
|
||||
def test_pooling_rejects_custom_logitsprocs(
|
||||
monkeypatch, logitproc_source: CustomLogitprocSource
|
||||
def test_rejects_custom_logitsprocs(
|
||||
monkeypatch, model_scenario: str, logitproc_source: CustomLogitprocSource
|
||||
):
|
||||
"""Validate that vLLM engine initialization properly rejects custom
|
||||
logitsprocs when the model is a pooling model.
|
||||
logitsprocs when the model is a pooling model or speculative decoding
|
||||
enabled.
|
||||
|
||||
Use `LLM` entrypoint. We expect `LLM` initialization to fail before the
|
||||
logitproc is actually loaded.
|
||||
@@ -241,8 +244,32 @@ def test_pooling_rejects_custom_logitsprocs(
|
||||
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
||||
random.seed(40)
|
||||
|
||||
test_params: dict[str, dict[str, Any]] = {
|
||||
"pooling": {
|
||||
"runner": "pooling",
|
||||
"model": POOLING_MODEL_NAME,
|
||||
"error_message": STR_POOLING_REJECTS_LOGITSPROCS,
|
||||
"speculative_config": None,
|
||||
},
|
||||
"spec_dec": {
|
||||
"runner": "auto",
|
||||
"model": MODEL_NAME,
|
||||
"error_message": STR_SPEC_DEC_REJECTS_LOGITSPROCS,
|
||||
"speculative_config": {"model": "ngram", "num_speculative_tokens": 1},
|
||||
},
|
||||
}
|
||||
|
||||
config = test_params[model_scenario]
|
||||
|
||||
llm_kwargs: dict[str, Any] = {
|
||||
"runner": config["runner"],
|
||||
"model": config["model"],
|
||||
"gpu_memory_utilization": 0.1,
|
||||
"speculative_config": config["speculative_config"],
|
||||
}
|
||||
|
||||
if logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_ENTRYPOINT:
|
||||
# Scenario: vLLM loads a pooling model and ignores a logitproc that is
|
||||
# Scenario: vLLM loads a model and ignores a logitproc that is
|
||||
# available at a preconfigured entrypoint
|
||||
|
||||
# Patch in dummy logitproc entrypoint
|
||||
@@ -254,30 +281,20 @@ def test_pooling_rejects_custom_logitsprocs(
|
||||
# although they should ignore the entrypoint patch anyway
|
||||
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "fork")
|
||||
|
||||
llm = LLM(
|
||||
runner="pooling",
|
||||
model=POOLING_MODEL_NAME,
|
||||
gpu_memory_utilization=0.1,
|
||||
)
|
||||
llm = LLM(**llm_kwargs)
|
||||
# Require that no logitsprocs have been loaded
|
||||
worker = llm.llm_engine.model_executor.driver_worker.worker
|
||||
assert sum([1 for _ in worker.model_runner.input_batch.logitsprocs.all]) == 0
|
||||
return
|
||||
|
||||
kwargs: dict[str, list[Union[str, type[LogitsProcessor]]]] = {}
|
||||
if logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_FQCN:
|
||||
# Scenario: load logitproc based on fully-qualified class name (FQCN)
|
||||
kwargs["logits_processors"] = [DUMMY_LOGITPROC_FQCN]
|
||||
llm_kwargs["logits_processors"] = [DUMMY_LOGITPROC_FQCN]
|
||||
elif logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_CLASS:
|
||||
# Scenario: load logitproc from provided class object
|
||||
kwargs["logits_processors"] = [DummyLogitsProcessor]
|
||||
llm_kwargs["logits_processors"] = [DummyLogitsProcessor]
|
||||
|
||||
with pytest.raises(ValueError, match=STR_POOLING_REJECTS_LOGITSPROCS):
|
||||
# Require that loading a pooling model alongside the logitproc raises
|
||||
with pytest.raises(ValueError, match=config["error_message"]):
|
||||
# Require that loading a model alongside the logitproc raises
|
||||
# the appropriate exception.
|
||||
LLM(
|
||||
runner="pooling",
|
||||
model=POOLING_MODEL_NAME,
|
||||
gpu_memory_utilization=0.1,
|
||||
**kwargs,
|
||||
)
|
||||
LLM(**llm_kwargs)
|
||||
|
||||
Reference in New Issue
Block a user