[Hybrid] calling get_mamba_groups() once at MambaCopyBuffers.create() (#37318)
Signed-off-by: Francesco Fusco <ffu@zurich.ibm.com>
This commit is contained in:
@@ -36,6 +36,7 @@ def test_resumed_req_ids_cleared_from_mamba_state_idx():
|
||||
spec = MagicMock(block_size=64, num_speculative_blocks=0)
|
||||
cache_config = MagicMock(enable_prefix_caching=True)
|
||||
input_batch = MagicMock(req_ids=[])
|
||||
copy_bufs = MagicMock(mamba_group_ids=[0], mamba_spec=spec)
|
||||
|
||||
mamba_state_idx = {
|
||||
"finished": 1,
|
||||
@@ -62,7 +63,7 @@ def test_resumed_req_ids_cleared_from_mamba_state_idx():
|
||||
{},
|
||||
{},
|
||||
(),
|
||||
MagicMock(),
|
||||
copy_bufs,
|
||||
)
|
||||
|
||||
assert mamba_state_idx == {"keep": 99}
|
||||
|
||||
@@ -67,6 +67,8 @@ class MambaCopyBuffers:
|
||||
src_ptrs: CpuGpuBuffer
|
||||
dst_ptrs: CpuGpuBuffer
|
||||
sizes: CpuGpuBuffer
|
||||
mamba_group_ids: list[int]
|
||||
mamba_spec: MambaSpec
|
||||
offset: int = 0
|
||||
|
||||
@classmethod
|
||||
@@ -77,7 +79,7 @@ class MambaCopyBuffers:
|
||||
copy_funcs: tuple[MambaStateCopyFunc, ...],
|
||||
make_buffer: Callable[..., CpuGpuBuffer],
|
||||
) -> "MambaCopyBuffers":
|
||||
mamba_group_ids, _ = get_mamba_groups(kv_cache_config)
|
||||
mamba_group_ids, mamba_spec = get_mamba_groups(kv_cache_config)
|
||||
entries_per_req = sum(
|
||||
len(kv_cache_config.kv_cache_groups[gid].layer_names)
|
||||
for gid in mamba_group_ids
|
||||
@@ -87,6 +89,8 @@ class MambaCopyBuffers:
|
||||
src_ptrs=make_buffer(n, dtype=torch.int64),
|
||||
dst_ptrs=make_buffer(n, dtype=torch.int64),
|
||||
sizes=make_buffer(n, dtype=torch.int32),
|
||||
mamba_group_ids=mamba_group_ids,
|
||||
mamba_spec=mamba_spec,
|
||||
)
|
||||
|
||||
|
||||
@@ -155,7 +159,8 @@ def preprocess_mamba(
|
||||
Copy the mamba state of previous step to the last
|
||||
(1 + num_speculative_blocks) block.
|
||||
"""
|
||||
mamba_group_ids, mamba_spec = get_mamba_groups(kv_cache_config)
|
||||
mamba_group_ids = copy_bufs.mamba_group_ids
|
||||
mamba_spec = copy_bufs.mamba_spec
|
||||
num_speculative_blocks = mamba_spec.num_speculative_blocks
|
||||
# TODO(Chen): we need to optimize this function a lot
|
||||
assert cache_config.enable_prefix_caching
|
||||
@@ -231,8 +236,8 @@ def postprocess_mamba(
|
||||
num_scheduled_tokens_dict = scheduler_output.num_scheduled_tokens
|
||||
scheduled_spec_decode_tokens_dict = scheduler_output.scheduled_spec_decode_tokens
|
||||
num_accepted_tokens_cpu = input_batch.num_accepted_tokens_cpu
|
||||
# NOTE: can be optimized as this function always returns the same result
|
||||
mamba_group_ids, mamba_spec = get_mamba_groups(kv_cache_config)
|
||||
mamba_group_ids = copy_bufs.mamba_group_ids
|
||||
mamba_spec = copy_bufs.mamba_spec
|
||||
copy_bufs.offset = 0
|
||||
for i, req_id in enumerate(input_batch.req_ids):
|
||||
req_state = requests[req_id]
|
||||
|
||||
Reference in New Issue
Block a user