[perf] Use pinned memory for async H2D transfer in do_mamba_copy_block (#35480)
Signed-off-by: Huamin Li <3ericli@gmail.com>
This commit is contained in:
@@ -325,6 +325,7 @@ def get_fake_process_mamba_fn(
|
||||
requests: dict[str, CachedRequestState],
|
||||
forward_context: dict[str, Any],
|
||||
mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...],
|
||||
copy_bufs: mamba_utils.MambaCopyBuffers,
|
||||
):
|
||||
nonlocal copy_info
|
||||
copy_info = None
|
||||
@@ -337,6 +338,7 @@ def get_fake_process_mamba_fn(
|
||||
requests,
|
||||
forward_context,
|
||||
mamba_state_copy_funcs,
|
||||
copy_bufs,
|
||||
)
|
||||
if cur_step_action is not None:
|
||||
check_copy_info(
|
||||
@@ -355,6 +357,7 @@ def get_fake_process_mamba_fn(
|
||||
mamba_state_idx: dict[str, int],
|
||||
forward_context: dict[str, Any],
|
||||
mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...],
|
||||
copy_bufs: mamba_utils.MambaCopyBuffers,
|
||||
):
|
||||
nonlocal copy_info
|
||||
copy_info = None
|
||||
@@ -366,6 +369,7 @@ def get_fake_process_mamba_fn(
|
||||
mamba_state_idx,
|
||||
forward_context,
|
||||
mamba_state_copy_funcs,
|
||||
copy_bufs,
|
||||
)
|
||||
if cur_step_action is not None:
|
||||
check_copy_info(
|
||||
@@ -376,19 +380,15 @@ def get_fake_process_mamba_fn(
|
||||
)
|
||||
return ret
|
||||
|
||||
def fake_copy_fn(
|
||||
src_state_list: list[int],
|
||||
dest_state_list: list[int],
|
||||
num_elements_list: list[int],
|
||||
):
|
||||
def fake_copy_fn(copy_bufs: mamba_utils.MambaCopyBuffers):
|
||||
nonlocal copy_info
|
||||
assert copy_info is None
|
||||
n = copy_bufs.offset
|
||||
src_state_list = copy_bufs.src_ptrs.cpu[:n].tolist()
|
||||
dest_state_list = copy_bufs.dst_ptrs.cpu[:n].tolist()
|
||||
num_elements_list = copy_bufs.sizes.cpu[:n].tolist()
|
||||
copy_info = (src_state_list, dest_state_list, num_elements_list)
|
||||
return original_copy_fn(
|
||||
src_state_list,
|
||||
dest_state_list,
|
||||
num_elements_list,
|
||||
)
|
||||
return original_copy_fn(copy_bufs)
|
||||
|
||||
return fake_preprocess_mamba_fn, fake_post_process_mamba_fn, fake_copy_fn
|
||||
|
||||
|
||||
@@ -62,6 +62,7 @@ def test_resumed_req_ids_cleared_from_mamba_state_idx():
|
||||
{},
|
||||
{},
|
||||
(),
|
||||
MagicMock(),
|
||||
)
|
||||
|
||||
assert mamba_state_idx == {"keep": 99}
|
||||
|
||||
@@ -755,6 +755,7 @@ class GPUModelRunner(
|
||||
self.execute_model_state: ExecuteModelState | None = None
|
||||
self.kv_connector_output: KVConnectorOutput | None = None
|
||||
self.mamba_state_idx: dict[str, int] = {}
|
||||
self._mamba_copy_bufs: mamba_utils.MambaCopyBuffers | None = None
|
||||
self.layerwise_nvtx_hooks_registered = False
|
||||
|
||||
def update_max_model_len(self, max_model_len: int) -> None:
|
||||
@@ -849,6 +850,16 @@ class GPUModelRunner(
|
||||
with_numpy=numpy,
|
||||
)
|
||||
|
||||
def _get_mamba_copy_bufs(self) -> mamba_utils.MambaCopyBuffers:
|
||||
if self._mamba_copy_bufs is None:
|
||||
self._mamba_copy_bufs = mamba_utils.MambaCopyBuffers.create(
|
||||
self.max_num_reqs,
|
||||
self.kv_cache_config,
|
||||
self.model.get_mamba_state_copy_func(),
|
||||
self._make_buffer,
|
||||
)
|
||||
return self._mamba_copy_bufs
|
||||
|
||||
def _init_model_kwargs(self):
|
||||
model_kwargs = dict[str, Any]()
|
||||
|
||||
@@ -1211,6 +1222,7 @@ class GPUModelRunner(
|
||||
self.mamba_state_idx,
|
||||
self.compilation_config.static_forward_context,
|
||||
self.model.get_mamba_state_copy_func(),
|
||||
self._get_mamba_copy_bufs(),
|
||||
)
|
||||
|
||||
def _update_streaming_request(
|
||||
@@ -3505,6 +3517,7 @@ class GPUModelRunner(
|
||||
self.requests,
|
||||
self.compilation_config.static_forward_context,
|
||||
self.model.get_mamba_state_copy_func(),
|
||||
self._get_mamba_copy_bufs(),
|
||||
)
|
||||
|
||||
use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0
|
||||
@@ -5997,6 +6010,7 @@ class GPUModelRunner(
|
||||
"""
|
||||
kv_cache_config = deepcopy(kv_cache_config)
|
||||
self.kv_cache_config = kv_cache_config
|
||||
self._mamba_copy_bufs = None
|
||||
self.may_add_encoder_only_layers_to_kv_cache_config()
|
||||
self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config)
|
||||
self.initialize_attn_backend(kv_cache_config)
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import dataclasses
|
||||
import itertools
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
@@ -13,6 +15,7 @@ from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig, MambaSpec
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState
|
||||
from vllm.v1.worker.lora_model_runner_mixin import GPUInputBatch
|
||||
|
||||
@@ -59,10 +62,36 @@ def get_mamba_groups(kv_cache_config: KVCacheConfig) -> tuple[list[int], MambaSp
|
||||
return mamba_group_ids, mamba_specs[0]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MambaCopyBuffers:
|
||||
src_ptrs: CpuGpuBuffer
|
||||
dst_ptrs: CpuGpuBuffer
|
||||
sizes: CpuGpuBuffer
|
||||
offset: int = 0
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
max_num_reqs: int,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
copy_funcs: tuple[MambaStateCopyFunc, ...],
|
||||
make_buffer: Callable[..., CpuGpuBuffer],
|
||||
) -> "MambaCopyBuffers":
|
||||
mamba_group_ids, _ = get_mamba_groups(kv_cache_config)
|
||||
entries_per_req = sum(
|
||||
len(kv_cache_config.kv_cache_groups[gid].layer_names)
|
||||
for gid in mamba_group_ids
|
||||
) * len(copy_funcs)
|
||||
n = max_num_reqs * entries_per_req
|
||||
return cls(
|
||||
src_ptrs=make_buffer(n, dtype=torch.int64),
|
||||
dst_ptrs=make_buffer(n, dtype=torch.int64),
|
||||
sizes=make_buffer(n, dtype=torch.int32),
|
||||
)
|
||||
|
||||
|
||||
def collect_mamba_copy_meta(
|
||||
src_state_list: list[int],
|
||||
dest_state_list: list[int],
|
||||
num_elements_list: list[int],
|
||||
copy_bufs: MambaCopyBuffers,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...],
|
||||
mamba_group_ids: list[int],
|
||||
@@ -71,10 +100,15 @@ def collect_mamba_copy_meta(
|
||||
accept_token_bias: int,
|
||||
req_state: CachedRequestState,
|
||||
forward_context: dict[str, Any],
|
||||
):
|
||||
) -> None:
|
||||
if src_block_idx == dest_block_idx and accept_token_bias == 0:
|
||||
return
|
||||
|
||||
src_ptrs_np = copy_bufs.src_ptrs.np
|
||||
dst_ptrs_np = copy_bufs.dst_ptrs.np
|
||||
sizes_np = copy_bufs.sizes.np
|
||||
offset = copy_bufs.offset
|
||||
|
||||
for mamba_group_id in mamba_group_ids:
|
||||
block_ids = req_state.block_ids[mamba_group_id]
|
||||
dest_block_id = block_ids[dest_block_idx]
|
||||
@@ -87,25 +121,23 @@ def collect_mamba_copy_meta(
|
||||
state, block_ids, src_block_idx, accept_token_bias + 1
|
||||
)
|
||||
|
||||
src_state_list.append(copy_spec.start_addr)
|
||||
dest_state_list.append(state[dest_block_id].data_ptr())
|
||||
num_elements_list.append(copy_spec.num_elements * state.element_size())
|
||||
src_ptrs_np[offset] = copy_spec.start_addr
|
||||
dst_ptrs_np[offset] = state[dest_block_id].data_ptr()
|
||||
sizes_np[offset] = copy_spec.num_elements * state.element_size()
|
||||
offset += 1
|
||||
|
||||
copy_bufs.offset = offset
|
||||
|
||||
|
||||
def do_mamba_copy_block(
|
||||
src_state_list: list[int],
|
||||
dest_state_list: list[int],
|
||||
num_elements_list: list[int],
|
||||
):
|
||||
if len(src_state_list) == 0:
|
||||
def do_mamba_copy_block(copy_bufs: MambaCopyBuffers):
|
||||
n = copy_bufs.offset
|
||||
if n == 0:
|
||||
return
|
||||
assert len(src_state_list) == len(dest_state_list)
|
||||
assert len(src_state_list) == len(num_elements_list)
|
||||
src_state_ptrs = torch.tensor(src_state_list, device="cuda", dtype=torch.int64)
|
||||
dst_state_ptrs = torch.tensor(dest_state_list, device="cuda", dtype=torch.int64)
|
||||
num_elements = torch.tensor(num_elements_list, device="cuda", dtype=torch.int32)
|
||||
|
||||
batch_memcpy(src_state_ptrs, dst_state_ptrs, num_elements)
|
||||
batch_memcpy(
|
||||
copy_bufs.src_ptrs.copy_to_gpu(n),
|
||||
copy_bufs.dst_ptrs.copy_to_gpu(n),
|
||||
copy_bufs.sizes.copy_to_gpu(n),
|
||||
)
|
||||
|
||||
|
||||
def preprocess_mamba(
|
||||
@@ -117,6 +149,7 @@ def preprocess_mamba(
|
||||
requests: dict[str, CachedRequestState],
|
||||
forward_context: dict[str, Any],
|
||||
mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...],
|
||||
copy_bufs: MambaCopyBuffers,
|
||||
):
|
||||
"""
|
||||
Copy the mamba state of previous step to the last
|
||||
@@ -138,9 +171,7 @@ def preprocess_mamba(
|
||||
for req_id in itertools.chain(finished_req_ids, preempted_req_ids, resumed_req_ids):
|
||||
mamba_state_idx.pop(req_id, None)
|
||||
|
||||
src_state_list: list[int] = []
|
||||
dest_state_list: list[int] = []
|
||||
num_elements_list: list[int] = []
|
||||
copy_bufs.offset = 0
|
||||
for i, req_id in enumerate(input_batch.req_ids):
|
||||
req_state = requests[req_id]
|
||||
prev_state_idx = mamba_state_idx.get(req_id)
|
||||
@@ -169,9 +200,7 @@ def preprocess_mamba(
|
||||
mamba_state_idx[req_id] = curr_state_idx
|
||||
if prev_state_idx != -1 and prev_state_idx != curr_state_idx:
|
||||
collect_mamba_copy_meta(
|
||||
src_state_list,
|
||||
dest_state_list,
|
||||
num_elements_list,
|
||||
copy_bufs,
|
||||
kv_cache_config,
|
||||
mamba_state_copy_funcs,
|
||||
mamba_group_ids,
|
||||
@@ -182,7 +211,7 @@ def preprocess_mamba(
|
||||
forward_context,
|
||||
)
|
||||
input_batch.num_accepted_tokens_cpu[i] = 1
|
||||
do_mamba_copy_block(src_state_list, dest_state_list, num_elements_list)
|
||||
do_mamba_copy_block(copy_bufs)
|
||||
|
||||
|
||||
def postprocess_mamba(
|
||||
@@ -193,6 +222,7 @@ def postprocess_mamba(
|
||||
mamba_state_idx: dict[str, int],
|
||||
forward_context: dict[str, Any],
|
||||
mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...],
|
||||
copy_bufs: MambaCopyBuffers,
|
||||
):
|
||||
"""
|
||||
If a blocks is converted from partial block to full block in this step, copy the
|
||||
@@ -203,9 +233,7 @@ def postprocess_mamba(
|
||||
num_accepted_tokens_cpu = input_batch.num_accepted_tokens_cpu
|
||||
# NOTE: can be optimized as this function always returns the same result
|
||||
mamba_group_ids, mamba_spec = get_mamba_groups(kv_cache_config)
|
||||
src_state_list: list[int] = []
|
||||
dest_state_list: list[int] = []
|
||||
num_elements_list: list[int] = []
|
||||
copy_bufs.offset = 0
|
||||
for i, req_id in enumerate(input_batch.req_ids):
|
||||
req_state = requests[req_id]
|
||||
num_computed_tokens = req_state.num_computed_tokens
|
||||
@@ -225,9 +253,7 @@ def postprocess_mamba(
|
||||
src_block_idx = mamba_state_idx[req_id]
|
||||
dest_block_idx = aligned_new_computed_tokens // mamba_spec.block_size - 1
|
||||
collect_mamba_copy_meta(
|
||||
src_state_list,
|
||||
dest_state_list,
|
||||
num_elements_list,
|
||||
copy_bufs,
|
||||
kv_cache_config,
|
||||
mamba_state_copy_funcs,
|
||||
mamba_group_ids,
|
||||
@@ -239,4 +265,4 @@ def postprocess_mamba(
|
||||
)
|
||||
if src_block_idx == dest_block_idx:
|
||||
num_accepted_tokens_cpu[i] = 1
|
||||
do_mamba_copy_block(src_state_list, dest_state_list, num_elements_list)
|
||||
do_mamba_copy_block(copy_bufs)
|
||||
|
||||
Reference in New Issue
Block a user