[Bugfix] Fix Dense module loading for sentence-transformers embedding models (simplified V2) (#23408)

Signed-off-by: FFFfff1FFFfff <yifanli0919@gmail.com>
This commit is contained in:
LIYIFAN_liyifan
2025-08-24 22:39:24 -07:00
committed by GitHub
parent 787cdb3829
commit c9abb10489
5 changed files with 175 additions and 2 deletions

View File

@@ -7,15 +7,21 @@ from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast
import torch
import torch.nn as nn
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.models.config import VerifyAndUpdateConfig
from vllm.transformers_utils.config import (get_hf_file_bytes,
get_hf_file_to_dict)
from .interfaces_base import VllmModelForPooling, is_pooling_model
if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.config import ModelConfig, VllmConfig
_T = TypeVar("_T", bound=type[nn.Module])
logger = init_logger(__name__)
_GENERATE_SUFFIXES = [
"ForCausalLM",
"ForConditionalGeneration",
@@ -24,6 +30,96 @@ _GENERATE_SUFFIXES = [
]
def _load_st_projector(model_config: "ModelConfig") -> Optional[nn.Module]:
"""Load Sentence-Transformers Dense projection layers."""
try:
modules = get_hf_file_to_dict("modules.json", model_config.model,
model_config.revision)
if not modules:
return None
if isinstance(modules, dict):
modules = modules.get("modules", [])
dense_modules = [
m for m in modules
if m.get("type") == "sentence_transformers.models.Dense"
]
if not dense_modules:
return None
module = dense_modules[0]
folder = module.get("path", "")
config_path = f"{folder}/config.json" if folder else "config.json"
layer_config = get_hf_file_to_dict(config_path, model_config.model,
model_config.revision)
if not layer_config:
return None
linear = nn.Linear(layer_config.get("in_features", 768),
layer_config.get("out_features", 768),
bias=layer_config.get("bias", True),
dtype=torch.float32)
if _load_dense_weights(linear, folder, model_config):
layers = [linear]
if act_name := layer_config.get("activation_function"):
layers.append(get_act_fn(act_name))
return nn.Sequential(*layers).to(dtype=torch.float32)
except Exception:
logger.exception("ST projector loading failed")
return None
def _load_dense_weights(linear: nn.Linear, folder: str,
model_config: "ModelConfig") -> bool:
"""Load weights using vLLM's weight_loader pattern."""
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader)
for filename in ["model.safetensors", "pytorch_model.bin"]:
file_path = f"{folder}/{filename}" if folder else filename
try:
file_bytes = get_hf_file_bytes(file_path, model_config.model,
model_config.revision)
if not file_bytes:
continue
if filename.endswith(".safetensors"):
from safetensors.torch import load as load_safetensors
state_dict = load_safetensors(file_bytes)
else:
import io
state_dict = torch.load(io.BytesIO(file_bytes),
map_location="cpu",
weights_only=True)
for weight_key in ["weight", "linear.weight", "dense.weight"]:
if weight_key in state_dict:
weight_loader = getattr(linear.weight, "weight_loader",
default_weight_loader)
weight_loader(linear.weight,
state_dict[weight_key].to(torch.float32))
bias_key = weight_key.replace("weight", "bias")
if linear.bias is not None and bias_key in state_dict:
bias_loader = getattr(linear.bias, "weight_loader",
default_weight_loader)
bias_loader(linear.bias,
state_dict[bias_key].to(torch.float32))
return True
except Exception:
logger.exception("Failed to load %s", filename)
continue
return False
def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str:
model_name = orig_model_name