[V1][PP] Cache Intermediate Tensors (#13353)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import gc
|
import gc
|
||||||
import time
|
import time
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast
|
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -149,6 +149,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.positions = torch.zeros(self.max_num_tokens,
|
self.positions = torch.zeros(self.max_num_tokens,
|
||||||
dtype=torch.int64,
|
dtype=torch.int64,
|
||||||
device=self.device)
|
device=self.device)
|
||||||
|
# self.intermediate_tensors # Set after load_model
|
||||||
|
|
||||||
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
||||||
if self.uses_mrope:
|
if self.uses_mrope:
|
||||||
@@ -869,7 +870,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self,
|
self,
|
||||||
scheduler_output: "SchedulerOutput",
|
scheduler_output: "SchedulerOutput",
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
) -> ModelRunnerOutput:
|
) -> Union[ModelRunnerOutput, torch.Tensor]:
|
||||||
batch_changed = self._update_states(scheduler_output)
|
batch_changed = self._update_states(scheduler_output)
|
||||||
|
|
||||||
if self.is_multimodal_model:
|
if self.is_multimodal_model:
|
||||||
@@ -919,6 +920,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
else:
|
else:
|
||||||
positions = self.positions[:num_input_tokens]
|
positions = self.positions[:num_input_tokens]
|
||||||
|
|
||||||
|
if get_pp_group().is_first_rank:
|
||||||
|
intermediate_tensors = None
|
||||||
|
else:
|
||||||
|
intermediate_tensors = IntermediateTensors({
|
||||||
|
k: v[:num_input_tokens]
|
||||||
|
for k, v in self.intermediate_tensors.items()
|
||||||
|
})
|
||||||
|
|
||||||
# Run the decoder.
|
# Run the decoder.
|
||||||
# Use persistent buffers for CUDA graphs.
|
# Use persistent buffers for CUDA graphs.
|
||||||
with set_forward_context(attn_metadata, self.vllm_config):
|
with set_forward_context(attn_metadata, self.vllm_config):
|
||||||
@@ -931,7 +940,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
)
|
)
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
|
# For mid-pipeline stages, return the hidden states.
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
hidden_states = hidden_states[:num_scheduled_tokens]
|
hidden_states = hidden_states[:num_scheduled_tokens]
|
||||||
sample_hidden_states = hidden_states[logits_indices]
|
sample_hidden_states = hidden_states[logits_indices]
|
||||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||||
@@ -1118,12 +1129,21 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
positions = self.mrope_positions[:, :num_tokens]
|
positions = self.mrope_positions[:, :num_tokens]
|
||||||
else:
|
else:
|
||||||
positions = self.positions[:num_tokens]
|
positions = self.positions[:num_tokens]
|
||||||
|
|
||||||
|
if get_pp_group().is_first_rank:
|
||||||
intermediate_tensors = None
|
intermediate_tensors = None
|
||||||
if not get_pp_group().is_first_rank:
|
else:
|
||||||
intermediate_tensors = self.model.make_empty_intermediate_tensors(
|
if not hasattr(self, "intermediate_tensors"):
|
||||||
batch_size=num_tokens,
|
self.intermediate_tensors = (
|
||||||
|
self.model.make_empty_intermediate_tensors(
|
||||||
|
batch_size=self.max_num_tokens,
|
||||||
dtype=self.model_config.dtype,
|
dtype=self.model_config.dtype,
|
||||||
device=self.device)
|
device=self.device))
|
||||||
|
intermediate_tensors = IntermediateTensors({
|
||||||
|
k: v[:num_tokens]
|
||||||
|
for k, v in self.intermediate_tensors.items()
|
||||||
|
})
|
||||||
|
|
||||||
with set_forward_context(None, self.vllm_config):
|
with set_forward_context(None, self.vllm_config):
|
||||||
hidden_states = model(
|
hidden_states = model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
|
|||||||
Reference in New Issue
Block a user