diff --git a/tests/v1/worker/test_gpu_input_batch.py b/tests/v1/worker/test_gpu_input_batch.py index c4a55c837..d4eee19ad 100644 --- a/tests/v1/worker/test_gpu_input_batch.py +++ b/tests/v1/worker/test_gpu_input_batch.py @@ -380,7 +380,7 @@ def test_swap_states_in_input_batch(device: str, batch_size: int, swap_list: lis _compare_objs(input_batch, ref_input_batch) -def _construct_pooling_request(req_id_suffix: int): +def _construct_pooling_request(req_id_suffix: int, pooling_params=None): from vllm.pooling_params import PoolingParams prompt_token_ids = [ @@ -391,7 +391,7 @@ def _construct_pooling_request(req_id_suffix: int): req_id=f"pool_req_{req_id_suffix}", prompt_token_ids=prompt_token_ids, sampling_params=None, - pooling_params=PoolingParams(task="classify"), + pooling_params=pooling_params or PoolingParams(task="classify"), mm_features=[], block_ids=([],), generator=None, @@ -440,3 +440,48 @@ def test_pooling_prompt_lens_not_aliased(device: str): "mutations to num_prompt_tokens_cpu_tensor corrupted prompt_lens. " f"Expected {prompt_lens_snapshot}, got {metadata.prompt_lens}" ) + + +@pytest.mark.parametrize( + ("pooling_params", "expect_device_prompt_token_ids", "expect_cpu_prompt_token_ids"), + [ + ({"task": "classify"}, False, False), + ({"task": "classify", "requires_token_ids": True}, True, True), + ], +) +def test_pooling_metadata_token_id_buffers( + pooling_params: dict[str, object], + expect_device_prompt_token_ids: bool, + expect_cpu_prompt_token_ids: bool, +): + from vllm.pooling_params import PoolingParams + + input_batch = InputBatch( + max_num_reqs=1, + max_model_len=MAX_PROMPT_SIZE + NUM_OUTPUT_TOKENS, + max_num_batched_tokens=MAX_PROMPT_SIZE + NUM_OUTPUT_TOKENS, + device=torch.device("cpu"), + pin_memory=False, + vocab_size=VOCAB_SIZE, + block_sizes=[16], + kernel_block_sizes=[16], + is_pooling_model=True, + ) + req = _construct_pooling_request(0, PoolingParams(**pooling_params)) + input_batch.add_request(req) + input_batch.refresh_metadata() + + metadata = input_batch.get_pooling_metadata() + if expect_device_prompt_token_ids: + assert input_batch.sampling_metadata.prompt_token_ids is not None + assert metadata.prompt_token_ids is not None + assert metadata.get_prompt_token_ids()[0].tolist() == req.prompt_token_ids + else: + assert input_batch.sampling_metadata.prompt_token_ids is None + assert metadata.prompt_token_ids is None + + if expect_cpu_prompt_token_ids: + assert metadata.prompt_token_ids_cpu is not None + assert metadata.get_prompt_token_ids_cpu()[0].tolist() == req.prompt_token_ids + else: + assert metadata.prompt_token_ids_cpu is None diff --git a/vllm/model_executor/layers/pooler/common.py b/vllm/model_executor/layers/pooler/common.py index d8aa78e70..55fc4b457 100644 --- a/vllm/model_executor/layers/pooler/common.py +++ b/vllm/model_executor/layers/pooler/common.py @@ -18,7 +18,7 @@ ActivationFn = Callable[[_T], _T] @dataclass(frozen=True) class PoolingParamsUpdate: requires_token_ids: bool = False - """Set this flag to enable `get_prompt_token_ids` for your pooler.""" + """Set this flag to enable prompt token IDs for your pooler.""" def __or__(self, other: "PoolingParamsUpdate") -> "PoolingParamsUpdate": return PoolingParamsUpdate( diff --git a/vllm/model_executor/layers/pooler/special.py b/vllm/model_executor/layers/pooler/special.py index 686072632..d06663b5b 100644 --- a/vllm/model_executor/layers/pooler/special.py +++ b/vllm/model_executor/layers/pooler/special.py @@ -146,17 +146,19 @@ class BOSEOSFilter(Pooler): ) -> PoolerOutput: pooled_outputs = self.pooler(hidden_states, pooling_metadata) assert isinstance(pooled_outputs, list) + prompt_token_ids = pooling_metadata.get_prompt_token_ids_cpu() - for i, prompt_len in enumerate(pooling_metadata.prompt_lens): + for i, (prompt_len, token_ids) in enumerate( + zip(pooling_metadata.prompt_lens, prompt_token_ids) + ): pooled_data = pooled_outputs[i] assert ( isinstance(pooled_data, torch.Tensor) and pooled_data.shape[0] == prompt_len ) - token_ids = pooling_metadata.prompt_token_ids[i, :prompt_len] - if token_ids[0] == self.bos_token_id: + if int(token_ids[0]) == self.bos_token_id: pooled_data = pooled_data[1:] - if token_ids[-1] == self.eos_token_id: + if int(token_ids[-1]) == self.eos_token_id: pooled_data = pooled_data[:-1] pooled_outputs[i] = pooled_data.squeeze(-1) diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index 0cdf4f70e..01854b96d 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -638,25 +638,26 @@ class SPLADESparsePooler(Pooler): lens: list[int] = lens_tensor.tolist() B: int = len(lens) - token_ids = pooling_metadata.prompt_token_ids + prompt_token_ids = pooling_metadata.get_prompt_token_ids_cpu() offset = 0 pooled_list: list[torch.Tensor] = [] for i in range(B): L = int(lens[i]) hs = hidden_states[offset : offset + L] + token_ids = prompt_token_ids[i] start_idx = 0 end_idx = L - if self.remove_cls_sep and token_ids is not None: + if self.remove_cls_sep: if ( self.cls_token_id is not None - and token_ids[i, 0].item() == self.cls_token_id + and int(token_ids[0]) == self.cls_token_id ): start_idx = 1 if ( self.sep_token_id is not None - and token_ids[i, L - 1].item() == self.sep_token_id + and int(token_ids[L - 1]) == self.sep_token_id ): end_idx = max(start_idx, L - 1) diff --git a/vllm/model_executor/models/gritlm.py b/vllm/model_executor/models/gritlm.py index b5c6946b6..4fb9bc7b0 100644 --- a/vllm/model_executor/models/gritlm.py +++ b/vllm/model_executor/models/gritlm.py @@ -156,10 +156,11 @@ class GritLMMeanPool(SequencePoolingMethod): pooling_metadata: PoolingMetadata, ) -> SequencePoolingMethodOutput: prompt_lens = pooling_metadata.prompt_lens + prompt_token_ids = pooling_metadata.get_prompt_token_ids_cpu() instr_lens = torch.tensor( [ - self._get_instruction_len(token_ids.cpu().numpy()) - for token_ids in pooling_metadata.get_prompt_token_ids() + self._get_instruction_len(token_ids.numpy()) + for token_ids in prompt_token_ids ], device="cpu", ) diff --git a/vllm/v1/pool/metadata.py b/vllm/v1/pool/metadata.py index c9fafe142..076c87526 100644 --- a/vllm/v1/pool/metadata.py +++ b/vllm/v1/pool/metadata.py @@ -50,7 +50,8 @@ class PoolingMetadata: """Tensors for pooling.""" prompt_lens: torch.Tensor # CPU Tensor - prompt_token_ids: torch.Tensor | None + prompt_token_ids: torch.Tensor | None # Model-device tensor + prompt_token_ids_cpu: torch.Tensor | None # CPU tensor pooling_params: list[PoolingParams] pooling_states: list[PoolingStates] pooling_cursor: PoolingCursor | None = None @@ -73,6 +74,9 @@ class PoolingMetadata: prompt_token_ids=None if self.prompt_token_ids is None else self.prompt_token_ids[indices], + prompt_token_ids_cpu=None + if self.prompt_token_ids_cpu is None + else self.prompt_token_ids_cpu[indices], pooling_params=self.pooling_params[indices], pooling_states=self.pooling_states[indices], pooling_cursor=None @@ -85,7 +89,13 @@ class PoolingMetadata: assert prompt_token_ids is not None, ( "Please set `requires_token_ids=True` in `get_pooling_updates`" ) + return [prompt_token_ids[i, :num] for i, num in enumerate(self.prompt_lens)] + def get_prompt_token_ids_cpu(self) -> list[torch.Tensor]: + prompt_token_ids = self.prompt_token_ids_cpu + assert prompt_token_ids is not None, ( + "Please set `requires_token_ids=True` in `get_pooling_updates`" + ) return [prompt_token_ids[i, :num] for i, num in enumerate(self.prompt_lens)] def get_pooling_cursor(self) -> PoolingCursor: diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 11d57f1d7..b9cd10544 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -833,8 +833,13 @@ class InputBatch: # step pooling during the sampling/pooling process. # Hence copy these tensors only when there are requests which # need penalties/step_pooler to be applied. + prompt_token_ids_cpu = ( + self._make_prompt_token_ids_cpu_tensor() if needs_prompt_token_ids else None + ) prompt_token_ids = ( - self._make_prompt_token_ids_tensor() if needs_prompt_token_ids else None + prompt_token_ids_cpu.to(device=self.device, non_blocking=True) + if prompt_token_ids_cpu is not None + else None ) # Only set output_token_ids if required by the current requests' @@ -891,15 +896,19 @@ class InputBatch: def get_pooling_metadata(self) -> PoolingMetadata: pooling_params = self.get_pooling_params() pooling_states = self.get_pooling_states() + prompt_token_ids_cpu = None + if any(p.requires_token_ids for p in pooling_params): + prompt_token_ids_cpu = self._make_prompt_token_ids_cpu_tensor() return PoolingMetadata( prompt_lens=self.num_prompt_tokens_cpu_tensor[: self.num_reqs].clone(), prompt_token_ids=self.sampling_metadata.prompt_token_ids, + prompt_token_ids_cpu=prompt_token_ids_cpu, pooling_params=pooling_params, pooling_states=pooling_states, ) - def _make_prompt_token_ids_tensor(self) -> torch.Tensor: + def _make_prompt_token_ids_cpu_tensor(self) -> torch.Tensor: num_reqs = self.num_reqs max_prompt_len = self.num_prompt_tokens[:num_reqs].max() prompt_token_ids_cpu_tensor = torch.empty( @@ -914,7 +923,7 @@ class InputBatch: # token_id of this value. for i in range(num_reqs): prompt_token_ids[i, self.num_prompt_tokens[i] :] = self.vocab_size - return prompt_token_ids_cpu_tensor.to(device=self.device, non_blocking=True) + return prompt_token_ids_cpu_tensor def make_lora_inputs( self, num_scheduled_tokens: np.ndarray, num_sampled_tokens: np.ndarray diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index d44bf74c3..42d784a47 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -5653,6 +5653,7 @@ class GPUModelRunner( dummy_metadata = PoolingMetadata( prompt_lens=dummy_prompt_lens, prompt_token_ids=dummy_token_ids, + prompt_token_ids_cpu=dummy_token_ids.cpu(), pooling_params=[dummy_pooling_params] * num_reqs, pooling_states=[PoolingStates() for i in range(num_reqs)], )