diff --git a/vllm/v1/pool/metadata.py b/vllm/v1/pool/metadata.py index 0764d5e6f..cb386decc 100644 --- a/vllm/v1/pool/metadata.py +++ b/vllm/v1/pool/metadata.py @@ -101,6 +101,7 @@ class PoolingMetadata: num_scheduled_tokens_np: np.ndarray, seq_lens_cpu: torch.Tensor, device: torch.device, + query_start_loc_gpu: torch.Tensor | None = None, ): n_seq = len(num_scheduled_tokens_np) prompt_lens = self.prompt_lens @@ -109,11 +110,25 @@ class PoolingMetadata: index = list(range(n_seq)) num_scheduled_tokens_cpu = torch.from_numpy(num_scheduled_tokens_np) - cumsum = torch.zeros( - n_seq + 1, dtype=torch.int64, pin_memory=pin_memory, device="cpu" - ) - torch.cumsum(num_scheduled_tokens_cpu, dim=0, out=cumsum[1:]) - cumsum = cumsum.to(device, non_blocking=True) + if query_start_loc_gpu is None: + cumsum = torch.zeros( + n_seq + 1, dtype=torch.int64, pin_memory=pin_memory, device="cpu" + ) + torch.cumsum(num_scheduled_tokens_cpu, dim=0, out=cumsum[1:]) + cumsum = cumsum.to(device, non_blocking=True) + else: + if query_start_loc_gpu.shape[0] != n_seq + 1: + raise ValueError( + "query_start_loc_gpu length does not match " + f"the number of sequences: {query_start_loc_gpu.shape[0]} " + f"!= {n_seq + 1}." + ) + if query_start_loc_gpu.device != device: + raise ValueError( + "query_start_loc_gpu must be on the same device as the " + f"hidden states: {query_start_loc_gpu.device} != {device}." + ) + cumsum = query_start_loc_gpu self.pooling_cursor = PoolingCursor( index=index, first_token_indices_gpu=cumsum[:n_seq], diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index af5dca71f..595e8cc39 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2928,7 +2928,10 @@ class GPUModelRunner( pooling_metadata = self.input_batch.get_pooling_metadata() pooling_metadata.build_pooling_cursor( - num_scheduled_tokens_np, seq_lens_cpu, device=hidden_states.device + num_scheduled_tokens_np, + seq_lens_cpu, + device=hidden_states.device, + query_start_loc_gpu=self.query_start_loc.gpu[: num_reqs + 1], ) model = cast(VllmModelForPooling, self.model)