[Model] Add Olmo3 model implementation (#24534)
Signed-off-by: Shane A <shanea@allenai.org> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@@ -52,10 +52,11 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
|
||||
from vllm.model_executor.models.utils import (
|
||||
AutoWeightsLoader, is_pp_missing_parameter,
|
||||
AutoWeightsLoader, extract_layer_index, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs import Olmo3Config
|
||||
|
||||
|
||||
class Olmo2Attention(nn.Module):
|
||||
@@ -68,7 +69,7 @@ class Olmo2Attention(nn.Module):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
self.config = vllm_config.model_config.hf_config
|
||||
assert isinstance(self.config, Olmo2Config)
|
||||
assert isinstance(self.config, (Olmo2Config, Olmo3Config))
|
||||
|
||||
hidden_size = self.config.hidden_size
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
@@ -111,14 +112,14 @@ class Olmo2Attention(nn.Module):
|
||||
self.q_norm = RMSNorm(self.config.hidden_size,
|
||||
eps=self.config.rms_norm_eps)
|
||||
|
||||
# Rotary embeddings.
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=self.max_position_embeddings,
|
||||
base=self.rope_theta, # type: ignore
|
||||
)
|
||||
self.scaling = self.head_dim**-0.5
|
||||
|
||||
layer_idx = extract_layer_index(prefix)
|
||||
sliding_window = None
|
||||
if ((layer_types := getattr(self.config, "layer_types", None))
|
||||
is not None and layer_types[layer_idx] == "sliding_attention"):
|
||||
sliding_window = self.config.sliding_window
|
||||
|
||||
self.attn = Attention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
@@ -126,7 +127,20 @@ class Olmo2Attention(nn.Module):
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=vllm_config.cache_config,
|
||||
quant_config=vllm_config.quant_config,
|
||||
prefix=prefix,
|
||||
per_layer_sliding_window=sliding_window,
|
||||
prefix=f"{prefix}.attn",
|
||||
)
|
||||
|
||||
# Rotary embeddings. Rope scaling is only applied on full attention
|
||||
# layers.
|
||||
self.rope_scaling = (self.config.rope_scaling
|
||||
if sliding_window is None else None)
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=self.max_position_embeddings,
|
||||
base=self.rope_theta, # type: ignore
|
||||
rope_scaling=self.rope_scaling,
|
||||
)
|
||||
|
||||
# Attention output projection.
|
||||
@@ -176,7 +190,7 @@ class Olmo2MLP(nn.Module):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
assert isinstance(config, Olmo2Config)
|
||||
assert isinstance(config, (Olmo2Config, Olmo3Config))
|
||||
hidden_size = config.hidden_size
|
||||
intermediate_size = config.intermediate_size
|
||||
|
||||
@@ -221,7 +235,7 @@ class Olmo2DecoderLayer(nn.Module):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
assert isinstance(config, Olmo2Config)
|
||||
assert isinstance(config, (Olmo2Config, Olmo3Config))
|
||||
# Attention block.
|
||||
self.self_attn = Olmo2Attention(vllm_config=vllm_config,
|
||||
prefix=f"{prefix}.self_attn")
|
||||
@@ -261,7 +275,7 @@ class Olmo2Model(nn.Module):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
self.config = vllm_config.model_config.hf_config
|
||||
assert isinstance(self.config, Olmo2Config)
|
||||
assert isinstance(self.config, (Olmo2Config, Olmo3Config))
|
||||
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
self.config.vocab_size,
|
||||
@@ -376,7 +390,7 @@ class Olmo2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
assert isinstance(config, Olmo2Config)
|
||||
assert isinstance(config, (Olmo2Config, Olmo3Config))
|
||||
self.config = config
|
||||
self.model = Olmo2Model(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
|
||||
@@ -120,6 +120,7 @@ _TEXT_GENERATION_MODELS = {
|
||||
"NemotronHForCausalLM": ("nemotron_h", "NemotronHForCausalLM"),
|
||||
"OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
|
||||
"Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
|
||||
"Olmo3ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
|
||||
"OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
|
||||
"OPTForCausalLM": ("opt", "OPTForCausalLM"),
|
||||
"OrionForCausalLM": ("orion", "OrionForCausalLM"),
|
||||
|
||||
Reference in New Issue
Block a user