[Perf] Remove redundant device copies for CPU-only pooling token IDs, 48.9% E2E throughput improvement (#38139)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
@@ -380,7 +380,7 @@ def test_swap_states_in_input_batch(device: str, batch_size: int, swap_list: lis
|
||||
_compare_objs(input_batch, ref_input_batch)
|
||||
|
||||
|
||||
def _construct_pooling_request(req_id_suffix: int):
|
||||
def _construct_pooling_request(req_id_suffix: int, pooling_params=None):
|
||||
from vllm.pooling_params import PoolingParams
|
||||
|
||||
prompt_token_ids = [
|
||||
@@ -391,7 +391,7 @@ def _construct_pooling_request(req_id_suffix: int):
|
||||
req_id=f"pool_req_{req_id_suffix}",
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
sampling_params=None,
|
||||
pooling_params=PoolingParams(task="classify"),
|
||||
pooling_params=pooling_params or PoolingParams(task="classify"),
|
||||
mm_features=[],
|
||||
block_ids=([],),
|
||||
generator=None,
|
||||
@@ -440,3 +440,48 @@ def test_pooling_prompt_lens_not_aliased(device: str):
|
||||
"mutations to num_prompt_tokens_cpu_tensor corrupted prompt_lens. "
|
||||
f"Expected {prompt_lens_snapshot}, got {metadata.prompt_lens}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("pooling_params", "expect_device_prompt_token_ids", "expect_cpu_prompt_token_ids"),
|
||||
[
|
||||
({"task": "classify"}, False, False),
|
||||
({"task": "classify", "requires_token_ids": True}, True, True),
|
||||
],
|
||||
)
|
||||
def test_pooling_metadata_token_id_buffers(
|
||||
pooling_params: dict[str, object],
|
||||
expect_device_prompt_token_ids: bool,
|
||||
expect_cpu_prompt_token_ids: bool,
|
||||
):
|
||||
from vllm.pooling_params import PoolingParams
|
||||
|
||||
input_batch = InputBatch(
|
||||
max_num_reqs=1,
|
||||
max_model_len=MAX_PROMPT_SIZE + NUM_OUTPUT_TOKENS,
|
||||
max_num_batched_tokens=MAX_PROMPT_SIZE + NUM_OUTPUT_TOKENS,
|
||||
device=torch.device("cpu"),
|
||||
pin_memory=False,
|
||||
vocab_size=VOCAB_SIZE,
|
||||
block_sizes=[16],
|
||||
kernel_block_sizes=[16],
|
||||
is_pooling_model=True,
|
||||
)
|
||||
req = _construct_pooling_request(0, PoolingParams(**pooling_params))
|
||||
input_batch.add_request(req)
|
||||
input_batch.refresh_metadata()
|
||||
|
||||
metadata = input_batch.get_pooling_metadata()
|
||||
if expect_device_prompt_token_ids:
|
||||
assert input_batch.sampling_metadata.prompt_token_ids is not None
|
||||
assert metadata.prompt_token_ids is not None
|
||||
assert metadata.get_prompt_token_ids()[0].tolist() == req.prompt_token_ids
|
||||
else:
|
||||
assert input_batch.sampling_metadata.prompt_token_ids is None
|
||||
assert metadata.prompt_token_ids is None
|
||||
|
||||
if expect_cpu_prompt_token_ids:
|
||||
assert metadata.prompt_token_ids_cpu is not None
|
||||
assert metadata.get_prompt_token_ids_cpu()[0].tolist() == req.prompt_token_ids
|
||||
else:
|
||||
assert metadata.prompt_token_ids_cpu is None
|
||||
|
||||
Reference in New Issue
Block a user