[Model Runner V2] Refactor Prompt Logprobs (#32811)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@@ -22,7 +22,6 @@ from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
|||||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||||
from vllm.v1.outputs import (
|
from vllm.v1.outputs import (
|
||||||
EMPTY_MODEL_RUNNER_OUTPUT,
|
EMPTY_MODEL_RUNNER_OUTPUT,
|
||||||
LogprobsTensors,
|
|
||||||
ModelRunnerOutput,
|
ModelRunnerOutput,
|
||||||
)
|
)
|
||||||
from vllm.v1.worker.gpu.async_utils import AsyncOutput
|
from vllm.v1.worker.gpu.async_utils import AsyncOutput
|
||||||
@@ -51,8 +50,8 @@ from vllm.v1.worker.gpu.input_batch import (
|
|||||||
)
|
)
|
||||||
from vllm.v1.worker.gpu.mm.encoder_runner import EncoderRunner
|
from vllm.v1.worker.gpu.mm.encoder_runner import EncoderRunner
|
||||||
from vllm.v1.worker.gpu.mm.mrope_utils import MRopeState
|
from vllm.v1.worker.gpu.mm.mrope_utils import MRopeState
|
||||||
from vllm.v1.worker.gpu.sample.logprob import compute_prompt_logprobs
|
|
||||||
from vllm.v1.worker.gpu.sample.output import SamplerOutput
|
from vllm.v1.worker.gpu.sample.output import SamplerOutput
|
||||||
|
from vllm.v1.worker.gpu.sample.prompt_logprob import PromptLogprobsWorker
|
||||||
from vllm.v1.worker.gpu.sample.sampler import Sampler
|
from vllm.v1.worker.gpu.sample.sampler import Sampler
|
||||||
from vllm.v1.worker.gpu.spec_decode import init_speculator
|
from vllm.v1.worker.gpu.spec_decode import init_speculator
|
||||||
from vllm.v1.worker.gpu.spec_decode.rejection_sample import rejection_sample
|
from vllm.v1.worker.gpu.spec_decode.rejection_sample import rejection_sample
|
||||||
@@ -156,6 +155,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
device=self.device,
|
device=self.device,
|
||||||
logprobs_mode=self.model_config.logprobs_mode,
|
logprobs_mode=self.model_config.logprobs_mode,
|
||||||
)
|
)
|
||||||
|
self.prompt_logprobs_worker = PromptLogprobsWorker(self.max_num_reqs)
|
||||||
|
|
||||||
# CUDA graphs.
|
# CUDA graphs.
|
||||||
self.cudagraph_manager = CudaGraphManager(
|
self.cudagraph_manager = CudaGraphManager(
|
||||||
@@ -416,10 +416,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
self.req_states.remove_request(req_id)
|
self.req_states.remove_request(req_id)
|
||||||
if self.supports_mm_inputs:
|
if self.supports_mm_inputs:
|
||||||
self.encoder_runner.remove_request(req_id)
|
self.encoder_runner.remove_request(req_id)
|
||||||
|
self.prompt_logprobs_worker.remove_request(req_id)
|
||||||
for req_id in scheduler_output.finished_req_ids:
|
for req_id in scheduler_output.finished_req_ids:
|
||||||
self.req_states.remove_request(req_id)
|
self.req_states.remove_request(req_id)
|
||||||
if self.supports_mm_inputs:
|
if self.supports_mm_inputs:
|
||||||
self.encoder_runner.remove_request(req_id)
|
self.encoder_runner.remove_request(req_id)
|
||||||
|
self.prompt_logprobs_worker.remove_request(req_id)
|
||||||
|
|
||||||
def free_states(self, scheduler_output: SchedulerOutput) -> None:
|
def free_states(self, scheduler_output: SchedulerOutput) -> None:
|
||||||
if self.supports_mm_inputs:
|
if self.supports_mm_inputs:
|
||||||
@@ -438,7 +440,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
prompt_len=prompt_len,
|
prompt_len=prompt_len,
|
||||||
prefill_token_ids=new_req_data.prefill_token_ids,
|
prefill_token_ids=new_req_data.prefill_token_ids,
|
||||||
num_computed_tokens=new_req_data.num_computed_tokens,
|
num_computed_tokens=new_req_data.num_computed_tokens,
|
||||||
sampling_params=new_req_data.sampling_params,
|
|
||||||
lora_request=new_req_data.lora_request,
|
lora_request=new_req_data.lora_request,
|
||||||
)
|
)
|
||||||
req_index = self.req_states.req_id_to_index[req_id]
|
req_index = self.req_states.req_id_to_index[req_id]
|
||||||
@@ -461,6 +462,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
self.sampler.add_request(
|
self.sampler.add_request(
|
||||||
req_index, prompt_len, new_req_data.sampling_params
|
req_index, prompt_len, new_req_data.sampling_params
|
||||||
)
|
)
|
||||||
|
self.prompt_logprobs_worker.add_request(
|
||||||
|
req_id, req_index, new_req_data.sampling_params
|
||||||
|
)
|
||||||
|
|
||||||
if scheduler_output.scheduled_new_reqs:
|
if scheduler_output.scheduled_new_reqs:
|
||||||
self.req_states.apply_staged_writes()
|
self.req_states.apply_staged_writes()
|
||||||
@@ -729,104 +733,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
)
|
)
|
||||||
return sampler_output, num_sampled, num_rejected
|
return sampler_output, num_sampled, num_rejected
|
||||||
|
|
||||||
def compute_prompt_logprobs(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
input_batch: InputBatch,
|
|
||||||
) -> dict[str, LogprobsTensors]:
|
|
||||||
idx_mapping_np = input_batch.idx_mapping_np
|
|
||||||
needs_prompt_logprobs = self.req_states.needs_prompt_logprobs[idx_mapping_np]
|
|
||||||
if not np.any(needs_prompt_logprobs):
|
|
||||||
# No request asks for prompt logprobs.
|
|
||||||
return {}
|
|
||||||
|
|
||||||
prompt_lens = self.req_states.prompt_len[idx_mapping_np]
|
|
||||||
# NOTE(woosuk): -1 because the last prompt token's hidden state is not
|
|
||||||
# needed for prompt logprobs.
|
|
||||||
computed_prefill = self.req_states.num_computed_prefill_tokens[idx_mapping_np]
|
|
||||||
includes_prompt = computed_prefill < prompt_lens - 1
|
|
||||||
# NOTE(woosuk): If the request was resumed after preemption, its prompt
|
|
||||||
# logprobs must have been computed before preemption. Skip.
|
|
||||||
resumed_after_prompt = (
|
|
||||||
prompt_lens < self.req_states.prefill_len.np[idx_mapping_np]
|
|
||||||
)
|
|
||||||
needs_prompt_logprobs &= includes_prompt & ~resumed_after_prompt
|
|
||||||
if not np.any(needs_prompt_logprobs):
|
|
||||||
return {}
|
|
||||||
|
|
||||||
# Just to be safe, clone the input ids.
|
|
||||||
n = input_batch.num_tokens
|
|
||||||
# Shift the input ids by one.
|
|
||||||
token_ids = torch.empty_like(input_batch.input_ids[:n])
|
|
||||||
token_ids[: n - 1] = input_batch.input_ids[1:n]
|
|
||||||
# To avoid out-of-bound access, set the last token id to 0.
|
|
||||||
token_ids[n - 1] = 0
|
|
||||||
|
|
||||||
# Handle chunked prompts.
|
|
||||||
pos_after_step = computed_prefill + input_batch.num_scheduled_tokens
|
|
||||||
is_prompt_chunked = pos_after_step < prompt_lens
|
|
||||||
prefill_token_ids = self.req_states.prefill_token_ids.gpu
|
|
||||||
query_start_loc_np = input_batch.query_start_loc_np
|
|
||||||
for i, req_id in enumerate(input_batch.req_ids):
|
|
||||||
if not needs_prompt_logprobs[i]:
|
|
||||||
continue
|
|
||||||
if not is_prompt_chunked[i]:
|
|
||||||
continue
|
|
||||||
# The prompt is chunked. Get the next prompt token.
|
|
||||||
req_idx = input_batch.idx_mapping_np[i]
|
|
||||||
idx = int(query_start_loc_np[i + 1] - 1)
|
|
||||||
# NOTE(woosuk): This triggers two GPU operations.
|
|
||||||
next_prompt_token = prefill_token_ids[req_idx, pos_after_step[i]]
|
|
||||||
token_ids[idx] = next_prompt_token
|
|
||||||
|
|
||||||
# NOTE(woosuk): We mask out logprobs for negative tokens.
|
|
||||||
prompt_logprobs, prompt_ranks = compute_prompt_logprobs(
|
|
||||||
token_ids,
|
|
||||||
hidden_states[:n],
|
|
||||||
self.model.compute_logits,
|
|
||||||
)
|
|
||||||
|
|
||||||
prompt_token_ids = token_ids.unsqueeze(-1)
|
|
||||||
prompt_logprobs_dict: dict[str, LogprobsTensors] = {}
|
|
||||||
for i, req_id in enumerate(input_batch.req_ids):
|
|
||||||
if not needs_prompt_logprobs[i]:
|
|
||||||
continue
|
|
||||||
|
|
||||||
start_idx = query_start_loc_np[i]
|
|
||||||
end_idx = query_start_loc_np[i + 1]
|
|
||||||
assert start_idx < end_idx, (
|
|
||||||
f"start_idx ({start_idx}) >= end_idx ({end_idx})"
|
|
||||||
)
|
|
||||||
logprobs = LogprobsTensors(
|
|
||||||
logprob_token_ids=prompt_token_ids[start_idx:end_idx],
|
|
||||||
logprobs=prompt_logprobs[start_idx:end_idx],
|
|
||||||
selected_token_ranks=prompt_ranks[start_idx:end_idx],
|
|
||||||
)
|
|
||||||
|
|
||||||
req_extra_data = self.req_states.extra_data[req_id]
|
|
||||||
prompt_logprobs_list = req_extra_data.in_progress_prompt_logprobs
|
|
||||||
if is_prompt_chunked[i]:
|
|
||||||
# Prompt is chunked. Do not return the logprobs yet.
|
|
||||||
prompt_logprobs_list.append(logprobs)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if prompt_logprobs_list:
|
|
||||||
# Merge the in-progress logprobs.
|
|
||||||
prompt_logprobs_list.append(logprobs)
|
|
||||||
logprobs = LogprobsTensors(
|
|
||||||
logprob_token_ids=torch.cat(
|
|
||||||
[x.logprob_token_ids for x in prompt_logprobs_list]
|
|
||||||
),
|
|
||||||
logprobs=torch.cat([x.logprobs for x in prompt_logprobs_list]),
|
|
||||||
selected_token_ranks=torch.cat(
|
|
||||||
[x.selected_token_ranks for x in prompt_logprobs_list]
|
|
||||||
),
|
|
||||||
)
|
|
||||||
prompt_logprobs_list.clear()
|
|
||||||
|
|
||||||
prompt_logprobs_dict[req_id] = logprobs
|
|
||||||
return prompt_logprobs_dict
|
|
||||||
|
|
||||||
def postprocess(
|
def postprocess(
|
||||||
self,
|
self,
|
||||||
input_batch: InputBatch,
|
input_batch: InputBatch,
|
||||||
@@ -1002,7 +908,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
sampler_output, num_sampled, num_rejected = self.sample(
|
sampler_output, num_sampled, num_rejected = self.sample(
|
||||||
hidden_states, input_batch, grammar_output
|
hidden_states, input_batch, grammar_output
|
||||||
)
|
)
|
||||||
prompt_logprobs_dict = self.compute_prompt_logprobs(hidden_states, input_batch)
|
prompt_logprobs_dict = self.prompt_logprobs_worker.compute_prompt_logprobs(
|
||||||
|
self.model.compute_logits,
|
||||||
|
hidden_states,
|
||||||
|
input_batch,
|
||||||
|
self.req_states.prefill_token_ids.gpu,
|
||||||
|
self.req_states.num_computed_tokens.gpu,
|
||||||
|
self.req_states.prompt_len,
|
||||||
|
self.req_states.prefill_len.np,
|
||||||
|
self.req_states.num_computed_prefill_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
# Prepare the model runner output.
|
# Prepare the model runner output.
|
||||||
model_runner_output = ModelRunnerOutput(
|
model_runner_output = ModelRunnerOutput(
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from collections.abc import Callable
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -137,31 +136,3 @@ def compute_topk_logprobs(
|
|||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
selected_token_ranks=token_ranks,
|
selected_token_ranks=token_ranks,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def compute_prompt_logprobs(
|
|
||||||
prompt_token_ids: torch.Tensor,
|
|
||||||
prompt_hidden_states: torch.Tensor,
|
|
||||||
logits_fn: Callable[[torch.Tensor], torch.Tensor],
|
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
# Since materializing the full prompt logits can take too much memory,
|
|
||||||
# we compute it in chunks.
|
|
||||||
CHUNK_SIZE = 1024
|
|
||||||
logprobs = []
|
|
||||||
ranks = []
|
|
||||||
prompt_token_ids = prompt_token_ids.to(torch.int64)
|
|
||||||
for start_idx in range(0, prompt_token_ids.shape[0], CHUNK_SIZE):
|
|
||||||
end_idx = start_idx + CHUNK_SIZE
|
|
||||||
# NOTE(woosuk): logits_fn can be slow because it involves all-gather.
|
|
||||||
prompt_logits = logits_fn(prompt_hidden_states[start_idx:end_idx])
|
|
||||||
prompt_logprobs = compute_topk_logprobs(
|
|
||||||
prompt_logits,
|
|
||||||
0, # num_logprobs
|
|
||||||
prompt_token_ids[start_idx:end_idx],
|
|
||||||
)
|
|
||||||
logprobs.append(prompt_logprobs.logprobs)
|
|
||||||
ranks.append(prompt_logprobs.selected_token_ranks)
|
|
||||||
|
|
||||||
logprobs = torch.cat(logprobs, dim=0) if len(logprobs) > 1 else logprobs[0]
|
|
||||||
ranks = torch.cat(ranks, dim=0) if len(ranks) > 1 else ranks[0]
|
|
||||||
return logprobs, ranks
|
|
||||||
|
|||||||
212
vllm/v1/worker/gpu/sample/prompt_logprob.py
Normal file
212
vllm/v1/worker/gpu/sample/prompt_logprob.py
Normal file
@@ -0,0 +1,212 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from collections.abc import Callable
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.sampling_params import SamplingParams
|
||||||
|
from vllm.triton_utils import tl, triton
|
||||||
|
from vllm.v1.outputs import LogprobsTensors
|
||||||
|
from vllm.v1.worker.gpu.input_batch import InputBatch
|
||||||
|
from vllm.v1.worker.gpu.sample.logprob import compute_topk_logprobs
|
||||||
|
|
||||||
|
|
||||||
|
class PromptLogprobsWorker:
|
||||||
|
def __init__(self, max_num_reqs: int):
|
||||||
|
self.max_num_reqs = max_num_reqs
|
||||||
|
|
||||||
|
self.uses_prompt_logprobs = np.zeros(self.max_num_reqs, dtype=bool)
|
||||||
|
# req_idx -> list of in-progress LogprobsTensors
|
||||||
|
self.in_progress_prompt_logprobs: dict[str, list[LogprobsTensors]] = {}
|
||||||
|
|
||||||
|
def add_request(self, req_id: str, req_idx: int, sampling_params: SamplingParams):
|
||||||
|
# For now, only support prompt logprobs for the prompt tokens (not top-k).
|
||||||
|
uses_prompt_logprobs = sampling_params.prompt_logprobs is not None
|
||||||
|
if uses_prompt_logprobs:
|
||||||
|
self.uses_prompt_logprobs[req_idx] = True
|
||||||
|
self.in_progress_prompt_logprobs[req_id] = []
|
||||||
|
else:
|
||||||
|
self.uses_prompt_logprobs[req_idx] = False
|
||||||
|
|
||||||
|
def remove_request(self, req_id: str) -> None:
|
||||||
|
self.in_progress_prompt_logprobs.pop(req_id, None)
|
||||||
|
|
||||||
|
def compute_prompt_logprobs(
|
||||||
|
self,
|
||||||
|
logits_fn: Callable[[torch.Tensor], torch.Tensor],
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
input_batch: InputBatch,
|
||||||
|
# [max_num_reqs, max_model_len]
|
||||||
|
prefill_token_ids: torch.Tensor,
|
||||||
|
# [max_num_reqs]
|
||||||
|
num_computed_tokens: torch.Tensor,
|
||||||
|
# [max_num_reqs]
|
||||||
|
prompt_lens: np.ndarray,
|
||||||
|
# [max_num_reqs]
|
||||||
|
prefill_lens: np.ndarray,
|
||||||
|
# [max_num_reqs]
|
||||||
|
num_computed_prefill_tokens: np.ndarray,
|
||||||
|
) -> dict[str, LogprobsTensors]:
|
||||||
|
idx_mapping_np = input_batch.idx_mapping_np
|
||||||
|
needs_prompt_logprobs = self.uses_prompt_logprobs[idx_mapping_np]
|
||||||
|
if not np.any(needs_prompt_logprobs):
|
||||||
|
# Common case: No request asks for prompt logprobs.
|
||||||
|
return {}
|
||||||
|
|
||||||
|
prompt_lens = prompt_lens[idx_mapping_np]
|
||||||
|
# NOTE(woosuk): -1 because the last prompt token's hidden state is not
|
||||||
|
# needed for prompt logprobs.
|
||||||
|
computed_prefill = num_computed_prefill_tokens[idx_mapping_np]
|
||||||
|
includes_prompt = computed_prefill < prompt_lens - 1
|
||||||
|
# NOTE(woosuk): If the request was resumed after preemption, its prompt
|
||||||
|
# logprobs must have been computed before preemption. Skip.
|
||||||
|
resumed_after_prompt = prompt_lens < prefill_lens[idx_mapping_np]
|
||||||
|
needs_prompt_logprobs &= includes_prompt & ~resumed_after_prompt
|
||||||
|
if not np.any(needs_prompt_logprobs):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# Get the prompt logprobs token_ids.
|
||||||
|
prompt_logprobs_token_ids = get_prompt_logprobs_token_ids(
|
||||||
|
input_batch.num_tokens,
|
||||||
|
input_batch.query_start_loc,
|
||||||
|
input_batch.idx_mapping,
|
||||||
|
num_computed_tokens,
|
||||||
|
prefill_token_ids,
|
||||||
|
)
|
||||||
|
# Compute the prompt logprobs.
|
||||||
|
prompt_logprobs, prompt_ranks = compute_prompt_logprobs_with_chunking(
|
||||||
|
prompt_logprobs_token_ids,
|
||||||
|
hidden_states[: input_batch.num_tokens],
|
||||||
|
logits_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
pos_after_step = computed_prefill + input_batch.num_scheduled_tokens
|
||||||
|
is_prompt_chunked = pos_after_step < prompt_lens
|
||||||
|
|
||||||
|
query_start_loc_np = input_batch.query_start_loc_np
|
||||||
|
prompt_token_ids = prompt_logprobs_token_ids.unsqueeze(-1)
|
||||||
|
prompt_logprobs_dict: dict[str, LogprobsTensors] = {}
|
||||||
|
for i, req_id in enumerate(input_batch.req_ids):
|
||||||
|
if not needs_prompt_logprobs[i]:
|
||||||
|
continue
|
||||||
|
|
||||||
|
start_idx = query_start_loc_np[i]
|
||||||
|
end_idx = query_start_loc_np[i + 1]
|
||||||
|
assert start_idx < end_idx, (
|
||||||
|
f"start_idx ({start_idx}) >= end_idx ({end_idx})"
|
||||||
|
)
|
||||||
|
if not is_prompt_chunked[i]:
|
||||||
|
end_idx -= 1
|
||||||
|
logprobs = LogprobsTensors(
|
||||||
|
logprob_token_ids=prompt_token_ids[start_idx:end_idx],
|
||||||
|
logprobs=prompt_logprobs[start_idx:end_idx],
|
||||||
|
selected_token_ranks=prompt_ranks[start_idx:end_idx],
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt_logprobs_list = self.in_progress_prompt_logprobs[req_id]
|
||||||
|
if is_prompt_chunked[i]:
|
||||||
|
# Prompt is chunked. Do not return the logprobs yet.
|
||||||
|
prompt_logprobs_list.append(logprobs)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if prompt_logprobs_list:
|
||||||
|
# Merge the in-progress logprobs.
|
||||||
|
prompt_logprobs_list.append(logprobs)
|
||||||
|
logprobs = LogprobsTensors(
|
||||||
|
logprob_token_ids=torch.cat(
|
||||||
|
[x.logprob_token_ids for x in prompt_logprobs_list]
|
||||||
|
),
|
||||||
|
logprobs=torch.cat([x.logprobs for x in prompt_logprobs_list]),
|
||||||
|
selected_token_ranks=torch.cat(
|
||||||
|
[x.selected_token_ranks for x in prompt_logprobs_list]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
prompt_logprobs_list.clear()
|
||||||
|
|
||||||
|
prompt_logprobs_dict[req_id] = logprobs
|
||||||
|
return prompt_logprobs_dict
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _prompt_logprobs_token_ids_kernel(
|
||||||
|
prompt_logprobs_token_ids_ptr,
|
||||||
|
query_start_loc_ptr,
|
||||||
|
idx_mapping_ptr,
|
||||||
|
num_computed_tokens_ptr,
|
||||||
|
prefill_token_ids_ptr,
|
||||||
|
prefill_token_ids_stride,
|
||||||
|
BLOCK_SIZE: tl.constexpr,
|
||||||
|
):
|
||||||
|
batch_idx = tl.program_id(0)
|
||||||
|
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
|
||||||
|
|
||||||
|
query_start = tl.load(query_start_loc_ptr + batch_idx)
|
||||||
|
query_end = tl.load(query_start_loc_ptr + batch_idx + 1)
|
||||||
|
query_len = query_end - query_start
|
||||||
|
|
||||||
|
num_computed_tokens = tl.load(num_computed_tokens_ptr + req_state_idx)
|
||||||
|
for i in range(0, query_len, BLOCK_SIZE):
|
||||||
|
block = i + tl.arange(0, BLOCK_SIZE)
|
||||||
|
mask = block < query_len
|
||||||
|
# NOTE(woosuk): We should shift the pos by one
|
||||||
|
# because the logprob is computed for the next token.
|
||||||
|
target_pos = num_computed_tokens + 1 + block
|
||||||
|
token_ids = tl.load(
|
||||||
|
prefill_token_ids_ptr
|
||||||
|
+ req_state_idx * prefill_token_ids_stride
|
||||||
|
+ target_pos,
|
||||||
|
mask=mask,
|
||||||
|
)
|
||||||
|
tl.store(
|
||||||
|
prompt_logprobs_token_ids_ptr + query_start + block, token_ids, mask=mask
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_prompt_logprobs_token_ids(
|
||||||
|
num_tokens: int,
|
||||||
|
query_start_loc: torch.Tensor,
|
||||||
|
idx_mapping: torch.Tensor,
|
||||||
|
num_computed_tokens: torch.Tensor,
|
||||||
|
prefill_token_ids: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
token_ids = torch.empty(num_tokens, dtype=torch.int64, device=idx_mapping.device)
|
||||||
|
num_reqs = idx_mapping.shape[0]
|
||||||
|
_prompt_logprobs_token_ids_kernel[(num_reqs,)](
|
||||||
|
token_ids,
|
||||||
|
query_start_loc,
|
||||||
|
idx_mapping,
|
||||||
|
num_computed_tokens,
|
||||||
|
prefill_token_ids,
|
||||||
|
prefill_token_ids.stride(0),
|
||||||
|
BLOCK_SIZE=1024,
|
||||||
|
)
|
||||||
|
return token_ids
|
||||||
|
|
||||||
|
|
||||||
|
def compute_prompt_logprobs_with_chunking(
|
||||||
|
prompt_token_ids: torch.Tensor,
|
||||||
|
prompt_hidden_states: torch.Tensor,
|
||||||
|
logits_fn: Callable[[torch.Tensor], torch.Tensor],
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
# Since materializing the full prompt logits can take too much memory,
|
||||||
|
# we compute it in chunks.
|
||||||
|
CHUNK_SIZE = 1024
|
||||||
|
logprobs = []
|
||||||
|
ranks = []
|
||||||
|
prompt_token_ids = prompt_token_ids.to(torch.int64)
|
||||||
|
for start_idx in range(0, prompt_token_ids.shape[0], CHUNK_SIZE):
|
||||||
|
end_idx = start_idx + CHUNK_SIZE
|
||||||
|
# NOTE(woosuk): logits_fn can be slow because it involves all-gather.
|
||||||
|
prompt_logits = logits_fn(prompt_hidden_states[start_idx:end_idx])
|
||||||
|
prompt_logprobs = compute_topk_logprobs(
|
||||||
|
prompt_logits,
|
||||||
|
0, # num_logprobs
|
||||||
|
prompt_token_ids[start_idx:end_idx],
|
||||||
|
)
|
||||||
|
logprobs.append(prompt_logprobs.logprobs)
|
||||||
|
ranks.append(prompt_logprobs.selected_token_ranks)
|
||||||
|
|
||||||
|
logprobs = torch.cat(logprobs, dim=0) if len(logprobs) > 1 else logprobs[0]
|
||||||
|
ranks = torch.cat(ranks, dim=0) if len(ranks) > 1 else ranks[0]
|
||||||
|
return logprobs, ranks
|
||||||
@@ -1,13 +1,11 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.sampling_params import SamplingParams
|
|
||||||
from vllm.v1.outputs import LogprobsTensors
|
|
||||||
from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor
|
from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor
|
||||||
|
|
||||||
NO_LORA_ID = 0
|
NO_LORA_ID = 0
|
||||||
@@ -76,8 +74,6 @@ class RequestState:
|
|||||||
self.lora_ids = np.zeros(self.max_num_reqs, dtype=np.int32)
|
self.lora_ids = np.zeros(self.max_num_reqs, dtype=np.int32)
|
||||||
self.lora_ids.fill(NO_LORA_ID)
|
self.lora_ids.fill(NO_LORA_ID)
|
||||||
|
|
||||||
self.needs_prompt_logprobs = np.zeros(self.max_num_reqs, dtype=bool)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_reqs(self) -> int:
|
def num_reqs(self) -> int:
|
||||||
return len(self.req_id_to_index)
|
return len(self.req_id_to_index)
|
||||||
@@ -88,7 +84,6 @@ class RequestState:
|
|||||||
prompt_len: int,
|
prompt_len: int,
|
||||||
prefill_token_ids: list[int],
|
prefill_token_ids: list[int],
|
||||||
num_computed_tokens: int,
|
num_computed_tokens: int,
|
||||||
sampling_params: SamplingParams,
|
|
||||||
lora_request: LoRARequest | None,
|
lora_request: LoRARequest | None,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert len(self.free_indices) > 0, "No free indices"
|
assert len(self.free_indices) > 0, "No free indices"
|
||||||
@@ -112,10 +107,6 @@ class RequestState:
|
|||||||
else:
|
else:
|
||||||
self.lora_ids[req_idx] = NO_LORA_ID
|
self.lora_ids[req_idx] = NO_LORA_ID
|
||||||
|
|
||||||
# For now, only support prompt logprobs for the prompt tokens.
|
|
||||||
needs_prompt_logprobs = sampling_params.prompt_logprobs is not None
|
|
||||||
self.needs_prompt_logprobs[req_idx] = needs_prompt_logprobs
|
|
||||||
|
|
||||||
def apply_staged_writes(self) -> None:
|
def apply_staged_writes(self) -> None:
|
||||||
self.prefill_len.copy_to_uva()
|
self.prefill_len.copy_to_uva()
|
||||||
self.prefill_token_ids.apply_write()
|
self.prefill_token_ids.apply_write()
|
||||||
@@ -151,4 +142,3 @@ class RequestState:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class ExtraData:
|
class ExtraData:
|
||||||
lora_request: LoRARequest | None
|
lora_request: LoRARequest | None
|
||||||
in_progress_prompt_logprobs: list[LogprobsTensors] = field(default_factory=list)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user