diff --git a/tests/v1/e2e/test_mamba_prefix_cache.py b/tests/v1/e2e/test_mamba_prefix_cache.py index 38cfdcdb3..5aa72ccb3 100644 --- a/tests/v1/e2e/test_mamba_prefix_cache.py +++ b/tests/v1/e2e/test_mamba_prefix_cache.py @@ -325,6 +325,7 @@ def get_fake_process_mamba_fn( requests: dict[str, CachedRequestState], forward_context: dict[str, Any], mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...], + copy_bufs: mamba_utils.MambaCopyBuffers, ): nonlocal copy_info copy_info = None @@ -337,6 +338,7 @@ def get_fake_process_mamba_fn( requests, forward_context, mamba_state_copy_funcs, + copy_bufs, ) if cur_step_action is not None: check_copy_info( @@ -355,6 +357,7 @@ def get_fake_process_mamba_fn( mamba_state_idx: dict[str, int], forward_context: dict[str, Any], mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...], + copy_bufs: mamba_utils.MambaCopyBuffers, ): nonlocal copy_info copy_info = None @@ -366,6 +369,7 @@ def get_fake_process_mamba_fn( mamba_state_idx, forward_context, mamba_state_copy_funcs, + copy_bufs, ) if cur_step_action is not None: check_copy_info( @@ -376,19 +380,15 @@ def get_fake_process_mamba_fn( ) return ret - def fake_copy_fn( - src_state_list: list[int], - dest_state_list: list[int], - num_elements_list: list[int], - ): + def fake_copy_fn(copy_bufs: mamba_utils.MambaCopyBuffers): nonlocal copy_info assert copy_info is None + n = copy_bufs.offset + src_state_list = copy_bufs.src_ptrs.cpu[:n].tolist() + dest_state_list = copy_bufs.dst_ptrs.cpu[:n].tolist() + num_elements_list = copy_bufs.sizes.cpu[:n].tolist() copy_info = (src_state_list, dest_state_list, num_elements_list) - return original_copy_fn( - src_state_list, - dest_state_list, - num_elements_list, - ) + return original_copy_fn(copy_bufs) return fake_preprocess_mamba_fn, fake_post_process_mamba_fn, fake_copy_fn diff --git a/tests/v1/worker/test_mamba_utils.py b/tests/v1/worker/test_mamba_utils.py index 38eb250fb..df3b7de9b 100644 --- a/tests/v1/worker/test_mamba_utils.py +++ b/tests/v1/worker/test_mamba_utils.py @@ -62,6 +62,7 @@ def test_resumed_req_ids_cleared_from_mamba_state_idx(): {}, {}, (), + MagicMock(), ) assert mamba_state_idx == {"keep": 99} diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a3e0adfae..768a7ee4b 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -755,6 +755,7 @@ class GPUModelRunner( self.execute_model_state: ExecuteModelState | None = None self.kv_connector_output: KVConnectorOutput | None = None self.mamba_state_idx: dict[str, int] = {} + self._mamba_copy_bufs: mamba_utils.MambaCopyBuffers | None = None self.layerwise_nvtx_hooks_registered = False def update_max_model_len(self, max_model_len: int) -> None: @@ -849,6 +850,16 @@ class GPUModelRunner( with_numpy=numpy, ) + def _get_mamba_copy_bufs(self) -> mamba_utils.MambaCopyBuffers: + if self._mamba_copy_bufs is None: + self._mamba_copy_bufs = mamba_utils.MambaCopyBuffers.create( + self.max_num_reqs, + self.kv_cache_config, + self.model.get_mamba_state_copy_func(), + self._make_buffer, + ) + return self._mamba_copy_bufs + def _init_model_kwargs(self): model_kwargs = dict[str, Any]() @@ -1211,6 +1222,7 @@ class GPUModelRunner( self.mamba_state_idx, self.compilation_config.static_forward_context, self.model.get_mamba_state_copy_func(), + self._get_mamba_copy_bufs(), ) def _update_streaming_request( @@ -3505,6 +3517,7 @@ class GPUModelRunner( self.requests, self.compilation_config.static_forward_context, self.model.get_mamba_state_copy_func(), + self._get_mamba_copy_bufs(), ) use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 @@ -5997,6 +6010,7 @@ class GPUModelRunner( """ kv_cache_config = deepcopy(kv_cache_config) self.kv_cache_config = kv_cache_config + self._mamba_copy_bufs = None self.may_add_encoder_only_layers_to_kv_cache_config() self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config) self.initialize_attn_backend(kv_cache_config) diff --git a/vllm/v1/worker/mamba_utils.py b/vllm/v1/worker/mamba_utils.py index 4f8a3bd05..2bd5d2b3f 100644 --- a/vllm/v1/worker/mamba_utils.py +++ b/vllm/v1/worker/mamba_utils.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import dataclasses import itertools +from collections.abc import Callable from typing import Any import torch @@ -13,6 +15,7 @@ from vllm.triton_utils import tl, triton from vllm.utils.math_utils import cdiv from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import KVCacheConfig, MambaSpec +from vllm.v1.utils import CpuGpuBuffer from vllm.v1.worker.gpu_input_batch import CachedRequestState from vllm.v1.worker.lora_model_runner_mixin import GPUInputBatch @@ -59,10 +62,36 @@ def get_mamba_groups(kv_cache_config: KVCacheConfig) -> tuple[list[int], MambaSp return mamba_group_ids, mamba_specs[0] +@dataclasses.dataclass +class MambaCopyBuffers: + src_ptrs: CpuGpuBuffer + dst_ptrs: CpuGpuBuffer + sizes: CpuGpuBuffer + offset: int = 0 + + @classmethod + def create( + cls, + max_num_reqs: int, + kv_cache_config: KVCacheConfig, + copy_funcs: tuple[MambaStateCopyFunc, ...], + make_buffer: Callable[..., CpuGpuBuffer], + ) -> "MambaCopyBuffers": + mamba_group_ids, _ = get_mamba_groups(kv_cache_config) + entries_per_req = sum( + len(kv_cache_config.kv_cache_groups[gid].layer_names) + for gid in mamba_group_ids + ) * len(copy_funcs) + n = max_num_reqs * entries_per_req + return cls( + src_ptrs=make_buffer(n, dtype=torch.int64), + dst_ptrs=make_buffer(n, dtype=torch.int64), + sizes=make_buffer(n, dtype=torch.int32), + ) + + def collect_mamba_copy_meta( - src_state_list: list[int], - dest_state_list: list[int], - num_elements_list: list[int], + copy_bufs: MambaCopyBuffers, kv_cache_config: KVCacheConfig, mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...], mamba_group_ids: list[int], @@ -71,10 +100,15 @@ def collect_mamba_copy_meta( accept_token_bias: int, req_state: CachedRequestState, forward_context: dict[str, Any], -): +) -> None: if src_block_idx == dest_block_idx and accept_token_bias == 0: return + src_ptrs_np = copy_bufs.src_ptrs.np + dst_ptrs_np = copy_bufs.dst_ptrs.np + sizes_np = copy_bufs.sizes.np + offset = copy_bufs.offset + for mamba_group_id in mamba_group_ids: block_ids = req_state.block_ids[mamba_group_id] dest_block_id = block_ids[dest_block_idx] @@ -87,25 +121,23 @@ def collect_mamba_copy_meta( state, block_ids, src_block_idx, accept_token_bias + 1 ) - src_state_list.append(copy_spec.start_addr) - dest_state_list.append(state[dest_block_id].data_ptr()) - num_elements_list.append(copy_spec.num_elements * state.element_size()) + src_ptrs_np[offset] = copy_spec.start_addr + dst_ptrs_np[offset] = state[dest_block_id].data_ptr() + sizes_np[offset] = copy_spec.num_elements * state.element_size() + offset += 1 + + copy_bufs.offset = offset -def do_mamba_copy_block( - src_state_list: list[int], - dest_state_list: list[int], - num_elements_list: list[int], -): - if len(src_state_list) == 0: +def do_mamba_copy_block(copy_bufs: MambaCopyBuffers): + n = copy_bufs.offset + if n == 0: return - assert len(src_state_list) == len(dest_state_list) - assert len(src_state_list) == len(num_elements_list) - src_state_ptrs = torch.tensor(src_state_list, device="cuda", dtype=torch.int64) - dst_state_ptrs = torch.tensor(dest_state_list, device="cuda", dtype=torch.int64) - num_elements = torch.tensor(num_elements_list, device="cuda", dtype=torch.int32) - - batch_memcpy(src_state_ptrs, dst_state_ptrs, num_elements) + batch_memcpy( + copy_bufs.src_ptrs.copy_to_gpu(n), + copy_bufs.dst_ptrs.copy_to_gpu(n), + copy_bufs.sizes.copy_to_gpu(n), + ) def preprocess_mamba( @@ -117,6 +149,7 @@ def preprocess_mamba( requests: dict[str, CachedRequestState], forward_context: dict[str, Any], mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...], + copy_bufs: MambaCopyBuffers, ): """ Copy the mamba state of previous step to the last @@ -138,9 +171,7 @@ def preprocess_mamba( for req_id in itertools.chain(finished_req_ids, preempted_req_ids, resumed_req_ids): mamba_state_idx.pop(req_id, None) - src_state_list: list[int] = [] - dest_state_list: list[int] = [] - num_elements_list: list[int] = [] + copy_bufs.offset = 0 for i, req_id in enumerate(input_batch.req_ids): req_state = requests[req_id] prev_state_idx = mamba_state_idx.get(req_id) @@ -169,9 +200,7 @@ def preprocess_mamba( mamba_state_idx[req_id] = curr_state_idx if prev_state_idx != -1 and prev_state_idx != curr_state_idx: collect_mamba_copy_meta( - src_state_list, - dest_state_list, - num_elements_list, + copy_bufs, kv_cache_config, mamba_state_copy_funcs, mamba_group_ids, @@ -182,7 +211,7 @@ def preprocess_mamba( forward_context, ) input_batch.num_accepted_tokens_cpu[i] = 1 - do_mamba_copy_block(src_state_list, dest_state_list, num_elements_list) + do_mamba_copy_block(copy_bufs) def postprocess_mamba( @@ -193,6 +222,7 @@ def postprocess_mamba( mamba_state_idx: dict[str, int], forward_context: dict[str, Any], mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...], + copy_bufs: MambaCopyBuffers, ): """ If a blocks is converted from partial block to full block in this step, copy the @@ -203,9 +233,7 @@ def postprocess_mamba( num_accepted_tokens_cpu = input_batch.num_accepted_tokens_cpu # NOTE: can be optimized as this function always returns the same result mamba_group_ids, mamba_spec = get_mamba_groups(kv_cache_config) - src_state_list: list[int] = [] - dest_state_list: list[int] = [] - num_elements_list: list[int] = [] + copy_bufs.offset = 0 for i, req_id in enumerate(input_batch.req_ids): req_state = requests[req_id] num_computed_tokens = req_state.num_computed_tokens @@ -225,9 +253,7 @@ def postprocess_mamba( src_block_idx = mamba_state_idx[req_id] dest_block_idx = aligned_new_computed_tokens // mamba_spec.block_size - 1 collect_mamba_copy_meta( - src_state_list, - dest_state_list, - num_elements_list, + copy_bufs, kv_cache_config, mamba_state_copy_funcs, mamba_group_ids, @@ -239,4 +265,4 @@ def postprocess_mamba( ) if src_block_idx == dest_block_idx: num_accepted_tokens_cpu[i] = 1 - do_mamba_copy_block(src_state_list, dest_state_list, num_elements_list) + do_mamba_copy_block(copy_bufs)