Files
vllm/vllm/model_executor/models/modernbert.py
wang.yuqi 6d729c43fb [Bugfix] Fix ModernBert load & Enable sliding window attention for bidirectional attention. (#22637)
Signed-off-by: wang.yuqi <noooop@126.com>
Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
Co-authored-by: Max de Bayser <mbayser@br.ibm.com>
2025-08-12 00:23:17 -07:00

372 lines
14 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable, Set
from typing import Optional, Union
import torch
from torch import nn
from transformers import ModernBertConfig
from vllm.attention import Attention, AttentionType
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.linear import (QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.pooler import (ClassifierPooler,
DispatchPooler, Pooler,
PoolingMethod,
PoolingParamsUpdate,
PoolingType)
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors
from vllm.tasks import PoolingTask
from .interfaces import SupportsCrossEncoding, default_pooling_type
from .utils import WeightsMapper, maybe_prefix
class ModernBertEmbeddings(nn.Module):
def __init__(self, config: ModernBertConfig):
super().__init__()
self.config = config
self.tok_embeddings = VocabParallelEmbedding(config.vocab_size,
config.hidden_size)
self.norm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps,
bias=config.norm_bias)
def forward(
self,
input_ids: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if inputs_embeds is not None:
return self.norm(inputs_embeds)
else:
inputs_embeds = self.tok_embeddings(input_ids)
embeddings = self.norm(inputs_embeds)
return embeddings
class ModernBertRotaryEmbedding(RotaryEmbedding):
def __init__(self, config: ModernBertConfig, head_size: int, dim: int,
base: float):
super().__init__(
head_size=head_size,
rotary_dim=dim,
max_position_embeddings=config.max_position_embeddings,
base=base,
is_neox_style=True,
dtype=torch.float16)
self.config = config
class ModernBertAttention(nn.Module):
def __init__(self,
config: ModernBertConfig,
layer_id: Optional[int] = None):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.layer_id = layer_id
self.deterministic_flash_attn = config.deterministic_flash_attn
self.num_heads = config.num_attention_heads
assert self.num_heads % tp_size == 0
self.head_dim = config.hidden_size // config.num_attention_heads
self.all_head_size = self.head_dim * self.num_heads
self.scaling = self.head_dim**-0.5
self.Wqkv = QKVParallelLinear(
config.hidden_size,
self.head_dim,
self.num_heads,
bias=config.attention_bias,
)
sliding_window = None
if layer_id % config.global_attn_every_n_layers != 0:
sliding_window = config.local_attention // 2
rope_theta = config.local_rope_theta if config.local_rope_theta \
is not None else config.global_rope_theta
else:
rope_theta = config.global_rope_theta
self.rotary_emb = ModernBertRotaryEmbedding(config=config,
head_size=self.head_dim,
dim=self.head_dim,
base=rope_theta)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
prefix=f"{layer_id}.attn",
attn_type=AttentionType.ENCODER_ONLY,
per_layer_sliding_window=sliding_window)
self.Wo = RowParallelLinear(config.hidden_size,
config.hidden_size,
bias=config.attention_bias)
def forward(
self,
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
) -> torch.Tensor:
qkv, _ = self.Wqkv(hidden_states)
q, k, v = qkv.split([self.all_head_size] * 3, dim=-1)
q, k = self.rotary_emb(position_ids, q, k)
attn_outputs = self.attn(q, k, v)
hidden_states = attn_outputs
hidden_states, _ = self.Wo(hidden_states)
return hidden_states
class ModernBertMLP(nn.Module):
def __init__(self, config: ModernBertConfig):
super().__init__()
self.config = config
self.Wi = nn.Linear(config.hidden_size,
int(config.intermediate_size) * 2,
bias=config.mlp_bias)
self.act = nn.GELU()
self.Wo = RowParallelLinear(config.intermediate_size,
config.hidden_size,
bias=config.mlp_bias)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
input, gate = self.Wi(hidden_states).chunk(2, dim=-1)
return self.Wo(self.act(input) * gate)[0]
class ModernBertLayer(nn.Module):
def __init__(self,
config: ModernBertConfig,
prefix: str = "",
layer_id: Optional[int] = None):
super().__init__()
self.config = config
if layer_id == 0:
self.attn_norm = nn.Identity()
else:
self.attn_norm = nn.LayerNorm(config.hidden_size,
eps=config.norm_eps,
bias=config.norm_bias)
self.attn = ModernBertAttention(config=config, layer_id=layer_id)
self.mlp_norm = nn.LayerNorm(config.hidden_size,
eps=config.norm_eps,
bias=config.norm_bias)
self.mlp = ModernBertMLP(config)
def forward(
self,
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
) -> torch.Tensor:
attn_outputs = self.attn(hidden_states=self.attn_norm(hidden_states),
position_ids=position_ids)
hidden_states = hidden_states + attn_outputs
mlp_output = self.mlp(self.mlp_norm(hidden_states))
hidden_states = hidden_states + mlp_output
return hidden_states
class ModernBertEncoderLayer(nn.Module):
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
self.layers = nn.ModuleList([
ModernBertLayer(config=config, layer_id=layer_id)
for layer_id in range(config.num_hidden_layers)
])
def forward(
self,
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
) -> torch.Tensor:
for i, layer in enumerate(self.layers):
hidden_states = layer(hidden_states, position_ids)
return hidden_states
@support_torch_compile
@default_pooling_type("CLS")
class ModernBertModel(nn.Module):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={"layers.": "encoder_layer.layers."})
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
):
super().__init__()
config = vllm_config.model_config.hf_config
self.config = config
self.embeddings = ModernBertEmbeddings(config)
self.encoder_layer = ModernBertEncoderLayer(vllm_config)
self.final_norm = nn.LayerNorm(config.hidden_size,
eps=config.norm_eps,
bias=config.norm_bias)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
weights = self.hf_to_vllm_mapper.apply(weights)
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.embeddings(input_ids=input_ids,
inputs_embeds=inputs_embeds)
outputs = self.encoder_layer(
hidden_states=hidden_states,
position_ids=positions,
)
norm_outputs = self.final_norm(outputs)
return norm_outputs
class ModernBertPooler(Pooler):
def __init__(self, config: ModernBertConfig):
super().__init__()
pooling_type = PoolingType[config.classifier_pooling.upper()]
self.pooling = PoolingMethod.from_pooling_type(pooling_type)
self.dense = nn.Linear(config.hidden_size, config.hidden_size,
config.classifier_bias)
self.act = nn.GELU()
self.norm = nn.LayerNorm(config.hidden_size,
eps=config.norm_eps,
bias=config.norm_bias)
def get_supported_tasks(self) -> Set[PoolingTask]:
return self.pooling.get_supported_tasks()
def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate:
return self.pooling.get_pooling_updates(task)
def _head(self, pooled_output: torch.Tensor):
pooled_output = pooled_output.to(self.dense.weight.dtype)
return self.norm(self.act(self.dense(pooled_output)))
def forward(
self,
hidden_states: Union[torch.Tensor, list[torch.Tensor]],
pooling_metadata: PoolingMetadata,
) -> Union[torch.Tensor, list[torch.Tensor]]:
pooled_output = self.pooling(hidden_states, pooling_metadata)
if isinstance(pooled_output, list):
pooled_output = [self._head(output) for output in pooled_output]
else:
pooled_output = self._head(pooled_output)
return pooled_output
@default_pooling_type("CLS")
class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
is_pooling_model = True
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
self.config = config
self.model = ModernBertModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "modernbert"))
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.pooling = ModernBertPooler(config)
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler({
"encode":
Pooler.for_encode(pooler_config),
"classify":
ClassifierPooler(
pooling=self.pooling,
classifier=self.classifier,
act_fn=ClassifierPooler.act_fn_for_seq_cls(
vllm_config.model_config),
),
"score":
ClassifierPooler(
pooling=self.pooling,
classifier=self.classifier,
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
vllm_config.model_config),
),
})
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
self_weights = []
def weight_filter():
for name, weight in weights:
if name.startswith("model."):
yield name[len("model."):], weight
else:
self_weights.append((name, weight))
self.model.load_weights(weight_filter())
params_dict = dict(self.named_parameters())
for name, loaded_weight in self_weights:
if name.startswith("classifier"):
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
if name.startswith("head"):
param = params_dict["pooling." + name[len("head") + 1:]]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
def forward(
self,
input_ids: Optional[torch.LongTensor],
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return self.model(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
positions=positions,
)