Files
vllm/vllm/model_executor/models/phi4flash.py
2025-07-15 21:12:40 -07:00

747 lines
30 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from collections.abc import Iterable
from typing import Optional, Union
import torch
import torch.nn as nn
from transformers.activations import ACT2FN
import vllm.envs as envs
from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.attention.selector import _Backend
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update)
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
selective_scan_fn, selective_state_update)
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid,
SupportsV0Only)
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .utils import make_layers, maybe_prefix
logger = init_logger(__name__)
class SwiGLUActivation(nn.Module):
def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
return x1 * nn.functional.silu(x2)
class SambaYMLP(nn.Module):
"""Gated Linear Unit.
Reference:
Language Modeling with Gated Convolutional Networks.
https://arxiv.org/pdf/1612.08083v3.pdf.
"""
def __init__(self, config):
super().__init__()
self.config = config
self.fc1 = nn.Linear(config.hidden_size,
2 * config.intermediate_size,
bias=False)
self.fc2 = nn.Linear(config.intermediate_size,
config.hidden_size,
bias=False)
self.activation_fn = ACT2FN[config.hidden_act]
def forward(self, hidden_states):
y = self.fc1(hidden_states)
gate, y = y.chunk(2, dim=-1)
y = y * self.activation_fn(gate)
return self.fc2(y)
def get_virtual_engine():
forward_context: ForwardContext = get_forward_context()
return forward_context.virtual_engine
class SambaYAttention(nn.Module):
def __init__(self,
config,
layer_idx: Optional[int] = None,
yoco_cross: bool = False,
cache_config: Optional[CacheConfig] = None,
prefix: str = ""):
super().__init__()
if layer_idx is None:
logger.warning_once(
f"Instantiating {self.__class__.__name__} without passing "
"a `layer_idx` is not recommended and will lead to errors "
"during the forward call if caching is used. Please make "
"sure to provide a `layer_idx` when creating this class.")
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.yoco_cross = yoco_cross
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError("hidden_size must be divisible by num_heads "
f"(got `hidden_size`: {self.hidden_size} and "
f"`num_heads`: {self.num_heads}).")
op_size = self.num_heads * self.head_dim + 2 * (
self.num_key_value_heads * self.head_dim)
self.out_proj = nn.Linear(self.num_heads * self.head_dim,
self.hidden_size,
bias=True)
if yoco_cross:
self.Wqkv = nn.Linear(self.hidden_size,
self.num_heads * self.head_dim,
bias=True)
else:
self.Wqkv = nn.Linear(self.hidden_size, op_size, bias=True)
# disable sliding window for the second half of the model
sliding_window = config.interleaved_sliding_window[layer_idx]
if layer_idx >= config.num_hidden_layers // 2:
assert sliding_window is None, \
"sliding_window must be none for the second decoder"
else:
assert sliding_window is not None, \
"sliding_window must be set for the first decoder"
assert self.num_heads % 2 == 0, 'num_heads should be even'
assert self.num_key_value_heads % 2 == 0, 'num_heads should be even'
self.lambda_init = self.lambda_init_fn(layer_idx)
self.lambda_q1 = nn.Parameter(
torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,
std=0.1))
self.lambda_k1 = nn.Parameter(
torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,
std=0.1))
self.lambda_q2 = nn.Parameter(
torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,
std=0.1))
self.lambda_k2 = nn.Parameter(
torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,
std=0.1))
self.subln = nn.RMSNorm(2 * self.head_dim,
eps=1e-5,
elementwise_affine=True)
params = {
'differential_flash_attention_config': {
'lambda_init': self.lambda_init,
'lambda_q1': self.lambda_q1,
'lambda_k1': self.lambda_k1,
'lambda_q2': self.lambda_q2,
'lambda_k2': self.lambda_k2,
"subln": self.subln,
}
}
if yoco_cross:
kv_shared_layer_index = config.num_hidden_layers // 2 + 1
kv_sharing_target_layer_name = \
f"model.layers.{kv_shared_layer_index}.self_attn.attn"
else:
kv_sharing_target_layer_name = None
self.attn = Attention(
self.num_heads,
self.head_dim,
self.head_dim**-0.5,
num_kv_heads=self.num_key_value_heads,
cache_config=cache_config,
per_layer_sliding_window=sliding_window,
prefix=f"{prefix}.attn",
attn_type=AttentionType.DECODER,
kv_sharing_target_layer_name=kv_sharing_target_layer_name,
**params)
assert self.attn.backend == _Backend.DIFFERENTIAL_FLASH_ATTN,\
"DIFFERENTIAL_FLASH_ATTN required"
def lambda_init_fn(self, depth):
return 0.8 - 0.6 * math.exp(-0.3 * depth)
def forward(
self,
hidden_states: torch.Tensor,
):
if not self.yoco_cross: # need to generate kv-cache
qkv = self.Wqkv(hidden_states)
q, k, v = qkv.split([
self.hidden_size, self.num_key_value_heads * self.head_dim,
self.num_key_value_heads * self.head_dim
],
dim=-1)
attn_output = self.attn(q, k, v)
else: # reuse the kv cache, full attention
q = self.Wqkv(hidden_states)
attn_output = self.attn(q, None, None)
attn_output = attn_output.view(-1, self.num_heads * self.head_dim)
return self.out_proj(attn_output)
class Phi4Mamba(nn.Module):
def __init__(
self,
d_model,
d_state=16,
d_conv=4,
expand=2,
dt_rank="auto",
dt_min=0.001,
dt_max=0.1,
dt_init="random", # difference
dt_scale=1.0, # difference
dt_init_floor=1e-4,
conv_bias=True,
bias=False,
use_fast_path=True, # Fused kernel options
layer_idx=None,
device=None,
dtype=None,
yoco_cross=False,
yoco_kv=False,
):
factory_kwargs = {"params_dtype": dtype} # difference
super().__init__()
self.yoco_cross = yoco_cross
self.yoco_kv = yoco_kv
self.d_model = d_model
self.d_state = d_state
self.d_conv = d_conv
self.expand = expand
self.d_inner = int(self.expand * self.d_model)
self.dt_rank = math.ceil(self.d_model /
16) if dt_rank == "auto" else dt_rank
self.use_fast_path = use_fast_path
self.layer_idx = layer_idx
self.swiGluActivation = SwiGLUActivation()
if self.yoco_cross:
self.in_proj = MergedColumnParallelLinear(self.d_model,
[self.d_inner],
bias=bias,
**factory_kwargs)
self.out_proj = RowParallelLinear(self.d_inner,
self.d_model,
bias=bias,
**factory_kwargs)
return
self.conv1d = ColumnParallelLinear(
input_size=d_conv,
output_size=self.d_inner,
bias=conv_bias,
params_dtype=dtype,
)
# unsqueeze to fit conv1d weights shape into the linear weights shape.
# Can't do this in `weight_loader` since it already exists in
# `ColumnParallelLinear` and `set_weight_attrs`
# doesn't allow to override it
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
self.in_proj = MergedColumnParallelLinear(
self.d_model,
[self.d_inner] * 2,
bias=bias,
params_dtype=dtype,
)
# selective projection used to make dt, B and C input dependent
self.x_proj = RowParallelLinear(
self.d_inner,
self.dt_rank + self.d_state * 2,
bias=False,
params_dtype=dtype,
)
# time step projection (discretization) -
# In the forward we need to apply dt_proj without the bias,
# as the bias is added in the selective scan kernel.
self.dt_proj = ColumnParallelLinear(
self.dt_rank,
self.d_inner,
bias=True,
skip_bias_add=True,
params_dtype=dtype,
)
# # D "skip" parameter
# self.D = nn.Parameter(torch.ones(self.d_inner)) # Keep in fp32
self.A = nn.Parameter(
torch.empty(
self.d_inner,
self.d_state,
dtype=torch.float32,
))
self.D = nn.Parameter(torch.ones(self.d_inner, dtype=torch.float32))
self.out_proj = RowParallelLinear(
self.d_inner,
self.d_model,
bias=bias,
input_is_parallel=True,
params_dtype=dtype,
)
self.activation = "silu"
def forward(self,
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
mamba_cache_params: MambaCacheParams,
yoco_key_values=None) -> torch.Tensor:
if self.yoco_cross:
out = self.in_proj(hidden_states)[0]
out = self.swiGluActivation(yoco_key_values, out)
out = self.out_proj(out)
return out[0], yoco_key_values
# 1. Gated MLP's linear projection
# projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
projected_states = self.in_proj(
hidden_states.to(self.in_proj.weight.dtype))[0].transpose(-2, -1)
hidden_states, gate = projected_states.chunk(2, dim=-2)
# 2. Convolution sequence transformation
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
self.conv1d.weight.size(2))
if attn_metadata.query_start_loc is not None \
and attn_metadata.context_lens_tensor is not None:
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
hidden_states = causal_conv1d_fn(
hidden_states,
conv_weights,
self.conv1d.bias,
activation=self.activation,
conv_states=mamba_cache_params.conv_state,
has_initial_state=attn_metadata.context_lens_tensor > 0,
cache_indices=mamba_cache_params.state_indices_tensor,
query_start_loc=attn_metadata.query_start_loc)
else:
hidden_states = causal_conv1d_update(
hidden_states.transpose(0, 1),
mamba_cache_params.conv_state,
conv_weights,
self.conv1d.bias,
self.activation,
conv_state_indices=mamba_cache_params.state_indices_tensor)
hidden_states = hidden_states.transpose(0, 1)
# 3. State Space Model sequence transformation
# 3.a. input varying initialization of time_step, B and C
ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0]
time_step, B, C = torch.split(
ssm_parameters,
[self.dt_rank, self.d_state, self.d_state],
dim=-1,
)
# Note that Jamba normalizes B, C, and time_step here but Mamba doesn't.
discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1)
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
time_proj_bias = (self.dt_proj.bias.float() if hasattr(
self.dt_proj, "bias") else None)
if attn_metadata.query_start_loc is not None \
and attn_metadata.context_lens_tensor is not None:
scan_outputs = selective_scan_fn(
hidden_states,
mamba_cache_params.ssm_state,
discrete_time_step,
self.A,
B.transpose(-2, -1),
C.transpose(-2, -1),
self.D.float(),
# z,
None if self.yoco_kv else gate,
time_proj_bias,
delta_softplus=True,
cache_indices=mamba_cache_params.state_indices_tensor,
has_initial_state=attn_metadata.context_lens_tensor > 0,
query_start_loc=attn_metadata.query_start_loc)
else:
scan_outputs = selective_state_update(
mamba_cache_params.ssm_state,
hidden_states.transpose(0, 1),
discrete_time_step.transpose(0, 1),
self.A,
B,
C,
self.D,
# z
# gate.transpose(0, 1),
None if self.yoco_kv else gate.transpose(0, 1),
time_proj_bias,
dt_softplus=True,
state_batch_indices=mamba_cache_params.state_indices_tensor)
scan_outputs = scan_outputs.transpose(0, 1)
# 4. Final linear projection
if self.yoco_kv:
# gate = gate.transpose(-1,-2).contiguous()
yoco_key_values = scan_outputs.transpose(-2, -1)
scan_outputs = self.swiGluActivation(scan_outputs, gate)
contextualized_states = self.out_proj(scan_outputs.transpose(-2,
-1))[0]
return contextualized_states, yoco_key_values
class SambaYDecoderLayer(nn.Module):
def __init__(
self,
config,
layer_idx,
cache_config,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.mlp = SambaYMLP(config)
self.input_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.yoco_mb = False
self.yoco_cross = False
if layer_idx >= config.num_hidden_layers // 2:
self.yoco_mb = True
self.yoco_cross = (layer_idx
>= (config.num_hidden_layers // 2 + 2))
self.use_mamba = config.mb_per_layer > 0 and \
layer_idx % config.mb_per_layer == 0
if self.use_mamba:
factory_kwargs = {"dtype": None}
self.attn = Phi4Mamba(config.hidden_size,
layer_idx=layer_idx,
yoco_cross=self.yoco_cross,
yoco_kv=self.yoco_mb,
**factory_kwargs)
else:
self.attn = SambaYAttention(config,
layer_idx=layer_idx,
yoco_cross=self.yoco_cross,
cache_config=cache_config,
prefix=f"{prefix}.self_attn")
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
positions: torch.Tensor,
attn_metadata: AttentionMetadata,
mamba_cache_params: MambaCacheParams,
ssm_output: Optional[torch.LongTensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if self.use_mamba:
assert mamba_cache_params is not None
else:
assert mamba_cache_params is None
residual = hidden_states
hidden_states = self.input_layernorm(
hidden_states.to(dtype=self.input_layernorm.weight.dtype))
if self.use_mamba:
attn_outputs, ssm_output = self.attn(hidden_states,
attn_metadata,
mamba_cache_params,
yoco_key_values=ssm_output)
residual = residual.to(torch.float32)
else:
attn_outputs = self.attn(hidden_states, )
hidden_states = residual + attn_outputs
residual = hidden_states
hidden_states = self.post_attention_layernorm(
hidden_states.to(dtype=self.post_attention_layernorm.weight.dtype))
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states, ssm_output
class SambaYModel(nn.Module):
def __init__(self,
config,
cache_config=None,
quant_config=None,
lora_config=None,
prefix: str = "") -> None:
super().__init__()
self.config = config
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
)
# Pipeline parallel is not supported since the second half of
# the layers share the kv cache.
if get_pp_group().world_size != 1:
raise ValueError("Pipeline Parallel not supported")
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: SambaYDecoderLayer(config,
int(prefix.split('.')[-1]),
cache_config,
prefix=prefix),
prefix=f"{prefix}.layers")
self.final_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
attn_metadata: AttentionMetadata,
mamba_cache_params: MambaCacheParams,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
mamba_state_idx = 0
ssm_output = None
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
if i == self.config.num_hidden_layers // 2 + 2:
# profile run
kv_cache_idx = self.config.num_hidden_layers // 2 + 1
cache_layer = self.layers[kv_cache_idx]
kv_cache = cache_layer.attn.attn.kv_cache
if kv_cache[0].numel() == 0:
break
# Starting from this layer, we do not need to calculate
# the kv cache since we reuse the kv cache from last layer.
# If in prefill phase, we can <s>prune></s> truncate
# the hidden state to save computation cost.
if attn_metadata.prefill_metadata and not envs.VLLM_USE_V1:
selected_token_indices = torch.cumsum(
attn_metadata.seq_lens_tensor, dim=0) - 1
hidden_states = hidden_states.index_select(
0, selected_token_indices)
ssm_output = ssm_output.index_select(
0, selected_token_indices)
if layer.use_mamba:
if i < self.config.num_hidden_layers // 2 or \
not layer.yoco_cross:
mamba_cache = mamba_cache_params.at_layer_idx(
mamba_state_idx)
mamba_state_idx += 1
else:
mamba_cache = mamba_cache_params.at_layer_idx(
mamba_state_idx - 1)
hidden_states, ssm_output = layer(hidden_states,
positions,
attn_metadata,
mamba_cache,
ssm_output=ssm_output)
else:
hidden_states, ssm_output = layer(
hidden_states,
positions,
attn_metadata,
None, # mamba_cache_params
ssm_output=ssm_output)
hidden_states = self.final_layernorm(
hidden_states.to(dtype=self.final_layernorm.weight.dtype))
return hidden_states
class Phi4FlashForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
lora_config = vllm_config.lora_config
quant_config = vllm_config.quant_config
scheduler_config = vllm_config.scheduler_config
self.compilation_config = vllm_config.compilation_config
self.vllm_config = vllm_config
# Prefix caching and chunked prefill is not supported for this model.
assert not cache_config.enable_prefix_caching, \
"Phi4flash currently does not support prefix caching"
assert not scheduler_config.chunked_prefill_enabled, \
"Phi4Flash currently does not support prefix caching"
super().__init__()
self.config = config
self.model_config = vllm_config.model_config
self.scheduler_config = scheduler_config
self.model = SambaYModel(config,
cache_config=cache_config,
prefix=maybe_prefix(prefix, "model"))
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=(
DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size),
quant_config=quant_config,
)
self.embedding_bias = None
# Used to track and store by the Mamba cache between steps.
self.mamba_cache: Optional[MambaCacheManager] = None
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size,
logits_as_input=False)
self.sampler = get_sampler()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[torch.Tensor, IntermediateTensors]:
if self.mamba_cache is None:
num_mamba_layers = self.config.num_hidden_layers \
// 2 // self.config.mb_per_layer + 1
self.mamba_cache = MambaCacheManager(
self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers,
*self._get_mamba_cache_shape())
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
attn_metadata = get_forward_context().attn_metadata
# input_ids and hidden_states isn't a one-to-one mapping in prefill
# stage due to YOCO optimization.
hidden_states = self.model(input_ids, positions, attn_metadata,
mamba_cache_params, intermediate_tensors,
inputs_embeds)
return hidden_states
def _get_mamba_cache_shape(
self
) -> tuple[Optional[tuple[int, int]], Optional[tuple[int, int]]]:
world_size = get_tensor_model_parallel_world_size()
hidden_size = self.config.hidden_size
mamba_expand = self.config.mamba_expand # 2
mamba_d_conv = self.config.mamba_d_conv # 4
mamba_d_state = self.config.mamba_d_state # 16
conv_state_shape = (
mamba_expand * hidden_size // world_size,
mamba_d_conv - 1,
)
temporal_state_shape = (
mamba_expand * hidden_size // world_size,
mamba_d_state,
)
return conv_state_shape, temporal_state_shape
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
return self.mamba_cache.copy_inputs_before_cuda_graphs(
input_buffers, **kwargs)
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
# If the shape is the same, it means that we have already
# prune hidden states manually.
prune_hidden_states = hidden_states.size(
0) != sampling_metadata.selected_token_indices.size(0)
processed_logits = self.logits_processor(
self.lm_head,
hidden_states,
sampling_metadata,
self.embedding_bias,
prune_hidden_states=prune_hidden_states)
return processed_logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(
self,
weights: Iterable[tuple[str, torch.Tensor]],
):
weights = {name: weight for name, weight in weights}
adjusted_weights = {}
for name, weight in weights.items():
if "A_log" in name:
name = name.replace("A_log", "A")
weight = -torch.exp(weight.float())
if "inner_cross_attn." in name:
name = name.replace("inner_cross_attn.", "")
adjusted_weights[name] = weight
adjusted_weights["lm_head.weight"] = weights[
"model.embed_tokens.weight"]
loaded_params: set[str] = set()
for name, param in self.named_parameters():
weight = adjusted_weights.get(name)
if weight is not None and weight.shape != param.shape:
logger.warning("Shape mismatch: %s %s %s", name, weight.shape,
param.shape)
loaded_params.add(name)
missing_keys, unexpected_keys = self.load_state_dict(adjusted_weights,
strict=False)
assert len(unexpected_keys) == 0, f"Unexpected keys: {unexpected_keys}"
assert len(missing_keys) == 0, f"Missing keys: {missing_keys}"
return loaded_params