[Bugfix] Fix FA3 full cuda graph correctness (#19106)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon
2025-06-03 23:10:15 -07:00
committed by GitHub
parent 41aa578428
commit b124e1085b
4 changed files with 32 additions and 10 deletions

View File

@@ -7,6 +7,7 @@ import pytest
from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig
from vllm.platforms import current_platform
MODEL = "Qwen/Qwen2-1.5B-Instruct"
@@ -37,7 +38,7 @@ def full_cudagraph_llm():
"VLLM_FLASH_ATTN_VERSION": "3"
}):
return LLM(model=MODEL,
gpu_memory_utilization=0.2,
gpu_memory_utilization=0.3,
compilation_config=CompilationConfig(full_cuda_graph=True))
@@ -48,7 +49,7 @@ def piecewise_llm():
"VLLM_FLASH_ATTN_VERSION": "3"
}):
return LLM(model=MODEL,
gpu_memory_utilization=0.5,
gpu_memory_utilization=0.6,
compilation_config=CompilationConfig())
@@ -61,6 +62,8 @@ def generate_text(llm: LLM, batch_size: int, max_tokens: int):
return llm.generate(prompts, sampling_params)
@pytest.mark.skipif(current_platform.get_device_capability() != (9, 0),
reason="Only Hopper GPUs support FlashAttention 3")
@pytest.mark.parametrize(("batch_size", "max_tokens"), [(1, 10), (7, 10),
(16, 10), (25, 10),
(32, 10), (45, 10),