[Core] add and implement VLLM_LOGITS_PROCESSOR_THREADS (#12368)

Signed-off-by: Aviv Keshet <akeshet@scaledcognition.com>
This commit is contained in:
Aviv Keshet
2025-02-04 18:46:26 -08:00
committed by GitHub
parent 75e94309e8
commit b3a0d01e45
2 changed files with 44 additions and 11 deletions

View File

@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
"""A layer that compute logits from hidden_stats."""
import inspect
from concurrent.futures import ThreadPoolExecutor
from typing import Optional
import torch
@@ -15,6 +16,11 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.platforms import current_platform
_logits_processor_threadpool: Optional[ThreadPoolExecutor] = None
if envs.VLLM_LOGITS_PROCESSOR_THREADS is not None:
_logits_processor_threadpool = ThreadPoolExecutor(
envs.VLLM_LOGITS_PROCESSOR_THREADS)
class LogitsProcessor(nn.Module):
"""Process logits and apply logits processors from sampling metadata.
@@ -135,6 +141,7 @@ def _apply_logits_processors(
) -> torch.Tensor:
found_logits_processors = False
logits_processed = 0
logits_row_ids_and_logits_row_futures = []
for seq_group in sampling_metadata.seq_groups:
seq_ids = seq_group.seq_ids
sampling_params = seq_group.sampling_params
@@ -148,22 +155,39 @@ def _apply_logits_processors(
past_tokens_ids = seq_group.seq_data[seq_id].output_token_ids
prompt_tokens_ids = seq_group.seq_data[seq_id].prompt_token_ids
for logits_processor in logits_processors:
parameters = inspect.signature(logits_processor).parameters
if len(parameters) == 3:
logits_row = logits_processor(prompt_tokens_ids,
past_tokens_ids,
logits_row)
else:
logits_row = logits_processor(past_tokens_ids,
logits_row)
logits[logits_row_idx] = logits_row
if _logits_processor_threadpool is not None:
logits_row_ids_and_logits_row_futures.append(
(logits_row_idx,
_logits_processor_threadpool.submit(
_apply_logits_processors_single_seq, logits_row,
logits_processors, past_tokens_ids,
prompt_tokens_ids)))
else:
logits[logits_row_idx] = \
_apply_logits_processors_single_seq(
logits_row, logits_processors, past_tokens_ids,
prompt_tokens_ids)
logits_processed += len(seq_group.sample_indices) + len(
seq_group.prompt_logprob_indices)
for logits_row_idx, future in logits_row_ids_and_logits_row_futures:
logits[logits_row_idx] = future.result()
if found_logits_processors:
# verifies that no rows in logits were missed unexpectedly
assert logits_processed == logits.shape[0]
return logits
def _apply_logits_processors_single_seq(logits_row, logits_processors,
past_tokens_ids,
prompt_tokens_ids) -> torch.Tensor:
for logits_processor in logits_processors:
parameters = inspect.signature(logits_processor).parameters
if len(parameters) == 3:
logits_row = logits_processor(prompt_tokens_ids, past_tokens_ids,
logits_row)
else:
logits_row = logits_processor(past_tokens_ids, logits_row)
return logits_row