[Perf] Remove redundant device copies for CPU-only pooling token IDs, 48.9% E2E throughput improvement (#38139)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
@@ -638,25 +638,26 @@ class SPLADESparsePooler(Pooler):
|
||||
lens: list[int] = lens_tensor.tolist()
|
||||
B: int = len(lens)
|
||||
|
||||
token_ids = pooling_metadata.prompt_token_ids
|
||||
prompt_token_ids = pooling_metadata.get_prompt_token_ids_cpu()
|
||||
offset = 0
|
||||
pooled_list: list[torch.Tensor] = []
|
||||
|
||||
for i in range(B):
|
||||
L = int(lens[i])
|
||||
hs = hidden_states[offset : offset + L]
|
||||
token_ids = prompt_token_ids[i]
|
||||
|
||||
start_idx = 0
|
||||
end_idx = L
|
||||
if self.remove_cls_sep and token_ids is not None:
|
||||
if self.remove_cls_sep:
|
||||
if (
|
||||
self.cls_token_id is not None
|
||||
and token_ids[i, 0].item() == self.cls_token_id
|
||||
and int(token_ids[0]) == self.cls_token_id
|
||||
):
|
||||
start_idx = 1
|
||||
if (
|
||||
self.sep_token_id is not None
|
||||
and token_ids[i, L - 1].item() == self.sep_token_id
|
||||
and int(token_ids[L - 1]) == self.sep_token_id
|
||||
):
|
||||
end_idx = max(start_idx, L - 1)
|
||||
|
||||
|
||||
@@ -156,10 +156,11 @@ class GritLMMeanPool(SequencePoolingMethod):
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> SequencePoolingMethodOutput:
|
||||
prompt_lens = pooling_metadata.prompt_lens
|
||||
prompt_token_ids = pooling_metadata.get_prompt_token_ids_cpu()
|
||||
instr_lens = torch.tensor(
|
||||
[
|
||||
self._get_instruction_len(token_ids.cpu().numpy())
|
||||
for token_ids in pooling_metadata.get_prompt_token_ids()
|
||||
self._get_instruction_len(token_ids.numpy())
|
||||
for token_ids in prompt_token_ids
|
||||
],
|
||||
device="cpu",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user