[Perf] Remove redundant device copies for CPU-only pooling token IDs, 48.9% E2E throughput improvement (#38139)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye
2026-03-29 14:12:50 -04:00
committed by GitHub
parent 8c0b6267d7
commit 995dea1354
8 changed files with 86 additions and 17 deletions

View File

@@ -638,25 +638,26 @@ class SPLADESparsePooler(Pooler):
lens: list[int] = lens_tensor.tolist()
B: int = len(lens)
token_ids = pooling_metadata.prompt_token_ids
prompt_token_ids = pooling_metadata.get_prompt_token_ids_cpu()
offset = 0
pooled_list: list[torch.Tensor] = []
for i in range(B):
L = int(lens[i])
hs = hidden_states[offset : offset + L]
token_ids = prompt_token_ids[i]
start_idx = 0
end_idx = L
if self.remove_cls_sep and token_ids is not None:
if self.remove_cls_sep:
if (
self.cls_token_id is not None
and token_ids[i, 0].item() == self.cls_token_id
and int(token_ids[0]) == self.cls_token_id
):
start_idx = 1
if (
self.sep_token_id is not None
and token_ids[i, L - 1].item() == self.sep_token_id
and int(token_ids[L - 1]) == self.sep_token_id
):
end_idx = max(start_idx, L - 1)

View File

@@ -156,10 +156,11 @@ class GritLMMeanPool(SequencePoolingMethod):
pooling_metadata: PoolingMetadata,
) -> SequencePoolingMethodOutput:
prompt_lens = pooling_metadata.prompt_lens
prompt_token_ids = pooling_metadata.get_prompt_token_ids_cpu()
instr_lens = torch.tensor(
[
self._get_instruction_len(token_ids.cpu().numpy())
for token_ids in pooling_metadata.get_prompt_token_ids()
self._get_instruction_len(token_ids.numpy())
for token_ids in prompt_token_ids
],
device="cpu",
)