[Core] Allow full cudagraph with separate attention routines and orthogonal to compilation, add support for FA2 and FlashInfer (#20059)
Signed-off-by: fhl <2410591650@qq.com> Signed-off-by: fhl2000 <63384265+fhl2000@users.noreply.github.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
This commit is contained in:
@@ -3,7 +3,8 @@
|
||||
import contextlib
|
||||
import os
|
||||
import weakref
|
||||
from contextlib import ExitStack
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -32,27 +33,130 @@ def temporary_environ(env_vars):
|
||||
os.environ[k] = v
|
||||
|
||||
|
||||
@dataclass
|
||||
class BackendConfig:
|
||||
name: str
|
||||
env_vars: dict
|
||||
comp_config: dict
|
||||
specific_gpu_arch: Optional[tuple] = None
|
||||
|
||||
|
||||
# Define all backend configurations of full cudagraph to be tested
|
||||
backend_configs = {
|
||||
# FA3 on Hopper
|
||||
"FA3":
|
||||
BackendConfig(name="FA3",
|
||||
env_vars={"VLLM_FLASH_ATTN_VERSION": "3"},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL",
|
||||
},
|
||||
specific_gpu_arch=(9, 0)),
|
||||
# FlashMLA on Hopper
|
||||
"FlashMLA":
|
||||
BackendConfig(name="FlashMLA",
|
||||
env_vars={
|
||||
"VLLM_ATTENTION_BACKEND": "FLASHMLA",
|
||||
},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||
},
|
||||
specific_gpu_arch=(9, 0)),
|
||||
# Cutlass MLA on Blackwell
|
||||
"CutlassMLA":
|
||||
BackendConfig(
|
||||
name="CutlassMLA",
|
||||
env_vars={
|
||||
"VLLM_USE_V1": "1",
|
||||
"VLLM_ATTENTION_BACKEND": "CUTLASS_MLA",
|
||||
"FORCE_NUM_KV_SPLITS":
|
||||
"1", # TODO: remove this when hang issue is fixed
|
||||
},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||
"cudagraph_capture_sizes": [16, 32, 64, 128, 256, 512],
|
||||
},
|
||||
specific_gpu_arch=(10, 0)),
|
||||
# FA2
|
||||
"FA2":
|
||||
BackendConfig(name="FA2",
|
||||
env_vars={"VLLM_FLASH_ATTN_VERSION": "2"},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL",
|
||||
}),
|
||||
# Triton Attention
|
||||
"TritonAttn":
|
||||
BackendConfig(name="TritonAttn",
|
||||
env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN_VLLM_V1"},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL",
|
||||
}),
|
||||
# FlashInfer
|
||||
"FlashInfer":
|
||||
BackendConfig(name="FlashInfer",
|
||||
env_vars={"VLLM_ATTENTION_BACKEND": "FLASHINFER"},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||
}),
|
||||
}
|
||||
|
||||
test_params_full_cudagraph = []
|
||||
|
||||
# deepseek-ai/DeepSeek-V2-Lite with MLA
|
||||
MLA_backends = ["FlashMLA", "CutlassMLA"]
|
||||
for mla_backend in MLA_backends:
|
||||
test_params_full_cudagraph.append(
|
||||
pytest.param(
|
||||
("deepseek-ai/DeepSeek-V2-Lite", backend_configs[mla_backend])))
|
||||
|
||||
# Qwen/Qwen2-1.5B-Instruct with other backends
|
||||
other_backend_configs = [
|
||||
backend_configs[c] for c in backend_configs if c not in MLA_backends
|
||||
]
|
||||
for backend_config in other_backend_configs:
|
||||
test_params_full_cudagraph.append(
|
||||
pytest.param(("Qwen/Qwen2-1.5B-Instruct", backend_config)))
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def llm_pair(request):
|
||||
model = request.param
|
||||
model, backend_config = request.param
|
||||
|
||||
with temporary_environ({
|
||||
"VLLM_USE_V1": "1",
|
||||
"VLLM_FLASH_ATTN_VERSION": "3"
|
||||
}):
|
||||
# Dynamically skip test if GPU capability is not met
|
||||
if backend_config.specific_gpu_arch and backend_config.specific_gpu_arch\
|
||||
!= current_platform.get_device_capability():
|
||||
if backend_config.specific_gpu_arch == (9, 0):
|
||||
pytest.skip("Only Hopper GPUs support FA3 and FlashMLA")
|
||||
elif backend_config.specific_gpu_arch == (10, 0):
|
||||
pytest.skip("Only Blackwell GPUs support Cutlass MLA")
|
||||
|
||||
env_vars = {
|
||||
"VLLM_USE_V1": "1",
|
||||
# Force native sampler to avoid potential nondeterminism in FlashInfer
|
||||
# when per-request generators are not used in V1.
|
||||
"VLLM_USE_FLASHINFER_SAMPLER": "0",
|
||||
**backend_config.env_vars,
|
||||
}
|
||||
with temporary_environ(env_vars):
|
||||
full = LLM(
|
||||
model=model,
|
||||
gpu_memory_utilization=0.45,
|
||||
gpu_memory_utilization=0.43,
|
||||
trust_remote_code=True,
|
||||
max_model_len=1024,
|
||||
compilation_config=CompilationConfig(full_cuda_graph=True),
|
||||
max_num_seqs=128,
|
||||
compilation_config=\
|
||||
CompilationConfig(**backend_config.comp_config),
|
||||
generation_config="vllm",
|
||||
seed=42,
|
||||
)
|
||||
piecewise = LLM(
|
||||
model=model,
|
||||
gpu_memory_utilization=0.45,
|
||||
gpu_memory_utilization=0.43,
|
||||
trust_remote_code=True,
|
||||
max_model_len=1024,
|
||||
compilation_config=CompilationConfig(),
|
||||
max_num_seqs=128,
|
||||
compilation_config=CompilationConfig(cudagraph_mode="PIECEWISE"),
|
||||
generation_config="vllm",
|
||||
seed=42,
|
||||
)
|
||||
|
||||
# PyTest caches the fixture values so we use weakref.proxy to enable GC
|
||||
@@ -66,90 +170,7 @@ def llm_pair(request):
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def cutlass_mla_llm_pair(request):
|
||||
model = request.param
|
||||
|
||||
# force V1 engine and Cutlass MLA backend
|
||||
with temporary_environ({
|
||||
"VLLM_USE_V1": "1",
|
||||
"VLLM_ATTENTION_BACKEND": "CUTLASS_MLA",
|
||||
"FORCE_NUM_KV_SPLITS":
|
||||
"1", # TODO: remove this when hang issue is fixed
|
||||
}):
|
||||
full = LLM(
|
||||
model=model,
|
||||
gpu_memory_utilization=0.45,
|
||||
trust_remote_code=True,
|
||||
max_model_len=1024,
|
||||
compilation_config=CompilationConfig(
|
||||
full_cuda_graph=True,
|
||||
cudagraph_capture_sizes=[16, 32, 64, 128, 256, 512],
|
||||
),
|
||||
)
|
||||
piecewise = LLM(
|
||||
model=model,
|
||||
gpu_memory_utilization=0.45,
|
||||
trust_remote_code=True,
|
||||
max_model_len=1024,
|
||||
compilation_config=CompilationConfig(),
|
||||
)
|
||||
|
||||
yield weakref.proxy(full), weakref.proxy(piecewise)
|
||||
del full
|
||||
del piecewise
|
||||
|
||||
wait_for_gpu_memory_to_clear(
|
||||
devices=[0],
|
||||
threshold_ratio=0.1,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"cutlass_mla_llm_pair",
|
||||
[
|
||||
# use an MLA model
|
||||
"deepseek-ai/DeepSeek-V2-Lite",
|
||||
],
|
||||
indirect=True)
|
||||
@pytest.mark.skipif(current_platform.get_device_capability() != (10, 0),
|
||||
reason="Only Blackwell GPUs support Cutlass MLA")
|
||||
class TestFullCUDAGraphCutlassMLA:
|
||||
"""
|
||||
Validate full CUDA Graph with Cutlass MLA (decode-only capture).
|
||||
"""
|
||||
|
||||
@pytest.mark.parametrize(("batch_size", "max_tokens"), [
|
||||
(8, 8),
|
||||
])
|
||||
def test_full_cudagraph_sm100_cutlass_mla(
|
||||
self, batch_size, max_tokens, cutlass_mla_llm_pair: tuple[LLM,
|
||||
LLM]):
|
||||
piecewise_llm, full_cudagraph_llm = cutlass_mla_llm_pair
|
||||
|
||||
prompts = ["Hello, my name is"] * batch_size
|
||||
sampling_params = SamplingParams(temperature=0.0,
|
||||
max_tokens=max_tokens,
|
||||
top_p=0.95)
|
||||
|
||||
piecewise_responses = piecewise_llm.generate(prompts, sampling_params)
|
||||
full_responses = full_cudagraph_llm.generate(prompts, sampling_params)
|
||||
|
||||
for piecewise_res, full_res in zip(piecewise_responses,
|
||||
full_responses):
|
||||
assert piecewise_res.outputs[0].text == full_res.outputs[0].text
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"llm_pair",
|
||||
[
|
||||
# Model names for the llm_pair fixture
|
||||
"deepseek-ai/DeepSeek-V2-Lite",
|
||||
"Qwen/Qwen2-1.5B-Instruct"
|
||||
],
|
||||
indirect=True)
|
||||
@pytest.mark.skipif(current_platform.get_device_capability() != (9, 0),
|
||||
reason="Only Hopper GPUs support FA3 and FlashMLA")
|
||||
@pytest.mark.parametrize("llm_pair", test_params_full_cudagraph, indirect=True)
|
||||
class TestFullCUDAGraph:
|
||||
"""
|
||||
Use a class such that an llm pair is constructed once for all
|
||||
@@ -178,12 +199,14 @@ class TestFullCUDAGraph:
|
||||
full cudagraph compilation works for padded cases too.
|
||||
"""
|
||||
|
||||
piecewise_llm, full_cudagraph_llm = llm_pair
|
||||
full_cudagraph_llm, piecewise_llm = llm_pair
|
||||
|
||||
prompts = ["Hello, my name is"] * batch_size
|
||||
prompts = ["the quick brown fox"] * batch_size
|
||||
# Use purely greedy decoding to avoid top-p truncation sensitivity
|
||||
# that can amplify tiny numeric differences across runtimes.
|
||||
sampling_params = SamplingParams(temperature=0.0,
|
||||
max_tokens=max_tokens,
|
||||
top_p=0.95)
|
||||
top_p=1.0)
|
||||
|
||||
piecewise_responses = piecewise_llm.generate(prompts, sampling_params)
|
||||
full_responses = full_cudagraph_llm.generate(prompts, sampling_params)
|
||||
@@ -191,42 +214,16 @@ class TestFullCUDAGraph:
|
||||
# Check that all responses are the same
|
||||
for piecewise_res, full_res in zip(piecewise_responses,
|
||||
full_responses):
|
||||
assert piecewise_res.outputs[0].text == full_res.outputs[0].text
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model, supported",
|
||||
[
|
||||
("Qwen/Qwen2-1.5B-Instruct", True),
|
||||
# MLA does not support capturing CUDA Graphs with size > max_num_seqs
|
||||
("deepseek-ai/DeepSeek-V2-Lite", False),
|
||||
])
|
||||
@pytest.mark.skipif(current_platform.get_device_capability() != (9, 0),
|
||||
reason="Only Hopper GPUs support FA3 and FlashMLA")
|
||||
def test_lower_max_num_seqs(model, supported):
|
||||
with temporary_environ({
|
||||
"VLLM_USE_V1": "1",
|
||||
"VLLM_FLASH_ATTN_VERSION": "3"
|
||||
}), ExitStack() as stack:
|
||||
if not supported:
|
||||
stack.enter_context(pytest.raises(RuntimeError))
|
||||
|
||||
llm = LLM(model=model,
|
||||
max_num_seqs=256,
|
||||
trust_remote_code=True,
|
||||
max_model_len=1024,
|
||||
compilation_config=CompilationConfig(
|
||||
full_cuda_graph=True,
|
||||
cudagraph_capture_sizes=[64, 256, 512]))
|
||||
llm.generate(["Hello, my name is"] * 10)
|
||||
assert piecewise_res.outputs[0].text.lower() == \
|
||||
full_res.outputs[0].text.lower()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
|
||||
def test_full_cudagraph_with_invalid_backend():
|
||||
with temporary_environ({
|
||||
"VLLM_USE_V1": "1",
|
||||
"VLLM_FLASH_ATTN_VERSION":
|
||||
"2" #FA2 not supported with full_cuda_graph
|
||||
"VLLM_ATTENTION_BACKEND": "FLEX_ATTENTION"
|
||||
# Flex_Attention is not supported with full cuda graph
|
||||
}), pytest.raises(RuntimeError):
|
||||
LLM(model="Qwen/Qwen2-1.5B-Instruct",
|
||||
compilation_config=CompilationConfig(full_cuda_graph=True))
|
||||
compilation_config=CompilationConfig(cudagraph_mode="FULL"))
|
||||
|
||||
@@ -11,10 +11,10 @@ from torch.library import Library
|
||||
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
|
||||
set_current_vllm_config)
|
||||
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
|
||||
VllmConfig, set_current_vllm_config)
|
||||
from vllm.envs import VLLM_USE_V1
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
global_counter = 0
|
||||
@@ -101,16 +101,33 @@ def test_simple_piecewise_compile(use_inductor):
|
||||
num_backend_compilations=3, # num_piecewise_capturable_graphs_seen
|
||||
num_cudagraph_captured=
|
||||
6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
|
||||
), set_forward_context({}, vllm_config=vllm_config):
|
||||
|
||||
), set_forward_context(None,
|
||||
vllm_config=vllm_config): # background context
|
||||
# warm up with background context
|
||||
model(inputs)
|
||||
|
||||
model(torch.randn(2).cuda())
|
||||
model(torch.randn(1).cuda())
|
||||
# capturing/replaying should under context of cudagraph dispatching
|
||||
with set_forward_context(
|
||||
None,
|
||||
vllm_config=vllm_config,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
|
||||
batch_descriptor=BatchDescriptor(num_tokens=2, )):
|
||||
model(torch.randn(2).cuda())
|
||||
with set_forward_context(
|
||||
None,
|
||||
vllm_config=vllm_config,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
|
||||
batch_descriptor=BatchDescriptor(num_tokens=1, )):
|
||||
model(torch.randn(1).cuda())
|
||||
|
||||
input = torch.zeros(2).cuda()
|
||||
global global_counter
|
||||
global_counter = 0
|
||||
output = model(input)
|
||||
with set_forward_context(
|
||||
None,
|
||||
vllm_config=vllm_config,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
|
||||
batch_descriptor=BatchDescriptor(num_tokens=2, )):
|
||||
output = model(input)
|
||||
assert global_counter == 2
|
||||
assert torch.allclose(output.cpu(), torch.tensor([3., 1.]))
|
||||
|
||||
@@ -18,9 +18,9 @@ from torch.library import Library
|
||||
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
|
||||
set_current_vllm_config)
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.config import (CompilationConfig, CompilationLevel, CUDAGraphMode,
|
||||
VllmConfig, set_current_vllm_config)
|
||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
# create a library to hold the custom op
|
||||
@@ -276,9 +276,11 @@ def run_model(llama_config,
|
||||
)
|
||||
if split_attn:
|
||||
compilation_config.splitting_ops = ["silly.attention"]
|
||||
cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
||||
else:
|
||||
compilation_config = CompilationConfig(
|
||||
level=CompilationLevel.NO_COMPILATION, )
|
||||
cudagraph_runtime_mode = CUDAGraphMode.NONE
|
||||
|
||||
vllm_config = VllmConfig(compilation_config=compilation_config,
|
||||
additional_config=llama_config)
|
||||
@@ -287,17 +289,37 @@ def run_model(llama_config,
|
||||
vllm_config=vllm_config,
|
||||
prefix="").eval().cuda()
|
||||
|
||||
with set_forward_context({}, vllm_config=vllm_config):
|
||||
with set_forward_context({},
|
||||
vllm_config=vllm_config): # background context
|
||||
B = 16 # max batch size
|
||||
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
|
||||
positions = torch.arange(B).cuda()
|
||||
|
||||
# warmup for the model with cudagraph_mode NONE
|
||||
model(input_ids, positions)
|
||||
model(input_ids[:2], positions[:2])
|
||||
model(input_ids[:1], positions[:1])
|
||||
|
||||
# simulate cudagraphs capturing
|
||||
with set_forward_context({},
|
||||
vllm_config=vllm_config,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=BatchDescriptor(
|
||||
num_tokens=2, )):
|
||||
model(input_ids[:2], positions[:2])
|
||||
with set_forward_context({},
|
||||
vllm_config=vllm_config,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=BatchDescriptor(
|
||||
num_tokens=1, )):
|
||||
model(input_ids[:1], positions[:1])
|
||||
|
||||
input_ids[:2].zero_()
|
||||
output = model(input_ids[:2], positions[:2])
|
||||
# simulate cudagraphs replay
|
||||
with set_forward_context({},
|
||||
vllm_config=vllm_config,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=BatchDescriptor(
|
||||
num_tokens=2, )):
|
||||
output = model(input_ids[:2], positions[:2])
|
||||
|
||||
output = output.cpu()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user