Add tree attention backend for v1 (part 1) (#20401)
Signed-off-by: Giancarlo Delfin <gdelfin@meta.com>
This commit is contained in:
@@ -202,7 +202,9 @@ def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 8])
|
||||
def test_propose(num_speculative_tokens):
|
||||
@pytest.mark.parametrize("backend",
|
||||
[_Backend.FLASH_ATTN_VLLM_V1, _Backend.TREE_ATTN])
|
||||
def test_propose(num_speculative_tokens, backend):
|
||||
# Use GPU device
|
||||
device = torch.device(current_platform.device_type)
|
||||
|
||||
@@ -301,8 +303,7 @@ def test_propose(num_speculative_tokens):
|
||||
device=device)
|
||||
sampling_metadata = mock.MagicMock()
|
||||
|
||||
attn_metadata_builder_cls, _ = get_attention_backend(
|
||||
_Backend.FLASH_ATTN_VLLM_V1)
|
||||
attn_metadata_builder_cls, _ = get_attention_backend(backend)
|
||||
attn_metadata_builder = attn_metadata_builder_cls(
|
||||
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
|
||||
layer_names=proposer.attn_layer_names,
|
||||
|
||||
Reference in New Issue
Block a user