[Hybrid] Enable mamba prefix cache "align" mode with async scheduling (#33997)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
@@ -76,14 +76,11 @@ def get_fake_sample_fn() -> SamplerOutput:
|
||||
),
|
||||
logprobs_tensors=None,
|
||||
)
|
||||
num_sampled_tokens = spec_decode_metadata.cu_num_sampled_tokens[0].item() + 1
|
||||
accpeted_tokens = prompt_token_ids[
|
||||
first_token_id_index : first_token_id_index
|
||||
+ min(num_accepted_tokens, logits.shape[0])
|
||||
]
|
||||
sampled_token_ids = accpeted_tokens + [-1] * (
|
||||
num_sampled_tokens - len(accpeted_tokens)
|
||||
)
|
||||
sampled_token_ids = accpeted_tokens
|
||||
return SamplerOutput(
|
||||
sampled_token_ids=torch.tensor(
|
||||
[sampled_token_ids], device="cuda", dtype=torch.int32
|
||||
@@ -124,7 +121,24 @@ def get_fake_propose_draft_token_ids_fn():
|
||||
first_token_id_index : first_token_id_index + num_speculative_tokens
|
||||
]
|
||||
]
|
||||
return proposed_draft_token_ids
|
||||
|
||||
next_token_ids = torch.tensor(
|
||||
prompt_token_ids[
|
||||
first_token_id_index - 1 : first_token_id_index
|
||||
- 1
|
||||
+ num_accepted_tokens
|
||||
],
|
||||
device="cuda",
|
||||
dtype=torch.int32,
|
||||
)
|
||||
|
||||
valid_sampled_tokens_count = torch.tensor(
|
||||
[num_accepted_tokens], device="cuda", dtype=torch.int32
|
||||
)
|
||||
|
||||
self._copy_valid_sampled_token_count(next_token_ids, valid_sampled_tokens_count)
|
||||
|
||||
return torch.tensor(proposed_draft_token_ids, device="cuda", dtype=torch.int32)
|
||||
|
||||
return fake_propose_draft_token_ids_fn
|
||||
|
||||
@@ -184,6 +198,7 @@ mamba_kv_cache_dict = {}
|
||||
|
||||
def get_fake_execute_model_fn(original_execute_model_fn: Callable):
|
||||
last_num_computed_tokens = 0
|
||||
num_prompt_tokens = None
|
||||
|
||||
def fake_execute_model_fn(
|
||||
self: GPUModelRunner,
|
||||
@@ -201,10 +216,30 @@ def get_fake_execute_model_fn(original_execute_model_fn: Callable):
|
||||
mamba_group_id
|
||||
].layer_names[0]
|
||||
nonlocal last_num_computed_tokens
|
||||
nonlocal num_prompt_tokens
|
||||
|
||||
if (
|
||||
len(scheduler_output.scheduled_new_reqs) > 0
|
||||
and scheduler_output.scheduled_new_reqs[0].prompt_token_ids is not None
|
||||
):
|
||||
# record number of prompt tokens
|
||||
num_prompt_tokens = len(
|
||||
scheduler_output.scheduled_new_reqs[0].prompt_token_ids
|
||||
)
|
||||
|
||||
if len(scheduler_output.scheduled_cached_reqs.req_ids) > 0:
|
||||
num_computed_tokens = (
|
||||
scheduler_output.scheduled_cached_reqs.num_computed_tokens[0]
|
||||
)
|
||||
if (
|
||||
self.num_spec_tokens
|
||||
and num_prompt_tokens is not None
|
||||
and num_computed_tokens > num_prompt_tokens
|
||||
):
|
||||
# NOTE (tdoublep) with async scheduling, the scheduler does not have an
|
||||
# accurate measure of the number of computed tokens; we need to subtract
|
||||
# the number of reject tokens from the previous timestep.
|
||||
num_computed_tokens -= num_speculative_tokens + 1 - num_accepted_tokens
|
||||
if (
|
||||
num_computed_tokens // BLOCK_SIZE
|
||||
> last_num_computed_tokens // BLOCK_SIZE
|
||||
@@ -493,9 +528,9 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch):
|
||||
step_actions=[
|
||||
StepAction(0, 554, [1, 1, 1, 1], (-1, -1), (-1, -1)),
|
||||
StepAction(554, 4, [], (-1, -1), (-1, -1)),
|
||||
StepAction(555, 4, [], (-1, -1), (-1, -1)),
|
||||
StepAction(555, 4, [1, 1, 1, 1, 1], (-1, -1), (-1, -1)),
|
||||
StepAction(556, 4, [], (-1, -1), (-1, -1)),
|
||||
StepAction(557, 4, [1, 1, 1, 1, 1], (0, 1), (-1, -1)),
|
||||
StepAction(557, 4, [], (0, 1), (-1, -1)),
|
||||
StepAction(558, 4, [], (-1, -1), (-1, -1)),
|
||||
StepAction(559, 4, [], (-1, -1), (1, 0)),
|
||||
StepAction(560, 4, [], (-1, -1), (-1, -1)),
|
||||
@@ -510,8 +545,8 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch):
|
||||
step_actions=[
|
||||
StepAction(0, 554, [1, 1, 1, 1], (-1, -1), (-1, -1)),
|
||||
StepAction(554, 4, [], (-1, -1), (-1, -1)),
|
||||
StepAction(556, 4, [], (-1, -1), (-1, -1)),
|
||||
StepAction(558, 4, [1, 1, 1, 1, 1], (1, 1), (2, 0)),
|
||||
StepAction(556, 4, [1, 1, 1, 1, 1], (-1, -1), (-1, -1)),
|
||||
StepAction(558, 4, [], (1, 1), (2, 0)),
|
||||
StepAction(560, 4, [], (-1, -1), (-1, -1)),
|
||||
StepAction(562, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
|
||||
],
|
||||
@@ -526,7 +561,8 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch):
|
||||
StepAction(555, 4, [], (-1, -1), (-1, -1)),
|
||||
StepAction(557, 4, [1, 1, 1, 1, 1], (1, 1), (-1, -1)),
|
||||
StepAction(559, 4, [], (-1, -1), (1, 0)),
|
||||
StepAction(561, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
|
||||
StepAction(561, 4, [], (-1, -1), (-1, -1)),
|
||||
StepAction(563, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
|
||||
],
|
||||
),
|
||||
"accept_3_1": TestConfig(
|
||||
@@ -536,9 +572,10 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch):
|
||||
step_actions=[
|
||||
StepAction(0, 553, [1, 1, 1, 1], (-1, -1), (-1, -1)),
|
||||
StepAction(553, 4, [], (-1, -1), (-1, -1)),
|
||||
StepAction(556, 4, [], (-1, -1), (-1, -1)),
|
||||
StepAction(559, 4, [1, 1, 1, 1, 1], (2, 1), (1, 0)),
|
||||
StepAction(562, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
|
||||
StepAction(556, 4, [1, 1, 1, 1, 1], (-1, -1), (-1, -1)),
|
||||
StepAction(559, 4, [], (2, 1), (1, 0)),
|
||||
StepAction(562, 4, [], (-1, -1), (-1, -1)),
|
||||
StepAction(565, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
|
||||
],
|
||||
),
|
||||
"accept_3_2": TestConfig(
|
||||
@@ -561,7 +598,8 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch):
|
||||
StepAction(0, 555, [1, 1, 1, 1], (-1, -1), (-1, -1)),
|
||||
StepAction(555, 4, [], (-1, -1), (-1, -1)),
|
||||
StepAction(558, 4, [1, 1, 1, 1, 1], (2, 1), (2, 0)),
|
||||
StepAction(561, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
|
||||
StepAction(561, 4, [], (-1, -1), (-1, -1)),
|
||||
StepAction(564, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
|
||||
],
|
||||
),
|
||||
"accept_4_1": TestConfig(
|
||||
@@ -572,8 +610,8 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch):
|
||||
StepAction(0, 553, [1, 1, 1, 1], (-1, -1), (-1, -1)),
|
||||
StepAction(553, 4, [], (-1, -1), (-1, -1)),
|
||||
StepAction(557, 4, [1, 1, 1, 1, 1], (3, 1), (3, 0)),
|
||||
StepAction(561, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
|
||||
StepAction(565, 4, [], (-1, -1), (-1, -1)),
|
||||
StepAction(561, 4, [], (-1, -1), (-1, -1)),
|
||||
StepAction(565, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
|
||||
],
|
||||
),
|
||||
"accept_4_2": TestConfig(
|
||||
@@ -584,8 +622,8 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch):
|
||||
StepAction(0, 554, [1, 1, 1, 1], (-1, -1), (-1, -1)),
|
||||
StepAction(554, 4, [], (-1, -1), (-1, -1)),
|
||||
StepAction(558, 4, [1, 1, 1, 1, 1], (3, 1), (2, 0)),
|
||||
StepAction(562, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
|
||||
StepAction(566, 4, [], (-1, -1), (-1, -1)),
|
||||
StepAction(562, 4, [], (-1, -1), (-1, -1)),
|
||||
StepAction(566, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
|
||||
],
|
||||
),
|
||||
"accept_4_3": TestConfig(
|
||||
@@ -596,7 +634,8 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch):
|
||||
StepAction(0, 555, [1, 1, 1, 1], (-1, -1), (-1, -1)),
|
||||
StepAction(555, 4, [], (-1, -1), (-1, -1)),
|
||||
StepAction(559, 4, [1, 1, 1, 1, 1], (3, 1), (1, 0)),
|
||||
StepAction(563, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
|
||||
StepAction(563, 4, [], (-1, -1), (-1, -1)),
|
||||
StepAction(567, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
|
||||
],
|
||||
),
|
||||
"accept_4_4": TestConfig(
|
||||
|
||||
Reference in New Issue
Block a user