[Model] Improve olmo and olmo2 (#23228)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li
2025-08-20 20:47:05 +08:00
committed by GitHub
parent 7cd17e22d7
commit c6d80a7a96
3 changed files with 36 additions and 7 deletions

View File

@@ -47,7 +47,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
@@ -91,6 +91,7 @@ class OlmoAttention(nn.Module):
self.total_num_heads,
bias=config.attention_bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
# Rotary embeddings.
@@ -114,6 +115,7 @@ class OlmoAttention(nn.Module):
self.hidden_size,
bias=config.attention_bias,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
def forward(
@@ -142,6 +144,7 @@ class OlmoMLP(nn.Module):
self,
config: OlmoConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
@@ -154,6 +157,7 @@ class OlmoMLP(nn.Module):
[self.intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
# Activation function.
@@ -165,6 +169,7 @@ class OlmoMLP(nn.Module):
self.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
def forward(
@@ -197,7 +202,7 @@ class OlmoDecoderLayer(nn.Module):
prefix=f"{prefix}.self_attn")
# MLP block.
self.mlp = OlmoMLP(config, quant_config)
self.mlp = OlmoMLP(config, quant_config, prefix=f"{prefix}.mlp")
# LayerNorm
self.input_layernorm = nn.LayerNorm(config.hidden_size,
@@ -326,10 +331,21 @@ class OlmoModel(nn.Module):
return loaded_params
class OlmoForCausalLM(nn.Module, SupportsPP):
class OlmoForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
"""
Extremely barebones HF model wrapper.
"""
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()