[ROCm] [V1] [SpecDec] Enable Speculative Decoding on ROCm V1 Engine (#21496)
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
@@ -6,6 +6,7 @@ from unittest import mock
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.utils import get_attn_backend_list_based_on_platform
|
||||
from tests.v1.attention.utils import (BatchSpec, _Backend,
|
||||
create_common_attn_metadata,
|
||||
create_standard_kv_cache_spec,
|
||||
@@ -120,17 +121,28 @@ def test_prepare_inputs():
|
||||
assert torch.equal(token_indices, expected_token_indices)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("method,proposer_helper", [
|
||||
("eagle", lambda k: _create_proposer("eagle", k)),
|
||||
("eagle3", lambda k: _create_proposer("eagle3", k)),
|
||||
])
|
||||
@pytest.mark.parametrize("method", ["eagle", "eagle3"])
|
||||
@pytest.mark.parametrize("attn_backend",
|
||||
get_attn_backend_list_based_on_platform())
|
||||
@pytest.mark.parametrize("pp_size", [1, 2])
|
||||
@pytest.mark.parametrize("use_distinct_embed_tokens", [True, False])
|
||||
@mock.patch('vllm.v1.spec_decode.eagle.get_pp_group')
|
||||
@mock.patch('vllm.v1.spec_decode.eagle.get_layers_from_vllm_config')
|
||||
@mock.patch('vllm.v1.spec_decode.eagle.get_model')
|
||||
def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
|
||||
proposer_helper, pp_size, use_distinct_embed_tokens):
|
||||
attn_backend, pp_size, use_distinct_embed_tokens,
|
||||
monkeypatch):
|
||||
|
||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
||||
|
||||
if (attn_backend == "TRITON_ATTN_VLLM_V1"
|
||||
and not current_platform.is_rocm()):
|
||||
pytest.skip("TRITON_ATTN_VLLM_V1 does not support "
|
||||
"multi-token eagle spec decode on current platform")
|
||||
|
||||
if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm():
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
|
||||
# Setup draft model mock
|
||||
mock_model = mock.MagicMock()
|
||||
if use_distinct_embed_tokens:
|
||||
@@ -177,7 +189,7 @@ def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
|
||||
target_model.lm_head = mock.MagicMock()
|
||||
|
||||
# Create proposer using the helper function
|
||||
proposer = proposer_helper(k=8)
|
||||
proposer = _create_proposer(method, k=8)
|
||||
|
||||
# Call the method under test
|
||||
proposer.load_model(target_model)
|
||||
@@ -201,10 +213,22 @@ def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
|
||||
target_model.model.embed_tokens
|
||||
|
||||
|
||||
@pytest.mark.parametrize("method", ["eagle", "eagle3"])
|
||||
@pytest.mark.parametrize("attn_backend",
|
||||
get_attn_backend_list_based_on_platform())
|
||||
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 8])
|
||||
@pytest.mark.parametrize("backend",
|
||||
[_Backend.FLASH_ATTN_VLLM_V1, _Backend.TREE_ATTN])
|
||||
def test_propose(num_speculative_tokens, backend):
|
||||
def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
|
||||
|
||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
||||
|
||||
if (attn_backend == "TRITON_ATTN_VLLM_V1"
|
||||
and not current_platform.is_rocm()):
|
||||
pytest.skip("TRITON_ATTN_VLLM_V1 does not support "
|
||||
"multi-token eagle spec decode on current platform")
|
||||
|
||||
if attn_backend == "FLASH_ATTN_VLLM_V1" and current_platform.is_rocm():
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
|
||||
# Use GPU device
|
||||
device = torch.device(current_platform.device_type)
|
||||
|
||||
@@ -303,7 +327,18 @@ def test_propose(num_speculative_tokens, backend):
|
||||
device=device)
|
||||
sampling_metadata = mock.MagicMock()
|
||||
|
||||
attn_metadata_builder_cls, _ = get_attention_backend(backend)
|
||||
if attn_backend == "FLASH_ATTN_VLLM_V1":
|
||||
attn_metadata_builder_cls, _ = get_attention_backend(
|
||||
_Backend.FLASH_ATTN_VLLM_V1)
|
||||
elif attn_backend == "TRITON_ATTN_VLLM_V1":
|
||||
attn_metadata_builder_cls, _ = get_attention_backend(
|
||||
_Backend.TRITON_ATTN_VLLM_V1)
|
||||
elif attn_backend == "TREE_ATTN":
|
||||
attn_metadata_builder_cls, _ = get_attention_backend(
|
||||
_Backend.TREE_ATTN)
|
||||
else:
|
||||
raise ValueError(f"Unsupported attention backend: {attn_backend}")
|
||||
|
||||
attn_metadata_builder = attn_metadata_builder_cls(
|
||||
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
|
||||
layer_names=proposer.attn_layer_names,
|
||||
|
||||
Reference in New Issue
Block a user