[Model] Support Qwen2 embeddings and use tags to select model tests (#10184)
This commit is contained in:
@@ -11,7 +11,8 @@ import tempfile
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from functools import lru_cache
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
||||
from typing import (AbstractSet, Callable, Dict, List, Optional, Tuple, Type,
|
||||
TypeVar, Union)
|
||||
|
||||
import cloudpickle
|
||||
import torch.nn as nn
|
||||
@@ -110,6 +111,8 @@ _EMBEDDING_MODELS = {
|
||||
},
|
||||
"MistralModel": ("llama", "LlamaEmbeddingModel"),
|
||||
"Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
|
||||
"Qwen2Model": ("qwen2", "Qwen2EmbeddingModel"),
|
||||
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
|
||||
"Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
|
||||
"Qwen2ForSequenceClassification": ("qwen2_cls", "Qwen2ForSequenceClassification"), # noqa: E501
|
||||
# [Multimodal]
|
||||
@@ -301,8 +304,8 @@ class _ModelRegistry:
|
||||
# Keyed by model_arch
|
||||
models: Dict[str, _BaseRegisteredModel] = field(default_factory=dict)
|
||||
|
||||
def get_supported_archs(self) -> List[str]:
|
||||
return list(self.models.keys())
|
||||
def get_supported_archs(self) -> AbstractSet[str]:
|
||||
return self.models.keys()
|
||||
|
||||
def register_model(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user