Support embedding models in V1 (#16188)
Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Signed-off-by: Max de Bayser <maxdebayser@gmail.com> Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com> Co-authored-by: 22quinn <33176974+22quinn@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
4959915089
commit
799397ee4f
@@ -5,6 +5,7 @@ import enum
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import is_list_of
|
||||
from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType,
|
||||
@@ -25,7 +26,8 @@ class Request:
|
||||
multi_modal_inputs: Optional[list[MultiModalKwargs]],
|
||||
multi_modal_hashes: Optional[list[str]],
|
||||
multi_modal_placeholders: Optional[list[PlaceholderRange]],
|
||||
sampling_params: SamplingParams,
|
||||
sampling_params: Optional[SamplingParams],
|
||||
pooling_params: Optional[PoolingParams],
|
||||
eos_token_id: Optional[int],
|
||||
client_index: int = 0,
|
||||
lora_request: Optional["LoRARequest"] = None,
|
||||
@@ -35,18 +37,35 @@ class Request:
|
||||
self.request_id = request_id
|
||||
self.client_index = client_index
|
||||
self.sampling_params = sampling_params
|
||||
self.pooling_params = pooling_params
|
||||
# Because of LoRA, the eos token id can be different for each request.
|
||||
self.eos_token_id = eos_token_id
|
||||
self.lora_request = lora_request
|
||||
self.structured_output_request = structured_output_request
|
||||
|
||||
self.status = (RequestStatus.WAITING_FOR_FSM
|
||||
if sampling_params.guided_decoding is not None else
|
||||
RequestStatus.WAITING)
|
||||
self.status = RequestStatus.WAITING
|
||||
if sampling_params and sampling_params.guided_decoding is not None:
|
||||
self.status = RequestStatus.WAITING_FOR_FSM
|
||||
self.events: list[EngineCoreEvent] = []
|
||||
self.stop_reason: Union[int, str, None] = None
|
||||
assert sampling_params.max_tokens is not None
|
||||
self.max_tokens = sampling_params.max_tokens
|
||||
|
||||
# P/D: Connector-specific KV transfer parameters.
|
||||
self.kv_transfer_params: Optional[dict[str, Any]] = None
|
||||
|
||||
if pooling_params is not None:
|
||||
self.max_tokens = 1
|
||||
elif sampling_params is not None:
|
||||
assert sampling_params.max_tokens is not None
|
||||
self.max_tokens = sampling_params.max_tokens
|
||||
if sampling_params.guided_decoding is not None:
|
||||
self.status = RequestStatus.WAITING_FOR_FSM
|
||||
|
||||
if sampling_params.extra_args is not None:
|
||||
self.kv_transfer_params = \
|
||||
sampling_params.extra_args.get("kv_transfer_params")
|
||||
else:
|
||||
raise ValueError(
|
||||
"sampling_params and pooling_params can't both be unset")
|
||||
|
||||
self.prompt_token_ids = prompt_token_ids
|
||||
self.num_prompt_tokens = len(self.prompt_token_ids)
|
||||
@@ -63,11 +82,6 @@ class Request:
|
||||
self.num_encoder_inputs = len(self.mm_inputs)
|
||||
self.has_encoder_inputs = self.num_encoder_inputs > 0
|
||||
|
||||
# P/D: Connector-specific KV transfer parameters.
|
||||
kv_params = (None if sampling_params.extra_args is None else
|
||||
sampling_params.extra_args.get("kv_transfer_params"))
|
||||
self.kv_transfer_params: Optional[dict[str, Any]] = kv_params
|
||||
|
||||
# Sanity check
|
||||
assert len(self.mm_inputs) == len(self.mm_positions)
|
||||
if self.mm_hashes:
|
||||
@@ -98,10 +112,12 @@ class Request:
|
||||
multi_modal_hashes=request.mm_hashes,
|
||||
multi_modal_placeholders=request.mm_placeholders,
|
||||
sampling_params=request.sampling_params,
|
||||
pooling_params=request.pooling_params,
|
||||
eos_token_id=request.eos_token_id,
|
||||
lora_request=request.lora_request,
|
||||
structured_output_request=StructuredOutputRequest(
|
||||
sampling_params=request.sampling_params),
|
||||
sampling_params=request.sampling_params) \
|
||||
if request.sampling_params else None,
|
||||
cache_salt=request.cache_salt,
|
||||
)
|
||||
|
||||
@@ -141,7 +157,8 @@ class Request:
|
||||
|
||||
@property
|
||||
def use_structured_output(self) -> bool:
|
||||
return self.sampling_params.guided_decoding is not None
|
||||
return self.sampling_params is not None and \
|
||||
self.sampling_params.guided_decoding is not None
|
||||
|
||||
def record_event(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user