[Perf] Optimize mean pooling using chunks and index_add, 5.9% E2E throughput improvement (#38559)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye
2026-03-31 23:54:58 -04:00
committed by GitHub
parent 17b72fd1c8
commit 7b01d97a22

View File

@@ -14,6 +14,8 @@ from vllm.v1.pool.metadata import PoolingMetadata
SequencePoolingMethodOutput: TypeAlias = torch.Tensor | list[torch.Tensor]
_MEAN_POOL_ACCUMULATION_CHUNK_BYTES = 16 * 1024 * 1024 # 16MB
class SequencePoolingMethod(nn.Module, ABC):
def get_supported_tasks(self) -> Set[PoolingTask]:
@@ -67,19 +69,41 @@ class MeanPool(SequencePoolingMethod):
)
prompt_lens = pooling_cursor.prompt_lens_cpu.to(
hidden_states.device, non_blocking=True
hidden_states.device, dtype=torch.int64, non_blocking=True
)
# Use float32 for torch.cumsum in MeanPool,
# otherwise precision will be lost significantly.
cumsum = torch.cumsum(hidden_states, dim=0, dtype=torch.float32)
num_seqs = prompt_lens.numel()
hidden_size = hidden_states.shape[-1]
start_indices = pooling_cursor.first_token_indices_gpu
end_indices = pooling_cursor.last_token_indices_gpu
if num_seqs == 0:
# early return for empty batch
return hidden_states.new_empty((0, hidden_size), dtype=torch.float32)
return (
cumsum[end_indices] - cumsum[start_indices] + hidden_states[start_indices]
) / prompt_lens.unsqueeze(1)
# eg. [2, 1, 3] -> [0, 0, 1, 2, 2, 2]
segment_ids = torch.repeat_interleave(
torch.arange(num_seqs, device=hidden_states.device, dtype=torch.long),
prompt_lens,
)
segment_sums = torch.zeros(
(num_seqs, hidden_size),
dtype=torch.float32,
device=hidden_states.device,
)
bytes_per_token = hidden_size * torch.finfo(torch.float32).bits // 8
chunk_size = max(1, _MEAN_POOL_ACCUMULATION_CHUNK_BYTES // bytes_per_token)
# iterate over the batch in chunks
for start in range(0, hidden_states.shape[0], chunk_size):
end = min(start + chunk_size, hidden_states.shape[0])
# using index_add_ to accumulate for each segment
segment_sums.index_add_(
0,
segment_ids[start:end],
hidden_states[start:end].to(dtype=torch.float32),
)
return segment_sums / prompt_lens.unsqueeze(1)
def get_seq_pooling_method(pooling_type: SequencePoolingType | str):