[perf] Use pinned memory for async H2D transfer in do_mamba_copy_block (#35480)
Signed-off-by: Huamin Li <3ericli@gmail.com>
This commit is contained in:
@@ -325,6 +325,7 @@ def get_fake_process_mamba_fn(
|
||||
requests: dict[str, CachedRequestState],
|
||||
forward_context: dict[str, Any],
|
||||
mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...],
|
||||
copy_bufs: mamba_utils.MambaCopyBuffers,
|
||||
):
|
||||
nonlocal copy_info
|
||||
copy_info = None
|
||||
@@ -337,6 +338,7 @@ def get_fake_process_mamba_fn(
|
||||
requests,
|
||||
forward_context,
|
||||
mamba_state_copy_funcs,
|
||||
copy_bufs,
|
||||
)
|
||||
if cur_step_action is not None:
|
||||
check_copy_info(
|
||||
@@ -355,6 +357,7 @@ def get_fake_process_mamba_fn(
|
||||
mamba_state_idx: dict[str, int],
|
||||
forward_context: dict[str, Any],
|
||||
mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...],
|
||||
copy_bufs: mamba_utils.MambaCopyBuffers,
|
||||
):
|
||||
nonlocal copy_info
|
||||
copy_info = None
|
||||
@@ -366,6 +369,7 @@ def get_fake_process_mamba_fn(
|
||||
mamba_state_idx,
|
||||
forward_context,
|
||||
mamba_state_copy_funcs,
|
||||
copy_bufs,
|
||||
)
|
||||
if cur_step_action is not None:
|
||||
check_copy_info(
|
||||
@@ -376,19 +380,15 @@ def get_fake_process_mamba_fn(
|
||||
)
|
||||
return ret
|
||||
|
||||
def fake_copy_fn(
|
||||
src_state_list: list[int],
|
||||
dest_state_list: list[int],
|
||||
num_elements_list: list[int],
|
||||
):
|
||||
def fake_copy_fn(copy_bufs: mamba_utils.MambaCopyBuffers):
|
||||
nonlocal copy_info
|
||||
assert copy_info is None
|
||||
n = copy_bufs.offset
|
||||
src_state_list = copy_bufs.src_ptrs.cpu[:n].tolist()
|
||||
dest_state_list = copy_bufs.dst_ptrs.cpu[:n].tolist()
|
||||
num_elements_list = copy_bufs.sizes.cpu[:n].tolist()
|
||||
copy_info = (src_state_list, dest_state_list, num_elements_list)
|
||||
return original_copy_fn(
|
||||
src_state_list,
|
||||
dest_state_list,
|
||||
num_elements_list,
|
||||
)
|
||||
return original_copy_fn(copy_bufs)
|
||||
|
||||
return fake_preprocess_mamba_fn, fake_post_process_mamba_fn, fake_copy_fn
|
||||
|
||||
|
||||
Reference in New Issue
Block a user