Updates to Flex + VLLm integration (#21416)
Signed-off-by: drisspg <drisspguessous@gmail.com>
This commit is contained in:
@@ -9,12 +9,17 @@ import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
from vllm import SamplingParams
|
||||
from tests.v1.attention.utils import (BatchSpec, create_common_attn_metadata,
|
||||
create_standard_kv_cache_spec,
|
||||
create_vllm_config)
|
||||
from vllm.v1.attention.backends.flex_attention import (
|
||||
FlexAttentionMetadataBuilder)
|
||||
|
||||
from ..models.utils import check_embeddings_close
|
||||
from ..models.utils import check_embeddings_close, check_logprobs_close
|
||||
|
||||
TORCH_VERSION = version.parse(torch.__version__)
|
||||
MINIMUM_TORCH_VERSION = version.parse("2.7.0")
|
||||
DIRECT_BUILD_VERSION = version.parse("2.9.dev0")
|
||||
|
||||
|
||||
def set_seed(seed):
|
||||
@@ -34,22 +39,18 @@ def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
|
||||
"""Test that FlexAttention produces the same outputs as the default backend.
|
||||
|
||||
This test compares the outputs from the FlexAttention backend with
|
||||
the default backend, ensuring they are identical when using the same seed.
|
||||
the default backend, ensuring they are similar when using the same seed.
|
||||
"""
|
||||
model_name = "Qwen/Qwen2.5-1.5B-Instruct"
|
||||
seed = 42
|
||||
max_tokens = 24
|
||||
num_logprobs = 5
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
]
|
||||
|
||||
sampling_params = SamplingParams(temperature=0.0,
|
||||
top_p=1.0,
|
||||
seed=seed,
|
||||
max_tokens=max_tokens)
|
||||
|
||||
# Run with flex attention
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
@@ -61,7 +62,8 @@ def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
|
||||
tensor_parallel_size=1,
|
||||
num_gpu_blocks_override=128,
|
||||
enforce_eager=True) as llm_flex:
|
||||
output_flex = llm_flex.generate(prompts, sampling_params)
|
||||
output_flex = llm_flex.generate_greedy_logprobs(
|
||||
prompts, max_tokens, num_logprobs)
|
||||
|
||||
# Run with default backend
|
||||
with monkeypatch.context() as m:
|
||||
@@ -71,20 +73,17 @@ def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
|
||||
runner="generate",
|
||||
tensor_parallel_size=1,
|
||||
num_gpu_blocks_override=128,
|
||||
enforce_eager=True) as llm_default:
|
||||
output_default = llm_default.generate(prompts, sampling_params)
|
||||
enforce_eager=True,
|
||||
gpu_memory_utilization=0.85) as llm_default:
|
||||
output_default = llm_default.generate_greedy_logprobs(
|
||||
prompts, max_tokens, num_logprobs)
|
||||
|
||||
# Compare outputs from both backends
|
||||
for i, (flex_result,
|
||||
default_result) in enumerate(zip(output_flex, output_default)):
|
||||
prompt = prompts[i]
|
||||
flex_text = flex_result[1][0]
|
||||
default_text = default_result[1][0]
|
||||
|
||||
assert flex_text == default_text, (
|
||||
f"FlexAttention output doesn't match default for: {prompt!r}\n"
|
||||
f"FlexAttention: {flex_text!r}\n"
|
||||
f"Default: {default_text!r}")
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=output_flex,
|
||||
outputs_1_lst=output_default,
|
||||
name_0="flex",
|
||||
name_1="default",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
@@ -136,5 +135,70 @@ def test_encoder_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available() or TORCH_VERSION < DIRECT_BUILD_VERSION,
|
||||
reason="CUDA not available or PyTorch version < 2.7",
|
||||
)
|
||||
def test_block_mask_direct_vs_slow_path():
|
||||
"""Test that direct path block mask is a superset of slow path.
|
||||
|
||||
The direct path may include extra blocks for performance (over-estimation),
|
||||
but must include all blocks that the slow path determines are necessary.
|
||||
"""
|
||||
device = torch.device("cuda")
|
||||
|
||||
vllm_config = create_vllm_config(model_name="meta-llama/Meta-Llama-3-8B",
|
||||
block_size=16,
|
||||
max_model_len=1024)
|
||||
kv_cache_spec = create_standard_kv_cache_spec(vllm_config)
|
||||
|
||||
# Use a mixed batch that will create groups spanning multiple sequences
|
||||
batch_spec = BatchSpec(seq_lens=[35, 64, 128, 256],
|
||||
query_lens=[33, 5, 32, 64],
|
||||
name="test_mixed_batch")
|
||||
|
||||
common_attn_metadata = create_common_attn_metadata(
|
||||
batch_spec, vllm_config.cache_config.block_size, device)
|
||||
|
||||
builder = FlexAttentionMetadataBuilder(kv_cache_spec, [], vllm_config,
|
||||
device)
|
||||
|
||||
metadata_direct = builder.build(common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata)
|
||||
builder.direct_build = False
|
||||
metadata_slow = builder.build(common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata)
|
||||
|
||||
assert metadata_direct.block_mask is not None
|
||||
assert metadata_slow.block_mask is not None
|
||||
|
||||
# Extract block indices for comparison, B, H are the same
|
||||
direct_indices = metadata_direct.block_mask.kv_indices[0, 0]
|
||||
slow_indices = metadata_slow.block_mask.kv_indices[0, 0]
|
||||
direct_num = metadata_direct.block_mask.kv_num_blocks[0, 0]
|
||||
slow_num = metadata_slow.block_mask.kv_num_blocks[0, 0]
|
||||
|
||||
# main test: every block needed by slow path must be in direct path
|
||||
num_groups = direct_num.shape[0]
|
||||
all_contained = True
|
||||
missing_details = []
|
||||
|
||||
for group_idx in range(num_groups):
|
||||
direct_blocks = set(
|
||||
direct_indices[group_idx, :direct_num[group_idx]].tolist())
|
||||
slow_blocks = set(
|
||||
slow_indices[group_idx, :slow_num[group_idx]].tolist())
|
||||
|
||||
missing_blocks = slow_blocks - direct_blocks
|
||||
if missing_blocks:
|
||||
all_contained = False
|
||||
missing_details.append(
|
||||
f"Group {group_idx}: missing {sorted(missing_blocks)}")
|
||||
|
||||
assert all_contained, (
|
||||
"Direct path is missing blocks required by slow path:\n" +
|
||||
"\n".join(missing_details))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
|
||||
Reference in New Issue
Block a user