[New Model]: google/embeddinggemma-300m (#24318)
Signed-off-by: wang.yuqi <noooop@126.com>
This commit is contained in:
@@ -49,26 +49,28 @@ def _load_st_projector(model_config: "ModelConfig") -> Optional[nn.Module]:
|
||||
if not dense_modules:
|
||||
return None
|
||||
|
||||
module = dense_modules[0]
|
||||
folder = module.get("path", "")
|
||||
layers = []
|
||||
for module in dense_modules:
|
||||
folder = module.get("path", "")
|
||||
|
||||
config_path = f"{folder}/config.json" if folder else "config.json"
|
||||
layer_config = get_hf_file_to_dict(config_path, model_config.model,
|
||||
model_config.revision)
|
||||
if not layer_config:
|
||||
return None
|
||||
config_path = f"{folder}/config.json" if folder else "config.json"
|
||||
layer_config = get_hf_file_to_dict(config_path, model_config.model,
|
||||
model_config.revision)
|
||||
if not layer_config:
|
||||
continue
|
||||
|
||||
linear = nn.Linear(layer_config.get("in_features", 768),
|
||||
layer_config.get("out_features", 768),
|
||||
bias=layer_config.get("bias", True),
|
||||
dtype=torch.float32)
|
||||
linear = nn.Linear(layer_config.get("in_features", 768),
|
||||
layer_config.get("out_features", 768),
|
||||
bias=layer_config.get("bias", True),
|
||||
dtype=torch.float32)
|
||||
|
||||
if _load_dense_weights(linear, folder, model_config):
|
||||
layers = [linear]
|
||||
if not _load_dense_weights(linear, folder, model_config):
|
||||
continue
|
||||
|
||||
layers.append(linear)
|
||||
if act_name := layer_config.get("activation_function"):
|
||||
layers.append(get_act_fn(act_name))
|
||||
return nn.Sequential(*layers).to(dtype=torch.float32)
|
||||
|
||||
return nn.Sequential(*layers).to(dtype=torch.float32)
|
||||
except Exception:
|
||||
logger.exception("ST projector loading failed")
|
||||
|
||||
|
||||
@@ -24,6 +24,14 @@ class VerifyAndUpdateConfig:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Gemma3TextModelConfig:
|
||||
|
||||
@staticmethod
|
||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||
hf_config = vllm_config.model_config.hf_config
|
||||
hf_config.is_causal = not hf_config.use_bidirectional_attention
|
||||
|
||||
|
||||
class GteNewModelConfig(VerifyAndUpdateConfig):
|
||||
|
||||
@staticmethod
|
||||
@@ -409,6 +417,7 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
|
||||
"GteModel": SnowflakeGteNewModelConfig,
|
||||
"GteNewModel": GteNewModelConfig,
|
||||
"GteNewForSequenceClassification": GteNewModelConfig,
|
||||
"Gemma3TextModel": Gemma3TextModelConfig,
|
||||
"NomicBertModel": NomicBertModelConfig,
|
||||
"Qwen2ForProcessRewardModel": Qwen2ForProcessRewardModelConfig,
|
||||
"Qwen2ForRewardModel": Qwen2ForRewardModelConfig,
|
||||
|
||||
@@ -24,7 +24,7 @@ import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from transformers import Gemma3TextConfig
|
||||
|
||||
from vllm.attention import Attention
|
||||
from vllm.attention import Attention, AttentionType
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
@@ -44,6 +44,7 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from ...attention.layers.encoder_only_attention import EncoderOnlyAttention
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, extract_layer_index,
|
||||
is_pp_missing_parameter,
|
||||
@@ -169,16 +170,24 @@ class Gemma3Attention(nn.Module):
|
||||
rope_scaling=self.rope_scaling,
|
||||
)
|
||||
|
||||
# Initialize the attention.
|
||||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
logits_soft_cap=attn_logits_soft_cap,
|
||||
per_layer_sliding_window=sliding_window,
|
||||
prefix=f"{prefix}.attn")
|
||||
if getattr(config, "is_causal", True):
|
||||
attn_type = AttentionType.DECODER
|
||||
else:
|
||||
attn_type = AttentionType.ENCODER_ONLY
|
||||
|
||||
attn_cls = (EncoderOnlyAttention
|
||||
if attn_type == AttentionType.ENCODER_ONLY else Attention)
|
||||
|
||||
self.attn = attn_cls(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
attn_type=attn_type,
|
||||
logits_soft_cap=attn_logits_soft_cap,
|
||||
per_layer_sliding_window=sliding_window,
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@@ -155,6 +155,7 @@ _EMBEDDING_MODELS = {
|
||||
"BertModel": ("bert", "BertEmbeddingModel"),
|
||||
"DeciLMForCausalLM": ("nemotron_nas", "DeciLMForCausalLM"),
|
||||
"Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
|
||||
"Gemma3TextModel": ("gemma3", "Gemma3Model"),
|
||||
"GlmForCausalLM": ("glm", "GlmForCausalLM"),
|
||||
"GPT2ForSequenceClassification": ("gpt2", "GPT2ForSequenceClassification"),
|
||||
"GritLM": ("gritlm", "GritLM"),
|
||||
|
||||
Reference in New Issue
Block a user