[Perf] Optimize mean pooling using chunks and index_add, 5.9% E2E throughput improvement (#38559)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
@@ -14,6 +14,8 @@ from vllm.v1.pool.metadata import PoolingMetadata
|
||||
|
||||
SequencePoolingMethodOutput: TypeAlias = torch.Tensor | list[torch.Tensor]
|
||||
|
||||
_MEAN_POOL_ACCUMULATION_CHUNK_BYTES = 16 * 1024 * 1024 # 16MB
|
||||
|
||||
|
||||
class SequencePoolingMethod(nn.Module, ABC):
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
@@ -67,19 +69,41 @@ class MeanPool(SequencePoolingMethod):
|
||||
)
|
||||
|
||||
prompt_lens = pooling_cursor.prompt_lens_cpu.to(
|
||||
hidden_states.device, non_blocking=True
|
||||
hidden_states.device, dtype=torch.int64, non_blocking=True
|
||||
)
|
||||
|
||||
# Use float32 for torch.cumsum in MeanPool,
|
||||
# otherwise precision will be lost significantly.
|
||||
cumsum = torch.cumsum(hidden_states, dim=0, dtype=torch.float32)
|
||||
num_seqs = prompt_lens.numel()
|
||||
hidden_size = hidden_states.shape[-1]
|
||||
|
||||
start_indices = pooling_cursor.first_token_indices_gpu
|
||||
end_indices = pooling_cursor.last_token_indices_gpu
|
||||
if num_seqs == 0:
|
||||
# early return for empty batch
|
||||
return hidden_states.new_empty((0, hidden_size), dtype=torch.float32)
|
||||
|
||||
return (
|
||||
cumsum[end_indices] - cumsum[start_indices] + hidden_states[start_indices]
|
||||
) / prompt_lens.unsqueeze(1)
|
||||
# eg. [2, 1, 3] -> [0, 0, 1, 2, 2, 2]
|
||||
segment_ids = torch.repeat_interleave(
|
||||
torch.arange(num_seqs, device=hidden_states.device, dtype=torch.long),
|
||||
prompt_lens,
|
||||
)
|
||||
segment_sums = torch.zeros(
|
||||
(num_seqs, hidden_size),
|
||||
dtype=torch.float32,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
|
||||
bytes_per_token = hidden_size * torch.finfo(torch.float32).bits // 8
|
||||
chunk_size = max(1, _MEAN_POOL_ACCUMULATION_CHUNK_BYTES // bytes_per_token)
|
||||
|
||||
# iterate over the batch in chunks
|
||||
for start in range(0, hidden_states.shape[0], chunk_size):
|
||||
end = min(start + chunk_size, hidden_states.shape[0])
|
||||
# using index_add_ to accumulate for each segment
|
||||
segment_sums.index_add_(
|
||||
0,
|
||||
segment_ids[start:end],
|
||||
hidden_states[start:end].to(dtype=torch.float32),
|
||||
)
|
||||
|
||||
return segment_sums / prompt_lens.unsqueeze(1)
|
||||
|
||||
|
||||
def get_seq_pooling_method(pooling_type: SequencePoolingType | str):
|
||||
|
||||
Reference in New Issue
Block a user