[Model] Introduce Kimi Linear to vLLM (#27809)
Signed-off-by: lizhiyuan <lizhiyuan@moonshot.cn> Signed-off-by: Zhiyuan Li <uniartisan2017@gmail.com>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from copy import deepcopy
|
||||
from math import lcm
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import vllm.envs as envs
|
||||
@@ -8,7 +9,7 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
from vllm.utils.math_utils import cdiv, round_up
|
||||
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec, MLAAttentionSpec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
@@ -347,12 +348,28 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
|
||||
kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
|
||||
|
||||
# get attention page size (for 1 token)
|
||||
attn_page_size_1_token = FullAttentionSpec(
|
||||
block_size=1,
|
||||
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
|
||||
head_size=model_config.get_head_size(),
|
||||
dtype=kv_cache_dtype,
|
||||
).page_size_bytes
|
||||
# Attention backend constraints:
|
||||
# - FlashAttention (FA) requires block size to be multiple of 16
|
||||
# - MLA (Multi-head Latent Attention) requires larger alignment:
|
||||
# * CUTLASS_MLA backend: kernel_block_size 128 alignment
|
||||
# * Other MLA backends: kernel_block_size 64 alignment
|
||||
if model_config.use_mla:
|
||||
use_cutlass_mla = envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA"
|
||||
kernel_block_alignment_size = 128 if use_cutlass_mla else 64
|
||||
attn_page_size_1_token = MLAAttentionSpec(
|
||||
block_size=1,
|
||||
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
|
||||
head_size=model_config.get_head_size(),
|
||||
dtype=kv_cache_dtype,
|
||||
).page_size_bytes
|
||||
else:
|
||||
kernel_block_alignment_size = 16
|
||||
attn_page_size_1_token = FullAttentionSpec(
|
||||
block_size=1,
|
||||
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
|
||||
head_size=model_config.get_head_size(),
|
||||
dtype=kv_cache_dtype,
|
||||
).page_size_bytes
|
||||
|
||||
model_cls, _ = ModelRegistry.resolve_model_cls(
|
||||
model_config.architecture,
|
||||
@@ -372,17 +389,6 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
|
||||
if mamba_page_size == 0:
|
||||
return
|
||||
|
||||
# Attention backend constraints:
|
||||
# - FlashAttention (FA) requires block size to be multiple of 16
|
||||
# - MLA (Multi-head Latent Attention) requires larger alignment:
|
||||
# * CUTLASS_MLA backend: 128-byte alignment
|
||||
# * Other MLA backends: 64-byte alignment
|
||||
if model_config.use_mla:
|
||||
use_cutlass_mla = envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA"
|
||||
kernel_block_alignment_size = 128 if use_cutlass_mla else 64
|
||||
else:
|
||||
kernel_block_alignment_size = 16
|
||||
|
||||
if cache_config.enable_prefix_caching:
|
||||
# With prefix caching, select attention block size to
|
||||
# optimize for mamba kernel performance
|
||||
@@ -400,15 +406,8 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
|
||||
# easily by changing the way we layout chunks in the
|
||||
# mamba2 kernels.
|
||||
|
||||
from math import gcd
|
||||
|
||||
def lcm(a, b):
|
||||
return a * b // gcd(a, b)
|
||||
|
||||
base_chunk_size = mamba_block_size or model_config.get_mamba_chunk_size()
|
||||
|
||||
base_chunk_size = model_config.get_mamba_chunk_size()
|
||||
attn_tokens_per_mamba_state = cdiv(mamba_page_size, attn_page_size_1_token)
|
||||
|
||||
chunk_size = lcm(base_chunk_size, kernel_block_alignment_size)
|
||||
attn_block_size = chunk_size * cdiv(attn_tokens_per_mamba_state, chunk_size)
|
||||
cache_config.mamba_block_size = attn_block_size
|
||||
|
||||
Reference in New Issue
Block a user