Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com> Signed-off-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: Chen Zhang <zhangch99@outlook.com>
765 lines
29 KiB
Python
765 lines
29 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import multiprocessing as mp
|
|
import os
|
|
import traceback
|
|
from collections.abc import Callable
|
|
from dataclasses import dataclass
|
|
from typing import Any
|
|
|
|
import datasets
|
|
import pytest
|
|
import torch
|
|
|
|
from vllm import LLM, SamplingParams, TokensPrompt
|
|
from vllm.config import CacheConfig
|
|
from vllm.model_executor.layers.mamba.mamba_utils import MambaStateCopyFunc
|
|
from vllm.sequence import IntermediateTensors
|
|
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
|
from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager
|
|
from vllm.v1.core.sched.output import SchedulerOutput
|
|
from vllm.v1.engine.core_client import InprocClient
|
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
|
from vllm.v1.outputs import SamplerOutput
|
|
from vllm.v1.request import Request
|
|
from vllm.v1.sample.metadata import SamplingMetadata
|
|
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
|
from vllm.v1.worker import mamba_utils
|
|
from vllm.v1.worker.gpu_input_batch import CachedRequestState
|
|
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
|
from vllm.v1.worker.lora_model_runner_mixin import GPUInputBatch
|
|
from vllm.v1.worker.mamba_utils import get_mamba_groups
|
|
|
|
|
|
@dataclass
|
|
class StepAction:
|
|
num_computed_tokens_start: int
|
|
num_scheduled_tokens: int
|
|
kv_cache_block_ids: list[int] # [] to follow last step
|
|
preprocess_copy_idx: tuple[int, int] # -1, -1 for no copy
|
|
postprocess_copy_idx: tuple[int, int] # -1, -1 for no copy
|
|
|
|
|
|
num_speculative_tokens = 3
|
|
|
|
num_accepted_tokens = 1
|
|
prompt_token_ids: list[int] = []
|
|
MODEL = "Qwen/Qwen3-Next-80B-A3B-Instruct-FP8"
|
|
BLOCK_SIZE = 560
|
|
NUM_HIDDEN_LAYERS = 1
|
|
cur_step_action_idx = 0
|
|
cur_step_action: StepAction | None = None
|
|
step_actions: list[StepAction] = []
|
|
|
|
|
|
def get_fake_sample_fn() -> SamplerOutput:
|
|
def fake_sample_fn(
|
|
self: GPUModelRunner,
|
|
logits: torch.Tensor | None,
|
|
spec_decode_metadata: SpecDecodeMetadata | None,
|
|
) -> SamplerOutput:
|
|
assert logits is not None
|
|
num_computed_tokens_cpu_tensor = self.input_batch.num_computed_tokens_cpu_tensor
|
|
num_computed_tokens = num_computed_tokens_cpu_tensor[0].item()
|
|
if num_computed_tokens < self.input_batch.num_prompt_tokens[0].item():
|
|
first_token_id_index = self.input_batch.num_prompt_tokens[0].item()
|
|
else:
|
|
first_token_id_index = num_computed_tokens + 1
|
|
if spec_decode_metadata is None:
|
|
return SamplerOutput(
|
|
sampled_token_ids=torch.tensor(
|
|
[[prompt_token_ids[first_token_id_index]]],
|
|
device="cuda",
|
|
dtype=torch.int32,
|
|
),
|
|
logprobs_tensors=None,
|
|
)
|
|
num_sampled_tokens = spec_decode_metadata.cu_num_sampled_tokens[0].item() + 1
|
|
accpeted_tokens = prompt_token_ids[
|
|
first_token_id_index : first_token_id_index
|
|
+ min(num_accepted_tokens, logits.shape[0])
|
|
]
|
|
sampled_token_ids = accpeted_tokens + [-1] * (
|
|
num_sampled_tokens - len(accpeted_tokens)
|
|
)
|
|
return SamplerOutput(
|
|
sampled_token_ids=torch.tensor(
|
|
[sampled_token_ids], device="cuda", dtype=torch.int32
|
|
),
|
|
logprobs_tensors=None,
|
|
)
|
|
|
|
return fake_sample_fn
|
|
|
|
|
|
def get_fake_propose_draft_token_ids_fn():
|
|
def fake_propose_draft_token_ids_fn(
|
|
self: GPUModelRunner,
|
|
scheduler_output: SchedulerOutput,
|
|
sampled_token_ids: torch.Tensor | list[list[int]],
|
|
sampling_metadata: SamplingMetadata,
|
|
hidden_states: torch.Tensor,
|
|
sample_hidden_states: torch.Tensor,
|
|
aux_hidden_states: list[torch.Tensor] | None,
|
|
spec_decode_metadata: SpecDecodeMetadata | None,
|
|
common_attn_metadata: CommonAttentionMetadata,
|
|
) -> list[list[int]]:
|
|
num_computed_tokens_cpu_tensor = self.input_batch.num_computed_tokens_cpu_tensor
|
|
num_computed_tokens = num_computed_tokens_cpu_tensor[0].item()
|
|
if (
|
|
self.input_batch.num_tokens_no_spec[0].item()
|
|
<= self.input_batch.num_prompt_tokens[0].item()
|
|
):
|
|
first_token_id_index = self.input_batch.num_prompt_tokens[0].item()
|
|
else:
|
|
first_token_id_index = (
|
|
num_computed_tokens + 1
|
|
) # bonus token isn't considered as computed
|
|
first_token_id_index += self.input_batch.num_accepted_tokens_cpu[0].item()
|
|
proposed_draft_token_ids = [
|
|
prompt_token_ids[
|
|
first_token_id_index : first_token_id_index + num_speculative_tokens
|
|
]
|
|
]
|
|
return proposed_draft_token_ids
|
|
|
|
return fake_propose_draft_token_ids_fn
|
|
|
|
|
|
def get_fake_step_action_fn(original_step_action_fn: Callable):
|
|
def fake_get_output(self: InprocClient):
|
|
global cur_step_action_idx
|
|
global cur_step_action
|
|
if cur_step_action_idx < len(step_actions):
|
|
cur_step_action = step_actions[cur_step_action_idx]
|
|
cur_step_action_idx += 1
|
|
else:
|
|
cur_step_action = None
|
|
print(f"cur_step_action: {cur_step_action_idx=} {cur_step_action=}")
|
|
return original_step_action_fn(self)
|
|
|
|
return fake_get_output
|
|
|
|
|
|
def get_fake_allocate_slots_fn(original_allocate_slots_fn: Callable):
|
|
def fake_allocate_slots_fn(
|
|
self: KVCacheManager,
|
|
request: Request,
|
|
num_new_tokens: int,
|
|
num_new_computed_tokens: int = 0,
|
|
new_computed_blocks: KVCacheBlocks | None = None,
|
|
num_lookahead_tokens: int = 0,
|
|
num_external_computed_tokens: int = 0,
|
|
delay_cache_blocks: bool = False,
|
|
num_encoder_tokens: int = 0,
|
|
):
|
|
ret = original_allocate_slots_fn(
|
|
self,
|
|
request,
|
|
num_new_tokens,
|
|
num_new_computed_tokens,
|
|
new_computed_blocks,
|
|
num_lookahead_tokens,
|
|
num_external_computed_tokens,
|
|
delay_cache_blocks,
|
|
num_encoder_tokens,
|
|
)
|
|
if cur_step_action is not None:
|
|
cur_block_ids = self.coordinator.single_type_managers[0].req_to_blocks[
|
|
request.request_id
|
|
]
|
|
not_null_block_flags = [not block.is_null for block in cur_block_ids]
|
|
block_ids = [1 if block else 0 for block in not_null_block_flags]
|
|
assert block_ids == cur_step_action.kv_cache_block_ids
|
|
return ret
|
|
|
|
return fake_allocate_slots_fn
|
|
|
|
|
|
mamba_kv_cache_dict = {}
|
|
|
|
|
|
def get_fake_execute_model_fn(original_execute_model_fn: Callable):
|
|
last_num_computed_tokens = 0
|
|
|
|
def fake_execute_model_fn(
|
|
self: GPUModelRunner,
|
|
scheduler_output: SchedulerOutput,
|
|
intermediate_tensors: IntermediateTensors | None = None,
|
|
):
|
|
if cur_step_action is not None:
|
|
num_scheduled_tokens = next(
|
|
iter(scheduler_output.num_scheduled_tokens.values())
|
|
)
|
|
assert num_scheduled_tokens == cur_step_action.num_scheduled_tokens
|
|
mamba_group_ids, mamba_spec = get_mamba_groups(self.kv_cache_config)
|
|
mamba_group_id = mamba_group_ids[0]
|
|
mamba_layer_name = self.kv_cache_config.kv_cache_groups[
|
|
mamba_group_id
|
|
].layer_names[0]
|
|
nonlocal last_num_computed_tokens
|
|
if len(scheduler_output.scheduled_cached_reqs.req_ids) > 0:
|
|
num_computed_tokens = (
|
|
scheduler_output.scheduled_cached_reqs.num_computed_tokens[0]
|
|
)
|
|
if (
|
|
num_computed_tokens // BLOCK_SIZE
|
|
> last_num_computed_tokens // BLOCK_SIZE
|
|
):
|
|
# generated a new aligned block in this step
|
|
block_idx = num_computed_tokens // mamba_spec.block_size - 1
|
|
block_id = (
|
|
self.input_batch.block_table.block_tables[mamba_group_id]
|
|
.block_table.cpu[0, block_idx]
|
|
.item()
|
|
)
|
|
if block_id != 0:
|
|
kv_cache = self.compilation_config.static_forward_context[
|
|
mamba_layer_name
|
|
].kv_cache
|
|
mamba_kv_cache_dict[
|
|
num_computed_tokens - num_computed_tokens % BLOCK_SIZE
|
|
] = (
|
|
kv_cache[0][0][block_id].clone(),
|
|
kv_cache[0][1][block_id].clone(),
|
|
)
|
|
|
|
last_num_computed_tokens = num_computed_tokens
|
|
else:
|
|
last_num_computed_tokens = 0
|
|
|
|
ret = original_execute_model_fn(self, scheduler_output, intermediate_tensors)
|
|
|
|
if cur_step_action is not None:
|
|
assert (
|
|
cur_step_action.num_computed_tokens_start
|
|
== self.input_batch.num_computed_tokens_cpu[0].item()
|
|
)
|
|
|
|
return ret
|
|
|
|
return fake_execute_model_fn
|
|
|
|
|
|
def get_fake_process_mamba_fn(
|
|
original_preprocess_mamba_fn: Callable,
|
|
original_post_process_mamba_fn: Callable,
|
|
original_copy_fn: Callable,
|
|
):
|
|
copy_info: tuple[list[int], list[int], list[int]] | None = None
|
|
|
|
def check_copy_info(
|
|
action: tuple[int, int],
|
|
kv_cache_config: KVCacheConfig,
|
|
forward_context: dict[str, Any],
|
|
input_batch: GPUInputBatch,
|
|
):
|
|
assert copy_info is not None
|
|
if action == (-1, -1):
|
|
assert len(copy_info[0]) == len(copy_info[1]) == len(copy_info[2]) == 0
|
|
else:
|
|
assert len(copy_info[0]) == len(copy_info[1]) == len(copy_info[2]) == 2
|
|
mamba_group_ids, mamba_spec = get_mamba_groups(kv_cache_config)
|
|
mamba_group_id = mamba_group_ids[0]
|
|
mamba_layer_name = kv_cache_config.kv_cache_groups[
|
|
mamba_group_id
|
|
].layer_names[0]
|
|
mamba_kv_cache = forward_context[mamba_layer_name].kv_cache[0][-1]
|
|
mamba_block_table = input_batch.block_table.block_tables[
|
|
mamba_group_id
|
|
].block_table.cpu[0]
|
|
expected_temporal_src = mamba_kv_cache[
|
|
mamba_block_table[action[0]]
|
|
].data_ptr()
|
|
expected_temporal_dest = mamba_kv_cache[
|
|
mamba_block_table[action[1]]
|
|
].data_ptr()
|
|
# -1 is qwen3-next's temporal. We skip checking conv as it is more complex.
|
|
assert copy_info[0][-1] == expected_temporal_src
|
|
assert copy_info[1][-1] == expected_temporal_dest
|
|
|
|
def fake_preprocess_mamba_fn(
|
|
scheduler_output: SchedulerOutput,
|
|
kv_cache_config: KVCacheConfig,
|
|
cache_config: CacheConfig,
|
|
mamba_state_idx: dict[str, int],
|
|
input_batch: GPUInputBatch,
|
|
requests: dict[str, CachedRequestState],
|
|
forward_context: dict[str, Any],
|
|
mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...],
|
|
):
|
|
nonlocal copy_info
|
|
copy_info = None
|
|
ret = original_preprocess_mamba_fn(
|
|
scheduler_output,
|
|
kv_cache_config,
|
|
cache_config,
|
|
mamba_state_idx,
|
|
input_batch,
|
|
requests,
|
|
forward_context,
|
|
mamba_state_copy_funcs,
|
|
)
|
|
if cur_step_action is not None:
|
|
check_copy_info(
|
|
cur_step_action.preprocess_copy_idx,
|
|
kv_cache_config,
|
|
forward_context,
|
|
input_batch,
|
|
)
|
|
return ret
|
|
|
|
def fake_post_process_mamba_fn(
|
|
scheduler_output: SchedulerOutput,
|
|
kv_cache_config: KVCacheConfig,
|
|
input_batch: GPUInputBatch,
|
|
requests: dict[str, CachedRequestState],
|
|
mamba_state_idx: dict[str, int],
|
|
forward_context: dict[str, Any],
|
|
mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...],
|
|
):
|
|
nonlocal copy_info
|
|
copy_info = None
|
|
ret = original_post_process_mamba_fn(
|
|
scheduler_output,
|
|
kv_cache_config,
|
|
input_batch,
|
|
requests,
|
|
mamba_state_idx,
|
|
forward_context,
|
|
mamba_state_copy_funcs,
|
|
)
|
|
if cur_step_action is not None:
|
|
check_copy_info(
|
|
cur_step_action.postprocess_copy_idx,
|
|
kv_cache_config,
|
|
forward_context,
|
|
input_batch,
|
|
)
|
|
return ret
|
|
|
|
def fake_copy_fn(
|
|
src_state_list: list[int],
|
|
dest_state_list: list[int],
|
|
num_elements_list: list[int],
|
|
):
|
|
nonlocal copy_info
|
|
assert copy_info is None
|
|
copy_info = (src_state_list, dest_state_list, num_elements_list)
|
|
return original_copy_fn(
|
|
src_state_list,
|
|
dest_state_list,
|
|
num_elements_list,
|
|
)
|
|
|
|
return fake_preprocess_mamba_fn, fake_post_process_mamba_fn, fake_copy_fn
|
|
|
|
|
|
def run_ref_mamba_state_in_subprocess() -> None:
|
|
ctx = mp.get_context("spawn")
|
|
proc = ctx.Process(target=_run_ref_mamba_state_worker)
|
|
proc.start()
|
|
proc.join(timeout=600)
|
|
if proc.exitcode != 0:
|
|
raise RuntimeError(f"Ref mamba state process exited with code {proc.exitcode}.")
|
|
|
|
|
|
def _run_ref_mamba_state_worker():
|
|
try:
|
|
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
|
|
num_generated_tokens = 8000
|
|
num_prompt_tokens = 500
|
|
sampling_params = SamplingParams(
|
|
temperature=0.0, max_tokens=num_generated_tokens
|
|
)
|
|
prompt_dataset = datasets.load_dataset("heheda/a_long_article")
|
|
full_prompt = prompt_dataset["train"][0]["text"]
|
|
fake_execute_model_fn = get_fake_execute_model_fn(GPUModelRunner.execute_model)
|
|
GPUModelRunner.execute_model = fake_execute_model_fn
|
|
fake_sample_fn = get_fake_sample_fn()
|
|
GPUModelRunner._sample = fake_sample_fn
|
|
engine = LLM(
|
|
model=MODEL,
|
|
block_size=BLOCK_SIZE,
|
|
hf_overrides={"num_hidden_layers": NUM_HIDDEN_LAYERS},
|
|
seed=42,
|
|
)
|
|
global prompt_token_ids
|
|
prompt_token_ids = engine.get_tokenizer().encode(full_prompt)
|
|
print(f"Token IDs length: {len(prompt_token_ids)}")
|
|
|
|
_outputs = engine.generate(
|
|
[TokensPrompt(prompt_token_ids=prompt_token_ids[:num_prompt_tokens])],
|
|
sampling_params,
|
|
)
|
|
# ref_mamba_kv_cache_dict = torch.load("mamba_kv_cache_dict.pth")
|
|
# check_mamba_state_equal(ref_mamba_kv_cache_dict, mamba_kv_cache_dict)
|
|
# torch.save(mamba_kv_cache_dict, "mamba_kv_cache_dict.pth")
|
|
cpu_state_ref = {
|
|
key: tuple(tensor.detach().cpu() for tensor in tensors)
|
|
for key, tensors in mamba_kv_cache_dict.items()
|
|
}
|
|
torch.save(cpu_state_ref, "mamba_kv_cache_dict_ref.pth")
|
|
mamba_kv_cache_dict.clear()
|
|
except Exception:
|
|
traceback.print_exc()
|
|
raise
|
|
|
|
|
|
def check_mamba_state_equal(
|
|
mamba_state_ref: dict, mamba_state_new: dict, keys_to_check: list[int]
|
|
):
|
|
atol = 1e-2
|
|
rtol = 1e-2
|
|
for key in keys_to_check:
|
|
assert key in mamba_state_new
|
|
assert key in mamba_state_ref
|
|
# mamba state new is a subset of mamba state ref
|
|
for i, (ref, new) in enumerate(zip(mamba_state_ref[key], mamba_state_new[key])):
|
|
if ref.device != new.device:
|
|
new = new.to(ref.device)
|
|
new = new[: ref.shape[0]]
|
|
if not torch.allclose(ref, new, atol=atol, rtol=rtol):
|
|
diff_mask = ~torch.isclose(ref, new, atol=atol, rtol=rtol)
|
|
diff_idx = torch.nonzero(diff_mask)
|
|
if diff_idx.shape[0] * 100 < ref.numel():
|
|
print(
|
|
f"[WARNING] found {diff_idx.shape[0] * 100 / ref.numel()}% of the elements are different" # noqa: E501
|
|
)
|
|
continue
|
|
raise ValueError(
|
|
f"Mamba state is not equal for key: {key} at index {i}"
|
|
)
|
|
return True
|
|
|
|
|
|
@dataclass
|
|
class TestConfig:
|
|
num_prompt_tokens: int
|
|
num_generated_tokens: int
|
|
num_accepted_tokens: int
|
|
step_actions: list[StepAction]
|
|
|
|
|
|
def apply_patch(monkeypatch: pytest.MonkeyPatch):
|
|
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
|
|
|
fake_sample_fn = get_fake_sample_fn()
|
|
monkeypatch.setattr(GPUModelRunner, "_sample", fake_sample_fn)
|
|
|
|
fake_propose_draft_token_ids_fn = get_fake_propose_draft_token_ids_fn()
|
|
monkeypatch.setattr(
|
|
GPUModelRunner, "propose_draft_token_ids", fake_propose_draft_token_ids_fn
|
|
)
|
|
|
|
fake_execute_model_fn = get_fake_execute_model_fn(GPUModelRunner.execute_model)
|
|
monkeypatch.setattr(GPUModelRunner, "execute_model", fake_execute_model_fn)
|
|
|
|
fake_step_action_fn = get_fake_step_action_fn(InprocClient.get_output)
|
|
monkeypatch.setattr(InprocClient, "get_output", fake_step_action_fn)
|
|
|
|
fake_allocate_slots_fn = get_fake_allocate_slots_fn(KVCacheManager.allocate_slots)
|
|
monkeypatch.setattr(KVCacheManager, "allocate_slots", fake_allocate_slots_fn)
|
|
|
|
fake_preprocess_mamba_fn, fake_post_process_mamba_fn, fake_copy_fn = (
|
|
get_fake_process_mamba_fn(
|
|
mamba_utils.preprocess_mamba,
|
|
mamba_utils.postprocess_mamba,
|
|
mamba_utils.do_mamba_copy_block,
|
|
)
|
|
)
|
|
monkeypatch.setattr(mamba_utils, "preprocess_mamba", fake_preprocess_mamba_fn)
|
|
monkeypatch.setattr(mamba_utils, "postprocess_mamba", fake_post_process_mamba_fn)
|
|
monkeypatch.setattr(mamba_utils, "do_mamba_copy_block", fake_copy_fn)
|
|
|
|
|
|
@pytest.mark.skip(
|
|
reason="Skipping test_mamba_prefix_cache because it is based on spec "
|
|
"decode which is not allowed now."
|
|
)
|
|
def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch):
|
|
run_ref_mamba_state_in_subprocess()
|
|
apply_patch(monkeypatch)
|
|
prompt_dataset = datasets.load_dataset("heheda/a_long_article")
|
|
full_prompt = prompt_dataset["train"][0]["text"]
|
|
tests = {
|
|
"accept_1": TestConfig(
|
|
num_prompt_tokens=554,
|
|
num_generated_tokens=20,
|
|
num_accepted_tokens=1,
|
|
step_actions=[
|
|
StepAction(0, 554, [1, 1, 1, 1], (-1, -1), (-1, -1)),
|
|
StepAction(554, 4, [], (-1, -1), (-1, -1)),
|
|
StepAction(555, 4, [], (-1, -1), (-1, -1)),
|
|
StepAction(556, 4, [], (-1, -1), (-1, -1)),
|
|
StepAction(557, 4, [1, 1, 1, 1, 1], (0, 1), (-1, -1)),
|
|
StepAction(558, 4, [], (-1, -1), (-1, -1)),
|
|
StepAction(559, 4, [], (-1, -1), (1, 0)),
|
|
StepAction(560, 4, [], (-1, -1), (-1, -1)),
|
|
StepAction(561, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
|
|
],
|
|
),
|
|
# test case 2.1: no hit, accept 2 tokens
|
|
"accept_2_1": TestConfig(
|
|
num_prompt_tokens=554,
|
|
num_generated_tokens=20,
|
|
num_accepted_tokens=2,
|
|
step_actions=[
|
|
StepAction(0, 554, [1, 1, 1, 1], (-1, -1), (-1, -1)),
|
|
StepAction(554, 4, [], (-1, -1), (-1, -1)),
|
|
StepAction(556, 4, [], (-1, -1), (-1, -1)),
|
|
StepAction(558, 4, [1, 1, 1, 1, 1], (1, 1), (2, 0)),
|
|
StepAction(560, 4, [], (-1, -1), (-1, -1)),
|
|
StepAction(562, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
|
|
],
|
|
),
|
|
# test case 2.2: no hit, accept 2 tokens
|
|
"accept_2_2": TestConfig(
|
|
num_prompt_tokens=555,
|
|
num_generated_tokens=20,
|
|
num_accepted_tokens=2,
|
|
step_actions=[
|
|
StepAction(0, 555, [1, 1, 1, 1], (-1, -1), (-1, -1)),
|
|
StepAction(555, 4, [], (-1, -1), (-1, -1)),
|
|
StepAction(557, 4, [1, 1, 1, 1, 1], (1, 1), (-1, -1)),
|
|
StepAction(559, 4, [], (-1, -1), (1, 0)),
|
|
StepAction(561, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
|
|
],
|
|
),
|
|
"accept_3_1": TestConfig(
|
|
num_prompt_tokens=553,
|
|
num_generated_tokens=20,
|
|
num_accepted_tokens=3,
|
|
step_actions=[
|
|
StepAction(0, 553, [1, 1, 1, 1], (-1, -1), (-1, -1)),
|
|
StepAction(553, 4, [], (-1, -1), (-1, -1)),
|
|
StepAction(556, 4, [], (-1, -1), (-1, -1)),
|
|
StepAction(559, 4, [1, 1, 1, 1, 1], (2, 1), (1, 0)),
|
|
StepAction(562, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
|
|
],
|
|
),
|
|
"accept_3_2": TestConfig(
|
|
num_prompt_tokens=554,
|
|
num_generated_tokens=20,
|
|
num_accepted_tokens=3,
|
|
step_actions=[
|
|
StepAction(0, 554, [1, 1, 1, 1], (-1, -1), (-1, -1)),
|
|
StepAction(554, 4, [], (-1, -1), (-1, -1)),
|
|
StepAction(557, 4, [1, 1, 1, 1, 1], (2, 1), (3, 0)),
|
|
StepAction(560, 4, [], (-1, -1), (-1, -1)),
|
|
StepAction(563, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
|
|
],
|
|
),
|
|
"accept_3_3": TestConfig(
|
|
num_prompt_tokens=555,
|
|
num_generated_tokens=20,
|
|
num_accepted_tokens=3,
|
|
step_actions=[
|
|
StepAction(0, 555, [1, 1, 1, 1], (-1, -1), (-1, -1)),
|
|
StepAction(555, 4, [], (-1, -1), (-1, -1)),
|
|
StepAction(558, 4, [1, 1, 1, 1, 1], (2, 1), (2, 0)),
|
|
StepAction(561, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
|
|
],
|
|
),
|
|
"accept_4_1": TestConfig(
|
|
num_prompt_tokens=553,
|
|
num_generated_tokens=20,
|
|
num_accepted_tokens=4,
|
|
step_actions=[
|
|
StepAction(0, 553, [1, 1, 1, 1], (-1, -1), (-1, -1)),
|
|
StepAction(553, 4, [], (-1, -1), (-1, -1)),
|
|
StepAction(557, 4, [1, 1, 1, 1, 1], (3, 1), (3, 0)),
|
|
StepAction(561, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
|
|
StepAction(565, 4, [], (-1, -1), (-1, -1)),
|
|
],
|
|
),
|
|
"accept_4_2": TestConfig(
|
|
num_prompt_tokens=554,
|
|
num_generated_tokens=25,
|
|
num_accepted_tokens=4,
|
|
step_actions=[
|
|
StepAction(0, 554, [1, 1, 1, 1], (-1, -1), (-1, -1)),
|
|
StepAction(554, 4, [], (-1, -1), (-1, -1)),
|
|
StepAction(558, 4, [1, 1, 1, 1, 1], (3, 1), (2, 0)),
|
|
StepAction(562, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
|
|
StepAction(566, 4, [], (-1, -1), (-1, -1)),
|
|
],
|
|
),
|
|
"accept_4_3": TestConfig(
|
|
num_prompt_tokens=555,
|
|
num_generated_tokens=25,
|
|
num_accepted_tokens=4,
|
|
step_actions=[
|
|
StepAction(0, 555, [1, 1, 1, 1], (-1, -1), (-1, -1)),
|
|
StepAction(555, 4, [], (-1, -1), (-1, -1)),
|
|
StepAction(559, 4, [1, 1, 1, 1, 1], (3, 1), (1, 0)),
|
|
StepAction(563, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
|
|
],
|
|
),
|
|
"accept_4_4": TestConfig(
|
|
num_prompt_tokens=556,
|
|
num_generated_tokens=25,
|
|
num_accepted_tokens=4,
|
|
step_actions=[
|
|
StepAction(0, 556, [1, 1, 1, 1], (-1, -1), (-1, -1)),
|
|
StepAction(556, 4, [], (-1, -1), (3, 0)),
|
|
StepAction(560, 4, [1, 1, 1, 1, 1], (0, 1), (-1, -1)),
|
|
StepAction(564, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
|
|
],
|
|
),
|
|
"prompt_block_size": TestConfig(
|
|
num_prompt_tokens=560,
|
|
num_generated_tokens=10,
|
|
num_accepted_tokens=4,
|
|
step_actions=[
|
|
StepAction(0, 560, [1, 1, 1, 1], (-1, -1), (-1, -1)),
|
|
StepAction(560, 4, [1, 1, 1, 1, 1], (0, 1), (-1, -1)),
|
|
],
|
|
),
|
|
"prompt_2_block_size": TestConfig(
|
|
num_prompt_tokens=560 * 2,
|
|
num_generated_tokens=10,
|
|
num_accepted_tokens=4,
|
|
step_actions=[
|
|
StepAction(0, 560, [1, 1, 1, 1], (-1, -1), (-1, -1)),
|
|
StepAction(560, 560, [1, 1, 1, 1, 1], (0, 1), (-1, -1)),
|
|
StepAction(560 * 2, 4, [0, 1, 1, 1, 1, 1], (1, 2), (-1, -1)),
|
|
],
|
|
),
|
|
"prompt_2_block_size_10": TestConfig(
|
|
num_prompt_tokens=560 * 2 + 10,
|
|
num_generated_tokens=10,
|
|
num_accepted_tokens=4,
|
|
step_actions=[
|
|
StepAction(0, 560, [1, 1, 1, 1], (-1, -1), (-1, -1)),
|
|
StepAction(560, 570, [1, 0, 1, 1, 1, 1], (0, 2), (-1, -1)),
|
|
StepAction(560 * 2 + 10, 4, [0, 0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
|
|
],
|
|
),
|
|
"prompt_3_block_size": TestConfig(
|
|
num_prompt_tokens=560 * 3,
|
|
num_generated_tokens=10,
|
|
num_accepted_tokens=4,
|
|
step_actions=[
|
|
StepAction(0, 560 * 2, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
|
|
StepAction(560 * 2, 560, [0, 1, 1, 1, 1, 1], (1, 2), (-1, -1)),
|
|
StepAction(560 * 3, 4, [0, 0, 1, 1, 1, 1, 1], (2, 3), (-1, -1)),
|
|
],
|
|
),
|
|
"prompt_3_block_size_10": TestConfig(
|
|
num_prompt_tokens=560 * 3 + 10,
|
|
num_generated_tokens=10,
|
|
num_accepted_tokens=4,
|
|
step_actions=[
|
|
StepAction(0, 560 * 2, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
|
|
StepAction(560 * 2, 570, [0, 1, 0, 1, 1, 1, 1], (1, 3), (-1, -1)),
|
|
StepAction(560 * 3 + 10, 4, [0, 0, 0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
|
|
],
|
|
),
|
|
"prompt_10_block_size": TestConfig(
|
|
num_prompt_tokens=560 * 10,
|
|
num_generated_tokens=10,
|
|
num_accepted_tokens=4,
|
|
step_actions=[
|
|
StepAction(0, 560 * 5, [0, 0, 0, 0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
|
|
StepAction(
|
|
560 * 5,
|
|
560 * 4,
|
|
[0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1],
|
|
(4, 8),
|
|
(-1, -1),
|
|
),
|
|
StepAction(
|
|
560 * 9,
|
|
560,
|
|
[0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1],
|
|
(8, 9),
|
|
(-1, -1),
|
|
),
|
|
StepAction(
|
|
560 * 10,
|
|
4,
|
|
[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1],
|
|
(9, 10),
|
|
(-1, -1),
|
|
),
|
|
],
|
|
),
|
|
"prompt_10_block_size_10": TestConfig(
|
|
num_prompt_tokens=560 * 10 + 10,
|
|
num_generated_tokens=10,
|
|
num_accepted_tokens=4,
|
|
step_actions=[
|
|
StepAction(0, 560 * 5, [0, 0, 0, 0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
|
|
StepAction(
|
|
560 * 5,
|
|
560 * 4,
|
|
[0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1],
|
|
(4, 8),
|
|
(-1, -1),
|
|
),
|
|
StepAction(
|
|
560 * 9,
|
|
560 + 10,
|
|
[0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1],
|
|
(8, 10),
|
|
(-1, -1),
|
|
),
|
|
],
|
|
),
|
|
}
|
|
|
|
engine = LLM(
|
|
model=MODEL,
|
|
enable_prefix_caching=True,
|
|
block_size=BLOCK_SIZE,
|
|
mamba_cache_mode="align",
|
|
speculative_config={
|
|
"method": "qwen3_next_mtp",
|
|
"num_speculative_tokens": num_speculative_tokens,
|
|
},
|
|
max_num_batched_tokens=3072,
|
|
hf_overrides={"num_hidden_layers": NUM_HIDDEN_LAYERS},
|
|
seed=42,
|
|
)
|
|
global prompt_token_ids
|
|
prompt_token_ids = engine.get_tokenizer().encode(full_prompt)
|
|
print(f"Token IDs length: {len(prompt_token_ids)}")
|
|
for test_case_name, test_config in tests.items():
|
|
print(f"Running test case: {test_case_name}")
|
|
num_generated_tokens = test_config.num_generated_tokens
|
|
num_prompt_tokens = test_config.num_prompt_tokens
|
|
global num_accepted_tokens
|
|
num_accepted_tokens = test_config.num_accepted_tokens
|
|
sampling_params = SamplingParams(
|
|
temperature=0.0, max_tokens=num_generated_tokens
|
|
)
|
|
global cur_step_action_idx
|
|
cur_step_action_idx = 0
|
|
for step_action_prev, step_action_next in zip(
|
|
test_config.step_actions[:-1], test_config.step_actions[1:]
|
|
):
|
|
if (
|
|
step_action_next.kv_cache_block_ids is not None
|
|
and len(step_action_next.kv_cache_block_ids) == 0
|
|
):
|
|
prev_block_ids = step_action_prev.kv_cache_block_ids
|
|
if prev_block_ids is not None:
|
|
step_action_next.kv_cache_block_ids = prev_block_ids.copy()
|
|
global step_actions
|
|
step_actions = test_config.step_actions
|
|
_ = engine.generate(
|
|
[TokensPrompt(prompt_token_ids=prompt_token_ids[:num_prompt_tokens])],
|
|
sampling_params,
|
|
)
|
|
assert engine.llm_engine.engine_core.engine_core.scheduler.reset_prefix_cache()
|
|
print(f"End test case: {test_case_name}")
|
|
keys_to_check = [
|
|
(action.postprocess_copy_idx[1] + 1) * BLOCK_SIZE
|
|
for action in test_config.step_actions
|
|
if action.postprocess_copy_idx and action.postprocess_copy_idx[0] != -1
|
|
]
|
|
mamba_state_ref = torch.load("mamba_kv_cache_dict_ref.pth")
|
|
check_mamba_state_equal(mamba_state_ref, mamba_kv_cache_dict, keys_to_check)
|
|
mamba_kv_cache_dict.clear()
|