diff --git a/vllm/model_executor/models/kimi_linear.py b/vllm/model_executor/models/kimi_linear.py index e36ff0227..4cd7b63c1 100644 --- a/vllm/model_executor/models/kimi_linear.py +++ b/vllm/model_executor/models/kimi_linear.py @@ -46,6 +46,7 @@ from vllm.transformers_utils.configs.kimi_linear import KimiLinearConfig from .interfaces import HasInnerState, IsHybrid, MixtureOfExperts, SupportsPP from .utils import ( + AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, make_layers, @@ -472,94 +473,7 @@ class KimiLinearModel(nn.Module): hidden_states, _ = self.norm(hidden_states, residual) return hidden_states - -class KimiLinearForCausalLM( - nn.Module, HasInnerState, SupportsPP, MixtureOfExperts, IsHybrid -): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - self.model_config = vllm_config.model_config - self.vllm_config = vllm_config - self.config = self.model_config.hf_config - quant_config = vllm_config.quant_config - self.quant_config = quant_config - self.model = KimiLinearModel( - vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") - ) - if get_pp_group().is_last_rank: - self.lm_head = ParallelLMHead( - self.config.vocab_size, - self.config.hidden_size, - quant_config=quant_config, - prefix=maybe_prefix(prefix, "lm_head"), - ) - else: - self.lm_head = PPMissingLayer() - logit_scale = getattr(self.config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor( - self.config.vocab_size, scale=logit_scale - ) - - def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.embed_input_ids(input_ids) - - def forward( - self, - input_ids: torch.Tensor | None, - positions: torch.Tensor, - intermediate_tensors: IntermediateTensors | None = None, - inputs_embeds: torch.Tensor | None = None, - **kwargs, - ) -> torch.Tensor | IntermediateTensors: - hidden_states = self.model( - input_ids, positions, intermediate_tensors, inputs_embeds, **kwargs - ) - return hidden_states - - @classmethod - def get_mamba_state_dtype_from_config( - cls, - vllm_config: "VllmConfig", - ) -> tuple[torch.dtype, torch.dtype, torch.dtype, torch.dtype]: - return MambaStateDtypeCalculator.kda_state_dtype( - vllm_config.model_config.dtype, vllm_config.cache_config.mamba_cache_dtype - ) - - @classmethod - def get_mamba_state_shape_from_config( - cls, vllm_config: "VllmConfig" - ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], tuple[int, ...]]: - parallel_config = vllm_config.parallel_config - hf_config = vllm_config.model_config.hf_config - tp_size = parallel_config.tensor_parallel_size - num_spec = ( - vllm_config.speculative_config.num_speculative_tokens - if vllm_config.speculative_config - else 0 - ) - return MambaStateShapeCalculator.kda_state_shape( - tp_size, - hf_config.linear_attn_config["num_heads"], - hf_config.linear_attn_config["head_dim"], - conv_kernel_size=hf_config.linear_attn_config["short_conv_kernel_size"], - num_spec=num_spec, - ) - - @classmethod - def get_mamba_state_copy_func( - cls, - ) -> tuple[ - MambaStateCopyFunc, MambaStateCopyFunc, MambaStateCopyFunc, MambaStateCopyFunc - ]: - return MambaStateCopyFuncCalculator.kda_state_copy_func() - - def compute_logits( - self, - hidden_states: torch.Tensor, - ) -> torch.Tensor | None: - return self.logits_processor(self.lm_head, hidden_states) - - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".gate_up_proj", ".gate_proj", 0), @@ -653,6 +567,101 @@ class KimiLinearForCausalLM( ) weight_loader(param, loaded_weight, **kwargs) loaded_params.add(name) + return loaded_params + + +class KimiLinearForCausalLM( + nn.Module, HasInnerState, SupportsPP, MixtureOfExperts, IsHybrid +): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + self.model_config = vllm_config.model_config + self.vllm_config = vllm_config + self.config = self.model_config.hf_config + quant_config = vllm_config.quant_config + self.quant_config = quant_config + self.model = KimiLinearModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) + if get_pp_group().is_last_rank: + self.lm_head = ParallelLMHead( + self.config.vocab_size, + self.config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) + else: + self.lm_head = PPMissingLayer() + logit_scale = getattr(self.config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor( + self.config.vocab_size, scale=logit_scale + ) + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) + + def forward( + self, + input_ids: torch.Tensor | None, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs, + ) -> torch.Tensor | IntermediateTensors: + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds, **kwargs + ) + return hidden_states + + @classmethod + def get_mamba_state_dtype_from_config( + cls, + vllm_config: "VllmConfig", + ) -> tuple[torch.dtype, torch.dtype, torch.dtype, torch.dtype]: + return MambaStateDtypeCalculator.kda_state_dtype( + vllm_config.model_config.dtype, vllm_config.cache_config.mamba_cache_dtype + ) + + @classmethod + def get_mamba_state_shape_from_config( + cls, vllm_config: "VllmConfig" + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], tuple[int, ...]]: + parallel_config = vllm_config.parallel_config + hf_config = vllm_config.model_config.hf_config + tp_size = parallel_config.tensor_parallel_size + num_spec = ( + vllm_config.speculative_config.num_speculative_tokens + if vllm_config.speculative_config + else 0 + ) + return MambaStateShapeCalculator.kda_state_shape( + tp_size, + hf_config.linear_attn_config["num_heads"], + hf_config.linear_attn_config["head_dim"], + conv_kernel_size=hf_config.linear_attn_config["short_conv_kernel_size"], + num_spec=num_spec, + ) + + @classmethod + def get_mamba_state_copy_func( + cls, + ) -> tuple[ + MambaStateCopyFunc, MambaStateCopyFunc, MambaStateCopyFunc, MambaStateCopyFunc + ]: + return MambaStateCopyFuncCalculator.kda_state_copy_func() + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor | None: + return self.logits_processor(self.lm_head, hidden_states) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) def get_spec_layer_idx_from_weight_name( diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index 80c0342cc..21d74d8b0 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -52,7 +52,12 @@ from vllm.sequence import IntermediateTensors from vllm.v1.attention.backend import AttentionMetadata from .interfaces import HasInnerState, IsHybrid -from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers +from .utils import ( + AutoWeightsLoader, + PPMissingLayer, + is_pp_missing_parameter, + make_layers, +) def replace_weight_name( @@ -494,6 +499,8 @@ class MiniMaxText01Model(nn.Module): quant_config = vllm_config.quant_config cache_config = vllm_config.cache_config scheduler_config = vllm_config.scheduler_config + self.config = config + self.CONCAT_FFN = True self.vocab_size = config.vocab_size @@ -620,128 +627,6 @@ class MiniMaxText01Model(nn.Module): def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) - def forward( - self, - input_ids: torch.Tensor | None, - positions: torch.Tensor, - intermediate_tensors: IntermediateTensors | None = None, - inputs_embeds: torch.Tensor | None = None, - **kwargs, - ) -> torch.Tensor | IntermediateTensors: - forward_context = get_forward_context() - attn_metadata = forward_context.attn_metadata - - if get_pp_group().is_first_rank: - if inputs_embeds is None: - hidden_states = self.embed_scale * self.embed_tokens(input_ids) - else: - hidden_states = inputs_embeds - residual = None - else: - assert intermediate_tensors is not None - hidden_states = intermediate_tensors["hidden_states"] - residual = intermediate_tensors["residual"] - - for layer in islice(self.layers, self.start_layer, self.end_layer): - hidden_states, residual = layer( - hidden_states=hidden_states, - positions=positions, - attn_metadata=attn_metadata, - residual=residual, - ) - if not get_pp_group().is_last_rank: - return IntermediateTensors( - {"hidden_states": hidden_states, "residual": residual} - ) - if residual is not None: - hidden_states, _ = self.norm(hidden_states, residual) - else: - hidden_states = self.norm(hidden_states) - - return hidden_states - - -class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: - super().__init__() - config = vllm_config.model_config.hf_config - - self.config = config - - if not hasattr(config, "sliding_window"): - config.sliding_window = None - - self.CONCAT_FFN = True - - if hasattr(vllm_config.model_config, "max_model_len"): - self.config.max_model_len = vllm_config.model_config.max_model_len - self.model = MiniMaxText01Model( - vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") - ) - if get_pp_group().is_last_rank: - self.lm_head = ParallelLMHead( - config.vocab_size, - self.config.hidden_size, - prefix=maybe_prefix(prefix, "lm_head"), - ) - - self.logits_processor = LogitsProcessor( - config.vocab_size, self.config.vocab_size - ) - - else: - self.lm_head = PPMissingLayer() - self.lm_head.float() - flash_layer_count = sum( - 1 for attn_type in self.model.decoder_attention_types if attn_type == 1 - ) - self.kv_cache = [torch.tensor([]) for _ in range(flash_layer_count)] - return - - def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): - return self.model.minimax_cache.copy_inputs_before_cuda_graphs( - input_buffers, **kwargs - ) - - def get_seqlen_agnostic_capture_inputs(self, batch_size: int): - return self.model.minimax_cache.get_seqlen_agnostic_capture_inputs(batch_size) - - def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.embed_input_ids(input_ids) - - def forward( - self, - input_ids: torch.Tensor | None, - positions: torch.Tensor, - intermediate_tensors: IntermediateTensors | None = None, - inputs_embeds: torch.Tensor | None = None, - **kwargs, - ) -> torch.Tensor: - hidden_states = self.model( - input_ids, positions, intermediate_tensors, inputs_embeds, **kwargs - ) - - return hidden_states - - def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: - logits = self.logits_processor(self.lm_head, hidden_states.float()) - - return logits - - def make_empty_intermediate_tensors( - self, batch_size: int, dtype: torch.dtype, device: torch.device - ) -> IntermediateTensors: - return IntermediateTensors( - { - "hidden_states": torch.zeros( - (batch_size, self.config.hidden_size), dtype=dtype, device=device - ), - "residual": torch.zeros( - (batch_size, self.config.hidden_size), dtype=dtype, device=device - ), - } - ) - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() @@ -753,17 +638,15 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid): return None def is_linear_attn_layer(layer_idx: int) -> bool: - if layer_idx is None or layer_idx >= len( - self.model.decoder_attention_types - ): + if layer_idx is None or layer_idx >= len(self.decoder_attention_types): return False - return self.model.decoder_attention_types[layer_idx] == 0 + return self.decoder_attention_types[layer_idx] == 0 def is_moe_weight(name: str) -> bool: return "block_sparse_moe" in name and not name.endswith(".bias") def get_expert_id(param_name): - pattern = r"model\.layers\.\d+\.block_sparse_moe\.experts\.(\d+)\." + pattern = r"layers\.\d+\.block_sparse_moe\.experts\.(\d+)\." match = re.search(pattern, param_name) if match: return match.group(1) @@ -948,9 +831,7 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid): for name, loaded_weight in weights: weight_at_layer = which_layer(name) - if weight_at_layer and weight_at_layer >= len( - self.model.decoder_attention_types - ): + if weight_at_layer and weight_at_layer >= len(self.decoder_attention_types): continue if is_layer_norm_weight(name): @@ -975,6 +856,128 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid): load_basic_weight(name, loaded_weight, self) return loaded_params + def forward( + self, + input_ids: torch.Tensor | None, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs, + ) -> torch.Tensor | IntermediateTensors: + forward_context = get_forward_context() + attn_metadata = forward_context.attn_metadata + + if get_pp_group().is_first_rank: + if inputs_embeds is None: + hidden_states = self.embed_scale * self.embed_tokens(input_ids) + else: + hidden_states = inputs_embeds + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for layer in islice(self.layers, self.start_layer, self.end_layer): + hidden_states, residual = layer( + hidden_states=hidden_states, + positions=positions, + attn_metadata=attn_metadata, + residual=residual, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) + if residual is not None: + hidden_states, _ = self.norm(hidden_states, residual) + else: + hidden_states = self.norm(hidden_states) + + return hidden_states + + +class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: + super().__init__() + config = vllm_config.model_config.hf_config + + self.config = config + + if not hasattr(config, "sliding_window"): + config.sliding_window = None + + self.CONCAT_FFN = True + + if hasattr(vllm_config.model_config, "max_model_len"): + self.config.max_model_len = vllm_config.model_config.max_model_len + self.model = MiniMaxText01Model( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) + if get_pp_group().is_last_rank: + self.lm_head = ParallelLMHead( + config.vocab_size, + self.config.hidden_size, + prefix=maybe_prefix(prefix, "lm_head"), + ) + + self.logits_processor = LogitsProcessor( + config.vocab_size, self.config.vocab_size + ) + + else: + self.lm_head = PPMissingLayer() + self.lm_head.float() + flash_layer_count = sum( + 1 for attn_type in self.model.decoder_attention_types if attn_type == 1 + ) + self.kv_cache = [torch.tensor([]) for _ in range(flash_layer_count)] + return + + def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): + return self.model.minimax_cache.copy_inputs_before_cuda_graphs( + input_buffers, **kwargs + ) + + def get_seqlen_agnostic_capture_inputs(self, batch_size: int): + return self.model.minimax_cache.get_seqlen_agnostic_capture_inputs(batch_size) + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) + + def forward( + self, + input_ids: torch.Tensor | None, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs, + ) -> torch.Tensor: + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds, **kwargs + ) + + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states.float()) + + return logits + + def make_empty_intermediate_tensors( + self, batch_size: int, dtype: torch.dtype, device: torch.device + ) -> IntermediateTensors: + return IntermediateTensors( + { + "hidden_states": torch.zeros( + (batch_size, self.config.hidden_size), dtype=dtype, device=device + ), + "residual": torch.zeros( + (batch_size, self.config.hidden_size), dtype=dtype, device=device + ), + } + ) + @classmethod def get_mamba_state_dtype_from_config( cls, @@ -1011,3 +1014,7 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid): @classmethod def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc]: return MambaStateCopyFuncCalculator.linear_attention_state_copy_func() + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights)