[Bugfix] Fix Dense module loading for sentence-transformers embedding models (simplified V2) (#23408)
Signed-off-by: FFFfff1FFFfff <yifanli0919@gmail.com>
This commit is contained in:
@@ -7,15 +7,21 @@ from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.models.config import VerifyAndUpdateConfig
|
||||
from vllm.transformers_utils.config import (get_hf_file_bytes,
|
||||
get_hf_file_to_dict)
|
||||
|
||||
from .interfaces_base import VllmModelForPooling, is_pooling_model
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
|
||||
_T = TypeVar("_T", bound=type[nn.Module])
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_GENERATE_SUFFIXES = [
|
||||
"ForCausalLM",
|
||||
"ForConditionalGeneration",
|
||||
@@ -24,6 +30,96 @@ _GENERATE_SUFFIXES = [
|
||||
]
|
||||
|
||||
|
||||
def _load_st_projector(model_config: "ModelConfig") -> Optional[nn.Module]:
|
||||
"""Load Sentence-Transformers Dense projection layers."""
|
||||
|
||||
try:
|
||||
modules = get_hf_file_to_dict("modules.json", model_config.model,
|
||||
model_config.revision)
|
||||
if not modules:
|
||||
return None
|
||||
|
||||
if isinstance(modules, dict):
|
||||
modules = modules.get("modules", [])
|
||||
|
||||
dense_modules = [
|
||||
m for m in modules
|
||||
if m.get("type") == "sentence_transformers.models.Dense"
|
||||
]
|
||||
if not dense_modules:
|
||||
return None
|
||||
|
||||
module = dense_modules[0]
|
||||
folder = module.get("path", "")
|
||||
|
||||
config_path = f"{folder}/config.json" if folder else "config.json"
|
||||
layer_config = get_hf_file_to_dict(config_path, model_config.model,
|
||||
model_config.revision)
|
||||
if not layer_config:
|
||||
return None
|
||||
|
||||
linear = nn.Linear(layer_config.get("in_features", 768),
|
||||
layer_config.get("out_features", 768),
|
||||
bias=layer_config.get("bias", True),
|
||||
dtype=torch.float32)
|
||||
|
||||
if _load_dense_weights(linear, folder, model_config):
|
||||
layers = [linear]
|
||||
if act_name := layer_config.get("activation_function"):
|
||||
layers.append(get_act_fn(act_name))
|
||||
return nn.Sequential(*layers).to(dtype=torch.float32)
|
||||
|
||||
except Exception:
|
||||
logger.exception("ST projector loading failed")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _load_dense_weights(linear: nn.Linear, folder: str,
|
||||
model_config: "ModelConfig") -> bool:
|
||||
"""Load weights using vLLM's weight_loader pattern."""
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader)
|
||||
|
||||
for filename in ["model.safetensors", "pytorch_model.bin"]:
|
||||
file_path = f"{folder}/{filename}" if folder else filename
|
||||
|
||||
try:
|
||||
file_bytes = get_hf_file_bytes(file_path, model_config.model,
|
||||
model_config.revision)
|
||||
if not file_bytes:
|
||||
continue
|
||||
|
||||
if filename.endswith(".safetensors"):
|
||||
from safetensors.torch import load as load_safetensors
|
||||
state_dict = load_safetensors(file_bytes)
|
||||
else:
|
||||
import io
|
||||
state_dict = torch.load(io.BytesIO(file_bytes),
|
||||
map_location="cpu",
|
||||
weights_only=True)
|
||||
|
||||
for weight_key in ["weight", "linear.weight", "dense.weight"]:
|
||||
if weight_key in state_dict:
|
||||
weight_loader = getattr(linear.weight, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(linear.weight,
|
||||
state_dict[weight_key].to(torch.float32))
|
||||
|
||||
bias_key = weight_key.replace("weight", "bias")
|
||||
if linear.bias is not None and bias_key in state_dict:
|
||||
bias_loader = getattr(linear.bias, "weight_loader",
|
||||
default_weight_loader)
|
||||
bias_loader(linear.bias,
|
||||
state_dict[bias_key].to(torch.float32))
|
||||
return True
|
||||
except Exception:
|
||||
logger.exception("Failed to load %s", filename)
|
||||
continue
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str:
|
||||
model_name = orig_model_name
|
||||
|
||||
|
||||
Reference in New Issue
Block a user