321 lines
10 KiB
Python
321 lines
10 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
from collections.abc import Callable
|
|
from dataclasses import dataclass
|
|
from typing import TypeAlias
|
|
|
|
import torch
|
|
|
|
from vllm.config.cache import MambaDType
|
|
from vllm.config.model import ModelDType
|
|
from vllm.distributed import divide
|
|
from vllm.utils.torch_utils import (
|
|
STR_DTYPE_TO_TORCH_DTYPE,
|
|
get_kv_cache_torch_dtype,
|
|
)
|
|
|
|
|
|
class MambaStateDtypeCalculator:
|
|
@classmethod
|
|
def linear_attention_state_dtype(
|
|
cls,
|
|
model_dtype: ModelDType | torch.dtype,
|
|
mamba_cache_dtype: MambaDType,
|
|
) -> tuple[torch.dtype, ...]:
|
|
# TODO (tdoublep) requires testing
|
|
if mamba_cache_dtype == "float32":
|
|
raise ValueError("fp32 state for minimax is not yet supported")
|
|
state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype)
|
|
return (state_dtype,)
|
|
|
|
@classmethod
|
|
def mamba1_state_dtype(
|
|
cls,
|
|
model_dtype: ModelDType | torch.dtype,
|
|
mamba_cache_dtype: MambaDType,
|
|
mamba_ssm_cache_dtype: MambaDType,
|
|
) -> tuple[torch.dtype, ...]:
|
|
return cls._mamba_state_dtype(
|
|
model_dtype, mamba_cache_dtype, mamba_ssm_cache_dtype
|
|
)
|
|
|
|
@classmethod
|
|
def mamba2_state_dtype(
|
|
cls,
|
|
model_dtype: ModelDType | torch.dtype,
|
|
mamba_cache_dtype: MambaDType,
|
|
mamba_ssm_cache_dtype: MambaDType,
|
|
) -> tuple[torch.dtype, ...]:
|
|
return cls._mamba_state_dtype(
|
|
model_dtype, mamba_cache_dtype, mamba_ssm_cache_dtype
|
|
)
|
|
|
|
@classmethod
|
|
def _mamba_state_dtype(
|
|
cls,
|
|
model_dtype: ModelDType | torch.dtype,
|
|
mamba_cache_dtype: MambaDType,
|
|
mamba_ssm_cache_dtype: MambaDType,
|
|
) -> tuple[torch.dtype, ...]:
|
|
conv_state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype)
|
|
if mamba_ssm_cache_dtype == "auto":
|
|
temporal_state_dtype = conv_state_dtype
|
|
else:
|
|
temporal_state_dtype = STR_DTYPE_TO_TORCH_DTYPE[mamba_ssm_cache_dtype]
|
|
|
|
return (conv_state_dtype, temporal_state_dtype)
|
|
|
|
@classmethod
|
|
def short_conv_state_dtype(
|
|
cls,
|
|
model_dtype: ModelDType | torch.dtype,
|
|
mamba_cache_dtype: MambaDType,
|
|
) -> tuple[torch.dtype, ...]:
|
|
conv_state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype)
|
|
return (conv_state_dtype,)
|
|
|
|
@classmethod
|
|
def gated_delta_net_state_dtype(
|
|
cls,
|
|
model_dtype: ModelDType | torch.dtype,
|
|
mamba_cache_dtype: MambaDType,
|
|
) -> tuple[torch.dtype, torch.dtype]:
|
|
state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype)
|
|
return (state_dtype, state_dtype)
|
|
|
|
@classmethod
|
|
def kda_state_dtype(
|
|
cls,
|
|
model_dtype: ModelDType | torch.dtype,
|
|
mamba_cache_dtype: MambaDType,
|
|
):
|
|
state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype)
|
|
return (state_dtype, state_dtype, state_dtype, torch.float32)
|
|
|
|
|
|
class MambaStateShapeCalculator:
|
|
@classmethod
|
|
def linear_attention_state_shape(
|
|
cls,
|
|
num_heads: int,
|
|
tp_size: int,
|
|
head_dim: int,
|
|
) -> tuple[tuple[int, int, int], ...]:
|
|
state_shape = (num_heads // tp_size, head_dim, head_dim)
|
|
return (state_shape,)
|
|
|
|
@classmethod
|
|
def mamba1_state_shape(
|
|
cls,
|
|
tp_world_size: int,
|
|
intermediate_size: int,
|
|
state_size: int,
|
|
conv_kernel: int,
|
|
) -> tuple[tuple[int, int], tuple[int, int]]:
|
|
conv_state_shape = (divide(intermediate_size, tp_world_size), conv_kernel - 1)
|
|
|
|
temporal_state_shape = (divide(intermediate_size, tp_world_size), state_size)
|
|
|
|
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
|
|
|
|
return conv_state_shape, temporal_state_shape
|
|
|
|
@classmethod
|
|
def mamba2_state_shape(
|
|
cls,
|
|
tp_world_size: int,
|
|
intermediate_size: int,
|
|
n_groups: int,
|
|
num_heads: int,
|
|
head_dim: int,
|
|
state_size: int,
|
|
conv_kernel: int,
|
|
) -> tuple[tuple[int, int], tuple[int, int, int]]:
|
|
# if n_groups is not divisible by world_size, need to extend the shards
|
|
# to ensure all groups needed by a head is sharded along with it
|
|
n_groups = n_groups + cls.extra_groups_for_head_shards(n_groups, tp_world_size)
|
|
# heads and n_groups are TP-ed
|
|
conv_dim = intermediate_size + 2 * n_groups * state_size
|
|
|
|
# contiguous along 'dim' axis
|
|
conv_state_shape = (conv_kernel - 1, divide(conv_dim, tp_world_size))
|
|
|
|
# These are not TP-ed as they depend on A, dt_bias, D
|
|
# - they are typically small
|
|
# e.g., (h_heads, head_dim, state_size) = (128, 64, 128)
|
|
temporal_state_shape = (divide(num_heads, tp_world_size), head_dim, state_size)
|
|
return conv_state_shape, temporal_state_shape
|
|
|
|
@classmethod
|
|
def short_conv_state_shape(
|
|
cls,
|
|
tp_world_size: int,
|
|
intermediate_size: int,
|
|
conv_kernel: int,
|
|
) -> tuple[tuple[int, int]]:
|
|
conv_dim = divide(intermediate_size, tp_world_size)
|
|
conv_state_shape = (conv_kernel - 1, conv_dim)
|
|
return (conv_state_shape,)
|
|
|
|
@classmethod
|
|
def extra_groups_for_head_shards(cls, ngroups: int, tp_size: int):
|
|
"""Compute the increase in group numbers to account for
|
|
replication in order to accompany the head shards."""
|
|
|
|
# in the case ngoups % tp_size == 0, this will be zero
|
|
if ngroups % tp_size == 0:
|
|
return 0
|
|
|
|
# for n_groups == 1, this is exactly tp_size - n_groups
|
|
return tp_size - ngroups
|
|
|
|
@classmethod
|
|
def gated_delta_net_state_shape(
|
|
cls,
|
|
tp_world_size: int,
|
|
num_k_heads: int,
|
|
num_v_heads: int,
|
|
head_k_dim: int,
|
|
head_v_dim: int,
|
|
conv_kernel_size: int,
|
|
num_spec: int = 0,
|
|
):
|
|
conv_dim = head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads
|
|
conv_state_shape = (
|
|
divide(conv_dim, tp_world_size),
|
|
conv_kernel_size - 1 + num_spec,
|
|
)
|
|
|
|
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
|
|
|
|
temporal_state_shape = (
|
|
divide(num_v_heads, tp_world_size),
|
|
head_v_dim,
|
|
head_k_dim,
|
|
)
|
|
return conv_state_shape, temporal_state_shape
|
|
|
|
@classmethod
|
|
def kda_state_shape(
|
|
cls,
|
|
tp_world_size: int,
|
|
num_heads: int,
|
|
head_dim: int,
|
|
num_k_heads: int | None = None,
|
|
head_k_dim: int | None = None,
|
|
conv_kernel_size: int = 4,
|
|
num_spec: int = 0,
|
|
) -> tuple[tuple[int, int], tuple[int, int], tuple[int, int], tuple[int, int, int]]:
|
|
if num_k_heads is None:
|
|
num_k_heads = num_heads
|
|
if head_k_dim is None:
|
|
head_k_dim = head_dim
|
|
|
|
proj_size = num_heads * head_dim
|
|
proj_k_size = num_k_heads * head_k_dim
|
|
|
|
conv_state_shape = (divide(proj_size, tp_world_size), conv_kernel_size - 1)
|
|
conv_state_k_shape = (divide(proj_k_size, tp_world_size), conv_kernel_size - 1)
|
|
recurrent_state_shape = (divide(num_heads, tp_world_size), head_dim, head_dim)
|
|
|
|
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
|
|
conv_state_k_shape = conv_state_k_shape[1], conv_state_k_shape[0]
|
|
return (
|
|
conv_state_shape,
|
|
conv_state_k_shape,
|
|
conv_state_k_shape,
|
|
recurrent_state_shape,
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class MambaCopySpec:
|
|
"""
|
|
Data class specifying the memory-copy parameters for Mamba states used for
|
|
prefix caching in align mode.
|
|
|
|
Attributes:
|
|
start_addr (int): Starting address for the memory copy operation.
|
|
num_elements (int): Number of elements to copy from the starting address.
|
|
"""
|
|
|
|
start_addr: int
|
|
num_elements: int
|
|
|
|
|
|
MambaStateCopyFunc: TypeAlias = Callable[
|
|
[torch.Tensor, list[int], int, int], MambaCopySpec
|
|
]
|
|
"""
|
|
Type alias for a function that computes a MambaCopySpec for copying state slices.
|
|
Parameters:
|
|
state: torch.Tensor - the Mamba state tensor (e.g., conv or temporal states).
|
|
block_ids: list[int] - the list of block indices for the state to copy.
|
|
cur_block_idx: int - current block index within `block_ids` to copy from.
|
|
num_accepted_tokens: int - number of accepted tokens used to compute the copy offset.
|
|
Range: 1 .. 1 + num_speculative_tokens (inclusive).
|
|
"""
|
|
|
|
|
|
def get_conv_copy_spec(
|
|
state: torch.Tensor,
|
|
block_ids: list[int],
|
|
cur_block_idx: int,
|
|
num_accepted_tokens: int,
|
|
) -> MambaCopySpec:
|
|
"""Return a MambaCopySpec for copying a convolutional state slice."""
|
|
src_block_id = block_ids[cur_block_idx]
|
|
src_state = state[src_block_id, num_accepted_tokens - 1 :]
|
|
return MambaCopySpec(
|
|
start_addr=src_state.data_ptr(), num_elements=src_state.numel()
|
|
)
|
|
|
|
|
|
def get_temporal_copy_spec(
|
|
state: torch.Tensor,
|
|
block_ids: list[int],
|
|
cur_block_idx: int,
|
|
num_accepted_tokens: int,
|
|
) -> MambaCopySpec:
|
|
"""Return a MambaCopySpec for copying a temporal state slice."""
|
|
src_block_id = block_ids[cur_block_idx + num_accepted_tokens - 1]
|
|
src_state = state[src_block_id]
|
|
return MambaCopySpec(
|
|
start_addr=src_state.data_ptr(), num_elements=src_state.numel()
|
|
)
|
|
|
|
|
|
get_full_copy_spec = get_temporal_copy_spec
|
|
|
|
|
|
class MambaStateCopyFuncCalculator:
|
|
@classmethod
|
|
def linear_attention_state_copy_func(cls):
|
|
return (get_temporal_copy_spec,)
|
|
|
|
@classmethod
|
|
def mamba1_state_copy_func(cls):
|
|
return (get_conv_copy_spec, get_temporal_copy_spec)
|
|
|
|
@classmethod
|
|
def mamba2_state_copy_func(cls):
|
|
return get_conv_copy_spec, get_temporal_copy_spec
|
|
|
|
@classmethod
|
|
def short_conv_state_copy_func(cls):
|
|
return (get_conv_copy_spec,)
|
|
|
|
@classmethod
|
|
def gated_delta_net_state_copy_func(cls):
|
|
return (get_conv_copy_spec, get_temporal_copy_spec)
|
|
|
|
@classmethod
|
|
def kda_state_copy_func(cls):
|
|
return (
|
|
get_conv_copy_spec,
|
|
get_conv_copy_spec,
|
|
get_conv_copy_spec,
|
|
get_temporal_copy_spec,
|
|
)
|