Add tree attention backend for v1 (part 1) (#20401)

Signed-off-by: Giancarlo Delfin <gdelfin@meta.com>
This commit is contained in:
Giancarlo Delfin
2025-08-03 22:13:26 -07:00
committed by GitHub
parent c2e75b3c11
commit aa7012eb6d
12 changed files with 1098 additions and 25 deletions

View File

@@ -202,7 +202,9 @@ def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 8])
def test_propose(num_speculative_tokens):
@pytest.mark.parametrize("backend",
[_Backend.FLASH_ATTN_VLLM_V1, _Backend.TREE_ATTN])
def test_propose(num_speculative_tokens, backend):
# Use GPU device
device = torch.device(current_platform.device_type)
@@ -301,8 +303,7 @@ def test_propose(num_speculative_tokens):
device=device)
sampling_metadata = mock.MagicMock()
attn_metadata_builder_cls, _ = get_attention_backend(
_Backend.FLASH_ATTN_VLLM_V1)
attn_metadata_builder_cls, _ = get_attention_backend(backend)
attn_metadata_builder = attn_metadata_builder_cls(
kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config),
layer_names=proposer.attn_layer_names,