[Model] Improve olmo and olmo2 (#23228)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@@ -47,7 +47,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsPP
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
@@ -91,6 +91,7 @@ class OlmoAttention(nn.Module):
|
||||
self.total_num_heads,
|
||||
bias=config.attention_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
|
||||
# Rotary embeddings.
|
||||
@@ -114,6 +115,7 @@ class OlmoAttention(nn.Module):
|
||||
self.hidden_size,
|
||||
bias=config.attention_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
|
||||
def forward(
|
||||
@@ -142,6 +144,7 @@ class OlmoMLP(nn.Module):
|
||||
self,
|
||||
config: OlmoConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -154,6 +157,7 @@ class OlmoMLP(nn.Module):
|
||||
[self.intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj",
|
||||
)
|
||||
|
||||
# Activation function.
|
||||
@@ -165,6 +169,7 @@ class OlmoMLP(nn.Module):
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
)
|
||||
|
||||
def forward(
|
||||
@@ -197,7 +202,7 @@ class OlmoDecoderLayer(nn.Module):
|
||||
prefix=f"{prefix}.self_attn")
|
||||
|
||||
# MLP block.
|
||||
self.mlp = OlmoMLP(config, quant_config)
|
||||
self.mlp = OlmoMLP(config, quant_config, prefix=f"{prefix}.mlp")
|
||||
|
||||
# LayerNorm
|
||||
self.input_layernorm = nn.LayerNorm(config.hidden_size,
|
||||
@@ -326,10 +331,21 @@ class OlmoModel(nn.Module):
|
||||
return loaded_params
|
||||
|
||||
|
||||
class OlmoForCausalLM(nn.Module, SupportsPP):
|
||||
class OlmoForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
|
||||
"""
|
||||
Extremely barebones HF model wrapper.
|
||||
"""
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
Reference in New Issue
Block a user