[Model Runner V2] Enable CUDA graph for Eagle3 (#35040)

Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
This commit is contained in:
Woosuk Kwon
2026-02-21 21:42:50 -08:00
committed by GitHub
parent 30132cd144
commit 2cbf9656ce
2 changed files with 44 additions and 9 deletions

View File

@@ -25,10 +25,17 @@ from vllm.v1.worker.utils import AttentionGroup
class CudaGraphManager:
def __init__(self, vllm_config: VllmConfig, uses_mrope: bool, device: torch.device):
def __init__(
self,
vllm_config: VllmConfig,
uses_mrope: bool,
use_aux_hidden_state_outputs: bool,
device: torch.device,
):
self.vllm_config = vllm_config
self.scheduler_config = vllm_config.scheduler_config
self.uses_mrope = uses_mrope
self.use_aux_hidden_state_outputs = use_aux_hidden_state_outputs
self.device = device
self.max_model_len = vllm_config.model_config.max_model_len
@@ -63,6 +70,7 @@ class CudaGraphManager:
if self.cudagraph_mode != CUDAGraphMode.NONE:
self.pool = torch.cuda.graph_pool_handle()
self.hidden_states: torch.Tensor | None = None
self.aux_hidden_states: list[torch.Tensor] = []
def needs_capture(self) -> bool:
return len(self.cudagraph_sizes) > 0
@@ -134,13 +142,22 @@ class CudaGraphManager:
num_tokens_across_dp=num_tokens_across_dp,
slot_mapping=slot_mappings,
):
hidden_states = model(
model_output = model(
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
)
if self.hidden_states is None:
self.hidden_states = torch.empty_like(hidden_states)
if self.use_aux_hidden_state_outputs:
hidden_states, aux_hidden_states = model_output
else:
hidden_states = model_output
aux_hidden_states = None
# Allocate output buffers if not already done.
if self.hidden_states is None:
self.hidden_states = torch.empty_like(hidden_states)
if self.use_aux_hidden_state_outputs and not self.aux_hidden_states:
self.aux_hidden_states = [torch.empty_like(x) for x in aux_hidden_states]
capture_fn(
num_tokens=num_tokens,
@@ -183,13 +200,23 @@ class CudaGraphManager:
),
torch.cuda.graph(graph, self.pool),
):
hidden_states = model(
model_output = model(
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
)
if self.use_aux_hidden_state_outputs:
hidden_states, aux_hidden_states = model_output
else:
hidden_states = model_output
aux_hidden_states = None
# Copy outputs to the output buffers.
assert self.hidden_states is not None
self.hidden_states[:num_tokens] = hidden_states
if self.use_aux_hidden_state_outputs:
for i, aux_hidden in enumerate(aux_hidden_states):
self.aux_hidden_states[i][:num_tokens] = aux_hidden
self.graphs[num_tokens] = graph
def _capture_piecewise_graph(
@@ -298,11 +325,16 @@ class CudaGraphManager:
cudagraph_size = None
return cudagraph_mode, cudagraph_size
def run_fullgraph(self, num_tokens: int) -> torch.Tensor:
def run_fullgraph(
self, num_tokens: int
) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
assert num_tokens in self.graphs, f"No cudagraph for {num_tokens} tokens"
self.graphs[num_tokens].replay()
assert self.hidden_states is not None
return self.hidden_states[:num_tokens]
hidden_states = self.hidden_states[:num_tokens]
if not self.use_aux_hidden_state_outputs:
return hidden_states
return hidden_states, [x[:num_tokens] for x in self.aux_hidden_states]
def get_cudagraph_sizes(

View File

@@ -197,7 +197,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# CUDA graphs.
self.cudagraph_manager = CudaGraphManager(
self.vllm_config, self.uses_mrope, self.device
self.vllm_config,
self.uses_mrope,
self.use_aux_hidden_state_outputs,
self.device,
)
# Structured outputs worker.
self.structured_outputs_worker = StructuredOutputsWorker(
@@ -1044,7 +1047,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
aux_hidden_states,
input_batch,
kv_connector_output,
)
) # type: ignore
return None
@torch.inference_mode()