[Hybrid] Enable spec decoding in mamba cache align mode (#33705)

Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com>
This commit is contained in:
Harry Huang
2026-02-14 05:02:28 +08:00
committed by GitHub
parent fd267bc7b7
commit c027541eaf
2 changed files with 10 additions and 8 deletions

View File

@@ -11,8 +11,10 @@ import datasets
import pytest
import torch
from tests.utils import create_new_process_for_each_test
from vllm import LLM, SamplingParams, TokensPrompt
from vllm.config import CacheConfig
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.model_executor.layers.mamba.mamba_utils import MambaStateCopyFunc
from vllm.sequence import IntermediateTensors
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
@@ -103,6 +105,7 @@ def get_fake_propose_draft_token_ids_fn():
aux_hidden_states: list[torch.Tensor] | None,
spec_decode_metadata: SpecDecodeMetadata | None,
common_attn_metadata: CommonAttentionMetadata,
slot_mappings: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None,
) -> list[list[int]]:
num_computed_tokens_cpu_tensor = self.input_batch.num_computed_tokens_cpu_tensor
num_computed_tokens = num_computed_tokens_cpu_tensor[0].item()
@@ -401,6 +404,9 @@ def _run_ref_mamba_state_worker():
}
torch.save(cpu_state_ref, "mamba_kv_cache_dict_ref.pth")
mamba_kv_cache_dict.clear()
del engine
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()
except Exception:
traceback.print_exc()
raise
@@ -473,10 +479,7 @@ def apply_patch(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(mamba_utils, "do_mamba_copy_block", fake_copy_fn)
@pytest.mark.skip(
reason="Skipping test_mamba_prefix_cache because it is based on spec "
"decode which is not allowed now."
)
@create_new_process_for_each_test()
def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch):
run_ref_mamba_state_in_subprocess()
apply_patch(monkeypatch)
@@ -762,3 +765,6 @@ def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch):
mamba_state_ref = torch.load("mamba_kv_cache_dict_ref.pth")
check_mamba_state_equal(mamba_state_ref, mamba_kv_cache_dict, keys_to_check)
mamba_kv_cache_dict.clear()
del engine
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()