[Hybrid] calling get_mamba_groups() once at MambaCopyBuffers.create() (#37318)
Signed-off-by: Francesco Fusco <ffu@zurich.ibm.com>
This commit is contained in:
@@ -36,6 +36,7 @@ def test_resumed_req_ids_cleared_from_mamba_state_idx():
|
||||
spec = MagicMock(block_size=64, num_speculative_blocks=0)
|
||||
cache_config = MagicMock(enable_prefix_caching=True)
|
||||
input_batch = MagicMock(req_ids=[])
|
||||
copy_bufs = MagicMock(mamba_group_ids=[0], mamba_spec=spec)
|
||||
|
||||
mamba_state_idx = {
|
||||
"finished": 1,
|
||||
@@ -62,7 +63,7 @@ def test_resumed_req_ids_cleared_from_mamba_state_idx():
|
||||
{},
|
||||
{},
|
||||
(),
|
||||
MagicMock(),
|
||||
copy_bufs,
|
||||
)
|
||||
|
||||
assert mamba_state_idx == {"keep": 99}
|
||||
|
||||
Reference in New Issue
Block a user