diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index e69f68fee..1e6776faa 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -519,6 +519,7 @@ These models primarily support the [`LLM.embed`](./pooling_models.md#llmembed) A | `LlamaModel`C, `LlamaForCausalLM`C, `MistralModel`C, etc. | Llama-based | `intfloat/e5-mistral-7b-instruct`, etc. | ✅︎ | ✅︎ | | `Qwen2Model`C, `Qwen2ForCausalLM`C | Qwen2-based | `ssmits/Qwen2-7B-Instruct-embed-base` (see note), `Alibaba-NLP/gte-Qwen2-7B-instruct` (see note), etc. | ✅︎ | ✅︎ | | `Qwen3Model`C, `Qwen3ForCausalLM`C | Qwen3-based | `Qwen/Qwen3-Embedding-0.6B`, etc. | ✅︎ | ✅︎ | +| `VoyageQwen3BidirectionalEmbedModel`C | Voyage Qwen3-based with bidirectional attention | `voyageai/voyage-4-nano`, etc. | ✅︎ | ✅︎ | | `RobertaModel`, `RobertaForMaskedLM` | RoBERTa-based | `sentence-transformers/all-roberta-large-v1`, etc. | | | | `*Model`C, `*ForCausalLM`C, etc. | Generative models | N/A | \* | \* | diff --git a/tests/models/language/pooling_mteb_test/test_voyage.py b/tests/models/language/pooling_mteb_test/test_voyage.py new file mode 100644 index 000000000..99ef1de9a --- /dev/null +++ b/tests/models/language/pooling_mteb_test/test_voyage.py @@ -0,0 +1,56 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest + +from tests.models.language.pooling.embed_utils import correctness_test_embed_models +from tests.models.utils import EmbedModelInfo + +from .mteb_embed_utils import mteb_test_embed_models + +MODELS = [ + EmbedModelInfo( + "voyageai/voyage-4-nano", + architecture="VoyageQwen3BidirectionalEmbedModel", + enable_test=True, + seq_pooling_type="MEAN", + attn_type="encoder_only", + is_prefix_caching_supported=False, + is_chunked_prefill_supported=False, + hf_overrides={ + "architectures": ["VoyageQwen3BidirectionalEmbedModel"], + "num_labels": 2048, + }, + mteb_score=0.7054, + # === MTEB Results === + # STS12: 0.6613 + # STS13: 0.6906 + # STS14: 0.6556 + # STS15: 0.7843 + # STS16: 0.7340 + # STSBenchmark: 0.7063 + # Average score: 0.7054 + ), +] + + +@pytest.mark.parametrize("model_info", MODELS) +def test_embed_models_mteb(hf_runner, vllm_runner, model_info: EmbedModelInfo) -> None: + # Encoder-only attention models need enforce_eager=True to avoid + # CUDA graph capture issues with piecewise compilation + mteb_test_embed_models( + hf_runner, vllm_runner, model_info, vllm_extra_kwargs={"enforce_eager": True} + ) + + +@pytest.mark.parametrize("model_info", MODELS) +def test_embed_models_correctness( + hf_runner, vllm_runner, model_info: EmbedModelInfo, example_prompts +) -> None: + correctness_test_embed_models( + hf_runner, + vllm_runner, + model_info, + example_prompts, + vllm_extra_kwargs={"enforce_eager": True}, + ) diff --git a/tests/models/registry.py b/tests/models/registry.py index ffa4f52f1..69da8c7af 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -565,6 +565,9 @@ _EMBEDDING_EXAMPLE_MODELS = { ), "RobertaModel": _HfExamplesInfo("sentence-transformers/stsb-roberta-base-v2"), "RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1"), + "VoyageQwen3BidirectionalEmbedModel": _HfExamplesInfo( + "voyageai/voyage-4-nano", trust_remote_code=True + ), "XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-small"), "BertSpladeSparseEmbeddingModel": _HfExamplesInfo( "naver/splade-v3", diff --git a/vllm/config/model.py b/vllm/config/model.py index 86b484181..7c0b33443 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -1513,6 +1513,10 @@ class ModelConfig: @property def embedding_size(self): + # Check for embedding_size set by model config (e.g., Voyage models) + override = getattr(self.hf_config, "embedding_size", None) + if override is not None: + return override dense_modules = try_get_dense_modules(self.model, revision=self.revision) if dense_modules is not None: return dense_modules[-1]["out_features"] diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index c41f5e18b..a6c244b6e 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -582,6 +582,13 @@ class NemotronHForCausalLMConfig(VerifyAndUpdateConfig): cache_config.mamba_ssm_cache_dtype = mamba_ssm_cache_dtype +class VoyageQwen3BidirectionalEmbedModelConfig(VerifyAndUpdateConfig): + @staticmethod + def verify_and_update_model_config(model_config: "ModelConfig") -> None: + model_config.hf_config.is_causal = False + model_config.hf_config.embedding_size = model_config.hf_config.num_labels + + MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = { "GteModel": SnowflakeGteNewModelConfig, "GteNewModel": GteNewModelConfig, @@ -604,4 +611,5 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = { "DeepseekV32ForCausalLM": DeepseekV32ForCausalLM, "NemotronHForCausalLM": NemotronHForCausalLMConfig, "NemotronHPuzzleForCausalLM": NemotronHForCausalLMConfig, + "VoyageQwen3BidirectionalEmbedModel": VoyageQwen3BidirectionalEmbedModelConfig, } diff --git a/vllm/model_executor/models/qwen3.py b/vllm/model_executor/models/qwen3.py index 06df05144..43f330eb0 100644 --- a/vllm/model_executor/models/qwen3.py +++ b/vllm/model_executor/models/qwen3.py @@ -34,7 +34,10 @@ from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.logger import init_logger -from vllm.model_executor.layers.attention import Attention +from vllm.model_executor.layers.attention.encoder_only_attention import ( + Attention, + EncoderOnlyAttention, +) from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -115,7 +118,12 @@ class Qwen3Attention(nn.Module): rope_parameters=rope_parameters, dual_chunk_attention_config=dual_chunk_attention_config, ) - self.attn = Attention( + attn_cls = ( + EncoderOnlyAttention + if attn_type == AttentionType.ENCODER_ONLY + else Attention + ) + self.attn = attn_cls( self.num_heads, self.head_dim, self.scaling, diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 830a615ce..c310f6f17 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -237,6 +237,10 @@ _EMBEDDING_MODELS = { "RobertaModel": ("roberta", "RobertaEmbeddingModel"), "TeleChatForCausalLM": ("telechat2", "TeleChat2ForCausalLM"), "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"), + "VoyageQwen3BidirectionalEmbedModel": ( + "voyage", + "VoyageQwen3BidirectionalEmbedModel", + ), "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"), "BgeM3EmbeddingModel": ("roberta", "BgeM3EmbeddingModel"), # [Multimodal] diff --git a/vllm/model_executor/models/voyage.py b/vllm/model_executor/models/voyage.py new file mode 100644 index 000000000..0713b128c --- /dev/null +++ b/vllm/model_executor/models/voyage.py @@ -0,0 +1,130 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from __future__ import annotations + +import re +from collections import defaultdict +from collections.abc import Iterable + +import torch +import torch.nn as nn + +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.qwen3 import Qwen3Model +from vllm.model_executor.models.utils import WeightsMapper + +WeightItem = tuple[str, torch.Tensor] + +_LAYER_RE = re.compile(r"^layers\.(\d+)\.(.+)$") + + +class VoyageQwen3BidirectionalEmbedModel(Qwen3Model): + """ + Qwen3Model + Voyage embedding head + bidirectional attention. + + Checkpoint conventions (HF): + - MLP: gate_proj + up_proj (unfused) + - Attn: q_proj + k_proj + v_proj (unfused) + - Linear head: linear.weight + - Weights prefixed with "model." (e.g., model.layers.0...) + + vLLM Qwen3Model expects: + - mlp.gate_up_proj (fused) + - self_attn.qkv_proj (fused) + - No "model." prefix + + We remap/fuse weights using generator pipeline and load directly + (bypassing parent's stacked_params_mapping which would cause + double-transformation like qkv_proj -> qkqkv_proj). + """ + + hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # Embedding head (hidden_size -> num_labels, bias=False) + self.linear = nn.Linear( + self.config.hidden_size, + self.config.num_labels, + bias=False, + ) + + def forward(self, *args, **kwargs): + out = super().forward(*args, **kwargs) + return self.linear(out) + + def _fuse_qkv_proj(self, weights: Iterable[WeightItem]) -> Iterable[WeightItem]: + """Fuse q_proj, k_proj, v_proj into qkv_proj.""" + qkv_buf: dict[int, dict[str, torch.Tensor]] = defaultdict(dict) + qkv_suffixes = { + "self_attn.q_proj.weight": "q", + "self_attn.k_proj.weight": "k", + "self_attn.v_proj.weight": "v", + } + + for name, tensor in weights: + m = _LAYER_RE.match(name) + if m and m.group(2) in qkv_suffixes: + layer_idx = int(m.group(1)) + qkv_buf[layer_idx][qkv_suffixes[m.group(2)]] = tensor + else: + yield name, tensor + + # Yield fused QKV weights + for layer_idx in sorted(qkv_buf.keys()): + parts = qkv_buf[layer_idx] + if all(p in parts for p in ("q", "k", "v")): + fused = torch.cat([parts["q"], parts["k"], parts["v"]], dim=0) + yield f"layers.{layer_idx}.self_attn.qkv_proj.weight", fused + elif parts: + missing = [p for p in ("q", "k", "v") if p not in parts] + raise ValueError(f"Layer {layer_idx} missing QKV parts: {missing}") + + def _fuse_gate_up_proj(self, weights: Iterable[WeightItem]) -> Iterable[WeightItem]: + """Fuse gate_proj and up_proj into gate_up_proj.""" + mlp_buf: dict[int, dict[str, torch.Tensor]] = defaultdict(dict) + mlp_suffixes = { + "mlp.gate_proj.weight": "gate", + "mlp.up_proj.weight": "up", + } + + for name, tensor in weights: + m = _LAYER_RE.match(name) + if m and m.group(2) in mlp_suffixes: + layer_idx = int(m.group(1)) + mlp_buf[layer_idx][mlp_suffixes[m.group(2)]] = tensor + else: + yield name, tensor + + # Yield fused gate_up weights + for layer_idx in sorted(mlp_buf.keys()): + parts = mlp_buf[layer_idx] + if all(p in parts for p in ("gate", "up")): + fused = torch.cat([parts["gate"], parts["up"]], dim=0) + yield f"layers.{layer_idx}.mlp.gate_up_proj.weight", fused + elif parts: + missing = [p for p in ("gate", "up") if p not in parts] + raise ValueError(f"Layer {layer_idx} missing MLP parts: {missing}") + + def load_weights(self, weights: Iterable[WeightItem]) -> set[str]: + """Remap, fuse, and load weights using generator pipeline.""" + # Chain weight transformations + weights = self.hf_to_vllm_mapper.apply(weights) + weights = self._fuse_qkv_proj(weights) + weights = self._fuse_gate_up_proj(weights) + + # Load weights directly into model parameters + # (bypass parent's stacked_params_mapping) + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + + return loaded_params