[perf] Use pinned memory for async H2D transfer in do_mamba_copy_block (#35480)

Signed-off-by: Huamin Li <3ericli@gmail.com>
This commit is contained in:
Huamin Li
2026-02-27 09:50:37 -08:00
committed by GitHub
parent 1d897ff04f
commit 157722da75
4 changed files with 85 additions and 44 deletions

View File

@@ -325,6 +325,7 @@ def get_fake_process_mamba_fn(
requests: dict[str, CachedRequestState],
forward_context: dict[str, Any],
mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...],
copy_bufs: mamba_utils.MambaCopyBuffers,
):
nonlocal copy_info
copy_info = None
@@ -337,6 +338,7 @@ def get_fake_process_mamba_fn(
requests,
forward_context,
mamba_state_copy_funcs,
copy_bufs,
)
if cur_step_action is not None:
check_copy_info(
@@ -355,6 +357,7 @@ def get_fake_process_mamba_fn(
mamba_state_idx: dict[str, int],
forward_context: dict[str, Any],
mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...],
copy_bufs: mamba_utils.MambaCopyBuffers,
):
nonlocal copy_info
copy_info = None
@@ -366,6 +369,7 @@ def get_fake_process_mamba_fn(
mamba_state_idx,
forward_context,
mamba_state_copy_funcs,
copy_bufs,
)
if cur_step_action is not None:
check_copy_info(
@@ -376,19 +380,15 @@ def get_fake_process_mamba_fn(
)
return ret
def fake_copy_fn(
src_state_list: list[int],
dest_state_list: list[int],
num_elements_list: list[int],
):
def fake_copy_fn(copy_bufs: mamba_utils.MambaCopyBuffers):
nonlocal copy_info
assert copy_info is None
n = copy_bufs.offset
src_state_list = copy_bufs.src_ptrs.cpu[:n].tolist()
dest_state_list = copy_bufs.dst_ptrs.cpu[:n].tolist()
num_elements_list = copy_bufs.sizes.cpu[:n].tolist()
copy_info = (src_state_list, dest_state_list, num_elements_list)
return original_copy_fn(
src_state_list,
dest_state_list,
num_elements_list,
)
return original_copy_fn(copy_bufs)
return fake_preprocess_mamba_fn, fake_post_process_mamba_fn, fake_copy_fn

View File

@@ -62,6 +62,7 @@ def test_resumed_req_ids_cleared_from_mamba_state_idx():
{},
{},
(),
MagicMock(),
)
assert mamba_state_idx == {"keep": 99}

View File

@@ -755,6 +755,7 @@ class GPUModelRunner(
self.execute_model_state: ExecuteModelState | None = None
self.kv_connector_output: KVConnectorOutput | None = None
self.mamba_state_idx: dict[str, int] = {}
self._mamba_copy_bufs: mamba_utils.MambaCopyBuffers | None = None
self.layerwise_nvtx_hooks_registered = False
def update_max_model_len(self, max_model_len: int) -> None:
@@ -849,6 +850,16 @@ class GPUModelRunner(
with_numpy=numpy,
)
def _get_mamba_copy_bufs(self) -> mamba_utils.MambaCopyBuffers:
if self._mamba_copy_bufs is None:
self._mamba_copy_bufs = mamba_utils.MambaCopyBuffers.create(
self.max_num_reqs,
self.kv_cache_config,
self.model.get_mamba_state_copy_func(),
self._make_buffer,
)
return self._mamba_copy_bufs
def _init_model_kwargs(self):
model_kwargs = dict[str, Any]()
@@ -1211,6 +1222,7 @@ class GPUModelRunner(
self.mamba_state_idx,
self.compilation_config.static_forward_context,
self.model.get_mamba_state_copy_func(),
self._get_mamba_copy_bufs(),
)
def _update_streaming_request(
@@ -3505,6 +3517,7 @@ class GPUModelRunner(
self.requests,
self.compilation_config.static_forward_context,
self.model.get_mamba_state_copy_func(),
self._get_mamba_copy_bufs(),
)
use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0
@@ -5997,6 +6010,7 @@ class GPUModelRunner(
"""
kv_cache_config = deepcopy(kv_cache_config)
self.kv_cache_config = kv_cache_config
self._mamba_copy_bufs = None
self.may_add_encoder_only_layers_to_kv_cache_config()
self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config)
self.initialize_attn_backend(kv_cache_config)

View File

@@ -1,6 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
import itertools
from collections.abc import Callable
from typing import Any
import torch
@@ -13,6 +15,7 @@ from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import cdiv
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig, MambaSpec
from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.gpu_input_batch import CachedRequestState
from vllm.v1.worker.lora_model_runner_mixin import GPUInputBatch
@@ -59,10 +62,36 @@ def get_mamba_groups(kv_cache_config: KVCacheConfig) -> tuple[list[int], MambaSp
return mamba_group_ids, mamba_specs[0]
@dataclasses.dataclass
class MambaCopyBuffers:
src_ptrs: CpuGpuBuffer
dst_ptrs: CpuGpuBuffer
sizes: CpuGpuBuffer
offset: int = 0
@classmethod
def create(
cls,
max_num_reqs: int,
kv_cache_config: KVCacheConfig,
copy_funcs: tuple[MambaStateCopyFunc, ...],
make_buffer: Callable[..., CpuGpuBuffer],
) -> "MambaCopyBuffers":
mamba_group_ids, _ = get_mamba_groups(kv_cache_config)
entries_per_req = sum(
len(kv_cache_config.kv_cache_groups[gid].layer_names)
for gid in mamba_group_ids
) * len(copy_funcs)
n = max_num_reqs * entries_per_req
return cls(
src_ptrs=make_buffer(n, dtype=torch.int64),
dst_ptrs=make_buffer(n, dtype=torch.int64),
sizes=make_buffer(n, dtype=torch.int32),
)
def collect_mamba_copy_meta(
src_state_list: list[int],
dest_state_list: list[int],
num_elements_list: list[int],
copy_bufs: MambaCopyBuffers,
kv_cache_config: KVCacheConfig,
mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...],
mamba_group_ids: list[int],
@@ -71,10 +100,15 @@ def collect_mamba_copy_meta(
accept_token_bias: int,
req_state: CachedRequestState,
forward_context: dict[str, Any],
):
) -> None:
if src_block_idx == dest_block_idx and accept_token_bias == 0:
return
src_ptrs_np = copy_bufs.src_ptrs.np
dst_ptrs_np = copy_bufs.dst_ptrs.np
sizes_np = copy_bufs.sizes.np
offset = copy_bufs.offset
for mamba_group_id in mamba_group_ids:
block_ids = req_state.block_ids[mamba_group_id]
dest_block_id = block_ids[dest_block_idx]
@@ -87,25 +121,23 @@ def collect_mamba_copy_meta(
state, block_ids, src_block_idx, accept_token_bias + 1
)
src_state_list.append(copy_spec.start_addr)
dest_state_list.append(state[dest_block_id].data_ptr())
num_elements_list.append(copy_spec.num_elements * state.element_size())
src_ptrs_np[offset] = copy_spec.start_addr
dst_ptrs_np[offset] = state[dest_block_id].data_ptr()
sizes_np[offset] = copy_spec.num_elements * state.element_size()
offset += 1
copy_bufs.offset = offset
def do_mamba_copy_block(
src_state_list: list[int],
dest_state_list: list[int],
num_elements_list: list[int],
):
if len(src_state_list) == 0:
def do_mamba_copy_block(copy_bufs: MambaCopyBuffers):
n = copy_bufs.offset
if n == 0:
return
assert len(src_state_list) == len(dest_state_list)
assert len(src_state_list) == len(num_elements_list)
src_state_ptrs = torch.tensor(src_state_list, device="cuda", dtype=torch.int64)
dst_state_ptrs = torch.tensor(dest_state_list, device="cuda", dtype=torch.int64)
num_elements = torch.tensor(num_elements_list, device="cuda", dtype=torch.int32)
batch_memcpy(src_state_ptrs, dst_state_ptrs, num_elements)
batch_memcpy(
copy_bufs.src_ptrs.copy_to_gpu(n),
copy_bufs.dst_ptrs.copy_to_gpu(n),
copy_bufs.sizes.copy_to_gpu(n),
)
def preprocess_mamba(
@@ -117,6 +149,7 @@ def preprocess_mamba(
requests: dict[str, CachedRequestState],
forward_context: dict[str, Any],
mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...],
copy_bufs: MambaCopyBuffers,
):
"""
Copy the mamba state of previous step to the last
@@ -138,9 +171,7 @@ def preprocess_mamba(
for req_id in itertools.chain(finished_req_ids, preempted_req_ids, resumed_req_ids):
mamba_state_idx.pop(req_id, None)
src_state_list: list[int] = []
dest_state_list: list[int] = []
num_elements_list: list[int] = []
copy_bufs.offset = 0
for i, req_id in enumerate(input_batch.req_ids):
req_state = requests[req_id]
prev_state_idx = mamba_state_idx.get(req_id)
@@ -169,9 +200,7 @@ def preprocess_mamba(
mamba_state_idx[req_id] = curr_state_idx
if prev_state_idx != -1 and prev_state_idx != curr_state_idx:
collect_mamba_copy_meta(
src_state_list,
dest_state_list,
num_elements_list,
copy_bufs,
kv_cache_config,
mamba_state_copy_funcs,
mamba_group_ids,
@@ -182,7 +211,7 @@ def preprocess_mamba(
forward_context,
)
input_batch.num_accepted_tokens_cpu[i] = 1
do_mamba_copy_block(src_state_list, dest_state_list, num_elements_list)
do_mamba_copy_block(copy_bufs)
def postprocess_mamba(
@@ -193,6 +222,7 @@ def postprocess_mamba(
mamba_state_idx: dict[str, int],
forward_context: dict[str, Any],
mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...],
copy_bufs: MambaCopyBuffers,
):
"""
If a blocks is converted from partial block to full block in this step, copy the
@@ -203,9 +233,7 @@ def postprocess_mamba(
num_accepted_tokens_cpu = input_batch.num_accepted_tokens_cpu
# NOTE: can be optimized as this function always returns the same result
mamba_group_ids, mamba_spec = get_mamba_groups(kv_cache_config)
src_state_list: list[int] = []
dest_state_list: list[int] = []
num_elements_list: list[int] = []
copy_bufs.offset = 0
for i, req_id in enumerate(input_batch.req_ids):
req_state = requests[req_id]
num_computed_tokens = req_state.num_computed_tokens
@@ -225,9 +253,7 @@ def postprocess_mamba(
src_block_idx = mamba_state_idx[req_id]
dest_block_idx = aligned_new_computed_tokens // mamba_spec.block_size - 1
collect_mamba_copy_meta(
src_state_list,
dest_state_list,
num_elements_list,
copy_bufs,
kv_cache_config,
mamba_state_copy_funcs,
mamba_group_ids,
@@ -239,4 +265,4 @@ def postprocess_mamba(
)
if src_block_idx == dest_block_idx:
num_accepted_tokens_cpu[i] = 1
do_mamba_copy_block(src_state_list, dest_state_list, num_elements_list)
do_mamba_copy_block(copy_bufs)