[Performance] V1 Pooling Models E2E Performance Optimization (#23162)
Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
@@ -6,15 +6,40 @@ from typing import Optional
|
||||
import torch
|
||||
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.utils import is_pin_memory_available
|
||||
|
||||
pin_memory = is_pin_memory_available()
|
||||
|
||||
|
||||
@dataclass
|
||||
class PoolingCursor:
|
||||
index: list[int]
|
||||
first_token_indices_gpu: torch.Tensor
|
||||
last_token_indices_gpu: torch.Tensor
|
||||
prompt_lens_cpu: torch.Tensor
|
||||
num_scheduled_tokens_cpu: torch.Tensor
|
||||
|
||||
def __getitem__(self, indices: slice):
|
||||
return PoolingCursor(
|
||||
index=self.index[indices],
|
||||
first_token_indices_gpu=self.first_token_indices_gpu[indices],
|
||||
last_token_indices_gpu=self.last_token_indices_gpu[indices],
|
||||
prompt_lens_cpu=self.prompt_lens_cpu[indices],
|
||||
num_scheduled_tokens_cpu=self.num_scheduled_tokens_cpu[indices],
|
||||
)
|
||||
|
||||
def is_partial_prefill(self):
|
||||
return not torch.all(
|
||||
self.prompt_lens_cpu == self.num_scheduled_tokens_cpu)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PoolingMetadata:
|
||||
"""Tensors for pooling."""
|
||||
|
||||
prompt_lens: torch.Tensor
|
||||
prompt_lens: torch.Tensor # CPU Tensor
|
||||
prompt_token_ids: Optional[torch.Tensor]
|
||||
pooling_params: list[PoolingParams]
|
||||
pooling_cursor: Optional[PoolingCursor] = None
|
||||
|
||||
def __getitem__(self, indices: slice):
|
||||
return PoolingMetadata(
|
||||
@@ -22,4 +47,31 @@ class PoolingMetadata:
|
||||
prompt_token_ids=None if self.prompt_token_ids is None else
|
||||
self.prompt_token_ids[indices],
|
||||
pooling_params=self.pooling_params[indices],
|
||||
pooling_cursor=None
|
||||
if self.pooling_cursor is None else self.pooling_cursor[indices],
|
||||
)
|
||||
|
||||
def build_pooling_cursor(self, num_scheduled_tokens: list[int],
|
||||
device: torch.device):
|
||||
self.pooling_cursor = build_pooling_cursor(num_scheduled_tokens,
|
||||
self.prompt_lens, device)
|
||||
|
||||
|
||||
def build_pooling_cursor(num_scheduled_tokens: list[int],
|
||||
prompt_lens: torch.Tensor, device: torch.device):
|
||||
assert len(prompt_lens) == len(num_scheduled_tokens)
|
||||
|
||||
n_seq = len(num_scheduled_tokens)
|
||||
index = list(range(n_seq))
|
||||
num_scheduled_tokens = torch.tensor(num_scheduled_tokens, device="cpu")
|
||||
cumsum = torch.zeros(n_seq + 1,
|
||||
dtype=torch.int64,
|
||||
pin_memory=pin_memory,
|
||||
device="cpu")
|
||||
torch.cumsum(num_scheduled_tokens, dim=0, out=cumsum[1:])
|
||||
cumsum = cumsum.to(device, non_blocking=True)
|
||||
return PoolingCursor(index=index,
|
||||
first_token_indices_gpu=cumsum[:n_seq],
|
||||
last_token_indices_gpu=cumsum[1:] - 1,
|
||||
prompt_lens_cpu=prompt_lens,
|
||||
num_scheduled_tokens_cpu=num_scheduled_tokens)
|
||||
|
||||
@@ -713,7 +713,7 @@ class InputBatch:
|
||||
|
||||
return PoolingMetadata(
|
||||
prompt_lens=torch.from_numpy(
|
||||
self.num_prompt_tokens[:self.num_reqs]).to(self.device),
|
||||
self.num_prompt_tokens[:self.num_reqs]),
|
||||
prompt_token_ids=self.sampling_metadata.prompt_token_ids,
|
||||
pooling_params=pooling_params,
|
||||
)
|
||||
|
||||
@@ -1476,23 +1476,22 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
"Either all or none of the requests in" \
|
||||
" a batch must be pooling request"
|
||||
|
||||
extracted_hidden_states = list(
|
||||
torch.split(hidden_states[:num_scheduled_tokens],
|
||||
num_scheduled_tokens_np.tolist()))
|
||||
|
||||
hidden_states = hidden_states[:num_scheduled_tokens]
|
||||
pooling_metadata = self.input_batch.pooling_metadata
|
||||
pooling_metadata.build_pooling_cursor(num_scheduled_tokens_np.tolist(),
|
||||
device=hidden_states.device)
|
||||
seq_lens_cpu = self.seq_lens_cpu[:self.input_batch.num_reqs]
|
||||
|
||||
# Pooling models D2H & synchronize occurs in pooler.py:build_output
|
||||
raw_pooler_output = self.model.pooler(
|
||||
hidden_states=extracted_hidden_states,
|
||||
pooling_metadata=pooling_metadata)
|
||||
hidden_states=hidden_states, pooling_metadata=pooling_metadata)
|
||||
|
||||
pooler_output: list[Optional[torch.Tensor]] = []
|
||||
seq_lens = self.seq_lens[:self.input_batch.num_reqs]
|
||||
for raw_output, seq_len, prompt_len in zip(
|
||||
raw_pooler_output, seq_lens, pooling_metadata.prompt_lens):
|
||||
raw_pooler_output, seq_lens_cpu, pooling_metadata.prompt_lens):
|
||||
|
||||
if seq_len == prompt_len:
|
||||
pooler_output.append(raw_output.data.cpu())
|
||||
pooler_output.append(raw_output.data)
|
||||
else:
|
||||
pooler_output.append(None)
|
||||
|
||||
@@ -2524,13 +2523,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
assert sum(num_scheduled_tokens_list) == num_tokens
|
||||
assert len(num_scheduled_tokens_list) == num_reqs
|
||||
|
||||
hidden_states_list = list(
|
||||
torch.split(hidden_states, num_scheduled_tokens_list))
|
||||
req_num_tokens = num_tokens // num_reqs
|
||||
|
||||
dummy_prompt_lens = torch.tensor(
|
||||
[h.shape[0] for h in hidden_states_list],
|
||||
device=self.device,
|
||||
num_scheduled_tokens_list,
|
||||
device="cpu",
|
||||
)
|
||||
dummy_token_ids = torch.zeros((num_reqs, req_num_tokens),
|
||||
dtype=torch.int32,
|
||||
@@ -2547,8 +2544,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
pooling_params=[dummy_pooling_params] * num_reqs,
|
||||
)
|
||||
|
||||
dummy_metadata.build_pooling_cursor(num_scheduled_tokens_list,
|
||||
device=hidden_states.device)
|
||||
|
||||
try:
|
||||
return model.pooler(hidden_states=hidden_states_list,
|
||||
return model.pooler(hidden_states=hidden_states,
|
||||
pooling_metadata=dummy_metadata)
|
||||
except RuntimeError as e:
|
||||
if 'out of memory' in str(e):
|
||||
@@ -3316,10 +3316,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
dummy_block_table = torch.zeros((num_reqs, 1),
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
pin_memory=self.pin_memory,
|
||||
device="cpu").to(self.device,
|
||||
non_blocking=True)
|
||||
dummy_slot_mapping = torch.zeros((total_num_scheduled_tokens, ),
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
pin_memory=self.pin_memory,
|
||||
device="cpu").to(self.device,
|
||||
non_blocking=True)
|
||||
|
||||
group_metadata = dict[str, tuple[CommonAttentionMetadata, Any]]()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user