[Model] Add support for BERT-like Chinese ERNIE pooling models (#36385)
Signed-off-by: whyiug <whyiug@hotmail.com> Co-authored-by: wang.yuqi <yuqi.wang@daocloud.io>
This commit is contained in:
@@ -514,6 +514,7 @@ These models primarily support the [`LLM.embed`](./pooling_models.md#llmembed) A
|
||||
| ------------ | ------ | ----------------- | -------------------- | ------------------------- |
|
||||
| `BertModel`<sup>C</sup> | BERT-based | `BAAI/bge-base-en-v1.5`, `Snowflake/snowflake-arctic-embed-xs`, etc. | | |
|
||||
| `BertSpladeSparseEmbeddingModel` | SPLADE | `naver/splade-v3` | | |
|
||||
| `ErnieModel` | BERT-like Chinese ERNIE | `shibing624/text2vec-base-chinese-sentence` | | |
|
||||
| `Gemma2Model`<sup>C</sup> | Gemma 2-based | `BAAI/bge-multilingual-gemma2`, etc. | ✅︎ | ✅︎ |
|
||||
| `Gemma3TextModel`<sup>C</sup> | Gemma 3-based | `google/embeddinggemma-300m`, etc. | ✅︎ | ✅︎ |
|
||||
| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ |
|
||||
@@ -556,8 +557,9 @@ These models primarily support the [`LLM.classify`](./pooling_models.md#llmclass
|
||||
|
||||
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) |
|
||||
| ------------ | ------ | ----------------- | -------------------- | ------------------------- |
|
||||
| `JambaForSequenceClassification` | Jamba | `ai21labs/Jamba-tiny-reward-dev`, etc. | ✅︎ | ✅︎ |
|
||||
| `ErnieForSequenceClassification` | BERT-like Chinese ERNIE | `Forrest20231206/ernie-3.0-base-zh-cls` | | |
|
||||
| `GPT2ForSequenceClassification` | GPT2 | `nie3e/sentiment-polish-gpt2-small` | | |
|
||||
| `JambaForSequenceClassification` | Jamba | `ai21labs/Jamba-tiny-reward-dev`, etc. | ✅︎ | ✅︎ |
|
||||
| `*Model`<sup>C</sup>, `*ForCausalLM`<sup>C</sup>, etc. | Generative models | N/A | \* | \* |
|
||||
|
||||
<sup>C</sup> Automatically converted into a classification model via `--convert classify`. ([details](./pooling_models.md#model-conversion))
|
||||
@@ -574,6 +576,7 @@ These models primarily support the [`LLM.score`](./pooling_models.md#llmscore) A
|
||||
| Architecture | Models | Example HF Models | Score template (see note) | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) |
|
||||
| ------------ | ------ | ----------------- | ------------------------- | --------------------------- | --------------------------------------- |
|
||||
| `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | N/A | | |
|
||||
| `ErnieForSequenceClassification` | BERT-like Chinese ERNIE | `Forrest20231206/ernie-3.0-base-zh-cls` | N/A | | |
|
||||
| `GemmaForSequenceClassification` | Gemma-based | `BAAI/bge-reranker-v2-gemma`(see note), etc. | [bge-reranker-v2-gemma.jinja](../../examples/pooling/score/template/bge-reranker-v2-gemma.jinja) | ✅︎ | ✅︎ |
|
||||
| `GteNewForSequenceClassification` | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-reranker-base`, etc. | N/A | | |
|
||||
| `LlamaBidirectionalForSequenceClassification`<sup>C</sup> | Llama-based with bidirectional attention | `nvidia/llama-nemotron-rerank-1b-v2`, etc. | [nemotron-rerank.jinja](../../examples/pooling/score/template/nemotron-rerank.jinja) | ✅︎ | ✅︎ |
|
||||
@@ -639,6 +642,7 @@ These models primarily support the [`LLM.encode`](./pooling_models.md#llmencode)
|
||||
| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) |
|
||||
| ------------ | ------ | ----------------- | --------------------------- | --------------------------------------- |
|
||||
| `BertForTokenClassification` | bert-based | `boltuix/NeuroBERT-NER` (see note), etc. | | |
|
||||
| `ErnieForTokenClassification` | BERT-like Chinese ERNIE | `gyr66/Ernie-3.0-base-chinese-finetuned-ner` | | |
|
||||
| `ModernBertForTokenClassification` | ModernBERT-based | `disham993/electrical-ner-ModernBERT-base` | | |
|
||||
|
||||
!!! note
|
||||
|
||||
@@ -18,6 +18,7 @@ from vllm.platforms import current_platform
|
||||
pytest.mark.slow_test,
|
||||
],
|
||||
),
|
||||
pytest.param("Forrest20231206/ernie-3.0-base-zh-cls"),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", ["half"] if current_platform.is_rocm() else ["float"])
|
||||
@@ -47,5 +48,6 @@ def test_models(
|
||||
assert torch.allclose(
|
||||
hf_output,
|
||||
vllm_output,
|
||||
atol=1e-3 if dtype == "float" else 1e-2,
|
||||
rtol=2e-3 if dtype == "float" else 1e-2,
|
||||
)
|
||||
|
||||
@@ -25,11 +25,17 @@ def seed_everything():
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["boltuix/NeuroBERT-NER"])
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
"boltuix/NeuroBERT-NER",
|
||||
"gyr66/Ernie-3.0-base-chinese-finetuned-ner",
|
||||
],
|
||||
)
|
||||
# The float32 is required for this tiny model to pass the test.
|
||||
@pytest.mark.parametrize("dtype", ["float"])
|
||||
@torch.inference_mode
|
||||
def test_bert_models(
|
||||
def test_bert_like_models(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
|
||||
45
tests/models/language/pooling_mteb_test/test_ernie.py
Normal file
45
tests/models/language/pooling_mteb_test/test_ernie.py
Normal file
@@ -0,0 +1,45 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.models.language.pooling.embed_utils import correctness_test_embed_models
|
||||
from tests.models.utils import EmbedModelInfo
|
||||
|
||||
from .mteb_embed_utils import mteb_test_embed_models
|
||||
|
||||
MODELS = [
|
||||
EmbedModelInfo(
|
||||
"shibing624/text2vec-base-chinese-sentence",
|
||||
architecture="ErnieModel",
|
||||
mteb_score=0.536523112,
|
||||
seq_pooling_type="MEAN",
|
||||
attn_type="encoder_only",
|
||||
is_prefix_caching_supported=False,
|
||||
is_chunked_prefill_supported=False,
|
||||
enable_test=True,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_info", MODELS)
|
||||
def test_embed_models_mteb(hf_runner, vllm_runner, model_info: EmbedModelInfo) -> None:
|
||||
mteb_test_embed_models(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
model_info,
|
||||
vllm_extra_kwargs={"gpu_memory_utilization": 0.2},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_info", MODELS)
|
||||
def test_embed_models_correctness(
|
||||
hf_runner, vllm_runner, model_info: EmbedModelInfo, example_prompts
|
||||
) -> None:
|
||||
correctness_test_embed_models(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
model_info,
|
||||
example_prompts,
|
||||
vllm_extra_kwargs={"gpu_memory_utilization": 0.2},
|
||||
)
|
||||
@@ -552,6 +552,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
_EMBEDDING_EXAMPLE_MODELS = {
|
||||
# [Text-only]
|
||||
"BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"),
|
||||
"ErnieModel": _HfExamplesInfo("shibing624/text2vec-base-chinese-sentence"),
|
||||
"BertSpladeSparseEmbeddingModel": _HfExamplesInfo(
|
||||
"naver/splade-v3",
|
||||
hf_overrides={"architectures": ["BertSpladeSparseEmbeddingModel"]},
|
||||
@@ -666,6 +667,9 @@ _REWARD_EXAMPLE_MODELS = {
|
||||
|
||||
_TOKEN_CLASSIFICATION_EXAMPLE_MODELS = {
|
||||
"BertForTokenClassification": _HfExamplesInfo("boltuix/NeuroBERT-NER"),
|
||||
"ErnieForTokenClassification": _HfExamplesInfo(
|
||||
"gyr66/Ernie-3.0-base-chinese-finetuned-ner"
|
||||
),
|
||||
"ModernBertForTokenClassification": _HfExamplesInfo(
|
||||
"disham993/electrical-ner-ModernBERT-base"
|
||||
),
|
||||
@@ -675,6 +679,9 @@ _SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS = {
|
||||
"BertForSequenceClassification": _HfExamplesInfo(
|
||||
"cross-encoder/ms-marco-MiniLM-L-6-v2"
|
||||
),
|
||||
"ErnieForSequenceClassification": _HfExamplesInfo(
|
||||
"Forrest20231206/ernie-3.0-base-zh-cls",
|
||||
),
|
||||
"GPT2ForSequenceClassification": _HfExamplesInfo(
|
||||
"nie3e/sentiment-polish-gpt2-small"
|
||||
),
|
||||
|
||||
247
vllm/model_executor/models/ernie.py
Normal file
247
vllm/model_executor/models/ernie.py
Normal file
@@ -0,0 +1,247 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Iterable
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import BertConfig
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.pooler import DispatchPooler
|
||||
from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_classify
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .bert import (
|
||||
TOKEN_TYPE_SHIFT,
|
||||
BertEmbedding,
|
||||
BertEmbeddingModel,
|
||||
BertModel,
|
||||
BertPoolingModel,
|
||||
_decode_token_type_ids,
|
||||
_encode_token_type_ids,
|
||||
)
|
||||
from .interfaces import SupportsCrossEncoding, SupportsQuant
|
||||
from .interfaces_base import attn_type, default_pooling_type
|
||||
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
|
||||
|
||||
_LEGACY_SUFFIX_MAPPER = WeightsMapper(
|
||||
orig_to_new_suffix={
|
||||
".gamma": ".weight",
|
||||
".beta": ".bias",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class ErnieEmbedding(BertEmbedding):
|
||||
def __init__(self, config: BertConfig):
|
||||
super().__init__(config)
|
||||
|
||||
task_type_vocab_size = max(1, getattr(config, "task_type_vocab_size", 1))
|
||||
self.task_type_embeddings = VocabParallelEmbedding(
|
||||
task_type_vocab_size, config.hidden_size
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
token_type_ids = _decode_token_type_ids(input_ids)
|
||||
task_type_ids = torch.zeros_like(token_type_ids)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.word_embeddings(input_ids)
|
||||
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
||||
task_type_embeddings = self.task_type_embeddings(task_type_ids)
|
||||
|
||||
embeddings = (
|
||||
inputs_embeds
|
||||
+ token_type_embeddings
|
||||
+ task_type_embeddings
|
||||
+ position_embeddings
|
||||
)
|
||||
embeddings = self.LayerNorm(embeddings)
|
||||
return embeddings
|
||||
|
||||
|
||||
@default_pooling_type(seq_pooling_type="CLS")
|
||||
class ErnieModel(BertModel):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(
|
||||
vllm_config=vllm_config,
|
||||
prefix=prefix,
|
||||
embedding_class=ErnieEmbedding,
|
||||
)
|
||||
|
||||
|
||||
class ErniePoolingModel(BertPoolingModel):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(
|
||||
vllm_config=vllm_config,
|
||||
prefix=prefix,
|
||||
embedding_class=ErnieEmbedding,
|
||||
)
|
||||
|
||||
|
||||
@default_pooling_type(seq_pooling_type="CLS")
|
||||
class ErnieEmbeddingModel(BertEmbeddingModel):
|
||||
def _build_model(self, vllm_config: VllmConfig, prefix: str = "") -> ErnieModel:
|
||||
return ErnieModel(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
weights_list = list(weights)
|
||||
has_model_prefix = any(name.startswith("model.") for name, _ in weights_list)
|
||||
has_ernie_prefix = any(name.startswith("ernie.") for name, _ in weights_list)
|
||||
|
||||
mapper: WeightsMapper | None = None
|
||||
if not has_model_prefix:
|
||||
if has_ernie_prefix:
|
||||
mapper = WeightsMapper(orig_to_new_prefix={"ernie.": "model."})
|
||||
else:
|
||||
mapper = WeightsMapper(orig_to_new_prefix={"": "model."})
|
||||
if mapper is None:
|
||||
mapper = _LEGACY_SUFFIX_MAPPER
|
||||
else:
|
||||
mapper = mapper | _LEGACY_SUFFIX_MAPPER
|
||||
|
||||
loader = AutoWeightsLoader(self, skip_prefixes=["lm_head.", "cls."])
|
||||
return loader.load_weights(weights_list, mapper=mapper)
|
||||
|
||||
|
||||
@default_pooling_type(seq_pooling_type="CLS")
|
||||
class ErnieForSequenceClassification(nn.Module, SupportsCrossEncoding, SupportsQuant):
|
||||
is_pooling_model = True
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
||||
self.num_labels = config.num_labels
|
||||
self.ernie = ErniePoolingModel(
|
||||
vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "ernie"),
|
||||
)
|
||||
self.classifier = nn.Linear(
|
||||
config.hidden_size,
|
||||
config.num_labels,
|
||||
dtype=vllm_config.model_config.head_dtype,
|
||||
)
|
||||
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
self.pooler = DispatchPooler.for_seq_cls(
|
||||
pooler_config,
|
||||
pooling=self.ernie.pooler,
|
||||
classifier=self.classifier,
|
||||
)
|
||||
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.ernie.embed_input_ids(input_ids)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
weights_list = list(weights)
|
||||
has_ernie_prefix = any(name.startswith("ernie.") for name, _ in weights_list)
|
||||
has_bert_prefix = any(name.startswith("bert.") for name, _ in weights_list)
|
||||
|
||||
mapper: WeightsMapper | None = None
|
||||
if has_bert_prefix and not has_ernie_prefix:
|
||||
mapper = WeightsMapper(orig_to_new_prefix={"bert.": "ernie."})
|
||||
if mapper is None:
|
||||
mapper = _LEGACY_SUFFIX_MAPPER
|
||||
else:
|
||||
mapper = mapper | _LEGACY_SUFFIX_MAPPER
|
||||
|
||||
loader = AutoWeightsLoader(self, skip_prefixes=["cls.", "lm_head."])
|
||||
return loader.load_weights(weights_list, mapper=mapper)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor | None,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: IntermediateTensors | None = None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
token_type_ids: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
if token_type_ids is not None:
|
||||
assert self.ernie.config.vocab_size < (1 << TOKEN_TYPE_SHIFT)
|
||||
assert input_ids is not None
|
||||
_encode_token_type_ids(input_ids, token_type_ids)
|
||||
|
||||
return self.ernie(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
)
|
||||
|
||||
|
||||
@attn_type("encoder_only")
|
||||
@default_pooling_type(tok_pooling_type="ALL")
|
||||
class ErnieForTokenClassification(nn.Module):
|
||||
is_pooling_model = True
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.head_dtype = vllm_config.model_config.head_dtype
|
||||
self.num_labels = config.num_labels
|
||||
self.ernie = ErnieModel(
|
||||
vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "ernie"),
|
||||
)
|
||||
self.classifier = nn.Linear(
|
||||
config.hidden_size, config.num_labels, dtype=self.head_dtype
|
||||
)
|
||||
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
self.pooler = pooler_for_token_classify(pooler_config)
|
||||
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.ernie.embed_input_ids(input_ids)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
weights_list = list(weights)
|
||||
has_ernie_prefix = any(name.startswith("ernie.") for name, _ in weights_list)
|
||||
has_bert_prefix = any(name.startswith("bert.") for name, _ in weights_list)
|
||||
|
||||
mapper: WeightsMapper | None = None
|
||||
if has_bert_prefix and not has_ernie_prefix:
|
||||
mapper = WeightsMapper(orig_to_new_prefix={"bert.": "ernie."})
|
||||
if mapper is None:
|
||||
mapper = _LEGACY_SUFFIX_MAPPER
|
||||
else:
|
||||
mapper = mapper | _LEGACY_SUFFIX_MAPPER
|
||||
|
||||
loader = AutoWeightsLoader(self, skip_prefixes=["cls.", "lm_head."])
|
||||
return loader.load_weights(weights_list, mapper=mapper)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor | None,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: IntermediateTensors | None = None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
token_type_ids: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
if token_type_ids is not None:
|
||||
assert self.ernie.config.vocab_size < (1 << TOKEN_TYPE_SHIFT)
|
||||
assert input_ids is not None
|
||||
_encode_token_type_ids(input_ids, token_type_ids)
|
||||
|
||||
hidden_states = self.ernie(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.to(self.head_dtype)
|
||||
return self.classifier(hidden_states)
|
||||
@@ -214,6 +214,7 @@ _EMBEDDING_MODELS = {
|
||||
# [Text-only]
|
||||
"BertModel": ("bert", "BertEmbeddingModel"),
|
||||
"BertSpladeSparseEmbeddingModel": ("bert", "BertSpladeSparseEmbeddingModel"),
|
||||
"ErnieModel": ("ernie", "ErnieEmbeddingModel"),
|
||||
"BgeM3EmbeddingModel": ("roberta", "BgeM3EmbeddingModel"),
|
||||
"DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
|
||||
"Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
|
||||
@@ -286,6 +287,7 @@ _REWARD_MODELS = {
|
||||
|
||||
_TOKEN_CLASSIFICATION_MODELS = {
|
||||
"BertForTokenClassification": ("bert", "BertForTokenClassification"),
|
||||
"ErnieForTokenClassification": ("ernie", "ErnieForTokenClassification"),
|
||||
"ModernBertForTokenClassification": (
|
||||
"modernbert",
|
||||
"ModernBertForTokenClassification",
|
||||
@@ -295,6 +297,7 @@ _TOKEN_CLASSIFICATION_MODELS = {
|
||||
_SEQUENCE_CLASSIFICATION_MODELS = {
|
||||
"BertForSequenceClassification": ("bert", "BertForSequenceClassification"),
|
||||
"GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
|
||||
"ErnieForSequenceClassification": ("ernie", "ErnieForSequenceClassification"),
|
||||
"GteNewForSequenceClassification": (
|
||||
"bert_with_rope",
|
||||
"GteNewForSequenceClassification",
|
||||
|
||||
Reference in New Issue
Block a user