[Perf] Remove redundant device copies for CPU-only pooling token IDs, 48.9% E2E throughput improvement (#38139)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
@@ -380,7 +380,7 @@ def test_swap_states_in_input_batch(device: str, batch_size: int, swap_list: lis
|
||||
_compare_objs(input_batch, ref_input_batch)
|
||||
|
||||
|
||||
def _construct_pooling_request(req_id_suffix: int):
|
||||
def _construct_pooling_request(req_id_suffix: int, pooling_params=None):
|
||||
from vllm.pooling_params import PoolingParams
|
||||
|
||||
prompt_token_ids = [
|
||||
@@ -391,7 +391,7 @@ def _construct_pooling_request(req_id_suffix: int):
|
||||
req_id=f"pool_req_{req_id_suffix}",
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
sampling_params=None,
|
||||
pooling_params=PoolingParams(task="classify"),
|
||||
pooling_params=pooling_params or PoolingParams(task="classify"),
|
||||
mm_features=[],
|
||||
block_ids=([],),
|
||||
generator=None,
|
||||
@@ -440,3 +440,48 @@ def test_pooling_prompt_lens_not_aliased(device: str):
|
||||
"mutations to num_prompt_tokens_cpu_tensor corrupted prompt_lens. "
|
||||
f"Expected {prompt_lens_snapshot}, got {metadata.prompt_lens}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("pooling_params", "expect_device_prompt_token_ids", "expect_cpu_prompt_token_ids"),
|
||||
[
|
||||
({"task": "classify"}, False, False),
|
||||
({"task": "classify", "requires_token_ids": True}, True, True),
|
||||
],
|
||||
)
|
||||
def test_pooling_metadata_token_id_buffers(
|
||||
pooling_params: dict[str, object],
|
||||
expect_device_prompt_token_ids: bool,
|
||||
expect_cpu_prompt_token_ids: bool,
|
||||
):
|
||||
from vllm.pooling_params import PoolingParams
|
||||
|
||||
input_batch = InputBatch(
|
||||
max_num_reqs=1,
|
||||
max_model_len=MAX_PROMPT_SIZE + NUM_OUTPUT_TOKENS,
|
||||
max_num_batched_tokens=MAX_PROMPT_SIZE + NUM_OUTPUT_TOKENS,
|
||||
device=torch.device("cpu"),
|
||||
pin_memory=False,
|
||||
vocab_size=VOCAB_SIZE,
|
||||
block_sizes=[16],
|
||||
kernel_block_sizes=[16],
|
||||
is_pooling_model=True,
|
||||
)
|
||||
req = _construct_pooling_request(0, PoolingParams(**pooling_params))
|
||||
input_batch.add_request(req)
|
||||
input_batch.refresh_metadata()
|
||||
|
||||
metadata = input_batch.get_pooling_metadata()
|
||||
if expect_device_prompt_token_ids:
|
||||
assert input_batch.sampling_metadata.prompt_token_ids is not None
|
||||
assert metadata.prompt_token_ids is not None
|
||||
assert metadata.get_prompt_token_ids()[0].tolist() == req.prompt_token_ids
|
||||
else:
|
||||
assert input_batch.sampling_metadata.prompt_token_ids is None
|
||||
assert metadata.prompt_token_ids is None
|
||||
|
||||
if expect_cpu_prompt_token_ids:
|
||||
assert metadata.prompt_token_ids_cpu is not None
|
||||
assert metadata.get_prompt_token_ids_cpu()[0].tolist() == req.prompt_token_ids
|
||||
else:
|
||||
assert metadata.prompt_token_ids_cpu is None
|
||||
|
||||
@@ -18,7 +18,7 @@ ActivationFn = Callable[[_T], _T]
|
||||
@dataclass(frozen=True)
|
||||
class PoolingParamsUpdate:
|
||||
requires_token_ids: bool = False
|
||||
"""Set this flag to enable `get_prompt_token_ids` for your pooler."""
|
||||
"""Set this flag to enable prompt token IDs for your pooler."""
|
||||
|
||||
def __or__(self, other: "PoolingParamsUpdate") -> "PoolingParamsUpdate":
|
||||
return PoolingParamsUpdate(
|
||||
|
||||
@@ -146,17 +146,19 @@ class BOSEOSFilter(Pooler):
|
||||
) -> PoolerOutput:
|
||||
pooled_outputs = self.pooler(hidden_states, pooling_metadata)
|
||||
assert isinstance(pooled_outputs, list)
|
||||
prompt_token_ids = pooling_metadata.get_prompt_token_ids_cpu()
|
||||
|
||||
for i, prompt_len in enumerate(pooling_metadata.prompt_lens):
|
||||
for i, (prompt_len, token_ids) in enumerate(
|
||||
zip(pooling_metadata.prompt_lens, prompt_token_ids)
|
||||
):
|
||||
pooled_data = pooled_outputs[i]
|
||||
assert (
|
||||
isinstance(pooled_data, torch.Tensor)
|
||||
and pooled_data.shape[0] == prompt_len
|
||||
)
|
||||
token_ids = pooling_metadata.prompt_token_ids[i, :prompt_len]
|
||||
if token_ids[0] == self.bos_token_id:
|
||||
if int(token_ids[0]) == self.bos_token_id:
|
||||
pooled_data = pooled_data[1:]
|
||||
if token_ids[-1] == self.eos_token_id:
|
||||
if int(token_ids[-1]) == self.eos_token_id:
|
||||
pooled_data = pooled_data[:-1]
|
||||
pooled_outputs[i] = pooled_data.squeeze(-1)
|
||||
|
||||
|
||||
@@ -638,25 +638,26 @@ class SPLADESparsePooler(Pooler):
|
||||
lens: list[int] = lens_tensor.tolist()
|
||||
B: int = len(lens)
|
||||
|
||||
token_ids = pooling_metadata.prompt_token_ids
|
||||
prompt_token_ids = pooling_metadata.get_prompt_token_ids_cpu()
|
||||
offset = 0
|
||||
pooled_list: list[torch.Tensor] = []
|
||||
|
||||
for i in range(B):
|
||||
L = int(lens[i])
|
||||
hs = hidden_states[offset : offset + L]
|
||||
token_ids = prompt_token_ids[i]
|
||||
|
||||
start_idx = 0
|
||||
end_idx = L
|
||||
if self.remove_cls_sep and token_ids is not None:
|
||||
if self.remove_cls_sep:
|
||||
if (
|
||||
self.cls_token_id is not None
|
||||
and token_ids[i, 0].item() == self.cls_token_id
|
||||
and int(token_ids[0]) == self.cls_token_id
|
||||
):
|
||||
start_idx = 1
|
||||
if (
|
||||
self.sep_token_id is not None
|
||||
and token_ids[i, L - 1].item() == self.sep_token_id
|
||||
and int(token_ids[L - 1]) == self.sep_token_id
|
||||
):
|
||||
end_idx = max(start_idx, L - 1)
|
||||
|
||||
|
||||
@@ -156,10 +156,11 @@ class GritLMMeanPool(SequencePoolingMethod):
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> SequencePoolingMethodOutput:
|
||||
prompt_lens = pooling_metadata.prompt_lens
|
||||
prompt_token_ids = pooling_metadata.get_prompt_token_ids_cpu()
|
||||
instr_lens = torch.tensor(
|
||||
[
|
||||
self._get_instruction_len(token_ids.cpu().numpy())
|
||||
for token_ids in pooling_metadata.get_prompt_token_ids()
|
||||
self._get_instruction_len(token_ids.numpy())
|
||||
for token_ids in prompt_token_ids
|
||||
],
|
||||
device="cpu",
|
||||
)
|
||||
|
||||
@@ -50,7 +50,8 @@ class PoolingMetadata:
|
||||
"""Tensors for pooling."""
|
||||
|
||||
prompt_lens: torch.Tensor # CPU Tensor
|
||||
prompt_token_ids: torch.Tensor | None
|
||||
prompt_token_ids: torch.Tensor | None # Model-device tensor
|
||||
prompt_token_ids_cpu: torch.Tensor | None # CPU tensor
|
||||
pooling_params: list[PoolingParams]
|
||||
pooling_states: list[PoolingStates]
|
||||
pooling_cursor: PoolingCursor | None = None
|
||||
@@ -73,6 +74,9 @@ class PoolingMetadata:
|
||||
prompt_token_ids=None
|
||||
if self.prompt_token_ids is None
|
||||
else self.prompt_token_ids[indices],
|
||||
prompt_token_ids_cpu=None
|
||||
if self.prompt_token_ids_cpu is None
|
||||
else self.prompt_token_ids_cpu[indices],
|
||||
pooling_params=self.pooling_params[indices],
|
||||
pooling_states=self.pooling_states[indices],
|
||||
pooling_cursor=None
|
||||
@@ -85,7 +89,13 @@ class PoolingMetadata:
|
||||
assert prompt_token_ids is not None, (
|
||||
"Please set `requires_token_ids=True` in `get_pooling_updates`"
|
||||
)
|
||||
return [prompt_token_ids[i, :num] for i, num in enumerate(self.prompt_lens)]
|
||||
|
||||
def get_prompt_token_ids_cpu(self) -> list[torch.Tensor]:
|
||||
prompt_token_ids = self.prompt_token_ids_cpu
|
||||
assert prompt_token_ids is not None, (
|
||||
"Please set `requires_token_ids=True` in `get_pooling_updates`"
|
||||
)
|
||||
return [prompt_token_ids[i, :num] for i, num in enumerate(self.prompt_lens)]
|
||||
|
||||
def get_pooling_cursor(self) -> PoolingCursor:
|
||||
|
||||
@@ -833,8 +833,13 @@ class InputBatch:
|
||||
# step pooling during the sampling/pooling process.
|
||||
# Hence copy these tensors only when there are requests which
|
||||
# need penalties/step_pooler to be applied.
|
||||
prompt_token_ids_cpu = (
|
||||
self._make_prompt_token_ids_cpu_tensor() if needs_prompt_token_ids else None
|
||||
)
|
||||
prompt_token_ids = (
|
||||
self._make_prompt_token_ids_tensor() if needs_prompt_token_ids else None
|
||||
prompt_token_ids_cpu.to(device=self.device, non_blocking=True)
|
||||
if prompt_token_ids_cpu is not None
|
||||
else None
|
||||
)
|
||||
|
||||
# Only set output_token_ids if required by the current requests'
|
||||
@@ -891,15 +896,19 @@ class InputBatch:
|
||||
def get_pooling_metadata(self) -> PoolingMetadata:
|
||||
pooling_params = self.get_pooling_params()
|
||||
pooling_states = self.get_pooling_states()
|
||||
prompt_token_ids_cpu = None
|
||||
if any(p.requires_token_ids for p in pooling_params):
|
||||
prompt_token_ids_cpu = self._make_prompt_token_ids_cpu_tensor()
|
||||
|
||||
return PoolingMetadata(
|
||||
prompt_lens=self.num_prompt_tokens_cpu_tensor[: self.num_reqs].clone(),
|
||||
prompt_token_ids=self.sampling_metadata.prompt_token_ids,
|
||||
prompt_token_ids_cpu=prompt_token_ids_cpu,
|
||||
pooling_params=pooling_params,
|
||||
pooling_states=pooling_states,
|
||||
)
|
||||
|
||||
def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
|
||||
def _make_prompt_token_ids_cpu_tensor(self) -> torch.Tensor:
|
||||
num_reqs = self.num_reqs
|
||||
max_prompt_len = self.num_prompt_tokens[:num_reqs].max()
|
||||
prompt_token_ids_cpu_tensor = torch.empty(
|
||||
@@ -914,7 +923,7 @@ class InputBatch:
|
||||
# token_id of this value.
|
||||
for i in range(num_reqs):
|
||||
prompt_token_ids[i, self.num_prompt_tokens[i] :] = self.vocab_size
|
||||
return prompt_token_ids_cpu_tensor.to(device=self.device, non_blocking=True)
|
||||
return prompt_token_ids_cpu_tensor
|
||||
|
||||
def make_lora_inputs(
|
||||
self, num_scheduled_tokens: np.ndarray, num_sampled_tokens: np.ndarray
|
||||
|
||||
@@ -5653,6 +5653,7 @@ class GPUModelRunner(
|
||||
dummy_metadata = PoolingMetadata(
|
||||
prompt_lens=dummy_prompt_lens,
|
||||
prompt_token_ids=dummy_token_ids,
|
||||
prompt_token_ids_cpu=dummy_token_ids.cpu(),
|
||||
pooling_params=[dummy_pooling_params] * num_reqs,
|
||||
pooling_states=[PoolingStates() for i in range(num_reqs)],
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user