[Model] support MiniMax-VL-01 model (#16328)
Signed-off-by: qingjun <qingjun@minimaxi.com>
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
import copy
|
||||
import math
|
||||
import re
|
||||
from typing import Dict, Iterable, List, Optional, Tuple, Union
|
||||
from typing import Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
@@ -110,7 +110,17 @@ class MiniMaxText01RMSNormTP(CustomOp):
|
||||
variance = tensor_model_parallel_all_reduce(
|
||||
variance) / self.tp_world
|
||||
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
||||
x = x.to(orig_dtype) * self.weight
|
||||
|
||||
weight = self.weight
|
||||
if x.size(-1) != self.weight.size(0):
|
||||
if self.weight.size(0) < x.size(-1):
|
||||
repeat_count = (x.size(-1) + self.weight.size(0)) // x.size(-1)
|
||||
full_weight = self.weight.repeat(repeat_count)
|
||||
weight = full_weight[:x.size(-1)]
|
||||
else:
|
||||
weight = self.weight[:x.size(-1)]
|
||||
|
||||
x = x.to(orig_dtype) * weight
|
||||
return x
|
||||
|
||||
def forward(
|
||||
@@ -421,6 +431,10 @@ class MiniMaxText01LinearAttention(nn.Module):
|
||||
attn_metadata):
|
||||
hidden = []
|
||||
for _prefill_idx in range(getattr(attn_metadata, "num_prefills", 0)):
|
||||
if _prefill_idx >= len(attn_metadata.query_start_loc):
|
||||
break
|
||||
if _prefill_idx >= len(state_indices_tensor):
|
||||
break
|
||||
_start = attn_metadata.query_start_loc[_prefill_idx]
|
||||
_end = attn_metadata.query_start_loc[_prefill_idx + 1]
|
||||
slot_id = state_indices_tensor[_prefill_idx]
|
||||
@@ -443,6 +457,10 @@ class MiniMaxText01LinearAttention(nn.Module):
|
||||
hidden.append(
|
||||
self._decode_infer(q, k, v, kv_cache, state_indices_tensor,
|
||||
attn_metadata))
|
||||
|
||||
if not hidden:
|
||||
return torch.empty((0, q.size(-1)), device=q.device, dtype=q.dtype)
|
||||
|
||||
hidden = torch.concat(hidden, dim=0).contiguous()
|
||||
return hidden
|
||||
|
||||
@@ -663,6 +681,9 @@ class MiniMaxText01DecoderLayer(nn.Module):
|
||||
self.shared_moe = False
|
||||
|
||||
shared_intermediate = getattr(config, 'shared_intermediate_size', 0)
|
||||
if isinstance(shared_intermediate, list):
|
||||
shared_intermediate = shared_intermediate[
|
||||
layer_id] if layer_id < len(shared_intermediate) else 0
|
||||
if shared_intermediate > 0:
|
||||
self.shared_moe = True
|
||||
self.shared_mlp = MiniMaxText01MLP(
|
||||
@@ -875,6 +896,8 @@ class MiniMaxText01Model(nn.Module):
|
||||
|
||||
slots_to_clear = []
|
||||
for _prefill_id in range(getattr(attn_metadata, "num_prefills", 0)):
|
||||
if _prefill_id >= len(seq_id_map):
|
||||
break
|
||||
seq_id = seq_id_map[_prefill_id]
|
||||
if attn_metadata.context_lens_tensor[
|
||||
_prefill_id] == 0 and seq_id in seq_to_slot_maps:
|
||||
@@ -886,13 +909,18 @@ class MiniMaxText01Model(nn.Module):
|
||||
dtype=torch.long)
|
||||
minimax_cache_tensors[:, slots_tensor, ...] = 0
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
intermediate_tensors=None,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs) -> torch.Tensor:
|
||||
**kwargs) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
forward_context = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if attn_metadata is None:
|
||||
@@ -901,6 +929,7 @@ class MiniMaxText01Model(nn.Module):
|
||||
kwargs["request_ids_to_seq_ids"] = {}
|
||||
if "finished_requests_ids" not in kwargs:
|
||||
kwargs["finished_requests_ids"] = []
|
||||
|
||||
(
|
||||
minimax_cache_tensors,
|
||||
state_indices_tensor,
|
||||
@@ -922,15 +951,11 @@ class MiniMaxText01Model(nn.Module):
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
kv_cache_index = 0
|
||||
minimax_cache_index = 0
|
||||
attn_metadata.rotary_emb = self.rotary_emb
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
_caches = None
|
||||
if isinstance(layer.self_attn, MiniMaxText01Attention):
|
||||
_caches = kv_caches[kv_cache_index]
|
||||
kv_cache_index += 1
|
||||
if isinstance(layer.self_attn, MiniMaxText01LinearAttention):
|
||||
current_state_layer = minimax_cache_index
|
||||
_caches = minimax_cache_params.at_layer_idx(
|
||||
@@ -1009,15 +1034,20 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid,
|
||||
return self.model.minimax_cache.get_seqlen_agnostic_capture_inputs(
|
||||
batch_size)
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
def forward(self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, self.kv_cache,
|
||||
intermediate_tensors, inputs_embeds,
|
||||
**kwargs)
|
||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds, **kwargs)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -1043,8 +1073,9 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid,
|
||||
})
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> None:
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
|
||||
def which_layer(name: str) -> int:
|
||||
if "layers" in name:
|
||||
@@ -1108,6 +1139,7 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid,
|
||||
weight_name,
|
||||
expert_id=expert_id,
|
||||
shard_id=shard_id)
|
||||
loaded_params.add(name)
|
||||
break
|
||||
else:
|
||||
if is_pp_missing_parameter(name, self):
|
||||
@@ -1117,6 +1149,7 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid,
|
||||
default_weight_loader)
|
||||
weight_loader = weight_loader_with_alias(name)(weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return
|
||||
|
||||
def is_shared_mlp_weight(name: str) -> bool:
|
||||
@@ -1154,6 +1187,7 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid,
|
||||
else:
|
||||
raise AssertionError(
|
||||
"MLP weight not in [gate_up_proj, down_proj]")
|
||||
loaded_params.add(name)
|
||||
return
|
||||
|
||||
def is_mha_weight(name: str) -> bool:
|
||||
@@ -1170,6 +1204,7 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid,
|
||||
MiniMaxText01LinearAttention.weight_direct_load)
|
||||
weight_loader = weight_loader_with_alias(name)(weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return
|
||||
|
||||
def load_flash_attn_weight(name: str, loaded_weight: torch.Tensor,
|
||||
@@ -1194,6 +1229,7 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid,
|
||||
default_weight_loader)
|
||||
weight_loader = weight_loader_with_alias(name)(weight_loader)
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
loaded_params.add(name)
|
||||
break
|
||||
else:
|
||||
if is_pp_missing_parameter(name, self):
|
||||
@@ -1204,6 +1240,7 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid,
|
||||
default_weight_loader)
|
||||
weight_loader = weight_loader_with_alias(name)(weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return
|
||||
|
||||
def is_layer_norm_weight(name: str) -> bool:
|
||||
@@ -1219,6 +1256,7 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid,
|
||||
default_weight_loader)
|
||||
weight_loader = weight_loader_with_alias(name)(weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return
|
||||
|
||||
def load_basic_weight(name: str, loaded_weight: torch.Tensor,
|
||||
@@ -1230,6 +1268,7 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid,
|
||||
default_weight_loader)
|
||||
weight_loader = weight_loader_with_alias(name)(weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
@@ -1258,4 +1297,4 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid,
|
||||
continue
|
||||
|
||||
load_basic_weight(name, loaded_weight, self)
|
||||
return
|
||||
return loaded_params
|
||||
|
||||
Reference in New Issue
Block a user