full cudagraph for flex-attn (#36298)
Signed-off-by: shunting314 <shunting@meta.com>
This commit is contained in:
@@ -26,6 +26,59 @@ MINIMUM_TORCH_VERSION = version.parse("2.7.0")
|
||||
DIRECT_BUILD_VERSION = version.parse("2.9.dev0")
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available() or TORCH_VERSION < MINIMUM_TORCH_VERSION,
|
||||
reason="CUDA not available or PyTorch version < 2.7",
|
||||
)
|
||||
def test_flex_attention_full_cudagraphs(vllm_runner):
|
||||
"""Test the numerics for flex attention full cudagraphs support."""
|
||||
model_name = "Qwen/Qwen2.5-1.5B-Instruct"
|
||||
seed = 42
|
||||
max_tokens = 24
|
||||
num_logprobs = 5
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
]
|
||||
|
||||
# Run with flex attention eager
|
||||
set_random_seed(seed)
|
||||
with vllm_runner(
|
||||
model_name,
|
||||
runner="generate",
|
||||
tensor_parallel_size=1,
|
||||
num_gpu_blocks_override=128,
|
||||
enforce_eager=True,
|
||||
attention_config={"backend": "FLEX_ATTENTION"},
|
||||
) as llm_flex:
|
||||
output_eager = llm_flex.generate_greedy_logprobs(
|
||||
prompts, max_tokens, num_logprobs
|
||||
)
|
||||
|
||||
# Run with flex attention compiled
|
||||
set_random_seed(seed)
|
||||
with vllm_runner(
|
||||
model_name,
|
||||
runner="generate",
|
||||
tensor_parallel_size=1,
|
||||
num_gpu_blocks_override=128,
|
||||
enforce_eager=False,
|
||||
gpu_memory_utilization=0.85,
|
||||
attention_config={"backend": "FLEX_ATTENTION"},
|
||||
) as llm_default:
|
||||
output_compile = llm_default.generate_greedy_logprobs(
|
||||
prompts, max_tokens, num_logprobs
|
||||
)
|
||||
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=output_eager,
|
||||
outputs_1_lst=output_compile,
|
||||
name_0="eager",
|
||||
name_1="compile",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available() or TORCH_VERSION < MINIMUM_TORCH_VERSION,
|
||||
reason="CUDA not available or PyTorch version < 2.7",
|
||||
|
||||
Reference in New Issue
Block a user