full cudagraph for flex-attn (#36298)

Signed-off-by: shunting314 <shunting@meta.com>
This commit is contained in:
shunting314
2026-04-02 21:15:01 -07:00
committed by GitHub
parent 2ad7c0335f
commit 8b141ed8c3
4 changed files with 145 additions and 11 deletions

View File

@@ -26,6 +26,59 @@ MINIMUM_TORCH_VERSION = version.parse("2.7.0")
DIRECT_BUILD_VERSION = version.parse("2.9.dev0")
@pytest.mark.skipif(
not torch.cuda.is_available() or TORCH_VERSION < MINIMUM_TORCH_VERSION,
reason="CUDA not available or PyTorch version < 2.7",
)
def test_flex_attention_full_cudagraphs(vllm_runner):
"""Test the numerics for flex attention full cudagraphs support."""
model_name = "Qwen/Qwen2.5-1.5B-Instruct"
seed = 42
max_tokens = 24
num_logprobs = 5
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
]
# Run with flex attention eager
set_random_seed(seed)
with vllm_runner(
model_name,
runner="generate",
tensor_parallel_size=1,
num_gpu_blocks_override=128,
enforce_eager=True,
attention_config={"backend": "FLEX_ATTENTION"},
) as llm_flex:
output_eager = llm_flex.generate_greedy_logprobs(
prompts, max_tokens, num_logprobs
)
# Run with flex attention compiled
set_random_seed(seed)
with vllm_runner(
model_name,
runner="generate",
tensor_parallel_size=1,
num_gpu_blocks_override=128,
enforce_eager=False,
gpu_memory_utilization=0.85,
attention_config={"backend": "FLEX_ATTENTION"},
) as llm_default:
output_compile = llm_default.generate_greedy_logprobs(
prompts, max_tokens, num_logprobs
)
check_logprobs_close(
outputs_0_lst=output_eager,
outputs_1_lst=output_compile,
name_0="eager",
name_1="compile",
)
@pytest.mark.skipif(
not torch.cuda.is_available() or TORCH_VERSION < MINIMUM_TORCH_VERSION,
reason="CUDA not available or PyTorch version < 2.7",