[Model] Introduce Kimi Linear to vLLM (#27809)
Signed-off-by: lizhiyuan <lizhiyuan@moonshot.cn> Signed-off-by: Zhiyuan Li <uniartisan2017@gmail.com>
This commit is contained in:
@@ -80,6 +80,15 @@ class MambaStateDtypeCalculator:
|
||||
state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype)
|
||||
return (state_dtype, state_dtype)
|
||||
|
||||
@classmethod
|
||||
def kda_state_dtype(
|
||||
cls,
|
||||
model_dtype: ModelDType | torch.dtype,
|
||||
mamba_cache_dtype: MambaDType,
|
||||
):
|
||||
state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype)
|
||||
return (state_dtype, state_dtype, state_dtype, torch.float32)
|
||||
|
||||
|
||||
class MambaStateShapeCalculator:
|
||||
@classmethod
|
||||
@@ -182,3 +191,35 @@ class MambaStateShapeCalculator:
|
||||
head_v_dim,
|
||||
)
|
||||
return conv_state_shape, temporal_state_shape
|
||||
|
||||
@classmethod
|
||||
def kda_state_shape(
|
||||
cls,
|
||||
tp_world_size: int,
|
||||
num_heads: int,
|
||||
head_dim: int,
|
||||
num_k_heads: int | None = None,
|
||||
head_k_dim: int | None = None,
|
||||
conv_kernel_size: int = 4,
|
||||
num_spec: int = 0,
|
||||
) -> tuple[tuple[int, int], tuple[int, int], tuple[int, int], tuple[int, int, int]]:
|
||||
if num_k_heads is None:
|
||||
num_k_heads = num_heads
|
||||
if head_k_dim is None:
|
||||
head_k_dim = head_dim
|
||||
|
||||
proj_size = num_heads * head_dim
|
||||
proj_k_size = num_k_heads * head_k_dim
|
||||
|
||||
conv_state_shape = (divide(proj_size, tp_world_size), conv_kernel_size - 1)
|
||||
conv_state_k_shape = (divide(proj_k_size, tp_world_size), conv_kernel_size - 1)
|
||||
recurrent_state_shape = (divide(num_heads, tp_world_size), head_dim, head_dim)
|
||||
|
||||
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
|
||||
conv_state_k_shape = conv_state_k_shape[1], conv_state_k_shape[0]
|
||||
return (
|
||||
conv_state_shape,
|
||||
conv_state_k_shape,
|
||||
conv_state_k_shape,
|
||||
recurrent_state_shape,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user