[Bugfix] Fix FA3 full cuda graph correctness (#19106)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@@ -7,6 +7,7 @@ import pytest
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import CompilationConfig
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
MODEL = "Qwen/Qwen2-1.5B-Instruct"
|
||||
|
||||
@@ -37,7 +38,7 @@ def full_cudagraph_llm():
|
||||
"VLLM_FLASH_ATTN_VERSION": "3"
|
||||
}):
|
||||
return LLM(model=MODEL,
|
||||
gpu_memory_utilization=0.2,
|
||||
gpu_memory_utilization=0.3,
|
||||
compilation_config=CompilationConfig(full_cuda_graph=True))
|
||||
|
||||
|
||||
@@ -48,7 +49,7 @@ def piecewise_llm():
|
||||
"VLLM_FLASH_ATTN_VERSION": "3"
|
||||
}):
|
||||
return LLM(model=MODEL,
|
||||
gpu_memory_utilization=0.5,
|
||||
gpu_memory_utilization=0.6,
|
||||
compilation_config=CompilationConfig())
|
||||
|
||||
|
||||
@@ -61,6 +62,8 @@ def generate_text(llm: LLM, batch_size: int, max_tokens: int):
|
||||
return llm.generate(prompts, sampling_params)
|
||||
|
||||
|
||||
@pytest.mark.skipif(current_platform.get_device_capability() != (9, 0),
|
||||
reason="Only Hopper GPUs support FlashAttention 3")
|
||||
@pytest.mark.parametrize(("batch_size", "max_tokens"), [(1, 10), (7, 10),
|
||||
(16, 10), (25, 10),
|
||||
(32, 10), (45, 10),
|
||||
|
||||
Reference in New Issue
Block a user