[Perf] Remove redundant device copies for CPU-only pooling token IDs, 48.9% E2E throughput improvement (#38139)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye
2026-03-29 14:12:50 -04:00
committed by GitHub
parent 8c0b6267d7
commit 995dea1354
8 changed files with 86 additions and 17 deletions

View File

@@ -380,7 +380,7 @@ def test_swap_states_in_input_batch(device: str, batch_size: int, swap_list: lis
_compare_objs(input_batch, ref_input_batch)
def _construct_pooling_request(req_id_suffix: int):
def _construct_pooling_request(req_id_suffix: int, pooling_params=None):
from vllm.pooling_params import PoolingParams
prompt_token_ids = [
@@ -391,7 +391,7 @@ def _construct_pooling_request(req_id_suffix: int):
req_id=f"pool_req_{req_id_suffix}",
prompt_token_ids=prompt_token_ids,
sampling_params=None,
pooling_params=PoolingParams(task="classify"),
pooling_params=pooling_params or PoolingParams(task="classify"),
mm_features=[],
block_ids=([],),
generator=None,
@@ -440,3 +440,48 @@ def test_pooling_prompt_lens_not_aliased(device: str):
"mutations to num_prompt_tokens_cpu_tensor corrupted prompt_lens. "
f"Expected {prompt_lens_snapshot}, got {metadata.prompt_lens}"
)
@pytest.mark.parametrize(
("pooling_params", "expect_device_prompt_token_ids", "expect_cpu_prompt_token_ids"),
[
({"task": "classify"}, False, False),
({"task": "classify", "requires_token_ids": True}, True, True),
],
)
def test_pooling_metadata_token_id_buffers(
pooling_params: dict[str, object],
expect_device_prompt_token_ids: bool,
expect_cpu_prompt_token_ids: bool,
):
from vllm.pooling_params import PoolingParams
input_batch = InputBatch(
max_num_reqs=1,
max_model_len=MAX_PROMPT_SIZE + NUM_OUTPUT_TOKENS,
max_num_batched_tokens=MAX_PROMPT_SIZE + NUM_OUTPUT_TOKENS,
device=torch.device("cpu"),
pin_memory=False,
vocab_size=VOCAB_SIZE,
block_sizes=[16],
kernel_block_sizes=[16],
is_pooling_model=True,
)
req = _construct_pooling_request(0, PoolingParams(**pooling_params))
input_batch.add_request(req)
input_batch.refresh_metadata()
metadata = input_batch.get_pooling_metadata()
if expect_device_prompt_token_ids:
assert input_batch.sampling_metadata.prompt_token_ids is not None
assert metadata.prompt_token_ids is not None
assert metadata.get_prompt_token_ids()[0].tolist() == req.prompt_token_ids
else:
assert input_batch.sampling_metadata.prompt_token_ids is None
assert metadata.prompt_token_ids is None
if expect_cpu_prompt_token_ids:
assert metadata.prompt_token_ids_cpu is not None
assert metadata.get_prompt_token_ids_cpu()[0].tolist() == req.prompt_token_ids
else:
assert metadata.prompt_token_ids_cpu is None

View File

@@ -18,7 +18,7 @@ ActivationFn = Callable[[_T], _T]
@dataclass(frozen=True)
class PoolingParamsUpdate:
requires_token_ids: bool = False
"""Set this flag to enable `get_prompt_token_ids` for your pooler."""
"""Set this flag to enable prompt token IDs for your pooler."""
def __or__(self, other: "PoolingParamsUpdate") -> "PoolingParamsUpdate":
return PoolingParamsUpdate(

View File

@@ -146,17 +146,19 @@ class BOSEOSFilter(Pooler):
) -> PoolerOutput:
pooled_outputs = self.pooler(hidden_states, pooling_metadata)
assert isinstance(pooled_outputs, list)
prompt_token_ids = pooling_metadata.get_prompt_token_ids_cpu()
for i, prompt_len in enumerate(pooling_metadata.prompt_lens):
for i, (prompt_len, token_ids) in enumerate(
zip(pooling_metadata.prompt_lens, prompt_token_ids)
):
pooled_data = pooled_outputs[i]
assert (
isinstance(pooled_data, torch.Tensor)
and pooled_data.shape[0] == prompt_len
)
token_ids = pooling_metadata.prompt_token_ids[i, :prompt_len]
if token_ids[0] == self.bos_token_id:
if int(token_ids[0]) == self.bos_token_id:
pooled_data = pooled_data[1:]
if token_ids[-1] == self.eos_token_id:
if int(token_ids[-1]) == self.eos_token_id:
pooled_data = pooled_data[:-1]
pooled_outputs[i] = pooled_data.squeeze(-1)

View File

@@ -638,25 +638,26 @@ class SPLADESparsePooler(Pooler):
lens: list[int] = lens_tensor.tolist()
B: int = len(lens)
token_ids = pooling_metadata.prompt_token_ids
prompt_token_ids = pooling_metadata.get_prompt_token_ids_cpu()
offset = 0
pooled_list: list[torch.Tensor] = []
for i in range(B):
L = int(lens[i])
hs = hidden_states[offset : offset + L]
token_ids = prompt_token_ids[i]
start_idx = 0
end_idx = L
if self.remove_cls_sep and token_ids is not None:
if self.remove_cls_sep:
if (
self.cls_token_id is not None
and token_ids[i, 0].item() == self.cls_token_id
and int(token_ids[0]) == self.cls_token_id
):
start_idx = 1
if (
self.sep_token_id is not None
and token_ids[i, L - 1].item() == self.sep_token_id
and int(token_ids[L - 1]) == self.sep_token_id
):
end_idx = max(start_idx, L - 1)

View File

@@ -156,10 +156,11 @@ class GritLMMeanPool(SequencePoolingMethod):
pooling_metadata: PoolingMetadata,
) -> SequencePoolingMethodOutput:
prompt_lens = pooling_metadata.prompt_lens
prompt_token_ids = pooling_metadata.get_prompt_token_ids_cpu()
instr_lens = torch.tensor(
[
self._get_instruction_len(token_ids.cpu().numpy())
for token_ids in pooling_metadata.get_prompt_token_ids()
self._get_instruction_len(token_ids.numpy())
for token_ids in prompt_token_ids
],
device="cpu",
)

View File

@@ -50,7 +50,8 @@ class PoolingMetadata:
"""Tensors for pooling."""
prompt_lens: torch.Tensor # CPU Tensor
prompt_token_ids: torch.Tensor | None
prompt_token_ids: torch.Tensor | None # Model-device tensor
prompt_token_ids_cpu: torch.Tensor | None # CPU tensor
pooling_params: list[PoolingParams]
pooling_states: list[PoolingStates]
pooling_cursor: PoolingCursor | None = None
@@ -73,6 +74,9 @@ class PoolingMetadata:
prompt_token_ids=None
if self.prompt_token_ids is None
else self.prompt_token_ids[indices],
prompt_token_ids_cpu=None
if self.prompt_token_ids_cpu is None
else self.prompt_token_ids_cpu[indices],
pooling_params=self.pooling_params[indices],
pooling_states=self.pooling_states[indices],
pooling_cursor=None
@@ -85,7 +89,13 @@ class PoolingMetadata:
assert prompt_token_ids is not None, (
"Please set `requires_token_ids=True` in `get_pooling_updates`"
)
return [prompt_token_ids[i, :num] for i, num in enumerate(self.prompt_lens)]
def get_prompt_token_ids_cpu(self) -> list[torch.Tensor]:
prompt_token_ids = self.prompt_token_ids_cpu
assert prompt_token_ids is not None, (
"Please set `requires_token_ids=True` in `get_pooling_updates`"
)
return [prompt_token_ids[i, :num] for i, num in enumerate(self.prompt_lens)]
def get_pooling_cursor(self) -> PoolingCursor:

View File

@@ -833,8 +833,13 @@ class InputBatch:
# step pooling during the sampling/pooling process.
# Hence copy these tensors only when there are requests which
# need penalties/step_pooler to be applied.
prompt_token_ids_cpu = (
self._make_prompt_token_ids_cpu_tensor() if needs_prompt_token_ids else None
)
prompt_token_ids = (
self._make_prompt_token_ids_tensor() if needs_prompt_token_ids else None
prompt_token_ids_cpu.to(device=self.device, non_blocking=True)
if prompt_token_ids_cpu is not None
else None
)
# Only set output_token_ids if required by the current requests'
@@ -891,15 +896,19 @@ class InputBatch:
def get_pooling_metadata(self) -> PoolingMetadata:
pooling_params = self.get_pooling_params()
pooling_states = self.get_pooling_states()
prompt_token_ids_cpu = None
if any(p.requires_token_ids for p in pooling_params):
prompt_token_ids_cpu = self._make_prompt_token_ids_cpu_tensor()
return PoolingMetadata(
prompt_lens=self.num_prompt_tokens_cpu_tensor[: self.num_reqs].clone(),
prompt_token_ids=self.sampling_metadata.prompt_token_ids,
prompt_token_ids_cpu=prompt_token_ids_cpu,
pooling_params=pooling_params,
pooling_states=pooling_states,
)
def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
def _make_prompt_token_ids_cpu_tensor(self) -> torch.Tensor:
num_reqs = self.num_reqs
max_prompt_len = self.num_prompt_tokens[:num_reqs].max()
prompt_token_ids_cpu_tensor = torch.empty(
@@ -914,7 +923,7 @@ class InputBatch:
# token_id of this value.
for i in range(num_reqs):
prompt_token_ids[i, self.num_prompt_tokens[i] :] = self.vocab_size
return prompt_token_ids_cpu_tensor.to(device=self.device, non_blocking=True)
return prompt_token_ids_cpu_tensor
def make_lora_inputs(
self, num_scheduled_tokens: np.ndarray, num_sampled_tokens: np.ndarray

View File

@@ -5653,6 +5653,7 @@ class GPUModelRunner(
dummy_metadata = PoolingMetadata(
prompt_lens=dummy_prompt_lens,
prompt_token_ids=dummy_token_ids,
prompt_token_ids_cpu=dummy_token_ids.cpu(),
pooling_params=[dummy_pooling_params] * num_reqs,
pooling_states=[PoolingStates() for i in range(num_reqs)],
)