[perf] Use pinned memory for async H2D transfer in do_mamba_copy_block (#35480)

Signed-off-by: Huamin Li <3ericli@gmail.com>
This commit is contained in:
Huamin Li
2026-02-27 09:50:37 -08:00
committed by GitHub
parent 1d897ff04f
commit 157722da75
4 changed files with 85 additions and 44 deletions

View File

@@ -325,6 +325,7 @@ def get_fake_process_mamba_fn(
requests: dict[str, CachedRequestState],
forward_context: dict[str, Any],
mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...],
copy_bufs: mamba_utils.MambaCopyBuffers,
):
nonlocal copy_info
copy_info = None
@@ -337,6 +338,7 @@ def get_fake_process_mamba_fn(
requests,
forward_context,
mamba_state_copy_funcs,
copy_bufs,
)
if cur_step_action is not None:
check_copy_info(
@@ -355,6 +357,7 @@ def get_fake_process_mamba_fn(
mamba_state_idx: dict[str, int],
forward_context: dict[str, Any],
mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...],
copy_bufs: mamba_utils.MambaCopyBuffers,
):
nonlocal copy_info
copy_info = None
@@ -366,6 +369,7 @@ def get_fake_process_mamba_fn(
mamba_state_idx,
forward_context,
mamba_state_copy_funcs,
copy_bufs,
)
if cur_step_action is not None:
check_copy_info(
@@ -376,19 +380,15 @@ def get_fake_process_mamba_fn(
)
return ret
def fake_copy_fn(
src_state_list: list[int],
dest_state_list: list[int],
num_elements_list: list[int],
):
def fake_copy_fn(copy_bufs: mamba_utils.MambaCopyBuffers):
nonlocal copy_info
assert copy_info is None
n = copy_bufs.offset
src_state_list = copy_bufs.src_ptrs.cpu[:n].tolist()
dest_state_list = copy_bufs.dst_ptrs.cpu[:n].tolist()
num_elements_list = copy_bufs.sizes.cpu[:n].tolist()
copy_info = (src_state_list, dest_state_list, num_elements_list)
return original_copy_fn(
src_state_list,
dest_state_list,
num_elements_list,
)
return original_copy_fn(copy_bufs)
return fake_preprocess_mamba_fn, fake_post_process_mamba_fn, fake_copy_fn