[CORE] Adding support for insertion of soft-tuned prompts (#4645)
Co-authored-by: Swapnil Parekh <swapnilp@ibm.com> Co-authored-by: Joe G <joseph.granados@h2o.ai> Co-authored-by: Antoni Baum <antoni.baum@protonmail.com>
This commit is contained in:
@@ -10,6 +10,7 @@ import torch
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -238,21 +239,25 @@ class Sequence:
|
||||
block_size: The block size of the sequence. Should be the same as the
|
||||
block size used by the block manager and cache engine.
|
||||
lora_request: LoRA request.
|
||||
prompt_adapter_request: Prompt Adapter request.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
seq_id: int,
|
||||
inputs: "LLMInputs",
|
||||
block_size: int,
|
||||
eos_token_id: Optional[int] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
self,
|
||||
seq_id: int,
|
||||
inputs: "LLMInputs",
|
||||
block_size: int,
|
||||
eos_token_id: Optional[int] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
||||
) -> None:
|
||||
self.seq_id = seq_id
|
||||
self.inputs = inputs
|
||||
self.block_size = block_size
|
||||
self.eos_token_id = eos_token_id
|
||||
self.lora_request = lora_request
|
||||
self.prompt_adapter_request = prompt_adapter_request
|
||||
|
||||
self.data = SequenceData(self.prompt_token_ids)
|
||||
self.output_logprobs: SampleLogprobs = []
|
||||
@@ -287,6 +292,11 @@ class Sequence:
|
||||
def lora_int_id(self) -> int:
|
||||
return self.lora_request.lora_int_id if self.lora_request else 0
|
||||
|
||||
@property
|
||||
def prompt_adapter_id(self) -> int:
|
||||
return self.prompt_adapter_request.prompt_adapter_id \
|
||||
if self.prompt_adapter_request else 0
|
||||
|
||||
def get_output_text_to_return(self, buffer_length: int):
|
||||
# We return the full output text if the sequence is finished.
|
||||
truncate = buffer_length and not self.is_finished()
|
||||
@@ -414,6 +424,7 @@ class SequenceGroup:
|
||||
encoder_seq: Optional, the single encoder sequence. Should be None
|
||||
unless you are working with an encoder/decoder model.
|
||||
trace_headers: OpenTelemetry trace headers.
|
||||
prompt_adapter_request: Prompt Adapter request.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -427,6 +438,7 @@ class SequenceGroup:
|
||||
pooling_params: Optional[PoolingParams] = None,
|
||||
encoder_seq: Optional[Sequence] = None,
|
||||
trace_headers: Optional[Dict[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> None:
|
||||
self.request_id = request_id
|
||||
self.seqs_dict = {seq.seq_id: seq for seq in seqs}
|
||||
@@ -441,6 +453,7 @@ class SequenceGroup:
|
||||
self.state = SequenceGroupState()
|
||||
self.embeddings = embeddings
|
||||
self.pooling_params = pooling_params
|
||||
self.prompt_adapter_request = prompt_adapter_request
|
||||
self.encoder_seq = encoder_seq
|
||||
self.trace_headers = trace_headers
|
||||
|
||||
@@ -466,6 +479,16 @@ class SequenceGroup:
|
||||
def lora_int_id(self) -> int:
|
||||
return self.lora_request.lora_int_id if self.lora_request else 0
|
||||
|
||||
@property
|
||||
def prompt_adapter_id(self) -> int:
|
||||
return self.prompt_adapter_request.prompt_adapter_id \
|
||||
if self.prompt_adapter_request else 0
|
||||
|
||||
@property
|
||||
def prompt_adapter_num_virtual_tokens(self) -> int:
|
||||
return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens\
|
||||
if self.prompt_adapter_request else 0
|
||||
|
||||
def get_last_latency(self, now: float) -> Optional[float]:
|
||||
"""Sets the last token time for Request level timings."""
|
||||
# If still in prefill phase, raise Error.
|
||||
@@ -624,6 +647,7 @@ class SequenceGroupMetadata:
|
||||
(SequenceGroup.encoder_seq). Should be None
|
||||
unless you are working with an encoder/decoder
|
||||
model.
|
||||
prompt_adapter_request: Prompt Adapter request.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -642,6 +666,7 @@ class SequenceGroupMetadata:
|
||||
multi_modal_data: Optional["MultiModalDataDict"] = None,
|
||||
encoder_seq_data: Optional[SequenceData] = None,
|
||||
cross_block_table: Optional[List[int]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> None:
|
||||
self.request_id = request_id
|
||||
self.is_prompt = is_prompt
|
||||
@@ -650,6 +675,7 @@ class SequenceGroupMetadata:
|
||||
self.block_tables = block_tables
|
||||
self.pooling_params = pooling_params
|
||||
self.lora_request = lora_request
|
||||
self.prompt_adapter_request = prompt_adapter_request
|
||||
self.computed_block_nums = computed_block_nums
|
||||
self.multi_modal_data = multi_modal_data
|
||||
self.state = SequenceGroupState() if state is None else state
|
||||
@@ -674,6 +700,16 @@ class SequenceGroupMetadata:
|
||||
def lora_int_id(self) -> int:
|
||||
return self.lora_request.lora_int_id if self.lora_request else 0
|
||||
|
||||
@property
|
||||
def prompt_adapter_id(self) -> int:
|
||||
return self.prompt_adapter_request.prompt_adapter_id \
|
||||
if self.prompt_adapter_request else 0
|
||||
|
||||
@property
|
||||
def prompt_adapter_num_virtual_tokens(self) -> int:
|
||||
return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens \
|
||||
if self.prompt_adapter_request else 0
|
||||
|
||||
@property
|
||||
def token_chunk_size(self) -> int:
|
||||
"""Return the number of tokens to be processed (chunk size)."""
|
||||
|
||||
Reference in New Issue
Block a user