[V1][Hybrid] Mamba Prefix Caching with align mode (#30877)
Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com> Signed-off-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
@@ -31,6 +31,7 @@ CacheDType = Literal[
|
||||
"fp8_ds_mla",
|
||||
]
|
||||
MambaDType = Literal["auto", "float32", "float16"]
|
||||
MambaCacheMode = Literal["all", "align", "none"]
|
||||
PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor", "xxhash", "xxhash_cbor"]
|
||||
KVOffloadingBackend = Literal["native", "lmcache"]
|
||||
|
||||
@@ -123,6 +124,15 @@ class CacheConfig:
|
||||
"""The data type to use for the Mamba cache (ssm state only, conv state will
|
||||
still be controlled by mamba_cache_dtype). If set to 'auto', the data type
|
||||
for the ssm state will be determined by mamba_cache_dtype."""
|
||||
mamba_cache_mode: MambaCacheMode = "none"
|
||||
"""The cache strategy for Mamba layers.
|
||||
- "none": set when prefix caching is disabled.
|
||||
- "all": cache the mamba state of all tokens at position i * block_size. This is
|
||||
the default behavior (for models that support it) when prefix caching is
|
||||
enabled.
|
||||
- "align": only cache the mamba state of the last token of each scheduler step and
|
||||
when the token is at position i * block_size.
|
||||
"""
|
||||
|
||||
# Will be set after profiling.
|
||||
num_gpu_blocks: int | None = field(default=None, init=False)
|
||||
|
||||
@@ -999,6 +999,17 @@ class VllmConfig:
|
||||
# Default to enable HMA if not explicitly disabled by user or logic above.
|
||||
self.scheduler_config.disable_hybrid_kv_cache_manager = False
|
||||
|
||||
if self.cache_config.mamba_cache_mode == "align":
|
||||
if self.scheduler_config.long_prefill_token_threshold > 0:
|
||||
assert (
|
||||
self.scheduler_config.long_prefill_token_threshold
|
||||
>= self.cache_config.block_size
|
||||
)
|
||||
assert not self.scheduler_config.disable_chunked_mm_input, (
|
||||
"Chunked MM input is required because we need the flexibility to "
|
||||
"schedule a multiple of block_size tokens even if they are in the "
|
||||
"middle of a mm input"
|
||||
)
|
||||
if self.compilation_config.debug_dump_path:
|
||||
self.compilation_config.debug_dump_path = (
|
||||
self.compilation_config.debug_dump_path.absolute().expanduser()
|
||||
|
||||
@@ -60,6 +60,7 @@ from vllm.config.cache import (
|
||||
BlockSize,
|
||||
CacheDType,
|
||||
KVOffloadingBackend,
|
||||
MambaCacheMode,
|
||||
MambaDType,
|
||||
PrefixCachingHashAlgo,
|
||||
)
|
||||
@@ -556,6 +557,7 @@ class EngineArgs:
|
||||
mamba_cache_dtype: MambaDType = CacheConfig.mamba_cache_dtype
|
||||
mamba_ssm_cache_dtype: MambaDType = CacheConfig.mamba_ssm_cache_dtype
|
||||
mamba_block_size: int | None = get_field(CacheConfig, "mamba_block_size")
|
||||
mamba_cache_mode: MambaCacheMode = CacheConfig.mamba_cache_mode
|
||||
|
||||
additional_config: dict[str, Any] = get_field(VllmConfig, "additional_config")
|
||||
|
||||
@@ -939,6 +941,9 @@ class EngineArgs:
|
||||
cache_group.add_argument(
|
||||
"--mamba-block-size", **cache_kwargs["mamba_block_size"]
|
||||
)
|
||||
cache_group.add_argument(
|
||||
"--mamba-cache-mode", **cache_kwargs["mamba_cache_mode"]
|
||||
)
|
||||
cache_group.add_argument(
|
||||
"--kv-offloading-size", **cache_kwargs["kv_offloading_size"]
|
||||
)
|
||||
@@ -1416,6 +1421,7 @@ class EngineArgs:
|
||||
mamba_cache_dtype=self.mamba_cache_dtype,
|
||||
mamba_ssm_cache_dtype=self.mamba_ssm_cache_dtype,
|
||||
mamba_block_size=self.mamba_block_size,
|
||||
mamba_cache_mode=self.mamba_cache_mode,
|
||||
kv_offloading_size=self.kv_offloading_size,
|
||||
kv_offloading_backend=self.kv_offloading_backend,
|
||||
)
|
||||
|
||||
@@ -56,6 +56,7 @@ class MambaBase(AttentionLayerBase):
|
||||
block_size=mamba_block_size,
|
||||
page_size_padded=page_size_padded,
|
||||
mamba_type=self.mamba_type,
|
||||
mamba_cache_mode=vllm_config.cache_config.mamba_cache_mode,
|
||||
num_speculative_blocks=(
|
||||
vllm_config.speculative_config.num_speculative_tokens
|
||||
if vllm_config.speculative_config
|
||||
|
||||
@@ -255,7 +255,7 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
|
||||
assert self.cache_config is not None
|
||||
mamba_block_size = self.cache_config.mamba_block_size
|
||||
prefix_caching_enabled = self.cache_config.enable_prefix_caching
|
||||
is_mamba_cache_all = self.cache_config.mamba_cache_mode == "all"
|
||||
|
||||
if attn_metadata is not None:
|
||||
assert isinstance(attn_metadata, dict)
|
||||
@@ -304,7 +304,7 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
state_indices_tensor_p = prefill_decode_split.state_indices_tensor_p
|
||||
state_indices_tensor_d = prefill_decode_split.state_indices_tensor_d
|
||||
|
||||
if prefix_caching_enabled:
|
||||
if is_mamba_cache_all:
|
||||
block_idx_last_computed_token_d, block_idx_last_computed_token_p = (
|
||||
torch.split(
|
||||
attn_metadata.block_idx_last_computed_token,
|
||||
@@ -380,7 +380,7 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
ssm_outputs.append(scan_out_p)
|
||||
|
||||
if has_decode:
|
||||
if prefix_caching_enabled:
|
||||
if is_mamba_cache_all:
|
||||
state_indices_tensor_d_input = state_indices_tensor_d.gather(
|
||||
1, block_idx_last_computed_token_d.unsqueeze(1)
|
||||
).squeeze(1)
|
||||
|
||||
@@ -570,7 +570,7 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
|
||||
assert self.cache_config is not None
|
||||
mamba_block_size = self.cache_config.mamba_block_size
|
||||
prefix_caching_enabled = self.cache_config.enable_prefix_caching
|
||||
is_mamba_cache_all = self.cache_config.mamba_cache_mode == "all"
|
||||
if attn_metadata is not None:
|
||||
assert isinstance(attn_metadata, dict)
|
||||
attn_metadata = attn_metadata[self.prefix]
|
||||
@@ -622,7 +622,7 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
dim=0,
|
||||
)
|
||||
|
||||
if prefix_caching_enabled:
|
||||
if is_mamba_cache_all:
|
||||
# If prefix caching is enabled, retrieve the relevant variables
|
||||
# for prefill and decode
|
||||
block_idx_last_computed_token_d, block_idx_last_computed_token_p = (
|
||||
@@ -701,7 +701,7 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
initial_states = None
|
||||
if has_initial_states_p is not None and prep_initial_states:
|
||||
kernel_ssm_indices = state_indices_tensor_p
|
||||
if prefix_caching_enabled:
|
||||
if is_mamba_cache_all:
|
||||
kernel_ssm_indices = state_indices_tensor_p.gather(
|
||||
1, block_idx_last_computed_token_p.unsqueeze(1)
|
||||
).squeeze(1)
|
||||
@@ -729,14 +729,14 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
cu_chunk_seqlens=cu_chunk_seqlen_p,
|
||||
last_chunk_indices=last_chunk_indices_p,
|
||||
initial_states=initial_states,
|
||||
return_intermediate_states=prefix_caching_enabled,
|
||||
return_intermediate_states=is_mamba_cache_all,
|
||||
dt_softplus=True,
|
||||
dt_limit=(0.0, float("inf")),
|
||||
out=preallocated_ssm_out_p.view(num_prefill_tokens, -1, self.head_dim),
|
||||
state_dtype=ssm_state.dtype,
|
||||
)
|
||||
|
||||
if prefix_caching_enabled:
|
||||
if is_mamba_cache_all:
|
||||
# The chunk_stride is the number of chunks per mamba block
|
||||
# e.g., if mamba_block_size = 512 and chunk_size = 256,
|
||||
# then chunk_stride = 2
|
||||
@@ -815,7 +815,7 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
|
||||
# Process decode requests
|
||||
if has_decode:
|
||||
if prefix_caching_enabled:
|
||||
if is_mamba_cache_all:
|
||||
state_indices_tensor_d_input = state_indices_tensor_d.gather(
|
||||
1, block_idx_last_computed_token_d.unsqueeze(1)
|
||||
).squeeze(1)
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import TypeAlias
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config.cache import MambaDType
|
||||
@@ -223,3 +227,94 @@ class MambaStateShapeCalculator:
|
||||
conv_state_k_shape,
|
||||
recurrent_state_shape,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MambaCopySpec:
|
||||
"""
|
||||
Data class specifying the memory-copy parameters for Mamba states used for
|
||||
prefix caching in align mode.
|
||||
|
||||
Attributes:
|
||||
start_addr (int): Starting address for the memory copy operation.
|
||||
num_elements (int): Number of elements to copy from the starting address.
|
||||
"""
|
||||
|
||||
start_addr: int
|
||||
num_elements: int
|
||||
|
||||
|
||||
MambaStateCopyFunc: TypeAlias = Callable[
|
||||
[torch.Tensor, list[int], int, int], MambaCopySpec
|
||||
]
|
||||
"""
|
||||
Type alias for a function that computes a MambaCopySpec for copying state slices.
|
||||
Parameters:
|
||||
state: torch.Tensor - the Mamba state tensor (e.g., conv or temporal states).
|
||||
block_ids: list[int] - the list of block indices for the state to copy.
|
||||
cur_block_idx: int - current block index within `block_ids` to copy from.
|
||||
num_accepted_tokens: int - number of accepted tokens used to compute the copy offset.
|
||||
Range: 1 .. 1 + num_speculative_tokens (inclusive).
|
||||
"""
|
||||
|
||||
|
||||
def get_conv_copy_spec(
|
||||
state: torch.Tensor,
|
||||
block_ids: list[int],
|
||||
cur_block_idx: int,
|
||||
num_accepted_tokens: int,
|
||||
) -> MambaCopySpec:
|
||||
"""Return a MambaCopySpec for copying a convolutional state slice."""
|
||||
src_block_id = block_ids[cur_block_idx]
|
||||
src_state = state[src_block_id, num_accepted_tokens - 1 :]
|
||||
return MambaCopySpec(
|
||||
start_addr=src_state.data_ptr(), num_elements=src_state.numel()
|
||||
)
|
||||
|
||||
|
||||
def get_temporal_copy_spec(
|
||||
state: torch.Tensor,
|
||||
block_ids: list[int],
|
||||
cur_block_idx: int,
|
||||
num_accepted_tokens: int,
|
||||
) -> MambaCopySpec:
|
||||
"""Return a MambaCopySpec for copying a temporal state slice."""
|
||||
src_block_id = block_ids[cur_block_idx + num_accepted_tokens - 1]
|
||||
src_state = state[src_block_id]
|
||||
return MambaCopySpec(
|
||||
start_addr=src_state.data_ptr(), num_elements=src_state.numel()
|
||||
)
|
||||
|
||||
|
||||
get_full_copy_spec = get_temporal_copy_spec
|
||||
|
||||
|
||||
class MambaStateCopyFuncCalculator:
|
||||
@classmethod
|
||||
def linear_attention_state_copy_func(cls):
|
||||
return (get_temporal_copy_spec,)
|
||||
|
||||
@classmethod
|
||||
def mamba1_state_copy_func(cls):
|
||||
return (get_conv_copy_spec, get_temporal_copy_spec)
|
||||
|
||||
@classmethod
|
||||
def mamba2_state_copy_func(cls):
|
||||
return get_conv_copy_spec, get_temporal_copy_spec
|
||||
|
||||
@classmethod
|
||||
def short_conv_state_copy_func(cls):
|
||||
return (get_conv_copy_spec,)
|
||||
|
||||
@classmethod
|
||||
def gated_delta_net_state_copy_func(cls):
|
||||
return (get_conv_copy_spec, get_temporal_copy_spec)
|
||||
|
||||
@classmethod
|
||||
def kda_state_copy_func(cls):
|
||||
return (
|
||||
get_conv_copy_spec,
|
||||
get_conv_copy_spec,
|
||||
get_conv_copy_spec,
|
||||
get_temporal_copy_spec,
|
||||
)
|
||||
|
||||
@@ -24,6 +24,8 @@ from vllm.model_executor.layers.linear import (
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateCopyFunc,
|
||||
MambaStateCopyFuncCalculator,
|
||||
MambaStateDtypeCalculator,
|
||||
MambaStateShapeCalculator,
|
||||
)
|
||||
@@ -455,6 +457,10 @@ class BambaForCausalLM(
|
||||
conv_kernel=hf_config.mamba_d_conv,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]:
|
||||
return MambaStateCopyFuncCalculator.mamba2_state_copy_func()
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.vllm_config = vllm_config
|
||||
|
||||
@@ -330,26 +330,54 @@ class MambaModelConfig(VerifyAndUpdateConfig):
|
||||
cache_config = vllm_config.cache_config
|
||||
|
||||
if cache_config.enable_prefix_caching:
|
||||
if model_config.supports_mamba_prefix_caching:
|
||||
logger.info(
|
||||
"Warning: Prefix caching is currently enabled. "
|
||||
"Its support for Mamba layers is experimental. "
|
||||
"Please report any issues you may observe."
|
||||
if cache_config.mamba_cache_mode == "none":
|
||||
cache_config.mamba_cache_mode = (
|
||||
"all" if model_config.supports_mamba_prefix_caching else "align"
|
||||
)
|
||||
# By default, mamba block size will be set to max_model_len (see
|
||||
# below). When enabling prefix caching, we align mamba block size
|
||||
# to the block size as the basic granularity for prefix caching.
|
||||
if cache_config.mamba_block_size is None:
|
||||
cache_config.mamba_block_size = cache_config.block_size
|
||||
else:
|
||||
logger.info(
|
||||
"Hybrid or mamba-based model detected without "
|
||||
"support for prefix caching: disabling."
|
||||
logger.warning(
|
||||
"Mamba cache mode is set to '%s' for %s by default "
|
||||
"when prefix caching is enabled",
|
||||
cache_config.mamba_cache_mode,
|
||||
model_config.architecture,
|
||||
)
|
||||
cache_config.enable_prefix_caching = False
|
||||
|
||||
if cache_config.mamba_block_size is None:
|
||||
cache_config.mamba_block_size = model_config.max_model_len
|
||||
if (
|
||||
cache_config.mamba_cache_mode == "all"
|
||||
and not model_config.supports_mamba_prefix_caching
|
||||
):
|
||||
cache_config.mamba_cache_mode = "align"
|
||||
logger.warning(
|
||||
"Hybrid or mamba-based model detected without support "
|
||||
"for prefix caching with Mamba cache 'all' mode: "
|
||||
"falling back to 'align' mode."
|
||||
)
|
||||
if cache_config.mamba_cache_mode == "align":
|
||||
assert vllm_config.scheduler_config.enable_chunked_prefill, (
|
||||
"Chunked prefill is required for mamba cache mode 'align'."
|
||||
)
|
||||
assert not vllm_config.speculative_config, (
|
||||
"Mamba cache mode 'align' is currently not compatible "
|
||||
"with speculative decoding."
|
||||
)
|
||||
logger.info(
|
||||
"Warning: Prefix caching in Mamba cache '%s' "
|
||||
"mode is currently enabled. "
|
||||
"Its support for Mamba layers is experimental. "
|
||||
"Please report any issues you may observe.",
|
||||
cache_config.mamba_cache_mode,
|
||||
)
|
||||
# By default, mamba block size will be set to max_model_len (see
|
||||
# below). When enabling prefix caching, we align mamba block size
|
||||
# to the block size as the basic granularity for prefix caching.
|
||||
if cache_config.mamba_block_size is None:
|
||||
cache_config.mamba_block_size = cache_config.block_size
|
||||
else:
|
||||
if cache_config.mamba_cache_mode != "none":
|
||||
cache_config.mamba_cache_mode = "none"
|
||||
logger.warning(
|
||||
"Mamba cache mode is set to 'none' when prefix caching is disabled"
|
||||
)
|
||||
if cache_config.mamba_block_size is None:
|
||||
cache_config.mamba_block_size = model_config.max_model_len
|
||||
|
||||
|
||||
class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
|
||||
@@ -426,7 +454,7 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
|
||||
mamba_page_size = MambaSpec(
|
||||
shapes=model_cls.get_mamba_state_shape_from_config(vllm_config),
|
||||
dtypes=model_cls.get_mamba_state_dtype_from_config(vllm_config),
|
||||
block_size=model_config.max_model_len,
|
||||
block_size=-1, # block_size doesn't matter for mamba page size
|
||||
).page_size_bytes
|
||||
|
||||
# Model may be marked as is_hybrid
|
||||
@@ -435,7 +463,7 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
|
||||
if mamba_page_size == 0:
|
||||
return
|
||||
|
||||
if cache_config.enable_prefix_caching:
|
||||
if cache_config.mamba_cache_mode == "all":
|
||||
# With prefix caching, select attention block size to
|
||||
# optimize for mamba kernel performance
|
||||
|
||||
@@ -479,6 +507,13 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
|
||||
attn_block_size,
|
||||
)
|
||||
|
||||
# By default, mamba block size will be set to max_model_len.
|
||||
# When enabling prefix caching and using align mamba cache
|
||||
# mode, we align mamba block size to the block size as the
|
||||
# basic granularity for prefix caching.
|
||||
if cache_config.mamba_cache_mode == "align":
|
||||
cache_config.mamba_block_size = cache_config.block_size
|
||||
|
||||
# compute new attention page size
|
||||
attn_page_size = cache_config.block_size * attn_page_size_1_token
|
||||
|
||||
|
||||
@@ -24,6 +24,8 @@ from vllm.model_executor.layers.linear import (
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateCopyFunc,
|
||||
MambaStateCopyFuncCalculator,
|
||||
MambaStateDtypeCalculator,
|
||||
MambaStateShapeCalculator,
|
||||
)
|
||||
@@ -551,6 +553,10 @@ class FalconH1ForCausalLM(
|
||||
conv_kernel=hf_config.mamba_d_conv,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]:
|
||||
return MambaStateCopyFuncCalculator.mamba2_state_copy_func()
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.vllm_config = vllm_config
|
||||
|
||||
@@ -19,6 +19,8 @@ from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLine
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateCopyFunc,
|
||||
MambaStateCopyFuncCalculator,
|
||||
MambaStateDtypeCalculator,
|
||||
MambaStateShapeCalculator,
|
||||
)
|
||||
@@ -641,6 +643,10 @@ class GraniteMoeHybridForCausalLM(
|
||||
conv_kernel=hf_config.mamba_d_conv,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]:
|
||||
return MambaStateCopyFuncCalculator.mamba2_state_copy_func()
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ from vllm.config import ModelConfig, SpeechToTextConfig
|
||||
from vllm.inputs import TokensPrompt
|
||||
from vllm.inputs.data import PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import MambaStateCopyFunc
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.utils.collection_utils import common_prefix
|
||||
from vllm.utils.func_utils import supports_kw
|
||||
@@ -776,6 +777,19 @@ class IsHybrid(Protocol):
|
||||
"""
|
||||
...
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, ...]:
|
||||
"""Calculate copy-function callables for each Mamba state.
|
||||
|
||||
Returns:
|
||||
A tuple of MambaStateCopyFunc callables that correspond, in order,
|
||||
to the Mamba states produced by the model. Each callable accepts
|
||||
(state, block_ids, cur_block_idx, num_accepted_tokens) and returns
|
||||
a MambaCopySpec describing the memory-copy parameters for prefix
|
||||
caching in align mode.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def is_hybrid(model: object) -> TypeIs[IsHybrid]: ...
|
||||
|
||||
@@ -24,6 +24,8 @@ from vllm.model_executor.layers.linear import (
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateCopyFunc,
|
||||
MambaStateCopyFuncCalculator,
|
||||
MambaStateDtypeCalculator,
|
||||
MambaStateShapeCalculator,
|
||||
)
|
||||
@@ -558,6 +560,10 @@ class JambaForCausalLM(
|
||||
conv_kernel=hf_config.mamba_d_conv,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]:
|
||||
return MambaStateCopyFuncCalculator.mamba1_state_copy_func()
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
||||
@@ -26,6 +26,8 @@ from vllm.model_executor.layers.linear import (
|
||||
)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateCopyFunc,
|
||||
MambaStateCopyFuncCalculator,
|
||||
MambaStateDtypeCalculator,
|
||||
MambaStateShapeCalculator,
|
||||
)
|
||||
@@ -544,6 +546,14 @@ class KimiLinearForCausalLM(
|
||||
num_spec=num_spec,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_copy_func(
|
||||
cls,
|
||||
) -> tuple[
|
||||
MambaStateCopyFunc, MambaStateCopyFunc, MambaStateCopyFunc, MambaStateCopyFunc
|
||||
]:
|
||||
return MambaStateCopyFuncCalculator.kda_state_copy_func()
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
||||
@@ -20,6 +20,8 @@ from vllm.model_executor.layers.linear import (
|
||||
)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateCopyFunc,
|
||||
MambaStateCopyFuncCalculator,
|
||||
MambaStateDtypeCalculator,
|
||||
MambaStateShapeCalculator,
|
||||
)
|
||||
@@ -459,14 +461,19 @@ class Lfm2ForCausalLM(
|
||||
conv_kernel=hf_config.conv_L_cache,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc]:
|
||||
return MambaStateCopyFuncCalculator.short_conv_state_copy_func()
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
cache_config = vllm_config.cache_config
|
||||
|
||||
assert not cache_config.enable_prefix_caching, (
|
||||
"Lfm2 currently does not support prefix caching"
|
||||
)
|
||||
if cache_config.mamba_cache_mode == "all":
|
||||
raise NotImplementedError(
|
||||
"Lfm2 currently does not support 'all' prefix caching, "
|
||||
"please use '--mamba-cache-mode=align' instead"
|
||||
)
|
||||
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
@@ -25,6 +25,8 @@ from vllm.model_executor.layers.linear import (
|
||||
)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateCopyFunc,
|
||||
MambaStateCopyFuncCalculator,
|
||||
MambaStateDtypeCalculator,
|
||||
MambaStateShapeCalculator,
|
||||
)
|
||||
@@ -640,6 +642,10 @@ class Lfm2MoeForCausalLM(
|
||||
conv_kernel=hf_config.conv_L_cache,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc]:
|
||||
return MambaStateCopyFuncCalculator.short_conv_state_copy_func()
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
@@ -16,6 +16,8 @@ from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateCopyFunc,
|
||||
MambaStateCopyFuncCalculator,
|
||||
MambaStateDtypeCalculator,
|
||||
MambaStateShapeCalculator,
|
||||
)
|
||||
@@ -261,6 +263,10 @@ class MambaForCausalLM(
|
||||
conv_kernel=hf_config.conv_kernel,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]:
|
||||
return MambaStateCopyFuncCalculator.mamba1_state_copy_func()
|
||||
|
||||
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
|
||||
return self.mamba_cache.copy_inputs_before_cuda_graphs(input_buffers, **kwargs)
|
||||
|
||||
|
||||
@@ -15,6 +15,8 @@ from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateCopyFunc,
|
||||
MambaStateCopyFuncCalculator,
|
||||
MambaStateDtypeCalculator,
|
||||
MambaStateShapeCalculator,
|
||||
)
|
||||
@@ -228,6 +230,10 @@ class Mamba2ForCausalLM(
|
||||
conv_kernel=hf_config.conv_kernel,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]:
|
||||
return MambaStateCopyFuncCalculator.mamba2_state_copy_func()
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
||||
|
||||
@@ -35,6 +35,8 @@ from vllm.model_executor.layers.linear import (
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.linear_attn import MiniMaxText01LinearAttention
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateCopyFunc,
|
||||
MambaStateCopyFuncCalculator,
|
||||
MambaStateDtypeCalculator,
|
||||
MambaStateShapeCalculator,
|
||||
)
|
||||
@@ -1006,3 +1008,7 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
||||
tp_size=parallel_config.tensor_parallel_size,
|
||||
head_dim=hf_config.head_dim,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc]:
|
||||
return MambaStateCopyFuncCalculator.linear_attention_state_copy_func()
|
||||
|
||||
@@ -2128,3 +2128,7 @@ class NemotronH_Nano_VL_V2(
|
||||
temp_vllm_config = copy.deepcopy(vllm_config)
|
||||
temp_vllm_config.model_config.hf_config = text_config
|
||||
return NemotronHForCausalLM.get_mamba_state_dtype_from_config(temp_vllm_config)
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_copy_func(cls):
|
||||
return NemotronHForCausalLM.get_mamba_state_copy_func()
|
||||
|
||||
@@ -45,6 +45,8 @@ from vllm.model_executor.layers.linear import (
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateCopyFunc,
|
||||
MambaStateCopyFuncCalculator,
|
||||
MambaStateDtypeCalculator,
|
||||
MambaStateShapeCalculator,
|
||||
)
|
||||
@@ -809,6 +811,10 @@ class NemotronHForCausalLM(
|
||||
conv_kernel=hf_config.conv_kernel,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]:
|
||||
return MambaStateCopyFuncCalculator.mamba2_state_copy_func()
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.vllm_config = vllm_config
|
||||
|
||||
@@ -27,6 +27,8 @@ from vllm.model_executor.layers.linear import (
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateCopyFunc,
|
||||
MambaStateCopyFuncCalculator,
|
||||
MambaStateDtypeCalculator,
|
||||
MambaStateShapeCalculator,
|
||||
)
|
||||
@@ -899,6 +901,10 @@ class Plamo2ForCausalLM(
|
||||
conv_kernel=hf_config.mamba_d_conv,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]:
|
||||
return MambaStateCopyFuncCalculator.mamba2_state_copy_func()
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
||||
@@ -48,6 +48,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import mamba_v2_sharded_weight_loader
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateCopyFunc,
|
||||
MambaStateCopyFuncCalculator,
|
||||
MambaStateDtypeCalculator,
|
||||
MambaStateShapeCalculator,
|
||||
)
|
||||
@@ -1205,9 +1207,11 @@ class Qwen3NextForCausalLM(
|
||||
cache_config = vllm_config.cache_config
|
||||
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
assert not cache_config.enable_prefix_caching, (
|
||||
"Qwen3Next currently does not support prefix caching"
|
||||
)
|
||||
if cache_config.mamba_cache_mode == "all":
|
||||
raise NotImplementedError(
|
||||
"Qwen3Next currently does not support 'all' prefix caching, "
|
||||
"please use '--mamba-cache-mode=align' instead"
|
||||
)
|
||||
self.quant_config = vllm_config.quant_config
|
||||
|
||||
super().__init__()
|
||||
@@ -1278,6 +1282,10 @@ class Qwen3NextForCausalLM(
|
||||
num_spec,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]:
|
||||
return MambaStateCopyFuncCalculator.gated_delta_net_state_copy_func()
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
||||
@@ -234,9 +234,11 @@ class Qwen3NextMTP(nn.Module, QwenNextMixtureOfExperts):
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.vllm_config = vllm_config
|
||||
cache_config = vllm_config.cache_config
|
||||
assert not cache_config.enable_prefix_caching, (
|
||||
"Qwen3NextMTP currently does not support prefix caching"
|
||||
)
|
||||
if cache_config.mamba_cache_mode == "all":
|
||||
raise NotImplementedError(
|
||||
"Qwen3NextMTP currently does not support 'all' prefix caching, "
|
||||
"please use '--mamba-cache-mode=align' instead"
|
||||
)
|
||||
|
||||
self.quant_config = vllm_config.quant_config
|
||||
|
||||
|
||||
@@ -32,6 +32,8 @@ from vllm.model_executor.layers.linear import (
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateCopyFunc,
|
||||
MambaStateCopyFuncCalculator,
|
||||
MambaStateDtypeCalculator,
|
||||
MambaStateShapeCalculator,
|
||||
)
|
||||
@@ -891,6 +893,10 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsMambaPrefixC
|
||||
conv_kernel=hf_config.mamba_d_conv,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]:
|
||||
return MambaStateCopyFuncCalculator.mamba2_state_copy_func()
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
"""Initialize the Zamba2 model for causal language modeling.
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ from vllm.v1.attention.backend import (
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
PAD_SLOT_ID,
|
||||
compute_causal_conv1d_metadata,
|
||||
mamba_get_block_table_tensor,
|
||||
split_decodes_and_prefills,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
|
||||
@@ -158,6 +159,12 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
query_start_loc_cpu = m.query_start_loc_cpu
|
||||
context_lens_tensor = m.compute_num_computed_tokens()
|
||||
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
|
||||
block_table_tensor = mamba_get_block_table_tensor(
|
||||
m.block_table_tensor,
|
||||
m.seq_lens,
|
||||
self.kv_cache_spec,
|
||||
self.vllm_config.cache_config.mamba_cache_mode,
|
||||
)
|
||||
|
||||
spec_sequence_masks_cpu: torch.Tensor | None = None
|
||||
if (
|
||||
@@ -189,7 +196,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
spec_token_indx = None
|
||||
non_spec_token_indx = None
|
||||
spec_state_indices_tensor = None
|
||||
non_spec_state_indices_tensor = m.block_table_tensor[:, 0]
|
||||
non_spec_state_indices_tensor = block_table_tensor[:, 0]
|
||||
spec_query_start_loc = None
|
||||
non_spec_query_start_loc = query_start_loc
|
||||
non_spec_query_start_loc_cpu = query_start_loc_cpu
|
||||
@@ -221,7 +228,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
non_spec_token_indx = torch.empty(
|
||||
0, dtype=torch.int32, device=query_start_loc.device
|
||||
)
|
||||
spec_state_indices_tensor = m.block_table_tensor[:, : self.num_spec + 1]
|
||||
spec_state_indices_tensor = block_table_tensor[:, : self.num_spec + 1]
|
||||
non_spec_state_indices_tensor = None
|
||||
spec_query_start_loc = query_start_loc
|
||||
non_spec_query_start_loc = None
|
||||
@@ -235,10 +242,10 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
|
||||
non_spec_token_indx = index[:num_non_spec_tokens]
|
||||
spec_token_indx = index[num_non_spec_tokens:]
|
||||
|
||||
spec_state_indices_tensor = m.block_table_tensor[
|
||||
spec_state_indices_tensor = block_table_tensor[
|
||||
spec_sequence_masks, : self.num_spec + 1
|
||||
]
|
||||
non_spec_state_indices_tensor = m.block_table_tensor[
|
||||
non_spec_state_indices_tensor = block_table_tensor[
|
||||
~spec_sequence_masks, 0
|
||||
]
|
||||
|
||||
|
||||
@@ -11,7 +11,10 @@ from vllm.v1.attention.backend import (
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import split_decodes_and_prefills
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
mamba_get_block_table_tensor,
|
||||
split_decodes_and_prefills,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
|
||||
|
||||
|
||||
@@ -61,7 +64,12 @@ class LinearAttentionMetadataBuilder(AttentionMetadataBuilder[LinearAttentionMet
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
|
||||
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
|
||||
state_indices_tensor = mamba_get_block_table_tensor(
|
||||
common_attn_metadata.block_table_tensor,
|
||||
common_attn_metadata.seq_lens,
|
||||
self.kv_cache_spec,
|
||||
self.vllm_config.cache_config.mamba_cache_mode,
|
||||
)[:, 0]
|
||||
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
split_decodes_and_prefills(
|
||||
|
||||
@@ -18,6 +18,7 @@ from vllm.v1.attention.backend import (
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
PAD_SLOT_ID,
|
||||
compute_causal_conv1d_metadata,
|
||||
mamba_get_block_table_tensor,
|
||||
split_decodes_and_prefills,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
|
||||
@@ -41,11 +42,15 @@ class BaseMambaAttentionMetadata:
|
||||
|
||||
state_indices_tensor: torch.Tensor
|
||||
|
||||
# The following tensors are only used for prefix caching and are None if disabled
|
||||
# The following tensors are only used for prefix caching in all mode and
|
||||
# are None if disabled
|
||||
block_idx_last_scheduled_token: torch.Tensor | None
|
||||
block_idx_first_scheduled_token_p: torch.Tensor | None
|
||||
block_idx_last_computed_token: torch.Tensor | None
|
||||
|
||||
# The following tensor is only used for prefix caching in align mode
|
||||
seq_lens: torch.Tensor
|
||||
|
||||
# The following attributes are for triton implementation of causal_conv1d
|
||||
nums_dict: dict | None = None
|
||||
batch_ptr: torch.Tensor | None = None
|
||||
@@ -78,7 +83,7 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
|
||||
self.compilation_config.max_cudagraph_capture_size,
|
||||
)
|
||||
|
||||
if self.vllm_config.cache_config.enable_prefix_caching:
|
||||
if self.vllm_config.cache_config.mamba_cache_mode == "all":
|
||||
self.state_indices_tensor = torch.empty(
|
||||
(
|
||||
self.decode_cudagraph_max_bs,
|
||||
@@ -198,7 +203,7 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
|
||||
# for causal_conv1d
|
||||
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
|
||||
|
||||
if self.vllm_config.cache_config.enable_prefix_caching:
|
||||
if self.vllm_config.cache_config.mamba_cache_mode == "all":
|
||||
num_computed_tokens = common_attn_metadata.compute_num_computed_tokens()
|
||||
|
||||
# Return a tensor of shape (#requests, #max blocks)
|
||||
@@ -214,7 +219,12 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
|
||||
)
|
||||
else:
|
||||
# Always return just a single block per each request:
|
||||
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
|
||||
state_indices_tensor = mamba_get_block_table_tensor(
|
||||
common_attn_metadata.block_table_tensor,
|
||||
common_attn_metadata.seq_lens,
|
||||
self.kv_cache_spec,
|
||||
self.vllm_config.cache_config.mamba_cache_mode,
|
||||
)[:, 0]
|
||||
|
||||
if num_prefills > 0:
|
||||
if num_computed_tokens is None:
|
||||
@@ -239,7 +249,7 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
|
||||
)
|
||||
)
|
||||
|
||||
if self.vllm_config.cache_config.enable_prefix_caching:
|
||||
if self.vllm_config.cache_config.mamba_cache_mode == "all":
|
||||
assert num_computed_tokens is not None
|
||||
num_computed_tokens_p = num_computed_tokens[
|
||||
num_reqs - num_prefills : num_reqs
|
||||
@@ -258,7 +268,7 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
|
||||
state_indices_tensor = self.state_indices_tensor[:num_decode_tokens]
|
||||
state_indices_tensor[num_decodes:] = PAD_SLOT_ID
|
||||
|
||||
if self.vllm_config.cache_config.enable_prefix_caching:
|
||||
if self.vllm_config.cache_config.mamba_cache_mode == "all":
|
||||
self.block_idx_last_scheduled_token[:num_decodes].copy_(
|
||||
block_idx_last_scheduled_token, non_blocking=True
|
||||
)
|
||||
@@ -286,6 +296,7 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
|
||||
block_idx_last_computed_token=block_idx_last_computed_token,
|
||||
num_computed_tokens_p=num_computed_tokens_p,
|
||||
num_reqs=num_reqs,
|
||||
seq_lens=common_attn_metadata.seq_lens,
|
||||
nums_dict=nums_dict,
|
||||
batch_ptr=batch_ptr,
|
||||
token_chunk_offset_ptr=token_chunk_offset_ptr,
|
||||
@@ -298,8 +309,16 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
|
||||
slot_mapping: torch.Tensor,
|
||||
) -> M:
|
||||
new_metadata = copy.copy(metadata)
|
||||
prefix_caching = self.vllm_config.cache_config.enable_prefix_caching
|
||||
state_indices_t = blk_table if prefix_caching else blk_table[:, 0]
|
||||
state_indices_t = mamba_get_block_table_tensor(
|
||||
blk_table,
|
||||
metadata.seq_lens,
|
||||
self.kv_cache_spec,
|
||||
self.vllm_config.cache_config.mamba_cache_mode,
|
||||
)
|
||||
if self.vllm_config.cache_config.mamba_cache_mode in ("none", "align"):
|
||||
# Only needs the block that saves the running state
|
||||
state_indices_t = state_indices_t[:, 0]
|
||||
|
||||
num_reqs = blk_table.shape[0]
|
||||
|
||||
# For CUDA graphs, copy to persistent buffer
|
||||
|
||||
@@ -17,6 +17,7 @@ from typing_extensions import runtime_checkable
|
||||
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.kv_cache_interface import KVCacheSpec, MambaSpec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
@@ -854,3 +855,40 @@ def extend_all_queries_by_1(
|
||||
slot_mapping=new_slot_mapping,
|
||||
)
|
||||
return new_cad
|
||||
|
||||
|
||||
def mamba_get_block_table_tensor(
|
||||
block_table: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
mamba_cache_mode: str,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Get the block table tensor for mamba kernels from the input
|
||||
common_attn_metadata.block_table_tensor given different mamba cache modes.
|
||||
|
||||
- "all": input (#requests, cdiv(max_model_len, block_size));
|
||||
output (#requests, cdiv(max_model_len, block_size)).
|
||||
|
||||
- "none": input (#requests, 1 + num_speculative_blocks);
|
||||
output (#requests, 1 + num_speculative_blocks).
|
||||
|
||||
- "align": input (#requests, cdiv(max_model_len, block_size));
|
||||
output (#requests, 1 + num_speculative_blocks), which are the last
|
||||
1 + num_speculative_blocks of each request.
|
||||
"""
|
||||
if mamba_cache_mode in ("all", "none"):
|
||||
return block_table
|
||||
else:
|
||||
assert isinstance(kv_cache_spec, MambaSpec)
|
||||
# NOTE: For 0-length requests in CUDA graph, use a start_index of 0
|
||||
# to handle the invalid block table.
|
||||
start_indices = torch.clamp(
|
||||
(seq_lens - 1) // kv_cache_spec.block_size,
|
||||
min=0,
|
||||
)
|
||||
offsets = torch.arange(
|
||||
1 + kv_cache_spec.num_speculative_blocks, device=block_table.device
|
||||
)
|
||||
indices_to_gather = start_indices.unsqueeze(1) + offsets
|
||||
return torch.gather(block_table, 1, indices_to_gather)
|
||||
|
||||
@@ -255,7 +255,8 @@ class BlockPool:
|
||||
)
|
||||
for i, blk in enumerate(new_full_blocks):
|
||||
# Some blocks may be null blocks when enabling sparse attention like
|
||||
# sliding window attention. We skip null blocks here.
|
||||
# sliding window attention, or Mamba models with prefix-caching in
|
||||
# align mode. We skip null blocks here.
|
||||
if blk.is_null:
|
||||
continue
|
||||
assert blk.block_hash is None
|
||||
|
||||
@@ -75,6 +75,7 @@ class KVCacheCoordinator(ABC):
|
||||
new_computed_blocks: tuple[Sequence[KVCacheBlock], ...],
|
||||
num_encoder_tokens: int,
|
||||
total_computed_tokens: int,
|
||||
num_tokens_main_model: int,
|
||||
) -> int:
|
||||
"""
|
||||
Get the number of blocks needed to be allocated for the request.
|
||||
@@ -88,6 +89,9 @@ class KVCacheCoordinator(ABC):
|
||||
num_encoder_tokens: The number of encoder tokens for allocating
|
||||
blocks for cross-attention.
|
||||
total_computed_tokens: Include both local and external tokens.
|
||||
num_tokens_main_model: The number of tokens for the main model (aka target
|
||||
model in spec decode). w/o spec decode, it is num_tokens;
|
||||
with spec decode, it is num_tokens - num_lookahead_tokens.
|
||||
|
||||
Returns:
|
||||
The number of blocks to allocate.
|
||||
@@ -98,7 +102,7 @@ class KVCacheCoordinator(ABC):
|
||||
# For cross-attention, we issue a single static allocation
|
||||
# of blocks based on the number of encoder input tokens.
|
||||
num_blocks_to_allocate += manager.get_num_blocks_to_allocate(
|
||||
request_id, num_encoder_tokens, [], 0
|
||||
request_id, num_encoder_tokens, [], 0, num_encoder_tokens
|
||||
)
|
||||
else:
|
||||
num_blocks_to_allocate += manager.get_num_blocks_to_allocate(
|
||||
@@ -106,6 +110,7 @@ class KVCacheCoordinator(ABC):
|
||||
num_tokens,
|
||||
new_computed_blocks[i],
|
||||
total_computed_tokens,
|
||||
num_tokens_main_model,
|
||||
)
|
||||
return num_blocks_to_allocate
|
||||
|
||||
@@ -139,6 +144,7 @@ class KVCacheCoordinator(ABC):
|
||||
self,
|
||||
request_id: str,
|
||||
num_tokens: int,
|
||||
num_tokens_main_model: int,
|
||||
num_encoder_tokens: int = 0,
|
||||
) -> tuple[list[KVCacheBlock], ...]:
|
||||
"""
|
||||
@@ -149,6 +155,9 @@ class KVCacheCoordinator(ABC):
|
||||
request_id: The request ID.
|
||||
num_tokens: The total number of tokens that need a slot (including
|
||||
tokens that are already allocated).
|
||||
num_tokens_main_model: The number of tokens for the main model (aka target
|
||||
model in spec decode). w/o spec decode, it is num_tokens;
|
||||
with spec decode, it is num_tokens - num_lookahead_tokens.
|
||||
num_encoder_tokens: The number of encoder tokens for allocating
|
||||
blocks for cross-attention.
|
||||
|
||||
@@ -161,6 +170,7 @@ class KVCacheCoordinator(ABC):
|
||||
num_encoder_tokens
|
||||
if isinstance(manager, CrossAttentionManager)
|
||||
else num_tokens,
|
||||
num_tokens_main_model,
|
||||
)
|
||||
for manager in self.single_type_managers
|
||||
)
|
||||
|
||||
@@ -307,8 +307,9 @@ class KVCacheManager:
|
||||
num_local_computed_tokens + num_external_computed_tokens,
|
||||
self.max_model_len,
|
||||
)
|
||||
num_tokens_main_model = total_computed_tokens + num_new_tokens
|
||||
num_tokens_need_slot = min(
|
||||
total_computed_tokens + num_new_tokens + num_lookahead_tokens,
|
||||
num_tokens_main_model + num_lookahead_tokens,
|
||||
self.max_model_len,
|
||||
)
|
||||
|
||||
@@ -329,6 +330,7 @@ class KVCacheManager:
|
||||
num_encoder_tokens=num_encoder_tokens,
|
||||
total_computed_tokens=num_local_computed_tokens
|
||||
+ num_external_computed_tokens,
|
||||
num_tokens_main_model=num_tokens_main_model,
|
||||
)
|
||||
|
||||
if num_blocks_to_allocate > self.block_pool.get_num_free_blocks():
|
||||
@@ -349,7 +351,10 @@ class KVCacheManager:
|
||||
)
|
||||
|
||||
new_blocks = self.coordinator.allocate_new_blocks(
|
||||
request.request_id, num_tokens_need_slot, num_encoder_tokens
|
||||
request.request_id,
|
||||
num_tokens_need_slot,
|
||||
num_tokens_main_model,
|
||||
num_encoder_tokens,
|
||||
)
|
||||
|
||||
# P/D: delay caching blocks if we have to recv from
|
||||
|
||||
@@ -47,7 +47,7 @@ from vllm.v1.core.sched.output import (
|
||||
from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_queue
|
||||
from vllm.v1.core.sched.utils import check_stop, remove_all
|
||||
from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig, MambaSpec
|
||||
from vllm.v1.metrics.perf import ModelMetrics, PerfStats
|
||||
from vllm.v1.metrics.stats import (
|
||||
PrefixCacheStats,
|
||||
@@ -226,6 +226,17 @@ class Scheduler(SchedulerInterface):
|
||||
)
|
||||
self.use_pp = self.parallel_config.pipeline_parallel_size > 1
|
||||
self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER
|
||||
|
||||
def has_mamba_layers(kv_cache_config: KVCacheConfig) -> bool:
|
||||
return any(
|
||||
isinstance(group_spec.kv_cache_spec, MambaSpec)
|
||||
for group_spec in kv_cache_config.kv_cache_groups
|
||||
)
|
||||
|
||||
self.has_mamba_layers = has_mamba_layers(kv_cache_config)
|
||||
self.need_mamba_block_aligned_split = (
|
||||
self.has_mamba_layers and self.cache_config.mamba_cache_mode == "align"
|
||||
)
|
||||
self.perf_metrics: ModelMetrics | None = None
|
||||
if self.log_stats and vllm_config.observability_config.enable_mfu_metrics:
|
||||
self.perf_metrics = ModelMetrics(vllm_config)
|
||||
@@ -250,6 +261,53 @@ class Scheduler(SchedulerInterface):
|
||||
vllm_config=self.vllm_config,
|
||||
)
|
||||
|
||||
def _mamba_block_aligned_split(
|
||||
self,
|
||||
request: Request,
|
||||
num_new_tokens: int,
|
||||
num_new_local_computed_tokens: int = 0,
|
||||
num_external_computed_tokens: int = 0,
|
||||
) -> int:
|
||||
assert num_external_computed_tokens == 0, (
|
||||
"External KV connector is not verified yet"
|
||||
)
|
||||
# TODO: need check for resume requests
|
||||
if request.num_output_tokens == 0: # prefill
|
||||
# To enable block-aligned caching of the Mamba state, `num_new_tokens`
|
||||
# must be a multiple of `block_size`.
|
||||
# As an exception, if `num_new_tokens` is less than `block_size`, the
|
||||
# state is simply not cached, requiring no special handling.
|
||||
# Additionally, when Eagle mode is enabled, FullAttn prunes the last
|
||||
# matching block. To prevent this from causing a Mamba cache miss, the
|
||||
# last chunk must be larger than `block_size`.
|
||||
block_size = self.cache_config.block_size
|
||||
last_cache_position = (
|
||||
request.num_prompt_tokens - request.num_prompt_tokens % block_size
|
||||
)
|
||||
# eagle prune
|
||||
if self.use_eagle:
|
||||
last_cache_position = max(last_cache_position - block_size, 0)
|
||||
num_computed_tokens = (
|
||||
request.num_computed_tokens
|
||||
+ num_new_local_computed_tokens
|
||||
+ num_external_computed_tokens
|
||||
)
|
||||
num_computed_tokens_after_sched = num_computed_tokens + num_new_tokens
|
||||
if num_computed_tokens_after_sched < last_cache_position:
|
||||
# align to block_size
|
||||
num_new_tokens = num_new_tokens // block_size * block_size
|
||||
elif (
|
||||
num_computed_tokens
|
||||
< last_cache_position
|
||||
< num_computed_tokens_after_sched
|
||||
):
|
||||
# force to cache the last chunk
|
||||
num_new_tokens = last_cache_position - num_computed_tokens
|
||||
else:
|
||||
# prefill the last few tokens
|
||||
pass
|
||||
return num_new_tokens
|
||||
|
||||
def schedule(self) -> SchedulerOutput:
|
||||
# NOTE(woosuk) on the scheduling algorithm:
|
||||
# There's no "decoding phase" nor "prefill phase" in the scheduler.
|
||||
@@ -340,6 +398,11 @@ class Scheduler(SchedulerInterface):
|
||||
shift_computed_tokens=1 if self.use_eagle else 0,
|
||||
)
|
||||
|
||||
if self.need_mamba_block_aligned_split:
|
||||
num_new_tokens = self._mamba_block_aligned_split(
|
||||
request, num_new_tokens
|
||||
)
|
||||
|
||||
if num_new_tokens == 0:
|
||||
# The request cannot be scheduled because one of the following
|
||||
# reasons:
|
||||
@@ -350,6 +413,8 @@ class Scheduler(SchedulerInterface):
|
||||
# its max_total_tokens or max_model_len.
|
||||
# 2. The encoder budget is exhausted.
|
||||
# 3. The encoder cache is exhausted.
|
||||
# 4. Insufficient budget for a block-aligned chunk in hybrid
|
||||
# models with mamba cache mode \"align\".
|
||||
# NOTE(woosuk): Here, by doing `continue` instead of `break`,
|
||||
# we do not strictly follow the FCFS scheduling policy and
|
||||
# allow the lower-priority requests to be scheduled.
|
||||
@@ -608,6 +673,16 @@ class Scheduler(SchedulerInterface):
|
||||
# The request cannot be scheduled.
|
||||
break
|
||||
|
||||
if self.need_mamba_block_aligned_split:
|
||||
num_new_tokens = self._mamba_block_aligned_split(
|
||||
request,
|
||||
num_new_tokens,
|
||||
num_new_local_computed_tokens,
|
||||
num_external_computed_tokens,
|
||||
)
|
||||
if num_new_tokens == 0:
|
||||
break
|
||||
|
||||
# Handles an edge case when P/D Disaggregation
|
||||
# is used with Spec Decoding where an
|
||||
# extra block gets allocated which
|
||||
|
||||
@@ -66,12 +66,17 @@ class SingleTypeKVCacheManager(ABC):
|
||||
self.kv_cache_group_id = kv_cache_group_id
|
||||
self._null_block = block_pool.null_block
|
||||
|
||||
@classmethod
|
||||
def _get_num_evictable_blocks(cls, blocks: Sequence[KVCacheBlock]):
|
||||
return sum(blk.ref_cnt == 0 and not blk.is_null for blk in blocks)
|
||||
|
||||
def get_num_blocks_to_allocate(
|
||||
self,
|
||||
request_id: str,
|
||||
num_tokens: int,
|
||||
new_computed_blocks: Sequence[KVCacheBlock],
|
||||
total_computed_tokens: int,
|
||||
num_tokens_main_model: int,
|
||||
) -> int:
|
||||
"""
|
||||
Get the number of blocks needed to be allocated for the request.
|
||||
@@ -84,6 +89,9 @@ class SingleTypeKVCacheManager(ABC):
|
||||
prefix caching.
|
||||
total_computed_tokens: Include both local and external computed
|
||||
tokens.
|
||||
num_tokens_main_model: The number of tokens for the main model (aka target
|
||||
model in spec decode). w/o spec decode, it is num_tokens;
|
||||
with spec decode, it is num_tokens - num_lookahead_tokens.
|
||||
|
||||
Returns:
|
||||
The number of blocks to allocate.
|
||||
@@ -121,9 +129,8 @@ class SingleTypeKVCacheManager(ABC):
|
||||
# If a computed block is an eviction candidate (in the free queue and
|
||||
# ref_cnt == 0), it will be removed from the free queue when touched by
|
||||
# the allocated request, so we must count it in the free-capacity check.
|
||||
num_evictable_blocks = sum(
|
||||
blk.ref_cnt == 0 and not blk.is_null
|
||||
for blk in new_computed_blocks[num_skipped_new_computed_blocks:]
|
||||
num_evictable_blocks = self._get_num_evictable_blocks(
|
||||
new_computed_blocks[num_skipped_new_computed_blocks:]
|
||||
)
|
||||
return num_new_blocks + num_evictable_blocks
|
||||
|
||||
@@ -199,7 +206,7 @@ class SingleTypeKVCacheManager(ABC):
|
||||
req_blocks.extend(allocated_blocks)
|
||||
|
||||
def allocate_new_blocks(
|
||||
self, request_id: str, num_tokens: int
|
||||
self, request_id: str, num_tokens: int, num_tokens_main_model: int
|
||||
) -> list[KVCacheBlock]:
|
||||
"""
|
||||
Allocate new blocks for the request to give it at least `num_tokens`
|
||||
@@ -209,7 +216,9 @@ class SingleTypeKVCacheManager(ABC):
|
||||
request_id: The request ID.
|
||||
num_tokens: The total number of tokens that need a slot (including
|
||||
tokens that are already allocated).
|
||||
|
||||
num_tokens_main_model: The number of tokens for the main model (aka target
|
||||
model in spec decode). w/o spec decode, it is num_tokens;
|
||||
with spec decode, it is num_tokens - num_lookahead_tokens.
|
||||
Returns:
|
||||
The new allocated blocks.
|
||||
"""
|
||||
@@ -450,12 +459,9 @@ class FullAttentionManager(SingleTypeKVCacheManager):
|
||||
|
||||
|
||||
class SlidingWindowManager(SingleTypeKVCacheManager):
|
||||
def __init__(
|
||||
self, kv_cache_spec: SlidingWindowSpec, block_pool: BlockPool, **kwargs
|
||||
) -> None:
|
||||
super().__init__(kv_cache_spec, block_pool, **kwargs)
|
||||
def __init__(self, kv_cache_spec: SlidingWindowSpec, **kwargs) -> None:
|
||||
super().__init__(kv_cache_spec, **kwargs)
|
||||
self.sliding_window = kv_cache_spec.sliding_window
|
||||
self._null_block = block_pool.null_block
|
||||
|
||||
@classmethod
|
||||
def find_longest_cache_hit(
|
||||
@@ -586,12 +592,9 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
|
||||
|
||||
|
||||
class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
|
||||
def __init__(
|
||||
self, kv_cache_spec: ChunkedLocalAttentionSpec, block_pool: BlockPool, **kwargs
|
||||
) -> None:
|
||||
super().__init__(kv_cache_spec, block_pool, **kwargs)
|
||||
def __init__(self, kv_cache_spec: ChunkedLocalAttentionSpec, **kwargs) -> None:
|
||||
super().__init__(kv_cache_spec, **kwargs)
|
||||
self.attention_chunk_size = kv_cache_spec.attention_chunk_size
|
||||
self._null_block = block_pool.null_block
|
||||
|
||||
@classmethod
|
||||
def find_longest_cache_hit(
|
||||
@@ -739,6 +742,17 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
|
||||
|
||||
|
||||
class MambaManager(SingleTypeKVCacheManager):
|
||||
def __init__(self, kv_cache_spec: MambaSpec, **kwargs) -> None:
|
||||
super().__init__(kv_cache_spec, **kwargs)
|
||||
self.mamba_cache_mode = kv_cache_spec.mamba_cache_mode
|
||||
self.num_speculative_blocks: int = kv_cache_spec.num_speculative_blocks
|
||||
if self.mamba_cache_mode == "align":
|
||||
# Mapping from request ID to the index of the block
|
||||
# allocated in the previous step
|
||||
self.last_state_block_idx: dict[str, int] = {}
|
||||
# The set of the requests that have been allocated blocks
|
||||
self._allocated_block_reqs: set[str] = set()
|
||||
|
||||
@classmethod
|
||||
def find_longest_cache_hit(
|
||||
cls,
|
||||
@@ -787,6 +801,28 @@ class MambaManager(SingleTypeKVCacheManager):
|
||||
|
||||
return computed_blocks
|
||||
|
||||
def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None:
|
||||
assert isinstance(self.kv_cache_spec, MambaSpec)
|
||||
super().remove_skipped_blocks(request_id, num_computed_tokens)
|
||||
if self.mamba_cache_mode == "align":
|
||||
# `last_state_block_idx` refers to the block index allocated two steps ago.
|
||||
# The block allocated in the previous step is used to copy Mamba states
|
||||
# into the block allocated in the current step; the earlier block is
|
||||
# no longer needed and should be freed here.
|
||||
last_state_block_idx = self.last_state_block_idx.get(request_id)
|
||||
# Blocks allocated during prefill may be non-contiguous. Use
|
||||
# `last_state_block_idx` to free the appropriate block and replace it
|
||||
# with a null block.
|
||||
if (
|
||||
last_state_block_idx is not None
|
||||
and last_state_block_idx
|
||||
< cdiv(num_computed_tokens, self.block_size) - 1
|
||||
):
|
||||
blocks = self.req_to_blocks[request_id]
|
||||
if blocks[last_state_block_idx] != self._null_block:
|
||||
self.block_pool.free_blocks([blocks[last_state_block_idx]])
|
||||
blocks[last_state_block_idx] = self._null_block
|
||||
|
||||
def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
|
||||
"""
|
||||
cascade attention is not supported by mamba
|
||||
@@ -799,31 +835,134 @@ class MambaManager(SingleTypeKVCacheManager):
|
||||
num_tokens: int,
|
||||
new_computed_blocks: Sequence[KVCacheBlock],
|
||||
total_computed_tokens: int,
|
||||
num_tokens_main_model: int,
|
||||
) -> int:
|
||||
# Allocate extra `num_speculative_blocks` blocks for
|
||||
# speculative decoding (MTP/EAGLE) with linear attention.
|
||||
assert isinstance(self.kv_cache_spec, MambaSpec)
|
||||
if self.kv_cache_spec.num_speculative_blocks > 0:
|
||||
num_tokens += (
|
||||
self.kv_cache_spec.block_size
|
||||
* self.kv_cache_spec.num_speculative_blocks
|
||||
if self.mamba_cache_mode != "align":
|
||||
# Allocate extra `num_speculative_blocks` blocks for
|
||||
# speculative decoding (MTP/EAGLE) with linear attention.
|
||||
if self.num_speculative_blocks > 0:
|
||||
num_tokens += (
|
||||
self.kv_cache_spec.block_size * self.num_speculative_blocks
|
||||
)
|
||||
return super().get_num_blocks_to_allocate(
|
||||
request_id,
|
||||
num_tokens,
|
||||
new_computed_blocks,
|
||||
total_computed_tokens,
|
||||
num_tokens_main_model,
|
||||
)
|
||||
return super().get_num_blocks_to_allocate(
|
||||
request_id, num_tokens, new_computed_blocks, total_computed_tokens
|
||||
)
|
||||
else:
|
||||
# We don't allocate blocks for lookahead tokens in align mode, because if
|
||||
# x * block_size tokens are scheduled, num_tokens is
|
||||
# x * block_size + num_lookahead_tokens and breaks the alignment.
|
||||
# We can ignore lookahead tokens because current draft models don't have
|
||||
# mamba layers.
|
||||
num_tokens = num_tokens_main_model
|
||||
num_required_blocks = (
|
||||
cdiv(num_tokens, self.block_size) + self.num_speculative_blocks
|
||||
)
|
||||
num_new_blocks = (
|
||||
num_required_blocks
|
||||
- len(new_computed_blocks)
|
||||
- len(self.req_to_blocks[request_id])
|
||||
)
|
||||
if num_new_blocks > 0:
|
||||
if request_id in self._allocated_block_reqs:
|
||||
# Old request. Needs at most 1 more blocks as we can reuse the
|
||||
# speculative blocks in previous step.
|
||||
num_new_blocks = 1
|
||||
else:
|
||||
# First prefill. Allocate 1 block for running state and the
|
||||
# speculative blocks.
|
||||
num_new_blocks = 1 + self.num_speculative_blocks
|
||||
|
||||
num_evictable_computed_blocks = self._get_num_evictable_blocks(
|
||||
new_computed_blocks
|
||||
)
|
||||
return num_new_blocks + num_evictable_computed_blocks
|
||||
|
||||
def allocate_new_blocks(
|
||||
self, request_id: str, num_tokens: int
|
||||
self, request_id: str, num_tokens: int, num_tokens_main_model: int
|
||||
) -> list[KVCacheBlock]:
|
||||
# Allocate extra `num_speculative_blocks` blocks for
|
||||
# speculative decoding (MTP/EAGLE) with linear attention.
|
||||
assert isinstance(self.kv_cache_spec, MambaSpec)
|
||||
if self.kv_cache_spec.num_speculative_blocks > 0:
|
||||
num_tokens += (
|
||||
self.kv_cache_spec.block_size
|
||||
* self.kv_cache_spec.num_speculative_blocks
|
||||
if self.mamba_cache_mode != "align":
|
||||
# Allocate extra `num_speculative_blocks` blocks for
|
||||
# speculative decoding (MTP/EAGLE) with linear attention.
|
||||
if self.num_speculative_blocks > 0:
|
||||
num_tokens += self.block_size * self.num_speculative_blocks
|
||||
return super().allocate_new_blocks(
|
||||
request_id, num_tokens, num_tokens_main_model
|
||||
)
|
||||
return super().allocate_new_blocks(request_id, num_tokens)
|
||||
else:
|
||||
# We don't allocate blocks for lookahead tokens in align mode, because if
|
||||
# x * block_size tokens are scheduled, num_tokens is
|
||||
# x * block_size + num_lookahead_tokens and breaks the alignment.
|
||||
# We can ignore lookahead tokens because current draft models don't have
|
||||
# mamba layers.
|
||||
num_tokens = num_tokens_main_model
|
||||
req_blocks: list[KVCacheBlock] = self.req_to_blocks[request_id]
|
||||
num_required_blocks = (
|
||||
cdiv(num_tokens, self.block_size) + self.num_speculative_blocks
|
||||
)
|
||||
if num_required_blocks == len(req_blocks):
|
||||
return []
|
||||
else:
|
||||
assert num_required_blocks > len(req_blocks), (
|
||||
"num_required_blocks "
|
||||
f"{num_required_blocks} < len(req_blocks) {len(req_blocks)}"
|
||||
)
|
||||
prev_block_len = len(req_blocks)
|
||||
blocks_allocated = request_id in self._allocated_block_reqs
|
||||
# Record the last state block
|
||||
if blocks_allocated:
|
||||
# We always save the running state at the last
|
||||
# (1 + num_speculative_blocks) block
|
||||
self.last_state_block_idx[request_id] = (
|
||||
prev_block_len - 1 - self.num_speculative_blocks
|
||||
)
|
||||
elif prev_block_len > 0:
|
||||
# When a new request hits the prefix cache, the last block
|
||||
# saves the hit state.
|
||||
self.last_state_block_idx[request_id] = prev_block_len - 1
|
||||
|
||||
num_skipped_blocks = (
|
||||
num_required_blocks - self.num_speculative_blocks - 1
|
||||
)
|
||||
# null blocks
|
||||
if prev_block_len < num_skipped_blocks:
|
||||
req_blocks.extend(
|
||||
[
|
||||
self._null_block
|
||||
for _ in range(prev_block_len, num_skipped_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
if blocks_allocated:
|
||||
# reuse previous speculative blocks in this step
|
||||
for block_idx in range(
|
||||
prev_block_len - self.num_speculative_blocks, prev_block_len
|
||||
):
|
||||
if block_idx < num_skipped_blocks:
|
||||
req_blocks.append(req_blocks[block_idx])
|
||||
req_blocks[block_idx] = self._null_block
|
||||
else:
|
||||
break
|
||||
num_new_blocks = num_required_blocks - len(req_blocks)
|
||||
if blocks_allocated:
|
||||
assert num_new_blocks <= 1
|
||||
else:
|
||||
assert num_new_blocks <= self.num_speculative_blocks + 1
|
||||
new_blocks = self.block_pool.get_new_blocks(num_new_blocks)
|
||||
req_blocks.extend(new_blocks)
|
||||
self._allocated_block_reqs.add(request_id)
|
||||
return req_blocks[prev_block_len:]
|
||||
|
||||
def free(self, request_id: str) -> None:
|
||||
if self.mamba_cache_mode == "align":
|
||||
self._allocated_block_reqs.discard(request_id)
|
||||
self.last_state_block_idx.pop(request_id, None)
|
||||
super().free(request_id)
|
||||
|
||||
def get_num_skipped_tokens(self, num_computed_tokens: int) -> int:
|
||||
"""
|
||||
|
||||
@@ -276,6 +276,7 @@ class MambaSpec(KVCacheSpec):
|
||||
dtypes: tuple[torch.dtype]
|
||||
page_size_padded: int | None = None
|
||||
mamba_type: str = "mamba2"
|
||||
mamba_cache_mode: str = "none"
|
||||
num_speculative_blocks: int = 0
|
||||
|
||||
@property
|
||||
@@ -290,8 +291,13 @@ class MambaSpec(KVCacheSpec):
|
||||
return page_size
|
||||
|
||||
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
|
||||
max_model_len = vllm_config.model_config.max_model_len
|
||||
return cdiv(max_model_len, self.block_size) * self.page_size_bytes
|
||||
if vllm_config.cache_config.mamba_cache_mode == "all":
|
||||
max_model_len = vllm_config.model_config.max_model_len
|
||||
return cdiv(max_model_len, self.block_size) * self.page_size_bytes
|
||||
elif vllm_config.cache_config.mamba_cache_mode == "align":
|
||||
return self.page_size_bytes * (2 + self.num_speculative_blocks)
|
||||
else:
|
||||
return self.page_size_bytes * (1 + self.num_speculative_blocks)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
||||
@@ -8,6 +8,7 @@ from vllm.distributed import get_dcp_group, get_pcp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
from vllm.v1.worker.cp_utils import get_total_cp_world_size
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -261,47 +262,45 @@ class MultiGroupBlockTable:
|
||||
device: torch.device,
|
||||
block_sizes: list[int],
|
||||
kernel_block_sizes: list[int],
|
||||
num_speculative_tokens: int = 0,
|
||||
max_num_blocks: list[int] | None = None,
|
||||
cp_kv_cache_interleave_size: int = 1,
|
||||
) -> None:
|
||||
# Note(hc): each dcp rank only store
|
||||
# (max_model_len//dcp_world_size) tokens in kvcache,
|
||||
# so the block_size which used for calc max_num_blocks_per_req
|
||||
# must be multiplied by dcp_world_size.
|
||||
try:
|
||||
pcp_world_size = get_pcp_group().world_size
|
||||
except AssertionError:
|
||||
# PCP might not be initialized in testing
|
||||
pcp_world_size = 1
|
||||
try:
|
||||
dcp_world_size = get_dcp_group().world_size
|
||||
except AssertionError:
|
||||
# DCP might not be initialized in testing
|
||||
dcp_world_size = 1
|
||||
|
||||
if len(kernel_block_sizes) != len(block_sizes):
|
||||
raise ValueError(
|
||||
f"kernel_block_sizes length ({len(kernel_block_sizes)}) "
|
||||
f"must match block_sizes length ({len(block_sizes)})"
|
||||
)
|
||||
if max_num_blocks is None:
|
||||
# Note(hc): each dcp rank only store
|
||||
# (max_model_len//dcp_world_size) tokens in kvcache,
|
||||
# so the block_size which used for calc max_num_blocks_per_req
|
||||
# must be multiplied by dcp_world_size.
|
||||
total_cp_world_size = get_total_cp_world_size()
|
||||
max_num_blocks = [
|
||||
cdiv(max_model_len, block_size * total_cp_world_size)
|
||||
for block_size in block_sizes
|
||||
]
|
||||
|
||||
total_cp_world_size = dcp_world_size * pcp_world_size
|
||||
if len(max_num_blocks) != len(block_sizes):
|
||||
raise ValueError(
|
||||
f"max_num_blocks length ({len(max_num_blocks)}) "
|
||||
f"must match block_sizes length ({len(block_sizes)})"
|
||||
)
|
||||
|
||||
self.block_tables = [
|
||||
BlockTable(
|
||||
block_size,
|
||||
max_num_reqs,
|
||||
max(
|
||||
cdiv(max_model_len, block_size * total_cp_world_size),
|
||||
1 + num_speculative_tokens,
|
||||
),
|
||||
max_num_blocks_per_req,
|
||||
max_num_batched_tokens,
|
||||
pin_memory,
|
||||
device,
|
||||
kernel_block_size,
|
||||
cp_kv_cache_interleave_size,
|
||||
)
|
||||
for block_size, kernel_block_size in zip(block_sizes, kernel_block_sizes)
|
||||
for block_size, kernel_block_size, max_num_blocks_per_req in zip(
|
||||
block_sizes, kernel_block_sizes, max_num_blocks
|
||||
)
|
||||
]
|
||||
|
||||
def append_row(self, block_ids: tuple[list[int], ...], row_idx: int) -> None:
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.distributed import get_dcp_group, get_pcp_group
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
@@ -40,3 +41,17 @@ def check_attention_cp_compatibility(vllm_config: VllmConfig) -> None:
|
||||
f"but the impl {layer_impl.__class__.__name__} "
|
||||
"does not support PCP."
|
||||
)
|
||||
|
||||
|
||||
def get_total_cp_world_size():
|
||||
try:
|
||||
pcp_world_size = get_pcp_group().world_size
|
||||
except AssertionError:
|
||||
# PCP might not be initialized in testing
|
||||
pcp_world_size = 1
|
||||
try:
|
||||
dcp_world_size = get_dcp_group().world_size
|
||||
except AssertionError:
|
||||
# DCP might not be initialized in testing
|
||||
dcp_world_size = 1
|
||||
return dcp_world_size * pcp_world_size
|
||||
|
||||
@@ -89,11 +89,11 @@ class InputBatch:
|
||||
vocab_size: int,
|
||||
block_sizes: list[int], # The block_size of each kv cache group
|
||||
kernel_block_sizes: list[int],
|
||||
max_num_blocks_per_req: list[int] | None = None,
|
||||
logitsprocs: LogitsProcessors | None = None,
|
||||
logitsprocs_need_output_token_ids: bool = False,
|
||||
is_spec_decode: bool = False,
|
||||
is_pooling_model: bool = False,
|
||||
num_speculative_tokens: int = 0,
|
||||
cp_kv_cache_interleave_size: int = 1,
|
||||
):
|
||||
self.is_pooling_model = is_pooling_model
|
||||
@@ -146,7 +146,7 @@ class InputBatch:
|
||||
device=device,
|
||||
block_sizes=block_sizes,
|
||||
kernel_block_sizes=kernel_block_sizes,
|
||||
num_speculative_tokens=num_speculative_tokens,
|
||||
max_num_blocks=max_num_blocks_per_req,
|
||||
cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,
|
||||
)
|
||||
|
||||
|
||||
@@ -152,7 +152,11 @@ from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||||
from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer
|
||||
from vllm.v1.structured_output.utils import apply_grammar_bitmask
|
||||
from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext
|
||||
from vllm.v1.worker.cp_utils import check_attention_cp_compatibility
|
||||
from vllm.v1.worker import mamba_utils
|
||||
from vllm.v1.worker.cp_utils import (
|
||||
check_attention_cp_compatibility,
|
||||
get_total_cp_world_size,
|
||||
)
|
||||
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
|
||||
from vllm.v1.worker.ec_connector_model_runner_mixin import ECConnectorModelRunnerMixin
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
@@ -688,6 +692,7 @@ class GPUModelRunner(
|
||||
# Ephemeral state transferred between execute_model() and sample_tokens().
|
||||
self.execute_model_state: ExecuteModelState | None = None
|
||||
self.kv_connector_output: KVConnectorOutput | None = None
|
||||
self.mamba_state_idx: dict[str, int] = {}
|
||||
self.layerwise_nvtx_hooks_registered = False
|
||||
|
||||
def update_max_model_len(self, max_model_len: int) -> None:
|
||||
@@ -1075,7 +1080,7 @@ class GPUModelRunner(
|
||||
self.input_batch.refresh_metadata()
|
||||
|
||||
def _update_states_after_model_execute(
|
||||
self, output_token_ids: torch.Tensor
|
||||
self, output_token_ids: torch.Tensor, scheduler_output: "SchedulerOutput"
|
||||
) -> None:
|
||||
"""Update the cached states after model execution.
|
||||
|
||||
@@ -1111,6 +1116,16 @@ class GPUModelRunner(
|
||||
)
|
||||
for i, num_tokens in enumerate(num_accepted_tokens):
|
||||
self.input_batch.num_accepted_tokens_cpu[i] = num_tokens
|
||||
if self.cache_config.mamba_cache_mode == "align":
|
||||
mamba_utils.postprocess_mamba(
|
||||
scheduler_output,
|
||||
self.kv_cache_config,
|
||||
self.input_batch,
|
||||
self.requests,
|
||||
self.mamba_state_idx,
|
||||
self.compilation_config.static_forward_context,
|
||||
self.model.get_mamba_state_copy_func(),
|
||||
)
|
||||
|
||||
def _init_mrope_positions(self, req_state: CachedRequestState):
|
||||
model = self.get_model()
|
||||
@@ -2751,7 +2766,6 @@ class GPUModelRunner(
|
||||
logits,
|
||||
sampling_metadata,
|
||||
)
|
||||
self._update_states_after_model_execute(sampler_output.sampled_token_ids)
|
||||
return sampler_output
|
||||
|
||||
def _bookkeeping_sync(
|
||||
@@ -3237,6 +3251,18 @@ class GPUModelRunner(
|
||||
|
||||
pad_attn = cudagraph_mode == CUDAGraphMode.FULL
|
||||
|
||||
if self.cache_config.mamba_cache_mode == "align":
|
||||
mamba_utils.preprocess_mamba(
|
||||
scheduler_output,
|
||||
self.kv_cache_config,
|
||||
self.cache_config,
|
||||
self.mamba_state_idx,
|
||||
self.input_batch,
|
||||
self.requests,
|
||||
self.compilation_config.static_forward_context,
|
||||
self.model.get_mamba_state_copy_func(),
|
||||
)
|
||||
|
||||
use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0
|
||||
ubatch_slices_attn = ubatch_slices_padded if pad_attn else ubatch_slices
|
||||
|
||||
@@ -3423,6 +3449,10 @@ class GPUModelRunner(
|
||||
with record_function_or_nullcontext("gpu_model_runner: sample"):
|
||||
sampler_output = self._sample(logits, spec_decode_metadata)
|
||||
|
||||
self._update_states_after_model_execute(
|
||||
sampler_output.sampled_token_ids, scheduler_output
|
||||
)
|
||||
|
||||
self._draft_token_ids = None
|
||||
self._draft_token_req_ids = None
|
||||
self.input_batch.prev_sampled_token_ids = None
|
||||
@@ -5322,6 +5352,24 @@ class GPUModelRunner(
|
||||
for kv_cache_group in kv_cache_config.kv_cache_groups
|
||||
if not isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec)
|
||||
]
|
||||
max_num_blocks = []
|
||||
max_model_len = max(self.max_model_len, self.max_encoder_len)
|
||||
for i, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
|
||||
if isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec):
|
||||
continue
|
||||
max_num_blocks_per_req = cdiv(
|
||||
max_model_len, block_sizes[i] * get_total_cp_world_size()
|
||||
)
|
||||
if isinstance(kv_cache_group.kv_cache_spec, MambaSpec):
|
||||
mamba_blocks_per_req = (
|
||||
max_num_blocks_per_req
|
||||
if self.cache_config.enable_prefix_caching
|
||||
else 1
|
||||
) + kv_cache_group.kv_cache_spec.num_speculative_blocks
|
||||
max_num_blocks_per_req = max(
|
||||
max_num_blocks_per_req, mamba_blocks_per_req
|
||||
)
|
||||
max_num_blocks.append(max_num_blocks_per_req)
|
||||
|
||||
if block_sizes != [self.cache_config.block_size] or kernel_block_sizes != [
|
||||
self.cache_config.block_size
|
||||
@@ -5333,18 +5381,18 @@ class GPUModelRunner(
|
||||
)
|
||||
self.input_batch = InputBatch(
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
max_model_len=max(self.max_model_len, self.max_encoder_len),
|
||||
max_model_len=max_model_len,
|
||||
max_num_batched_tokens=self.max_num_tokens,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
vocab_size=self.model_config.get_vocab_size(),
|
||||
block_sizes=block_sizes,
|
||||
kernel_block_sizes=kernel_block_sizes,
|
||||
max_num_blocks_per_req=max_num_blocks,
|
||||
is_spec_decode=bool(self.vllm_config.speculative_config),
|
||||
logitsprocs=self.input_batch.logitsprocs,
|
||||
logitsprocs_need_output_token_ids=self.input_batch.logitsprocs_need_output_token_ids,
|
||||
is_pooling_model=self.is_pooling_model,
|
||||
num_speculative_tokens=self.num_spec_tokens,
|
||||
)
|
||||
|
||||
def _allocate_kv_cache_tensors(
|
||||
|
||||
232
vllm/v1/worker/mamba_utils.py
Normal file
232
vllm/v1/worker/mamba_utils.py
Normal file
@@ -0,0 +1,232 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import itertools
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateCopyFunc,
|
||||
)
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig, MambaSpec
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState
|
||||
from vllm.v1.worker.lora_model_runner_mixin import GPUInputBatch
|
||||
|
||||
|
||||
@triton.jit
|
||||
def batch_memcpy_kernel(src_ptrs, dst_ptrs, sizes, BLOCK_SIZE: tl.constexpr):
|
||||
pid = tl.program_id(0)
|
||||
|
||||
src_ptr = tl.load(src_ptrs + pid)
|
||||
dst_ptr = tl.load(dst_ptrs + pid)
|
||||
size = tl.load(sizes + pid)
|
||||
|
||||
offsets = tl.arange(0, BLOCK_SIZE)
|
||||
for i in range(0, size, BLOCK_SIZE):
|
||||
mask = (i + offsets) < size
|
||||
|
||||
curr_src_ptr = (src_ptr + i + offsets).to(tl.pointer_type(tl.uint8))
|
||||
curr_dst_ptr = (dst_ptr + i + offsets).to(tl.pointer_type(tl.uint8))
|
||||
|
||||
data = tl.load(curr_src_ptr, mask=mask)
|
||||
tl.store(curr_dst_ptr, data, mask=mask)
|
||||
|
||||
|
||||
def batch_memcpy(src_ptrs, dst_ptrs, sizes):
|
||||
batch = src_ptrs.shape[0]
|
||||
assert dst_ptrs.shape[0] == batch
|
||||
assert sizes.shape[0] == batch
|
||||
|
||||
grid = (batch,)
|
||||
BLOCK_SIZE = 1024
|
||||
batch_memcpy_kernel[grid](src_ptrs, dst_ptrs, sizes, BLOCK_SIZE=BLOCK_SIZE)
|
||||
|
||||
|
||||
def get_mamba_groups(kv_cache_config: KVCacheConfig) -> tuple[list[int], MambaSpec]:
|
||||
mamba_group_ids: list[int] = []
|
||||
mamba_specs: list[MambaSpec] = []
|
||||
for i in range(len(kv_cache_config.kv_cache_groups)):
|
||||
kv_cache_spec = kv_cache_config.kv_cache_groups[i].kv_cache_spec
|
||||
if isinstance(kv_cache_spec, MambaSpec):
|
||||
mamba_group_ids.append(i)
|
||||
mamba_specs.append(kv_cache_spec)
|
||||
assert len(mamba_group_ids) > 0, "no mamba layers in the model"
|
||||
assert all(mamba_specs[0] == spec for spec in mamba_specs)
|
||||
return mamba_group_ids, mamba_specs[0]
|
||||
|
||||
|
||||
def collect_mamba_copy_meta(
|
||||
src_state_list: list[int],
|
||||
dest_state_list: list[int],
|
||||
num_elements_list: list[int],
|
||||
kv_cache_config: KVCacheConfig,
|
||||
mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...],
|
||||
mamba_group_ids: list[int],
|
||||
src_block_idx: int,
|
||||
dest_block_idx: int,
|
||||
accept_token_bias: int,
|
||||
req_state: CachedRequestState,
|
||||
forward_context: dict[str, Any],
|
||||
):
|
||||
if src_block_idx == dest_block_idx and accept_token_bias == 0:
|
||||
return
|
||||
|
||||
for mamba_group_id in mamba_group_ids:
|
||||
block_ids = req_state.block_ids[mamba_group_id]
|
||||
dest_block_id = block_ids[dest_block_idx]
|
||||
layer_names = kv_cache_config.kv_cache_groups[mamba_group_id].layer_names
|
||||
for layer_name in layer_names:
|
||||
attention = forward_context[layer_name]
|
||||
kv_caches: list[torch.Tensor] = attention.kv_cache[0]
|
||||
for state, state_copy_func in zip(kv_caches, mamba_state_copy_funcs):
|
||||
copy_spec = state_copy_func(
|
||||
state, block_ids, src_block_idx, accept_token_bias + 1
|
||||
)
|
||||
|
||||
src_state_list.append(copy_spec.start_addr)
|
||||
dest_state_list.append(state[dest_block_id].data_ptr())
|
||||
num_elements_list.append(copy_spec.num_elements * state.element_size())
|
||||
|
||||
|
||||
def do_mamba_copy_block(
|
||||
src_state_list: list[int],
|
||||
dest_state_list: list[int],
|
||||
num_elements_list: list[int],
|
||||
):
|
||||
if len(src_state_list) == 0:
|
||||
return
|
||||
assert len(src_state_list) == len(dest_state_list)
|
||||
assert len(src_state_list) == len(num_elements_list)
|
||||
src_state_ptrs = torch.tensor(src_state_list, device="cuda", dtype=torch.int64)
|
||||
dst_state_ptrs = torch.tensor(dest_state_list, device="cuda", dtype=torch.int64)
|
||||
num_elements = torch.tensor(num_elements_list, device="cuda", dtype=torch.int32)
|
||||
|
||||
batch_memcpy(src_state_ptrs, dst_state_ptrs, num_elements)
|
||||
|
||||
|
||||
def preprocess_mamba(
|
||||
scheduler_output: SchedulerOutput,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
cache_config: CacheConfig,
|
||||
mamba_state_idx: dict[str, int],
|
||||
input_batch: GPUInputBatch,
|
||||
requests: dict[str, CachedRequestState],
|
||||
forward_context: dict[str, Any],
|
||||
mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...],
|
||||
):
|
||||
"""
|
||||
Copy the mamba state of previous step to the last
|
||||
(1 + num_speculative_blocks) block.
|
||||
"""
|
||||
mamba_group_ids, mamba_spec = get_mamba_groups(kv_cache_config)
|
||||
num_speculative_blocks = mamba_spec.num_speculative_blocks
|
||||
# TODO(Chen): we need to optimize this function a lot
|
||||
assert cache_config.enable_prefix_caching
|
||||
block_size = mamba_spec.block_size
|
||||
finished_req_ids = scheduler_output.finished_req_ids
|
||||
preempted_req_ids = scheduler_output.preempted_req_ids or set()
|
||||
for req_id in itertools.chain(finished_req_ids, preempted_req_ids):
|
||||
mamba_state_idx.pop(req_id, None)
|
||||
|
||||
src_state_list: list[int] = []
|
||||
dest_state_list: list[int] = []
|
||||
num_elements_list: list[int] = []
|
||||
for i, req_id in enumerate(input_batch.req_ids):
|
||||
req_state = requests[req_id]
|
||||
prev_state_idx = mamba_state_idx.get(req_id)
|
||||
if prev_state_idx is None:
|
||||
# new / resumed request, no previous state
|
||||
# if num_computed_tokens is 0, prev_state_idx will be -1
|
||||
prev_state_idx = (req_state.num_computed_tokens - 1) // block_size
|
||||
|
||||
num_blocks = len(req_state.block_ids[mamba_group_ids[0]])
|
||||
|
||||
# We always save the current running state at the last
|
||||
# (1 + num_speculative_blocks) block.
|
||||
# A corner case worth mention here: assume we have block_size = 4 and
|
||||
# num_speculative_tokens = 2. The request is [A, B, C] and contains 2 draft
|
||||
# tokens [draft 1, draft 2]. Then we will have:
|
||||
# Block 0: [A, B, C, draft 1]
|
||||
# Block 1: [draft 2, TOFILL, TOFILL, TOFILL]
|
||||
# Block 2: speculative block
|
||||
# Block 3: speculative block
|
||||
# And use block 1 to save the running state.
|
||||
curr_state_idx = num_blocks - 1 - num_speculative_blocks
|
||||
mamba_state_idx[req_id] = curr_state_idx
|
||||
if prev_state_idx != -1 and prev_state_idx != curr_state_idx:
|
||||
collect_mamba_copy_meta(
|
||||
src_state_list,
|
||||
dest_state_list,
|
||||
num_elements_list,
|
||||
kv_cache_config,
|
||||
mamba_state_copy_funcs,
|
||||
mamba_group_ids,
|
||||
prev_state_idx,
|
||||
curr_state_idx,
|
||||
input_batch.num_accepted_tokens_cpu[i] - 1,
|
||||
req_state,
|
||||
forward_context,
|
||||
)
|
||||
input_batch.num_accepted_tokens_cpu[i] = 1
|
||||
do_mamba_copy_block(src_state_list, dest_state_list, num_elements_list)
|
||||
|
||||
|
||||
def postprocess_mamba(
|
||||
scheduler_output: SchedulerOutput,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
input_batch: GPUInputBatch,
|
||||
requests: dict[str, CachedRequestState],
|
||||
mamba_state_idx: dict[str, int],
|
||||
forward_context: dict[str, Any],
|
||||
mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...],
|
||||
):
|
||||
"""
|
||||
If a blocks is converted from partial block to full block in this step, copy the
|
||||
state from the block for running state to the new full block.
|
||||
"""
|
||||
num_scheduled_tokens_dict = scheduler_output.num_scheduled_tokens
|
||||
scheduled_spec_decode_tokens_dict = scheduler_output.scheduled_spec_decode_tokens
|
||||
num_accepted_tokens_cpu = input_batch.num_accepted_tokens_cpu
|
||||
# NOTE: can be optimized as this function always returns the same result
|
||||
mamba_group_ids, mamba_spec = get_mamba_groups(kv_cache_config)
|
||||
src_state_list: list[int] = []
|
||||
dest_state_list: list[int] = []
|
||||
num_elements_list: list[int] = []
|
||||
for i, req_id in enumerate(input_batch.req_ids):
|
||||
req_state = requests[req_id]
|
||||
num_computed_tokens = req_state.num_computed_tokens
|
||||
num_draft_tokens = len(scheduled_spec_decode_tokens_dict.get(req_id, []))
|
||||
num_scheduled_tokens = num_scheduled_tokens_dict[req_id]
|
||||
num_accepted_tokens = num_accepted_tokens_cpu[i]
|
||||
num_tokens_running_state = (
|
||||
num_computed_tokens + num_scheduled_tokens - num_draft_tokens
|
||||
)
|
||||
new_num_computed_tokens = num_tokens_running_state + num_accepted_tokens - 1
|
||||
aligned_new_computed_tokens = (
|
||||
new_num_computed_tokens // mamba_spec.block_size * mamba_spec.block_size
|
||||
)
|
||||
# TODO: how to ensure all blocks that cache_blocks called are cached here?
|
||||
if aligned_new_computed_tokens >= num_tokens_running_state:
|
||||
accept_token_bias = aligned_new_computed_tokens - num_tokens_running_state
|
||||
src_block_idx = mamba_state_idx[req_id]
|
||||
dest_block_idx = aligned_new_computed_tokens // mamba_spec.block_size - 1
|
||||
collect_mamba_copy_meta(
|
||||
src_state_list,
|
||||
dest_state_list,
|
||||
num_elements_list,
|
||||
kv_cache_config,
|
||||
mamba_state_copy_funcs,
|
||||
mamba_group_ids,
|
||||
src_block_idx,
|
||||
dest_block_idx,
|
||||
accept_token_bias,
|
||||
req_state,
|
||||
forward_context,
|
||||
)
|
||||
if src_block_idx == dest_block_idx:
|
||||
num_accepted_tokens_cpu[i] = 1
|
||||
do_mamba_copy_block(src_state_list, dest_state_list, num_elements_list)
|
||||
Reference in New Issue
Block a user