[Model Runner V2] Enable CUDA graph for Eagle3 (#35040)
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
This commit is contained in:
@@ -25,10 +25,17 @@ from vllm.v1.worker.utils import AttentionGroup
|
||||
|
||||
|
||||
class CudaGraphManager:
|
||||
def __init__(self, vllm_config: VllmConfig, uses_mrope: bool, device: torch.device):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
uses_mrope: bool,
|
||||
use_aux_hidden_state_outputs: bool,
|
||||
device: torch.device,
|
||||
):
|
||||
self.vllm_config = vllm_config
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
self.uses_mrope = uses_mrope
|
||||
self.use_aux_hidden_state_outputs = use_aux_hidden_state_outputs
|
||||
self.device = device
|
||||
|
||||
self.max_model_len = vllm_config.model_config.max_model_len
|
||||
@@ -63,6 +70,7 @@ class CudaGraphManager:
|
||||
if self.cudagraph_mode != CUDAGraphMode.NONE:
|
||||
self.pool = torch.cuda.graph_pool_handle()
|
||||
self.hidden_states: torch.Tensor | None = None
|
||||
self.aux_hidden_states: list[torch.Tensor] = []
|
||||
|
||||
def needs_capture(self) -> bool:
|
||||
return len(self.cudagraph_sizes) > 0
|
||||
@@ -134,13 +142,22 @@ class CudaGraphManager:
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
slot_mapping=slot_mappings,
|
||||
):
|
||||
hidden_states = model(
|
||||
model_output = model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
if self.hidden_states is None:
|
||||
self.hidden_states = torch.empty_like(hidden_states)
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
hidden_states, aux_hidden_states = model_output
|
||||
else:
|
||||
hidden_states = model_output
|
||||
aux_hidden_states = None
|
||||
|
||||
# Allocate output buffers if not already done.
|
||||
if self.hidden_states is None:
|
||||
self.hidden_states = torch.empty_like(hidden_states)
|
||||
if self.use_aux_hidden_state_outputs and not self.aux_hidden_states:
|
||||
self.aux_hidden_states = [torch.empty_like(x) for x in aux_hidden_states]
|
||||
|
||||
capture_fn(
|
||||
num_tokens=num_tokens,
|
||||
@@ -183,13 +200,23 @@ class CudaGraphManager:
|
||||
),
|
||||
torch.cuda.graph(graph, self.pool),
|
||||
):
|
||||
hidden_states = model(
|
||||
model_output = model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
hidden_states, aux_hidden_states = model_output
|
||||
else:
|
||||
hidden_states = model_output
|
||||
aux_hidden_states = None
|
||||
|
||||
# Copy outputs to the output buffers.
|
||||
assert self.hidden_states is not None
|
||||
self.hidden_states[:num_tokens] = hidden_states
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
for i, aux_hidden in enumerate(aux_hidden_states):
|
||||
self.aux_hidden_states[i][:num_tokens] = aux_hidden
|
||||
self.graphs[num_tokens] = graph
|
||||
|
||||
def _capture_piecewise_graph(
|
||||
@@ -298,11 +325,16 @@ class CudaGraphManager:
|
||||
cudagraph_size = None
|
||||
return cudagraph_mode, cudagraph_size
|
||||
|
||||
def run_fullgraph(self, num_tokens: int) -> torch.Tensor:
|
||||
def run_fullgraph(
|
||||
self, num_tokens: int
|
||||
) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
|
||||
assert num_tokens in self.graphs, f"No cudagraph for {num_tokens} tokens"
|
||||
self.graphs[num_tokens].replay()
|
||||
assert self.hidden_states is not None
|
||||
return self.hidden_states[:num_tokens]
|
||||
hidden_states = self.hidden_states[:num_tokens]
|
||||
if not self.use_aux_hidden_state_outputs:
|
||||
return hidden_states
|
||||
return hidden_states, [x[:num_tokens] for x in self.aux_hidden_states]
|
||||
|
||||
|
||||
def get_cudagraph_sizes(
|
||||
|
||||
@@ -197,7 +197,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
# CUDA graphs.
|
||||
self.cudagraph_manager = CudaGraphManager(
|
||||
self.vllm_config, self.uses_mrope, self.device
|
||||
self.vllm_config,
|
||||
self.uses_mrope,
|
||||
self.use_aux_hidden_state_outputs,
|
||||
self.device,
|
||||
)
|
||||
# Structured outputs worker.
|
||||
self.structured_outputs_worker = StructuredOutputsWorker(
|
||||
@@ -1044,7 +1047,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
aux_hidden_states,
|
||||
input_batch,
|
||||
kv_connector_output,
|
||||
)
|
||||
) # type: ignore
|
||||
return None
|
||||
|
||||
@torch.inference_mode()
|
||||
|
||||
Reference in New Issue
Block a user