[Model] Introduce Kimi Linear to vLLM (#27809)

Signed-off-by: lizhiyuan <lizhiyuan@moonshot.cn>
Signed-off-by: Zhiyuan Li <uniartisan2017@gmail.com>
This commit is contained in:
Zhiyuan Li
2025-10-30 21:02:27 +08:00
committed by GitHub
parent 1994de99ea
commit 4e68cc9b6a
15 changed files with 1325 additions and 48 deletions

View File

@@ -1304,7 +1304,7 @@ def kda_gate_fwd_kernel(
tl.store(y_ptr, b_y.to(y.dtype.element_ty), boundary_check=(0, 1))
def kda_gate_fwd(
def fused_kda_gate(
g: torch.Tensor,
A: torch.Tensor,
head_k_dim: int,

View File

@@ -0,0 +1,426 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from einops import rearrange
from torch import nn
from vllm.attention import AttentionBackend
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
from vllm.distributed import (
divide,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.model_loader.weight_utils import sharded_weight_loader
from vllm.model_executor.utils import set_weight_attrs
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
from .fla.ops.kda import (
FusedRMSNormGated,
chunk_kda,
fused_kda_gate,
fused_recurrent_kda,
)
from .linear import (
ColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from .mamba.abstract import MambaBase
from .mamba.mamba_utils import MambaStateDtypeCalculator, MambaStateShapeCalculator
from .mamba.ops.causal_conv1d import causal_conv1d_fn, causal_conv1d_update
from .quantization.base_config import QuantizationConfig
logger = init_logger(__name__)
def kda_attention(
hidden_states: torch.Tensor,
output: torch.Tensor,
layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
self._forward(hidden_states=hidden_states, output=output)
def kda_attention_fake(
hidden_states: torch.Tensor,
output: torch.Tensor,
layer_name: str,
) -> None:
return
direct_register_custom_op(
op_name="kda_attention",
op_func=kda_attention,
mutates_args=["output"],
fake_impl=kda_attention_fake,
)
class KimiDeltaAttention(nn.Module, MambaBase):
@property
def mamba_type(self) -> str:
return "linear_attention"
def get_attn_backend(self) -> type["AttentionBackend"]:
from vllm.v1.attention.backends.gdn_attn import GDNAttentionBackend
return GDNAttentionBackend
def get_state_dtype(
self,
) -> tuple[torch.dtype, torch.dtype, torch.dtype, torch.dtype]:
if self.model_config is None or self.cache_config is None:
raise ValueError("model_config and cache_config must be set")
return MambaStateDtypeCalculator.kda_state_dtype(
self.model_config.dtype, self.cache_config.mamba_cache_dtype
)
def get_state_shape(
self,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
return MambaStateShapeCalculator.kda_state_shape(
self.tp_size, self.num_heads, self.head_dim, conv_kernel_size=self.conv_size
)
def __init__(
self,
layer_idx: int,
hidden_size: int,
quant_config: QuantizationConfig | None = None,
cache_config: CacheConfig | None = None,
model_config: ModelConfig | None = None,
rms_norm_eps: float = 1e-5,
prefix: str = "",
**kwargs,
) -> None:
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.hidden_size = hidden_size
self.model_config = model_config
self.cache_config = cache_config
if model_config is None:
raise ValueError("model_config must be provided")
kda_config = model_config.linear_attn_config
self.head_dim = kda_config["head_dim"]
self.num_heads = kda_config["num_heads"]
self.layer_idx = layer_idx
self.prefix = prefix
assert self.num_heads % self.tp_size == 0
self.local_num_heads = divide(self.num_heads, self.tp_size)
projection_size = self.head_dim * self.num_heads
self.conv_size = kda_config["short_conv_kernel_size"]
self.q_proj = ColumnParallelLinear(
self.hidden_size,
projection_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_proj",
)
self.k_proj = ColumnParallelLinear(
self.hidden_size,
projection_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.k_proj",
)
self.v_proj = ColumnParallelLinear(
self.hidden_size,
projection_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.v_proj",
)
self.f_a_proj = ReplicatedLinear(
self.hidden_size,
self.head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.f_a_proj",
)
self.f_b_proj = ColumnParallelLinear(
self.head_dim,
projection_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.f_b_proj",
)
self.dt_bias = nn.Parameter(
torch.empty(divide(projection_size, self.tp_size), dtype=torch.float32)
)
set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)})
self.b_proj = ColumnParallelLinear(
self.hidden_size,
self.num_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.b_proj",
)
self.q_conv1d = ColumnParallelLinear(
input_size=self.conv_size,
output_size=projection_size,
bias=False,
params_dtype=torch.float32,
prefix=f"{prefix}.q_conv1d",
)
self.k_conv1d = ColumnParallelLinear(
input_size=self.conv_size,
output_size=projection_size,
bias=False,
params_dtype=torch.float32,
prefix=f"{prefix}.k_conv1d",
)
self.v_conv1d = ColumnParallelLinear(
input_size=self.conv_size,
output_size=projection_size,
bias=False,
params_dtype=torch.float32,
prefix=f"{prefix}.v_conv1d",
)
# unsqueeze to fit conv1d weights shape into the linear weights shape.
# Can't do this in `weight_loader` since it already exists in
# `ColumnParallelLinear` and `set_weight_attrs`
# doesn't allow to override it
self.q_conv1d.weight.data = self.q_conv1d.weight.data.unsqueeze(1)
self.k_conv1d.weight.data = self.k_conv1d.weight.data.unsqueeze(1)
self.v_conv1d.weight.data = self.v_conv1d.weight.data.unsqueeze(1)
self.A_log = nn.Parameter(
torch.empty(1, 1, self.local_num_heads, 1, dtype=torch.float32)
)
set_weight_attrs(self.A_log, {"weight_loader": sharded_weight_loader(2)})
self.g_a_proj = ReplicatedLinear(
self.hidden_size,
self.head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.g_a_proj",
)
self.g_b_proj = ColumnParallelLinear(
self.head_dim,
projection_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.g_b_proj",
)
self.o_norm = FusedRMSNormGated(
self.head_dim, eps=rms_norm_eps, activation="sigmoid"
)
self.o_proj = RowParallelLinear(
projection_size,
self.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
def forward(
self,
hidden_states: torch.Tensor,
positions: torch.Tensor,
output: torch.Tensor,
) -> None:
return torch.ops.vllm.kda_attention(
hidden_states,
output,
self.prefix,
)
def _forward(
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
) -> None:
forward_context = get_forward_context()
attn_metadata: AttentionMetadata = forward_context.attn_metadata
if attn_metadata is None:
# V1 profile run
# Mimic the memory allocation in the real run
q = torch.empty_like(hidden_states)
k = torch.empty_like(hidden_states)
v = torch.empty_like(hidden_states)
g = hidden_states.new_empty(
hidden_states.size(0),
self.local_num_heads,
self.head_dim,
dtype=torch.float32,
)
beta = torch.empty(
hidden_states.size(0), self.local_num_heads, dtype=torch.float32
)
core_attn_out = torch.empty_like(hidden_states)
return
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
assert isinstance(attn_metadata, GDNAttentionMetadata)
has_initial_state = attn_metadata.has_initial_state
non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501
constant_caches = self.kv_cache[forward_context.virtual_engine]
(conv_state_q, conv_state_k, conv_state_v, recurrent_state) = constant_caches
# deal with strides
conv_state_q = conv_state_q.transpose(-1, -2)
conv_state_k = conv_state_k.transpose(-1, -2)
conv_state_v = conv_state_v.transpose(-1, -2)
q_proj_states = self.q_proj(hidden_states)[0]
k_proj_states = self.k_proj(hidden_states)[0]
v_proj_states = self.v_proj(hidden_states)[0]
q_conv_weights = self.q_conv1d.weight.view(
self.q_conv1d.weight.size(0), self.q_conv1d.weight.size(2)
)
k_conv_weights = self.k_conv1d.weight.view(
self.k_conv1d.weight.size(0), self.k_conv1d.weight.size(2)
)
v_conv_weights = self.v_conv1d.weight.view(
self.v_conv1d.weight.size(0), self.v_conv1d.weight.size(2)
)
if attn_metadata.num_prefills > 0:
q_proj_states = q_proj_states.transpose(0, 1)
k_proj_states = k_proj_states.transpose(0, 1)
v_proj_states = v_proj_states.transpose(0, 1)
q = causal_conv1d_fn(
q_proj_states,
q_conv_weights,
self.q_conv1d.bias,
activation="silu",
conv_states=conv_state_q,
has_initial_state=has_initial_state,
cache_indices=non_spec_state_indices_tensor,
query_start_loc=non_spec_query_start_loc,
metadata=attn_metadata,
).transpose(0, 1)
k = causal_conv1d_fn(
k_proj_states,
k_conv_weights,
self.k_conv1d.bias,
activation="silu",
conv_states=conv_state_k,
has_initial_state=has_initial_state,
cache_indices=non_spec_state_indices_tensor,
query_start_loc=non_spec_query_start_loc,
metadata=attn_metadata,
).transpose(0, 1)
v = causal_conv1d_fn(
v_proj_states,
v_conv_weights,
self.v_conv1d.bias,
activation="silu",
conv_states=conv_state_v,
has_initial_state=has_initial_state,
cache_indices=non_spec_state_indices_tensor,
query_start_loc=non_spec_query_start_loc,
metadata=attn_metadata,
).transpose(0, 1)
else:
decode_conv_indices = non_spec_state_indices_tensor[
: attn_metadata.num_decodes
]
q = causal_conv1d_update(
q_proj_states,
conv_state_q,
q_conv_weights,
self.q_conv1d.bias,
activation="silu",
conv_state_indices=decode_conv_indices,
validate_data=True,
)
k = causal_conv1d_update(
k_proj_states,
conv_state_k,
k_conv_weights,
self.k_conv1d.bias,
activation="silu",
conv_state_indices=decode_conv_indices,
validate_data=True,
)
v = causal_conv1d_update(
v_proj_states,
conv_state_v,
v_conv_weights,
self.v_conv1d.bias,
activation="silu",
conv_state_indices=decode_conv_indices,
validate_data=True,
)
q, k, v = map(
lambda x: rearrange(x, "n (h d) -> 1 n h d", d=self.head_dim), (q, k, v)
)
beta = self.b_proj(hidden_states)[0].float().sigmoid()
g = self.f_b_proj(self.f_a_proj(hidden_states)[0])[0]
g = fused_kda_gate(g, self.A_log, self.head_dim, g_bias=self.dt_bias)
beta = beta.unsqueeze(0)
g = g.unsqueeze(0)
if attn_metadata.num_prefills > 0:
zero_idx = non_spec_state_indices_tensor[~has_initial_state]
recurrent_state[zero_idx] = 0
initial_state = recurrent_state[non_spec_state_indices_tensor].contiguous()
(
core_attn_out_non_spec,
last_recurrent_state,
) = chunk_kda(
q=q,
k=k,
v=v,
g=g,
beta=beta,
initial_state=initial_state,
output_final_state=True,
use_qk_l2norm_in_kernel=True,
cu_seqlens=non_spec_query_start_loc,
)
# Init cache
recurrent_state[non_spec_state_indices_tensor] = last_recurrent_state
else:
(
core_attn_out_non_spec,
last_recurrent_state,
) = fused_recurrent_kda(
q=q,
k=k,
v=v,
g=g,
beta=beta,
initial_state=recurrent_state,
use_qk_l2norm_in_kernel=True,
cu_seqlens=non_spec_query_start_loc,
ssm_state_indices=non_spec_state_indices_tensor,
)
g_proj_states = self.g_b_proj(self.g_a_proj(hidden_states)[0])[0]
g = rearrange(g_proj_states, "... (h d) -> ... h d", d=self.head_dim)
core_attn_out = self.o_norm(core_attn_out_non_spec, g)
core_attn_out = rearrange(core_attn_out, "1 n h d -> n (h d)")
output[:] = self.o_proj(core_attn_out)[0]

View File

@@ -80,6 +80,15 @@ class MambaStateDtypeCalculator:
state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype)
return (state_dtype, state_dtype)
@classmethod
def kda_state_dtype(
cls,
model_dtype: ModelDType | torch.dtype,
mamba_cache_dtype: MambaDType,
):
state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype)
return (state_dtype, state_dtype, state_dtype, torch.float32)
class MambaStateShapeCalculator:
@classmethod
@@ -182,3 +191,35 @@ class MambaStateShapeCalculator:
head_v_dim,
)
return conv_state_shape, temporal_state_shape
@classmethod
def kda_state_shape(
cls,
tp_world_size: int,
num_heads: int,
head_dim: int,
num_k_heads: int | None = None,
head_k_dim: int | None = None,
conv_kernel_size: int = 4,
num_spec: int = 0,
) -> tuple[tuple[int, int], tuple[int, int], tuple[int, int], tuple[int, int, int]]:
if num_k_heads is None:
num_k_heads = num_heads
if head_k_dim is None:
head_k_dim = head_dim
proj_size = num_heads * head_dim
proj_k_size = num_k_heads * head_k_dim
conv_state_shape = (divide(proj_size, tp_world_size), conv_kernel_size - 1)
conv_state_k_shape = (divide(proj_k_size, tp_world_size), conv_kernel_size - 1)
recurrent_state_shape = (divide(num_heads, tp_world_size), head_dim, head_dim)
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
conv_state_k_shape = conv_state_k_shape[1], conv_state_k_shape[0]
return (
conv_state_shape,
conv_state_k_shape,
conv_state_k_shape,
recurrent_state_shape,
)

View File

@@ -147,9 +147,10 @@ class MultiHeadLatentAttentionWrapper(CustomOp):
# Add head dim of 1 to k_pe
k_pe = k_pe.unsqueeze(1)
q[..., self.qk_nope_head_dim :], k_pe = self.rotary_emb(
positions, q[..., self.qk_nope_head_dim :], k_pe
)
if self.rotary_emb is not None:
q[..., self.qk_nope_head_dim :], k_pe = self.rotary_emb(
positions, q[..., self.qk_nope_head_dim :], k_pe
)
if self.indexer and self.is_sparse:
_topk_indices = self.indexer(hidden_states, q_c, positions, self.rotary_emb)

View File

@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from copy import deepcopy
from math import lcm
from typing import TYPE_CHECKING
import vllm.envs as envs
@@ -8,7 +9,7 @@ from vllm.logger import init_logger
from vllm.model_executor.models import ModelRegistry
from vllm.utils.math_utils import cdiv, round_up
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec, MLAAttentionSpec
if TYPE_CHECKING:
from vllm.config import VllmConfig
@@ -347,12 +348,28 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
# get attention page size (for 1 token)
attn_page_size_1_token = FullAttentionSpec(
block_size=1,
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
head_size=model_config.get_head_size(),
dtype=kv_cache_dtype,
).page_size_bytes
# Attention backend constraints:
# - FlashAttention (FA) requires block size to be multiple of 16
# - MLA (Multi-head Latent Attention) requires larger alignment:
# * CUTLASS_MLA backend: kernel_block_size 128 alignment
# * Other MLA backends: kernel_block_size 64 alignment
if model_config.use_mla:
use_cutlass_mla = envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA"
kernel_block_alignment_size = 128 if use_cutlass_mla else 64
attn_page_size_1_token = MLAAttentionSpec(
block_size=1,
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
head_size=model_config.get_head_size(),
dtype=kv_cache_dtype,
).page_size_bytes
else:
kernel_block_alignment_size = 16
attn_page_size_1_token = FullAttentionSpec(
block_size=1,
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
head_size=model_config.get_head_size(),
dtype=kv_cache_dtype,
).page_size_bytes
model_cls, _ = ModelRegistry.resolve_model_cls(
model_config.architecture,
@@ -372,17 +389,6 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
if mamba_page_size == 0:
return
# Attention backend constraints:
# - FlashAttention (FA) requires block size to be multiple of 16
# - MLA (Multi-head Latent Attention) requires larger alignment:
# * CUTLASS_MLA backend: 128-byte alignment
# * Other MLA backends: 64-byte alignment
if model_config.use_mla:
use_cutlass_mla = envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA"
kernel_block_alignment_size = 128 if use_cutlass_mla else 64
else:
kernel_block_alignment_size = 16
if cache_config.enable_prefix_caching:
# With prefix caching, select attention block size to
# optimize for mamba kernel performance
@@ -400,15 +406,8 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
# easily by changing the way we layout chunks in the
# mamba2 kernels.
from math import gcd
def lcm(a, b):
return a * b // gcd(a, b)
base_chunk_size = mamba_block_size or model_config.get_mamba_chunk_size()
base_chunk_size = model_config.get_mamba_chunk_size()
attn_tokens_per_mamba_state = cdiv(mamba_page_size, attn_page_size_1_token)
chunk_size = lcm(base_chunk_size, kernel_block_alignment_size)
attn_block_size = chunk_size * cdiv(attn_tokens_per_mamba_state, chunk_size)
cache_config.mamba_block_size = attn_block_size

View File

@@ -0,0 +1,663 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
from typing import Any
import torch
from torch import nn
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig
from vllm.distributed import (
get_pp_group,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.kda import KimiDeltaAttention
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator,
MambaStateShapeCalculator,
)
from vllm.model_executor.layers.mla import MLAModules, MultiHeadLatentAttentionWrapper
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
)
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.kimi_linear import KimiLinearConfig
from .interfaces import HasInnerState, IsHybrid, MixtureOfExperts, SupportsPP
from .utils import (
PPMissingLayer,
is_pp_missing_parameter,
make_layers,
maybe_prefix,
)
logger = init_logger(__name__)
class KimiMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: QKVParallelLinear | None = None,
reduce_results: bool = True,
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
reduce_results=reduce_results,
prefix=f"{prefix}.down_proj",
)
if hidden_act != "silu":
raise ValueError(
f"Unsupported activation: {hidden_act}. Only silu is supported for now."
)
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class KimiMoE(nn.Module):
def __init__(
self,
config: KimiLinearConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
layer_idx: int = 0,
):
super().__init__()
hidden_size = config.hidden_size
intermediate_size = config.intermediate_size
moe_intermediate_size = config.moe_intermediate_size
num_experts = config.num_experts
moe_renormalize = config.moe_renormalize
self.tp_size = get_tensor_model_parallel_world_size()
self.routed_scaling_factor = config.routed_scaling_factor
self.num_shared_experts = config.num_shared_experts
self.layer_idx = layer_idx
if config.hidden_act != "silu":
raise ValueError(
f"Unsupported activation: {config.hidden_act}. "
"Only silu is supported for now."
)
# Gate always runs at half / full precision for now.
self.gate = ReplicatedLinear(
hidden_size,
num_experts,
bias=False,
quant_config=None,
prefix=f"{prefix}.gate",
)
self.gate.e_score_correction_bias = nn.Parameter(torch.empty(num_experts))
self.experts = FusedMoE(
num_experts=num_experts,
top_k=config.num_experts_per_token,
hidden_size=hidden_size,
intermediate_size=moe_intermediate_size,
reduce_results=False,
renormalize=moe_renormalize,
quant_config=quant_config,
use_grouped_topk=config.use_grouped_topk,
num_expert_group=config.num_expert_group,
topk_group=config.topk_group,
prefix=f"{prefix}.experts",
scoring_func=config.moe_router_activation_func,
e_score_correction_bias=self.gate.e_score_correction_bias,
)
if self.num_shared_experts is not None:
intermediate_size = moe_intermediate_size * self.num_shared_experts
self.shared_experts = KimiMLP(
hidden_size=config.hidden_size,
intermediate_size=intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
reduce_results=False,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_size = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_size)
if self.num_shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = (
self.experts(hidden_states=hidden_states, router_logits=router_logits)
* self.routed_scaling_factor
)
if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states.view(num_tokens, hidden_size)
class KimiMLAAttention(nn.Module):
"""
Main reference: DeepseekV2 vllm Implementation
"""
def __init__(
self,
config: KimiLinearConfig,
hidden_size: int,
num_heads: int,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
v_head_dim: int,
q_lora_rank: int | None,
kv_lora_rank: int,
rope_theta: float = 10000,
use_nope: bool = False,
rope_scaling: dict[str, Any] | None = None,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
**kwargs,
) -> None:
super().__init__()
self.hidden_size = hidden_size
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
self.v_head_dim = v_head_dim
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.num_heads = num_heads
tp_size = get_tensor_model_parallel_world_size()
self.num_local_heads = num_heads // tp_size
self.scaling = self.qk_head_dim**-0.5
self.rope_theta = rope_theta
self.use_nope = use_nope
assert self.use_nope is True
assert self.q_lora_rank is None
assert rope_scaling is None
assert num_heads % tp_size == 0
self.kv_a_proj_with_mqa = ReplicatedLinear(
self.hidden_size,
self.kv_lora_rank + self.qk_rope_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.kv_a_proj_with_mqa",
)
self.q_proj = ColumnParallelLinear(
self.hidden_size,
self.num_heads * self.qk_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_proj",
)
self.kv_a_layernorm = RMSNorm(
self.kv_lora_rank,
eps=config.rms_norm_eps,
)
self.kv_b_proj = ColumnParallelLinear(
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.kv_b_proj",
)
self.o_proj = RowParallelLinear(
self.num_heads * self.v_head_dim,
self.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
mla_modules = MLAModules(
kv_a_layernorm=self.kv_a_layernorm,
kv_b_proj=self.kv_b_proj,
rotary_emb=None,
o_proj=self.o_proj,
fused_qkv_a_proj=None,
kv_a_proj_with_mqa=self.kv_a_proj_with_mqa,
q_a_layernorm=None,
q_b_proj=None,
q_proj=self.q_proj,
indexer=None,
is_sparse=False,
topk_indices_buffer=None,
)
self.mla_attn = MultiHeadLatentAttentionWrapper(
self.hidden_size,
self.num_local_heads,
self.scaling,
self.qk_nope_head_dim,
self.qk_rope_head_dim,
self.v_head_dim,
self.q_lora_rank,
self.kv_lora_rank,
mla_modules,
cache_config,
quant_config,
prefix,
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
output: torch.Tensor,
) -> None:
output[:] = self.mla_attn(positions, hidden_states)
class KimiDecoderLayer(nn.Module):
def __init__(
self,
config: KimiLinearConfig,
layer_idx: int,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
parallel_config: ParallelConfig | None = None,
model_config: ModelConfig | None = None,
prefix: str = "",
**kwargs,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
self.is_moe = config.is_moe
if config.is_kda_layer(layer_idx):
self.self_attn = KimiDeltaAttention(
layer_idx=layer_idx,
hidden_size=config.hidden_size,
quant_config=quant_config,
cache_config=cache_config,
model_config=config,
prefix=f"{prefix}.self_attn",
)
else:
self.self_attn = KimiMLAAttention(
layer_idx=layer_idx,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
quant_config=quant_config,
cache_config=cache_config,
model_config=model_config,
prefix=f"{prefix}.self_attn",
config=config,
qk_nope_head_dim=config.qk_nope_head_dim,
qk_rope_head_dim=config.qk_rope_head_dim,
v_head_dim=config.v_head_dim,
q_lora_rank=config.q_lora_rank,
kv_lora_rank=config.kv_lora_rank,
use_nope=config.mla_use_nope,
)
if (
self.is_moe
and config.num_experts is not None
and layer_idx >= config.first_k_dense_replace
and layer_idx % config.moe_layer_freq == 0
):
self.block_sparse_moe = KimiMoE(
config=config,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.mlp = self.block_sparse_moe
else:
self.mlp = KimiMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: torch.Tensor | None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
attn_output = torch.empty_like(hidden_states)
self.self_attn(
hidden_states=hidden_states,
positions=positions,
output=attn_output,
)
hidden_states = attn_output
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
@support_torch_compile
class KimiLinearModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_text_config
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
parallel_config = vllm_config.parallel_config
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
if get_pp_group().is_first_rank:
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
prefix=f"{prefix}.embed_tokens",
)
else:
self.embed_tokens = PPMissingLayer()
extra_kwargs = {}
def get_layer(prefix: str):
layer_idx = int(prefix.rsplit(".", 1)[1])
return KimiDecoderLayer(
config,
layer_idx,
cache_config,
quant_config,
parallel_config,
model_config,
prefix,
**extra_kwargs,
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
get_layer,
prefix=f"{prefix}.layers",
)
if get_pp_group().is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer()
world_size = get_tensor_model_parallel_world_size()
assert config.num_attention_heads % world_size == 0, (
"num_attention_heads must be divisible by world_size"
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None,
**kwargs,
) -> torch.Tensor:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for _, layer in enumerate(self.layers[self.start_layer : self.end_layer]):
hidden_states, residual = layer(
positions=positions,
hidden_states=hidden_states,
residual=residual,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors(
{"hidden_states": hidden_states, "residual": residual}
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class KimiLinearForCausalLM(
nn.Module, HasInnerState, SupportsPP, MixtureOfExperts, IsHybrid
):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.model_config = vllm_config.model_config
self.vllm_config = vllm_config
self.config = self.model_config.hf_config
quant_config = vllm_config.quant_config
self.quant_config = quant_config
self.model = KimiLinearModel(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)
if get_pp_group().is_last_rank:
self.lm_head = ParallelLMHead(
self.config.vocab_size,
self.config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
else:
self.lm_head = PPMissingLayer()
logit_scale = getattr(self.config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(
self.config.vocab_size, scale=logit_scale
)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs,
) -> torch.Tensor | IntermediateTensors:
hidden_states = self.model(
input_ids, positions, intermediate_tensors, inputs_embeds, **kwargs
)
return hidden_states
@classmethod
def get_mamba_state_dtype_from_config(
cls,
vllm_config: "VllmConfig",
) -> tuple[torch.dtype, torch.dtype, torch.dtype, torch.dtype]:
return MambaStateDtypeCalculator.kda_state_dtype(
vllm_config.model_config.dtype, vllm_config.cache_config.mamba_cache_dtype
)
@classmethod
def get_mamba_state_shape_from_config(
cls, vllm_config: "VllmConfig"
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
parallel_config = vllm_config.parallel_config
hf_config = vllm_config.model_config.hf_config
tp_size = parallel_config.tensor_parallel_size
num_spec = (
vllm_config.speculative_config.num_speculative_tokens
if vllm_config.speculative_config
else 0
)
return MambaStateShapeCalculator.kda_state_shape(
tp_size,
hf_config.linear_attn_config["num_heads"],
hf_config.linear_attn_config["head_dim"],
conv_kernel_size=hf_config.linear_attn_config["short_conv_kernel_size"],
num_spec=num_spec,
)
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor | None:
return self.logits_processor(self.lm_head, hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
if self.config.is_moe:
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="w1",
ckpt_down_proj_name="w2",
ckpt_up_proj_name="w3",
num_experts=self.config.num_experts,
)
else:
expert_params_mapping = []
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for args in weights:
name, loaded_weight = args[:2]
kwargs = args[2] if len(args) > 2 else {}
if "rotary_emb.inv_freq" in name:
continue
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
if spec_layer is not None:
continue # skip spec decode layers for main model
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if ("mlp.experts." in name) and name not in params_dict:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for idx, (param_name, weight_name, expert_id, shard_id) in enumerate(
expert_params_mapping
):
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(
param,
loaded_weight,
name,
expert_id=expert_id,
shard_id=shard_id,
)
break
else:
# Skip loading extra bias for GPTQ models.
if (
name.endswith(".bias")
and name not in params_dict
and not self.config.is_linear_attn
): # noqa: E501
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight, **kwargs)
loaded_params.add(name)
def get_spec_layer_idx_from_weight_name(
config: KimiLinearConfig, weight_name: str
) -> int | None:
if hasattr(config, "num_nextn_predict_layers") and (
config.num_nextn_predict_layers > 0
):
layer_idx = config.num_hidden_layers
for i in range(config.num_nextn_predict_layers):
if weight_name.startswith(f"model.layers.{layer_idx + i}."):
return layer_idx + i
return None

View File

@@ -118,6 +118,7 @@ _TEXT_GENERATION_MODELS = {
"InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
"JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
"JambaForCausalLM": ("jamba", "JambaForCausalLM"),
"KimiLinearForCausalLM": ("kimi_linear", "KimiLinearForCausalLM"), # noqa: E501
"Lfm2ForCausalLM": ("lfm2", "Lfm2ForCausalLM"),
"Lfm2MoeForCausalLM": ("lfm2_moe", "Lfm2MoeForCausalLM"),
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),