diff --git a/vllm/v1/worker/gpu/cudagraph_utils.py b/vllm/v1/worker/gpu/cudagraph_utils.py index 7bba7ffb9..5665937a0 100644 --- a/vllm/v1/worker/gpu/cudagraph_utils.py +++ b/vllm/v1/worker/gpu/cudagraph_utils.py @@ -25,10 +25,17 @@ from vllm.v1.worker.utils import AttentionGroup class CudaGraphManager: - def __init__(self, vllm_config: VllmConfig, uses_mrope: bool, device: torch.device): + def __init__( + self, + vllm_config: VllmConfig, + uses_mrope: bool, + use_aux_hidden_state_outputs: bool, + device: torch.device, + ): self.vllm_config = vllm_config self.scheduler_config = vllm_config.scheduler_config self.uses_mrope = uses_mrope + self.use_aux_hidden_state_outputs = use_aux_hidden_state_outputs self.device = device self.max_model_len = vllm_config.model_config.max_model_len @@ -63,6 +70,7 @@ class CudaGraphManager: if self.cudagraph_mode != CUDAGraphMode.NONE: self.pool = torch.cuda.graph_pool_handle() self.hidden_states: torch.Tensor | None = None + self.aux_hidden_states: list[torch.Tensor] = [] def needs_capture(self) -> bool: return len(self.cudagraph_sizes) > 0 @@ -134,13 +142,22 @@ class CudaGraphManager: num_tokens_across_dp=num_tokens_across_dp, slot_mapping=slot_mappings, ): - hidden_states = model( + model_output = model( input_ids=input_ids, positions=positions, inputs_embeds=inputs_embeds, ) - if self.hidden_states is None: - self.hidden_states = torch.empty_like(hidden_states) + if self.use_aux_hidden_state_outputs: + hidden_states, aux_hidden_states = model_output + else: + hidden_states = model_output + aux_hidden_states = None + + # Allocate output buffers if not already done. + if self.hidden_states is None: + self.hidden_states = torch.empty_like(hidden_states) + if self.use_aux_hidden_state_outputs and not self.aux_hidden_states: + self.aux_hidden_states = [torch.empty_like(x) for x in aux_hidden_states] capture_fn( num_tokens=num_tokens, @@ -183,13 +200,23 @@ class CudaGraphManager: ), torch.cuda.graph(graph, self.pool), ): - hidden_states = model( + model_output = model( input_ids=input_ids, positions=positions, inputs_embeds=inputs_embeds, ) + if self.use_aux_hidden_state_outputs: + hidden_states, aux_hidden_states = model_output + else: + hidden_states = model_output + aux_hidden_states = None + + # Copy outputs to the output buffers. assert self.hidden_states is not None self.hidden_states[:num_tokens] = hidden_states + if self.use_aux_hidden_state_outputs: + for i, aux_hidden in enumerate(aux_hidden_states): + self.aux_hidden_states[i][:num_tokens] = aux_hidden self.graphs[num_tokens] = graph def _capture_piecewise_graph( @@ -298,11 +325,16 @@ class CudaGraphManager: cudagraph_size = None return cudagraph_mode, cudagraph_size - def run_fullgraph(self, num_tokens: int) -> torch.Tensor: + def run_fullgraph( + self, num_tokens: int + ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: assert num_tokens in self.graphs, f"No cudagraph for {num_tokens} tokens" self.graphs[num_tokens].replay() assert self.hidden_states is not None - return self.hidden_states[:num_tokens] + hidden_states = self.hidden_states[:num_tokens] + if not self.use_aux_hidden_state_outputs: + return hidden_states + return hidden_states, [x[:num_tokens] for x in self.aux_hidden_states] def get_cudagraph_sizes( diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index b909b90ad..cdea0b2aa 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -197,7 +197,10 @@ class GPUModelRunner(LoRAModelRunnerMixin): # CUDA graphs. self.cudagraph_manager = CudaGraphManager( - self.vllm_config, self.uses_mrope, self.device + self.vllm_config, + self.uses_mrope, + self.use_aux_hidden_state_outputs, + self.device, ) # Structured outputs worker. self.structured_outputs_worker = StructuredOutputsWorker( @@ -1044,7 +1047,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): aux_hidden_states, input_batch, kv_connector_output, - ) + ) # type: ignore return None @torch.inference_mode()