[CORE] Adding support for insertion of soft-tuned prompts (#4645)
Co-authored-by: Swapnil Parekh <swapnilp@ibm.com> Co-authored-by: Joe G <joseph.granados@h2o.ai> Co-authored-by: Antoni Baum <antoni.baum@protonmail.com>
This commit is contained in:
@@ -22,7 +22,8 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
|
||||
TokenizeResponse, UsageInfo)
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
||||
OpenAIServing)
|
||||
OpenAIServing,
|
||||
PromptAdapterPath)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.guided_decoding import (
|
||||
get_guided_decoding_logits_processor)
|
||||
@@ -67,11 +68,13 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
|
||||
def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig,
|
||||
served_model_names: List[str],
|
||||
lora_modules: Optional[List[LoRAModulePath]]):
|
||||
lora_modules: Optional[List[LoRAModulePath]],
|
||||
prompt_adapters: Optional[List[PromptAdapterPath]]):
|
||||
super().__init__(engine=engine,
|
||||
model_config=model_config,
|
||||
served_model_names=served_model_names,
|
||||
lora_modules=lora_modules)
|
||||
lora_modules=lora_modules,
|
||||
prompt_adapters=prompt_adapters)
|
||||
|
||||
async def create_completion(self, request: CompletionRequest,
|
||||
raw_request: Request):
|
||||
@@ -101,7 +104,12 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
generators: List[AsyncIterator[RequestOutput]] = []
|
||||
try:
|
||||
sampling_params = request.to_sampling_params()
|
||||
lora_request = self._maybe_get_lora(request)
|
||||
adapter_type, adapter_request = self._maybe_get_adapter(request)
|
||||
lora_request, prompt_adapter_request = None, None
|
||||
if adapter_type == 'LoRA':
|
||||
lora_request, prompt_adapter_request = adapter_request, None
|
||||
elif adapter_type == 'PromptAdapter':
|
||||
lora_request, prompt_adapter_request = None, adapter_request
|
||||
decoding_config = await self.engine.get_decoding_config()
|
||||
guided_decoding_backend = request.guided_decoding_backend \
|
||||
or decoding_config.guided_decoding_backend
|
||||
@@ -147,6 +155,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
sampling_params,
|
||||
f"{request_id}-{i}",
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
trace_headers=trace_headers,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user