[Model] Explicit interface for vLLM models and support OOT embedding models (#9108)

This commit is contained in:
Cyrus Leung
2024-10-07 14:10:35 +08:00
committed by GitHub
parent 18b296fdb2
commit 8c6de96ea1
10 changed files with 342 additions and 37 deletions

View File

@@ -9,6 +9,12 @@ def register():
ModelRegistry.register_model("MyOPTForCausalLM", MyOPTForCausalLM)
# Test passing lazy model
if "MyGemma2Embedding" not in ModelRegistry.get_supported_archs():
ModelRegistry.register_model(
"MyGemma2Embedding",
"vllm_add_dummy_model.my_gemma_embedding:MyGemma2Embedding",
)
if "MyLlava" not in ModelRegistry.get_supported_archs():
ModelRegistry.register_model("MyLlava",
"vllm_add_dummy_model.my_llava:MyLlava")

View File

@@ -0,0 +1,34 @@
from typing import List, Optional, Union
import torch
from vllm.attention import AttentionMetadata
from vllm.model_executor.models.gemma2_embedding import Gemma2EmbeddingModel
from vllm.sequence import IntermediateTensors
class MyGemma2Embedding(Gemma2EmbeddingModel):
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = super().forward(
input_ids,
positions,
kv_caches,
attn_metadata,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
if isinstance(hidden_states, IntermediateTensors):
return hidden_states
# Return all-zero embeddings
return torch.zeros_like(hidden_states)