[Model] Add support for BERT-like Chinese ERNIE pooling models (#36385)
Signed-off-by: whyiug <whyiug@hotmail.com> Co-authored-by: wang.yuqi <yuqi.wang@daocloud.io>
This commit is contained in:
@@ -514,6 +514,7 @@ These models primarily support the [`LLM.embed`](./pooling_models.md#llmembed) A
|
|||||||
| ------------ | ------ | ----------------- | -------------------- | ------------------------- |
|
| ------------ | ------ | ----------------- | -------------------- | ------------------------- |
|
||||||
| `BertModel`<sup>C</sup> | BERT-based | `BAAI/bge-base-en-v1.5`, `Snowflake/snowflake-arctic-embed-xs`, etc. | | |
|
| `BertModel`<sup>C</sup> | BERT-based | `BAAI/bge-base-en-v1.5`, `Snowflake/snowflake-arctic-embed-xs`, etc. | | |
|
||||||
| `BertSpladeSparseEmbeddingModel` | SPLADE | `naver/splade-v3` | | |
|
| `BertSpladeSparseEmbeddingModel` | SPLADE | `naver/splade-v3` | | |
|
||||||
|
| `ErnieModel` | BERT-like Chinese ERNIE | `shibing624/text2vec-base-chinese-sentence` | | |
|
||||||
| `Gemma2Model`<sup>C</sup> | Gemma 2-based | `BAAI/bge-multilingual-gemma2`, etc. | ✅︎ | ✅︎ |
|
| `Gemma2Model`<sup>C</sup> | Gemma 2-based | `BAAI/bge-multilingual-gemma2`, etc. | ✅︎ | ✅︎ |
|
||||||
| `Gemma3TextModel`<sup>C</sup> | Gemma 3-based | `google/embeddinggemma-300m`, etc. | ✅︎ | ✅︎ |
|
| `Gemma3TextModel`<sup>C</sup> | Gemma 3-based | `google/embeddinggemma-300m`, etc. | ✅︎ | ✅︎ |
|
||||||
| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ |
|
| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ |
|
||||||
@@ -556,8 +557,9 @@ These models primarily support the [`LLM.classify`](./pooling_models.md#llmclass
|
|||||||
|
|
||||||
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) |
|
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) |
|
||||||
| ------------ | ------ | ----------------- | -------------------- | ------------------------- |
|
| ------------ | ------ | ----------------- | -------------------- | ------------------------- |
|
||||||
| `JambaForSequenceClassification` | Jamba | `ai21labs/Jamba-tiny-reward-dev`, etc. | ✅︎ | ✅︎ |
|
| `ErnieForSequenceClassification` | BERT-like Chinese ERNIE | `Forrest20231206/ernie-3.0-base-zh-cls` | | |
|
||||||
| `GPT2ForSequenceClassification` | GPT2 | `nie3e/sentiment-polish-gpt2-small` | | |
|
| `GPT2ForSequenceClassification` | GPT2 | `nie3e/sentiment-polish-gpt2-small` | | |
|
||||||
|
| `JambaForSequenceClassification` | Jamba | `ai21labs/Jamba-tiny-reward-dev`, etc. | ✅︎ | ✅︎ |
|
||||||
| `*Model`<sup>C</sup>, `*ForCausalLM`<sup>C</sup>, etc. | Generative models | N/A | \* | \* |
|
| `*Model`<sup>C</sup>, `*ForCausalLM`<sup>C</sup>, etc. | Generative models | N/A | \* | \* |
|
||||||
|
|
||||||
<sup>C</sup> Automatically converted into a classification model via `--convert classify`. ([details](./pooling_models.md#model-conversion))
|
<sup>C</sup> Automatically converted into a classification model via `--convert classify`. ([details](./pooling_models.md#model-conversion))
|
||||||
@@ -574,6 +576,7 @@ These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) A
|
|||||||
| Architecture | Models | Example HF Models | Score template (see note) | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) |
|
| Architecture | Models | Example HF Models | Score template (see note) | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) |
|
||||||
| ------------ | ------ | ----------------- | ------------------------- | --------------------------- | --------------------------------------- |
|
| ------------ | ------ | ----------------- | ------------------------- | --------------------------- | --------------------------------------- |
|
||||||
| `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | N/A | | |
|
| `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | N/A | | |
|
||||||
|
| `ErnieForSequenceClassification` | BERT-like Chinese ERNIE | `Forrest20231206/ernie-3.0-base-zh-cls` | N/A | | |
|
||||||
| `GemmaForSequenceClassification` | Gemma-based | `BAAI/bge-reranker-v2-gemma`(see note), etc. | [bge-reranker-v2-gemma.jinja](../../examples/pooling/score/template/bge-reranker-v2-gemma.jinja) | ✅︎ | ✅︎ |
|
| `GemmaForSequenceClassification` | Gemma-based | `BAAI/bge-reranker-v2-gemma`(see note), etc. | [bge-reranker-v2-gemma.jinja](../../examples/pooling/score/template/bge-reranker-v2-gemma.jinja) | ✅︎ | ✅︎ |
|
||||||
| `GteNewForSequenceClassification` | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-reranker-base`, etc. | N/A | | |
|
| `GteNewForSequenceClassification` | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-reranker-base`, etc. | N/A | | |
|
||||||
| `LlamaBidirectionalForSequenceClassification`<sup>C</sup> | Llama-based with bidirectional attention | `nvidia/llama-nemotron-rerank-1b-v2`, etc. | [nemotron-rerank.jinja](../../examples/pooling/score/template/nemotron-rerank.jinja) | ✅︎ | ✅︎ |
|
| `LlamaBidirectionalForSequenceClassification`<sup>C</sup> | Llama-based with bidirectional attention | `nvidia/llama-nemotron-rerank-1b-v2`, etc. | [nemotron-rerank.jinja](../../examples/pooling/score/template/nemotron-rerank.jinja) | ✅︎ | ✅︎ |
|
||||||
@@ -639,6 +642,7 @@ These models primarily support the [`LLM.encode`](./pooling_models.md#llmencode)
|
|||||||
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) |
|
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) |
|
||||||
| ------------ | ------ | ----------------- | --------------------------- | --------------------------------------- |
|
| ------------ | ------ | ----------------- | --------------------------- | --------------------------------------- |
|
||||||
| `BertForTokenClassification` | bert-based | `boltuix/NeuroBERT-NER` (see note), etc. | | |
|
| `BertForTokenClassification` | bert-based | `boltuix/NeuroBERT-NER` (see note), etc. | | |
|
||||||
|
| `ErnieForTokenClassification` | BERT-like Chinese ERNIE | `gyr66/Ernie-3.0-base-chinese-finetuned-ner` | | |
|
||||||
| `ModernBertForTokenClassification` | ModernBERT-based | `disham993/electrical-ner-ModernBERT-base` | | |
|
| `ModernBertForTokenClassification` | ModernBERT-based | `disham993/electrical-ner-ModernBERT-base` | | |
|
||||||
|
|
||||||
!!! note
|
!!! note
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from vllm.platforms import current_platform
|
|||||||
pytest.mark.slow_test,
|
pytest.mark.slow_test,
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
|
pytest.param("Forrest20231206/ernie-3.0-base-zh-cls"),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.parametrize("dtype", ["half"] if current_platform.is_rocm() else ["float"])
|
@pytest.mark.parametrize("dtype", ["half"] if current_platform.is_rocm() else ["float"])
|
||||||
@@ -47,5 +48,6 @@ def test_models(
|
|||||||
assert torch.allclose(
|
assert torch.allclose(
|
||||||
hf_output,
|
hf_output,
|
||||||
vllm_output,
|
vllm_output,
|
||||||
|
atol=1e-3 if dtype == "float" else 1e-2,
|
||||||
rtol=2e-3 if dtype == "float" else 1e-2,
|
rtol=2e-3 if dtype == "float" else 1e-2,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -25,11 +25,17 @@ def seed_everything():
|
|||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model", ["boltuix/NeuroBERT-NER"])
|
@pytest.mark.parametrize(
|
||||||
|
"model",
|
||||||
|
[
|
||||||
|
"boltuix/NeuroBERT-NER",
|
||||||
|
"gyr66/Ernie-3.0-base-chinese-finetuned-ner",
|
||||||
|
],
|
||||||
|
)
|
||||||
# The float32 is required for this tiny model to pass the test.
|
# The float32 is required for this tiny model to pass the test.
|
||||||
@pytest.mark.parametrize("dtype", ["float"])
|
@pytest.mark.parametrize("dtype", ["float"])
|
||||||
@torch.inference_mode
|
@torch.inference_mode
|
||||||
def test_bert_models(
|
def test_bert_like_models(
|
||||||
hf_runner,
|
hf_runner,
|
||||||
vllm_runner,
|
vllm_runner,
|
||||||
example_prompts,
|
example_prompts,
|
||||||
|
|||||||
45
tests/models/language/pooling_mteb_test/test_ernie.py
Normal file
45
tests/models/language/pooling_mteb_test/test_ernie.py
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from tests.models.language.pooling.embed_utils import correctness_test_embed_models
|
||||||
|
from tests.models.utils import EmbedModelInfo
|
||||||
|
|
||||||
|
from .mteb_embed_utils import mteb_test_embed_models
|
||||||
|
|
||||||
|
MODELS = [
|
||||||
|
EmbedModelInfo(
|
||||||
|
"shibing624/text2vec-base-chinese-sentence",
|
||||||
|
architecture="ErnieModel",
|
||||||
|
mteb_score=0.536523112,
|
||||||
|
seq_pooling_type="MEAN",
|
||||||
|
attn_type="encoder_only",
|
||||||
|
is_prefix_caching_supported=False,
|
||||||
|
is_chunked_prefill_supported=False,
|
||||||
|
enable_test=True,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model_info", MODELS)
|
||||||
|
def test_embed_models_mteb(hf_runner, vllm_runner, model_info: EmbedModelInfo) -> None:
|
||||||
|
mteb_test_embed_models(
|
||||||
|
hf_runner,
|
||||||
|
vllm_runner,
|
||||||
|
model_info,
|
||||||
|
vllm_extra_kwargs={"gpu_memory_utilization": 0.2},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model_info", MODELS)
|
||||||
|
def test_embed_models_correctness(
|
||||||
|
hf_runner, vllm_runner, model_info: EmbedModelInfo, example_prompts
|
||||||
|
) -> None:
|
||||||
|
correctness_test_embed_models(
|
||||||
|
hf_runner,
|
||||||
|
vllm_runner,
|
||||||
|
model_info,
|
||||||
|
example_prompts,
|
||||||
|
vllm_extra_kwargs={"gpu_memory_utilization": 0.2},
|
||||||
|
)
|
||||||
@@ -552,6 +552,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
|||||||
_EMBEDDING_EXAMPLE_MODELS = {
|
_EMBEDDING_EXAMPLE_MODELS = {
|
||||||
# [Text-only]
|
# [Text-only]
|
||||||
"BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"),
|
"BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"),
|
||||||
|
"ErnieModel": _HfExamplesInfo("shibing624/text2vec-base-chinese-sentence"),
|
||||||
"BertSpladeSparseEmbeddingModel": _HfExamplesInfo(
|
"BertSpladeSparseEmbeddingModel": _HfExamplesInfo(
|
||||||
"naver/splade-v3",
|
"naver/splade-v3",
|
||||||
hf_overrides={"architectures": ["BertSpladeSparseEmbeddingModel"]},
|
hf_overrides={"architectures": ["BertSpladeSparseEmbeddingModel"]},
|
||||||
@@ -666,6 +667,9 @@ _REWARD_EXAMPLE_MODELS = {
|
|||||||
|
|
||||||
_TOKEN_CLASSIFICATION_EXAMPLE_MODELS = {
|
_TOKEN_CLASSIFICATION_EXAMPLE_MODELS = {
|
||||||
"BertForTokenClassification": _HfExamplesInfo("boltuix/NeuroBERT-NER"),
|
"BertForTokenClassification": _HfExamplesInfo("boltuix/NeuroBERT-NER"),
|
||||||
|
"ErnieForTokenClassification": _HfExamplesInfo(
|
||||||
|
"gyr66/Ernie-3.0-base-chinese-finetuned-ner"
|
||||||
|
),
|
||||||
"ModernBertForTokenClassification": _HfExamplesInfo(
|
"ModernBertForTokenClassification": _HfExamplesInfo(
|
||||||
"disham993/electrical-ner-ModernBERT-base"
|
"disham993/electrical-ner-ModernBERT-base"
|
||||||
),
|
),
|
||||||
@@ -675,6 +679,9 @@ _SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS = {
|
|||||||
"BertForSequenceClassification": _HfExamplesInfo(
|
"BertForSequenceClassification": _HfExamplesInfo(
|
||||||
"cross-encoder/ms-marco-MiniLM-L-6-v2"
|
"cross-encoder/ms-marco-MiniLM-L-6-v2"
|
||||||
),
|
),
|
||||||
|
"ErnieForSequenceClassification": _HfExamplesInfo(
|
||||||
|
"Forrest20231206/ernie-3.0-base-zh-cls",
|
||||||
|
),
|
||||||
"GPT2ForSequenceClassification": _HfExamplesInfo(
|
"GPT2ForSequenceClassification": _HfExamplesInfo(
|
||||||
"nie3e/sentiment-polish-gpt2-small"
|
"nie3e/sentiment-polish-gpt2-small"
|
||||||
),
|
),
|
||||||
|
|||||||
247
vllm/model_executor/models/ernie.py
Normal file
247
vllm/model_executor/models/ernie.py
Normal file
@@ -0,0 +1,247 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from collections.abc import Iterable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from transformers import BertConfig
|
||||||
|
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.model_executor.layers.pooler import DispatchPooler
|
||||||
|
from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_classify
|
||||||
|
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||||
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
|
from .bert import (
|
||||||
|
TOKEN_TYPE_SHIFT,
|
||||||
|
BertEmbedding,
|
||||||
|
BertEmbeddingModel,
|
||||||
|
BertModel,
|
||||||
|
BertPoolingModel,
|
||||||
|
_decode_token_type_ids,
|
||||||
|
_encode_token_type_ids,
|
||||||
|
)
|
||||||
|
from .interfaces import SupportsCrossEncoding, SupportsQuant
|
||||||
|
from .interfaces_base import attn_type, default_pooling_type
|
||||||
|
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
|
||||||
|
|
||||||
|
_LEGACY_SUFFIX_MAPPER = WeightsMapper(
|
||||||
|
orig_to_new_suffix={
|
||||||
|
".gamma": ".weight",
|
||||||
|
".beta": ".bias",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ErnieEmbedding(BertEmbedding):
|
||||||
|
def __init__(self, config: BertConfig):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
task_type_vocab_size = max(1, getattr(config, "task_type_vocab_size", 1))
|
||||||
|
self.task_type_embeddings = VocabParallelEmbedding(
|
||||||
|
task_type_vocab_size, config.hidden_size
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
inputs_embeds: torch.Tensor | None = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
token_type_ids = _decode_token_type_ids(input_ids)
|
||||||
|
task_type_ids = torch.zeros_like(token_type_ids)
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.word_embeddings(input_ids)
|
||||||
|
|
||||||
|
position_embeddings = self.position_embeddings(position_ids)
|
||||||
|
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
||||||
|
task_type_embeddings = self.task_type_embeddings(task_type_ids)
|
||||||
|
|
||||||
|
embeddings = (
|
||||||
|
inputs_embeds
|
||||||
|
+ token_type_embeddings
|
||||||
|
+ task_type_embeddings
|
||||||
|
+ position_embeddings
|
||||||
|
)
|
||||||
|
embeddings = self.LayerNorm(embeddings)
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
|
@default_pooling_type(seq_pooling_type="CLS")
|
||||||
|
class ErnieModel(BertModel):
|
||||||
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
super().__init__(
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
prefix=prefix,
|
||||||
|
embedding_class=ErnieEmbedding,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ErniePoolingModel(BertPoolingModel):
|
||||||
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
super().__init__(
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
prefix=prefix,
|
||||||
|
embedding_class=ErnieEmbedding,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@default_pooling_type(seq_pooling_type="CLS")
|
||||||
|
class ErnieEmbeddingModel(BertEmbeddingModel):
|
||||||
|
def _build_model(self, vllm_config: VllmConfig, prefix: str = "") -> ErnieModel:
|
||||||
|
return ErnieModel(vllm_config=vllm_config, prefix=prefix)
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||||
|
weights_list = list(weights)
|
||||||
|
has_model_prefix = any(name.startswith("model.") for name, _ in weights_list)
|
||||||
|
has_ernie_prefix = any(name.startswith("ernie.") for name, _ in weights_list)
|
||||||
|
|
||||||
|
mapper: WeightsMapper | None = None
|
||||||
|
if not has_model_prefix:
|
||||||
|
if has_ernie_prefix:
|
||||||
|
mapper = WeightsMapper(orig_to_new_prefix={"ernie.": "model."})
|
||||||
|
else:
|
||||||
|
mapper = WeightsMapper(orig_to_new_prefix={"": "model."})
|
||||||
|
if mapper is None:
|
||||||
|
mapper = _LEGACY_SUFFIX_MAPPER
|
||||||
|
else:
|
||||||
|
mapper = mapper | _LEGACY_SUFFIX_MAPPER
|
||||||
|
|
||||||
|
loader = AutoWeightsLoader(self, skip_prefixes=["lm_head.", "cls."])
|
||||||
|
return loader.load_weights(weights_list, mapper=mapper)
|
||||||
|
|
||||||
|
|
||||||
|
@default_pooling_type(seq_pooling_type="CLS")
|
||||||
|
class ErnieForSequenceClassification(nn.Module, SupportsCrossEncoding, SupportsQuant):
|
||||||
|
is_pooling_model = True
|
||||||
|
|
||||||
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
super().__init__()
|
||||||
|
config = vllm_config.model_config.hf_config
|
||||||
|
|
||||||
|
self.num_labels = config.num_labels
|
||||||
|
self.ernie = ErniePoolingModel(
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
prefix=maybe_prefix(prefix, "ernie"),
|
||||||
|
)
|
||||||
|
self.classifier = nn.Linear(
|
||||||
|
config.hidden_size,
|
||||||
|
config.num_labels,
|
||||||
|
dtype=vllm_config.model_config.head_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
pooler_config = vllm_config.model_config.pooler_config
|
||||||
|
assert pooler_config is not None
|
||||||
|
|
||||||
|
self.pooler = DispatchPooler.for_seq_cls(
|
||||||
|
pooler_config,
|
||||||
|
pooling=self.ernie.pooler,
|
||||||
|
classifier=self.classifier,
|
||||||
|
)
|
||||||
|
|
||||||
|
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.ernie.embed_input_ids(input_ids)
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||||
|
weights_list = list(weights)
|
||||||
|
has_ernie_prefix = any(name.startswith("ernie.") for name, _ in weights_list)
|
||||||
|
has_bert_prefix = any(name.startswith("bert.") for name, _ in weights_list)
|
||||||
|
|
||||||
|
mapper: WeightsMapper | None = None
|
||||||
|
if has_bert_prefix and not has_ernie_prefix:
|
||||||
|
mapper = WeightsMapper(orig_to_new_prefix={"bert.": "ernie."})
|
||||||
|
if mapper is None:
|
||||||
|
mapper = _LEGACY_SUFFIX_MAPPER
|
||||||
|
else:
|
||||||
|
mapper = mapper | _LEGACY_SUFFIX_MAPPER
|
||||||
|
|
||||||
|
loader = AutoWeightsLoader(self, skip_prefixes=["cls.", "lm_head."])
|
||||||
|
return loader.load_weights(weights_list, mapper=mapper)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor | None,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
intermediate_tensors: IntermediateTensors | None = None,
|
||||||
|
inputs_embeds: torch.Tensor | None = None,
|
||||||
|
token_type_ids: torch.Tensor | None = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if token_type_ids is not None:
|
||||||
|
assert self.ernie.config.vocab_size < (1 << TOKEN_TYPE_SHIFT)
|
||||||
|
assert input_ids is not None
|
||||||
|
_encode_token_type_ids(input_ids, token_type_ids)
|
||||||
|
|
||||||
|
return self.ernie(
|
||||||
|
input_ids=input_ids,
|
||||||
|
positions=positions,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
intermediate_tensors=intermediate_tensors,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@attn_type("encoder_only")
|
||||||
|
@default_pooling_type(tok_pooling_type="ALL")
|
||||||
|
class ErnieForTokenClassification(nn.Module):
|
||||||
|
is_pooling_model = True
|
||||||
|
|
||||||
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
super().__init__()
|
||||||
|
config = vllm_config.model_config.hf_config
|
||||||
|
self.head_dtype = vllm_config.model_config.head_dtype
|
||||||
|
self.num_labels = config.num_labels
|
||||||
|
self.ernie = ErnieModel(
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
prefix=maybe_prefix(prefix, "ernie"),
|
||||||
|
)
|
||||||
|
self.classifier = nn.Linear(
|
||||||
|
config.hidden_size, config.num_labels, dtype=self.head_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
pooler_config = vllm_config.model_config.pooler_config
|
||||||
|
assert pooler_config is not None
|
||||||
|
|
||||||
|
self.pooler = pooler_for_token_classify(pooler_config)
|
||||||
|
|
||||||
|
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.ernie.embed_input_ids(input_ids)
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||||
|
weights_list = list(weights)
|
||||||
|
has_ernie_prefix = any(name.startswith("ernie.") for name, _ in weights_list)
|
||||||
|
has_bert_prefix = any(name.startswith("bert.") for name, _ in weights_list)
|
||||||
|
|
||||||
|
mapper: WeightsMapper | None = None
|
||||||
|
if has_bert_prefix and not has_ernie_prefix:
|
||||||
|
mapper = WeightsMapper(orig_to_new_prefix={"bert.": "ernie."})
|
||||||
|
if mapper is None:
|
||||||
|
mapper = _LEGACY_SUFFIX_MAPPER
|
||||||
|
else:
|
||||||
|
mapper = mapper | _LEGACY_SUFFIX_MAPPER
|
||||||
|
|
||||||
|
loader = AutoWeightsLoader(self, skip_prefixes=["cls.", "lm_head."])
|
||||||
|
return loader.load_weights(weights_list, mapper=mapper)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor | None,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
intermediate_tensors: IntermediateTensors | None = None,
|
||||||
|
inputs_embeds: torch.Tensor | None = None,
|
||||||
|
token_type_ids: torch.Tensor | None = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if token_type_ids is not None:
|
||||||
|
assert self.ernie.config.vocab_size < (1 << TOKEN_TYPE_SHIFT)
|
||||||
|
assert input_ids is not None
|
||||||
|
_encode_token_type_ids(input_ids, token_type_ids)
|
||||||
|
|
||||||
|
hidden_states = self.ernie(
|
||||||
|
input_ids=input_ids,
|
||||||
|
positions=positions,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
intermediate_tensors=intermediate_tensors,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = hidden_states.to(self.head_dtype)
|
||||||
|
return self.classifier(hidden_states)
|
||||||
@@ -214,6 +214,7 @@ _EMBEDDING_MODELS = {
|
|||||||
# [Text-only]
|
# [Text-only]
|
||||||
"BertModel": ("bert", "BertEmbeddingModel"),
|
"BertModel": ("bert", "BertEmbeddingModel"),
|
||||||
"BertSpladeSparseEmbeddingModel": ("bert", "BertSpladeSparseEmbeddingModel"),
|
"BertSpladeSparseEmbeddingModel": ("bert", "BertSpladeSparseEmbeddingModel"),
|
||||||
|
"ErnieModel": ("ernie", "ErnieEmbeddingModel"),
|
||||||
"BgeM3EmbeddingModel": ("roberta", "BgeM3EmbeddingModel"),
|
"BgeM3EmbeddingModel": ("roberta", "BgeM3EmbeddingModel"),
|
||||||
"DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
|
"DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
|
||||||
"Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
|
"Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
|
||||||
@@ -286,6 +287,7 @@ _REWARD_MODELS = {
|
|||||||
|
|
||||||
_TOKEN_CLASSIFICATION_MODELS = {
|
_TOKEN_CLASSIFICATION_MODELS = {
|
||||||
"BertForTokenClassification": ("bert", "BertForTokenClassification"),
|
"BertForTokenClassification": ("bert", "BertForTokenClassification"),
|
||||||
|
"ErnieForTokenClassification": ("ernie", "ErnieForTokenClassification"),
|
||||||
"ModernBertForTokenClassification": (
|
"ModernBertForTokenClassification": (
|
||||||
"modernbert",
|
"modernbert",
|
||||||
"ModernBertForTokenClassification",
|
"ModernBertForTokenClassification",
|
||||||
@@ -295,6 +297,7 @@ _TOKEN_CLASSIFICATION_MODELS = {
|
|||||||
_SEQUENCE_CLASSIFICATION_MODELS = {
|
_SEQUENCE_CLASSIFICATION_MODELS = {
|
||||||
"BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
|
"BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
|
||||||
"GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
|
"GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
|
||||||
|
"ErnieForSequenceClassification": ("ernie", "ErnieForSequenceClassification"),
|
||||||
"GteNewForSequenceClassification": (
|
"GteNewForSequenceClassification": (
|
||||||
"bert_with_rope",
|
"bert_with_rope",
|
||||||
"GteNewForSequenceClassification",
|
"GteNewForSequenceClassification",
|
||||||
|
|||||||
Reference in New Issue
Block a user