[Core] Support Lora lineage and base model metadata management (#6315)

This commit is contained in:
Jiaxin Shan
2024-09-19 23:20:56 -07:00
committed by GitHub
parent 9e5ec35b1f
commit 260d40b5ea
15 changed files with 337 additions and 45 deletions

View File

@@ -39,6 +39,12 @@ from vllm.utils import AtomicCounter
logger = init_logger(__name__)
@dataclass
class BaseModelPath:
name: str
model_path: str
@dataclass
class PromptAdapterPath:
name: str
@@ -49,6 +55,7 @@ class PromptAdapterPath:
class LoRAModulePath:
name: str
path: str
base_model_name: Optional[str] = None
AnyRequest = Union[ChatCompletionRequest, CompletionRequest, DetokenizeRequest,
@@ -66,7 +73,7 @@ class OpenAIServing:
self,
engine_client: EngineClient,
model_config: ModelConfig,
served_model_names: List[str],
base_model_paths: List[BaseModelPath],
*,
lora_modules: Optional[List[LoRAModulePath]],
prompt_adapters: Optional[List[PromptAdapterPath]],
@@ -79,17 +86,20 @@ class OpenAIServing:
self.model_config = model_config
self.max_model_len = model_config.max_model_len
self.served_model_names = served_model_names
self.base_model_paths = base_model_paths
self.lora_id_counter = AtomicCounter(0)
self.lora_requests = []
if lora_modules is not None:
self.lora_requests = [
LoRARequest(
lora_name=lora.name,
lora_int_id=i,
lora_path=lora.path,
) for i, lora in enumerate(lora_modules, start=1)
LoRARequest(lora_name=lora.name,
lora_int_id=i,
lora_path=lora.path,
base_model_name=lora.base_model_name
if lora.base_model_name
and self._is_model_supported(lora.base_model_name)
else self.base_model_paths[0].name)
for i, lora in enumerate(lora_modules, start=1)
]
self.prompt_adapter_requests = []
@@ -112,21 +122,23 @@ class OpenAIServing:
async def show_available_models(self) -> ModelList:
"""Show available models. Right now we only have one model."""
model_cards = [
ModelCard(id=served_model_name,
ModelCard(id=base_model.name,
max_model_len=self.max_model_len,
root=self.served_model_names[0],
root=base_model.model_path,
permission=[ModelPermission()])
for served_model_name in self.served_model_names
for base_model in self.base_model_paths
]
lora_cards = [
ModelCard(id=lora.lora_name,
root=self.served_model_names[0],
root=lora.local_path,
parent=lora.base_model_name if lora.base_model_name else
self.base_model_paths[0].name,
permission=[ModelPermission()])
for lora in self.lora_requests
]
prompt_adapter_cards = [
ModelCard(id=prompt_adapter.prompt_adapter_name,
root=self.served_model_names[0],
root=self.base_model_paths[0].name,
permission=[ModelPermission()])
for prompt_adapter in self.prompt_adapter_requests
]
@@ -169,7 +181,7 @@ class OpenAIServing:
self,
request: AnyRequest,
) -> Optional[ErrorResponse]:
if request.model in self.served_model_names:
if self._is_model_supported(request.model):
return None
if request.model in [lora.lora_name for lora in self.lora_requests]:
return None
@@ -187,7 +199,7 @@ class OpenAIServing:
self, request: AnyRequest
) -> Union[Tuple[None, None], Tuple[LoRARequest, None], Tuple[
None, PromptAdapterRequest]]:
if request.model in self.served_model_names:
if self._is_model_supported(request.model):
return None, None
for lora in self.lora_requests:
if request.model == lora.lora_name:
@@ -480,3 +492,6 @@ class OpenAIServing:
if lora_request.lora_name != lora_name
]
return f"Success: LoRA adapter '{lora_name}' removed successfully."
def _is_model_supported(self, model_name):
return any(model.name == model_name for model in self.base_model_paths)