[V1] Aggregate chunked prompt logprobs in model runner (#14875)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@@ -11,6 +11,7 @@ from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MultiModalKwargs
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.utils import swap_dict_values
|
||||
from vllm.v1.outputs import LogprobsTensors
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.utils import copy_slice
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
@@ -197,6 +198,9 @@ class InputBatch:
|
||||
# that are currently in the prefill phase.
|
||||
self.num_prompt_logprobs: dict[str, int] = {}
|
||||
|
||||
# To accumulate prompt logprobs tensor chunks across prefill steps.
|
||||
self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}
|
||||
|
||||
self.logit_bias: list[Optional[dict[int,
|
||||
float]]] = [None] * max_num_reqs
|
||||
self.has_allowed_token_ids: set[str] = set()
|
||||
@@ -362,6 +366,7 @@ class InputBatch:
|
||||
self.generators.pop(req_index, None)
|
||||
self.num_logprobs.pop(req_id, None)
|
||||
self.num_prompt_logprobs.pop(req_id, None)
|
||||
self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
|
||||
|
||||
# LoRA
|
||||
lora_id = self.request_lora_mapping[req_index]
|
||||
|
||||
@@ -1191,6 +1191,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
if not num_prompt_logprobs_dict:
|
||||
return {}
|
||||
|
||||
in_progress_dict = self.input_batch.in_progress_prompt_logprobs_cpu
|
||||
prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {}
|
||||
|
||||
# Since prompt logprobs are a rare feature, prioritize simple,
|
||||
@@ -1206,16 +1207,36 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
prompt_token_ids = torch.tensor(request.prompt_token_ids).to(
|
||||
self.device, non_blocking=True)
|
||||
|
||||
# Set up target LogprobsTensors object.
|
||||
logprobs_tensors = in_progress_dict.get(req_id)
|
||||
if not logprobs_tensors:
|
||||
# Create empty logprobs CPU tensors for the entire prompt.
|
||||
# If chunked, we'll copy in slice by slice.
|
||||
logprobs_tensors = LogprobsTensors.empty_cpu(
|
||||
num_prompt_tokens - 1, num_prompt_logprobs + 1)
|
||||
in_progress_dict[req_id] = logprobs_tensors
|
||||
|
||||
# Determine number of logits to retrieve.
|
||||
start_tok = request.num_computed_tokens + 1
|
||||
start_idx = request.num_computed_tokens
|
||||
start_tok = start_idx + 1
|
||||
num_remaining_tokens = num_prompt_tokens - start_tok
|
||||
if num_tokens < num_remaining_tokens:
|
||||
if num_tokens <= num_remaining_tokens:
|
||||
# This is a chunk, more tokens remain.
|
||||
# In the == case, there are no more prompt logprobs to produce
|
||||
# but we want to defer returning them to the next step where we
|
||||
# have new generated tokens to return.
|
||||
num_logits = num_tokens
|
||||
else:
|
||||
# This is the last chunk of prompt tokens to return.
|
||||
num_logits = num_remaining_tokens
|
||||
completed_prefill_reqs.append(req_id)
|
||||
prompt_logprobs_dict[req_id] = logprobs_tensors
|
||||
|
||||
if num_logits <= 0:
|
||||
# This can happen for the final chunk if we prefilled exactly
|
||||
# (num_prompt_tokens - 1) tokens for this request in the prior
|
||||
# step. There are no more prompt logprobs to produce.
|
||||
continue
|
||||
|
||||
# Get the logits corresponding to this req's prompt tokens.
|
||||
# If this is a partial request (i.e. chunked prefill),
|
||||
@@ -1236,19 +1257,23 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
logprobs, num_prompt_logprobs, tgt_token_ids)
|
||||
|
||||
# Transfer GPU->CPU async.
|
||||
prompt_logprobs_dict[req_id] = LogprobsTensors(
|
||||
token_ids.to("cpu", non_blocking=True),
|
||||
logprobs.to("cpu", non_blocking=True),
|
||||
ranks.to("cpu", non_blocking=True),
|
||||
)
|
||||
chunk_slice = slice(start_idx, start_idx + num_logits)
|
||||
logprobs_tensors.logprob_token_ids[chunk_slice].copy_(
|
||||
token_ids, non_blocking=True)
|
||||
logprobs_tensors.logprobs[chunk_slice].copy_(logprobs,
|
||||
non_blocking=True)
|
||||
logprobs_tensors.selected_token_ranks[chunk_slice].copy_(
|
||||
ranks, non_blocking=True)
|
||||
|
||||
# Remove requests that have completed prefill from the batch
|
||||
# num_prompt_logprobs_dict.
|
||||
for req_id in completed_prefill_reqs:
|
||||
del num_prompt_logprobs_dict[req_id]
|
||||
del in_progress_dict[req_id]
|
||||
|
||||
# Must synchronize the non-blocking GPU->CPU transfers.
|
||||
torch.cuda.synchronize()
|
||||
if prompt_logprobs_dict:
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return prompt_logprobs_dict
|
||||
|
||||
|
||||
Reference in New Issue
Block a user