[V1] Logprobs and prompt logprobs support (#9880)
This PR is adding support for sample logprobs & prompt logprobs to vLLM v1. New behavior: - During model execution, model runner computes sample logprobs (if user-provided logprobs setting is not None) and prompt logprobs (if user-provided prompt_logprobs setting is not None). For both sample and prompt logprobs, the engine core returns 3 vectors: token ids, token logprob values, token ranks. Ranks reflect tokens' 1-indexed positions in the vocabulary vector after sorting the vocabulary by log probability in descending order. - In scheduler.update_from_output(), sample and prompt logprobs are incorporated into the EngineCoreOutput data structure which is transferred to the engine client. If multiprocessing is enabled, then sample and prompt logprobs will be (de)serialized when the EngineCoreOutput data structure is (de)serialized. - During output processing, the LogprobsProcessor transforms the triplet of token ids, token logprobs values, and token ranks into the OpenAI-compatible List[Dict[token id,Logprob]] format (for sample and prompt logprobs respectively.) - Each Logprob instance (whether sample- or prompt-) consists of a token's log-probability, rank, and detokenized string representation. Note that logprob detokenization is handled by the LogprobsProcessor not the detokenizer. Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com> Signed-off-by: Nick Hill <nhill@redhat.com> Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com> Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@@ -176,7 +176,9 @@ class InputBatch:
|
||||
self.generators: Dict[int, torch.Generator] = {}
|
||||
|
||||
self.num_logprobs: Dict[str, int] = {}
|
||||
self.prompt_logprob_reqs: Set[str] = set()
|
||||
# NOTE(rob): num_prompt_logprobs only includes reqs
|
||||
# that are currently in the prefill phase.
|
||||
self.num_prompt_logprobs: Dict[str, int] = {}
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
@@ -238,11 +240,10 @@ class InputBatch:
|
||||
if request.generator is not None:
|
||||
self.generators[req_index] = request.generator
|
||||
|
||||
num_logprobs = sampling_params.logprobs
|
||||
if num_logprobs is not None and num_logprobs > 0:
|
||||
self.num_logprobs[req_id] = num_logprobs
|
||||
if sampling_params.prompt_logprobs:
|
||||
self.prompt_logprob_reqs.add(req_id)
|
||||
if sampling_params.logprobs is not None:
|
||||
self.num_logprobs[req_id] = sampling_params.logprobs
|
||||
if sampling_params.prompt_logprobs is not None:
|
||||
self.num_prompt_logprobs[req_id] = sampling_params.prompt_logprobs
|
||||
|
||||
# Add request lora ID
|
||||
if request.lora_request:
|
||||
@@ -272,7 +273,7 @@ class InputBatch:
|
||||
self.repetition_penalties_reqs.discard(req_id)
|
||||
self.generators.pop(req_index, None)
|
||||
self.num_logprobs.pop(req_id, None)
|
||||
self.prompt_logprob_reqs.discard(req_id)
|
||||
self.num_prompt_logprobs.pop(req_id, None)
|
||||
|
||||
# LoRA
|
||||
lora_id = self.request_lora_mapping[req_index]
|
||||
@@ -297,7 +298,7 @@ class InputBatch:
|
||||
self.repetition_penalties_reqs.clear()
|
||||
self.generators.clear()
|
||||
self.num_logprobs.clear()
|
||||
self.prompt_logprob_reqs.clear()
|
||||
self.num_prompt_logprobs.clear()
|
||||
self.request_lora_mapping.fill(0)
|
||||
self.lora_id_to_lora_request.clear()
|
||||
self.lora_id_to_request_ids.clear()
|
||||
@@ -489,13 +490,9 @@ class InputBatch:
|
||||
and len(self.repetition_penalties_reqs) == 0)
|
||||
|
||||
@property
|
||||
def max_num_logprobs(self) -> int:
|
||||
return max(self.num_logprobs.values()) if self.num_logprobs else 0
|
||||
|
||||
@property
|
||||
def no_logprob(self) -> bool:
|
||||
return len(self.num_logprobs) == 0
|
||||
def max_num_logprobs(self) -> Optional[int]:
|
||||
return max(self.num_logprobs.values()) if self.num_logprobs else None
|
||||
|
||||
@property
|
||||
def no_prompt_logprob(self) -> bool:
|
||||
return len(self.prompt_logprob_reqs) == 0
|
||||
return not self.num_prompt_logprobs
|
||||
|
||||
@@ -29,7 +29,7 @@ from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
||||
from vllm.v1.engine.mm_input_mapper import MMInputMapperClient
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheSpec)
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.utils import bind_kv_cache
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
@@ -804,8 +804,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
hidden_states = hidden_states[:num_scheduled_tokens]
|
||||
hidden_states = hidden_states[logits_indices]
|
||||
logits = self.model.compute_logits(hidden_states, None)
|
||||
sample_hidden_states = hidden_states[logits_indices]
|
||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||
|
||||
# Sample the next token and get logprobs if needed.
|
||||
sampling_metadata = self._prepare_sampling(batch_changed)
|
||||
@@ -818,7 +818,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# the requests one by one. Optimize.
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
request_seq_lens: List[Tuple[int, CachedRequestState, int]] = []
|
||||
for i, req_id in enumerate(self.input_batch.req_ids[:num_reqs]):
|
||||
for i, req_id in enumerate( # type: ignore[assignment]
|
||||
self.input_batch.req_ids[:num_reqs]):
|
||||
assert req_id is not None
|
||||
req_state = self.requests[req_id]
|
||||
seq_len = (req_state.num_computed_tokens +
|
||||
@@ -847,27 +848,28 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# NOTE: GPU -> CPU Sync happens here.
|
||||
# Move as many CPU operations as possible before this sync point.
|
||||
sampled_token_ids = sampler_output.sampled_token_ids.tolist()
|
||||
logprobs_tensors = sampler_output.logprobs_tensors
|
||||
logprobs_lists = logprobs_tensors.tolists() \
|
||||
if logprobs_tensors is not None else None
|
||||
|
||||
# Compute prompt logprobs if needed.
|
||||
prompt_logprobs_dict = self._get_prompt_logprobs_dict(
|
||||
hidden_states,
|
||||
scheduler_output,
|
||||
)
|
||||
|
||||
# Update with the actual token ids
|
||||
for i, req_state, seq_len in request_seq_lens:
|
||||
token_id = sampled_token_ids[i]
|
||||
self.input_batch.token_ids_cpu[i, seq_len] = token_id
|
||||
req_state.output_token_ids[-1] = token_id
|
||||
|
||||
if sampler_output.logprob_token_ids is None:
|
||||
logprob_token_ids = None
|
||||
else:
|
||||
logprob_token_ids = sampler_output.logprob_token_ids.cpu()
|
||||
if sampler_output.logprobs is None:
|
||||
logprobs = None
|
||||
else:
|
||||
logprobs = sampler_output.logprobs.cpu()
|
||||
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=self.input_batch.req_id_to_index,
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
logprob_token_ids_cpu=logprob_token_ids,
|
||||
logprobs_cpu=logprobs,
|
||||
logprobs=logprobs_lists,
|
||||
prompt_logprobs_dict=prompt_logprobs_dict,
|
||||
)
|
||||
return model_runner_output
|
||||
|
||||
@@ -886,6 +888,76 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
logger.info("Loading model weights took %.4f GB",
|
||||
self.model_memory_usage / float(2**30))
|
||||
|
||||
def _get_prompt_logprobs_dict(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> Dict[str, LogprobsTensors]:
|
||||
num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs
|
||||
if not num_prompt_logprobs_dict:
|
||||
return {}
|
||||
|
||||
prompt_logprobs_dict: Dict[str, LogprobsTensors] = {}
|
||||
|
||||
# Since prompt logprobs are a rare feature, prioritize simple,
|
||||
# maintainable loop over optimal performance.
|
||||
completed_prefill_reqs = []
|
||||
for req_id, num_prompt_logprobs in num_prompt_logprobs_dict.items():
|
||||
|
||||
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||
|
||||
# Get metadata for this request.
|
||||
request = self.requests[req_id]
|
||||
num_prompt_tokens = len(request.prompt_token_ids)
|
||||
prompt_token_ids = torch.tensor(request.prompt_token_ids).to(
|
||||
self.device, non_blocking=True)
|
||||
|
||||
# Determine number of logits to retrieve.
|
||||
start_tok = request.num_computed_tokens + 1
|
||||
num_remaining_tokens = num_prompt_tokens - start_tok
|
||||
if num_tokens < num_remaining_tokens:
|
||||
# This is a chunk, more tokens remain.
|
||||
num_logits = num_tokens
|
||||
else:
|
||||
# This is the last chunk of prompt tokens to return.
|
||||
num_logits = num_remaining_tokens
|
||||
completed_prefill_reqs.append(req_id)
|
||||
|
||||
# Get the logits corresponding to this req's prompt tokens.
|
||||
# If this is a partial request (i.e. chunked prefill),
|
||||
# then there is prompt logprob generated for each index.
|
||||
req_idx = self.input_batch.req_id_to_index[req_id]
|
||||
offset = self.query_start_loc_np[req_idx].item()
|
||||
prompt_hidden_states = hidden_states[offset:offset + num_logits]
|
||||
logits = self.model.compute_logits(prompt_hidden_states, None)
|
||||
|
||||
# Get the "target" tokens for each index. For prompt at index i,
|
||||
# the token at prompt index i+1 is the "sampled" token we want
|
||||
# to gather the logprob for.
|
||||
tgt_token_ids = prompt_token_ids[start_tok:start_tok + num_logits]
|
||||
|
||||
# Compute prompt logprobs.
|
||||
logprobs = self.model.sampler.compute_logprobs(logits)
|
||||
token_ids, logprobs, ranks = self.model.sampler.gather_logprobs(
|
||||
logprobs, num_prompt_logprobs, tgt_token_ids)
|
||||
|
||||
# Transfer GPU->CPU async.
|
||||
prompt_logprobs_dict[req_id] = LogprobsTensors(
|
||||
token_ids.to("cpu", non_blocking=True),
|
||||
logprobs.to("cpu", non_blocking=True),
|
||||
ranks.to("cpu", non_blocking=True),
|
||||
)
|
||||
|
||||
# Remove requests that have completed prefill from the batch
|
||||
# num_prompt_logprobs_dict.
|
||||
for req_id in completed_prefill_reqs:
|
||||
del num_prompt_logprobs_dict[req_id]
|
||||
|
||||
# Must synchronize the non-blocking GPU->CPU transfers.
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return prompt_logprobs_dict
|
||||
|
||||
@torch.inference_mode()
|
||||
def _dummy_run(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user