standardize load_weights using AutoWeightsLoader for kimi_linear and minimax_text_01 (#37371)

Signed-off-by: XuLiu <xuliu40@gmail.com>
Co-authored-by: XuLiu <xuliu40@gmail.com>
This commit is contained in:
XLiu-2000
2026-03-18 23:05:37 +08:00
committed by GitHub
parent 296839a1b0
commit 17808394bc
2 changed files with 235 additions and 219 deletions

View File

@@ -46,6 +46,7 @@ from vllm.transformers_utils.configs.kimi_linear import KimiLinearConfig
from .interfaces import HasInnerState, IsHybrid, MixtureOfExperts, SupportsPP
from .utils import (
AutoWeightsLoader,
PPMissingLayer,
is_pp_missing_parameter,
make_layers,
@@ -472,94 +473,7 @@ class KimiLinearModel(nn.Module):
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class KimiLinearForCausalLM(
nn.Module, HasInnerState, SupportsPP, MixtureOfExperts, IsHybrid
):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.model_config = vllm_config.model_config
self.vllm_config = vllm_config
self.config = self.model_config.hf_config
quant_config = vllm_config.quant_config
self.quant_config = quant_config
self.model = KimiLinearModel(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)
if get_pp_group().is_last_rank:
self.lm_head = ParallelLMHead(
self.config.vocab_size,
self.config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
else:
self.lm_head = PPMissingLayer()
logit_scale = getattr(self.config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(
self.config.vocab_size, scale=logit_scale
)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def forward(
self,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs,
) -> torch.Tensor | IntermediateTensors:
hidden_states = self.model(
input_ids, positions, intermediate_tensors, inputs_embeds, **kwargs
)
return hidden_states
@classmethod
def get_mamba_state_dtype_from_config(
cls,
vllm_config: "VllmConfig",
) -> tuple[torch.dtype, torch.dtype, torch.dtype, torch.dtype]:
return MambaStateDtypeCalculator.kda_state_dtype(
vllm_config.model_config.dtype, vllm_config.cache_config.mamba_cache_dtype
)
@classmethod
def get_mamba_state_shape_from_config(
cls, vllm_config: "VllmConfig"
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
parallel_config = vllm_config.parallel_config
hf_config = vllm_config.model_config.hf_config
tp_size = parallel_config.tensor_parallel_size
num_spec = (
vllm_config.speculative_config.num_speculative_tokens
if vllm_config.speculative_config
else 0
)
return MambaStateShapeCalculator.kda_state_shape(
tp_size,
hf_config.linear_attn_config["num_heads"],
hf_config.linear_attn_config["head_dim"],
conv_kernel_size=hf_config.linear_attn_config["short_conv_kernel_size"],
num_spec=num_spec,
)
@classmethod
def get_mamba_state_copy_func(
cls,
) -> tuple[
MambaStateCopyFunc, MambaStateCopyFunc, MambaStateCopyFunc, MambaStateCopyFunc
]:
return MambaStateCopyFuncCalculator.kda_state_copy_func()
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor | None:
return self.logits_processor(self.lm_head, hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".gate_up_proj", ".gate_proj", 0),
@@ -653,6 +567,101 @@ class KimiLinearForCausalLM(
)
weight_loader(param, loaded_weight, **kwargs)
loaded_params.add(name)
return loaded_params
class KimiLinearForCausalLM(
nn.Module, HasInnerState, SupportsPP, MixtureOfExperts, IsHybrid
):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.model_config = vllm_config.model_config
self.vllm_config = vllm_config
self.config = self.model_config.hf_config
quant_config = vllm_config.quant_config
self.quant_config = quant_config
self.model = KimiLinearModel(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)
if get_pp_group().is_last_rank:
self.lm_head = ParallelLMHead(
self.config.vocab_size,
self.config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
else:
self.lm_head = PPMissingLayer()
logit_scale = getattr(self.config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(
self.config.vocab_size, scale=logit_scale
)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def forward(
self,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs,
) -> torch.Tensor | IntermediateTensors:
hidden_states = self.model(
input_ids, positions, intermediate_tensors, inputs_embeds, **kwargs
)
return hidden_states
@classmethod
def get_mamba_state_dtype_from_config(
cls,
vllm_config: "VllmConfig",
) -> tuple[torch.dtype, torch.dtype, torch.dtype, torch.dtype]:
return MambaStateDtypeCalculator.kda_state_dtype(
vllm_config.model_config.dtype, vllm_config.cache_config.mamba_cache_dtype
)
@classmethod
def get_mamba_state_shape_from_config(
cls, vllm_config: "VllmConfig"
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
parallel_config = vllm_config.parallel_config
hf_config = vllm_config.model_config.hf_config
tp_size = parallel_config.tensor_parallel_size
num_spec = (
vllm_config.speculative_config.num_speculative_tokens
if vllm_config.speculative_config
else 0
)
return MambaStateShapeCalculator.kda_state_shape(
tp_size,
hf_config.linear_attn_config["num_heads"],
hf_config.linear_attn_config["head_dim"],
conv_kernel_size=hf_config.linear_attn_config["short_conv_kernel_size"],
num_spec=num_spec,
)
@classmethod
def get_mamba_state_copy_func(
cls,
) -> tuple[
MambaStateCopyFunc, MambaStateCopyFunc, MambaStateCopyFunc, MambaStateCopyFunc
]:
return MambaStateCopyFuncCalculator.kda_state_copy_func()
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor | None:
return self.logits_processor(self.lm_head, hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights)
def get_spec_layer_idx_from_weight_name(

View File

@@ -52,7 +52,12 @@ from vllm.sequence import IntermediateTensors
from vllm.v1.attention.backend import AttentionMetadata
from .interfaces import HasInnerState, IsHybrid
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers
from .utils import (
AutoWeightsLoader,
PPMissingLayer,
is_pp_missing_parameter,
make_layers,
)
def replace_weight_name(
@@ -494,6 +499,8 @@ class MiniMaxText01Model(nn.Module):
quant_config = vllm_config.quant_config
cache_config = vllm_config.cache_config
scheduler_config = vllm_config.scheduler_config
self.config = config
self.CONCAT_FFN = True
self.vocab_size = config.vocab_size
@@ -620,128 +627,6 @@ class MiniMaxText01Model(nn.Module):
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs,
) -> torch.Tensor | IntermediateTensors:
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
if get_pp_group().is_first_rank:
if inputs_embeds is None:
hidden_states = self.embed_scale * self.embed_tokens(input_ids)
else:
hidden_states = inputs_embeds
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer(
hidden_states=hidden_states,
positions=positions,
attn_metadata=attn_metadata,
residual=residual,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors(
{"hidden_states": hidden_states, "residual": residual}
)
if residual is not None:
hidden_states, _ = self.norm(hidden_states, residual)
else:
hidden_states = self.norm(hidden_states)
return hidden_states
class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
config = vllm_config.model_config.hf_config
self.config = config
if not hasattr(config, "sliding_window"):
config.sliding_window = None
self.CONCAT_FFN = True
if hasattr(vllm_config.model_config, "max_model_len"):
self.config.max_model_len = vllm_config.model_config.max_model_len
self.model = MiniMaxText01Model(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)
if get_pp_group().is_last_rank:
self.lm_head = ParallelLMHead(
config.vocab_size,
self.config.hidden_size,
prefix=maybe_prefix(prefix, "lm_head"),
)
self.logits_processor = LogitsProcessor(
config.vocab_size, self.config.vocab_size
)
else:
self.lm_head = PPMissingLayer()
self.lm_head.float()
flash_layer_count = sum(
1 for attn_type in self.model.decoder_attention_types if attn_type == 1
)
self.kv_cache = [torch.tensor([]) for _ in range(flash_layer_count)]
return
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
return self.model.minimax_cache.copy_inputs_before_cuda_graphs(
input_buffers, **kwargs
)
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
return self.model.minimax_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def forward(
self,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs,
) -> torch.Tensor:
hidden_states = self.model(
input_ids, positions, intermediate_tensors, inputs_embeds, **kwargs
)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
logits = self.logits_processor(self.lm_head, hidden_states.float())
return logits
def make_empty_intermediate_tensors(
self, batch_size: int, dtype: torch.dtype, device: torch.device
) -> IntermediateTensors:
return IntermediateTensors(
{
"hidden_states": torch.zeros(
(batch_size, self.config.hidden_size), dtype=dtype, device=device
),
"residual": torch.zeros(
(batch_size, self.config.hidden_size), dtype=dtype, device=device
),
}
)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
@@ -753,17 +638,15 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
return None
def is_linear_attn_layer(layer_idx: int) -> bool:
if layer_idx is None or layer_idx >= len(
self.model.decoder_attention_types
):
if layer_idx is None or layer_idx >= len(self.decoder_attention_types):
return False
return self.model.decoder_attention_types[layer_idx] == 0
return self.decoder_attention_types[layer_idx] == 0
def is_moe_weight(name: str) -> bool:
return "block_sparse_moe" in name and not name.endswith(".bias")
def get_expert_id(param_name):
pattern = r"model\.layers\.\d+\.block_sparse_moe\.experts\.(\d+)\."
pattern = r"layers\.\d+\.block_sparse_moe\.experts\.(\d+)\."
match = re.search(pattern, param_name)
if match:
return match.group(1)
@@ -948,9 +831,7 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
for name, loaded_weight in weights:
weight_at_layer = which_layer(name)
if weight_at_layer and weight_at_layer >= len(
self.model.decoder_attention_types
):
if weight_at_layer and weight_at_layer >= len(self.decoder_attention_types):
continue
if is_layer_norm_weight(name):
@@ -975,6 +856,128 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
load_basic_weight(name, loaded_weight, self)
return loaded_params
def forward(
self,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs,
) -> torch.Tensor | IntermediateTensors:
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
if get_pp_group().is_first_rank:
if inputs_embeds is None:
hidden_states = self.embed_scale * self.embed_tokens(input_ids)
else:
hidden_states = inputs_embeds
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer(
hidden_states=hidden_states,
positions=positions,
attn_metadata=attn_metadata,
residual=residual,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors(
{"hidden_states": hidden_states, "residual": residual}
)
if residual is not None:
hidden_states, _ = self.norm(hidden_states, residual)
else:
hidden_states = self.norm(hidden_states)
return hidden_states
class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
config = vllm_config.model_config.hf_config
self.config = config
if not hasattr(config, "sliding_window"):
config.sliding_window = None
self.CONCAT_FFN = True
if hasattr(vllm_config.model_config, "max_model_len"):
self.config.max_model_len = vllm_config.model_config.max_model_len
self.model = MiniMaxText01Model(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)
if get_pp_group().is_last_rank:
self.lm_head = ParallelLMHead(
config.vocab_size,
self.config.hidden_size,
prefix=maybe_prefix(prefix, "lm_head"),
)
self.logits_processor = LogitsProcessor(
config.vocab_size, self.config.vocab_size
)
else:
self.lm_head = PPMissingLayer()
self.lm_head.float()
flash_layer_count = sum(
1 for attn_type in self.model.decoder_attention_types if attn_type == 1
)
self.kv_cache = [torch.tensor([]) for _ in range(flash_layer_count)]
return
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
return self.model.minimax_cache.copy_inputs_before_cuda_graphs(
input_buffers, **kwargs
)
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
return self.model.minimax_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def forward(
self,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs,
) -> torch.Tensor:
hidden_states = self.model(
input_ids, positions, intermediate_tensors, inputs_embeds, **kwargs
)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
logits = self.logits_processor(self.lm_head, hidden_states.float())
return logits
def make_empty_intermediate_tensors(
self, batch_size: int, dtype: torch.dtype, device: torch.device
) -> IntermediateTensors:
return IntermediateTensors(
{
"hidden_states": torch.zeros(
(batch_size, self.config.hidden_size), dtype=dtype, device=device
),
"residual": torch.zeros(
(batch_size, self.config.hidden_size), dtype=dtype, device=device
),
}
)
@classmethod
def get_mamba_state_dtype_from_config(
cls,
@@ -1011,3 +1014,7 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
@classmethod
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc]:
return MambaStateCopyFuncCalculator.linear_attention_state_copy_func()
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)