[CORE] Adding support for insertion of soft-tuned prompts (#4645)

Co-authored-by: Swapnil Parekh <swapnilp@ibm.com>
Co-authored-by: Joe G <joseph.granados@h2o.ai>
Co-authored-by: Antoni Baum <antoni.baum@protonmail.com>
This commit is contained in:
Swapnil Parekh
2024-07-09 16:26:36 -04:00
committed by GitHub
parent a0550cbc80
commit 4d6ada947c
48 changed files with 1952 additions and 519 deletions

View File

@@ -16,12 +16,19 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
ModelPermission, TokenizeRequest)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import get_tokenizer
logger = init_logger(__name__)
@dataclass
class PromptAdapterPath:
name: str
local_path: str
@dataclass
class LoRAModulePath:
name: str
@@ -30,9 +37,14 @@ class LoRAModulePath:
class OpenAIServing:
def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig,
served_model_names: List[str],
lora_modules: Optional[List[LoRAModulePath]]):
def __init__(
self,
engine: AsyncLLMEngine,
model_config: ModelConfig,
served_model_names: List[str],
lora_modules: Optional[List[LoRAModulePath]],
prompt_adapters: Optional[List[PromptAdapterPath]] = None,
):
super().__init__()
self.engine = engine
@@ -49,9 +61,8 @@ class OpenAIServing:
self.served_model_names = served_model_names
if lora_modules is None:
self.lora_requests = []
else:
self.lora_requests = []
if lora_modules is not None:
self.lora_requests = [
LoRARequest(
lora_name=lora.name,
@@ -60,6 +71,20 @@ class OpenAIServing:
) for i, lora in enumerate(lora_modules, start=1)
]
self.prompt_adapter_requests = []
if prompt_adapters is not None:
for i, prompt_adapter in enumerate(prompt_adapters, start=1):
with open(f"./{prompt_adapter.local_path}"
f"/adapter_config.json") as f:
adapter_config = json.load(f)
num_virtual_tokens = adapter_config["num_virtual_tokens"]
self.prompt_adapter_requests.append(
PromptAdapterRequest(
prompt_adapter_name=prompt_adapter.name,
prompt_adapter_id=i,
prompt_adapter_local_path=prompt_adapter.local_path,
prompt_adapter_num_virtual_tokens=num_virtual_tokens))
async def show_available_models(self) -> ModelList:
"""Show available models. Right now we only have one model."""
model_cards = [
@@ -75,7 +100,14 @@ class OpenAIServing:
permission=[ModelPermission()])
for lora in self.lora_requests
]
prompt_adapter_cards = [
ModelCard(id=prompt_adapter.prompt_adapter_name,
root=self.served_model_names[0],
permission=[ModelPermission()])
for prompt_adapter in self.prompt_adapter_requests
]
model_cards.extend(lora_cards)
model_cards.extend(prompt_adapter_cards)
return ModelList(data=model_cards)
def create_error_response(
@@ -109,20 +141,29 @@ class OpenAIServing:
return None
if request.model in [lora.lora_name for lora in self.lora_requests]:
return None
if request.model in [
prompt_adapter.prompt_adapter_name
for prompt_adapter in self.prompt_adapter_requests
]:
return None
return self.create_error_response(
message=f"The model `{request.model}` does not exist.",
err_type="NotFoundError",
status_code=HTTPStatus.NOT_FOUND)
def _maybe_get_lora(
def _maybe_get_adapter(
self, request: Union[CompletionRequest, ChatCompletionRequest,
EmbeddingRequest]
) -> Optional[LoRARequest]:
) -> Tuple[Optional[str], Optional[Union[LoRARequest,
PromptAdapterRequest]]]:
if request.model in self.served_model_names:
return None
return None, None
for lora in self.lora_requests:
if request.model == lora.lora_name:
return lora
return 'LoRA', lora
for prompt_adapter in self.prompt_adapter_requests:
if request.model == prompt_adapter.prompt_adapter_name:
return 'PromptAdapter', prompt_adapter
# if _check_model has been called earlier, this will be unreachable
raise ValueError(f"The model `{request.model}` does not exist.")