diff --git a/vllm/v1/worker/gpu/async_utils.py b/vllm/v1/worker/gpu/async_utils.py index e628e38bd..f87459efa 100644 --- a/vllm/v1/worker/gpu/async_utils.py +++ b/vllm/v1/worker/gpu/async_utils.py @@ -70,6 +70,42 @@ class AsyncOutput(AsyncModelRunnerOutput): return self.model_runner_output +class AsyncPoolingOutput(AsyncModelRunnerOutput): + def __init__( + self, + model_runner_output: ModelRunnerOutput, + pooler_output: torch.Tensor, + is_valid: torch.Tensor | None, + main_stream: torch.cuda.Stream, + copy_stream: torch.cuda.Stream, + copy_event: torch.cuda.Event, + ): + self.model_runner_output = model_runner_output + self.pooler_output = pooler_output + self.is_valid = is_valid + self.copy_event = copy_event + + with stream(copy_stream, main_stream): + copy_stream.wait_stream(main_stream) + self.pooler_output_cpu = self.pooler_output.to("cpu", non_blocking=True) + if self.is_valid is not None: + self.is_valid_cpu = self.is_valid.to("cpu", non_blocking=True) + else: + self.is_valid_cpu = None + self.copy_event.record(copy_stream) + + def get_output(self) -> ModelRunnerOutput: + self.copy_event.synchronize() + pooler_output = self.pooler_output_cpu.unbind(dim=0) + if self.is_valid_cpu is not None: + is_valid_cpu = self.is_valid_cpu.tolist() + for i, is_valid in enumerate(is_valid_cpu): + if not is_valid: + pooler_output[i] = None + self.model_runner_output.pooler_output = pooler_output + return self.model_runner_output + + def async_copy_to_np(x: torch.Tensor) -> np.ndarray: return x.to("cpu", non_blocking=True).numpy() diff --git a/vllm/v1/worker/gpu/input_batch.py b/vllm/v1/worker/gpu/input_batch.py index 75655258c..5918cc374 100644 --- a/vllm/v1/worker/gpu/input_batch.py +++ b/vllm/v1/worker/gpu/input_batch.py @@ -499,6 +499,38 @@ def post_update( ) +@triton.jit +def _post_update_pool_kernel( + idx_mapping_ptr, + num_computed_tokens_ptr, + query_start_loc_ptr, +): + batch_id = tl.program_id(0) + query_start = tl.load(query_start_loc_ptr + batch_id) + query_end = tl.load(query_start_loc_ptr + batch_id + 1) + query_len = query_end - query_start + + req_state_idx = tl.load(idx_mapping_ptr + batch_id) + num_computed = tl.load(num_computed_tokens_ptr + req_state_idx) + tl.store(num_computed_tokens_ptr + req_state_idx, num_computed + query_len) + + +def post_update_pool( + # [num_reqs] + idx_mapping: torch.Tensor, + # [max_num_reqs] + num_computed_tokens: torch.Tensor, + # [num_reqs + 1] + query_start_loc: torch.Tensor, +) -> None: + num_reqs = idx_mapping.shape[0] + _post_update_pool_kernel[(num_reqs,)]( + idx_mapping, + num_computed_tokens, + query_start_loc, + ) + + @triton.jit def _expand_idx_mapping_kernel( idx_mapping_ptr, diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 7dcdaf1d2..8bca1a17f 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -38,13 +38,14 @@ from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model_loader from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sequence import IntermediateTensors +from vllm.tasks import SupportedTask from vllm.utils.mem_utils import DeviceMemoryProfiler, format_gib from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput from vllm.v1.worker.cp_utils import check_attention_cp_compatibility -from vllm.v1.worker.gpu.async_utils import AsyncOutput +from vllm.v1.worker.gpu.async_utils import AsyncOutput, AsyncPoolingOutput from vllm.v1.worker.gpu.attn_utils import ( build_slot_mappings_by_layer, get_kv_cache_spec, @@ -66,6 +67,7 @@ from vllm.v1.worker.gpu.input_batch import ( expand_idx_mapping, get_num_sampled_and_rejected, post_update, + post_update_pool, prepare_pos_seq_lens, prepare_prefill_inputs, ) @@ -77,6 +79,7 @@ from vllm.v1.worker.gpu.kv_connector import ( from vllm.v1.worker.gpu.lora_utils import LoraState from vllm.v1.worker.gpu.mm.encoder_runner import EncoderRunner from vllm.v1.worker.gpu.model_states import ModelState +from vllm.v1.worker.gpu.pool.pooling_runner import PoolingRunner from vllm.v1.worker.gpu.pp_utils import pp_broadcast, pp_receive from vllm.v1.worker.gpu.sample.output import SamplerOutput from vllm.v1.worker.gpu.sample.prompt_logprob import PromptLogprobsWorker @@ -119,7 +122,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ self.cache_config.cache_dtype ] - self.is_pooling_model = False self.vocab_size = self.model_config.get_vocab_size() self.max_model_len = self.model_config.max_model_len @@ -217,6 +219,10 @@ class GPUModelRunner(LoRAModelRunnerMixin): # KV Connector if configured. self.kv_connector: KVConnector = NO_OP_KV_CONNECTOR + # Pooling models. + self.is_pooling_model = self.model_config.runner_type == "pooling" + self.pooling_runner: PoolingRunner | None = None + # For transferring state from execute_model to subsequent sample_tokens call. self.execute_model_state: tuple | None = None @@ -224,9 +230,13 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.max_model_len = max_model_len self.req_states.max_model_len = max_model_len - @staticmethod - def get_supported_tasks() -> tuple[str]: - return ("generate",) + def get_supported_tasks(self) -> tuple[SupportedTask, ...]: + tasks: list[SupportedTask] = [] + if self.model_config.runner_type == "generate": + tasks.append("generate") + if self.pooling_runner is not None: + tasks.extend(self.pooling_runner.get_supported_pooling_tasks()) + return tuple(tasks) def load_model(self, *args, **kwargs) -> None: time_before_load = time.perf_counter() @@ -263,6 +273,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): # Initialize the components that require the model. self.model_state = ModelState(self.vllm_config, self.model, self.device) + if self.is_pooling_model: + self.pooling_runner = PoolingRunner(self.model) def get_model(self) -> nn.Module: return self.model @@ -388,16 +400,24 @@ class GPUModelRunner(LoRAModelRunnerMixin): expanded_local_pos, ) + @torch.inference_mode() + def _dummy_pooler_run(self, hidden_states: torch.Tensor) -> None: + assert self.pooling_runner is not None + self.pooling_runner.dummy_pooler_run(hidden_states) + @torch.inference_mode() def profile_run(self) -> None: hidden_states, sample_hidden_states = self._dummy_run( self.max_num_tokens, skip_attn=True ) - # Only run sampler on last PP rank (non-last ranks return None). + # Only run sampler/pooler on last PP rank (non-last ranks return None). if self.is_last_pp_rank: assert sample_hidden_states is not None - self._dummy_sampler_run(sample_hidden_states) + if self.pooling_runner is None: + self._dummy_sampler_run(sample_hidden_states) + else: + self._dummy_pooler_run(hidden_states) if self.speculator is not None: num_tokens_across_dp = make_num_tokens_across_dp( @@ -505,7 +525,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): for new_req_data in scheduler_output.scheduled_new_reqs: assert new_req_data.prompt_token_ids is not None assert new_req_data.prefill_token_ids is not None - assert new_req_data.sampling_params is not None req_id = new_req_data.req_id prompt_len = len(new_req_data.prompt_token_ids) self.req_states.add_request( @@ -523,14 +542,16 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.block_tables.append_block_ids( req_index, new_req_data.block_ids, overwrite=True ) - self.sampler.add_request( - req_index, prompt_len, new_req_data.sampling_params - ) - self.prompt_logprobs_worker.add_request( - req_id, req_index, new_req_data.sampling_params - ) self.lora_state.add_request(req_id, req_index, new_req_data.lora_request) + if new_req_data.sampling_params is not None: + self.sampler.add_request( + req_index, prompt_len, new_req_data.sampling_params + ) + self.prompt_logprobs_worker.add_request( + req_id, req_index, new_req_data.sampling_params + ) + if scheduler_output.scheduled_new_reqs: self.req_states.apply_staged_writes() self.sampler.apply_staged_writes() @@ -1083,3 +1104,58 @@ class GPUModelRunner(LoRAModelRunnerMixin): def take_draft_token_ids(self) -> DraftTokenIds | None: return self.draft_tokens_handler.get_draft_tokens() + + @torch.inference_mode() + def pool(self) -> AsyncPoolingOutput | ModelRunnerOutput | None: + if self.execute_model_state is None: + # The prior execute_model call must have failed. + return None + + input_batch, _, _, _, hidden_states, _, kv_connector_output = ( + self.execute_model_state + ) + self.execute_model_state = None + + if not self.is_last_pp_rank: + self.postprocess_pool(input_batch) + return None + + assert self.pooling_runner is not None + pooler_output, is_valid = self.pooling_runner.pool( + hidden_states, input_batch, self.req_states + ) + self.postprocess_pool(input_batch) + + # Build the model runner output. + model_runner_output = ModelRunnerOutput( + req_ids=input_batch.req_ids, + req_id_to_index={req_id: i for i, req_id in enumerate(input_batch.req_ids)}, + kv_connector_output=kv_connector_output, + ) + async_output = AsyncPoolingOutput( + model_runner_output=model_runner_output, + pooler_output=pooler_output, + is_valid=is_valid, + main_stream=self.main_stream, + copy_stream=self.output_copy_stream, + copy_event=self.output_copy_event, + ) + if self.use_async_scheduling: + return async_output + return async_output.get_output() + + def postprocess_pool(self, input_batch: InputBatch) -> None: + # Update the number of computed tokens. + post_update_pool( + input_batch.idx_mapping, + self.req_states.num_computed_tokens.gpu, + input_batch.query_start_loc, + ) + + # Update the number of computed prefill tokens. + idx_mapping_np = input_batch.idx_mapping_np + computed_prefill = self.req_states.num_computed_prefill_tokens + computed_prefill[idx_mapping_np] += input_batch.num_scheduled_tokens + np.minimum( + computed_prefill, self.req_states.prefill_len.np, out=computed_prefill + ) diff --git a/vllm/v1/worker/gpu/pool/__init__.py b/vllm/v1/worker/gpu/pool/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/vllm/v1/worker/gpu/pool/pooling_runner.py b/vllm/v1/worker/gpu/pool/pooling_runner.py new file mode 100644 index 000000000..7098aad54 --- /dev/null +++ b/vllm/v1/worker/gpu/pool/pooling_runner.py @@ -0,0 +1,45 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import cast + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from vllm.model_executor.models import VllmModelForPooling, is_pooling_model +from vllm.tasks import PoolingTask +from vllm.v1.worker.gpu.input_batch import InputBatch +from vllm.v1.worker.gpu.states import RequestState + + +# NOTE(woosuk): Currently, this class only supports the "LAST" pooling task +# on decoder-only models. How to support other pooling tasks and models +# is to be determined. +class PoolingRunner: + def __init__(self, model: nn.Module): + self.model = cast(VllmModelForPooling, model) + + def get_supported_pooling_tasks(self) -> list[PoolingTask]: + if not is_pooling_model(self.model): + return [] + assert "embed" in self.model.pooler.get_supported_tasks() + return ["embed"] + + def pool( + self, + hidden_states: torch.Tensor, + input_batch: InputBatch, + req_states: RequestState, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + # TODO(woosuk): Support different types of pooling tasks. + last_hidden_states = hidden_states[input_batch.logits_indices] + # TODO(woosuk): Make normalization optional. + last_hidden_states = F.normalize(last_hidden_states, p=2, dim=-1) + + prompt_len = req_states.prompt_len.gpu[input_batch.idx_mapping] + is_valid = input_batch.seq_lens == prompt_len + return last_hidden_states, is_valid + + def dummy_pooler_run(self, hidden_states: torch.Tensor) -> None: + F.normalize(hidden_states, p=2, dim=-1) + return diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index fcc0fdf88..06410b2eb 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -700,6 +700,12 @@ class Worker(WorkerBase): output = self.model_runner.execute_model( scheduler_output, intermediate_tensors ) + if ( + self.use_v2_model_runner + and self.model_runner.is_pooling_model + and output is None + ): + output = self.model_runner.pool() # type: ignore if isinstance( output, ModelRunnerOutput | AsyncModelRunnerOutput | NoneType ):