standardize load_weights using AutoWeightsLoader for kimi_linear and minimax_text_01 (#37371)
Signed-off-by: XuLiu <xuliu40@gmail.com> Co-authored-by: XuLiu <xuliu40@gmail.com>
This commit is contained in:
@@ -46,6 +46,7 @@ from vllm.transformers_utils.configs.kimi_linear import KimiLinearConfig
|
||||
|
||||
from .interfaces import HasInnerState, IsHybrid, MixtureOfExperts, SupportsPP
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
PPMissingLayer,
|
||||
is_pp_missing_parameter,
|
||||
make_layers,
|
||||
@@ -472,94 +473,7 @@ class KimiLinearModel(nn.Module):
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class KimiLinearForCausalLM(
|
||||
nn.Module, HasInnerState, SupportsPP, MixtureOfExperts, IsHybrid
|
||||
):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
self.model_config = vllm_config.model_config
|
||||
self.vllm_config = vllm_config
|
||||
self.config = self.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.quant_config = quant_config
|
||||
self.model = KimiLinearModel(
|
||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
|
||||
)
|
||||
if get_pp_group().is_last_rank:
|
||||
self.lm_head = ParallelLMHead(
|
||||
self.config.vocab_size,
|
||||
self.config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
logit_scale = getattr(self.config, "logit_scale", 1.0)
|
||||
self.logits_processor = LogitsProcessor(
|
||||
self.config.vocab_size, scale=logit_scale
|
||||
)
|
||||
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor | None,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: IntermediateTensors | None = None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor | IntermediateTensors:
|
||||
hidden_states = self.model(
|
||||
input_ids, positions, intermediate_tensors, inputs_embeds, **kwargs
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_dtype_from_config(
|
||||
cls,
|
||||
vllm_config: "VllmConfig",
|
||||
) -> tuple[torch.dtype, torch.dtype, torch.dtype, torch.dtype]:
|
||||
return MambaStateDtypeCalculator.kda_state_dtype(
|
||||
vllm_config.model_config.dtype, vllm_config.cache_config.mamba_cache_dtype
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_shape_from_config(
|
||||
cls, vllm_config: "VllmConfig"
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
parallel_config = vllm_config.parallel_config
|
||||
hf_config = vllm_config.model_config.hf_config
|
||||
tp_size = parallel_config.tensor_parallel_size
|
||||
num_spec = (
|
||||
vllm_config.speculative_config.num_speculative_tokens
|
||||
if vllm_config.speculative_config
|
||||
else 0
|
||||
)
|
||||
return MambaStateShapeCalculator.kda_state_shape(
|
||||
tp_size,
|
||||
hf_config.linear_attn_config["num_heads"],
|
||||
hf_config.linear_attn_config["head_dim"],
|
||||
conv_kernel_size=hf_config.linear_attn_config["short_conv_kernel_size"],
|
||||
num_spec=num_spec,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_copy_func(
|
||||
cls,
|
||||
) -> tuple[
|
||||
MambaStateCopyFunc, MambaStateCopyFunc, MambaStateCopyFunc, MambaStateCopyFunc
|
||||
]:
|
||||
return MambaStateCopyFuncCalculator.kda_state_copy_func()
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor | None:
|
||||
return self.logits_processor(self.lm_head, hidden_states)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
(".gate_up_proj", ".gate_proj", 0),
|
||||
@@ -653,6 +567,101 @@ class KimiLinearForCausalLM(
|
||||
)
|
||||
weight_loader(param, loaded_weight, **kwargs)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
|
||||
class KimiLinearForCausalLM(
|
||||
nn.Module, HasInnerState, SupportsPP, MixtureOfExperts, IsHybrid
|
||||
):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
self.model_config = vllm_config.model_config
|
||||
self.vllm_config = vllm_config
|
||||
self.config = self.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.quant_config = quant_config
|
||||
self.model = KimiLinearModel(
|
||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
|
||||
)
|
||||
if get_pp_group().is_last_rank:
|
||||
self.lm_head = ParallelLMHead(
|
||||
self.config.vocab_size,
|
||||
self.config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
logit_scale = getattr(self.config, "logit_scale", 1.0)
|
||||
self.logits_processor = LogitsProcessor(
|
||||
self.config.vocab_size, scale=logit_scale
|
||||
)
|
||||
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor | None,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: IntermediateTensors | None = None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor | IntermediateTensors:
|
||||
hidden_states = self.model(
|
||||
input_ids, positions, intermediate_tensors, inputs_embeds, **kwargs
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_dtype_from_config(
|
||||
cls,
|
||||
vllm_config: "VllmConfig",
|
||||
) -> tuple[torch.dtype, torch.dtype, torch.dtype, torch.dtype]:
|
||||
return MambaStateDtypeCalculator.kda_state_dtype(
|
||||
vllm_config.model_config.dtype, vllm_config.cache_config.mamba_cache_dtype
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_shape_from_config(
|
||||
cls, vllm_config: "VllmConfig"
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
parallel_config = vllm_config.parallel_config
|
||||
hf_config = vllm_config.model_config.hf_config
|
||||
tp_size = parallel_config.tensor_parallel_size
|
||||
num_spec = (
|
||||
vllm_config.speculative_config.num_speculative_tokens
|
||||
if vllm_config.speculative_config
|
||||
else 0
|
||||
)
|
||||
return MambaStateShapeCalculator.kda_state_shape(
|
||||
tp_size,
|
||||
hf_config.linear_attn_config["num_heads"],
|
||||
hf_config.linear_attn_config["head_dim"],
|
||||
conv_kernel_size=hf_config.linear_attn_config["short_conv_kernel_size"],
|
||||
num_spec=num_spec,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_copy_func(
|
||||
cls,
|
||||
) -> tuple[
|
||||
MambaStateCopyFunc, MambaStateCopyFunc, MambaStateCopyFunc, MambaStateCopyFunc
|
||||
]:
|
||||
return MambaStateCopyFuncCalculator.kda_state_copy_func()
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor | None:
|
||||
return self.logits_processor(self.lm_head, hidden_states)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
|
||||
)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
|
||||
def get_spec_layer_idx_from_weight_name(
|
||||
|
||||
@@ -52,7 +52,12 @@ from vllm.sequence import IntermediateTensors
|
||||
from vllm.v1.attention.backend import AttentionMetadata
|
||||
|
||||
from .interfaces import HasInnerState, IsHybrid
|
||||
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
PPMissingLayer,
|
||||
is_pp_missing_parameter,
|
||||
make_layers,
|
||||
)
|
||||
|
||||
|
||||
def replace_weight_name(
|
||||
@@ -494,6 +499,8 @@ class MiniMaxText01Model(nn.Module):
|
||||
quant_config = vllm_config.quant_config
|
||||
cache_config = vllm_config.cache_config
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
self.config = config
|
||||
self.CONCAT_FFN = True
|
||||
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
@@ -620,128 +627,6 @@ class MiniMaxText01Model(nn.Module):
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor | None,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: IntermediateTensors | None = None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor | IntermediateTensors:
|
||||
forward_context = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is None:
|
||||
hidden_states = self.embed_scale * self.embed_tokens(input_ids)
|
||||
else:
|
||||
hidden_states = inputs_embeds
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
for layer in islice(self.layers, self.start_layer, self.end_layer):
|
||||
hidden_states, residual = layer(
|
||||
hidden_states=hidden_states,
|
||||
positions=positions,
|
||||
attn_metadata=attn_metadata,
|
||||
residual=residual,
|
||||
)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors(
|
||||
{"hidden_states": hidden_states, "residual": residual}
|
||||
)
|
||||
if residual is not None:
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
else:
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
||||
self.config = config
|
||||
|
||||
if not hasattr(config, "sliding_window"):
|
||||
config.sliding_window = None
|
||||
|
||||
self.CONCAT_FFN = True
|
||||
|
||||
if hasattr(vllm_config.model_config, "max_model_len"):
|
||||
self.config.max_model_len = vllm_config.model_config.max_model_len
|
||||
self.model = MiniMaxText01Model(
|
||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
|
||||
)
|
||||
if get_pp_group().is_last_rank:
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
self.config.hidden_size,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
|
||||
self.logits_processor = LogitsProcessor(
|
||||
config.vocab_size, self.config.vocab_size
|
||||
)
|
||||
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
self.lm_head.float()
|
||||
flash_layer_count = sum(
|
||||
1 for attn_type in self.model.decoder_attention_types if attn_type == 1
|
||||
)
|
||||
self.kv_cache = [torch.tensor([]) for _ in range(flash_layer_count)]
|
||||
return
|
||||
|
||||
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
|
||||
return self.model.minimax_cache.copy_inputs_before_cuda_graphs(
|
||||
input_buffers, **kwargs
|
||||
)
|
||||
|
||||
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
|
||||
return self.model.minimax_cache.get_seqlen_agnostic_capture_inputs(batch_size)
|
||||
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor | None,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: IntermediateTensors | None = None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(
|
||||
input_ids, positions, intermediate_tensors, inputs_embeds, **kwargs
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states.float())
|
||||
|
||||
return logits
|
||||
|
||||
def make_empty_intermediate_tensors(
|
||||
self, batch_size: int, dtype: torch.dtype, device: torch.device
|
||||
) -> IntermediateTensors:
|
||||
return IntermediateTensors(
|
||||
{
|
||||
"hidden_states": torch.zeros(
|
||||
(batch_size, self.config.hidden_size), dtype=dtype, device=device
|
||||
),
|
||||
"residual": torch.zeros(
|
||||
(batch_size, self.config.hidden_size), dtype=dtype, device=device
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
@@ -753,17 +638,15 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
||||
return None
|
||||
|
||||
def is_linear_attn_layer(layer_idx: int) -> bool:
|
||||
if layer_idx is None or layer_idx >= len(
|
||||
self.model.decoder_attention_types
|
||||
):
|
||||
if layer_idx is None or layer_idx >= len(self.decoder_attention_types):
|
||||
return False
|
||||
return self.model.decoder_attention_types[layer_idx] == 0
|
||||
return self.decoder_attention_types[layer_idx] == 0
|
||||
|
||||
def is_moe_weight(name: str) -> bool:
|
||||
return "block_sparse_moe" in name and not name.endswith(".bias")
|
||||
|
||||
def get_expert_id(param_name):
|
||||
pattern = r"model\.layers\.\d+\.block_sparse_moe\.experts\.(\d+)\."
|
||||
pattern = r"layers\.\d+\.block_sparse_moe\.experts\.(\d+)\."
|
||||
match = re.search(pattern, param_name)
|
||||
if match:
|
||||
return match.group(1)
|
||||
@@ -948,9 +831,7 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
weight_at_layer = which_layer(name)
|
||||
if weight_at_layer and weight_at_layer >= len(
|
||||
self.model.decoder_attention_types
|
||||
):
|
||||
if weight_at_layer and weight_at_layer >= len(self.decoder_attention_types):
|
||||
continue
|
||||
|
||||
if is_layer_norm_weight(name):
|
||||
@@ -975,6 +856,128 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
||||
load_basic_weight(name, loaded_weight, self)
|
||||
return loaded_params
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor | None,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: IntermediateTensors | None = None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor | IntermediateTensors:
|
||||
forward_context = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is None:
|
||||
hidden_states = self.embed_scale * self.embed_tokens(input_ids)
|
||||
else:
|
||||
hidden_states = inputs_embeds
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
for layer in islice(self.layers, self.start_layer, self.end_layer):
|
||||
hidden_states, residual = layer(
|
||||
hidden_states=hidden_states,
|
||||
positions=positions,
|
||||
attn_metadata=attn_metadata,
|
||||
residual=residual,
|
||||
)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors(
|
||||
{"hidden_states": hidden_states, "residual": residual}
|
||||
)
|
||||
if residual is not None:
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
else:
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
||||
self.config = config
|
||||
|
||||
if not hasattr(config, "sliding_window"):
|
||||
config.sliding_window = None
|
||||
|
||||
self.CONCAT_FFN = True
|
||||
|
||||
if hasattr(vllm_config.model_config, "max_model_len"):
|
||||
self.config.max_model_len = vllm_config.model_config.max_model_len
|
||||
self.model = MiniMaxText01Model(
|
||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
|
||||
)
|
||||
if get_pp_group().is_last_rank:
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
self.config.hidden_size,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
|
||||
self.logits_processor = LogitsProcessor(
|
||||
config.vocab_size, self.config.vocab_size
|
||||
)
|
||||
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
self.lm_head.float()
|
||||
flash_layer_count = sum(
|
||||
1 for attn_type in self.model.decoder_attention_types if attn_type == 1
|
||||
)
|
||||
self.kv_cache = [torch.tensor([]) for _ in range(flash_layer_count)]
|
||||
return
|
||||
|
||||
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
|
||||
return self.model.minimax_cache.copy_inputs_before_cuda_graphs(
|
||||
input_buffers, **kwargs
|
||||
)
|
||||
|
||||
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
|
||||
return self.model.minimax_cache.get_seqlen_agnostic_capture_inputs(batch_size)
|
||||
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor | None,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: IntermediateTensors | None = None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(
|
||||
input_ids, positions, intermediate_tensors, inputs_embeds, **kwargs
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states.float())
|
||||
|
||||
return logits
|
||||
|
||||
def make_empty_intermediate_tensors(
|
||||
self, batch_size: int, dtype: torch.dtype, device: torch.device
|
||||
) -> IntermediateTensors:
|
||||
return IntermediateTensors(
|
||||
{
|
||||
"hidden_states": torch.zeros(
|
||||
(batch_size, self.config.hidden_size), dtype=dtype, device=device
|
||||
),
|
||||
"residual": torch.zeros(
|
||||
(batch_size, self.config.hidden_size), dtype=dtype, device=device
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_dtype_from_config(
|
||||
cls,
|
||||
@@ -1011,3 +1014,7 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
||||
@classmethod
|
||||
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc]:
|
||||
return MambaStateCopyFuncCalculator.linear_attention_state_copy_func()
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
Reference in New Issue
Block a user