1247 lines
44 KiB
Python
1247 lines
44 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import copy
|
|
from collections.abc import Iterable
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from transformers.configuration_utils import PretrainedConfig
|
|
|
|
from vllm.compilation.decorators import support_torch_compile
|
|
from vllm.config import CacheConfig, ModelConfig, VllmConfig, get_current_vllm_config
|
|
from vllm.distributed import (
|
|
get_pp_group,
|
|
get_tensor_model_parallel_rank,
|
|
get_tensor_model_parallel_world_size,
|
|
)
|
|
from vllm.forward_context import get_forward_context
|
|
from vllm.logger import init_logger
|
|
from vllm.model_executor.layers.fla.ops.layernorm_guard import (
|
|
RMSNormGated,
|
|
layernorm_fn,
|
|
)
|
|
from vllm.model_executor.layers.fused_moe import FusedMoE, SharedFusedMoE
|
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
|
from vllm.model_executor.layers.linear import (
|
|
ColumnParallelLinear,
|
|
MergedColumnParallelLinear,
|
|
QKVParallelLinear,
|
|
ReplicatedLinear,
|
|
RowParallelLinear,
|
|
)
|
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
|
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
|
from vllm.model_executor.layers.mamba.linear_attn import (
|
|
MiniMaxText01LinearAttention,
|
|
MiniMaxText01LinearKernel,
|
|
MiniMaxText01RMSNormTP,
|
|
clear_linear_attention_cache_for_new_sequences,
|
|
linear_attention_decode,
|
|
linear_attention_prefill_and_mix,
|
|
)
|
|
from vllm.model_executor.layers.mamba.mamba_utils import (
|
|
MambaStateCopyFuncCalculator,
|
|
MambaStateDtypeCalculator,
|
|
MambaStateShapeCalculator,
|
|
)
|
|
from vllm.model_executor.layers.mla import MLAModules, MultiHeadLatentAttentionWrapper
|
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
ParallelLMHead,
|
|
VocabParallelEmbedding,
|
|
)
|
|
from vllm.model_executor.model_loader.weight_utils import (
|
|
default_weight_loader,
|
|
maybe_remap_kv_scale_name,
|
|
)
|
|
from vllm.model_executor.models.bailing_moe import BailingMLP
|
|
from vllm.sequence import IntermediateTensors
|
|
from vllm.v1.attention.backend import AttentionMetadata
|
|
from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata
|
|
|
|
from .interfaces import HasInnerState, IsHybrid, SupportsPP
|
|
from .utils import (
|
|
AutoWeightsLoader,
|
|
PPMissingLayer,
|
|
is_pp_missing_parameter,
|
|
make_layers,
|
|
maybe_prefix,
|
|
)
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
def is_linear_layer(layer_idx, layer_group_size):
|
|
if layer_idx is None:
|
|
return False
|
|
if layer_group_size > 0:
|
|
return (layer_idx + 1) % layer_group_size != 0
|
|
else:
|
|
return False
|
|
|
|
|
|
def _build_rope_parameters(config: PretrainedConfig) -> dict | None:
|
|
rope_parameters = copy.deepcopy(getattr(config, "rope_parameters", None)) or {}
|
|
if "rope_theta" not in rope_parameters and hasattr(config, "rope_theta"):
|
|
rope_parameters["rope_theta"] = config.rope_theta
|
|
if "partial_rotary_factor" not in rope_parameters and hasattr(
|
|
config, "partial_rotary_factor"
|
|
):
|
|
rope_parameters["partial_rotary_factor"] = config.partial_rotary_factor
|
|
|
|
rope_scaling = getattr(config, "rope_scaling", None)
|
|
if isinstance(rope_scaling, dict):
|
|
rope_scaling = copy.deepcopy(rope_scaling)
|
|
if "type" in rope_scaling and "rope_type" not in rope_scaling:
|
|
rope_scaling["rope_type"] = rope_scaling.pop("type")
|
|
rope_parameters.update(rope_scaling)
|
|
|
|
return rope_parameters or None
|
|
|
|
|
|
class BailingMoeV25MLAAttention(nn.Module):
|
|
"""
|
|
MLA Attention for BailingMoeV2.5 full attention layers.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
config: PretrainedConfig,
|
|
quant_config: QuantizationConfig | None = None,
|
|
layer_id: int = 0,
|
|
prefix: str = "attention",
|
|
cache_config: CacheConfig | None = None,
|
|
) -> None:
|
|
super().__init__()
|
|
self.hidden_size = config.hidden_size
|
|
self.num_heads = config.num_attention_heads
|
|
self.layer_id = layer_id
|
|
self.prefix = prefix
|
|
|
|
# MLA dimensions
|
|
self.qk_nope_head_dim = getattr(config, "qk_nope_head_dim", 128)
|
|
self.qk_rope_head_dim = getattr(config, "qk_rope_head_dim", 64)
|
|
self.qk_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim
|
|
self.v_head_dim = getattr(config, "v_head_dim", 128)
|
|
|
|
# LoRA ranks
|
|
self.q_lora_rank = getattr(config, "q_lora_rank", None)
|
|
self.kv_lora_rank = getattr(config, "kv_lora_rank", 512)
|
|
|
|
tp_size = get_tensor_model_parallel_world_size()
|
|
assert self.num_heads % tp_size == 0
|
|
self.num_local_heads = self.num_heads // tp_size
|
|
|
|
self.scaling = self.qk_head_dim**-0.5
|
|
|
|
# KV projections
|
|
self.kv_a_layernorm = RMSNorm(
|
|
self.kv_lora_rank,
|
|
eps=config.rms_norm_eps,
|
|
)
|
|
self.kv_b_proj = ColumnParallelLinear(
|
|
self.kv_lora_rank,
|
|
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.kv_b_proj",
|
|
)
|
|
|
|
# Output projection
|
|
self.o_proj = RowParallelLinear(
|
|
self.num_heads * self.v_head_dim,
|
|
self.hidden_size,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.o_proj",
|
|
)
|
|
|
|
if self.q_lora_rank is not None:
|
|
# Use fused_qkv_a_proj when q_lora_rank is set
|
|
self.fused_qkv_a_proj = MergedColumnParallelLinear(
|
|
self.hidden_size,
|
|
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.fused_qkv_a_proj",
|
|
disable_tp=True,
|
|
)
|
|
self.q_a_layernorm = RMSNorm(
|
|
self.q_lora_rank,
|
|
eps=config.rms_norm_eps,
|
|
)
|
|
self.q_b_proj = ColumnParallelLinear(
|
|
self.q_lora_rank,
|
|
self.num_heads * self.qk_head_dim,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.q_b_proj",
|
|
)
|
|
self.q_proj = None
|
|
self.kv_a_proj_with_mqa = None
|
|
else:
|
|
# Direct projections when no q_lora_rank
|
|
self.q_proj = ColumnParallelLinear(
|
|
self.hidden_size,
|
|
self.num_heads * self.qk_head_dim,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.q_proj",
|
|
)
|
|
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
|
self.hidden_size,
|
|
self.kv_lora_rank + self.qk_rope_head_dim,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.kv_a_proj_with_mqa",
|
|
)
|
|
self.fused_qkv_a_proj = None
|
|
self.q_a_layernorm = None
|
|
self.q_b_proj = None
|
|
|
|
rope_parameters = _build_rope_parameters(config)
|
|
max_position = getattr(config, "max_position_embeddings", 8192)
|
|
self.rotary_emb = get_rope(
|
|
head_size=self.qk_rope_head_dim,
|
|
max_position=max_position,
|
|
is_neox_style=False,
|
|
rope_parameters=rope_parameters or None,
|
|
dtype=torch.float32,
|
|
)
|
|
|
|
# Build MLAModules for MultiHeadLatentAttentionWrapper
|
|
mla_modules = MLAModules(
|
|
kv_a_layernorm=self.kv_a_layernorm,
|
|
kv_b_proj=self.kv_b_proj,
|
|
rotary_emb=self.rotary_emb,
|
|
o_proj=self.o_proj,
|
|
fused_qkv_a_proj=self.fused_qkv_a_proj,
|
|
kv_a_proj_with_mqa=self.kv_a_proj_with_mqa,
|
|
q_a_layernorm=self.q_a_layernorm,
|
|
q_b_proj=self.q_b_proj,
|
|
q_proj=self.q_proj,
|
|
indexer=None,
|
|
is_sparse=False,
|
|
topk_indices_buffer=None,
|
|
)
|
|
|
|
self.mla_attn = MultiHeadLatentAttentionWrapper(
|
|
self.hidden_size,
|
|
self.num_local_heads,
|
|
self.scaling,
|
|
self.qk_nope_head_dim,
|
|
self.qk_rope_head_dim,
|
|
self.v_head_dim,
|
|
self.q_lora_rank,
|
|
self.kv_lora_rank,
|
|
mla_modules,
|
|
cache_config,
|
|
quant_config,
|
|
prefix,
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
"""Forward pass for MLA attention."""
|
|
return self.mla_attn(positions, hidden_states)
|
|
|
|
|
|
class BailingMoEGate(nn.Module):
|
|
def __init__(
|
|
self,
|
|
config: PretrainedConfig,
|
|
params_dtype: torch.dtype | None = None,
|
|
prefix: str = "",
|
|
):
|
|
super().__init__()
|
|
if params_dtype is None:
|
|
params_dtype = torch.get_default_dtype()
|
|
self.params_dtype = params_dtype
|
|
self.weight = nn.Parameter(
|
|
torch.empty(
|
|
(config.num_experts, config.hidden_size),
|
|
dtype=self.params_dtype,
|
|
),
|
|
)
|
|
if getattr(config, "moe_router_enable_expert_bias", False):
|
|
self.expert_bias = nn.Parameter(
|
|
torch.empty((config.num_experts,), dtype=torch.float32),
|
|
)
|
|
else:
|
|
self.expert_bias = None
|
|
|
|
def forward(self, hidden_states):
|
|
logits = F.linear(hidden_states.to(self.weight.dtype), self.weight, None).to(
|
|
hidden_states.dtype
|
|
)
|
|
return logits
|
|
|
|
|
|
class BailingMoeV25(nn.Module):
|
|
"""Bailing MoE v2.5 - standalone implementation for linear attention model."""
|
|
|
|
def __init__(
|
|
self,
|
|
config: PretrainedConfig,
|
|
quant_config: QuantizationConfig | None = None,
|
|
layer_id: int = 0,
|
|
prefix: str = "",
|
|
):
|
|
super().__init__()
|
|
|
|
self.layer_id = layer_id
|
|
self.tp_size = get_tensor_model_parallel_world_size()
|
|
self.tp_rank = get_tensor_model_parallel_rank()
|
|
self.num_experts = config.num_experts
|
|
self.top_k = config.num_experts_per_tok
|
|
norm_topk_prob = getattr(config, "norm_topk_prob", None)
|
|
# Ring-2.5 reference implementations normalize routing weights by default.
|
|
self.norm_expert_prob = True if norm_topk_prob is None else bool(norm_topk_prob)
|
|
self.hidden_size = config.hidden_size
|
|
self.quant_config = quant_config
|
|
self.num_shared_experts = config.num_shared_experts
|
|
self.score_function = getattr(config, "score_function", None)
|
|
self.n_group = getattr(config, "n_group", None)
|
|
self.topk_group = getattr(config, "topk_group", None)
|
|
self.use_grouped_topk = self.n_group is not None and self.topk_group is not None
|
|
self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 1.0)
|
|
|
|
router_dtype = getattr(config, "router_dtype", None)
|
|
if router_dtype is None or router_dtype == "fp32":
|
|
self.router_dtype = torch.float32
|
|
else:
|
|
self.router_dtype = torch.bfloat16
|
|
|
|
# Gate for routing
|
|
self.gate = BailingMoEGate(
|
|
config=config,
|
|
params_dtype=self.router_dtype,
|
|
prefix=f"{prefix}.gate",
|
|
)
|
|
correction_bias = (
|
|
self.gate.expert_bias if self.gate.expert_bias is not None else None
|
|
)
|
|
if self.score_function is not None:
|
|
assert (self.score_function == "softmax" and correction_bias is None) or (
|
|
self.score_function == "sigmoid" and correction_bias is not None
|
|
), (
|
|
"score_function and correction_bias should be "
|
|
"(softmax, None) or (sigmoid, not None)"
|
|
)
|
|
|
|
# Shared experts (using BailingMLP)
|
|
if self.num_shared_experts > 0:
|
|
if hasattr(config, "moe_shared_expert_intermediate_size"):
|
|
intermediate_size = config.moe_shared_expert_intermediate_size
|
|
else:
|
|
intermediate_size = config.moe_intermediate_size
|
|
intermediate_size *= config.num_shared_experts
|
|
self.shared_experts = BailingMLP(
|
|
intermediate_size=intermediate_size,
|
|
config=config,
|
|
quant_config=quant_config,
|
|
reduce_results=False,
|
|
prefix=f"{prefix}.shared_experts",
|
|
)
|
|
else:
|
|
self.shared_experts = None
|
|
|
|
# Routed experts using SharedFusedMoE
|
|
self.experts = SharedFusedMoE(
|
|
shared_experts=self.shared_experts,
|
|
num_experts=self.num_experts,
|
|
top_k=self.top_k,
|
|
hidden_size=self.hidden_size,
|
|
intermediate_size=config.moe_intermediate_size,
|
|
reduce_results=False,
|
|
renormalize=self.norm_expert_prob,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.experts",
|
|
scoring_func=self.score_function,
|
|
e_score_correction_bias=correction_bias,
|
|
num_expert_group=self.n_group,
|
|
topk_group=self.topk_group,
|
|
use_grouped_topk=self.use_grouped_topk,
|
|
router_logits_dtype=self.router_dtype,
|
|
)
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
num_tokens, hidden_size = hidden_states.shape
|
|
# Ensure contiguous token-major layout before router/projections.
|
|
hidden_states = hidden_states.contiguous().view(-1, hidden_size)
|
|
|
|
# router_logits: (num_tokens, n_experts)
|
|
router_logits = self.gate(hidden_states.to(self.router_dtype))
|
|
router_logits = router_logits.to(hidden_states.dtype)
|
|
|
|
final_hidden_states = self.experts(
|
|
hidden_states=hidden_states, router_logits=router_logits
|
|
)
|
|
|
|
# Handle tuple return from SharedFusedMoE
|
|
if self.shared_experts is not None:
|
|
shared_output, final_hidden_states = final_hidden_states
|
|
else:
|
|
shared_output = None
|
|
|
|
final_hidden_states *= self.routed_scaling_factor
|
|
|
|
if shared_output is not None:
|
|
final_hidden_states = final_hidden_states + shared_output
|
|
|
|
if self.tp_size > 1:
|
|
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(
|
|
final_hidden_states
|
|
)
|
|
|
|
return final_hidden_states.view(num_tokens, hidden_size)
|
|
|
|
|
|
BailingRMSNormTP = MiniMaxText01RMSNormTP
|
|
|
|
|
|
class BailingGroupRMSNormGate(RMSNormGated):
|
|
def __init__(
|
|
self,
|
|
hidden_size,
|
|
eps=1e-5,
|
|
group_size=None,
|
|
norm_before_gate=True,
|
|
device=None,
|
|
dtype=None,
|
|
):
|
|
super().__init__(
|
|
hidden_size,
|
|
eps=eps,
|
|
group_size=group_size,
|
|
norm_before_gate=norm_before_gate,
|
|
device=device,
|
|
dtype=dtype,
|
|
activation="sigmoid",
|
|
)
|
|
# Add custom weight loader for TP sharding
|
|
self.weight.weight_loader = self._weight_loader
|
|
|
|
@staticmethod
|
|
def _weight_loader(param: torch.nn.Parameter, loaded_weight: torch.Tensor) -> None:
|
|
"""Load weight with TP sharding."""
|
|
tp_size = get_tensor_model_parallel_world_size()
|
|
tp_rank = get_tensor_model_parallel_rank()
|
|
shard_size = loaded_weight.shape[0] // tp_size
|
|
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
|
|
param.data.copy_(loaded_weight[shard].contiguous())
|
|
|
|
|
|
class BailingMoELinearAttention(nn.Module, MambaBase):
|
|
"""
|
|
Bailing MoE Linear Attention implementation using minimax backend.
|
|
|
|
This implements the linear attention mechanism from sglang, adapted for vLLM's
|
|
v1 engine with MambaBase interface support.
|
|
"""
|
|
|
|
@property
|
|
def mamba_type(self) -> str:
|
|
return "linear_attention"
|
|
|
|
def get_state_shape(self) -> tuple[tuple[int, ...], ...]:
|
|
"""Return state shape for linear attention cache.
|
|
|
|
Must match the calculation in get_mamba_state_shape_from_config.
|
|
"""
|
|
return MambaStateShapeCalculator.linear_attention_state_shape(
|
|
num_heads=self.total_num_heads,
|
|
tp_size=self.tp_size,
|
|
head_dim=self.head_dim,
|
|
)
|
|
|
|
def get_state_dtype(self) -> tuple[torch.dtype, ...]:
|
|
"""Return state dtype for linear attention cache.
|
|
|
|
Must match the calculation in get_mamba_state_dtype_from_config.
|
|
"""
|
|
return MambaStateDtypeCalculator.linear_attention_state_dtype(
|
|
self.model_config.dtype,
|
|
self.cache_config.mamba_cache_dtype,
|
|
)
|
|
|
|
def __init__(
|
|
self,
|
|
config: PretrainedConfig,
|
|
quant_config: QuantizationConfig | None = None,
|
|
layer_id: int = 0,
|
|
prefix: str = "linear_attn",
|
|
model_config: ModelConfig | None = None,
|
|
cache_config: CacheConfig | None = None,
|
|
):
|
|
super().__init__()
|
|
|
|
self.layer_id = layer_id
|
|
self.hidden_size = config.hidden_size
|
|
self.total_num_heads = config.num_attention_heads
|
|
self.total_kv_heads = config.num_attention_heads # MHA
|
|
self.tp_size = get_tensor_model_parallel_world_size()
|
|
self.tp_rank = get_tensor_model_parallel_rank()
|
|
self.model_config = model_config
|
|
self.cache_config = cache_config
|
|
self.prefix = prefix
|
|
|
|
self.head_dim = (
|
|
config.head_dim
|
|
if hasattr(config, "head_dim")
|
|
else config.hidden_size // self.total_num_heads
|
|
)
|
|
|
|
self.hidden_inner_size = self.head_dim * self.total_num_heads
|
|
self.scaling = self.head_dim**-0.5
|
|
|
|
assert self.total_num_heads % self.tp_size == 0
|
|
self.tp_heads = self.total_num_heads // self.tp_size
|
|
|
|
self.max_position_embeddings = config.max_position_embeddings
|
|
self.rope_theta = getattr(config, "rope_theta", 600000)
|
|
|
|
self.tp_kv_heads = self.total_kv_heads // self.tp_size
|
|
self.q_size_per_rank = self.head_dim * self.tp_heads
|
|
self.kv_size_per_rank = self.head_dim * self.tp_kv_heads
|
|
|
|
self.use_qk_norm = getattr(config, "use_qk_norm", False)
|
|
self.linear_backend = "minimax"
|
|
self.linear_scale = self.linear_backend == "minimax"
|
|
self.linear_rope = getattr(config, "linear_rope", True)
|
|
if hasattr(config, "use_linear_silu"):
|
|
self.linear_silu = config.use_linear_silu
|
|
elif hasattr(config, "linear_silu"):
|
|
self.linear_silu = config.linear_silu
|
|
else:
|
|
self.linear_silu = False
|
|
|
|
# Block size for lightning attention
|
|
self.BLOCK = getattr(config, "block", 256)
|
|
|
|
self.query_key_value = QKVParallelLinear(
|
|
self.hidden_size,
|
|
self.head_dim,
|
|
self.total_num_heads,
|
|
self.total_num_heads, # MHA: kv_heads = num_heads
|
|
bias=(config.use_bias or config.use_qkv_bias),
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.query_key_value",
|
|
)
|
|
|
|
if self.use_qk_norm:
|
|
self.query_layernorm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
|
self.key_layernorm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
|
|
|
self.g_proj = ColumnParallelLinear(
|
|
self.hidden_size,
|
|
self.hidden_inner_size,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.g_proj",
|
|
)
|
|
self.dense = RowParallelLinear(
|
|
self.hidden_inner_size,
|
|
self.hidden_size,
|
|
bias=config.use_bias,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.dense",
|
|
reduce_results=True,
|
|
)
|
|
|
|
self.group_norm_size = getattr(config, "group_norm_size", 1)
|
|
self.rms_norm_eps = float(getattr(config, "rms_norm_eps", 1e-5))
|
|
assert self.tp_size <= self.group_norm_size, (
|
|
"tp_size must be <= group_norm_size for local rms norm"
|
|
)
|
|
assert self.group_norm_size % self.tp_size == 0, (
|
|
"group_norm_size must be divisible by tp_size"
|
|
)
|
|
|
|
# When group_norm_size == 1, group_size equals hidden_size // tp_size
|
|
self.g_norm = BailingGroupRMSNormGate(
|
|
hidden_size=self.hidden_inner_size // self.tp_size,
|
|
eps=self.rms_norm_eps,
|
|
group_size=(
|
|
self.hidden_inner_size // self.group_norm_size
|
|
if self.group_norm_size > 1
|
|
else self.hidden_inner_size // self.tp_size
|
|
),
|
|
)
|
|
|
|
# use fp32 rotary embedding
|
|
rope_parameters = _build_rope_parameters(config)
|
|
|
|
self.rotary_emb = get_rope(
|
|
self.head_dim,
|
|
max_position=self.max_position_embeddings,
|
|
is_neox_style=True,
|
|
dtype=torch.float32,
|
|
rope_parameters=rope_parameters or None,
|
|
)
|
|
|
|
# Build slope tensor for linear attention decay
|
|
num_hidden_layers = config.num_hidden_layers
|
|
slope_rate = MiniMaxText01LinearAttention._build_slope_tensor(
|
|
self.total_num_heads
|
|
)
|
|
if num_hidden_layers <= 1:
|
|
self.slope_rate = slope_rate * (1 + 1e-5)
|
|
else:
|
|
self.slope_rate = slope_rate * (
|
|
1 - layer_id / (num_hidden_layers - 1) + 1e-5
|
|
)
|
|
self.tp_slope = self.slope_rate[
|
|
self.tp_rank * self.tp_heads : (self.tp_rank + 1) * self.tp_heads
|
|
].contiguous()
|
|
|
|
# Register for compilation
|
|
compilation_config = get_current_vllm_config().compilation_config
|
|
if prefix in compilation_config.static_forward_context:
|
|
raise ValueError(f"Duplicate layer name: {prefix}")
|
|
compilation_config.static_forward_context[prefix] = self
|
|
|
|
@staticmethod
|
|
def weight_direct_load(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
|
|
"""Load weight for linear attention layers.
|
|
|
|
For FP8 quantized parameters, we need to use the weight_loader if available,
|
|
as it handles special cases like tensor parallelism sharding.
|
|
"""
|
|
# Check if param has a weight_loader (for vLLM ModelWeightParameter)
|
|
weight_loader = getattr(param, "weight_loader", None)
|
|
if weight_loader is not None:
|
|
# Use the weight_loader which handles TP sharding and quantization
|
|
weight_loader(param, loaded_weight)
|
|
else:
|
|
# Fall back to direct copy for standard tensors
|
|
assert param.size() == loaded_weight.size(), (
|
|
f"Shape mismatch: {param.shape} vs {loaded_weight.shape}"
|
|
)
|
|
param.data.copy_(loaded_weight)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
output: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
) -> None:
|
|
"""Forward method called by torch.ops.vllm.linear_attention"""
|
|
torch.ops.vllm.linear_attention(
|
|
hidden_states,
|
|
output,
|
|
positions,
|
|
self.prefix,
|
|
)
|
|
|
|
def _forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
output: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
) -> None:
|
|
"""Actual forward implementation."""
|
|
forward_context = get_forward_context()
|
|
attn_metadata: AttentionMetadata = forward_context.attn_metadata
|
|
if attn_metadata is not None:
|
|
assert isinstance(attn_metadata, dict)
|
|
attn_metadata = attn_metadata[self.prefix]
|
|
assert isinstance(attn_metadata, LinearAttentionMetadata)
|
|
num_actual_tokens = (
|
|
attn_metadata.num_prefill_tokens + attn_metadata.num_decode_tokens
|
|
)
|
|
else:
|
|
num_actual_tokens = hidden_states.shape[0]
|
|
|
|
# QKV projection
|
|
qkv, _ = self.query_key_value(hidden_states[:num_actual_tokens])
|
|
|
|
# use rotary_emb support fp32
|
|
qkv = qkv.to(torch.float32)
|
|
if self.linear_silu:
|
|
qkv = F.silu(qkv)
|
|
|
|
# Split q, k, v
|
|
q, k, v = torch.split(
|
|
qkv,
|
|
[self.q_size_per_rank, self.kv_size_per_rank, self.kv_size_per_rank],
|
|
dim=-1,
|
|
)
|
|
|
|
# Apply QK norm if needed
|
|
if self.use_qk_norm:
|
|
q = q.reshape(-1, self.tp_heads, self.head_dim)
|
|
k = k.reshape(-1, self.tp_kv_heads, self.head_dim)
|
|
q = layernorm_fn(
|
|
q,
|
|
self.query_layernorm.weight.data,
|
|
bias=None,
|
|
eps=self.rms_norm_eps,
|
|
is_rms_norm=True,
|
|
)
|
|
k = layernorm_fn(
|
|
k,
|
|
self.key_layernorm.weight.data,
|
|
bias=None,
|
|
eps=self.rms_norm_eps,
|
|
is_rms_norm=True,
|
|
)
|
|
q = q.reshape(-1, self.q_size_per_rank)
|
|
k = k.reshape(-1, self.kv_size_per_rank)
|
|
|
|
# Apply rotary embeddings
|
|
if self.linear_rope:
|
|
q, k = self.rotary_emb(positions[:num_actual_tokens], q, k)
|
|
|
|
# Reshape to [batch, heads, seq_len, head_dim]
|
|
q = q.view((qkv.shape[0], self.tp_heads, self.head_dim))
|
|
k = k.view((qkv.shape[0], self.tp_kv_heads, self.head_dim))
|
|
v = v.view((qkv.shape[0], self.tp_kv_heads, self.head_dim))
|
|
|
|
# Apply scaling if using minimax backend
|
|
if self.linear_scale:
|
|
q = q * self.scaling
|
|
|
|
# Get KV cache and state indices
|
|
if attn_metadata is not None:
|
|
kv_cache = self.kv_cache[forward_context.virtual_engine][0]
|
|
state_indices_tensor = attn_metadata.state_indices_tensor
|
|
clear_linear_attention_cache_for_new_sequences(
|
|
kv_cache, state_indices_tensor, attn_metadata
|
|
)
|
|
|
|
# Compute attention
|
|
decode_only = getattr(attn_metadata, "num_prefills", 0) == 0
|
|
if attn_metadata is None:
|
|
hidden = torch.empty(
|
|
(q.shape[0], q.shape[1] * q.shape[2]), device=q.device, dtype=q.dtype
|
|
)
|
|
else:
|
|
if not decode_only:
|
|
hidden = self._prefill_and_mix_infer(
|
|
q, k, v, kv_cache, state_indices_tensor, attn_metadata
|
|
)
|
|
else:
|
|
hidden = self._decode_infer(
|
|
q, k, v, kv_cache, state_indices_tensor, attn_metadata
|
|
)
|
|
|
|
# Apply group norm and gate (matching SGLang behavior)
|
|
gate, _ = self.g_proj(hidden_states[:num_actual_tokens])
|
|
|
|
if self.group_norm_size > 1:
|
|
hidden = self.g_norm(hidden, gate)
|
|
else:
|
|
hidden = self.g_norm(hidden)
|
|
hidden = F.sigmoid(gate) * hidden
|
|
|
|
hidden = hidden.to(hidden_states.dtype)
|
|
|
|
# Output projection
|
|
dense_out, _ = self.dense(hidden)
|
|
output[:num_actual_tokens] = dense_out
|
|
|
|
def _prefill_and_mix_infer(
|
|
self, q, k, v, kv_cache, state_indices_tensor, attn_metadata
|
|
):
|
|
"""Handle prefill (mixed with decode if any)."""
|
|
return linear_attention_prefill_and_mix(
|
|
q=q,
|
|
k=k,
|
|
v=v,
|
|
kv_cache=kv_cache,
|
|
state_indices_tensor=state_indices_tensor,
|
|
attn_metadata=attn_metadata,
|
|
slope_rate=self.tp_slope,
|
|
block_size=self.BLOCK,
|
|
decode_fn=self._decode_infer,
|
|
prefix_fn=MiniMaxText01LinearKernel.jit_linear_forward_prefix,
|
|
layer_idx=self.layer_id,
|
|
)
|
|
|
|
def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor, attn_metadata):
|
|
"""Handle decode (single token per sequence)."""
|
|
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
|
num_prefills = attn_metadata.num_prefills
|
|
hidden = linear_attention_decode(
|
|
q,
|
|
k,
|
|
v,
|
|
kv_cache,
|
|
self.tp_slope,
|
|
state_indices_tensor,
|
|
q_start=num_prefill_tokens,
|
|
q_end=None,
|
|
slot_start=num_prefills,
|
|
slot_end=None,
|
|
block_size=32,
|
|
)
|
|
return hidden
|
|
|
|
|
|
class BailingMoeV25DecoderLayer(nn.Module):
|
|
"""Decoder layer supporting both linear and full attention."""
|
|
|
|
def __init__(
|
|
self,
|
|
config: PretrainedConfig,
|
|
quant_config: QuantizationConfig | None = None,
|
|
layer_id: int = 0,
|
|
prefix: str = "layer",
|
|
model_config: ModelConfig | None = None,
|
|
cache_config: CacheConfig | None = None,
|
|
) -> None:
|
|
super().__init__()
|
|
self.layer_id = layer_id
|
|
self.hidden_size = config.hidden_size
|
|
|
|
# Determine attention type (0 = linear, 1 = full)
|
|
self.attention_type = getattr(config, "attention_type", 1)
|
|
|
|
if self.attention_type == 0: # Linear attention
|
|
self.self_attn = BailingMoELinearAttention(
|
|
config,
|
|
quant_config=quant_config,
|
|
layer_id=layer_id,
|
|
prefix=f"{prefix}.self_attn",
|
|
model_config=model_config,
|
|
cache_config=cache_config,
|
|
)
|
|
else: # Full attention
|
|
self.self_attn = BailingMoeV25MLAAttention(
|
|
config,
|
|
quant_config=quant_config,
|
|
layer_id=layer_id,
|
|
prefix=f"{prefix}.self_attn",
|
|
cache_config=cache_config,
|
|
)
|
|
|
|
# MLP/MoE
|
|
is_moe_layer = config.num_experts > 1 and layer_id >= getattr(
|
|
config, "first_k_dense_replace", 0
|
|
)
|
|
|
|
if is_moe_layer:
|
|
self.mlp = BailingMoeV25(
|
|
config,
|
|
quant_config=quant_config,
|
|
layer_id=layer_id,
|
|
prefix=f"{prefix}.mlp",
|
|
)
|
|
else:
|
|
self.mlp = BailingMLP(
|
|
intermediate_size=config.intermediate_size,
|
|
config=config,
|
|
quant_config=quant_config,
|
|
reduce_results=True,
|
|
prefix=f"{prefix}.mlp",
|
|
)
|
|
|
|
# Layer norms
|
|
rms_norm_eps = float(getattr(config, "rms_norm_eps", 1e-5))
|
|
self.input_layernorm = RMSNorm(self.hidden_size, eps=rms_norm_eps)
|
|
self.post_attention_layernorm = RMSNorm(self.hidden_size, eps=rms_norm_eps)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
attn_metadata: AttentionMetadata | None = None,
|
|
residual: torch.Tensor | None = None,
|
|
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
|
# Input layernorm
|
|
if residual is None:
|
|
residual = hidden_states
|
|
hidden_states = self.input_layernorm(hidden_states)
|
|
else:
|
|
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
|
|
|
# Self attention
|
|
if self.attention_type == 0:
|
|
# Linear attention uses output tensor
|
|
self_attention_output = torch.zeros_like(hidden_states)
|
|
self.self_attn(
|
|
hidden_states=hidden_states,
|
|
output=self_attention_output,
|
|
positions=positions,
|
|
)
|
|
else:
|
|
# Full attention
|
|
self_attention_output = self.self_attn(hidden_states, positions)
|
|
|
|
hidden_states, residual = self.post_attention_layernorm(
|
|
self_attention_output, residual
|
|
)
|
|
hidden_states = self.mlp(hidden_states)
|
|
return hidden_states, residual
|
|
|
|
|
|
@support_torch_compile(
|
|
dynamic_arg_dims={
|
|
"input_ids": 0,
|
|
"positions": -1,
|
|
"intermediate_tensors": 0,
|
|
"inputs_embeds": 0,
|
|
}
|
|
)
|
|
class BailingMoeV25Model(nn.Module):
|
|
"""Bailing MoE v2.5 Model with hybrid attention support."""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
vllm_config: VllmConfig,
|
|
prefix: str = "",
|
|
):
|
|
super().__init__()
|
|
config = vllm_config.model_config.hf_config
|
|
model_config = vllm_config.model_config
|
|
quant_config = vllm_config.quant_config
|
|
cache_config = vllm_config.cache_config
|
|
|
|
self.config = config
|
|
self.vocab_size = config.vocab_size
|
|
self.embed_dim = config.hidden_size
|
|
|
|
# Determine layer types based on layer_group_size
|
|
self.layer_group_size = getattr(config, "layer_group_size", 1)
|
|
self.num_layers = config.num_hidden_layers
|
|
|
|
# decoder_attention_types: 0 = linear, 1 = full
|
|
self.decoder_attention_types = [
|
|
0 if is_linear_layer(i, self.layer_group_size) else 1
|
|
for i in range(self.num_layers)
|
|
]
|
|
|
|
# Embeddings
|
|
if get_pp_group().is_first_rank:
|
|
self.word_embeddings = VocabParallelEmbedding(
|
|
self.vocab_size,
|
|
self.embed_dim,
|
|
org_num_embeddings=self.vocab_size,
|
|
)
|
|
else:
|
|
from vllm.model_executor.models.utils import PPMissingLayer
|
|
|
|
self.word_embeddings = PPMissingLayer()
|
|
|
|
# Layers
|
|
def layer_fn(prefix):
|
|
layer_idx = int(prefix.split(".")[-1])
|
|
layer_config = copy.deepcopy(config)
|
|
layer_config.attention_type = self.decoder_attention_types[layer_idx]
|
|
|
|
return BailingMoeV25DecoderLayer(
|
|
config=layer_config,
|
|
quant_config=quant_config,
|
|
layer_id=layer_idx,
|
|
prefix=prefix,
|
|
model_config=model_config,
|
|
cache_config=cache_config,
|
|
)
|
|
|
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
|
self.num_layers, layer_fn, prefix=f"{prefix}.layers"
|
|
)
|
|
|
|
# Final norm
|
|
norm_kwargs = {}
|
|
if hasattr(config, "rms_norm_eps"):
|
|
norm_kwargs["eps"] = config.rms_norm_eps
|
|
if get_pp_group().is_last_rank:
|
|
self.norm = RMSNorm(config.hidden_size, **norm_kwargs)
|
|
else:
|
|
from vllm.model_executor.models.utils import PPMissingLayer
|
|
|
|
self.norm = PPMissingLayer()
|
|
|
|
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
|
return self.word_embeddings(input_ids)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor | None,
|
|
positions: torch.Tensor,
|
|
intermediate_tensors: IntermediateTensors | None = None,
|
|
inputs_embeds: torch.Tensor | None = None,
|
|
) -> torch.Tensor:
|
|
forward_context = get_forward_context()
|
|
attn_metadata = forward_context.attn_metadata
|
|
|
|
if get_pp_group().is_first_rank:
|
|
if inputs_embeds is None:
|
|
hidden_states = self.word_embeddings(input_ids)
|
|
else:
|
|
hidden_states = inputs_embeds
|
|
residual = None
|
|
else:
|
|
assert intermediate_tensors is not None
|
|
hidden_states = intermediate_tensors["hidden_states"]
|
|
residual = intermediate_tensors["residual"]
|
|
|
|
for layer in self.layers[self.start_layer : self.end_layer]:
|
|
hidden_states, residual = layer(
|
|
hidden_states=hidden_states,
|
|
positions=positions,
|
|
attn_metadata=attn_metadata,
|
|
residual=residual,
|
|
)
|
|
|
|
if not get_pp_group().is_last_rank:
|
|
return IntermediateTensors(
|
|
{"hidden_states": hidden_states, "residual": residual}
|
|
)
|
|
else:
|
|
if residual is not None:
|
|
hidden_states, _ = self.norm(hidden_states, residual)
|
|
else:
|
|
hidden_states = self.norm(hidden_states)
|
|
return hidden_states
|
|
|
|
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
|
"""Get expert parameter mapping for MoE layers."""
|
|
return FusedMoE.make_expert_params_mapping(
|
|
self,
|
|
ckpt_gate_proj_name="gate_proj",
|
|
ckpt_down_proj_name="down_proj",
|
|
ckpt_up_proj_name="up_proj",
|
|
num_experts=self.config.num_experts,
|
|
num_redundant_experts=0,
|
|
)
|
|
|
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
|
"""Load checkpoint weights with simplified mapping."""
|
|
|
|
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
|
loaded_params: set[str] = set()
|
|
|
|
# Stacked parameter mappings (fused projections)
|
|
stacked_mappings = [
|
|
(".fused_qkv_a_proj", ".q_a_proj", 0),
|
|
(".fused_qkv_a_proj", ".kv_a_proj_with_mqa", 1),
|
|
(".gate_up_proj", ".gate_proj", 0),
|
|
(".gate_up_proj", ".up_proj", 1),
|
|
]
|
|
|
|
# Expert parameter mappings from FusedMoE
|
|
expert_mappings = list(self.get_expert_mapping())
|
|
|
|
def load_param(name: str, tensor: torch.Tensor, shard_id=None) -> bool:
|
|
"""Load a single parameter."""
|
|
if name not in params_dict or is_pp_missing_parameter(name, self):
|
|
return False
|
|
if name.endswith(".bias") and name not in params_dict:
|
|
return False
|
|
|
|
param = params_dict[name]
|
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
|
|
|
if shard_id is None:
|
|
weight_loader(param, tensor)
|
|
elif isinstance(shard_id, int):
|
|
weight_loader(param, tensor, shard_id)
|
|
else:
|
|
# Expert param: (expert_id, shard_id)
|
|
weight_loader(
|
|
param, tensor, name, expert_id=shard_id[0], shard_id=shard_id[1]
|
|
)
|
|
|
|
loaded_params.add(name)
|
|
return True
|
|
|
|
def normalize_name(name: str) -> str | None:
|
|
"""Normalize checkpoint name to model parameter name."""
|
|
# Skip special weights
|
|
if name.startswith("model.mtp"):
|
|
return None
|
|
# Remove 'model.' prefix if present
|
|
# (e.g., 'model.layers.0...' -> 'layers.0...')
|
|
name = name.removeprefix("model.")
|
|
# Map attention.dense based on layer type
|
|
if "attention.dense" in name:
|
|
layer_idx = (
|
|
int(name.split("layers.")[1].split(".")[0])
|
|
if "layers." in name
|
|
else 0
|
|
)
|
|
attn_name = (
|
|
"self_attn.dense"
|
|
if is_linear_layer(layer_idx, self.config.layer_group_size)
|
|
else "self_attn.o_proj"
|
|
)
|
|
name = name.replace("attention.dense", attn_name)
|
|
|
|
# Standard mappings
|
|
name = name.replace("attention.", "self_attn.")
|
|
name = name.replace(
|
|
"mlp.gate.e_score_correction_bias", "mlp.gate.expert_bias"
|
|
)
|
|
|
|
return maybe_remap_kv_scale_name(name, params_dict)
|
|
|
|
for orig_name, weight in weights:
|
|
norm_name = normalize_name(orig_name)
|
|
if norm_name is None:
|
|
continue
|
|
|
|
# Try stacked mappings
|
|
loaded = False
|
|
for param_suf, weight_suf, shard_id in stacked_mappings:
|
|
if weight_suf not in norm_name:
|
|
continue
|
|
mapped = norm_name.replace(weight_suf, param_suf).replace(
|
|
"attention.", "self_attn."
|
|
)
|
|
if load_param(mapped, weight, shard_id):
|
|
loaded = True
|
|
break
|
|
if loaded:
|
|
continue
|
|
|
|
# Handle expert weights
|
|
if "mlp.experts" in norm_name:
|
|
# Expert bias
|
|
if (
|
|
"mlp.experts.e_score_correction_bias" in norm_name
|
|
or "mlp.experts.expert_bias" in norm_name
|
|
):
|
|
alt = norm_name.replace(
|
|
"mlp.experts.e_score_correction_bias", "mlp.gate.expert_bias"
|
|
).replace("mlp.experts.expert_bias", "mlp.gate.expert_bias")
|
|
if load_param(alt, weight) or load_param(norm_name, weight):
|
|
continue
|
|
|
|
# Routed experts
|
|
for param_name, weight_name, expert_id, shard_id in expert_mappings:
|
|
if weight_name not in norm_name:
|
|
continue
|
|
mapped = norm_name.replace(weight_name, param_name)
|
|
if load_param(mapped, weight, (expert_id, shard_id)):
|
|
break
|
|
continue
|
|
|
|
# General parameters
|
|
load_param(norm_name, weight)
|
|
|
|
return loaded_params
|
|
|
|
|
|
class BailingMoeV25ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsPP):
|
|
"""Bailing MoE v2.5 For CausalLM."""
|
|
|
|
packed_modules_mapping = {
|
|
"gate_up_proj": ["gate_proj", "up_proj"],
|
|
}
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
vllm_config: VllmConfig,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
config = vllm_config.model_config.hf_config
|
|
quant_config = vllm_config.quant_config
|
|
|
|
self.config = config
|
|
self.quant_config = quant_config
|
|
|
|
self.model = BailingMoeV25Model(
|
|
vllm_config=vllm_config,
|
|
prefix=maybe_prefix(prefix, "model"),
|
|
)
|
|
|
|
if get_pp_group().is_last_rank:
|
|
self.lm_head = ParallelLMHead(
|
|
config.vocab_size,
|
|
config.hidden_size,
|
|
quant_config=quant_config,
|
|
)
|
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
|
else:
|
|
self.lm_head = PPMissingLayer()
|
|
|
|
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
|
return self.model.embed_input_ids(input_ids)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor | None,
|
|
positions: torch.Tensor,
|
|
intermediate_tensors: IntermediateTensors | None = None,
|
|
inputs_embeds: torch.Tensor | None = None,
|
|
) -> torch.Tensor:
|
|
hidden_states = self.model(
|
|
input_ids=input_ids,
|
|
positions=positions,
|
|
intermediate_tensors=intermediate_tensors,
|
|
inputs_embeds=inputs_embeds,
|
|
)
|
|
return hidden_states
|
|
|
|
def compute_logits(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
return self.logits_processor(self.lm_head, hidden_states)
|
|
|
|
def make_empty_intermediate_tensors(
|
|
self, batch_size: int, dtype: torch.dtype, device: torch.device
|
|
) -> IntermediateTensors:
|
|
return IntermediateTensors(
|
|
{
|
|
"hidden_states": torch.zeros(
|
|
(batch_size, self.config.hidden_size), dtype=dtype, device=device
|
|
),
|
|
"residual": torch.zeros(
|
|
(batch_size, self.config.hidden_size), dtype=dtype, device=device
|
|
),
|
|
}
|
|
)
|
|
|
|
@classmethod
|
|
def get_mamba_state_shape_from_config(
|
|
cls,
|
|
vllm_config: VllmConfig,
|
|
) -> tuple[tuple[int, ...], ...]:
|
|
"""Calculate shape for linear attention cache."""
|
|
config = vllm_config.model_config.hf_config
|
|
tp_size = vllm_config.parallel_config.tensor_parallel_size
|
|
|
|
head_dim = getattr(
|
|
config, "head_dim", config.hidden_size // config.num_attention_heads
|
|
)
|
|
|
|
# Return base state shape from linear attention (no padding)
|
|
return MambaStateShapeCalculator.linear_attention_state_shape(
|
|
num_heads=config.num_attention_heads,
|
|
tp_size=tp_size,
|
|
head_dim=head_dim,
|
|
)
|
|
|
|
@classmethod
|
|
def get_mamba_state_dtype_from_config(
|
|
cls,
|
|
vllm_config: VllmConfig,
|
|
) -> tuple[torch.dtype, ...]:
|
|
return MambaStateDtypeCalculator.linear_attention_state_dtype(
|
|
vllm_config.model_config.dtype,
|
|
vllm_config.cache_config.mamba_cache_dtype,
|
|
)
|
|
|
|
@classmethod
|
|
def get_mamba_state_copy_func(cls) -> tuple:
|
|
return MambaStateCopyFuncCalculator.linear_attention_state_copy_func()
|
|
|
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
|
loader = AutoWeightsLoader(self)
|
|
return loader.load_weights(weights)
|
|
|
|
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
|
return self.model.get_expert_mapping()
|