Support embedding models in V1 (#16188)
Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Signed-off-by: Max de Bayser <maxdebayser@gmail.com> Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com> Co-authored-by: 22quinn <33176974+22quinn@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
4959915089
commit
799397ee4f
@@ -146,7 +146,8 @@ class KVCacheManager:
|
||||
# Prefix caching is disabled or
|
||||
# When the request requires prompt logprobs, we skip prefix caching.
|
||||
if (not self.enable_caching
|
||||
or request.sampling_params.prompt_logprobs is not None):
|
||||
or (request.sampling_params is not None
|
||||
and request.sampling_params.prompt_logprobs is not None)):
|
||||
return self.create_empty_block_list(), 0
|
||||
|
||||
# The block hashes for the request may already be computed
|
||||
|
||||
@@ -14,6 +14,7 @@ if TYPE_CHECKING:
|
||||
KVConnectorMetadata)
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.request import Request
|
||||
|
||||
@@ -26,7 +27,8 @@ class NewRequestData:
|
||||
mm_inputs: list[MultiModalKwargs]
|
||||
mm_hashes: list[str]
|
||||
mm_positions: list[PlaceholderRange]
|
||||
sampling_params: SamplingParams
|
||||
sampling_params: Optional[SamplingParams]
|
||||
pooling_params: Optional[PoolingParams]
|
||||
block_ids: tuple[list[int], ...]
|
||||
num_computed_tokens: int
|
||||
lora_request: Optional[LoRARequest]
|
||||
@@ -44,6 +46,7 @@ class NewRequestData:
|
||||
mm_hashes=request.mm_hashes,
|
||||
mm_positions=request.mm_positions,
|
||||
sampling_params=request.sampling_params,
|
||||
pooling_params=request.pooling_params,
|
||||
block_ids=block_ids,
|
||||
num_computed_tokens=request.num_computed_tokens,
|
||||
lora_request=request.lora_request,
|
||||
|
||||
@@ -402,6 +402,15 @@ class Scheduler(SchedulerInterface):
|
||||
< num_new_tokens):
|
||||
num_new_tokens = (
|
||||
self.scheduler_config.long_prefill_token_threshold)
|
||||
|
||||
# chunked prefill has to be enabled explicitly to allow
|
||||
# pooling requests to be chunked
|
||||
if not self.scheduler_config.chunked_prefill_enabled and \
|
||||
num_new_tokens > token_budget:
|
||||
self.waiting.popleft()
|
||||
skipped_waiting_requests.appendleft(request)
|
||||
continue
|
||||
|
||||
num_new_tokens = min(num_new_tokens, token_budget)
|
||||
assert num_new_tokens > 0
|
||||
|
||||
@@ -707,6 +716,7 @@ class Scheduler(SchedulerInterface):
|
||||
logprobs = model_runner_output.logprobs
|
||||
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
|
||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
|
||||
pooler_outputs = model_runner_output.pooler_output
|
||||
|
||||
new_running: list[Request] = []
|
||||
outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list)
|
||||
@@ -724,7 +734,8 @@ class Scheduler(SchedulerInterface):
|
||||
continue
|
||||
|
||||
req_index = model_runner_output.req_id_to_index[req_id]
|
||||
generated_token_ids = sampled_token_ids[req_index]
|
||||
generated_token_ids = sampled_token_ids[
|
||||
req_index] if sampled_token_ids else []
|
||||
|
||||
scheduled_spec_token_ids = (
|
||||
scheduler_output.scheduled_spec_decode_tokens.get(req_id))
|
||||
@@ -776,8 +787,17 @@ class Scheduler(SchedulerInterface):
|
||||
del new_token_ids[num_new:] # Trim new tokens if needed.
|
||||
break
|
||||
|
||||
pooler_output = None
|
||||
if pooler_outputs:
|
||||
pooler_output = pooler_outputs[req_index]
|
||||
stopped = check_stop(request, self.max_model_len,
|
||||
pooler_output)
|
||||
if stopped:
|
||||
kv_transfer_params = self._free_request(request)
|
||||
|
||||
# Extract sample logprobs if needed.
|
||||
if request.sampling_params.logprobs is not None and logprobs:
|
||||
if request.sampling_params is not None \
|
||||
and request.sampling_params.logprobs is not None and logprobs:
|
||||
# NOTE: once we support N tokens per step (spec decode),
|
||||
# the outer lists can be of length > 1.
|
||||
new_logprobs = logprobs.slice(req_index, req_index + 1)
|
||||
@@ -802,7 +822,8 @@ class Scheduler(SchedulerInterface):
|
||||
|
||||
# Get prompt logprobs for this request.
|
||||
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
|
||||
if new_token_ids or kv_transfer_params:
|
||||
if new_token_ids or pooler_output is not None \
|
||||
or kv_transfer_params:
|
||||
|
||||
# Add EngineCoreOutput for this Request.
|
||||
outputs[request.client_index].append(
|
||||
@@ -812,6 +833,7 @@ class Scheduler(SchedulerInterface):
|
||||
finish_reason=request.get_finished_reason(),
|
||||
new_logprobs=new_logprobs,
|
||||
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
|
||||
pooling_output=pooler_output,
|
||||
stop_reason=request.stop_reason,
|
||||
events=request.take_events(),
|
||||
kv_transfer_params=kv_transfer_params,
|
||||
|
||||
@@ -1,15 +1,28 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
|
||||
|
||||
def check_stop(request: Request, max_model_len: int) -> bool:
|
||||
def check_stop(request: Request,
|
||||
max_model_len: int,
|
||||
pooler_output: Optional[torch.Tensor] = None) -> bool:
|
||||
if (request.num_tokens >= max_model_len
|
||||
or request.num_output_tokens >= request.max_tokens):
|
||||
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
|
||||
return True
|
||||
|
||||
if request.pooling_params:
|
||||
if pooler_output is not None:
|
||||
request.status = RequestStatus.FINISHED_STOPPED
|
||||
return True
|
||||
return False
|
||||
|
||||
sampling_params = request.sampling_params
|
||||
assert sampling_params is not None
|
||||
last_token_id = request.output_token_ids[-1]
|
||||
if (not sampling_params.ignore_eos
|
||||
and last_token_id == request.eos_token_id):
|
||||
|
||||
Reference in New Issue
Block a user