[CORE] Adding support for insertion of soft-tuned prompts (#4645)

Co-authored-by: Swapnil Parekh <swapnilp@ibm.com>
Co-authored-by: Joe G <joseph.granados@h2o.ai>
Co-authored-by: Antoni Baum <antoni.baum@protonmail.com>
This commit is contained in:
Swapnil Parekh
2024-07-09 16:26:36 -04:00
committed by GitHub
parent a0550cbc80
commit 4d6ada947c
48 changed files with 1952 additions and 519 deletions

View File

@@ -22,7 +22,8 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
TokenizeResponse, UsageInfo)
# yapf: enable
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing)
OpenAIServing,
PromptAdapterPath)
from vllm.logger import init_logger
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
@@ -67,11 +68,13 @@ class OpenAIServingCompletion(OpenAIServing):
def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig,
served_model_names: List[str],
lora_modules: Optional[List[LoRAModulePath]]):
lora_modules: Optional[List[LoRAModulePath]],
prompt_adapters: Optional[List[PromptAdapterPath]]):
super().__init__(engine=engine,
model_config=model_config,
served_model_names=served_model_names,
lora_modules=lora_modules)
lora_modules=lora_modules,
prompt_adapters=prompt_adapters)
async def create_completion(self, request: CompletionRequest,
raw_request: Request):
@@ -101,7 +104,12 @@ class OpenAIServingCompletion(OpenAIServing):
generators: List[AsyncIterator[RequestOutput]] = []
try:
sampling_params = request.to_sampling_params()
lora_request = self._maybe_get_lora(request)
adapter_type, adapter_request = self._maybe_get_adapter(request)
lora_request, prompt_adapter_request = None, None
if adapter_type == 'LoRA':
lora_request, prompt_adapter_request = adapter_request, None
elif adapter_type == 'PromptAdapter':
lora_request, prompt_adapter_request = None, adapter_request
decoding_config = await self.engine.get_decoding_config()
guided_decoding_backend = request.guided_decoding_backend \
or decoding_config.guided_decoding_backend
@@ -147,6 +155,7 @@ class OpenAIServingCompletion(OpenAIServing):
sampling_params,
f"{request_id}-{i}",
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers,
)