[Hybrid] Enable spec decoding in mamba cache align mode (#33705)
Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
This commit is contained in:
@@ -11,8 +11,10 @@ import datasets
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.utils import create_new_process_for_each_test
|
||||
from vllm import LLM, SamplingParams, TokensPrompt
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import MambaStateCopyFunc
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
@@ -103,6 +105,7 @@ def get_fake_propose_draft_token_ids_fn():
|
||||
aux_hidden_states: list[torch.Tensor] | None,
|
||||
spec_decode_metadata: SpecDecodeMetadata | None,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
slot_mappings: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None,
|
||||
) -> list[list[int]]:
|
||||
num_computed_tokens_cpu_tensor = self.input_batch.num_computed_tokens_cpu_tensor
|
||||
num_computed_tokens = num_computed_tokens_cpu_tensor[0].item()
|
||||
@@ -401,6 +404,9 @@ def _run_ref_mamba_state_worker():
|
||||
}
|
||||
torch.save(cpu_state_ref, "mamba_kv_cache_dict_ref.pth")
|
||||
mamba_kv_cache_dict.clear()
|
||||
del engine
|
||||
torch.cuda.empty_cache()
|
||||
cleanup_dist_env_and_memory()
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
raise
|
||||
@@ -473,10 +479,7 @@ def apply_patch(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(mamba_utils, "do_mamba_copy_block", fake_copy_fn)
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="Skipping test_mamba_prefix_cache because it is based on spec "
|
||||
"decode which is not allowed now."
|
||||
)
|
||||
@create_new_process_for_each_test()
|
||||
def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch):
|
||||
run_ref_mamba_state_in_subprocess()
|
||||
apply_patch(monkeypatch)
|
||||
@@ -762,3 +765,6 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch):
|
||||
mamba_state_ref = torch.load("mamba_kv_cache_dict_ref.pth")
|
||||
check_mamba_state_equal(mamba_state_ref, mamba_kv_cache_dict, keys_to_check)
|
||||
mamba_kv_cache_dict.clear()
|
||||
del engine
|
||||
torch.cuda.empty_cache()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
Reference in New Issue
Block a user