[Core] Allow full cudagraph with separate attention routines and orthogonal to compilation, add support for FA2 and FlashInfer (#20059)
Signed-off-by: fhl <2410591650@qq.com> Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
This commit is contained in:
@@ -3,7 +3,8 @@
|
||||
import contextlib
|
||||
import os
|
||||
import weakref
|
||||
from contextlib import ExitStack
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -32,27 +33,130 @@ def temporary_environ(env_vars):
|
||||
os.environ[k] = v
|
||||
|
||||
|
||||
@dataclass
|
||||
class BackendConfig:
|
||||
name: str
|
||||
env_vars: dict
|
||||
comp_config: dict
|
||||
specific_gpu_arch: Optional[tuple] = None
|
||||
|
||||
|
||||
# Define all backend configurations of full cudagraph to be tested
|
||||
backend_configs = {
|
||||
# FA3 on Hopper
|
||||
"FA3":
|
||||
BackendConfig(name="FA3",
|
||||
env_vars={"VLLM_FLASH_ATTN_VERSION": "3"},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL",
|
||||
},
|
||||
specific_gpu_arch=(9, 0)),
|
||||
# FlashMLA on Hopper
|
||||
"FlashMLA":
|
||||
BackendConfig(name="FlashMLA",
|
||||
env_vars={
|
||||
"VLLM_ATTENTION_BACKEND": "FLASHMLA",
|
||||
},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||
},
|
||||
specific_gpu_arch=(9, 0)),
|
||||
# Cutlass MLA on Blackwell
|
||||
"CutlassMLA":
|
||||
BackendConfig(
|
||||
name="CutlassMLA",
|
||||
env_vars={
|
||||
"VLLM_USE_V1": "1",
|
||||
"VLLM_ATTENTION_BACKEND": "CUTLASS_MLA",
|
||||
"FORCE_NUM_KV_SPLITS":
|
||||
"1", # TODO: remove this when hang issue is fixed
|
||||
},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||
"cudagraph_capture_sizes": [16, 32, 64, 128, 256, 512],
|
||||
},
|
||||
specific_gpu_arch=(10, 0)),
|
||||
# FA2
|
||||
"FA2":
|
||||
BackendConfig(name="FA2",
|
||||
env_vars={"VLLM_FLASH_ATTN_VERSION": "2"},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL",
|
||||
}),
|
||||
# Triton Attention
|
||||
"TritonAttn":
|
||||
BackendConfig(name="TritonAttn",
|
||||
env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN_VLLM_V1"},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL",
|
||||
}),
|
||||
# FlashInfer
|
||||
"FlashInfer":
|
||||
BackendConfig(name="FlashInfer",
|
||||
env_vars={"VLLM_ATTENTION_BACKEND": "FLASHINFER"},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||
}),
|
||||
}
|
||||
|
||||
test_params_full_cudagraph = []
|
||||
|
||||
# deepseek-ai/DeepSeek-V2-Lite with MLA
|
||||
MLA_backends = ["FlashMLA", "CutlassMLA"]
|
||||
for mla_backend in MLA_backends:
|
||||
test_params_full_cudagraph.append(
|
||||
pytest.param(
|
||||
("deepseek-ai/DeepSeek-V2-Lite", backend_configs[mla_backend])))
|
||||
|
||||
# Qwen/Qwen2-1.5B-Instruct with other backends
|
||||
other_backend_configs = [
|
||||
backend_configs[c] for c in backend_configs if c not in MLA_backends
|
||||
]
|
||||
for backend_config in other_backend_configs:
|
||||
test_params_full_cudagraph.append(
|
||||
pytest.param(("Qwen/Qwen2-1.5B-Instruct", backend_config)))
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def llm_pair(request):
|
||||
model = request.param
|
||||
model, backend_config = request.param
|
||||
|
||||
with temporary_environ({
|
||||
"VLLM_USE_V1": "1",
|
||||
"VLLM_FLASH_ATTN_VERSION": "3"
|
||||
}):
|
||||
# Dynamically skip test if GPU capability is not met
|
||||
if backend_config.specific_gpu_arch and backend_config.specific_gpu_arch\
|
||||
!= current_platform.get_device_capability():
|
||||
if backend_config.specific_gpu_arch == (9, 0):
|
||||
pytest.skip("Only Hopper GPUs support FA3 and FlashMLA")
|
||||
elif backend_config.specific_gpu_arch == (10, 0):
|
||||
pytest.skip("Only Blackwell GPUs support Cutlass MLA")
|
||||
|
||||
env_vars = {
|
||||
"VLLM_USE_V1": "1",
|
||||
# Force native sampler to avoid potential nondeterminism in FlashInfer
|
||||
# when per-request generators are not used in V1.
|
||||
"VLLM_USE_FLASHINFER_SAMPLER": "0",
|
||||
**backend_config.env_vars,
|
||||
}
|
||||
with temporary_environ(env_vars):
|
||||
full = LLM(
|
||||
model=model,
|
||||
gpu_memory_utilization=0.45,
|
||||
gpu_memory_utilization=0.43,
|
||||
trust_remote_code=True,
|
||||
max_model_len=1024,
|
||||
compilation_config=CompilationConfig(full_cuda_graph=True),
|
||||
max_num_seqs=128,
|
||||
compilation_config=\
|
||||
CompilationConfig(**backend_config.comp_config),
|
||||
generation_config="vllm",
|
||||
seed=42,
|
||||
)
|
||||
piecewise = LLM(
|
||||
model=model,
|
||||
gpu_memory_utilization=0.45,
|
||||
gpu_memory_utilization=0.43,
|
||||
trust_remote_code=True,
|
||||
max_model_len=1024,
|
||||
compilation_config=CompilationConfig(),
|
||||
max_num_seqs=128,
|
||||
compilation_config=CompilationConfig(cudagraph_mode="PIECEWISE"),
|
||||
generation_config="vllm",
|
||||
seed=42,
|
||||
)
|
||||
|
||||
# PyTest caches the fixture values so we use weakref.proxy to enable GC
|
||||
@@ -66,90 +170,7 @@ def llm_pair(request):
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def cutlass_mla_llm_pair(request):
|
||||
model = request.param
|
||||
|
||||
# force V1 engine and Cutlass MLA backend
|
||||
with temporary_environ({
|
||||
"VLLM_USE_V1": "1",
|
||||
"VLLM_ATTENTION_BACKEND": "CUTLASS_MLA",
|
||||
"FORCE_NUM_KV_SPLITS":
|
||||
"1", # TODO: remove this when hang issue is fixed
|
||||
}):
|
||||
full = LLM(
|
||||
model=model,
|
||||
gpu_memory_utilization=0.45,
|
||||
trust_remote_code=True,
|
||||
max_model_len=1024,
|
||||
compilation_config=CompilationConfig(
|
||||
full_cuda_graph=True,
|
||||
cudagraph_capture_sizes=[16, 32, 64, 128, 256, 512],
|
||||
),
|
||||
)
|
||||
piecewise = LLM(
|
||||
model=model,
|
||||
gpu_memory_utilization=0.45,
|
||||
trust_remote_code=True,
|
||||
max_model_len=1024,
|
||||
compilation_config=CompilationConfig(),
|
||||
)
|
||||
|
||||
yield weakref.proxy(full), weakref.proxy(piecewise)
|
||||
del full
|
||||
del piecewise
|
||||
|
||||
wait_for_gpu_memory_to_clear(
|
||||
devices=[0],
|
||||
threshold_ratio=0.1,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"cutlass_mla_llm_pair",
|
||||
[
|
||||
# use an MLA model
|
||||
"deepseek-ai/DeepSeek-V2-Lite",
|
||||
],
|
||||
indirect=True)
|
||||
@pytest.mark.skipif(current_platform.get_device_capability() != (10, 0),
|
||||
reason="Only Blackwell GPUs support Cutlass MLA")
|
||||
class TestFullCUDAGraphCutlassMLA:
|
||||
"""
|
||||
Validate full CUDA Graph with Cutlass MLA (decode-only capture).
|
||||
"""
|
||||
|
||||
@pytest.mark.parametrize(("batch_size", "max_tokens"), [
|
||||
(8, 8),
|
||||
])
|
||||
def test_full_cudagraph_sm100_cutlass_mla(
|
||||
self, batch_size, max_tokens, cutlass_mla_llm_pair: tuple[LLM,
|
||||
LLM]):
|
||||
piecewise_llm, full_cudagraph_llm = cutlass_mla_llm_pair
|
||||
|
||||
prompts = ["Hello, my name is"] * batch_size
|
||||
sampling_params = SamplingParams(temperature=0.0,
|
||||
max_tokens=max_tokens,
|
||||
top_p=0.95)
|
||||
|
||||
piecewise_responses = piecewise_llm.generate(prompts, sampling_params)
|
||||
full_responses = full_cudagraph_llm.generate(prompts, sampling_params)
|
||||
|
||||
for piecewise_res, full_res in zip(piecewise_responses,
|
||||
full_responses):
|
||||
assert piecewise_res.outputs[0].text == full_res.outputs[0].text
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"llm_pair",
|
||||
[
|
||||
# Model names for the llm_pair fixture
|
||||
"deepseek-ai/DeepSeek-V2-Lite",
|
||||
"Qwen/Qwen2-1.5B-Instruct"
|
||||
],
|
||||
indirect=True)
|
||||
@pytest.mark.skipif(current_platform.get_device_capability() != (9, 0),
|
||||
reason="Only Hopper GPUs support FA3 and FlashMLA")
|
||||
@pytest.mark.parametrize("llm_pair", test_params_full_cudagraph, indirect=True)
|
||||
class TestFullCUDAGraph:
|
||||
"""
|
||||
Use a class such that an llm pair is constructed once for all
|
||||
@@ -178,12 +199,14 @@ class TestFullCUDAGraph:
|
||||
full cudagraph compilation works for padded cases too.
|
||||
"""
|
||||
|
||||
piecewise_llm, full_cudagraph_llm = llm_pair
|
||||
full_cudagraph_llm, piecewise_llm = llm_pair
|
||||
|
||||
prompts = ["Hello, my name is"] * batch_size
|
||||
prompts = ["the quick brown fox"] * batch_size
|
||||
# Use purely greedy decoding to avoid top-p truncation sensitivity
|
||||
# that can amplify tiny numeric differences across runtimes.
|
||||
sampling_params = SamplingParams(temperature=0.0,
|
||||
max_tokens=max_tokens,
|
||||
top_p=0.95)
|
||||
top_p=1.0)
|
||||
|
||||
piecewise_responses = piecewise_llm.generate(prompts, sampling_params)
|
||||
full_responses = full_cudagraph_llm.generate(prompts, sampling_params)
|
||||
@@ -191,42 +214,16 @@ class TestFullCUDAGraph:
|
||||
# Check that all responses are the same
|
||||
for piecewise_res, full_res in zip(piecewise_responses,
|
||||
full_responses):
|
||||
assert piecewise_res.outputs[0].text == full_res.outputs[0].text
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model, supported",
|
||||
[
|
||||
("Qwen/Qwen2-1.5B-Instruct", True),
|
||||
# MLA does not support capturing CUDA Graphs with size > max_num_seqs
|
||||
("deepseek-ai/DeepSeek-V2-Lite", False),
|
||||
])
|
||||
@pytest.mark.skipif(current_platform.get_device_capability() != (9, 0),
|
||||
reason="Only Hopper GPUs support FA3 and FlashMLA")
|
||||
def test_lower_max_num_seqs(model, supported):
|
||||
with temporary_environ({
|
||||
"VLLM_USE_V1": "1",
|
||||
"VLLM_FLASH_ATTN_VERSION": "3"
|
||||
}), ExitStack() as stack:
|
||||
if not supported:
|
||||
stack.enter_context(pytest.raises(RuntimeError))
|
||||
|
||||
llm = LLM(model=model,
|
||||
max_num_seqs=256,
|
||||
trust_remote_code=True,
|
||||
max_model_len=1024,
|
||||
compilation_config=CompilationConfig(
|
||||
full_cuda_graph=True,
|
||||
cudagraph_capture_sizes=[64, 256, 512]))
|
||||
llm.generate(["Hello, my name is"] * 10)
|
||||
assert piecewise_res.outputs[0].text.lower() == \
|
||||
full_res.outputs[0].text.lower()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
|
||||
def test_full_cudagraph_with_invalid_backend():
|
||||
with temporary_environ({
|
||||
"VLLM_USE_V1": "1",
|
||||
"VLLM_FLASH_ATTN_VERSION":
|
||||
"2" #FA2 not supported with full_cuda_graph
|
||||
"VLLM_ATTENTION_BACKEND": "FLEX_ATTENTION"
|
||||
# Flex_Attention is not supported with full cuda graph
|
||||
}), pytest.raises(RuntimeError):
|
||||
LLM(model="Qwen/Qwen2-1.5B-Instruct",
|
||||
compilation_config=CompilationConfig(full_cuda_graph=True))
|
||||
compilation_config=CompilationConfig(cudagraph_mode="FULL"))
|
||||
|
||||
Reference in New Issue
Block a user