[Model] Support Qwen2 embeddings and use tags to select model tests (#10184)

This commit is contained in:
Cyrus Leung
2024-11-15 12:23:09 +08:00
committed by GitHub
parent 2885ba0e24
commit b40cf6402e
19 changed files with 252 additions and 178 deletions

View File

@@ -11,7 +11,8 @@ import tempfile
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from functools import lru_cache
from typing import Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union
from typing import (AbstractSet, Callable, Dict, List, Optional, Tuple, Type,
TypeVar, Union)
import cloudpickle
import torch.nn as nn
@@ -110,6 +111,8 @@ _EMBEDDING_MODELS = {
},
"MistralModel": ("llama", "LlamaEmbeddingModel"),
"Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
"Qwen2Model": ("qwen2", "Qwen2EmbeddingModel"),
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
"Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
"Qwen2ForSequenceClassification": ("qwen2_cls", "Qwen2ForSequenceClassification"), # noqa: E501
# [Multimodal]
@@ -301,8 +304,8 @@ class _ModelRegistry:
# Keyed by model_arch
models: Dict[str, _BaseRegisteredModel] = field(default_factory=dict)
def get_supported_archs(self) -> List[str]:
return list(self.models.keys())
def get_supported_archs(self) -> AbstractSet[str]:
return self.models.keys()
def register_model(
self,