[V1] Logit processors for rejection sampler (#19482)

Signed-off-by: southfreebird <yvorott@gmail.com>
Signed-off-by: Sergei Skvortsov <sergeyskv@nebius.com>
Signed-off-by: Sergei Skvortsov <yvorott@gmail.com>
Co-authored-by: Sergei Skvortsov <sergeyskv@nebius.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Sergei Skvortsov
2025-10-07 21:02:49 +01:00
committed by GitHub
parent 0c824fc46f
commit 6ebaf43ee4
12 changed files with 471 additions and 92 deletions

View File

@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import random
import sys
from typing import Union
from typing import Any, Union
import pytest
@@ -25,6 +25,7 @@ from tests.v1.logits_processors.utils import entry_points as fake_entry_points
from vllm import LLM, SamplingParams
from vllm.v1.sample.logits_processor import (
STR_POOLING_REJECTS_LOGITSPROCS,
STR_SPEC_DEC_REJECTS_LOGITSPROCS,
LogitsProcessor,
)
@@ -205,6 +206,7 @@ def test_custom_logitsprocs_req(monkeypatch):
@create_new_process_for_each_test()
@pytest.mark.parametrize("model_scenario", ["pooling", "spec_dec"])
@pytest.mark.parametrize(
"logitproc_source",
[
@@ -213,11 +215,12 @@ def test_custom_logitsprocs_req(monkeypatch):
CustomLogitprocSource.LOGITPROC_SOURCE_CLASS,
],
)
def test_pooling_rejects_custom_logitsprocs(
monkeypatch, logitproc_source: CustomLogitprocSource
def test_rejects_custom_logitsprocs(
monkeypatch, model_scenario: str, logitproc_source: CustomLogitprocSource
):
"""Validate that vLLM engine initialization properly rejects custom
logitsprocs when the model is a pooling model.
logitsprocs when the model is a pooling model or speculative decoding
enabled.
Use `LLM` entrypoint. We expect `LLM` initialization to fail before the
logitproc is actually loaded.
@@ -241,8 +244,32 @@ def test_pooling_rejects_custom_logitsprocs(
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
random.seed(40)
test_params: dict[str, dict[str, Any]] = {
"pooling": {
"runner": "pooling",
"model": POOLING_MODEL_NAME,
"error_message": STR_POOLING_REJECTS_LOGITSPROCS,
"speculative_config": None,
},
"spec_dec": {
"runner": "auto",
"model": MODEL_NAME,
"error_message": STR_SPEC_DEC_REJECTS_LOGITSPROCS,
"speculative_config": {"model": "ngram", "num_speculative_tokens": 1},
},
}
config = test_params[model_scenario]
llm_kwargs: dict[str, Any] = {
"runner": config["runner"],
"model": config["model"],
"gpu_memory_utilization": 0.1,
"speculative_config": config["speculative_config"],
}
if logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_ENTRYPOINT:
# Scenario: vLLM loads a pooling model and ignores a logitproc that is
# Scenario: vLLM loads a model and ignores a logitproc that is
# available at a preconfigured entrypoint
# Patch in dummy logitproc entrypoint
@@ -254,30 +281,20 @@ def test_pooling_rejects_custom_logitsprocs(
# although they should ignore the entrypoint patch anyway
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "fork")
llm = LLM(
runner="pooling",
model=POOLING_MODEL_NAME,
gpu_memory_utilization=0.1,
)
llm = LLM(**llm_kwargs)
# Require that no logitsprocs have been loaded
worker = llm.llm_engine.model_executor.driver_worker.worker
assert sum([1 for _ in worker.model_runner.input_batch.logitsprocs.all]) == 0
return
kwargs: dict[str, list[Union[str, type[LogitsProcessor]]]] = {}
if logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_FQCN:
# Scenario: load logitproc based on fully-qualified class name (FQCN)
kwargs["logits_processors"] = [DUMMY_LOGITPROC_FQCN]
llm_kwargs["logits_processors"] = [DUMMY_LOGITPROC_FQCN]
elif logitproc_source == CustomLogitprocSource.LOGITPROC_SOURCE_CLASS:
# Scenario: load logitproc from provided class object
kwargs["logits_processors"] = [DummyLogitsProcessor]
llm_kwargs["logits_processors"] = [DummyLogitsProcessor]
with pytest.raises(ValueError, match=STR_POOLING_REJECTS_LOGITSPROCS):
# Require that loading a pooling model alongside the logitproc raises
with pytest.raises(ValueError, match=config["error_message"]):
# Require that loading a model alongside the logitproc raises
# the appropriate exception.
LLM(
runner="pooling",
model=POOLING_MODEL_NAME,
gpu_memory_utilization=0.1,
**kwargs,
)
LLM(**llm_kwargs)