diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 4aad46385..7a6b1732b 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -22,7 +22,6 @@ from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.outputs import ( EMPTY_MODEL_RUNNER_OUTPUT, - LogprobsTensors, ModelRunnerOutput, ) from vllm.v1.worker.gpu.async_utils import AsyncOutput @@ -51,8 +50,8 @@ from vllm.v1.worker.gpu.input_batch import ( ) from vllm.v1.worker.gpu.mm.encoder_runner import EncoderRunner from vllm.v1.worker.gpu.mm.mrope_utils import MRopeState -from vllm.v1.worker.gpu.sample.logprob import compute_prompt_logprobs from vllm.v1.worker.gpu.sample.output import SamplerOutput +from vllm.v1.worker.gpu.sample.prompt_logprob import PromptLogprobsWorker from vllm.v1.worker.gpu.sample.sampler import Sampler from vllm.v1.worker.gpu.spec_decode import init_speculator from vllm.v1.worker.gpu.spec_decode.rejection_sample import rejection_sample @@ -156,6 +155,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): device=self.device, logprobs_mode=self.model_config.logprobs_mode, ) + self.prompt_logprobs_worker = PromptLogprobsWorker(self.max_num_reqs) # CUDA graphs. self.cudagraph_manager = CudaGraphManager( @@ -416,10 +416,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.req_states.remove_request(req_id) if self.supports_mm_inputs: self.encoder_runner.remove_request(req_id) + self.prompt_logprobs_worker.remove_request(req_id) for req_id in scheduler_output.finished_req_ids: self.req_states.remove_request(req_id) if self.supports_mm_inputs: self.encoder_runner.remove_request(req_id) + self.prompt_logprobs_worker.remove_request(req_id) def free_states(self, scheduler_output: SchedulerOutput) -> None: if self.supports_mm_inputs: @@ -438,7 +440,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): prompt_len=prompt_len, prefill_token_ids=new_req_data.prefill_token_ids, num_computed_tokens=new_req_data.num_computed_tokens, - sampling_params=new_req_data.sampling_params, lora_request=new_req_data.lora_request, ) req_index = self.req_states.req_id_to_index[req_id] @@ -461,6 +462,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.sampler.add_request( req_index, prompt_len, new_req_data.sampling_params ) + self.prompt_logprobs_worker.add_request( + req_id, req_index, new_req_data.sampling_params + ) if scheduler_output.scheduled_new_reqs: self.req_states.apply_staged_writes() @@ -729,104 +733,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ) return sampler_output, num_sampled, num_rejected - def compute_prompt_logprobs( - self, - hidden_states: torch.Tensor, - input_batch: InputBatch, - ) -> dict[str, LogprobsTensors]: - idx_mapping_np = input_batch.idx_mapping_np - needs_prompt_logprobs = self.req_states.needs_prompt_logprobs[idx_mapping_np] - if not np.any(needs_prompt_logprobs): - # No request asks for prompt logprobs. - return {} - - prompt_lens = self.req_states.prompt_len[idx_mapping_np] - # NOTE(woosuk): -1 because the last prompt token's hidden state is not - # needed for prompt logprobs. - computed_prefill = self.req_states.num_computed_prefill_tokens[idx_mapping_np] - includes_prompt = computed_prefill < prompt_lens - 1 - # NOTE(woosuk): If the request was resumed after preemption, its prompt - # logprobs must have been computed before preemption. Skip. - resumed_after_prompt = ( - prompt_lens < self.req_states.prefill_len.np[idx_mapping_np] - ) - needs_prompt_logprobs &= includes_prompt & ~resumed_after_prompt - if not np.any(needs_prompt_logprobs): - return {} - - # Just to be safe, clone the input ids. - n = input_batch.num_tokens - # Shift the input ids by one. - token_ids = torch.empty_like(input_batch.input_ids[:n]) - token_ids[: n - 1] = input_batch.input_ids[1:n] - # To avoid out-of-bound access, set the last token id to 0. - token_ids[n - 1] = 0 - - # Handle chunked prompts. - pos_after_step = computed_prefill + input_batch.num_scheduled_tokens - is_prompt_chunked = pos_after_step < prompt_lens - prefill_token_ids = self.req_states.prefill_token_ids.gpu - query_start_loc_np = input_batch.query_start_loc_np - for i, req_id in enumerate(input_batch.req_ids): - if not needs_prompt_logprobs[i]: - continue - if not is_prompt_chunked[i]: - continue - # The prompt is chunked. Get the next prompt token. - req_idx = input_batch.idx_mapping_np[i] - idx = int(query_start_loc_np[i + 1] - 1) - # NOTE(woosuk): This triggers two GPU operations. - next_prompt_token = prefill_token_ids[req_idx, pos_after_step[i]] - token_ids[idx] = next_prompt_token - - # NOTE(woosuk): We mask out logprobs for negative tokens. - prompt_logprobs, prompt_ranks = compute_prompt_logprobs( - token_ids, - hidden_states[:n], - self.model.compute_logits, - ) - - prompt_token_ids = token_ids.unsqueeze(-1) - prompt_logprobs_dict: dict[str, LogprobsTensors] = {} - for i, req_id in enumerate(input_batch.req_ids): - if not needs_prompt_logprobs[i]: - continue - - start_idx = query_start_loc_np[i] - end_idx = query_start_loc_np[i + 1] - assert start_idx < end_idx, ( - f"start_idx ({start_idx}) >= end_idx ({end_idx})" - ) - logprobs = LogprobsTensors( - logprob_token_ids=prompt_token_ids[start_idx:end_idx], - logprobs=prompt_logprobs[start_idx:end_idx], - selected_token_ranks=prompt_ranks[start_idx:end_idx], - ) - - req_extra_data = self.req_states.extra_data[req_id] - prompt_logprobs_list = req_extra_data.in_progress_prompt_logprobs - if is_prompt_chunked[i]: - # Prompt is chunked. Do not return the logprobs yet. - prompt_logprobs_list.append(logprobs) - continue - - if prompt_logprobs_list: - # Merge the in-progress logprobs. - prompt_logprobs_list.append(logprobs) - logprobs = LogprobsTensors( - logprob_token_ids=torch.cat( - [x.logprob_token_ids for x in prompt_logprobs_list] - ), - logprobs=torch.cat([x.logprobs for x in prompt_logprobs_list]), - selected_token_ranks=torch.cat( - [x.selected_token_ranks for x in prompt_logprobs_list] - ), - ) - prompt_logprobs_list.clear() - - prompt_logprobs_dict[req_id] = logprobs - return prompt_logprobs_dict - def postprocess( self, input_batch: InputBatch, @@ -1002,7 +908,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): sampler_output, num_sampled, num_rejected = self.sample( hidden_states, input_batch, grammar_output ) - prompt_logprobs_dict = self.compute_prompt_logprobs(hidden_states, input_batch) + prompt_logprobs_dict = self.prompt_logprobs_worker.compute_prompt_logprobs( + self.model.compute_logits, + hidden_states, + input_batch, + self.req_states.prefill_token_ids.gpu, + self.req_states.num_computed_tokens.gpu, + self.req_states.prompt_len, + self.req_states.prefill_len.np, + self.req_states.num_computed_prefill_tokens, + ) # Prepare the model runner output. model_runner_output = ModelRunnerOutput( diff --git a/vllm/v1/worker/gpu/sample/logprob.py b/vllm/v1/worker/gpu/sample/logprob.py index 25448b387..3c8f89f21 100644 --- a/vllm/v1/worker/gpu/sample/logprob.py +++ b/vllm/v1/worker/gpu/sample/logprob.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Callable import torch @@ -137,31 +136,3 @@ def compute_topk_logprobs( logprobs=logprobs, selected_token_ranks=token_ranks, ) - - -def compute_prompt_logprobs( - prompt_token_ids: torch.Tensor, - prompt_hidden_states: torch.Tensor, - logits_fn: Callable[[torch.Tensor], torch.Tensor], -) -> tuple[torch.Tensor, torch.Tensor]: - # Since materializing the full prompt logits can take too much memory, - # we compute it in chunks. - CHUNK_SIZE = 1024 - logprobs = [] - ranks = [] - prompt_token_ids = prompt_token_ids.to(torch.int64) - for start_idx in range(0, prompt_token_ids.shape[0], CHUNK_SIZE): - end_idx = start_idx + CHUNK_SIZE - # NOTE(woosuk): logits_fn can be slow because it involves all-gather. - prompt_logits = logits_fn(prompt_hidden_states[start_idx:end_idx]) - prompt_logprobs = compute_topk_logprobs( - prompt_logits, - 0, # num_logprobs - prompt_token_ids[start_idx:end_idx], - ) - logprobs.append(prompt_logprobs.logprobs) - ranks.append(prompt_logprobs.selected_token_ranks) - - logprobs = torch.cat(logprobs, dim=0) if len(logprobs) > 1 else logprobs[0] - ranks = torch.cat(ranks, dim=0) if len(ranks) > 1 else ranks[0] - return logprobs, ranks diff --git a/vllm/v1/worker/gpu/sample/prompt_logprob.py b/vllm/v1/worker/gpu/sample/prompt_logprob.py new file mode 100644 index 000000000..896541af1 --- /dev/null +++ b/vllm/v1/worker/gpu/sample/prompt_logprob.py @@ -0,0 +1,212 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Callable + +import numpy as np +import torch + +from vllm.sampling_params import SamplingParams +from vllm.triton_utils import tl, triton +from vllm.v1.outputs import LogprobsTensors +from vllm.v1.worker.gpu.input_batch import InputBatch +from vllm.v1.worker.gpu.sample.logprob import compute_topk_logprobs + + +class PromptLogprobsWorker: + def __init__(self, max_num_reqs: int): + self.max_num_reqs = max_num_reqs + + self.uses_prompt_logprobs = np.zeros(self.max_num_reqs, dtype=bool) + # req_idx -> list of in-progress LogprobsTensors + self.in_progress_prompt_logprobs: dict[str, list[LogprobsTensors]] = {} + + def add_request(self, req_id: str, req_idx: int, sampling_params: SamplingParams): + # For now, only support prompt logprobs for the prompt tokens (not top-k). + uses_prompt_logprobs = sampling_params.prompt_logprobs is not None + if uses_prompt_logprobs: + self.uses_prompt_logprobs[req_idx] = True + self.in_progress_prompt_logprobs[req_id] = [] + else: + self.uses_prompt_logprobs[req_idx] = False + + def remove_request(self, req_id: str) -> None: + self.in_progress_prompt_logprobs.pop(req_id, None) + + def compute_prompt_logprobs( + self, + logits_fn: Callable[[torch.Tensor], torch.Tensor], + hidden_states: torch.Tensor, + input_batch: InputBatch, + # [max_num_reqs, max_model_len] + prefill_token_ids: torch.Tensor, + # [max_num_reqs] + num_computed_tokens: torch.Tensor, + # [max_num_reqs] + prompt_lens: np.ndarray, + # [max_num_reqs] + prefill_lens: np.ndarray, + # [max_num_reqs] + num_computed_prefill_tokens: np.ndarray, + ) -> dict[str, LogprobsTensors]: + idx_mapping_np = input_batch.idx_mapping_np + needs_prompt_logprobs = self.uses_prompt_logprobs[idx_mapping_np] + if not np.any(needs_prompt_logprobs): + # Common case: No request asks for prompt logprobs. + return {} + + prompt_lens = prompt_lens[idx_mapping_np] + # NOTE(woosuk): -1 because the last prompt token's hidden state is not + # needed for prompt logprobs. + computed_prefill = num_computed_prefill_tokens[idx_mapping_np] + includes_prompt = computed_prefill < prompt_lens - 1 + # NOTE(woosuk): If the request was resumed after preemption, its prompt + # logprobs must have been computed before preemption. Skip. + resumed_after_prompt = prompt_lens < prefill_lens[idx_mapping_np] + needs_prompt_logprobs &= includes_prompt & ~resumed_after_prompt + if not np.any(needs_prompt_logprobs): + return {} + + # Get the prompt logprobs token_ids. + prompt_logprobs_token_ids = get_prompt_logprobs_token_ids( + input_batch.num_tokens, + input_batch.query_start_loc, + input_batch.idx_mapping, + num_computed_tokens, + prefill_token_ids, + ) + # Compute the prompt logprobs. + prompt_logprobs, prompt_ranks = compute_prompt_logprobs_with_chunking( + prompt_logprobs_token_ids, + hidden_states[: input_batch.num_tokens], + logits_fn, + ) + + pos_after_step = computed_prefill + input_batch.num_scheduled_tokens + is_prompt_chunked = pos_after_step < prompt_lens + + query_start_loc_np = input_batch.query_start_loc_np + prompt_token_ids = prompt_logprobs_token_ids.unsqueeze(-1) + prompt_logprobs_dict: dict[str, LogprobsTensors] = {} + for i, req_id in enumerate(input_batch.req_ids): + if not needs_prompt_logprobs[i]: + continue + + start_idx = query_start_loc_np[i] + end_idx = query_start_loc_np[i + 1] + assert start_idx < end_idx, ( + f"start_idx ({start_idx}) >= end_idx ({end_idx})" + ) + if not is_prompt_chunked[i]: + end_idx -= 1 + logprobs = LogprobsTensors( + logprob_token_ids=prompt_token_ids[start_idx:end_idx], + logprobs=prompt_logprobs[start_idx:end_idx], + selected_token_ranks=prompt_ranks[start_idx:end_idx], + ) + + prompt_logprobs_list = self.in_progress_prompt_logprobs[req_id] + if is_prompt_chunked[i]: + # Prompt is chunked. Do not return the logprobs yet. + prompt_logprobs_list.append(logprobs) + continue + + if prompt_logprobs_list: + # Merge the in-progress logprobs. + prompt_logprobs_list.append(logprobs) + logprobs = LogprobsTensors( + logprob_token_ids=torch.cat( + [x.logprob_token_ids for x in prompt_logprobs_list] + ), + logprobs=torch.cat([x.logprobs for x in prompt_logprobs_list]), + selected_token_ranks=torch.cat( + [x.selected_token_ranks for x in prompt_logprobs_list] + ), + ) + prompt_logprobs_list.clear() + + prompt_logprobs_dict[req_id] = logprobs + return prompt_logprobs_dict + + +@triton.jit +def _prompt_logprobs_token_ids_kernel( + prompt_logprobs_token_ids_ptr, + query_start_loc_ptr, + idx_mapping_ptr, + num_computed_tokens_ptr, + prefill_token_ids_ptr, + prefill_token_ids_stride, + BLOCK_SIZE: tl.constexpr, +): + batch_idx = tl.program_id(0) + req_state_idx = tl.load(idx_mapping_ptr + batch_idx) + + query_start = tl.load(query_start_loc_ptr + batch_idx) + query_end = tl.load(query_start_loc_ptr + batch_idx + 1) + query_len = query_end - query_start + + num_computed_tokens = tl.load(num_computed_tokens_ptr + req_state_idx) + for i in range(0, query_len, BLOCK_SIZE): + block = i + tl.arange(0, BLOCK_SIZE) + mask = block < query_len + # NOTE(woosuk): We should shift the pos by one + # because the logprob is computed for the next token. + target_pos = num_computed_tokens + 1 + block + token_ids = tl.load( + prefill_token_ids_ptr + + req_state_idx * prefill_token_ids_stride + + target_pos, + mask=mask, + ) + tl.store( + prompt_logprobs_token_ids_ptr + query_start + block, token_ids, mask=mask + ) + + +def get_prompt_logprobs_token_ids( + num_tokens: int, + query_start_loc: torch.Tensor, + idx_mapping: torch.Tensor, + num_computed_tokens: torch.Tensor, + prefill_token_ids: torch.Tensor, +) -> torch.Tensor: + token_ids = torch.empty(num_tokens, dtype=torch.int64, device=idx_mapping.device) + num_reqs = idx_mapping.shape[0] + _prompt_logprobs_token_ids_kernel[(num_reqs,)]( + token_ids, + query_start_loc, + idx_mapping, + num_computed_tokens, + prefill_token_ids, + prefill_token_ids.stride(0), + BLOCK_SIZE=1024, + ) + return token_ids + + +def compute_prompt_logprobs_with_chunking( + prompt_token_ids: torch.Tensor, + prompt_hidden_states: torch.Tensor, + logits_fn: Callable[[torch.Tensor], torch.Tensor], +) -> tuple[torch.Tensor, torch.Tensor]: + # Since materializing the full prompt logits can take too much memory, + # we compute it in chunks. + CHUNK_SIZE = 1024 + logprobs = [] + ranks = [] + prompt_token_ids = prompt_token_ids.to(torch.int64) + for start_idx in range(0, prompt_token_ids.shape[0], CHUNK_SIZE): + end_idx = start_idx + CHUNK_SIZE + # NOTE(woosuk): logits_fn can be slow because it involves all-gather. + prompt_logits = logits_fn(prompt_hidden_states[start_idx:end_idx]) + prompt_logprobs = compute_topk_logprobs( + prompt_logits, + 0, # num_logprobs + prompt_token_ids[start_idx:end_idx], + ) + logprobs.append(prompt_logprobs.logprobs) + ranks.append(prompt_logprobs.selected_token_ranks) + + logprobs = torch.cat(logprobs, dim=0) if len(logprobs) > 1 else logprobs[0] + ranks = torch.cat(ranks, dim=0) if len(ranks) > 1 else ranks[0] + return logprobs, ranks diff --git a/vllm/v1/worker/gpu/states.py b/vllm/v1/worker/gpu/states.py index f11b03ab6..b73d41a78 100644 --- a/vllm/v1/worker/gpu/states.py +++ b/vllm/v1/worker/gpu/states.py @@ -1,13 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from dataclasses import dataclass, field +from dataclasses import dataclass import numpy as np import torch from vllm.lora.request import LoRARequest -from vllm.sampling_params import SamplingParams -from vllm.v1.outputs import LogprobsTensors from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor NO_LORA_ID = 0 @@ -76,8 +74,6 @@ class RequestState: self.lora_ids = np.zeros(self.max_num_reqs, dtype=np.int32) self.lora_ids.fill(NO_LORA_ID) - self.needs_prompt_logprobs = np.zeros(self.max_num_reqs, dtype=bool) - @property def num_reqs(self) -> int: return len(self.req_id_to_index) @@ -88,7 +84,6 @@ class RequestState: prompt_len: int, prefill_token_ids: list[int], num_computed_tokens: int, - sampling_params: SamplingParams, lora_request: LoRARequest | None, ) -> None: assert len(self.free_indices) > 0, "No free indices" @@ -112,10 +107,6 @@ class RequestState: else: self.lora_ids[req_idx] = NO_LORA_ID - # For now, only support prompt logprobs for the prompt tokens. - needs_prompt_logprobs = sampling_params.prompt_logprobs is not None - self.needs_prompt_logprobs[req_idx] = needs_prompt_logprobs - def apply_staged_writes(self) -> None: self.prefill_len.copy_to_uva() self.prefill_token_ids.apply_write() @@ -151,4 +142,3 @@ class RequestState: @dataclass class ExtraData: lora_request: LoRARequest | None - in_progress_prompt_logprobs: list[LogprobsTensors] = field(default_factory=list)