[ROCm][CI] Fix spec decode logprobs flakiness and parametrize tree attention backends (#34599)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
Andreas Karatzas
2026-02-20 22:25:23 -06:00
committed by GitHub
parent cf93c1a128
commit 54254f7a61
2 changed files with 351 additions and 146 deletions

View File

@@ -13,6 +13,7 @@ from tests.v1.attention.utils import (
try_get_attention_backend,
)
from vllm.config import ParallelConfig, SpeculativeConfig
from vllm.platforms import current_platform
from vllm.v1.attention.backend import CommonAttentionMetadata
from vllm.v1.attention.backends.fa_utils import is_flash_attn_varlen_func_available
from vllm.v1.attention.backends.registry import AttentionBackendEnum
@@ -23,11 +24,156 @@ if not is_flash_attn_varlen_func_available():
allow_module_level=True,
)
# --------------------------------------------------------------------------- #
# KV cache layout adaptation
# --------------------------------------------------------------------------- #
# Two KV cache layouts exist across backends:
#
# Flash layout: (2, num_blocks, block_size, num_kv_heads, head_size)
# - dim 0 separates key (index 0) and value (index 1)
# - Used by: FLASH_ATTN, TREE_ATTN, ROCM_AITER_FA, ROCM_ATTN
#
# Block layout: (num_blocks, 2, block_size, num_kv_heads, head_size)
# - dim 1 separates key (index 0) and value (index 1)
# - Used by: TRITON_ATTN
#
# The test creates KV caches in flash layout (the canonical format used by
# tree attention). When a reference backend needs block layout we transpose
# dims 0 and 1.
#
# Note: ROCM_ATTN uses flash layout for storage but its forward path calls
# PagedAttention.split_kv_cache which reinterprets the raw memory as paged
# layout (num_blocks, num_kv_heads, head_size//x, block_size, x). This is
# a view-level incompatibility, not a transpose - see the TODO in
# _get_available_reference_backends for details.
#
# TODO: Replace this mapping with a `KV_CACHE_LAYOUT` class attribute on each
# AttentionImpl so the layout is self-documented by the backend itself, e.g.:
# class TritonAttentionImpl(AttentionImpl):
# KV_CACHE_LAYOUT = "block"
# --------------------------------------------------------------------------- #
_BLOCK_KV_LAYOUT_BACKENDS = frozenset(
{
AttentionBackendEnum.TRITON_ATTN,
}
)
# Backends whose do_kv_cache_update requires engine-level state (e.g.
# ForwardContext) that is not available in this test harness, but whose
# KV cache is flash layout and can be written with reshape_and_cache_flash.
# When a backend is listed here, forward_attention() bypasses
# do_kv_cache_update and writes directly to the cache.
_NEEDS_DIRECT_CACHE_UPDATE = frozenset(
{
AttentionBackendEnum.ROCM_AITER_FA,
}
)
# Backends with known test-harness incompatibilities - see the TODOs
# inside _get_available_reference_backends for details.
_INCOMPATIBLE_REFERENCE_BACKENDS = frozenset(
{
AttentionBackendEnum.ROCM_AITER_FA,
AttentionBackendEnum.ROCM_ATTN,
}
)
def _adapt_kv_cache_for_backend(
kv_cache: torch.Tensor,
backend: AttentionBackendEnum,
) -> torch.Tensor:
"""Convert kv_cache from flash layout ``(2, num_blocks, ...)`` to block
layout ``(num_blocks, 2, ...)`` if the backend requires it. Returns the
original tensor unchanged when no conversion is needed."""
if backend in _BLOCK_KV_LAYOUT_BACKENDS:
return kv_cache.transpose(0, 1).contiguous()
return kv_cache
def _get_platform_default_backend() -> AttentionBackendEnum:
"""Ask the platform what backend it would auto-select at runtime."""
from vllm.v1.attention.selector import AttentionSelectorConfig
config = AttentionSelectorConfig(
block_size=32,
kv_cache_dtype="auto",
use_mla=False,
use_sparse=False,
head_size=128,
dtype=torch.bfloat16,
)
backend_path = current_platform.get_attn_backend_cls(
selected_backend=None,
attn_selector_config=config,
)
for backend in AttentionBackendEnum:
try:
if backend.get_path() == backend_path:
return backend
except ValueError:
continue
raise RuntimeError(
f"Platform returned backend path '{backend_path}' "
f"that doesn't match any AttentionBackendEnum member."
)
def _get_available_reference_backends() -> list[AttentionBackendEnum]:
"""Collect all reference backends the current platform can run.
On CUDA this is just FLASH_ATTN. On ROCm this includes the platform
default plus every backend the hardware supports, so the test validates
tree attention against all of them.
"""
if current_platform.is_rocm():
backends: list[AttentionBackendEnum] = []
# 1. Whatever the platform would auto-select at runtime.
default_backend = _get_platform_default_backend()
if default_backend not in _INCOMPATIBLE_REFERENCE_BACKENDS:
backends.append(default_backend)
# 2. TRITON_ATTN - always available on ROCm.
if AttentionBackendEnum.TRITON_ATTN not in backends:
backends.append(AttentionBackendEnum.TRITON_ATTN)
# TODO: Enable ROCM_ATTN. Its forward path uses
# PagedAttention.split_kv_cache which reinterprets the raw
# cache memory as paged layout:
# key: (num_blocks, num_kv_heads, head_size//x, block_size, x)
# value: (num_blocks, num_kv_heads, head_size, block_size)
# Tree attention writes prefix data in NHD flash layout, so the
# same bytes produce completely different values when read in
# paged format. Supporting ROCM_ATTN would require writing
# prefix data via PagedAttention.write_to_paged_cache into a
# separate paged-format KV cache.
# TODO: Enable ROCM_AITER_FA. Its metadata builder reads head
# counts from the model config at construction time and
# allocates extend_workspace with those dimensions. The test
# uses independent head count parameters (num_heads=2/4,
# num_kv_heads=2) that don't match the model config
# (Llama-3-8B: 32 q heads, 8 kv heads), causing a head count
# mismatch in flash_attn_varlen_func during extend_forward.
# Fixing this requires either matching test head counts to the
# model config or decoupling the builder from model config
# head geometry. The direct cache update path
# (_NEEDS_DIRECT_CACHE_UPDATE) is already in place for when
# this is resolved.
return backends
# CUDA: flash attention.
return [AttentionBackendEnum.FLASH_ATTN]
class MockAttentionLayer(torch.nn.Module):
_q_scale = torch.tensor(1.0, dtype=torch.float32, device="cuda")
_k_scale = torch.tensor(1.0, dtype=torch.float32, device="cuda")
_v_scale = torch.tensor(1.0, dtype=torch.float32, device="cuda")
layer_name = "mock_layer"
def __init__(self):
super().__init__()
@@ -48,6 +194,13 @@ def forward_attention(
spec_token_tree: str | None = None,
num_spec_tokens: int = 0,
) -> torch.Tensor:
"""Run a single attention forward pass through the given backend.
``kv_cache`` is expected in **flash layout**
``(2, num_blocks, block_size, num_kv_heads, head_size)``.
It is automatically converted when the target backend needs a
different layout.
"""
batch_size, q_len, num_heads, dim_per_head = q.shape
num_kv_heads = k.shape[-2]
# Initialize the query and KV sequence lengths.
@@ -116,31 +269,58 @@ def forward_attention(
kv_cache_dtype="auto",
)
# Adapt KV cache layout for this backend.
adapted_kv_cache = _adapt_kv_cache_for_backend(kv_cache, backend)
# Run forward pass and return output.
query = q.view(-1, num_heads, dim_per_head)
key = k.view(-1, num_kv_heads, dim_per_head)
value = v.view(-1, num_kv_heads, dim_per_head)
output = torch.empty_like(query)
if not try_backend_includes_kv_cache_update(backend):
instance.do_kv_cache_update(
layer=layer,
key=key,
value=value,
kv_cache=kv_cache,
slot_mapping=attn_metadata.slot_mapping,
)
if backend in _NEEDS_DIRECT_CACHE_UPDATE:
# This backend's do_kv_cache_update requires engine-level
# ForwardContext that isn't available in this test harness.
# Write directly using reshape_and_cache_flash since the
# KV cache layout is identical (flash layout, unbind on dim 0).
key_cache, value_cache = adapted_kv_cache.unbind(0)
torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
"auto",
layer._k_scale,
layer._v_scale,
)
else:
instance.do_kv_cache_update(
layer=layer,
key=key,
value=value,
kv_cache=adapted_kv_cache,
slot_mapping=attn_metadata.slot_mapping,
)
return instance.forward(
layer=layer,
query=query,
key=key,
value=value,
kv_cache=kv_cache.clone(),
kv_cache=adapted_kv_cache.clone(),
attn_metadata=attn_metadata,
output=output,
)
def test_tree_attn_correctness() -> None:
@pytest.mark.parametrize(
"reference_backend",
_get_available_reference_backends(),
ids=lambda b: b.name,
)
def test_tree_attn_correctness(
reference_backend: AttentionBackendEnum,
) -> None:
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)
@@ -205,7 +385,9 @@ def test_tree_attn_correctness() -> None:
dtype=torch.bfloat16,
)
# Set up the block table and KV cache for paged KV.
# KV cache in flash layout - the canonical format for
# tree attention. forward_attention() handles conversion
# when needed.
assert max_sequence_length % block_size == 0
max_blocks_per_batch = max_sequence_length // block_size
kv_cache = torch.randn(
@@ -263,9 +445,7 @@ def test_tree_attn_correctness() -> None:
num_spec_tokens=tree_size_q - 1,
).view(batch_size, -1, num_heads, dim_per_head)
# Verify that the chain attention output for each
# branch of the tree (computed using FA3) matches
# the tree attention output.
# Verify each branch against the reference backend.
for q_index in range(tree_size_q):
# Get the q, k, and v for the branch.
branch_mask = tree_attn_mask[q_index, :]
@@ -286,8 +466,8 @@ def test_tree_attn_correctness() -> None:
branch_positions, block_table, block_size
)
# Compute flash attention for the branch.
flash_attn_output = forward_attention(
# Reference attention for this branch.
ref_output = forward_attention(
q=q_branch,
k=k_branch,
v=v_branch,
@@ -295,16 +475,17 @@ def test_tree_attn_correctness() -> None:
block_table=block_table,
slot_mapping=branch_slot_mapping,
seqlen_k=sequence_position + q_len,
backend=AttentionBackendEnum.FLASH_ATTN,
backend=reference_backend,
).view(batch_size, -1, num_heads, dim_per_head)
# Compare the outputs.
assert torch.allclose(
tree_attn_output[:, branch_indices],
flash_attn_output,
ref_output,
atol=7.81e-3,
), (
f"outputs are not close for "
f"reference_backend: {reference_backend.name}, "
f"batch_size: {batch_size}, "
f"num_heads: {num_heads}, "
f"sequence_position: {sequence_position}, "