[New Model]: google/embeddinggemma-300m (#24318)

Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
wang.yuqi
2025-09-06 13:58:36 +08:00
committed by GitHub
parent 53b19ccdd5
commit 6d6c6b05d3
9 changed files with 73 additions and 29 deletions

View File

@@ -49,26 +49,28 @@ def _load_st_projector(model_config: "ModelConfig") -> Optional[nn.Module]:
if not dense_modules:
return None
module = dense_modules[0]
folder = module.get("path", "")
layers = []
for module in dense_modules:
folder = module.get("path", "")
config_path = f"{folder}/config.json" if folder else "config.json"
layer_config = get_hf_file_to_dict(config_path, model_config.model,
model_config.revision)
if not layer_config:
return None
config_path = f"{folder}/config.json" if folder else "config.json"
layer_config = get_hf_file_to_dict(config_path, model_config.model,
model_config.revision)
if not layer_config:
continue
linear = nn.Linear(layer_config.get("in_features", 768),
layer_config.get("out_features", 768),
bias=layer_config.get("bias", True),
dtype=torch.float32)
linear = nn.Linear(layer_config.get("in_features", 768),
layer_config.get("out_features", 768),
bias=layer_config.get("bias", True),
dtype=torch.float32)
if _load_dense_weights(linear, folder, model_config):
layers = [linear]
if not _load_dense_weights(linear, folder, model_config):
continue
layers.append(linear)
if act_name := layer_config.get("activation_function"):
layers.append(get_act_fn(act_name))
return nn.Sequential(*layers).to(dtype=torch.float32)
return nn.Sequential(*layers).to(dtype=torch.float32)
except Exception:
logger.exception("ST projector loading failed")

View File

@@ -24,6 +24,14 @@ class VerifyAndUpdateConfig:
raise NotImplementedError
class Gemma3TextModelConfig:
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
hf_config = vllm_config.model_config.hf_config
hf_config.is_causal = not hf_config.use_bidirectional_attention
class GteNewModelConfig(VerifyAndUpdateConfig):
@staticmethod
@@ -409,6 +417,7 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
"GteModel": SnowflakeGteNewModelConfig,
"GteNewModel": GteNewModelConfig,
"GteNewForSequenceClassification": GteNewModelConfig,
"Gemma3TextModel": Gemma3TextModelConfig,
"NomicBertModel": NomicBertModelConfig,
"Qwen2ForProcessRewardModel": Qwen2ForProcessRewardModelConfig,
"Qwen2ForRewardModel": Qwen2ForRewardModelConfig,

View File

@@ -24,7 +24,7 @@ import torch.nn.functional as F
from torch import nn
from transformers import Gemma3TextConfig
from vllm.attention import Attention
from vllm.attention import Attention, AttentionType
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
@@ -44,6 +44,7 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from ...attention.layers.encoder_only_attention import EncoderOnlyAttention
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, extract_layer_index,
is_pp_missing_parameter,
@@ -169,16 +170,24 @@ class Gemma3Attention(nn.Module):
rope_scaling=self.rope_scaling,
)
# Initialize the attention.
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
logits_soft_cap=attn_logits_soft_cap,
per_layer_sliding_window=sliding_window,
prefix=f"{prefix}.attn")
if getattr(config, "is_causal", True):
attn_type = AttentionType.DECODER
else:
attn_type = AttentionType.ENCODER_ONLY
attn_cls = (EncoderOnlyAttention
if attn_type == AttentionType.ENCODER_ONLY else Attention)
self.attn = attn_cls(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
attn_type=attn_type,
logits_soft_cap=attn_logits_soft_cap,
per_layer_sliding_window=sliding_window,
prefix=f"{prefix}.attn")
def forward(
self,

View File

@@ -155,6 +155,7 @@ _EMBEDDING_MODELS = {
"BertModel": ("bert", "BertEmbeddingModel"),
"DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
"Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
"Gemma3TextModel": ("gemma3", "Gemma3Model"),
"GlmForCausalLM": ("glm", "GlmForCausalLM"),
"GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
"GritLM": ("gritlm", "GritLM"),