[Perf] Remove redundant device copies for CPU-only pooling token IDs, 48.9% E2E throughput improvement (#38139)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye
2026-03-29 14:12:50 -04:00
committed by GitHub
parent 8c0b6267d7
commit 995dea1354
8 changed files with 86 additions and 17 deletions

View File

@@ -380,7 +380,7 @@ def test_swap_states_in_input_batch(device: str, batch_size: int, swap_list: lis
_compare_objs(input_batch, ref_input_batch)
def _construct_pooling_request(req_id_suffix: int):
def _construct_pooling_request(req_id_suffix: int, pooling_params=None):
from vllm.pooling_params import PoolingParams
prompt_token_ids = [
@@ -391,7 +391,7 @@ def _construct_pooling_request(req_id_suffix: int):
req_id=f"pool_req_{req_id_suffix}",
prompt_token_ids=prompt_token_ids,
sampling_params=None,
pooling_params=PoolingParams(task="classify"),
pooling_params=pooling_params or PoolingParams(task="classify"),
mm_features=[],
block_ids=([],),
generator=None,
@@ -440,3 +440,48 @@ def test_pooling_prompt_lens_not_aliased(device: str):
"mutations to num_prompt_tokens_cpu_tensor corrupted prompt_lens. "
f"Expected {prompt_lens_snapshot}, got {metadata.prompt_lens}"
)
@pytest.mark.parametrize(
("pooling_params", "expect_device_prompt_token_ids", "expect_cpu_prompt_token_ids"),
[
({"task": "classify"}, False, False),
({"task": "classify", "requires_token_ids": True}, True, True),
],
)
def test_pooling_metadata_token_id_buffers(
pooling_params: dict[str, object],
expect_device_prompt_token_ids: bool,
expect_cpu_prompt_token_ids: bool,
):
from vllm.pooling_params import PoolingParams
input_batch = InputBatch(
max_num_reqs=1,
max_model_len=MAX_PROMPT_SIZE + NUM_OUTPUT_TOKENS,
max_num_batched_tokens=MAX_PROMPT_SIZE + NUM_OUTPUT_TOKENS,
device=torch.device("cpu"),
pin_memory=False,
vocab_size=VOCAB_SIZE,
block_sizes=[16],
kernel_block_sizes=[16],
is_pooling_model=True,
)
req = _construct_pooling_request(0, PoolingParams(**pooling_params))
input_batch.add_request(req)
input_batch.refresh_metadata()
metadata = input_batch.get_pooling_metadata()
if expect_device_prompt_token_ids:
assert input_batch.sampling_metadata.prompt_token_ids is not None
assert metadata.prompt_token_ids is not None
assert metadata.get_prompt_token_ids()[0].tolist() == req.prompt_token_ids
else:
assert input_batch.sampling_metadata.prompt_token_ids is None
assert metadata.prompt_token_ids is None
if expect_cpu_prompt_token_ids:
assert metadata.prompt_token_ids_cpu is not None
assert metadata.get_prompt_token_ids_cpu()[0].tolist() == req.prompt_token_ids
else:
assert metadata.prompt_token_ids_cpu is None