[MRV2] Skip hidden states allocation for PW CUDA graphs (#37818)
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
This commit is contained in:
@@ -263,6 +263,7 @@ class ModelCudaGraphManager(CudaGraphManager):
|
||||
decode_query_len: int,
|
||||
):
|
||||
super().__init__(vllm_config, device, cudagraph_mode, decode_query_len)
|
||||
# Used for FULL CUDA graphs. PW CUDA graphs do not use these.
|
||||
self.hidden_states: torch.Tensor | None = None
|
||||
self.aux_hidden_states: list[torch.Tensor] = []
|
||||
self.use_aux_hidden_state_outputs = False
|
||||
@@ -326,6 +327,12 @@ class ModelCudaGraphManager(CudaGraphManager):
|
||||
**model_state.prepare_dummy_inputs(num_reqs, num_tokens),
|
||||
}
|
||||
model_output = model(**model_inputs)
|
||||
|
||||
if cg_mode == CUDAGraphMode.PIECEWISE:
|
||||
# PW CUDA graph internally handles the model outputs.
|
||||
# No need to keep track of the hidden states.
|
||||
return None
|
||||
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
hidden_states, aux_hidden_states = model_output
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user