[Model] Add support for BERT-like Chinese ERNIE pooling models (#36385)

Signed-off-by: whyiug <whyiug@hotmail.com>
Co-authored-by: wang.yuqi <yuqi.wang@daocloud.io>
This commit is contained in:
whyiug
2026-03-13 11:23:53 +08:00
committed by GitHub
parent 10f08dedfa
commit 1ce13cf992
7 changed files with 317 additions and 3 deletions

View File

@@ -514,6 +514,7 @@ These models primarily support the [`LLM.embed`](./pooling_models.md#llmembed) A
| ------------ | ------ | ----------------- | -------------------- | ------------------------- | | ------------ | ------ | ----------------- | -------------------- | ------------------------- |
| `BertModel`<sup>C</sup> | BERT-based | `BAAI/bge-base-en-v1.5`, `Snowflake/snowflake-arctic-embed-xs`, etc. | | | | `BertModel`<sup>C</sup> | BERT-based | `BAAI/bge-base-en-v1.5`, `Snowflake/snowflake-arctic-embed-xs`, etc. | | |
| `BertSpladeSparseEmbeddingModel` | SPLADE | `naver/splade-v3` | | | | `BertSpladeSparseEmbeddingModel` | SPLADE | `naver/splade-v3` | | |
| `ErnieModel` | BERT-like Chinese ERNIE | `shibing624/text2vec-base-chinese-sentence` | | |
| `Gemma2Model`<sup>C</sup> | Gemma 2-based | `BAAI/bge-multilingual-gemma2`, etc. | ✅︎ | ✅︎ | | `Gemma2Model`<sup>C</sup> | Gemma 2-based | `BAAI/bge-multilingual-gemma2`, etc. | ✅︎ | ✅︎ |
| `Gemma3TextModel`<sup>C</sup> | Gemma 3-based | `google/embeddinggemma-300m`, etc. | ✅︎ | ✅︎ | | `Gemma3TextModel`<sup>C</sup> | Gemma 3-based | `google/embeddinggemma-300m`, etc. | ✅︎ | ✅︎ |
| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | | `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ |
@@ -556,8 +557,9 @@ These models primarily support the [`LLM.classify`](./pooling_models.md#llmclass
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | | Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) |
| ------------ | ------ | ----------------- | -------------------- | ------------------------- | | ------------ | ------ | ----------------- | -------------------- | ------------------------- |
| `JambaForSequenceClassification` | Jamba | `ai21labs/Jamba-tiny-reward-dev`, etc. | ✅︎ | ✅︎ | | `ErnieForSequenceClassification` | BERT-like Chinese ERNIE | `Forrest20231206/ernie-3.0-base-zh-cls` | | |
| `GPT2ForSequenceClassification` | GPT2 | `nie3e/sentiment-polish-gpt2-small` | | | | `GPT2ForSequenceClassification` | GPT2 | `nie3e/sentiment-polish-gpt2-small` | | |
| `JambaForSequenceClassification` | Jamba | `ai21labs/Jamba-tiny-reward-dev`, etc. | ✅︎ | ✅︎ |
| `*Model`<sup>C</sup>, `*ForCausalLM`<sup>C</sup>, etc. | Generative models | N/A | \* | \* | | `*Model`<sup>C</sup>, `*ForCausalLM`<sup>C</sup>, etc. | Generative models | N/A | \* | \* |
<sup>C</sup> Automatically converted into a classification model via `--convert classify`. ([details](./pooling_models.md#model-conversion)) <sup>C</sup> Automatically converted into a classification model via `--convert classify`. ([details](./pooling_models.md#model-conversion))
@@ -574,6 +576,7 @@ These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) A
| Architecture | Models | Example HF Models | Score template (see note) | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | | Architecture | Models | Example HF Models | Score template (see note) | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) |
| ------------ | ------ | ----------------- | ------------------------- | --------------------------- | --------------------------------------- | | ------------ | ------ | ----------------- | ------------------------- | --------------------------- | --------------------------------------- |
| `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | N/A | | | | `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | N/A | | |
| `ErnieForSequenceClassification` | BERT-like Chinese ERNIE | `Forrest20231206/ernie-3.0-base-zh-cls` | N/A | | |
| `GemmaForSequenceClassification` | Gemma-based | `BAAI/bge-reranker-v2-gemma`(see note), etc. | [bge-reranker-v2-gemma.jinja](../../examples/pooling/score/template/bge-reranker-v2-gemma.jinja) | ✅︎ | ✅︎ | | `GemmaForSequenceClassification` | Gemma-based | `BAAI/bge-reranker-v2-gemma`(see note), etc. | [bge-reranker-v2-gemma.jinja](../../examples/pooling/score/template/bge-reranker-v2-gemma.jinja) | ✅︎ | ✅︎ |
| `GteNewForSequenceClassification` | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-reranker-base`, etc. | N/A | | | | `GteNewForSequenceClassification` | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-reranker-base`, etc. | N/A | | |
| `LlamaBidirectionalForSequenceClassification`<sup>C</sup> | Llama-based with bidirectional attention | `nvidia/llama-nemotron-rerank-1b-v2`, etc. | [nemotron-rerank.jinja](../../examples/pooling/score/template/nemotron-rerank.jinja) | ✅︎ | ✅︎ | | `LlamaBidirectionalForSequenceClassification`<sup>C</sup> | Llama-based with bidirectional attention | `nvidia/llama-nemotron-rerank-1b-v2`, etc. | [nemotron-rerank.jinja](../../examples/pooling/score/template/nemotron-rerank.jinja) | ✅︎ | ✅︎ |
@@ -639,6 +642,7 @@ These models primarily support the [`LLM.encode`](./pooling_models.md#llmencode)
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | | Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) |
| ------------ | ------ | ----------------- | --------------------------- | --------------------------------------- | | ------------ | ------ | ----------------- | --------------------------- | --------------------------------------- |
| `BertForTokenClassification` | bert-based | `boltuix/NeuroBERT-NER` (see note), etc. | | | | `BertForTokenClassification` | bert-based | `boltuix/NeuroBERT-NER` (see note), etc. | | |
| `ErnieForTokenClassification` | BERT-like Chinese ERNIE | `gyr66/Ernie-3.0-base-chinese-finetuned-ner` | | |
| `ModernBertForTokenClassification` | ModernBERT-based | `disham993/electrical-ner-ModernBERT-base` | | | | `ModernBertForTokenClassification` | ModernBERT-based | `disham993/electrical-ner-ModernBERT-base` | | |
!!! note !!! note

View File

@@ -18,6 +18,7 @@ from vllm.platforms import current_platform
pytest.mark.slow_test, pytest.mark.slow_test,
], ],
), ),
pytest.param("Forrest20231206/ernie-3.0-base-zh-cls"),
], ],
) )
@pytest.mark.parametrize("dtype", ["half"] if current_platform.is_rocm() else ["float"]) @pytest.mark.parametrize("dtype", ["half"] if current_platform.is_rocm() else ["float"])
@@ -47,5 +48,6 @@ def test_models(
assert torch.allclose( assert torch.allclose(
hf_output, hf_output,
vllm_output, vllm_output,
atol=1e-3 if dtype == "float" else 1e-2,
rtol=2e-3 if dtype == "float" else 1e-2, rtol=2e-3 if dtype == "float" else 1e-2,
) )

View File

@@ -25,11 +25,17 @@ def seed_everything():
yield yield
@pytest.mark.parametrize("model", ["boltuix/NeuroBERT-NER"]) @pytest.mark.parametrize(
"model",
[
"boltuix/NeuroBERT-NER",
"gyr66/Ernie-3.0-base-chinese-finetuned-ner",
],
)
# The float32 is required for this tiny model to pass the test. # The float32 is required for this tiny model to pass the test.
@pytest.mark.parametrize("dtype", ["float"]) @pytest.mark.parametrize("dtype", ["float"])
@torch.inference_mode @torch.inference_mode
def test_bert_models( def test_bert_like_models(
hf_runner, hf_runner,
vllm_runner, vllm_runner,
example_prompts, example_prompts,

View File

@@ -0,0 +1,45 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from tests.models.language.pooling.embed_utils import correctness_test_embed_models
from tests.models.utils import EmbedModelInfo
from .mteb_embed_utils import mteb_test_embed_models
MODELS = [
EmbedModelInfo(
"shibing624/text2vec-base-chinese-sentence",
architecture="ErnieModel",
mteb_score=0.536523112,
seq_pooling_type="MEAN",
attn_type="encoder_only",
is_prefix_caching_supported=False,
is_chunked_prefill_supported=False,
enable_test=True,
),
]
@pytest.mark.parametrize("model_info", MODELS)
def test_embed_models_mteb(hf_runner, vllm_runner, model_info: EmbedModelInfo) -> None:
mteb_test_embed_models(
hf_runner,
vllm_runner,
model_info,
vllm_extra_kwargs={"gpu_memory_utilization": 0.2},
)
@pytest.mark.parametrize("model_info", MODELS)
def test_embed_models_correctness(
hf_runner, vllm_runner, model_info: EmbedModelInfo, example_prompts
) -> None:
correctness_test_embed_models(
hf_runner,
vllm_runner,
model_info,
example_prompts,
vllm_extra_kwargs={"gpu_memory_utilization": 0.2},
)

View File

@@ -552,6 +552,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
_EMBEDDING_EXAMPLE_MODELS = { _EMBEDDING_EXAMPLE_MODELS = {
# [Text-only] # [Text-only]
"BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"), "BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"),
"ErnieModel": _HfExamplesInfo("shibing624/text2vec-base-chinese-sentence"),
"BertSpladeSparseEmbeddingModel": _HfExamplesInfo( "BertSpladeSparseEmbeddingModel": _HfExamplesInfo(
"naver/splade-v3", "naver/splade-v3",
hf_overrides={"architectures": ["BertSpladeSparseEmbeddingModel"]}, hf_overrides={"architectures": ["BertSpladeSparseEmbeddingModel"]},
@@ -666,6 +667,9 @@ _REWARD_EXAMPLE_MODELS = {
_TOKEN_CLASSIFICATION_EXAMPLE_MODELS = { _TOKEN_CLASSIFICATION_EXAMPLE_MODELS = {
"BertForTokenClassification": _HfExamplesInfo("boltuix/NeuroBERT-NER"), "BertForTokenClassification": _HfExamplesInfo("boltuix/NeuroBERT-NER"),
"ErnieForTokenClassification": _HfExamplesInfo(
"gyr66/Ernie-3.0-base-chinese-finetuned-ner"
),
"ModernBertForTokenClassification": _HfExamplesInfo( "ModernBertForTokenClassification": _HfExamplesInfo(
"disham993/electrical-ner-ModernBERT-base" "disham993/electrical-ner-ModernBERT-base"
), ),
@@ -675,6 +679,9 @@ _SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS = {
"BertForSequenceClassification": _HfExamplesInfo( "BertForSequenceClassification": _HfExamplesInfo(
"cross-encoder/ms-marco-MiniLM-L-6-v2" "cross-encoder/ms-marco-MiniLM-L-6-v2"
), ),
"ErnieForSequenceClassification": _HfExamplesInfo(
"Forrest20231206/ernie-3.0-base-zh-cls",
),
"GPT2ForSequenceClassification": _HfExamplesInfo( "GPT2ForSequenceClassification": _HfExamplesInfo(
"nie3e/sentiment-polish-gpt2-small" "nie3e/sentiment-polish-gpt2-small"
), ),

View File

@@ -0,0 +1,247 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
import torch
from torch import nn
from transformers import BertConfig
from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import DispatchPooler
from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_classify
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.sequence import IntermediateTensors
from .bert import (
TOKEN_TYPE_SHIFT,
BertEmbedding,
BertEmbeddingModel,
BertModel,
BertPoolingModel,
_decode_token_type_ids,
_encode_token_type_ids,
)
from .interfaces import SupportsCrossEncoding, SupportsQuant
from .interfaces_base import attn_type, default_pooling_type
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
_LEGACY_SUFFIX_MAPPER = WeightsMapper(
orig_to_new_suffix={
".gamma": ".weight",
".beta": ".bias",
}
)
class ErnieEmbedding(BertEmbedding):
def __init__(self, config: BertConfig):
super().__init__(config)
task_type_vocab_size = max(1, getattr(config, "task_type_vocab_size", 1))
self.task_type_embeddings = VocabParallelEmbedding(
task_type_vocab_size, config.hidden_size
)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor:
token_type_ids = _decode_token_type_ids(input_ids)
task_type_ids = torch.zeros_like(token_type_ids)
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
task_type_embeddings = self.task_type_embeddings(task_type_ids)
embeddings = (
inputs_embeds
+ token_type_embeddings
+ task_type_embeddings
+ position_embeddings
)
embeddings = self.LayerNorm(embeddings)
return embeddings
@default_pooling_type(seq_pooling_type="CLS")
class ErnieModel(BertModel):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(
vllm_config=vllm_config,
prefix=prefix,
embedding_class=ErnieEmbedding,
)
class ErniePoolingModel(BertPoolingModel):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(
vllm_config=vllm_config,
prefix=prefix,
embedding_class=ErnieEmbedding,
)
@default_pooling_type(seq_pooling_type="CLS")
class ErnieEmbeddingModel(BertEmbeddingModel):
def _build_model(self, vllm_config: VllmConfig, prefix: str = "") -> ErnieModel:
return ErnieModel(vllm_config=vllm_config, prefix=prefix)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
weights_list = list(weights)
has_model_prefix = any(name.startswith("model.") for name, _ in weights_list)
has_ernie_prefix = any(name.startswith("ernie.") for name, _ in weights_list)
mapper: WeightsMapper | None = None
if not has_model_prefix:
if has_ernie_prefix:
mapper = WeightsMapper(orig_to_new_prefix={"ernie.": "model."})
else:
mapper = WeightsMapper(orig_to_new_prefix={"": "model."})
if mapper is None:
mapper = _LEGACY_SUFFIX_MAPPER
else:
mapper = mapper | _LEGACY_SUFFIX_MAPPER
loader = AutoWeightsLoader(self, skip_prefixes=["lm_head.", "cls."])
return loader.load_weights(weights_list, mapper=mapper)
@default_pooling_type(seq_pooling_type="CLS")
class ErnieForSequenceClassification(nn.Module, SupportsCrossEncoding, SupportsQuant):
is_pooling_model = True
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
self.num_labels = config.num_labels
self.ernie = ErniePoolingModel(
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "ernie"),
)
self.classifier = nn.Linear(
config.hidden_size,
config.num_labels,
dtype=vllm_config.model_config.head_dtype,
)
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler.for_seq_cls(
pooler_config,
pooling=self.ernie.pooler,
classifier=self.classifier,
)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.ernie.embed_input_ids(input_ids)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
weights_list = list(weights)
has_ernie_prefix = any(name.startswith("ernie.") for name, _ in weights_list)
has_bert_prefix = any(name.startswith("bert.") for name, _ in weights_list)
mapper: WeightsMapper | None = None
if has_bert_prefix and not has_ernie_prefix:
mapper = WeightsMapper(orig_to_new_prefix={"bert.": "ernie."})
if mapper is None:
mapper = _LEGACY_SUFFIX_MAPPER
else:
mapper = mapper | _LEGACY_SUFFIX_MAPPER
loader = AutoWeightsLoader(self, skip_prefixes=["cls.", "lm_head."])
return loader.load_weights(weights_list, mapper=mapper)
def forward(
self,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
token_type_ids: torch.Tensor | None = None,
) -> torch.Tensor:
if token_type_ids is not None:
assert self.ernie.config.vocab_size < (1 << TOKEN_TYPE_SHIFT)
assert input_ids is not None
_encode_token_type_ids(input_ids, token_type_ids)
return self.ernie(
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors,
)
@attn_type("encoder_only")
@default_pooling_type(tok_pooling_type="ALL")
class ErnieForTokenClassification(nn.Module):
is_pooling_model = True
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
self.head_dtype = vllm_config.model_config.head_dtype
self.num_labels = config.num_labels
self.ernie = ErnieModel(
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "ernie"),
)
self.classifier = nn.Linear(
config.hidden_size, config.num_labels, dtype=self.head_dtype
)
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = pooler_for_token_classify(pooler_config)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.ernie.embed_input_ids(input_ids)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
weights_list = list(weights)
has_ernie_prefix = any(name.startswith("ernie.") for name, _ in weights_list)
has_bert_prefix = any(name.startswith("bert.") for name, _ in weights_list)
mapper: WeightsMapper | None = None
if has_bert_prefix and not has_ernie_prefix:
mapper = WeightsMapper(orig_to_new_prefix={"bert.": "ernie."})
if mapper is None:
mapper = _LEGACY_SUFFIX_MAPPER
else:
mapper = mapper | _LEGACY_SUFFIX_MAPPER
loader = AutoWeightsLoader(self, skip_prefixes=["cls.", "lm_head."])
return loader.load_weights(weights_list, mapper=mapper)
def forward(
self,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
token_type_ids: torch.Tensor | None = None,
) -> torch.Tensor:
if token_type_ids is not None:
assert self.ernie.config.vocab_size < (1 << TOKEN_TYPE_SHIFT)
assert input_ids is not None
_encode_token_type_ids(input_ids, token_type_ids)
hidden_states = self.ernie(
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors,
)
hidden_states = hidden_states.to(self.head_dtype)
return self.classifier(hidden_states)

View File

@@ -214,6 +214,7 @@ _EMBEDDING_MODELS = {
# [Text-only] # [Text-only]
"BertModel": ("bert", "BertEmbeddingModel"), "BertModel": ("bert", "BertEmbeddingModel"),
"BertSpladeSparseEmbeddingModel": ("bert", "BertSpladeSparseEmbeddingModel"), "BertSpladeSparseEmbeddingModel": ("bert", "BertSpladeSparseEmbeddingModel"),
"ErnieModel": ("ernie", "ErnieEmbeddingModel"),
"BgeM3EmbeddingModel": ("roberta", "BgeM3EmbeddingModel"), "BgeM3EmbeddingModel": ("roberta", "BgeM3EmbeddingModel"),
"DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"), "DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
"Gemma2Model": ("gemma2", "Gemma2ForCausalLM"), "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
@@ -286,6 +287,7 @@ _REWARD_MODELS = {
_TOKEN_CLASSIFICATION_MODELS = { _TOKEN_CLASSIFICATION_MODELS = {
"BertForTokenClassification": ("bert", "BertForTokenClassification"), "BertForTokenClassification": ("bert", "BertForTokenClassification"),
"ErnieForTokenClassification": ("ernie", "ErnieForTokenClassification"),
"ModernBertForTokenClassification": ( "ModernBertForTokenClassification": (
"modernbert", "modernbert",
"ModernBertForTokenClassification", "ModernBertForTokenClassification",
@@ -295,6 +297,7 @@ _TOKEN_CLASSIFICATION_MODELS = {
_SEQUENCE_CLASSIFICATION_MODELS = { _SEQUENCE_CLASSIFICATION_MODELS = {
"BertForSequenceClassification": ("bert", "BertForSequenceClassification"), "BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
"GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"), "GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
"ErnieForSequenceClassification": ("ernie", "ErnieForSequenceClassification"),
"GteNewForSequenceClassification": ( "GteNewForSequenceClassification": (
"bert_with_rope", "bert_with_rope",
"GteNewForSequenceClassification", "GteNewForSequenceClassification",