[New Model]: Support GteNewModelForSequenceClassification (#23524)

Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
wang.yuqi
2025-08-28 15:36:42 +08:00
committed by GitHub
parent 186aced5ff
commit 11a7fafaa8
13 changed files with 157 additions and 76 deletions

View File

@@ -3,7 +3,8 @@
import warnings
from collections.abc import Sequence
from typing import Any, NamedTuple, Optional, Union
from dataclasses import dataclass
from typing import Any, Optional, Union
import torch
import torch.nn.functional as F
@@ -339,36 +340,43 @@ def softmax(data):
return F.softmax(data, dim=-1)
class EmbedModelInfo(NamedTuple):
@dataclass
class ModelInfo:
name: str
is_matryoshka: bool = False
matryoshka_dimensions: Optional[list[int]] = None
architecture: str = ""
dtype: str = "auto"
hf_overrides: Optional[dict[str, Any]] = None
default_pooling_type: str = ""
enable_test: bool = True
@dataclass
class EmbedModelInfo(ModelInfo):
is_matryoshka: bool = False
matryoshka_dimensions: Optional[list[int]] = None
@dataclass
class CLSPoolingEmbedModelInfo(EmbedModelInfo):
default_pooling_type: str = "CLS"
@dataclass
class LASTPoolingEmbedModelInfo(EmbedModelInfo):
default_pooling_type: str = "LAST"
class RerankModelInfo(NamedTuple):
name: str
architecture: str = ""
dtype: str = "auto"
default_pooling_type: str = ""
enable_test: bool = True
@dataclass
class RerankModelInfo(ModelInfo):
pass
@dataclass
class CLSPoolingRerankModelInfo(RerankModelInfo):
default_pooling_type: str = "CLS"
@dataclass
class LASTPoolingRerankModelInfo(RerankModelInfo):
default_pooling_type: str = "LAST"