Signed-off-by: whyiug <whyiug@hotmail.com> Co-authored-by: wang.yuqi <yuqi.wang@daocloud.io>
248 lines
8.6 KiB
Python
248 lines
8.6 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
from collections.abc import Iterable
|
|
|
|
import torch
|
|
from torch import nn
|
|
from transformers import BertConfig
|
|
|
|
from vllm.config import VllmConfig
|
|
from vllm.model_executor.layers.pooler import DispatchPooler
|
|
from vllm.model_executor.layers.pooler.tokwise import pooler_for_token_classify
|
|
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
|
from vllm.sequence import IntermediateTensors
|
|
|
|
from .bert import (
|
|
TOKEN_TYPE_SHIFT,
|
|
BertEmbedding,
|
|
BertEmbeddingModel,
|
|
BertModel,
|
|
BertPoolingModel,
|
|
_decode_token_type_ids,
|
|
_encode_token_type_ids,
|
|
)
|
|
from .interfaces import SupportsCrossEncoding, SupportsQuant
|
|
from .interfaces_base import attn_type, default_pooling_type
|
|
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
|
|
|
|
_LEGACY_SUFFIX_MAPPER = WeightsMapper(
|
|
orig_to_new_suffix={
|
|
".gamma": ".weight",
|
|
".beta": ".bias",
|
|
}
|
|
)
|
|
|
|
|
|
class ErnieEmbedding(BertEmbedding):
|
|
def __init__(self, config: BertConfig):
|
|
super().__init__(config)
|
|
|
|
task_type_vocab_size = max(1, getattr(config, "task_type_vocab_size", 1))
|
|
self.task_type_embeddings = VocabParallelEmbedding(
|
|
task_type_vocab_size, config.hidden_size
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
position_ids: torch.Tensor,
|
|
inputs_embeds: torch.Tensor | None = None,
|
|
) -> torch.Tensor:
|
|
token_type_ids = _decode_token_type_ids(input_ids)
|
|
task_type_ids = torch.zeros_like(token_type_ids)
|
|
|
|
if inputs_embeds is None:
|
|
inputs_embeds = self.word_embeddings(input_ids)
|
|
|
|
position_embeddings = self.position_embeddings(position_ids)
|
|
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
|
task_type_embeddings = self.task_type_embeddings(task_type_ids)
|
|
|
|
embeddings = (
|
|
inputs_embeds
|
|
+ token_type_embeddings
|
|
+ task_type_embeddings
|
|
+ position_embeddings
|
|
)
|
|
embeddings = self.LayerNorm(embeddings)
|
|
return embeddings
|
|
|
|
|
|
@default_pooling_type(seq_pooling_type="CLS")
|
|
class ErnieModel(BertModel):
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
super().__init__(
|
|
vllm_config=vllm_config,
|
|
prefix=prefix,
|
|
embedding_class=ErnieEmbedding,
|
|
)
|
|
|
|
|
|
class ErniePoolingModel(BertPoolingModel):
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
super().__init__(
|
|
vllm_config=vllm_config,
|
|
prefix=prefix,
|
|
embedding_class=ErnieEmbedding,
|
|
)
|
|
|
|
|
|
@default_pooling_type(seq_pooling_type="CLS")
|
|
class ErnieEmbeddingModel(BertEmbeddingModel):
|
|
def _build_model(self, vllm_config: VllmConfig, prefix: str = "") -> ErnieModel:
|
|
return ErnieModel(vllm_config=vllm_config, prefix=prefix)
|
|
|
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
|
weights_list = list(weights)
|
|
has_model_prefix = any(name.startswith("model.") for name, _ in weights_list)
|
|
has_ernie_prefix = any(name.startswith("ernie.") for name, _ in weights_list)
|
|
|
|
mapper: WeightsMapper | None = None
|
|
if not has_model_prefix:
|
|
if has_ernie_prefix:
|
|
mapper = WeightsMapper(orig_to_new_prefix={"ernie.": "model."})
|
|
else:
|
|
mapper = WeightsMapper(orig_to_new_prefix={"": "model."})
|
|
if mapper is None:
|
|
mapper = _LEGACY_SUFFIX_MAPPER
|
|
else:
|
|
mapper = mapper | _LEGACY_SUFFIX_MAPPER
|
|
|
|
loader = AutoWeightsLoader(self, skip_prefixes=["lm_head.", "cls."])
|
|
return loader.load_weights(weights_list, mapper=mapper)
|
|
|
|
|
|
@default_pooling_type(seq_pooling_type="CLS")
|
|
class ErnieForSequenceClassification(nn.Module, SupportsCrossEncoding, SupportsQuant):
|
|
is_pooling_model = True
|
|
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
super().__init__()
|
|
config = vllm_config.model_config.hf_config
|
|
|
|
self.num_labels = config.num_labels
|
|
self.ernie = ErniePoolingModel(
|
|
vllm_config=vllm_config,
|
|
prefix=maybe_prefix(prefix, "ernie"),
|
|
)
|
|
self.classifier = nn.Linear(
|
|
config.hidden_size,
|
|
config.num_labels,
|
|
dtype=vllm_config.model_config.head_dtype,
|
|
)
|
|
|
|
pooler_config = vllm_config.model_config.pooler_config
|
|
assert pooler_config is not None
|
|
|
|
self.pooler = DispatchPooler.for_seq_cls(
|
|
pooler_config,
|
|
pooling=self.ernie.pooler,
|
|
classifier=self.classifier,
|
|
)
|
|
|
|
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
|
return self.ernie.embed_input_ids(input_ids)
|
|
|
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
|
weights_list = list(weights)
|
|
has_ernie_prefix = any(name.startswith("ernie.") for name, _ in weights_list)
|
|
has_bert_prefix = any(name.startswith("bert.") for name, _ in weights_list)
|
|
|
|
mapper: WeightsMapper | None = None
|
|
if has_bert_prefix and not has_ernie_prefix:
|
|
mapper = WeightsMapper(orig_to_new_prefix={"bert.": "ernie."})
|
|
if mapper is None:
|
|
mapper = _LEGACY_SUFFIX_MAPPER
|
|
else:
|
|
mapper = mapper | _LEGACY_SUFFIX_MAPPER
|
|
|
|
loader = AutoWeightsLoader(self, skip_prefixes=["cls.", "lm_head."])
|
|
return loader.load_weights(weights_list, mapper=mapper)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor | None,
|
|
positions: torch.Tensor,
|
|
intermediate_tensors: IntermediateTensors | None = None,
|
|
inputs_embeds: torch.Tensor | None = None,
|
|
token_type_ids: torch.Tensor | None = None,
|
|
) -> torch.Tensor:
|
|
if token_type_ids is not None:
|
|
assert self.ernie.config.vocab_size < (1 << TOKEN_TYPE_SHIFT)
|
|
assert input_ids is not None
|
|
_encode_token_type_ids(input_ids, token_type_ids)
|
|
|
|
return self.ernie(
|
|
input_ids=input_ids,
|
|
positions=positions,
|
|
inputs_embeds=inputs_embeds,
|
|
intermediate_tensors=intermediate_tensors,
|
|
)
|
|
|
|
|
|
@attn_type("encoder_only")
|
|
@default_pooling_type(tok_pooling_type="ALL")
|
|
class ErnieForTokenClassification(nn.Module):
|
|
is_pooling_model = True
|
|
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
super().__init__()
|
|
config = vllm_config.model_config.hf_config
|
|
self.head_dtype = vllm_config.model_config.head_dtype
|
|
self.num_labels = config.num_labels
|
|
self.ernie = ErnieModel(
|
|
vllm_config=vllm_config,
|
|
prefix=maybe_prefix(prefix, "ernie"),
|
|
)
|
|
self.classifier = nn.Linear(
|
|
config.hidden_size, config.num_labels, dtype=self.head_dtype
|
|
)
|
|
|
|
pooler_config = vllm_config.model_config.pooler_config
|
|
assert pooler_config is not None
|
|
|
|
self.pooler = pooler_for_token_classify(pooler_config)
|
|
|
|
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
|
return self.ernie.embed_input_ids(input_ids)
|
|
|
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
|
weights_list = list(weights)
|
|
has_ernie_prefix = any(name.startswith("ernie.") for name, _ in weights_list)
|
|
has_bert_prefix = any(name.startswith("bert.") for name, _ in weights_list)
|
|
|
|
mapper: WeightsMapper | None = None
|
|
if has_bert_prefix and not has_ernie_prefix:
|
|
mapper = WeightsMapper(orig_to_new_prefix={"bert.": "ernie."})
|
|
if mapper is None:
|
|
mapper = _LEGACY_SUFFIX_MAPPER
|
|
else:
|
|
mapper = mapper | _LEGACY_SUFFIX_MAPPER
|
|
|
|
loader = AutoWeightsLoader(self, skip_prefixes=["cls.", "lm_head."])
|
|
return loader.load_weights(weights_list, mapper=mapper)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor | None,
|
|
positions: torch.Tensor,
|
|
intermediate_tensors: IntermediateTensors | None = None,
|
|
inputs_embeds: torch.Tensor | None = None,
|
|
token_type_ids: torch.Tensor | None = None,
|
|
) -> torch.Tensor:
|
|
if token_type_ids is not None:
|
|
assert self.ernie.config.vocab_size < (1 << TOKEN_TYPE_SHIFT)
|
|
assert input_ids is not None
|
|
_encode_token_type_ids(input_ids, token_type_ids)
|
|
|
|
hidden_states = self.ernie(
|
|
input_ids=input_ids,
|
|
positions=positions,
|
|
inputs_embeds=inputs_embeds,
|
|
intermediate_tensors=intermediate_tensors,
|
|
)
|
|
|
|
hidden_states = hidden_states.to(self.head_dtype)
|
|
return self.classifier(hidden_states)
|