From 7b01d97a22c977aaad2f38fa03cd11742d2893e8 Mon Sep 17 00:00:00 2001 From: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Date: Tue, 31 Mar 2026 23:54:58 -0400 Subject: [PATCH] [Perf] Optimize mean pooling using chunks and index_add, 5.9% E2E throughput improvement (#38559) Signed-off-by: yewentao256 --- .../layers/pooler/seqwise/methods.py | 42 +++++++++++++++---- 1 file changed, 33 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/layers/pooler/seqwise/methods.py b/vllm/model_executor/layers/pooler/seqwise/methods.py index f3c7f29d6..b967ff4ed 100644 --- a/vllm/model_executor/layers/pooler/seqwise/methods.py +++ b/vllm/model_executor/layers/pooler/seqwise/methods.py @@ -14,6 +14,8 @@ from vllm.v1.pool.metadata import PoolingMetadata SequencePoolingMethodOutput: TypeAlias = torch.Tensor | list[torch.Tensor] +_MEAN_POOL_ACCUMULATION_CHUNK_BYTES = 16 * 1024 * 1024 # 16MB + class SequencePoolingMethod(nn.Module, ABC): def get_supported_tasks(self) -> Set[PoolingTask]: @@ -67,19 +69,41 @@ class MeanPool(SequencePoolingMethod): ) prompt_lens = pooling_cursor.prompt_lens_cpu.to( - hidden_states.device, non_blocking=True + hidden_states.device, dtype=torch.int64, non_blocking=True ) - # Use float32 for torch.cumsum in MeanPool, - # otherwise precision will be lost significantly. - cumsum = torch.cumsum(hidden_states, dim=0, dtype=torch.float32) + num_seqs = prompt_lens.numel() + hidden_size = hidden_states.shape[-1] - start_indices = pooling_cursor.first_token_indices_gpu - end_indices = pooling_cursor.last_token_indices_gpu + if num_seqs == 0: + # early return for empty batch + return hidden_states.new_empty((0, hidden_size), dtype=torch.float32) - return ( - cumsum[end_indices] - cumsum[start_indices] + hidden_states[start_indices] - ) / prompt_lens.unsqueeze(1) + # eg. [2, 1, 3] -> [0, 0, 1, 2, 2, 2] + segment_ids = torch.repeat_interleave( + torch.arange(num_seqs, device=hidden_states.device, dtype=torch.long), + prompt_lens, + ) + segment_sums = torch.zeros( + (num_seqs, hidden_size), + dtype=torch.float32, + device=hidden_states.device, + ) + + bytes_per_token = hidden_size * torch.finfo(torch.float32).bits // 8 + chunk_size = max(1, _MEAN_POOL_ACCUMULATION_CHUNK_BYTES // bytes_per_token) + + # iterate over the batch in chunks + for start in range(0, hidden_states.shape[0], chunk_size): + end = min(start + chunk_size, hidden_states.shape[0]) + # using index_add_ to accumulate for each segment + segment_sums.index_add_( + 0, + segment_ids[start:end], + hidden_states[start:end].to(dtype=torch.float32), + ) + + return segment_sums / prompt_lens.unsqueeze(1) def get_seq_pooling_method(pooling_type: SequencePoolingType | str):