[CORE] Adding support for insertion of soft-tuned prompts (#4645)
Co-authored-by: Swapnil Parekh <swapnilp@ibm.com> Co-authored-by: Joe G <joseph.granados@h2o.ai> Co-authored-by: Antoni Baum <antoni.baum@protonmail.com>
This commit is contained in:
@@ -16,12 +16,19 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
ModelPermission, TokenizeRequest)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sequence import Logprob
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptAdapterPath:
|
||||
name: str
|
||||
local_path: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoRAModulePath:
|
||||
name: str
|
||||
@@ -30,9 +37,14 @@ class LoRAModulePath:
|
||||
|
||||
class OpenAIServing:
|
||||
|
||||
def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig,
|
||||
served_model_names: List[str],
|
||||
lora_modules: Optional[List[LoRAModulePath]]):
|
||||
def __init__(
|
||||
self,
|
||||
engine: AsyncLLMEngine,
|
||||
model_config: ModelConfig,
|
||||
served_model_names: List[str],
|
||||
lora_modules: Optional[List[LoRAModulePath]],
|
||||
prompt_adapters: Optional[List[PromptAdapterPath]] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.engine = engine
|
||||
@@ -49,9 +61,8 @@ class OpenAIServing:
|
||||
|
||||
self.served_model_names = served_model_names
|
||||
|
||||
if lora_modules is None:
|
||||
self.lora_requests = []
|
||||
else:
|
||||
self.lora_requests = []
|
||||
if lora_modules is not None:
|
||||
self.lora_requests = [
|
||||
LoRARequest(
|
||||
lora_name=lora.name,
|
||||
@@ -60,6 +71,20 @@ class OpenAIServing:
|
||||
) for i, lora in enumerate(lora_modules, start=1)
|
||||
]
|
||||
|
||||
self.prompt_adapter_requests = []
|
||||
if prompt_adapters is not None:
|
||||
for i, prompt_adapter in enumerate(prompt_adapters, start=1):
|
||||
with open(f"./{prompt_adapter.local_path}"
|
||||
f"/adapter_config.json") as f:
|
||||
adapter_config = json.load(f)
|
||||
num_virtual_tokens = adapter_config["num_virtual_tokens"]
|
||||
self.prompt_adapter_requests.append(
|
||||
PromptAdapterRequest(
|
||||
prompt_adapter_name=prompt_adapter.name,
|
||||
prompt_adapter_id=i,
|
||||
prompt_adapter_local_path=prompt_adapter.local_path,
|
||||
prompt_adapter_num_virtual_tokens=num_virtual_tokens))
|
||||
|
||||
async def show_available_models(self) -> ModelList:
|
||||
"""Show available models. Right now we only have one model."""
|
||||
model_cards = [
|
||||
@@ -75,7 +100,14 @@ class OpenAIServing:
|
||||
permission=[ModelPermission()])
|
||||
for lora in self.lora_requests
|
||||
]
|
||||
prompt_adapter_cards = [
|
||||
ModelCard(id=prompt_adapter.prompt_adapter_name,
|
||||
root=self.served_model_names[0],
|
||||
permission=[ModelPermission()])
|
||||
for prompt_adapter in self.prompt_adapter_requests
|
||||
]
|
||||
model_cards.extend(lora_cards)
|
||||
model_cards.extend(prompt_adapter_cards)
|
||||
return ModelList(data=model_cards)
|
||||
|
||||
def create_error_response(
|
||||
@@ -109,20 +141,29 @@ class OpenAIServing:
|
||||
return None
|
||||
if request.model in [lora.lora_name for lora in self.lora_requests]:
|
||||
return None
|
||||
if request.model in [
|
||||
prompt_adapter.prompt_adapter_name
|
||||
for prompt_adapter in self.prompt_adapter_requests
|
||||
]:
|
||||
return None
|
||||
return self.create_error_response(
|
||||
message=f"The model `{request.model}` does not exist.",
|
||||
err_type="NotFoundError",
|
||||
status_code=HTTPStatus.NOT_FOUND)
|
||||
|
||||
def _maybe_get_lora(
|
||||
def _maybe_get_adapter(
|
||||
self, request: Union[CompletionRequest, ChatCompletionRequest,
|
||||
EmbeddingRequest]
|
||||
) -> Optional[LoRARequest]:
|
||||
) -> Tuple[Optional[str], Optional[Union[LoRARequest,
|
||||
PromptAdapterRequest]]]:
|
||||
if request.model in self.served_model_names:
|
||||
return None
|
||||
return None, None
|
||||
for lora in self.lora_requests:
|
||||
if request.model == lora.lora_name:
|
||||
return lora
|
||||
return 'LoRA', lora
|
||||
for prompt_adapter in self.prompt_adapter_requests:
|
||||
if request.model == prompt_adapter.prompt_adapter_name:
|
||||
return 'PromptAdapter', prompt_adapter
|
||||
# if _check_model has been called earlier, this will be unreachable
|
||||
raise ValueError(f"The model `{request.model}` does not exist.")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user