[V1] LogitsProcessor programming model (#16728)

Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
Signed-off-by: Andrew Feldman <afeldman@redhat.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
afeldman-nm
2025-07-02 12:10:42 -04:00
committed by GitHub
parent c1909e7e8c
commit 48fb076cbc
13 changed files with 1401 additions and 393 deletions

View File

@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import inspect
from collections.abc import Sequence
from typing import Optional
import numpy as np
@@ -12,6 +13,7 @@ from vllm.platforms import current_platform
from vllm.sampling_params import SamplingParams
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.logits_processor import LogitsProcessorManager
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.worker.block_table import BlockTable, MultiGroupBlockTable
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
@@ -26,13 +28,18 @@ CUDA_DEVICES = [
MAX_NUM_PROMPT_TOKENS = 64
def _compare_objs(obj1, obj2):
def _compare_objs(obj1,
obj2,
skip: Sequence = ("logitsprocs", "batch_update_builder")):
attrs = inspect.getmembers(obj1, lambda a: not (inspect.isroutine(a)))
attr_names = set([
a[0] for a in attrs
if not (a[0].startswith('__') and a[0].endswith('__'))
])
for attr_name in attr_names:
if attr_name in skip:
continue
a = getattr(obj1, attr_name)
b = getattr(obj2, attr_name)
@@ -58,13 +65,11 @@ def _compare_objs(obj1, obj2):
f" in {obj1} and {obj2}: {a} != {b}"
def _remove_requests(
input_batch: InputBatch, batch_size: int,
reqs: list[CachedRequestState]) -> tuple[set[str], list[int]]:
def _remove_requests(input_batch: InputBatch, batch_size: int,
reqs: list[CachedRequestState]) -> set[str]:
"""
Remove some requests randomly from the batch and returns a tuple
of 1) set of request removed 2) indices of the requests removed
ordered in descending order
Remove some requests randomly from the batch and returns
set of request removed
"""
num_reqs_to_remove = np.random.randint(0, batch_size)
@@ -73,13 +78,11 @@ def _remove_requests(
req_index_to_remove = np.random.randint(0, batch_size)
req_indices_to_remove.add(req_index_to_remove)
req_indices_to_remove_list = list(req_indices_to_remove)
req_indices_to_remove_list.sort(reverse=True)
req_ids_to_remove: set[str] = set()
for index in req_indices_to_remove:
input_batch.remove_request(reqs[index].req_id)
req_ids_to_remove.add(reqs[index].req_id)
return req_ids_to_remove, req_indices_to_remove_list
return req_ids_to_remove
def _construct_expected_sampling_metadata(
@@ -100,7 +103,6 @@ def _construct_expected_sampling_metadata(
repetition_penalties = [1.0 for _ in range(num_reqs)]
top_k = [0 for _ in range(num_reqs)]
top_p = [0.0 for _ in range(num_reqs)]
min_p = [0.0 for _ in range(num_reqs)]
temperature = [0.0 for _ in range(num_reqs)]
min_tokens = {}
logit_bias = [None] * num_reqs
@@ -123,7 +125,6 @@ def _construct_expected_sampling_metadata(
req.sampling_params.repetition_penalty)
top_k[index_in_input_batch] = req.sampling_params.top_k
top_p[index_in_input_batch] = req.sampling_params.top_p
min_p[index_in_input_batch] = req.sampling_params.min_p
temperature[index_in_input_batch] = req.sampling_params.temperature
min_tokens[index_in_input_batch] = (
req.sampling_params.min_tokens,
@@ -145,8 +146,6 @@ def _construct_expected_sampling_metadata(
top_p, dtype=torch.float, device=device),
top_k=None if all(x == 0 for x in top_k) else torch.tensor(
top_k, dtype=torch.int, device=device),
min_p=None if all(x == 0.0 for x in min_p) else torch.tensor(
min_p, dtype=torch.float, device=device),
generators={},
max_num_logprobs=0,
prompt_token_ids=make_tensor_with_pad(
@@ -165,13 +164,12 @@ def _construct_expected_sampling_metadata(
dtype=torch.float,
device=device),
output_token_ids=output_token_ids,
min_tokens=min_tokens,
no_penalties=(all(x == 0 for x in presence_penalties)
and all(x == 0 for x in frequency_penalties)
and all(x == 1 for x in repetition_penalties)),
logit_bias=logit_bias,
allowed_token_ids_mask=allowed_token_ids_mask,
bad_words_token_ids=bad_words_token_ids,
logitsprocs=LogitsProcessorManager(),
)
@@ -225,6 +223,8 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
and the `make_sampling_metadata` method is invoked on the batch. The
output of `make_sampling_metadata` is then compared against the expected
results to ensure correctness.
Note: Ignore logits processor logic, which is tested separately
"""
input_batch: InputBatch = InputBatch(
max_num_reqs=batch_size,
@@ -238,21 +238,22 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
reqs: list[CachedRequestState] = []
req_id_reqs = {}
req_id_output_token_ids = {}
# Add requests
for req_index in range(batch_size):
req: CachedRequestState = _construct_cached_request_state(req_index)
input_batch.add_request(req, req_index)
assigned_req_index = input_batch.add_request(req)
assert req_index == assigned_req_index
reqs.append(req)
req_id_reqs[req.req_id] = req
req_id_output_token_ids[req.req_id] = req.output_token_ids
# Remove some requests
req_ids_to_remove, req_indices_to_remove = _remove_requests(
input_batch, batch_size, reqs)
req_ids_to_remove = _remove_requests(input_batch, batch_size, reqs)
req_ids_retained = set(req_id_reqs.keys()) - req_ids_to_remove
# Compact the input batch
input_batch.condense(req_indices_to_remove)
input_batch.condense()
# Generate the sampling metadata
sampling_metadata = input_batch._make_sampling_metadata()
@@ -290,10 +291,8 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
sampling_metadata.prompt_token_ids)
assert (expected_sampling_metadata.output_token_ids ==
sampling_metadata.output_token_ids)
assert expected_sampling_metadata.min_tokens == sampling_metadata.min_tokens
assert expected_sampling_metadata.no_penalties == \
sampling_metadata.no_penalties
assert expected_sampling_metadata.logit_bias == sampling_metadata.logit_bias
if sampling_metadata.allowed_token_ids_mask:
assert torch.allclose(
expected_sampling_metadata.allowed_token_ids_mask,
@@ -315,6 +314,8 @@ def test_swap_states_in_input_batch(device: str, batch_size: int,
and the `make_sampling_metadata` method is invoked on the batch. The
output of `make_sampling_metadata` is then compared against the expected
results to ensure correctness.
Note: Ignore logits processor logic, which is tested separately
"""
input_batch: InputBatch = InputBatch(
max_num_reqs=batch_size,
@@ -341,7 +342,8 @@ def test_swap_states_in_input_batch(device: str, batch_size: int,
# Add requests
for req_index in range(batch_size):
req: CachedRequestState = _construct_cached_request_state(req_index)
input_batch.add_request(req, req_index)
assigned_req_index = input_batch.add_request(req)
assert assigned_req_index == req_index
reqs.append(req)
req_id_reqs[req.req_id] = req
req_id_output_token_ids[req.req_id] = req.output_token_ids
@@ -354,9 +356,10 @@ def test_swap_states_in_input_batch(device: str, batch_size: int,
for req_index in range(batch_size):
req = reordered_reqs[req_index]
ref_input_batch.add_request(req, req_index)
assigned_req_index = ref_input_batch.add_request(req)
assert assigned_req_index == req_index
input_batch.refresh_sampling_metadata()
ref_input_batch.refresh_sampling_metadata()
input_batch.refresh_metadata()
ref_input_batch.refresh_metadata()
_compare_objs(input_batch, ref_input_batch)