[New Model]: Support GteNewModelForSequenceClassification (#23524)
Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
@@ -3,7 +3,8 @@
|
||||
|
||||
import warnings
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, NamedTuple, Optional, Union
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -339,36 +340,43 @@ def softmax(data):
|
||||
return F.softmax(data, dim=-1)
|
||||
|
||||
|
||||
class EmbedModelInfo(NamedTuple):
|
||||
@dataclass
|
||||
class ModelInfo:
|
||||
name: str
|
||||
is_matryoshka: bool = False
|
||||
matryoshka_dimensions: Optional[list[int]] = None
|
||||
architecture: str = ""
|
||||
dtype: str = "auto"
|
||||
hf_overrides: Optional[dict[str, Any]] = None
|
||||
default_pooling_type: str = ""
|
||||
enable_test: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbedModelInfo(ModelInfo):
|
||||
is_matryoshka: bool = False
|
||||
matryoshka_dimensions: Optional[list[int]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class CLSPoolingEmbedModelInfo(EmbedModelInfo):
|
||||
default_pooling_type: str = "CLS"
|
||||
|
||||
|
||||
@dataclass
|
||||
class LASTPoolingEmbedModelInfo(EmbedModelInfo):
|
||||
default_pooling_type: str = "LAST"
|
||||
|
||||
|
||||
class RerankModelInfo(NamedTuple):
|
||||
name: str
|
||||
architecture: str = ""
|
||||
dtype: str = "auto"
|
||||
default_pooling_type: str = ""
|
||||
enable_test: bool = True
|
||||
@dataclass
|
||||
class RerankModelInfo(ModelInfo):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class CLSPoolingRerankModelInfo(RerankModelInfo):
|
||||
default_pooling_type: str = "CLS"
|
||||
|
||||
|
||||
@dataclass
|
||||
class LASTPoolingRerankModelInfo(RerankModelInfo):
|
||||
default_pooling_type: str = "LAST"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user