Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Inference-only MiniMaxText01 model."""
|
||||
|
||||
from collections.abc import Iterable
|
||||
from itertools import islice
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
@@ -18,25 +19,33 @@ from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_pp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
get_pp_group,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.linear import (
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.linear_attn import (
|
||||
MiniMaxText01LinearAttention)
|
||||
from vllm.model_executor.layers.mamba.linear_attn import MiniMaxText01LinearAttention
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||
MambaStateDtypeCalculator,
|
||||
MambaStateShapeCalculator,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
DEFAULT_VOCAB_PADDING_SIZE,
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.utils import maybe_prefix
|
||||
from vllm.sequence import IntermediateTensors
|
||||
@@ -45,25 +54,22 @@ from .interfaces import HasInnerState, IsHybrid
|
||||
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers
|
||||
|
||||
|
||||
def replace_weight_name(name: str,
|
||||
key: str = None,
|
||||
to: str = None,
|
||||
count: int = None,
|
||||
prefix: str = None) -> str:
|
||||
name = name.replace(key, to) if count is None else \
|
||||
name.replace(key, to, count)
|
||||
def replace_weight_name(
|
||||
name: str, key: str = None, to: str = None, count: int = None, prefix: str = None
|
||||
) -> str:
|
||||
name = name.replace(key, to) if count is None else name.replace(key, to, count)
|
||||
return name
|
||||
|
||||
|
||||
def weight_loader_with_alias(alias: str):
|
||||
|
||||
def wrapper(func: callable):
|
||||
|
||||
def inner_func(param: torch.Tensor,
|
||||
loaded_weight: torch.Tensor,
|
||||
*args,
|
||||
prefix: str = None,
|
||||
**kwargs):
|
||||
def inner_func(
|
||||
param: torch.Tensor,
|
||||
loaded_weight: torch.Tensor,
|
||||
*args,
|
||||
prefix: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
value = func(param, loaded_weight, *args, **kwargs)
|
||||
return value
|
||||
|
||||
@@ -73,7 +79,6 @@ def weight_loader_with_alias(alias: str):
|
||||
|
||||
|
||||
class MiniMaxText01MLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
@@ -103,7 +108,6 @@ class MiniMaxText01MLP(nn.Module):
|
||||
return
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
gate_up, _ = self.gate_up_proj(x)
|
||||
x = self.act_fn(gate_up)
|
||||
x, _ = self.down_proj(x)
|
||||
@@ -111,7 +115,6 @@ class MiniMaxText01MLP(nn.Module):
|
||||
|
||||
|
||||
class MiniMaxText01MoE(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_experts: int,
|
||||
@@ -162,8 +165,7 @@ class MiniMaxText01MoE(nn.Module):
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def gate_weight_loader(param: nn.Parameter,
|
||||
loaded_weight: torch.Tensor) -> None:
|
||||
def gate_weight_loader(param: nn.Parameter, loaded_weight: torch.Tensor) -> None:
|
||||
assert param.size() == loaded_weight.size()
|
||||
param.data.copy_(loaded_weight.to(torch.float32))
|
||||
return
|
||||
@@ -173,13 +175,13 @@ class MiniMaxText01MoE(nn.Module):
|
||||
hidden_states = hidden_states.view(-1, self.hidden_size)
|
||||
router_logits_fp32, _ = self.gate(hidden_states.to(torch.float32))
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states, router_logits_fp32.to(hidden_states.dtype))
|
||||
hidden_states, router_logits_fp32.to(hidden_states.dtype)
|
||||
)
|
||||
final_hidden = final_hidden_states.view(num_tokens, hidden_size)
|
||||
return final_hidden
|
||||
|
||||
|
||||
class MiniMaxText01Attention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
@@ -253,8 +255,13 @@ class MiniMaxText01Attention(nn.Module):
|
||||
)
|
||||
return
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
|
||||
positions: torch.Tensor, **kwargs) -> None:
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
@@ -263,7 +270,6 @@ class MiniMaxText01Attention(nn.Module):
|
||||
|
||||
|
||||
class MiniMaxText01DecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: MiniMaxConfig,
|
||||
@@ -288,14 +294,17 @@ class MiniMaxText01DecoderLayer(nn.Module):
|
||||
head_dim = getattr(config, "head_dim", None)
|
||||
if head_dim is None:
|
||||
head_dim = config.hidden_size // config.num_attention_heads
|
||||
if hasattr(config, "max_model_len") and isinstance(
|
||||
config.max_model_len, int):
|
||||
max_position_embeddings = min(config.max_position_embeddings,
|
||||
config.max_model_len)
|
||||
if hasattr(config, "max_model_len") and isinstance(config.max_model_len, int):
|
||||
max_position_embeddings = min(
|
||||
config.max_position_embeddings, config.max_model_len
|
||||
)
|
||||
if config.attention_type == 0:
|
||||
use_headxdim = True
|
||||
hidden_inner = (head_dim * config.num_attention_heads
|
||||
if use_headxdim else config.hidden_size)
|
||||
hidden_inner = (
|
||||
head_dim * config.num_attention_heads
|
||||
if use_headxdim
|
||||
else config.hidden_size
|
||||
)
|
||||
self.self_attn = MiniMaxText01LinearAttention(
|
||||
hidden_size=self.hidden_size,
|
||||
hidden_inner_size=hidden_inner,
|
||||
@@ -309,14 +318,16 @@ class MiniMaxText01DecoderLayer(nn.Module):
|
||||
quant_config=quant_config,
|
||||
layer_idx=self._ilayer,
|
||||
linear_layer_idx=linear_layer_id,
|
||||
prefix=prefix)
|
||||
prefix=prefix,
|
||||
)
|
||||
elif config.attention_type == 1:
|
||||
self.self_attn = MiniMaxText01Attention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
head_dim=head_dim,
|
||||
rotary_dim=config.rotary_dim
|
||||
if hasattr(config, "rotary_dim") else head_dim,
|
||||
if hasattr(config, "rotary_dim")
|
||||
else head_dim,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
max_position=max_position_embeddings,
|
||||
rope_theta=rope_theta,
|
||||
@@ -324,10 +335,12 @@ class MiniMaxText01DecoderLayer(nn.Module):
|
||||
quant_config=quant_config,
|
||||
layer_idx=self._ilayer,
|
||||
cache_config=cache_config,
|
||||
prefix=prefix)
|
||||
prefix=prefix,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported attention type: {self.config.attention_type}")
|
||||
f"Unsupported attention type: {self.config.attention_type}"
|
||||
)
|
||||
|
||||
if expert_num == 1:
|
||||
self.mlp = MiniMaxText01MLP(
|
||||
@@ -335,7 +348,8 @@ class MiniMaxText01DecoderLayer(nn.Module):
|
||||
intermediate_size=config.intermediate_size,
|
||||
quant_config=quant_config,
|
||||
layer_idx=self._ilayer,
|
||||
prefix=prefix)
|
||||
prefix=prefix,
|
||||
)
|
||||
else:
|
||||
self.block_sparse_moe = MiniMaxText01MoE(
|
||||
num_experts=expert_num,
|
||||
@@ -344,39 +358,51 @@ class MiniMaxText01DecoderLayer(nn.Module):
|
||||
intermediate_size=config.intermediate_size,
|
||||
layer_idx=self._ilayer,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix)
|
||||
prefix=prefix,
|
||||
)
|
||||
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
)
|
||||
if config.attention_type == 0:
|
||||
self.layernorm_attention_alpha = getattr(
|
||||
config, 'layernorm_linear_attention_alpha',
|
||||
getattr(config, 'linear_attn_alpha_factor', 1))
|
||||
config,
|
||||
"layernorm_linear_attention_alpha",
|
||||
getattr(config, "linear_attn_alpha_factor", 1),
|
||||
)
|
||||
self.layernorm_attention_beta = getattr(
|
||||
config, 'layernorm_linear_attention_beta',
|
||||
getattr(config, 'linear_attn_beta_factor', 1))
|
||||
config,
|
||||
"layernorm_linear_attention_beta",
|
||||
getattr(config, "linear_attn_beta_factor", 1),
|
||||
)
|
||||
else:
|
||||
self.layernorm_attention_alpha = getattr(
|
||||
config, 'layernorm_full_attention_alpha',
|
||||
getattr(config, 'full_attn_alpha_factor', 1))
|
||||
config,
|
||||
"layernorm_full_attention_alpha",
|
||||
getattr(config, "full_attn_alpha_factor", 1),
|
||||
)
|
||||
self.layernorm_attention_beta = getattr(
|
||||
config, 'layernorm_full_attention_beta',
|
||||
getattr(config, 'full_attn_beta_factor', 1))
|
||||
config,
|
||||
"layernorm_full_attention_beta",
|
||||
getattr(config, "full_attn_beta_factor", 1),
|
||||
)
|
||||
self.layernorm_mlp_alpha = getattr(
|
||||
config, 'layernorm_mlp_alpha',
|
||||
getattr(config, 'mlp_alpha_factor', 1))
|
||||
config, "layernorm_mlp_alpha", getattr(config, "mlp_alpha_factor", 1)
|
||||
)
|
||||
self.layernorm_mlp_beta = getattr(
|
||||
config, 'layernorm_mlp_beta', getattr(config, 'mlp_beta_factor',
|
||||
1))
|
||||
self.postnorm = getattr(config, 'postnorm', False)
|
||||
config, "layernorm_mlp_beta", getattr(config, "mlp_beta_factor", 1)
|
||||
)
|
||||
self.postnorm = getattr(config, "postnorm", False)
|
||||
self.shared_moe = False
|
||||
|
||||
shared_intermediate = getattr(config, 'shared_intermediate_size', 0)
|
||||
shared_intermediate = getattr(config, "shared_intermediate_size", 0)
|
||||
if isinstance(shared_intermediate, list):
|
||||
shared_intermediate = shared_intermediate[
|
||||
layer_id] if layer_id < len(shared_intermediate) else 0
|
||||
shared_intermediate = (
|
||||
shared_intermediate[layer_id]
|
||||
if layer_id < len(shared_intermediate)
|
||||
else 0
|
||||
)
|
||||
if shared_intermediate > 0:
|
||||
self.shared_moe = True
|
||||
self.shared_mlp = MiniMaxText01MLP(
|
||||
@@ -384,7 +410,8 @@ class MiniMaxText01DecoderLayer(nn.Module):
|
||||
intermediate_size=shared_intermediate,
|
||||
quant_config=quant_config,
|
||||
layer_idx=self._ilayer,
|
||||
prefix=prefix)
|
||||
prefix=prefix,
|
||||
)
|
||||
self.coefficient = ReplicatedLinear(
|
||||
self.hidden_size,
|
||||
1,
|
||||
@@ -392,20 +419,19 @@ class MiniMaxText01DecoderLayer(nn.Module):
|
||||
quant_config=quant_config,
|
||||
params_dtype=torch.float32,
|
||||
)
|
||||
self.coefficient.weight.weight_loader = (
|
||||
self.shared_moe_coefficient_loader)
|
||||
self.shared_moe_mode = getattr(config, 'shared_moe_mode',
|
||||
'softmax')
|
||||
self.coefficient.weight.weight_loader = self.shared_moe_coefficient_loader
|
||||
self.shared_moe_mode = getattr(config, "shared_moe_mode", "softmax")
|
||||
return
|
||||
|
||||
def forward(self,
|
||||
hidden_states: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
is_warmup: bool = False,
|
||||
**kwargs) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
is_warmup: bool = False,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
layernorm_input = hidden_states
|
||||
layernorm_output = self.input_layernorm(layernorm_input)
|
||||
residual = layernorm_output if self.postnorm else layernorm_input
|
||||
@@ -417,8 +443,7 @@ class MiniMaxText01DecoderLayer(nn.Module):
|
||||
)
|
||||
|
||||
residual = residual * self.layernorm_attention_alpha
|
||||
self_attention_output = (self_attention_output *
|
||||
self.layernorm_attention_beta)
|
||||
self_attention_output = self_attention_output * self.layernorm_attention_beta
|
||||
|
||||
layernorm_input = residual + self_attention_output
|
||||
layernorm_output = self.post_attention_layernorm(layernorm_input)
|
||||
@@ -432,19 +457,16 @@ class MiniMaxText01DecoderLayer(nn.Module):
|
||||
if self.shared_moe:
|
||||
before_moe_dtype = layernorm_output.dtype
|
||||
moe_hidden_fp32 = moe_hidden_states.to(torch.float32)
|
||||
output_mlp = self.shared_mlp(layernorm_output).to(
|
||||
torch.float32)
|
||||
output_mlp = self.shared_mlp(layernorm_output).to(torch.float32)
|
||||
|
||||
coef, _ = self.coefficient(layernorm_output.to(torch.float32))
|
||||
|
||||
if self.shared_moe_mode == 'softmax':
|
||||
if self.shared_moe_mode == "softmax":
|
||||
coef = torch.nn.functional.softmax(coef, dim=-1)
|
||||
hidden_states = moe_hidden_fp32 * (
|
||||
1 - coef) + output_mlp * coef
|
||||
elif self.shared_moe_mode == 'sigmoid':
|
||||
hidden_states = moe_hidden_fp32 * (1 - coef) + output_mlp * coef
|
||||
elif self.shared_moe_mode == "sigmoid":
|
||||
coef = torch.nn.functional.sigmoid(coef)
|
||||
hidden_states = moe_hidden_fp32 * (
|
||||
1 - coef) + output_mlp * coef
|
||||
hidden_states = moe_hidden_fp32 * (1 - coef) + output_mlp * coef
|
||||
|
||||
hidden_states = hidden_states.to(before_moe_dtype)
|
||||
else:
|
||||
@@ -458,8 +480,9 @@ class MiniMaxText01DecoderLayer(nn.Module):
|
||||
return hidden_states, None
|
||||
|
||||
@staticmethod
|
||||
def shared_moe_coefficient_loader(param: torch.Tensor,
|
||||
loaded_weight: torch.Tensor) -> None:
|
||||
def shared_moe_coefficient_loader(
|
||||
param: torch.Tensor, loaded_weight: torch.Tensor
|
||||
) -> None:
|
||||
assert param.size() == loaded_weight.size()
|
||||
|
||||
param.data.copy_(loaded_weight.to(torch.float32))
|
||||
@@ -468,7 +491,6 @@ class MiniMaxText01DecoderLayer(nn.Module):
|
||||
|
||||
@support_torch_compile
|
||||
class MiniMaxText01Model(nn.Module):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config: MiniMaxConfig = vllm_config.model_config.hf_config
|
||||
@@ -481,8 +503,8 @@ class MiniMaxText01Model(nn.Module):
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.decoder_attention_types = getattr(
|
||||
config, "attn_type_list", False) or getattr(
|
||||
config, "decoder_attention_types", False)
|
||||
config, "attn_type_list", False
|
||||
) or getattr(config, "decoder_attention_types", False)
|
||||
# The HF format uses "layer_types" instead of "attn_type_list"
|
||||
# where "linear_attention" is 0 and "full_attention" is 1
|
||||
if not self.decoder_attention_types and hasattr(config, "layer_types"):
|
||||
@@ -510,50 +532,57 @@ class MiniMaxText01Model(nn.Module):
|
||||
self.embed_tokens = PPMissingLayer()
|
||||
|
||||
def layer_fn(prefix):
|
||||
layer_idx = int(prefix.split('.')[-1])
|
||||
layer_idx = int(prefix.split(".")[-1])
|
||||
layer_config = config
|
||||
layer_config.attention_type = self.decoder_attention_types[
|
||||
layer_idx]
|
||||
layer_config.attention_type = self.decoder_attention_types[layer_idx]
|
||||
layer_config.layer_idx = layer_idx
|
||||
|
||||
decoder_kwargs = {
|
||||
"quant_config": quant_config,
|
||||
"layer_id": layer_idx,
|
||||
"model_config": model_config,
|
||||
"cache_config": cache_config
|
||||
"cache_config": cache_config,
|
||||
}
|
||||
|
||||
if layer_config.attention_type == 0:
|
||||
decoder_kwargs["linear_layer_id"] = sum(
|
||||
1 for i in range(layer_idx)
|
||||
if self.decoder_attention_types[i] == 0)
|
||||
1 for i in range(layer_idx) if self.decoder_attention_types[i] == 0
|
||||
)
|
||||
else:
|
||||
decoder_kwargs["linear_layer_id"] = None
|
||||
|
||||
if hasattr(config, "num_local_experts") and isinstance(
|
||||
config.num_local_experts, list):
|
||||
decoder_kwargs["expert_num"] = config.num_local_experts[
|
||||
layer_idx]
|
||||
config.num_local_experts, list
|
||||
):
|
||||
decoder_kwargs["expert_num"] = config.num_local_experts[layer_idx]
|
||||
elif hasattr(config, "num_local_experts") and isinstance(
|
||||
config.num_local_experts, int):
|
||||
config.num_local_experts, int
|
||||
):
|
||||
decoder_kwargs["expert_num"] = config.num_local_experts
|
||||
else:
|
||||
decoder_kwargs["expert_num"] = 1
|
||||
|
||||
return MiniMaxText01DecoderLayer(layer_config,
|
||||
**decoder_kwargs,
|
||||
prefix=prefix)
|
||||
return MiniMaxText01DecoderLayer(
|
||||
layer_config, **decoder_kwargs, prefix=prefix
|
||||
)
|
||||
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers, layer_fn, prefix=f"{prefix}.layers")
|
||||
config.num_hidden_layers, layer_fn, prefix=f"{prefix}.layers"
|
||||
)
|
||||
|
||||
linear_layer_nums = sum(1 for i in range(config.num_hidden_layers)
|
||||
if self.decoder_attention_types[i] == 0)
|
||||
linear_layer_nums = sum(
|
||||
1
|
||||
for i in range(config.num_hidden_layers)
|
||||
if self.decoder_attention_types[i] == 0
|
||||
)
|
||||
max_slots_number = scheduler_config.max_num_seqs
|
||||
self.cache_shape = (linear_layer_nums, max_slots_number,
|
||||
config.num_attention_heads //
|
||||
get_tensor_model_parallel_world_size(),
|
||||
config.head_dim, config.head_dim)
|
||||
self.cache_shape = (
|
||||
linear_layer_nums,
|
||||
max_slots_number,
|
||||
config.num_attention_heads // get_tensor_model_parallel_world_size(),
|
||||
config.head_dim,
|
||||
config.head_dim,
|
||||
)
|
||||
_dummy = torch.zeros(1)
|
||||
self._dtype = _dummy.dtype
|
||||
del _dummy
|
||||
@@ -568,12 +597,12 @@ class MiniMaxText01Model(nn.Module):
|
||||
self.embed_scale = 1.0
|
||||
return
|
||||
|
||||
def _clear_prefill_cache(self, attn_metadata,
|
||||
minimax_cache_tensors: torch.Tensor, **kwargs):
|
||||
def _clear_prefill_cache(
|
||||
self, attn_metadata, minimax_cache_tensors: torch.Tensor, **kwargs
|
||||
):
|
||||
seq_to_slot_maps = {}
|
||||
seq_id_map = sum(list(kwargs["request_ids_to_seq_ids"].values()), [])
|
||||
for _, seq_to_slot_map in (
|
||||
self.minimax_cache.cache_indices_mapping.items()):
|
||||
for _, seq_to_slot_map in self.minimax_cache.cache_indices_mapping.items():
|
||||
seq_to_slot_maps.update(seq_to_slot_map)
|
||||
|
||||
slots_to_clear = []
|
||||
@@ -581,25 +610,29 @@ class MiniMaxText01Model(nn.Module):
|
||||
if _prefill_id >= len(seq_id_map):
|
||||
break
|
||||
seq_id = seq_id_map[_prefill_id]
|
||||
if attn_metadata.context_lens_tensor[
|
||||
_prefill_id] == 0 and seq_id in seq_to_slot_maps:
|
||||
if (
|
||||
attn_metadata.context_lens_tensor[_prefill_id] == 0
|
||||
and seq_id in seq_to_slot_maps
|
||||
):
|
||||
slots_to_clear.append(seq_to_slot_maps[seq_id])
|
||||
|
||||
if slots_to_clear:
|
||||
slots_tensor = torch.tensor(slots_to_clear,
|
||||
device=minimax_cache_tensors.device,
|
||||
dtype=torch.long)
|
||||
slots_tensor = torch.tensor(
|
||||
slots_to_clear, device=minimax_cache_tensors.device, dtype=torch.long
|
||||
)
|
||||
minimax_cache_tensors[:, slots_tensor, ...] = 0
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
forward_context = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
|
||||
@@ -622,10 +655,9 @@ class MiniMaxText01Model(nn.Module):
|
||||
residual=residual,
|
||||
)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
return IntermediateTensors(
|
||||
{"hidden_states": hidden_states, "residual": residual}
|
||||
)
|
||||
if residual is not None:
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
else:
|
||||
@@ -635,9 +667,7 @@ class MiniMaxText01Model(nn.Module):
|
||||
|
||||
|
||||
class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
lora_config = vllm_config.lora_config
|
||||
@@ -652,8 +682,9 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
||||
self.unpadded_vocab_size = self.config.vocab_size
|
||||
if hasattr(vllm_config.model_config, "max_model_len"):
|
||||
self.config.max_model_len = vllm_config.model_config.max_model_len
|
||||
self.model = MiniMaxText01Model(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.model = MiniMaxText01Model(
|
||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
|
||||
)
|
||||
if get_pp_group().is_last_rank:
|
||||
self.lm_head = ParallelLMHead(
|
||||
self.unpadded_vocab_size,
|
||||
@@ -663,37 +694,41 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
self.config.vocab_size)
|
||||
self.logits_processor = LogitsProcessor(
|
||||
self.unpadded_vocab_size, self.config.vocab_size
|
||||
)
|
||||
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
self.lm_head.float()
|
||||
flash_layer_count = sum(
|
||||
1 for attn_type in self.model.decoder_attention_types
|
||||
if attn_type == 1)
|
||||
1 for attn_type in self.model.decoder_attention_types if attn_type == 1
|
||||
)
|
||||
self.kv_cache = [torch.tensor([]) for _ in range(flash_layer_count)]
|
||||
return
|
||||
|
||||
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
|
||||
return self.model.minimax_cache.copy_inputs_before_cuda_graphs(
|
||||
input_buffers, **kwargs)
|
||||
input_buffers, **kwargs
|
||||
)
|
||||
|
||||
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
|
||||
return self.model.minimax_cache.get_seqlen_agnostic_capture_inputs(
|
||||
batch_size)
|
||||
return self.model.minimax_cache.get_seqlen_agnostic_capture_inputs(batch_size)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
def forward(self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds, **kwargs)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(
|
||||
input_ids, positions, intermediate_tensors, inputs_embeds, **kwargs
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -703,21 +738,20 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
||||
return logits
|
||||
|
||||
def make_empty_intermediate_tensors(
|
||||
self, batch_size: int, dtype: torch.dtype,
|
||||
device: torch.device) -> IntermediateTensors:
|
||||
return IntermediateTensors({
|
||||
"hidden_states":
|
||||
torch.zeros((batch_size, self.config.hidden_size),
|
||||
dtype=dtype,
|
||||
device=device),
|
||||
"residual":
|
||||
torch.zeros((batch_size, self.config.hidden_size),
|
||||
dtype=dtype,
|
||||
device=device),
|
||||
})
|
||||
self, batch_size: int, dtype: torch.dtype, device: torch.device
|
||||
) -> IntermediateTensors:
|
||||
return IntermediateTensors(
|
||||
{
|
||||
"hidden_states": torch.zeros(
|
||||
(batch_size, self.config.hidden_size), dtype=dtype, device=device
|
||||
),
|
||||
"residual": torch.zeros(
|
||||
(batch_size, self.config.hidden_size), dtype=dtype, device=device
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
|
||||
@@ -729,7 +763,8 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
||||
|
||||
def is_linear_attn_layer(layer_idx: int) -> bool:
|
||||
if layer_idx is None or layer_idx >= len(
|
||||
self.model.decoder_attention_types):
|
||||
self.model.decoder_attention_types
|
||||
):
|
||||
return False
|
||||
return self.model.decoder_attention_types[layer_idx] == 0
|
||||
|
||||
@@ -737,39 +772,48 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
||||
return "block_sparse_moe" in name and not name.endswith(".bias")
|
||||
|
||||
def get_expert_id(param_name):
|
||||
pattern = r'model\.layers\.\d+\.block_sparse_moe\.experts\.(\d+)\.'
|
||||
pattern = r"model\.layers\.\d+\.block_sparse_moe\.experts\.(\d+)\."
|
||||
match = re.search(pattern, param_name)
|
||||
if match:
|
||||
return match.group(1)
|
||||
return None
|
||||
|
||||
def load_sparse_moe_weight(name: str, loaded_weight: torch.Tensor,
|
||||
self) -> None:
|
||||
def load_sparse_moe_weight(
|
||||
name: str, loaded_weight: torch.Tensor, self
|
||||
) -> None:
|
||||
if isinstance(self.config.num_local_experts, list):
|
||||
expert_params_mapping = [
|
||||
("w13_weight"
|
||||
if weight_name in ["w1", "w3"] else "w2_weight",
|
||||
f"experts.{expert_id}.{weight_name}.weight", expert_id)
|
||||
(
|
||||
"w13_weight" if weight_name in ["w1", "w3"] else "w2_weight",
|
||||
f"experts.{expert_id}.{weight_name}.weight",
|
||||
expert_id,
|
||||
)
|
||||
for expert_id in range(max(self.config.num_local_experts))
|
||||
for weight_name in ["w1", "w2", "w3"]
|
||||
]
|
||||
else:
|
||||
expert_params_mapping = [
|
||||
("w13_scale" if weight_name in ["w1", "w3"] else
|
||||
"w2_scale", f"{expert_id}.{weight_name}.weight_scale",
|
||||
expert_id, weight_name)
|
||||
(
|
||||
"w13_scale" if weight_name in ["w1", "w3"] else "w2_scale",
|
||||
f"{expert_id}.{weight_name}.weight_scale",
|
||||
expert_id,
|
||||
weight_name,
|
||||
)
|
||||
for expert_id in range(self.config.num_local_experts)
|
||||
for weight_name in ["w1", "w2", "w3"]
|
||||
] + [("w13_weight" if weight_name in ["w1", "w3"] else
|
||||
"w2_weight", f"{expert_id}.{weight_name}.weight",
|
||||
expert_id, weight_name)
|
||||
for expert_id in range(self.config.num_local_experts)
|
||||
for weight_name in ["w1", "w2", "w3"]]
|
||||
for (param_name, weight_name, expert_id,
|
||||
shard_id) in expert_params_mapping:
|
||||
] + [
|
||||
(
|
||||
"w13_weight" if weight_name in ["w1", "w3"] else "w2_weight",
|
||||
f"{expert_id}.{weight_name}.weight",
|
||||
expert_id,
|
||||
weight_name,
|
||||
)
|
||||
for expert_id in range(self.config.num_local_experts)
|
||||
for weight_name in ["w1", "w2", "w3"]
|
||||
]
|
||||
for param_name, weight_name, expert_id, shard_id in expert_params_mapping:
|
||||
name_expert_id = get_expert_id(name)
|
||||
if name_expert_id is not None and int(name_expert_id) != int(
|
||||
expert_id):
|
||||
if name_expert_id is not None and int(name_expert_id) != int(expert_id):
|
||||
continue
|
||||
if weight_name not in name:
|
||||
continue
|
||||
@@ -779,19 +823,20 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader = weight_loader_with_alias(name)(weight_loader)
|
||||
weight_loader(param,
|
||||
loaded_weight,
|
||||
weight_name,
|
||||
expert_id=expert_id,
|
||||
shard_id=shard_id)
|
||||
weight_loader(
|
||||
param,
|
||||
loaded_weight,
|
||||
weight_name,
|
||||
expert_id=expert_id,
|
||||
shard_id=shard_id,
|
||||
)
|
||||
loaded_params.add(name)
|
||||
break
|
||||
else:
|
||||
if is_pp_missing_parameter(name, self):
|
||||
return
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader = weight_loader_with_alias(name)(weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
@@ -800,8 +845,9 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
||||
def is_shared_mlp_weight(name: str) -> bool:
|
||||
return "shared_mlp" in name and not name.endswith(".bias")
|
||||
|
||||
def load_shared_mlp_weight(name: str, loaded_weight: torch.Tensor,
|
||||
self) -> None:
|
||||
def load_shared_mlp_weight(
|
||||
name: str, loaded_weight: torch.Tensor, self
|
||||
) -> None:
|
||||
if not self.CONCAT_FFN:
|
||||
if "gate_proj" in name:
|
||||
name = name.replace("gate_proj", "w1", 1)
|
||||
@@ -819,8 +865,7 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
||||
if is_pp_missing_parameter(name, self):
|
||||
return
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader = weight_loader_with_alias(name)(weight_loader)
|
||||
if not self.CONCAT_FFN:
|
||||
weight_loader(param, loaded_weight)
|
||||
@@ -830,31 +875,31 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
||||
elif "down_proj" in name:
|
||||
weight_loader(param, loaded_weight)
|
||||
else:
|
||||
raise AssertionError(
|
||||
"MLP weight not in [gate_up_proj, down_proj]")
|
||||
raise AssertionError("MLP weight not in [gate_up_proj, down_proj]")
|
||||
loaded_params.add(name)
|
||||
return
|
||||
|
||||
def is_mha_weight(name: str) -> bool:
|
||||
return "self_attn" in name and not name.endswith(".bias")
|
||||
|
||||
def load_linear_attn_weight(name: str, loaded_weight: torch.Tensor,
|
||||
self) -> None:
|
||||
def load_linear_attn_weight(
|
||||
name: str, loaded_weight: torch.Tensor, self
|
||||
) -> None:
|
||||
if is_pp_missing_parameter(name, self):
|
||||
return
|
||||
param = params_dict[name]
|
||||
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader",
|
||||
MiniMaxText01LinearAttention.weight_direct_load)
|
||||
param, "weight_loader", MiniMaxText01LinearAttention.weight_direct_load
|
||||
)
|
||||
weight_loader = weight_loader_with_alias(name)(weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return
|
||||
|
||||
def load_flash_attn_weight(name: str, loaded_weight: torch.Tensor,
|
||||
self) -> None:
|
||||
|
||||
def load_flash_attn_weight(
|
||||
name: str, loaded_weight: torch.Tensor, self
|
||||
) -> None:
|
||||
flash_mha_params_mapping = [
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
@@ -862,16 +907,14 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
for (param_name, weight_name,
|
||||
shard_id) in flash_mha_params_mapping:
|
||||
for param_name, weight_name, shard_id in flash_mha_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
if is_pp_missing_parameter(name, self):
|
||||
return
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader = weight_loader_with_alias(name)(weight_loader)
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
loaded_params.add(name)
|
||||
@@ -881,36 +924,32 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
||||
return
|
||||
param = params_dict[name]
|
||||
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader = weight_loader_with_alias(name)(weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return
|
||||
|
||||
def is_layer_norm_weight(name: str) -> bool:
|
||||
return "norm" in name and not name.endswith(
|
||||
".bias") and name in params_dict
|
||||
return "norm" in name and not name.endswith(".bias") and name in params_dict
|
||||
|
||||
def load_layer_norm_weight(name: str, loaded_weight: torch.Tensor,
|
||||
self) -> None:
|
||||
def load_layer_norm_weight(
|
||||
name: str, loaded_weight: torch.Tensor, self
|
||||
) -> None:
|
||||
if is_pp_missing_parameter(name, self):
|
||||
return
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader = weight_loader_with_alias(name)(weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return
|
||||
|
||||
def load_basic_weight(name: str, loaded_weight: torch.Tensor,
|
||||
self) -> None:
|
||||
def load_basic_weight(name: str, loaded_weight: torch.Tensor, self) -> None:
|
||||
if is_pp_missing_parameter(name, self):
|
||||
return
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader = weight_loader_with_alias(name)(weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
@@ -919,7 +958,8 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
||||
for name, loaded_weight in weights:
|
||||
weight_at_layer = which_layer(name)
|
||||
if weight_at_layer and weight_at_layer >= len(
|
||||
self.model.decoder_attention_types):
|
||||
self.model.decoder_attention_types
|
||||
):
|
||||
continue
|
||||
|
||||
if is_layer_norm_weight(name):
|
||||
@@ -949,7 +989,6 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
||||
cls,
|
||||
vllm_config: "VllmConfig",
|
||||
) -> tuple[torch.dtype, torch.dtype]:
|
||||
|
||||
return MambaStateDtypeCalculator.linear_attention_state_dtype(
|
||||
vllm_config.model_config.dtype,
|
||||
vllm_config.cache_config.mamba_cache_dtype,
|
||||
|
||||
Reference in New Issue
Block a user