[Model Runner V2] Refactor Prompt Logprobs (#32811)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon
2026-01-21 15:12:20 -08:00
committed by GitHub
parent 63227accf5
commit 408195ec59
4 changed files with 230 additions and 142 deletions

View File

@@ -22,7 +22,6 @@ from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.outputs import ( from vllm.v1.outputs import (
EMPTY_MODEL_RUNNER_OUTPUT, EMPTY_MODEL_RUNNER_OUTPUT,
LogprobsTensors,
ModelRunnerOutput, ModelRunnerOutput,
) )
from vllm.v1.worker.gpu.async_utils import AsyncOutput from vllm.v1.worker.gpu.async_utils import AsyncOutput
@@ -51,8 +50,8 @@ from vllm.v1.worker.gpu.input_batch import (
) )
from vllm.v1.worker.gpu.mm.encoder_runner import EncoderRunner from vllm.v1.worker.gpu.mm.encoder_runner import EncoderRunner
from vllm.v1.worker.gpu.mm.mrope_utils import MRopeState from vllm.v1.worker.gpu.mm.mrope_utils import MRopeState
from vllm.v1.worker.gpu.sample.logprob import compute_prompt_logprobs
from vllm.v1.worker.gpu.sample.output import SamplerOutput from vllm.v1.worker.gpu.sample.output import SamplerOutput
from vllm.v1.worker.gpu.sample.prompt_logprob import PromptLogprobsWorker
from vllm.v1.worker.gpu.sample.sampler import Sampler from vllm.v1.worker.gpu.sample.sampler import Sampler
from vllm.v1.worker.gpu.spec_decode import init_speculator from vllm.v1.worker.gpu.spec_decode import init_speculator
from vllm.v1.worker.gpu.spec_decode.rejection_sample import rejection_sample from vllm.v1.worker.gpu.spec_decode.rejection_sample import rejection_sample
@@ -156,6 +155,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
device=self.device, device=self.device,
logprobs_mode=self.model_config.logprobs_mode, logprobs_mode=self.model_config.logprobs_mode,
) )
self.prompt_logprobs_worker = PromptLogprobsWorker(self.max_num_reqs)
# CUDA graphs. # CUDA graphs.
self.cudagraph_manager = CudaGraphManager( self.cudagraph_manager = CudaGraphManager(
@@ -416,10 +416,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.req_states.remove_request(req_id) self.req_states.remove_request(req_id)
if self.supports_mm_inputs: if self.supports_mm_inputs:
self.encoder_runner.remove_request(req_id) self.encoder_runner.remove_request(req_id)
self.prompt_logprobs_worker.remove_request(req_id)
for req_id in scheduler_output.finished_req_ids: for req_id in scheduler_output.finished_req_ids:
self.req_states.remove_request(req_id) self.req_states.remove_request(req_id)
if self.supports_mm_inputs: if self.supports_mm_inputs:
self.encoder_runner.remove_request(req_id) self.encoder_runner.remove_request(req_id)
self.prompt_logprobs_worker.remove_request(req_id)
def free_states(self, scheduler_output: SchedulerOutput) -> None: def free_states(self, scheduler_output: SchedulerOutput) -> None:
if self.supports_mm_inputs: if self.supports_mm_inputs:
@@ -438,7 +440,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
prompt_len=prompt_len, prompt_len=prompt_len,
prefill_token_ids=new_req_data.prefill_token_ids, prefill_token_ids=new_req_data.prefill_token_ids,
num_computed_tokens=new_req_data.num_computed_tokens, num_computed_tokens=new_req_data.num_computed_tokens,
sampling_params=new_req_data.sampling_params,
lora_request=new_req_data.lora_request, lora_request=new_req_data.lora_request,
) )
req_index = self.req_states.req_id_to_index[req_id] req_index = self.req_states.req_id_to_index[req_id]
@@ -461,6 +462,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.sampler.add_request( self.sampler.add_request(
req_index, prompt_len, new_req_data.sampling_params req_index, prompt_len, new_req_data.sampling_params
) )
self.prompt_logprobs_worker.add_request(
req_id, req_index, new_req_data.sampling_params
)
if scheduler_output.scheduled_new_reqs: if scheduler_output.scheduled_new_reqs:
self.req_states.apply_staged_writes() self.req_states.apply_staged_writes()
@@ -729,104 +733,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
) )
return sampler_output, num_sampled, num_rejected return sampler_output, num_sampled, num_rejected
def compute_prompt_logprobs(
self,
hidden_states: torch.Tensor,
input_batch: InputBatch,
) -> dict[str, LogprobsTensors]:
idx_mapping_np = input_batch.idx_mapping_np
needs_prompt_logprobs = self.req_states.needs_prompt_logprobs[idx_mapping_np]
if not np.any(needs_prompt_logprobs):
# No request asks for prompt logprobs.
return {}
prompt_lens = self.req_states.prompt_len[idx_mapping_np]
# NOTE(woosuk): -1 because the last prompt token's hidden state is not
# needed for prompt logprobs.
computed_prefill = self.req_states.num_computed_prefill_tokens[idx_mapping_np]
includes_prompt = computed_prefill < prompt_lens - 1
# NOTE(woosuk): If the request was resumed after preemption, its prompt
# logprobs must have been computed before preemption. Skip.
resumed_after_prompt = (
prompt_lens < self.req_states.prefill_len.np[idx_mapping_np]
)
needs_prompt_logprobs &= includes_prompt & ~resumed_after_prompt
if not np.any(needs_prompt_logprobs):
return {}
# Just to be safe, clone the input ids.
n = input_batch.num_tokens
# Shift the input ids by one.
token_ids = torch.empty_like(input_batch.input_ids[:n])
token_ids[: n - 1] = input_batch.input_ids[1:n]
# To avoid out-of-bound access, set the last token id to 0.
token_ids[n - 1] = 0
# Handle chunked prompts.
pos_after_step = computed_prefill + input_batch.num_scheduled_tokens
is_prompt_chunked = pos_after_step < prompt_lens
prefill_token_ids = self.req_states.prefill_token_ids.gpu
query_start_loc_np = input_batch.query_start_loc_np
for i, req_id in enumerate(input_batch.req_ids):
if not needs_prompt_logprobs[i]:
continue
if not is_prompt_chunked[i]:
continue
# The prompt is chunked. Get the next prompt token.
req_idx = input_batch.idx_mapping_np[i]
idx = int(query_start_loc_np[i + 1] - 1)
# NOTE(woosuk): This triggers two GPU operations.
next_prompt_token = prefill_token_ids[req_idx, pos_after_step[i]]
token_ids[idx] = next_prompt_token
# NOTE(woosuk): We mask out logprobs for negative tokens.
prompt_logprobs, prompt_ranks = compute_prompt_logprobs(
token_ids,
hidden_states[:n],
self.model.compute_logits,
)
prompt_token_ids = token_ids.unsqueeze(-1)
prompt_logprobs_dict: dict[str, LogprobsTensors] = {}
for i, req_id in enumerate(input_batch.req_ids):
if not needs_prompt_logprobs[i]:
continue
start_idx = query_start_loc_np[i]
end_idx = query_start_loc_np[i + 1]
assert start_idx < end_idx, (
f"start_idx ({start_idx}) >= end_idx ({end_idx})"
)
logprobs = LogprobsTensors(
logprob_token_ids=prompt_token_ids[start_idx:end_idx],
logprobs=prompt_logprobs[start_idx:end_idx],
selected_token_ranks=prompt_ranks[start_idx:end_idx],
)
req_extra_data = self.req_states.extra_data[req_id]
prompt_logprobs_list = req_extra_data.in_progress_prompt_logprobs
if is_prompt_chunked[i]:
# Prompt is chunked. Do not return the logprobs yet.
prompt_logprobs_list.append(logprobs)
continue
if prompt_logprobs_list:
# Merge the in-progress logprobs.
prompt_logprobs_list.append(logprobs)
logprobs = LogprobsTensors(
logprob_token_ids=torch.cat(
[x.logprob_token_ids for x in prompt_logprobs_list]
),
logprobs=torch.cat([x.logprobs for x in prompt_logprobs_list]),
selected_token_ranks=torch.cat(
[x.selected_token_ranks for x in prompt_logprobs_list]
),
)
prompt_logprobs_list.clear()
prompt_logprobs_dict[req_id] = logprobs
return prompt_logprobs_dict
def postprocess( def postprocess(
self, self,
input_batch: InputBatch, input_batch: InputBatch,
@@ -1002,7 +908,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
sampler_output, num_sampled, num_rejected = self.sample( sampler_output, num_sampled, num_rejected = self.sample(
hidden_states, input_batch, grammar_output hidden_states, input_batch, grammar_output
) )
prompt_logprobs_dict = self.compute_prompt_logprobs(hidden_states, input_batch) prompt_logprobs_dict = self.prompt_logprobs_worker.compute_prompt_logprobs(
self.model.compute_logits,
hidden_states,
input_batch,
self.req_states.prefill_token_ids.gpu,
self.req_states.num_computed_tokens.gpu,
self.req_states.prompt_len,
self.req_states.prefill_len.np,
self.req_states.num_computed_prefill_tokens,
)
# Prepare the model runner output. # Prepare the model runner output.
model_runner_output = ModelRunnerOutput( model_runner_output = ModelRunnerOutput(

View File

@@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch import torch
@@ -137,31 +136,3 @@ def compute_topk_logprobs(
logprobs=logprobs, logprobs=logprobs,
selected_token_ranks=token_ranks, selected_token_ranks=token_ranks,
) )
def compute_prompt_logprobs(
prompt_token_ids: torch.Tensor,
prompt_hidden_states: torch.Tensor,
logits_fn: Callable[[torch.Tensor], torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor]:
# Since materializing the full prompt logits can take too much memory,
# we compute it in chunks.
CHUNK_SIZE = 1024
logprobs = []
ranks = []
prompt_token_ids = prompt_token_ids.to(torch.int64)
for start_idx in range(0, prompt_token_ids.shape[0], CHUNK_SIZE):
end_idx = start_idx + CHUNK_SIZE
# NOTE(woosuk): logits_fn can be slow because it involves all-gather.
prompt_logits = logits_fn(prompt_hidden_states[start_idx:end_idx])
prompt_logprobs = compute_topk_logprobs(
prompt_logits,
0, # num_logprobs
prompt_token_ids[start_idx:end_idx],
)
logprobs.append(prompt_logprobs.logprobs)
ranks.append(prompt_logprobs.selected_token_ranks)
logprobs = torch.cat(logprobs, dim=0) if len(logprobs) > 1 else logprobs[0]
ranks = torch.cat(ranks, dim=0) if len(ranks) > 1 else ranks[0]
return logprobs, ranks

View File

@@ -0,0 +1,212 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import numpy as np
import torch
from vllm.sampling_params import SamplingParams
from vllm.triton_utils import tl, triton
from vllm.v1.outputs import LogprobsTensors
from vllm.v1.worker.gpu.input_batch import InputBatch
from vllm.v1.worker.gpu.sample.logprob import compute_topk_logprobs
class PromptLogprobsWorker:
def __init__(self, max_num_reqs: int):
self.max_num_reqs = max_num_reqs
self.uses_prompt_logprobs = np.zeros(self.max_num_reqs, dtype=bool)
# req_idx -> list of in-progress LogprobsTensors
self.in_progress_prompt_logprobs: dict[str, list[LogprobsTensors]] = {}
def add_request(self, req_id: str, req_idx: int, sampling_params: SamplingParams):
# For now, only support prompt logprobs for the prompt tokens (not top-k).
uses_prompt_logprobs = sampling_params.prompt_logprobs is not None
if uses_prompt_logprobs:
self.uses_prompt_logprobs[req_idx] = True
self.in_progress_prompt_logprobs[req_id] = []
else:
self.uses_prompt_logprobs[req_idx] = False
def remove_request(self, req_id: str) -> None:
self.in_progress_prompt_logprobs.pop(req_id, None)
def compute_prompt_logprobs(
self,
logits_fn: Callable[[torch.Tensor], torch.Tensor],
hidden_states: torch.Tensor,
input_batch: InputBatch,
# [max_num_reqs, max_model_len]
prefill_token_ids: torch.Tensor,
# [max_num_reqs]
num_computed_tokens: torch.Tensor,
# [max_num_reqs]
prompt_lens: np.ndarray,
# [max_num_reqs]
prefill_lens: np.ndarray,
# [max_num_reqs]
num_computed_prefill_tokens: np.ndarray,
) -> dict[str, LogprobsTensors]:
idx_mapping_np = input_batch.idx_mapping_np
needs_prompt_logprobs = self.uses_prompt_logprobs[idx_mapping_np]
if not np.any(needs_prompt_logprobs):
# Common case: No request asks for prompt logprobs.
return {}
prompt_lens = prompt_lens[idx_mapping_np]
# NOTE(woosuk): -1 because the last prompt token's hidden state is not
# needed for prompt logprobs.
computed_prefill = num_computed_prefill_tokens[idx_mapping_np]
includes_prompt = computed_prefill < prompt_lens - 1
# NOTE(woosuk): If the request was resumed after preemption, its prompt
# logprobs must have been computed before preemption. Skip.
resumed_after_prompt = prompt_lens < prefill_lens[idx_mapping_np]
needs_prompt_logprobs &= includes_prompt & ~resumed_after_prompt
if not np.any(needs_prompt_logprobs):
return {}
# Get the prompt logprobs token_ids.
prompt_logprobs_token_ids = get_prompt_logprobs_token_ids(
input_batch.num_tokens,
input_batch.query_start_loc,
input_batch.idx_mapping,
num_computed_tokens,
prefill_token_ids,
)
# Compute the prompt logprobs.
prompt_logprobs, prompt_ranks = compute_prompt_logprobs_with_chunking(
prompt_logprobs_token_ids,
hidden_states[: input_batch.num_tokens],
logits_fn,
)
pos_after_step = computed_prefill + input_batch.num_scheduled_tokens
is_prompt_chunked = pos_after_step < prompt_lens
query_start_loc_np = input_batch.query_start_loc_np
prompt_token_ids = prompt_logprobs_token_ids.unsqueeze(-1)
prompt_logprobs_dict: dict[str, LogprobsTensors] = {}
for i, req_id in enumerate(input_batch.req_ids):
if not needs_prompt_logprobs[i]:
continue
start_idx = query_start_loc_np[i]
end_idx = query_start_loc_np[i + 1]
assert start_idx < end_idx, (
f"start_idx ({start_idx}) >= end_idx ({end_idx})"
)
if not is_prompt_chunked[i]:
end_idx -= 1
logprobs = LogprobsTensors(
logprob_token_ids=prompt_token_ids[start_idx:end_idx],
logprobs=prompt_logprobs[start_idx:end_idx],
selected_token_ranks=prompt_ranks[start_idx:end_idx],
)
prompt_logprobs_list = self.in_progress_prompt_logprobs[req_id]
if is_prompt_chunked[i]:
# Prompt is chunked. Do not return the logprobs yet.
prompt_logprobs_list.append(logprobs)
continue
if prompt_logprobs_list:
# Merge the in-progress logprobs.
prompt_logprobs_list.append(logprobs)
logprobs = LogprobsTensors(
logprob_token_ids=torch.cat(
[x.logprob_token_ids for x in prompt_logprobs_list]
),
logprobs=torch.cat([x.logprobs for x in prompt_logprobs_list]),
selected_token_ranks=torch.cat(
[x.selected_token_ranks for x in prompt_logprobs_list]
),
)
prompt_logprobs_list.clear()
prompt_logprobs_dict[req_id] = logprobs
return prompt_logprobs_dict
@triton.jit
def _prompt_logprobs_token_ids_kernel(
prompt_logprobs_token_ids_ptr,
query_start_loc_ptr,
idx_mapping_ptr,
num_computed_tokens_ptr,
prefill_token_ids_ptr,
prefill_token_ids_stride,
BLOCK_SIZE: tl.constexpr,
):
batch_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
query_start = tl.load(query_start_loc_ptr + batch_idx)
query_end = tl.load(query_start_loc_ptr + batch_idx + 1)
query_len = query_end - query_start
num_computed_tokens = tl.load(num_computed_tokens_ptr + req_state_idx)
for i in range(0, query_len, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < query_len
# NOTE(woosuk): We should shift the pos by one
# because the logprob is computed for the next token.
target_pos = num_computed_tokens + 1 + block
token_ids = tl.load(
prefill_token_ids_ptr
+ req_state_idx * prefill_token_ids_stride
+ target_pos,
mask=mask,
)
tl.store(
prompt_logprobs_token_ids_ptr + query_start + block, token_ids, mask=mask
)
def get_prompt_logprobs_token_ids(
num_tokens: int,
query_start_loc: torch.Tensor,
idx_mapping: torch.Tensor,
num_computed_tokens: torch.Tensor,
prefill_token_ids: torch.Tensor,
) -> torch.Tensor:
token_ids = torch.empty(num_tokens, dtype=torch.int64, device=idx_mapping.device)
num_reqs = idx_mapping.shape[0]
_prompt_logprobs_token_ids_kernel[(num_reqs,)](
token_ids,
query_start_loc,
idx_mapping,
num_computed_tokens,
prefill_token_ids,
prefill_token_ids.stride(0),
BLOCK_SIZE=1024,
)
return token_ids
def compute_prompt_logprobs_with_chunking(
prompt_token_ids: torch.Tensor,
prompt_hidden_states: torch.Tensor,
logits_fn: Callable[[torch.Tensor], torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor]:
# Since materializing the full prompt logits can take too much memory,
# we compute it in chunks.
CHUNK_SIZE = 1024
logprobs = []
ranks = []
prompt_token_ids = prompt_token_ids.to(torch.int64)
for start_idx in range(0, prompt_token_ids.shape[0], CHUNK_SIZE):
end_idx = start_idx + CHUNK_SIZE
# NOTE(woosuk): logits_fn can be slow because it involves all-gather.
prompt_logits = logits_fn(prompt_hidden_states[start_idx:end_idx])
prompt_logprobs = compute_topk_logprobs(
prompt_logits,
0, # num_logprobs
prompt_token_ids[start_idx:end_idx],
)
logprobs.append(prompt_logprobs.logprobs)
ranks.append(prompt_logprobs.selected_token_ranks)
logprobs = torch.cat(logprobs, dim=0) if len(logprobs) > 1 else logprobs[0]
ranks = torch.cat(ranks, dim=0) if len(ranks) > 1 else ranks[0]
return logprobs, ranks

View File

@@ -1,13 +1,11 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass, field from dataclasses import dataclass
import numpy as np import numpy as np
import torch import torch
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams
from vllm.v1.outputs import LogprobsTensors
from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor
NO_LORA_ID = 0 NO_LORA_ID = 0
@@ -76,8 +74,6 @@ class RequestState:
self.lora_ids = np.zeros(self.max_num_reqs, dtype=np.int32) self.lora_ids = np.zeros(self.max_num_reqs, dtype=np.int32)
self.lora_ids.fill(NO_LORA_ID) self.lora_ids.fill(NO_LORA_ID)
self.needs_prompt_logprobs = np.zeros(self.max_num_reqs, dtype=bool)
@property @property
def num_reqs(self) -> int: def num_reqs(self) -> int:
return len(self.req_id_to_index) return len(self.req_id_to_index)
@@ -88,7 +84,6 @@ class RequestState:
prompt_len: int, prompt_len: int,
prefill_token_ids: list[int], prefill_token_ids: list[int],
num_computed_tokens: int, num_computed_tokens: int,
sampling_params: SamplingParams,
lora_request: LoRARequest | None, lora_request: LoRARequest | None,
) -> None: ) -> None:
assert len(self.free_indices) > 0, "No free indices" assert len(self.free_indices) > 0, "No free indices"
@@ -112,10 +107,6 @@ class RequestState:
else: else:
self.lora_ids[req_idx] = NO_LORA_ID self.lora_ids[req_idx] = NO_LORA_ID
# For now, only support prompt logprobs for the prompt tokens.
needs_prompt_logprobs = sampling_params.prompt_logprobs is not None
self.needs_prompt_logprobs[req_idx] = needs_prompt_logprobs
def apply_staged_writes(self) -> None: def apply_staged_writes(self) -> None:
self.prefill_len.copy_to_uva() self.prefill_len.copy_to_uva()
self.prefill_token_ids.apply_write() self.prefill_token_ids.apply_write()
@@ -151,4 +142,3 @@ class RequestState:
@dataclass @dataclass
class ExtraData: class ExtraData:
lora_request: LoRARequest | None lora_request: LoRARequest | None
in_progress_prompt_logprobs: list[LogprobsTensors] = field(default_factory=list)