Support embedding models in V1 (#16188)

Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
Signed-off-by: Max de Bayser <maxdebayser@gmail.com>
Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
Co-authored-by: 22quinn <33176974+22quinn@users.noreply.github.com>
This commit is contained in:
Maximilien de Bayser
2025-06-19 01:36:33 -03:00
committed by GitHub
parent 4959915089
commit 799397ee4f
56 changed files with 889 additions and 281 deletions

View File

@@ -146,7 +146,8 @@ class KVCacheManager:
# Prefix caching is disabled or
# When the request requires prompt logprobs, we skip prefix caching.
if (not self.enable_caching
or request.sampling_params.prompt_logprobs is not None):
or (request.sampling_params is not None
and request.sampling_params.prompt_logprobs is not None)):
return self.create_empty_block_list(), 0
# The block hashes for the request may already be computed

View File

@@ -14,6 +14,7 @@ if TYPE_CHECKING:
KVConnectorMetadata)
from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.v1.request import Request
@@ -26,7 +27,8 @@ class NewRequestData:
mm_inputs: list[MultiModalKwargs]
mm_hashes: list[str]
mm_positions: list[PlaceholderRange]
sampling_params: SamplingParams
sampling_params: Optional[SamplingParams]
pooling_params: Optional[PoolingParams]
block_ids: tuple[list[int], ...]
num_computed_tokens: int
lora_request: Optional[LoRARequest]
@@ -44,6 +46,7 @@ class NewRequestData:
mm_hashes=request.mm_hashes,
mm_positions=request.mm_positions,
sampling_params=request.sampling_params,
pooling_params=request.pooling_params,
block_ids=block_ids,
num_computed_tokens=request.num_computed_tokens,
lora_request=request.lora_request,

View File

@@ -402,6 +402,15 @@ class Scheduler(SchedulerInterface):
< num_new_tokens):
num_new_tokens = (
self.scheduler_config.long_prefill_token_threshold)
# chunked prefill has to be enabled explicitly to allow
# pooling requests to be chunked
if not self.scheduler_config.chunked_prefill_enabled and \
num_new_tokens > token_budget:
self.waiting.popleft()
skipped_waiting_requests.appendleft(request)
continue
num_new_tokens = min(num_new_tokens, token_budget)
assert num_new_tokens > 0
@@ -707,6 +716,7 @@ class Scheduler(SchedulerInterface):
logprobs = model_runner_output.logprobs
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
pooler_outputs = model_runner_output.pooler_output
new_running: list[Request] = []
outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list)
@@ -724,7 +734,8 @@ class Scheduler(SchedulerInterface):
continue
req_index = model_runner_output.req_id_to_index[req_id]
generated_token_ids = sampled_token_ids[req_index]
generated_token_ids = sampled_token_ids[
req_index] if sampled_token_ids else []
scheduled_spec_token_ids = (
scheduler_output.scheduled_spec_decode_tokens.get(req_id))
@@ -776,8 +787,17 @@ class Scheduler(SchedulerInterface):
del new_token_ids[num_new:] # Trim new tokens if needed.
break
pooler_output = None
if pooler_outputs:
pooler_output = pooler_outputs[req_index]
stopped = check_stop(request, self.max_model_len,
pooler_output)
if stopped:
kv_transfer_params = self._free_request(request)
# Extract sample logprobs if needed.
if request.sampling_params.logprobs is not None and logprobs:
if request.sampling_params is not None \
and request.sampling_params.logprobs is not None and logprobs:
# NOTE: once we support N tokens per step (spec decode),
# the outer lists can be of length > 1.
new_logprobs = logprobs.slice(req_index, req_index + 1)
@@ -802,7 +822,8 @@ class Scheduler(SchedulerInterface):
# Get prompt logprobs for this request.
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
if new_token_ids or kv_transfer_params:
if new_token_ids or pooler_output is not None \
or kv_transfer_params:
# Add EngineCoreOutput for this Request.
outputs[request.client_index].append(
@@ -812,6 +833,7 @@ class Scheduler(SchedulerInterface):
finish_reason=request.get_finished_reason(),
new_logprobs=new_logprobs,
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
pooling_output=pooler_output,
stop_reason=request.stop_reason,
events=request.take_events(),
kv_transfer_params=kv_transfer_params,

View File

@@ -1,15 +1,28 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
from vllm.v1.request import Request, RequestStatus
def check_stop(request: Request, max_model_len: int) -> bool:
def check_stop(request: Request,
max_model_len: int,
pooler_output: Optional[torch.Tensor] = None) -> bool:
if (request.num_tokens >= max_model_len
or request.num_output_tokens >= request.max_tokens):
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
return True
if request.pooling_params:
if pooler_output is not None:
request.status = RequestStatus.FINISHED_STOPPED
return True
return False
sampling_params = request.sampling_params
assert sampling_params is not None
last_token_id = request.output_token_ids[-1]
if (not sampling_params.ignore_eos
and last_token_id == request.eos_token_id):