diff --git a/vllm/v1/executor/ray_utils.py b/vllm/v1/executor/ray_utils.py index 21910d116..dadf55006 100644 --- a/vllm/v1/executor/ray_utils.py +++ b/vllm/v1/executor/ray_utils.py @@ -104,8 +104,11 @@ try: scheduler_output, intermediate_tensors ) if isinstance(output, IntermediateTensors): - output = scheduler_output, grammar_output, output - elif not get_pp_group().is_last_rank: + return scheduler_output, grammar_output, output + + if isinstance(output, AsyncModelRunnerOutput): + output = output.get_output() + if not get_pp_group().is_last_rank: # Case where there are no scheduled requests # but may still be finished requests. assert not output or not output.req_ids diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index bea9e5846..2ac44e3bb 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -151,21 +151,23 @@ class ModelRunnerOutput: # num_generated_tokens is the number of tokens # generated in the current step. It can be different for # each request due to speculative/jump decoding. - sampled_token_ids: list[list[int]] + sampled_token_ids: list[list[int]] = field(default_factory=list) # [num_reqs, max_num_logprobs + 1] # [num_reqs, max_num_logprobs + 1] # [num_reqs] - logprobs: LogprobsLists | None + logprobs: LogprobsLists | None = None # req_id -> (token_ids, logprobs, ranks) # [prompt_len, num_prompt_logprobs] # [prompt_len, num_prompt_logprobs] # [prompt_len] - prompt_logprobs_dict: dict[str, LogprobsTensors | None] + prompt_logprobs_dict: dict[str, LogprobsTensors | None] = field( + default_factory=dict + ) # [num_reqs, hidden_size] - pooler_output: list[torch.Tensor | None] + pooler_output: list[torch.Tensor | None] | None = None kv_connector_output: KVConnectorOutput | None = None @@ -225,21 +227,8 @@ def make_empty_encoder_model_runner_output( req_ids=req_ids, req_id_to_index=req_id_to_index, sampled_token_ids=sampled_token_ids, - logprobs=None, - prompt_logprobs_dict={}, pooler_output=pooler_output, - kv_connector_output=None, - ec_connector_output=None, - num_nans_in_logits=None, ) -EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput( - req_ids=[], - req_id_to_index={}, - sampled_token_ids=[], - logprobs=None, - prompt_logprobs_dict={}, - pooler_output=[], - num_nans_in_logits=None, -) +EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[], req_id_to_index={}) diff --git a/vllm/v1/pool/metadata.py b/vllm/v1/pool/metadata.py index acd1a00e8..7ed022bb9 100644 --- a/vllm/v1/pool/metadata.py +++ b/vllm/v1/pool/metadata.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass +import numpy as np import torch from vllm.pooling_params import PoolingParams @@ -91,36 +92,27 @@ class PoolingMetadata: def build_pooling_cursor( self, - num_scheduled_tokens: list[int], + num_scheduled_tokens_np: np.ndarray, seq_lens_cpu: torch.Tensor, device: torch.device, ): - self.pooling_cursor = build_pooling_cursor( - num_scheduled_tokens, seq_lens_cpu, self.prompt_lens, device + n_seq = len(num_scheduled_tokens_np) + prompt_lens = self.prompt_lens + + assert len(prompt_lens) == n_seq + + index = list(range(n_seq)) + num_scheduled_tokens_cpu = torch.from_numpy(num_scheduled_tokens_np) + cumsum = torch.zeros( + n_seq + 1, dtype=torch.int64, pin_memory=pin_memory, device="cpu" + ) + torch.cumsum(num_scheduled_tokens_cpu, dim=0, out=cumsum[1:]) + cumsum = cumsum.to(device, non_blocking=True) + self.pooling_cursor = PoolingCursor( + index=index, + first_token_indices_gpu=cumsum[:n_seq], + last_token_indices_gpu=cumsum[1:] - 1, + prompt_lens_cpu=prompt_lens, + seq_lens_cpu=seq_lens_cpu, + num_scheduled_tokens_cpu=num_scheduled_tokens_cpu, ) - - -def build_pooling_cursor( - num_scheduled_tokens: list[int], - seq_lens_cpu: torch.Tensor, - prompt_lens: torch.Tensor, - device: torch.device, -): - assert len(prompt_lens) == len(num_scheduled_tokens) - - n_seq = len(num_scheduled_tokens) - index = list(range(n_seq)) - num_scheduled_tokens_cpu = torch.tensor(num_scheduled_tokens, device="cpu") - cumsum = torch.zeros( - n_seq + 1, dtype=torch.int64, pin_memory=pin_memory, device="cpu" - ) - torch.cumsum(num_scheduled_tokens_cpu, dim=0, out=cumsum[1:]) - cumsum = cumsum.to(device, non_blocking=True) - return PoolingCursor( - index=index, - first_token_indices_gpu=cumsum[:n_seq], - last_token_indices_gpu=cumsum[1:] - 1, - prompt_lens_cpu=prompt_lens, - seq_lens_cpu=seq_lens_cpu, - num_scheduled_tokens_cpu=num_scheduled_tokens_cpu, - ) diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 9f4c6edfb..06b4aed56 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -968,11 +968,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Only for compatibility with the existing model runner and scheduler. req_id_to_index={req_id: i for i, req_id in enumerate(input_batch.req_ids)}, sampled_token_ids=None, # type: ignore - logprobs=None, - prompt_logprobs_dict=prompt_logprobs_dict, # type: ignore - pooler_output=[], - kv_connector_output=None, - num_nans_in_logits=None, + prompt_logprobs_dict=prompt_logprobs_dict, # type: ignore[arg-type] ) async_output = AsyncOutput( model_runner_output=model_runner_output, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 7e7409d27..de8e168e4 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -254,6 +254,50 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput): return output +class AsyncGPUPoolingModelRunnerOutput(AsyncModelRunnerOutput): + def __init__( + self, + model_runner_output: ModelRunnerOutput, + raw_pooler_output: PoolerOutput, + finished_mask: list[bool], + async_output_copy_stream: torch.cuda.Stream, + ): + self._model_runner_output = model_runner_output + self._finished_mask = finished_mask + + # Event on the copy stream so we can synchronize the non-blocking copy. + self.async_copy_ready_event = torch.Event() + + # Keep a reference to the device tensors to avoid them being + # deallocated until we finish copying it to the host. + self._raw_pooler_output = raw_pooler_output + + # Initiate the copy on a separate stream, but do not synchronize it. + default_stream = torch.cuda.current_stream() + with torch.cuda.stream(async_output_copy_stream): + async_output_copy_stream.wait_stream(default_stream) + self._raw_pooler_output_cpu = json_map_leaves( + lambda x: None if x is None else x.to("cpu", non_blocking=True), + self._raw_pooler_output, + ) + self.async_copy_ready_event.record() + + def get_output(self) -> ModelRunnerOutput: + """Copy the device tensors to the host and return a ModelRunnerOutput. + This function blocks until the copy is finished. + """ + self.async_copy_ready_event.synchronize() + + # Release the device tensors once the copy has completed. + del self._raw_pooler_output + + self._model_runner_output.pooler_output = [ + out if include else None + for out, include in zip(self._raw_pooler_output_cpu, self._finished_mask) + ] + return self._model_runner_output + + class ExecuteModelState(NamedTuple): """Ephemeral cached state transferred between execute_model() and sample_tokens(), after execute_model() returns None.""" @@ -2476,17 +2520,19 @@ class GPUModelRunner( hidden_states: torch.Tensor, num_scheduled_tokens: int, num_scheduled_tokens_np: np.ndarray, - ) -> ModelRunnerOutput: - assert self.input_batch.num_reqs == len(self.input_batch.pooling_params), ( + kv_connector_output: KVConnectorOutput | None, + ) -> ModelRunnerOutput | AsyncModelRunnerOutput: + num_reqs = self.input_batch.num_reqs + assert num_reqs == len(self.input_batch.pooling_params), ( "Either all or none of the requests in a batch must be pooling request" ) hidden_states = hidden_states[:num_scheduled_tokens] - seq_lens_cpu = self.seq_lens.cpu[: self.input_batch.num_reqs] + seq_lens_cpu = self.seq_lens.cpu[:num_reqs] pooling_metadata = self.input_batch.get_pooling_metadata() pooling_metadata.build_pooling_cursor( - num_scheduled_tokens_np.tolist(), seq_lens_cpu, device=hidden_states.device + num_scheduled_tokens_np, seq_lens_cpu, device=hidden_states.device ) model = cast(VllmModelForPooling, self.model) @@ -2494,27 +2540,41 @@ class GPUModelRunner( hidden_states=hidden_states, pooling_metadata=pooling_metadata, ) + + finished_mask = [ + seq_len == prompt_len + for seq_len, prompt_len in zip(seq_lens_cpu, pooling_metadata.prompt_lens) + ] + + model_runner_output = ModelRunnerOutput( + req_ids=self.input_batch.req_ids.copy(), + req_id_to_index=self.input_batch.req_id_to_index.copy(), + kv_connector_output=kv_connector_output, + ) + + if raw_pooler_output is None or not any(finished_mask): + model_runner_output.pooler_output = [None] * num_reqs + return model_runner_output + + if self.use_async_scheduling: + return AsyncGPUPoolingModelRunnerOutput( + model_runner_output=model_runner_output, + raw_pooler_output=raw_pooler_output, + finished_mask=finished_mask, + async_output_copy_stream=self.async_output_copy_stream, + ) + raw_pooler_output = json_map_leaves( - lambda x: x.to("cpu", non_blocking=True) if x is not None else x, + lambda x: None if x is None else x.to("cpu", non_blocking=True), raw_pooler_output, ) self._sync_device() - pooler_output: list[torch.Tensor | None] = [] - for raw_output, seq_len, prompt_len in zip( - raw_pooler_output, seq_lens_cpu, pooling_metadata.prompt_lens - ): - output = raw_output if seq_len == prompt_len else None - pooler_output.append(output) - - return ModelRunnerOutput( - req_ids=self.input_batch.req_ids, - req_id_to_index=self.input_batch.req_id_to_index, - sampled_token_ids=[], - logprobs=None, - prompt_logprobs_dict={}, - pooler_output=pooler_output, - ) + model_runner_output.pooler_output = [ + out if include else None + for out, include in zip(raw_pooler_output, finished_mask) + ] + return model_runner_output def _pad_for_sequence_parallelism(self, num_scheduled_tokens: int) -> int: # Pad tokens to multiple of tensor_parallel_size when @@ -3036,7 +3096,7 @@ class GPUModelRunner( self, scheduler_output: "SchedulerOutput", intermediate_tensors: IntermediateTensors | None = None, - ) -> ModelRunnerOutput | IntermediateTensors | None: + ) -> ModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors | None: if self.execute_model_state is not None: raise RuntimeError( "State error: sample_tokens() must be called " @@ -3244,11 +3304,12 @@ class GPUModelRunner( if self.is_pooling_model: # Return the pooling output. - output = self._pool( - hidden_states, num_scheduled_tokens, num_scheduled_tokens_np + return self._pool( + hidden_states, + num_scheduled_tokens, + num_scheduled_tokens_np, + kv_connector_output, ) - output.kv_connector_output = kv_connector_output - return output sample_hidden_states = hidden_states[logits_indices] logits = self.model.compute_logits(sample_hidden_states) @@ -3437,7 +3498,6 @@ class GPUModelRunner( sampled_token_ids=valid_sampled_token_ids, logprobs=logprobs_lists, prompt_logprobs_dict=prompt_logprobs_dict, - pooler_output=[], kv_connector_output=kv_connector_output, ec_connector_output=ec_connector_output if self.supports_mm_inputs @@ -4508,17 +4568,14 @@ class GPUModelRunner( max_num_reqs = self.scheduler_config.max_num_seqs num_reqs = min(num_tokens, max_num_reqs) min_tokens_per_req = num_tokens // num_reqs - num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs - num_scheduled_tokens_list[-1] += num_tokens % num_reqs - assert sum(num_scheduled_tokens_list) == num_tokens - assert len(num_scheduled_tokens_list) == num_reqs + num_scheduled_tokens_np = np.full(num_reqs, min_tokens_per_req) + num_scheduled_tokens_np[-1] += num_tokens % num_reqs + assert np.sum(num_scheduled_tokens_np) == num_tokens + assert len(num_scheduled_tokens_np) == num_reqs req_num_tokens = num_tokens // num_reqs - dummy_prompt_lens = torch.tensor( - num_scheduled_tokens_list, - device="cpu", - ) + dummy_prompt_lens = torch.from_numpy(num_scheduled_tokens_np) dummy_token_ids = torch.zeros( (num_reqs, req_num_tokens), dtype=torch.int32, device=self.device ) @@ -4537,7 +4594,7 @@ class GPUModelRunner( ) dummy_metadata.build_pooling_cursor( - num_scheduled_tokens_list, + num_scheduled_tokens_np, seq_lens_cpu=dummy_prompt_lens, device=hidden_states.device, ) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index fae7fa620..2416c8eba 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -575,7 +575,7 @@ class Worker(WorkerBase): @torch.inference_mode() def execute_model( self, scheduler_output: "SchedulerOutput" - ) -> ModelRunnerOutput | None: + ) -> ModelRunnerOutput | AsyncModelRunnerOutput | None: intermediate_tensors = None forward_pass = scheduler_output.total_num_scheduled_tokens > 0 num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens @@ -624,7 +624,9 @@ class Worker(WorkerBase): output = self.model_runner.execute_model( scheduler_output, intermediate_tensors ) - if isinstance(output, ModelRunnerOutput | NoneType): + if isinstance( + output, ModelRunnerOutput | AsyncModelRunnerOutput | NoneType + ): return output assert isinstance(output, IntermediateTensors) diff --git a/vllm/v1/worker/worker_base.py b/vllm/v1/worker/worker_base.py index 57e7037e9..d06ae2fdf 100644 --- a/vllm/v1/worker/worker_base.py +++ b/vllm/v1/worker/worker_base.py @@ -124,7 +124,7 @@ class WorkerBase: def execute_model( self, scheduler_output: SchedulerOutput - ) -> ModelRunnerOutput | None: + ) -> ModelRunnerOutput | AsyncModelRunnerOutput | None: """If this method returns None, sample_tokens should be called immediately after to obtain the ModelRunnerOutput. @@ -362,7 +362,7 @@ class WorkerWrapperBase: scheduler_output: SchedulerOutput, *args, **kwargs, - ) -> ModelRunnerOutput | None: + ) -> ModelRunnerOutput | AsyncModelRunnerOutput | None: self._apply_mm_cache(scheduler_output) assert self.worker is not None