[CORE] Adding support for insertion of soft-tuned prompts (#4645)

Co-authored-by: Swapnil Parekh <swapnilp@ibm.com>
Co-authored-by: Joe G <joseph.granados@h2o.ai>
Co-authored-by: Antoni Baum <antoni.baum@protonmail.com>
This commit is contained in:
Swapnil Parekh
2024-07-09 16:26:36 -04:00
committed by GitHub
parent a0550cbc80
commit 4d6ada947c
48 changed files with 1952 additions and 519 deletions

View File

@@ -10,6 +10,7 @@ import torch
from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
if TYPE_CHECKING:
@@ -238,21 +239,25 @@ class Sequence:
block_size: The block size of the sequence. Should be the same as the
block size used by the block manager and cache engine.
lora_request: LoRA request.
prompt_adapter_request: Prompt Adapter request.
"""
def __init__(
self,
seq_id: int,
inputs: "LLMInputs",
block_size: int,
eos_token_id: Optional[int] = None,
lora_request: Optional[LoRARequest] = None,
self,
seq_id: int,
inputs: "LLMInputs",
block_size: int,
eos_token_id: Optional[int] = None,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> None:
self.seq_id = seq_id
self.inputs = inputs
self.block_size = block_size
self.eos_token_id = eos_token_id
self.lora_request = lora_request
self.prompt_adapter_request = prompt_adapter_request
self.data = SequenceData(self.prompt_token_ids)
self.output_logprobs: SampleLogprobs = []
@@ -287,6 +292,11 @@ class Sequence:
def lora_int_id(self) -> int:
return self.lora_request.lora_int_id if self.lora_request else 0
@property
def prompt_adapter_id(self) -> int:
return self.prompt_adapter_request.prompt_adapter_id \
if self.prompt_adapter_request else 0
def get_output_text_to_return(self, buffer_length: int):
# We return the full output text if the sequence is finished.
truncate = buffer_length and not self.is_finished()
@@ -414,6 +424,7 @@ class SequenceGroup:
encoder_seq: Optional, the single encoder sequence. Should be None
unless you are working with an encoder/decoder model.
trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: Prompt Adapter request.
"""
def __init__(
@@ -427,6 +438,7 @@ class SequenceGroup:
pooling_params: Optional[PoolingParams] = None,
encoder_seq: Optional[Sequence] = None,
trace_headers: Optional[Dict[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> None:
self.request_id = request_id
self.seqs_dict = {seq.seq_id: seq for seq in seqs}
@@ -441,6 +453,7 @@ class SequenceGroup:
self.state = SequenceGroupState()
self.embeddings = embeddings
self.pooling_params = pooling_params
self.prompt_adapter_request = prompt_adapter_request
self.encoder_seq = encoder_seq
self.trace_headers = trace_headers
@@ -466,6 +479,16 @@ class SequenceGroup:
def lora_int_id(self) -> int:
return self.lora_request.lora_int_id if self.lora_request else 0
@property
def prompt_adapter_id(self) -> int:
return self.prompt_adapter_request.prompt_adapter_id \
if self.prompt_adapter_request else 0
@property
def prompt_adapter_num_virtual_tokens(self) -> int:
return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens\
if self.prompt_adapter_request else 0
def get_last_latency(self, now: float) -> Optional[float]:
"""Sets the last token time for Request level timings."""
# If still in prefill phase, raise Error.
@@ -624,6 +647,7 @@ class SequenceGroupMetadata:
(SequenceGroup.encoder_seq). Should be None
unless you are working with an encoder/decoder
model.
prompt_adapter_request: Prompt Adapter request.
"""
def __init__(
@@ -642,6 +666,7 @@ class SequenceGroupMetadata:
multi_modal_data: Optional["MultiModalDataDict"] = None,
encoder_seq_data: Optional[SequenceData] = None,
cross_block_table: Optional[List[int]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> None:
self.request_id = request_id
self.is_prompt = is_prompt
@@ -650,6 +675,7 @@ class SequenceGroupMetadata:
self.block_tables = block_tables
self.pooling_params = pooling_params
self.lora_request = lora_request
self.prompt_adapter_request = prompt_adapter_request
self.computed_block_nums = computed_block_nums
self.multi_modal_data = multi_modal_data
self.state = SequenceGroupState() if state is None else state
@@ -674,6 +700,16 @@ class SequenceGroupMetadata:
def lora_int_id(self) -> int:
return self.lora_request.lora_int_id if self.lora_request else 0
@property
def prompt_adapter_id(self) -> int:
return self.prompt_adapter_request.prompt_adapter_id \
if self.prompt_adapter_request else 0
@property
def prompt_adapter_num_virtual_tokens(self) -> int:
return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens \
if self.prompt_adapter_request else 0
@property
def token_chunk_size(self) -> int:
"""Return the number of tokens to be processed (chunk size)."""