[CORE] Adding support for insertion of soft-tuned prompts (#4645)
Co-authored-by: Swapnil Parekh <swapnilp@ibm.com> Co-authored-by: Joe G <joseph.granados@h2o.ai> Co-authored-by: Antoni Baum <antoni.baum@protonmail.com>
This commit is contained in:
@@ -18,6 +18,7 @@ from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
@@ -264,6 +265,7 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
request_id: str,
|
||||
inputs: PromptInputs,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> LLMInputs:
|
||||
if isinstance(inputs, str):
|
||||
inputs = {"prompt": inputs}
|
||||
@@ -279,6 +281,12 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
else:
|
||||
prompt_token_ids = inputs["prompt_token_ids"]
|
||||
|
||||
if prompt_adapter_request:
|
||||
prompt_token_ids = [
|
||||
0
|
||||
] * prompt_adapter_request.prompt_adapter_num_virtual_tokens + \
|
||||
prompt_token_ids
|
||||
|
||||
llm_inputs = LLMInputs(prompt_token_ids=prompt_token_ids,
|
||||
prompt=inputs.get("prompt"),
|
||||
multi_modal_data=inputs.get("multi_modal_data"))
|
||||
@@ -286,13 +294,14 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
return self.input_processor(llm_inputs)
|
||||
|
||||
async def add_request_async(
|
||||
self,
|
||||
request_id: str,
|
||||
inputs: PromptInputs,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Dict[str, str]] = None,
|
||||
self,
|
||||
request_id: str,
|
||||
inputs: PromptInputs,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Dict[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
||||
) -> None:
|
||||
if lora_request is not None and not self.lora_config:
|
||||
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
|
||||
@@ -301,7 +310,10 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
arrival_time = time.time()
|
||||
|
||||
processed_inputs = await self.process_model_inputs_async(
|
||||
request_id=request_id, inputs=inputs, lora_request=lora_request)
|
||||
request_id=request_id,
|
||||
inputs=inputs,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
|
||||
self._add_processed_request(
|
||||
request_id=request_id,
|
||||
@@ -309,6 +321,7 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
params=params,
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
trace_headers=trace_headers,
|
||||
)
|
||||
|
||||
@@ -627,6 +640,7 @@ class AsyncLLMEngine:
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Dict[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
||||
) -> AsyncStream:
|
||||
if self.log_requests:
|
||||
if isinstance(inputs, str):
|
||||
@@ -669,7 +683,7 @@ class AsyncLLMEngine:
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
)
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
|
||||
return stream
|
||||
|
||||
@@ -680,6 +694,7 @@ class AsyncLLMEngine:
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Dict[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
||||
) -> AsyncIterator[RequestOutput]:
|
||||
"""Generate outputs for a request.
|
||||
|
||||
@@ -695,6 +710,8 @@ class AsyncLLMEngine:
|
||||
request_id: The unique id of the request.
|
||||
lora_request: LoRA request to use for generation, if any.
|
||||
trace_headers: OpenTelemetry trace headers.
|
||||
prompt_adapter_request: Prompt Adapter request to use
|
||||
for generation, if any.
|
||||
|
||||
Yields:
|
||||
The output `RequestOutput` objects from the LLMEngine
|
||||
@@ -749,6 +766,7 @@ class AsyncLLMEngine:
|
||||
sampling_params,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
):
|
||||
yield LLMEngine.validate_output(output, RequestOutput)
|
||||
|
||||
@@ -837,6 +855,7 @@ class AsyncLLMEngine:
|
||||
*,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Dict[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]:
|
||||
"""Common logic to process requests with SamplingParams or
|
||||
PoolingParams."""
|
||||
@@ -849,6 +868,7 @@ class AsyncLLMEngine:
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user