[V1] Aggregate chunked prompt logprobs in model runner (#14875)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill
2025-03-24 09:27:57 -07:00
committed by GitHub
parent 9cc645141d
commit 3aee6573dc
7 changed files with 68 additions and 44 deletions

View File

@@ -11,6 +11,7 @@ from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalKwargs
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.utils import swap_dict_values
from vllm.v1.outputs import LogprobsTensors
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.utils import copy_slice
from vllm.v1.worker.block_table import BlockTable
@@ -197,6 +198,9 @@ class InputBatch:
# that are currently in the prefill phase.
self.num_prompt_logprobs: dict[str, int] = {}
# To accumulate prompt logprobs tensor chunks across prefill steps.
self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}
self.logit_bias: list[Optional[dict[int,
float]]] = [None] * max_num_reqs
self.has_allowed_token_ids: set[str] = set()
@@ -362,6 +366,7 @@ class InputBatch:
self.generators.pop(req_index, None)
self.num_logprobs.pop(req_id, None)
self.num_prompt_logprobs.pop(req_id, None)
self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
# LoRA
lora_id = self.request_lora_mapping[req_index]

View File

@@ -1191,6 +1191,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if not num_prompt_logprobs_dict:
return {}
in_progress_dict = self.input_batch.in_progress_prompt_logprobs_cpu
prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {}
# Since prompt logprobs are a rare feature, prioritize simple,
@@ -1206,16 +1207,36 @@ class GPUModelRunner(LoRAModelRunnerMixin):
prompt_token_ids = torch.tensor(request.prompt_token_ids).to(
self.device, non_blocking=True)
# Set up target LogprobsTensors object.
logprobs_tensors = in_progress_dict.get(req_id)
if not logprobs_tensors:
# Create empty logprobs CPU tensors for the entire prompt.
# If chunked, we'll copy in slice by slice.
logprobs_tensors = LogprobsTensors.empty_cpu(
num_prompt_tokens - 1, num_prompt_logprobs + 1)
in_progress_dict[req_id] = logprobs_tensors
# Determine number of logits to retrieve.
start_tok = request.num_computed_tokens + 1
start_idx = request.num_computed_tokens
start_tok = start_idx + 1
num_remaining_tokens = num_prompt_tokens - start_tok
if num_tokens < num_remaining_tokens:
if num_tokens <= num_remaining_tokens:
# This is a chunk, more tokens remain.
# In the == case, there are no more prompt logprobs to produce
# but we want to defer returning them to the next step where we
# have new generated tokens to return.
num_logits = num_tokens
else:
# This is the last chunk of prompt tokens to return.
num_logits = num_remaining_tokens
completed_prefill_reqs.append(req_id)
prompt_logprobs_dict[req_id] = logprobs_tensors
if num_logits <= 0:
# This can happen for the final chunk if we prefilled exactly
# (num_prompt_tokens - 1) tokens for this request in the prior
# step. There are no more prompt logprobs to produce.
continue
# Get the logits corresponding to this req's prompt tokens.
# If this is a partial request (i.e. chunked prefill),
@@ -1236,19 +1257,23 @@ class GPUModelRunner(LoRAModelRunnerMixin):
logprobs, num_prompt_logprobs, tgt_token_ids)
# Transfer GPU->CPU async.
prompt_logprobs_dict[req_id] = LogprobsTensors(
token_ids.to("cpu", non_blocking=True),
logprobs.to("cpu", non_blocking=True),
ranks.to("cpu", non_blocking=True),
)
chunk_slice = slice(start_idx, start_idx + num_logits)
logprobs_tensors.logprob_token_ids[chunk_slice].copy_(
token_ids, non_blocking=True)
logprobs_tensors.logprobs[chunk_slice].copy_(logprobs,
non_blocking=True)
logprobs_tensors.selected_token_ranks[chunk_slice].copy_(
ranks, non_blocking=True)
# Remove requests that have completed prefill from the batch
# num_prompt_logprobs_dict.
for req_id in completed_prefill_reqs:
del num_prompt_logprobs_dict[req_id]
del in_progress_dict[req_id]
# Must synchronize the non-blocking GPU->CPU transfers.
torch.cuda.synchronize()
if prompt_logprobs_dict:
torch.cuda.synchronize()
return prompt_logprobs_dict