[ROCm][CI] Fix spec decode logprobs flakiness and parametrize tree attention backends (#34599)
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
@@ -52,7 +52,7 @@ def vllm_model(vllm_runner, request) -> Generator[VllmRunner, None, None]:
|
||||
# TODO: enable this once we support it for
|
||||
# prompt logprobs.
|
||||
enable_prefix_caching=request.param,
|
||||
gpu_memory_utilization=0.4, # up to 2 alive concurrently
|
||||
gpu_memory_utilization=0.4,
|
||||
) as vllm_model:
|
||||
yield vllm_model
|
||||
|
||||
@@ -366,21 +366,20 @@ def test_max_logprobs():
|
||||
Should also fail for `prompt_logprobs > max_logprobs`
|
||||
APC should not matter as this test checks basic request validation.
|
||||
"""
|
||||
runner = VllmRunner(
|
||||
with VllmRunner(
|
||||
"facebook/opt-125m",
|
||||
max_logprobs=1,
|
||||
enable_prefix_caching=False,
|
||||
# 2 other llms alive during whole session
|
||||
gpu_memory_utilization=0.15,
|
||||
max_model_len=256,
|
||||
)
|
||||
vllm_sampling_params = SamplingParams(logprobs=1)
|
||||
# should pass
|
||||
runner.generate(["Hello world"], sampling_params=vllm_sampling_params)
|
||||
) as runner:
|
||||
vllm_sampling_params = SamplingParams(logprobs=1)
|
||||
# should pass
|
||||
runner.generate(["Hello world"], sampling_params=vllm_sampling_params)
|
||||
|
||||
bad_sampling_params = SamplingParams(logprobs=2)
|
||||
with pytest.raises(ValueError):
|
||||
runner.generate(["Hello world"], sampling_params=bad_sampling_params)
|
||||
bad_sampling_params = SamplingParams(logprobs=2)
|
||||
with pytest.raises(ValueError):
|
||||
runner.generate(["Hello world"], sampling_params=bad_sampling_params)
|
||||
|
||||
|
||||
def test_none_logprobs(vllm_model, example_prompts):
|
||||
@@ -449,33 +448,31 @@ def test_all_logprobs(example_prompts):
|
||||
Args:
|
||||
example_prompts: list of example prompts (test fixture)
|
||||
"""
|
||||
runner = VllmRunner(
|
||||
with VllmRunner(
|
||||
"facebook/opt-125m",
|
||||
max_logprobs=-1,
|
||||
enable_prefix_caching=False,
|
||||
# 2 other llms alive during whole session
|
||||
gpu_memory_utilization=0.15,
|
||||
max_model_len=256,
|
||||
)
|
||||
) as runner:
|
||||
sampling_params_logprobs_all = SamplingParams(
|
||||
max_tokens=5, logprobs=-1, prompt_logprobs=-1
|
||||
)
|
||||
results_logprobs_all = runner.llm.generate(
|
||||
example_prompts, sampling_params=sampling_params_logprobs_all
|
||||
)
|
||||
vocab_size = runner.llm.llm_engine.model_config.get_vocab_size()
|
||||
|
||||
sampling_params_logprobs_all = SamplingParams(
|
||||
max_tokens=5, logprobs=-1, prompt_logprobs=-1
|
||||
)
|
||||
results_logprobs_all = runner.llm.generate(
|
||||
example_prompts, sampling_params=sampling_params_logprobs_all
|
||||
)
|
||||
vocab_size = runner.llm.llm_engine.model_config.get_vocab_size()
|
||||
|
||||
for i in range(len(results_logprobs_all)):
|
||||
logprobs = results_logprobs_all[i].outputs[0].logprobs
|
||||
prompt_logprobs = results_logprobs_all[i].prompt_logprobs
|
||||
assert logprobs is not None
|
||||
for logprob in logprobs:
|
||||
assert len(logprob) == vocab_size
|
||||
assert prompt_logprobs is not None
|
||||
assert prompt_logprobs[0] is None
|
||||
for prompt_logprob in prompt_logprobs[1:]:
|
||||
assert len(prompt_logprob) == vocab_size
|
||||
for i in range(len(results_logprobs_all)):
|
||||
logprobs = results_logprobs_all[i].outputs[0].logprobs
|
||||
prompt_logprobs = results_logprobs_all[i].prompt_logprobs
|
||||
assert logprobs is not None
|
||||
for logprob in logprobs:
|
||||
assert len(logprob) == vocab_size
|
||||
assert prompt_logprobs is not None
|
||||
assert prompt_logprobs[0] is None
|
||||
for prompt_logprob in prompt_logprobs[1:]:
|
||||
assert len(prompt_logprob) == vocab_size
|
||||
|
||||
|
||||
@pytest.mark.parametrize("logprobs_mode", get_args(LogprobsMode))
|
||||
@@ -495,24 +492,28 @@ def test_logprobs_mode(logprobs_mode: LogprobsMode):
|
||||
max_model_len=16,
|
||||
logprobs_mode=logprobs_mode,
|
||||
)
|
||||
vllm_sampling_params = SamplingParams(logprobs=1)
|
||||
results = llm.generate(["Hello world"], sampling_params=vllm_sampling_params)
|
||||
try:
|
||||
vllm_sampling_params = SamplingParams(logprobs=1)
|
||||
results = llm.generate(["Hello world"], sampling_params=vllm_sampling_params)
|
||||
|
||||
total_token_with_logprobs = 0
|
||||
positive_values = 0
|
||||
for output in results[0].outputs:
|
||||
for logprobs in output.logprobs:
|
||||
for token_id in logprobs:
|
||||
logprob = logprobs[token_id]
|
||||
if logprobs_mode in ("raw_logprobs", "processed_logprobs"):
|
||||
assert logprob.logprob <= 0
|
||||
if logprob.logprob > 0:
|
||||
positive_values = positive_values + 1
|
||||
total_token_with_logprobs = total_token_with_logprobs + 1
|
||||
assert total_token_with_logprobs >= len(results[0].outputs)
|
||||
if logprobs_mode in ("raw_logits", "processed_logits"):
|
||||
assert positive_values > 0
|
||||
del llm
|
||||
total_token_with_logprobs = 0
|
||||
positive_values = 0
|
||||
for output in results[0].outputs:
|
||||
for logprobs in output.logprobs:
|
||||
for token_id in logprobs:
|
||||
logprob = logprobs[token_id]
|
||||
if logprobs_mode in ("raw_logprobs", "processed_logprobs"):
|
||||
assert logprob.logprob <= 0
|
||||
if logprob.logprob > 0:
|
||||
positive_values = positive_values + 1
|
||||
total_token_with_logprobs = total_token_with_logprobs + 1
|
||||
assert total_token_with_logprobs >= len(results[0].outputs)
|
||||
if logprobs_mode in ("raw_logits", "processed_logits"):
|
||||
assert positive_values > 0
|
||||
finally:
|
||||
del llm
|
||||
torch.cuda.empty_cache()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
class TestCorrectDecodedToken:
|
||||
@@ -767,7 +768,7 @@ class TestCorrectDecodedToken:
|
||||
# Simulate cases where individual tokens decode to "<22>"
|
||||
# but combinations decode correctly
|
||||
if len(ids) == 1:
|
||||
if ids[0] == 3 or ids[0] == 4 or ids[0] == 8 or ids[0] == 9:
|
||||
if ids[0] in (3, 4, 8, 9):
|
||||
return "<EFBFBD>"
|
||||
elif len(ids) == 2:
|
||||
if ids == [2, 3]:
|
||||
@@ -809,42 +810,41 @@ def test_verify_tokens_integration():
|
||||
corrects tokens ending with the replacement character "<EFBFBD>".
|
||||
Uses facebook/opt-125m which is known to produce these issues.
|
||||
"""
|
||||
runner = VllmRunner(
|
||||
with VllmRunner(
|
||||
"facebook/opt-125m",
|
||||
max_logprobs=0,
|
||||
enable_prefix_caching=False,
|
||||
gpu_memory_utilization=0.15,
|
||||
max_model_len=256,
|
||||
)
|
||||
) as runner:
|
||||
# Use a prompt that triggers multi-byte UTF-8 issues
|
||||
# Based on user's example: "In this example,"
|
||||
test_prompts = ["In this example,"]
|
||||
|
||||
# Use a prompt that triggers multi-byte UTF-8 issues
|
||||
# Based on user's example: "In this example,"
|
||||
test_prompts = ["In this example,"]
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=16,
|
||||
temperature=0,
|
||||
logprobs=0,
|
||||
)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=16,
|
||||
temperature=0,
|
||||
logprobs=0,
|
||||
)
|
||||
results = runner.llm.generate(test_prompts, sampling_params=sampling_params)
|
||||
|
||||
results = runner.llm.generate(test_prompts, sampling_params=sampling_params)
|
||||
|
||||
# Verify that decoded tokens don't contain replacement characters
|
||||
for result in results:
|
||||
assert result.outputs[0].logprobs is not None
|
||||
for logprob_dict in result.outputs[0].logprobs:
|
||||
for token_id, logprob_info in logprob_dict.items():
|
||||
decoded_token = logprob_info.decoded_token
|
||||
# Decoded tokens should not end with replacement character
|
||||
# They should either be corrected or empty string
|
||||
assert not decoded_token.endswith("<EFBFBD>"), (
|
||||
f"Token {token_id} decoded to '{decoded_token}' which "
|
||||
f"ends with replacement character"
|
||||
)
|
||||
# Decoded tokens should not contain lone replacement characters
|
||||
assert decoded_token != "<EFBFBD>", (
|
||||
f"Token {token_id} is a lone replacement character"
|
||||
)
|
||||
# Verify that decoded tokens don't contain replacement characters
|
||||
for result in results:
|
||||
assert result.outputs[0].logprobs is not None
|
||||
for logprob_dict in result.outputs[0].logprobs:
|
||||
for token_id, logprob_info in logprob_dict.items():
|
||||
decoded_token = logprob_info.decoded_token
|
||||
# Decoded tokens should not end with replacement character
|
||||
# They should either be corrected or empty string
|
||||
assert not decoded_token.endswith("<EFBFBD>"), (
|
||||
f"Token {token_id} decoded to '{decoded_token}' which "
|
||||
f"ends with replacement character"
|
||||
)
|
||||
# Decoded tokens should not contain lone replacement characters
|
||||
assert decoded_token != "<EFBFBD>", (
|
||||
f"Token {token_id} is a lone replacement character"
|
||||
)
|
||||
|
||||
|
||||
def test_utf8_edge_cases_with_real_model():
|
||||
@@ -853,45 +853,44 @@ def test_utf8_edge_cases_with_real_model():
|
||||
Tests prompts that are likely to trigger byte-fallback tokenization
|
||||
and multi-byte UTF-8 splitting.
|
||||
"""
|
||||
runner = VllmRunner(
|
||||
with VllmRunner(
|
||||
"facebook/opt-125m",
|
||||
max_logprobs=1,
|
||||
enable_prefix_caching=False,
|
||||
gpu_memory_utilization=0.15,
|
||||
max_model_len=256,
|
||||
)
|
||||
) as runner:
|
||||
# Prompts with various multi-byte UTF-8 characters
|
||||
test_prompts = [
|
||||
'Smart quotes: "Hello"', # Curly quotes
|
||||
"Em dash — test", # Em dash
|
||||
"Ellipsis… continues", # Ellipsis
|
||||
"Chinese: 你好", # Chinese characters
|
||||
"Emoji: 😀 🎉", # Emojis
|
||||
'Mixed: "quoted" — with symbols', # Mixed
|
||||
]
|
||||
|
||||
# Prompts with various multi-byte UTF-8 characters
|
||||
test_prompts = [
|
||||
'Smart quotes: "Hello"', # Curly quotes
|
||||
"Em dash — test", # Em dash
|
||||
"Ellipsis… continues", # Ellipsis
|
||||
"Chinese: 你好", # Chinese characters
|
||||
"Emoji: 😀 🎉", # Emojis
|
||||
'Mixed: "quoted" — with symbols', # Mixed
|
||||
]
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=10,
|
||||
temperature=0,
|
||||
logprobs=1,
|
||||
)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=10,
|
||||
temperature=0,
|
||||
logprobs=1,
|
||||
)
|
||||
results = runner.llm.generate(test_prompts, sampling_params=sampling_params)
|
||||
|
||||
results = runner.llm.generate(test_prompts, sampling_params=sampling_params)
|
||||
for i, result in enumerate(results):
|
||||
prompt = test_prompts[i]
|
||||
assert result.outputs[0].logprobs is not None
|
||||
|
||||
for i, result in enumerate(results):
|
||||
prompt = test_prompts[i]
|
||||
assert result.outputs[0].logprobs is not None
|
||||
|
||||
# Check that no decoded tokens end with replacement character
|
||||
for logprob_dict in result.outputs[0].logprobs:
|
||||
for token_id, logprob_info in logprob_dict.items():
|
||||
decoded_token = logprob_info.decoded_token
|
||||
assert not decoded_token.endswith("<EFBFBD>"), (
|
||||
f"Prompt: '{prompt}'\n"
|
||||
f"Token {token_id} decoded to '{decoded_token}' which "
|
||||
f"ends with replacement character"
|
||||
)
|
||||
# Check that no decoded tokens end with replacement character
|
||||
for logprob_dict in result.outputs[0].logprobs:
|
||||
for token_id, logprob_info in logprob_dict.items():
|
||||
decoded_token = logprob_info.decoded_token
|
||||
assert not decoded_token.endswith("<EFBFBD>"), (
|
||||
f"Prompt: '{prompt}'\n"
|
||||
f"Token {token_id} decoded to '{decoded_token}' which "
|
||||
f"ends with replacement character"
|
||||
)
|
||||
|
||||
|
||||
def test_correct_decoded_token_preserves_valid_tokens():
|
||||
@@ -901,36 +900,35 @@ def test_correct_decoded_token_preserves_valid_tokens():
|
||||
ending with "<EFBFBD>", but this test verifies the broader _verify_tokens
|
||||
logic doesn't affect valid tokens.
|
||||
"""
|
||||
runner = VllmRunner(
|
||||
with VllmRunner(
|
||||
"facebook/opt-125m",
|
||||
max_logprobs=2,
|
||||
enable_prefix_caching=False,
|
||||
gpu_memory_utilization=0.15,
|
||||
max_model_len=256,
|
||||
)
|
||||
) as runner:
|
||||
# Simple prompt with standard ASCII characters
|
||||
test_prompts = ["Hello world, this is a test."]
|
||||
|
||||
# Simple prompt with standard ASCII characters
|
||||
test_prompts = ["Hello world, this is a test."]
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=10,
|
||||
temperature=0,
|
||||
logprobs=2,
|
||||
)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=10,
|
||||
temperature=0,
|
||||
logprobs=2,
|
||||
)
|
||||
results = runner.llm.generate(test_prompts, sampling_params=sampling_params)
|
||||
|
||||
results = runner.llm.generate(test_prompts, sampling_params=sampling_params)
|
||||
for result in results:
|
||||
assert result.outputs[0].logprobs is not None
|
||||
|
||||
for result in results:
|
||||
assert result.outputs[0].logprobs is not None
|
||||
|
||||
# All decoded tokens should be valid strings
|
||||
for logprob_dict in result.outputs[0].logprobs:
|
||||
for token_id, logprob_info in logprob_dict.items():
|
||||
decoded_token = logprob_info.decoded_token
|
||||
# Valid tokens should be non-empty strings (or empty if corrected)
|
||||
assert isinstance(decoded_token, str)
|
||||
# Should not contain replacement character
|
||||
assert "<EFBFBD>" not in decoded_token
|
||||
# All decoded tokens should be valid strings
|
||||
for logprob_dict in result.outputs[0].logprobs:
|
||||
for token_id, logprob_info in logprob_dict.items():
|
||||
decoded_token = logprob_info.decoded_token
|
||||
# Valid tokens should be non-empty strings (or empty if corrected)
|
||||
assert isinstance(decoded_token, str)
|
||||
# Should not contain replacement character
|
||||
assert "<EFBFBD>" not in decoded_token
|
||||
|
||||
|
||||
@pytest.mark.parametrize("logprobs_mode", get_args(LogprobsMode))
|
||||
@@ -985,16 +983,33 @@ def test_correct_decoded_token_preserves_valid_tokens():
|
||||
def test_spec_decode_logprobs(
|
||||
logprobs_mode: LogprobsMode,
|
||||
model_setup: tuple[str, str, dict, int],
|
||||
monkeypatch,
|
||||
):
|
||||
"""Spec decode logprobs should match those of the base model.
|
||||
|
||||
Runs the base model and spec decode model sequentially, ensuring
|
||||
only one LLM instance is alive at a time to avoid GPU memory
|
||||
contention. Both use identical chunked prefill settings and eager
|
||||
mode to control for infrastructure differences.
|
||||
|
||||
Args:
|
||||
logprobs_mode: logprobs mode.
|
||||
model_setup: Tuple of (method, base model name,
|
||||
speculative_config dict, top_logprobs).
|
||||
monkeypatch: pytest fixture for setting env vars.
|
||||
"""
|
||||
from vllm import LLM
|
||||
|
||||
# The ROCm skinny GEMM kernels (gemm_kernels.cu) are
|
||||
# non-deterministic across LLM instantiations due to persistent
|
||||
# workgroup scheduling and wave-level shuffle reductions, which
|
||||
# causes logprob differences that get misattributed to spec decode.
|
||||
# Disable them so this test isolates spec decode correctness only.
|
||||
# TODO(akaratza): Remove this workaround once the follow-up to
|
||||
# https://github.com/vllm-project/vllm/pull/33493#issuecomment-3906083975
|
||||
# lands with a determinism fix for wvSplitK kernels.
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_SKINNY_GEMM", "0")
|
||||
|
||||
method, model_name, spec_config, top_logprobs = model_setup
|
||||
|
||||
prompt = "Hello world " * 50
|
||||
@@ -1068,8 +1083,17 @@ def test_spec_decode_logprobs(
|
||||
for ref_logprob, spec_logprob in zip(ref_logprobs, spec_logprobs):
|
||||
assert math.isclose(
|
||||
ref_logprob.logprob, spec_logprob.logprob, rel_tol=5e-2, abs_tol=1e-1
|
||||
), (
|
||||
f"Logprob mismatch: ref={ref_logprob.logprob} "
|
||||
f"spec={spec_logprob.logprob} "
|
||||
f"diff={abs(ref_logprob.logprob - spec_logprob.logprob)} "
|
||||
f"(token={ref_logprob.decoded_token!r})"
|
||||
)
|
||||
assert ref_logprob.rank == spec_logprob.rank, (
|
||||
f"Rank mismatch: ref={ref_logprob.rank} "
|
||||
f"spec={spec_logprob.rank} "
|
||||
f"(token={ref_logprob.decoded_token!r})"
|
||||
)
|
||||
assert ref_logprob.rank == spec_logprob.rank
|
||||
assert ref_logprob.decoded_token == spec_logprob.decoded_token
|
||||
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ from tests.v1.attention.utils import (
|
||||
try_get_attention_backend,
|
||||
)
|
||||
from vllm.config import ParallelConfig, SpeculativeConfig
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.attention.backend import CommonAttentionMetadata
|
||||
from vllm.v1.attention.backends.fa_utils import is_flash_attn_varlen_func_available
|
||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||
@@ -23,11 +24,156 @@ if not is_flash_attn_varlen_func_available():
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# KV cache layout adaptation
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Two KV cache layouts exist across backends:
|
||||
#
|
||||
# Flash layout: (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||
# - dim 0 separates key (index 0) and value (index 1)
|
||||
# - Used by: FLASH_ATTN, TREE_ATTN, ROCM_AITER_FA, ROCM_ATTN
|
||||
#
|
||||
# Block layout: (num_blocks, 2, block_size, num_kv_heads, head_size)
|
||||
# - dim 1 separates key (index 0) and value (index 1)
|
||||
# - Used by: TRITON_ATTN
|
||||
#
|
||||
# The test creates KV caches in flash layout (the canonical format used by
|
||||
# tree attention). When a reference backend needs block layout we transpose
|
||||
# dims 0 and 1.
|
||||
#
|
||||
# Note: ROCM_ATTN uses flash layout for storage but its forward path calls
|
||||
# PagedAttention.split_kv_cache which reinterprets the raw memory as paged
|
||||
# layout (num_blocks, num_kv_heads, head_size//x, block_size, x). This is
|
||||
# a view-level incompatibility, not a transpose - see the TODO in
|
||||
# _get_available_reference_backends for details.
|
||||
#
|
||||
# TODO: Replace this mapping with a `KV_CACHE_LAYOUT` class attribute on each
|
||||
# AttentionImpl so the layout is self-documented by the backend itself, e.g.:
|
||||
# class TritonAttentionImpl(AttentionImpl):
|
||||
# KV_CACHE_LAYOUT = "block"
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
_BLOCK_KV_LAYOUT_BACKENDS = frozenset(
|
||||
{
|
||||
AttentionBackendEnum.TRITON_ATTN,
|
||||
}
|
||||
)
|
||||
|
||||
# Backends whose do_kv_cache_update requires engine-level state (e.g.
|
||||
# ForwardContext) that is not available in this test harness, but whose
|
||||
# KV cache is flash layout and can be written with reshape_and_cache_flash.
|
||||
# When a backend is listed here, forward_attention() bypasses
|
||||
# do_kv_cache_update and writes directly to the cache.
|
||||
_NEEDS_DIRECT_CACHE_UPDATE = frozenset(
|
||||
{
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}
|
||||
)
|
||||
|
||||
# Backends with known test-harness incompatibilities - see the TODOs
|
||||
# inside _get_available_reference_backends for details.
|
||||
_INCOMPATIBLE_REFERENCE_BACKENDS = frozenset(
|
||||
{
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
AttentionBackendEnum.ROCM_ATTN,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _adapt_kv_cache_for_backend(
|
||||
kv_cache: torch.Tensor,
|
||||
backend: AttentionBackendEnum,
|
||||
) -> torch.Tensor:
|
||||
"""Convert kv_cache from flash layout ``(2, num_blocks, ...)`` to block
|
||||
layout ``(num_blocks, 2, ...)`` if the backend requires it. Returns the
|
||||
original tensor unchanged when no conversion is needed."""
|
||||
if backend in _BLOCK_KV_LAYOUT_BACKENDS:
|
||||
return kv_cache.transpose(0, 1).contiguous()
|
||||
return kv_cache
|
||||
|
||||
|
||||
def _get_platform_default_backend() -> AttentionBackendEnum:
|
||||
"""Ask the platform what backend it would auto-select at runtime."""
|
||||
from vllm.v1.attention.selector import AttentionSelectorConfig
|
||||
|
||||
config = AttentionSelectorConfig(
|
||||
block_size=32,
|
||||
kv_cache_dtype="auto",
|
||||
use_mla=False,
|
||||
use_sparse=False,
|
||||
head_size=128,
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
backend_path = current_platform.get_attn_backend_cls(
|
||||
selected_backend=None,
|
||||
attn_selector_config=config,
|
||||
)
|
||||
for backend in AttentionBackendEnum:
|
||||
try:
|
||||
if backend.get_path() == backend_path:
|
||||
return backend
|
||||
except ValueError:
|
||||
continue
|
||||
raise RuntimeError(
|
||||
f"Platform returned backend path '{backend_path}' "
|
||||
f"that doesn't match any AttentionBackendEnum member."
|
||||
)
|
||||
|
||||
|
||||
def _get_available_reference_backends() -> list[AttentionBackendEnum]:
|
||||
"""Collect all reference backends the current platform can run.
|
||||
|
||||
On CUDA this is just FLASH_ATTN. On ROCm this includes the platform
|
||||
default plus every backend the hardware supports, so the test validates
|
||||
tree attention against all of them.
|
||||
"""
|
||||
if current_platform.is_rocm():
|
||||
backends: list[AttentionBackendEnum] = []
|
||||
|
||||
# 1. Whatever the platform would auto-select at runtime.
|
||||
default_backend = _get_platform_default_backend()
|
||||
if default_backend not in _INCOMPATIBLE_REFERENCE_BACKENDS:
|
||||
backends.append(default_backend)
|
||||
|
||||
# 2. TRITON_ATTN - always available on ROCm.
|
||||
if AttentionBackendEnum.TRITON_ATTN not in backends:
|
||||
backends.append(AttentionBackendEnum.TRITON_ATTN)
|
||||
|
||||
# TODO: Enable ROCM_ATTN. Its forward path uses
|
||||
# PagedAttention.split_kv_cache which reinterprets the raw
|
||||
# cache memory as paged layout:
|
||||
# key: (num_blocks, num_kv_heads, head_size//x, block_size, x)
|
||||
# value: (num_blocks, num_kv_heads, head_size, block_size)
|
||||
# Tree attention writes prefix data in NHD flash layout, so the
|
||||
# same bytes produce completely different values when read in
|
||||
# paged format. Supporting ROCM_ATTN would require writing
|
||||
# prefix data via PagedAttention.write_to_paged_cache into a
|
||||
# separate paged-format KV cache.
|
||||
|
||||
# TODO: Enable ROCM_AITER_FA. Its metadata builder reads head
|
||||
# counts from the model config at construction time and
|
||||
# allocates extend_workspace with those dimensions. The test
|
||||
# uses independent head count parameters (num_heads=2/4,
|
||||
# num_kv_heads=2) that don't match the model config
|
||||
# (Llama-3-8B: 32 q heads, 8 kv heads), causing a head count
|
||||
# mismatch in flash_attn_varlen_func during extend_forward.
|
||||
# Fixing this requires either matching test head counts to the
|
||||
# model config or decoupling the builder from model config
|
||||
# head geometry. The direct cache update path
|
||||
# (_NEEDS_DIRECT_CACHE_UPDATE) is already in place for when
|
||||
# this is resolved.
|
||||
|
||||
return backends
|
||||
|
||||
# CUDA: flash attention.
|
||||
return [AttentionBackendEnum.FLASH_ATTN]
|
||||
|
||||
|
||||
class MockAttentionLayer(torch.nn.Module):
|
||||
_q_scale = torch.tensor(1.0, dtype=torch.float32, device="cuda")
|
||||
_k_scale = torch.tensor(1.0, dtype=torch.float32, device="cuda")
|
||||
_v_scale = torch.tensor(1.0, dtype=torch.float32, device="cuda")
|
||||
layer_name = "mock_layer"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@@ -48,6 +194,13 @@ def forward_attention(
|
||||
spec_token_tree: str | None = None,
|
||||
num_spec_tokens: int = 0,
|
||||
) -> torch.Tensor:
|
||||
"""Run a single attention forward pass through the given backend.
|
||||
|
||||
``kv_cache`` is expected in **flash layout**
|
||||
``(2, num_blocks, block_size, num_kv_heads, head_size)``.
|
||||
It is automatically converted when the target backend needs a
|
||||
different layout.
|
||||
"""
|
||||
batch_size, q_len, num_heads, dim_per_head = q.shape
|
||||
num_kv_heads = k.shape[-2]
|
||||
# Initialize the query and KV sequence lengths.
|
||||
@@ -116,31 +269,58 @@ def forward_attention(
|
||||
kv_cache_dtype="auto",
|
||||
)
|
||||
|
||||
# Adapt KV cache layout for this backend.
|
||||
adapted_kv_cache = _adapt_kv_cache_for_backend(kv_cache, backend)
|
||||
|
||||
# Run forward pass and return output.
|
||||
query = q.view(-1, num_heads, dim_per_head)
|
||||
key = k.view(-1, num_kv_heads, dim_per_head)
|
||||
value = v.view(-1, num_kv_heads, dim_per_head)
|
||||
output = torch.empty_like(query)
|
||||
if not try_backend_includes_kv_cache_update(backend):
|
||||
instance.do_kv_cache_update(
|
||||
layer=layer,
|
||||
key=key,
|
||||
value=value,
|
||||
kv_cache=kv_cache,
|
||||
slot_mapping=attn_metadata.slot_mapping,
|
||||
)
|
||||
if backend in _NEEDS_DIRECT_CACHE_UPDATE:
|
||||
# This backend's do_kv_cache_update requires engine-level
|
||||
# ForwardContext that isn't available in this test harness.
|
||||
# Write directly using reshape_and_cache_flash since the
|
||||
# KV cache layout is identical (flash layout, unbind on dim 0).
|
||||
key_cache, value_cache = adapted_kv_cache.unbind(0)
|
||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
"auto",
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
else:
|
||||
instance.do_kv_cache_update(
|
||||
layer=layer,
|
||||
key=key,
|
||||
value=value,
|
||||
kv_cache=adapted_kv_cache,
|
||||
slot_mapping=attn_metadata.slot_mapping,
|
||||
)
|
||||
return instance.forward(
|
||||
layer=layer,
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
kv_cache=kv_cache.clone(),
|
||||
kv_cache=adapted_kv_cache.clone(),
|
||||
attn_metadata=attn_metadata,
|
||||
output=output,
|
||||
)
|
||||
|
||||
|
||||
def test_tree_attn_correctness() -> None:
|
||||
@pytest.mark.parametrize(
|
||||
"reference_backend",
|
||||
_get_available_reference_backends(),
|
||||
ids=lambda b: b.name,
|
||||
)
|
||||
def test_tree_attn_correctness(
|
||||
reference_backend: AttentionBackendEnum,
|
||||
) -> None:
|
||||
torch.manual_seed(42)
|
||||
torch.cuda.manual_seed_all(42)
|
||||
|
||||
@@ -205,7 +385,9 @@ def test_tree_attn_correctness() -> None:
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
# Set up the block table and KV cache for paged KV.
|
||||
# KV cache in flash layout - the canonical format for
|
||||
# tree attention. forward_attention() handles conversion
|
||||
# when needed.
|
||||
assert max_sequence_length % block_size == 0
|
||||
max_blocks_per_batch = max_sequence_length // block_size
|
||||
kv_cache = torch.randn(
|
||||
@@ -263,9 +445,7 @@ def test_tree_attn_correctness() -> None:
|
||||
num_spec_tokens=tree_size_q - 1,
|
||||
).view(batch_size, -1, num_heads, dim_per_head)
|
||||
|
||||
# Verify that the chain attention output for each
|
||||
# branch of the tree (computed using FA3) matches
|
||||
# the tree attention output.
|
||||
# Verify each branch against the reference backend.
|
||||
for q_index in range(tree_size_q):
|
||||
# Get the q, k, and v for the branch.
|
||||
branch_mask = tree_attn_mask[q_index, :]
|
||||
@@ -286,8 +466,8 @@ def test_tree_attn_correctness() -> None:
|
||||
branch_positions, block_table, block_size
|
||||
)
|
||||
|
||||
# Compute flash attention for the branch.
|
||||
flash_attn_output = forward_attention(
|
||||
# Reference attention for this branch.
|
||||
ref_output = forward_attention(
|
||||
q=q_branch,
|
||||
k=k_branch,
|
||||
v=v_branch,
|
||||
@@ -295,16 +475,17 @@ def test_tree_attn_correctness() -> None:
|
||||
block_table=block_table,
|
||||
slot_mapping=branch_slot_mapping,
|
||||
seqlen_k=sequence_position + q_len,
|
||||
backend=AttentionBackendEnum.FLASH_ATTN,
|
||||
backend=reference_backend,
|
||||
).view(batch_size, -1, num_heads, dim_per_head)
|
||||
|
||||
# Compare the outputs.
|
||||
assert torch.allclose(
|
||||
tree_attn_output[:, branch_indices],
|
||||
flash_attn_output,
|
||||
ref_output,
|
||||
atol=7.81e-3,
|
||||
), (
|
||||
f"outputs are not close for "
|
||||
f"reference_backend: {reference_backend.name}, "
|
||||
f"batch_size: {batch_size}, "
|
||||
f"num_heads: {num_heads}, "
|
||||
f"sequence_position: {sequence_position}, "
|
||||
|
||||
Reference in New Issue
Block a user