full cudagraph for flex-attn (#36298)

Signed-off-by: shunting314 <shunting@meta.com>
This commit is contained in:
shunting314
2026-04-02 21:15:01 -07:00
committed by GitHub
parent 2ad7c0335f
commit 8b141ed8c3
4 changed files with 145 additions and 11 deletions

View File

@@ -170,14 +170,3 @@ class TestFullCUDAGraph:
piecewise_res.outputs[0].text.lower()
== full_res.outputs[0].text.lower()
)
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
def test_full_cudagraph_with_invalid_backend():
# Flex_Attention is not supported with full cuda graph
with pytest.raises(RuntimeError):
LLM(
model="Qwen/Qwen2-1.5B-Instruct",
compilation_config=CompilationConfig(cudagraph_mode="FULL"),
attention_config={"backend": "FLEX_ATTENTION"},
)

View File

@@ -26,6 +26,59 @@ MINIMUM_TORCH_VERSION = version.parse("2.7.0")
DIRECT_BUILD_VERSION = version.parse("2.9.dev0")
@pytest.mark.skipif(
not torch.cuda.is_available() or TORCH_VERSION < MINIMUM_TORCH_VERSION,
reason="CUDA not available or PyTorch version < 2.7",
)
def test_flex_attention_full_cudagraphs(vllm_runner):
"""Test the numerics for flex attention full cudagraphs support."""
model_name = "Qwen/Qwen2.5-1.5B-Instruct"
seed = 42
max_tokens = 24
num_logprobs = 5
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
]
# Run with flex attention eager
set_random_seed(seed)
with vllm_runner(
model_name,
runner="generate",
tensor_parallel_size=1,
num_gpu_blocks_override=128,
enforce_eager=True,
attention_config={"backend": "FLEX_ATTENTION"},
) as llm_flex:
output_eager = llm_flex.generate_greedy_logprobs(
prompts, max_tokens, num_logprobs
)
# Run with flex attention compiled
set_random_seed(seed)
with vllm_runner(
model_name,
runner="generate",
tensor_parallel_size=1,
num_gpu_blocks_override=128,
enforce_eager=False,
gpu_memory_utilization=0.85,
attention_config={"backend": "FLEX_ATTENTION"},
) as llm_default:
output_compile = llm_default.generate_greedy_logprobs(
prompts, max_tokens, num_logprobs
)
check_logprobs_close(
outputs_0_lst=output_eager,
outputs_1_lst=output_compile,
name_0="eager",
name_1="compile",
)
@pytest.mark.skipif(
not torch.cuda.is_available() or TORCH_VERSION < MINIMUM_TORCH_VERSION,
reason="CUDA not available or PyTorch version < 2.7",

View File

@@ -30,6 +30,7 @@ from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import is_quantized_kv_cache, is_torch_equal_or_newer
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionCGSupport,
AttentionImpl,
AttentionMetadataBuilder,
AttentionType,
@@ -315,6 +316,18 @@ class BlockSparsityHint(NamedTuple):
hint_fn: _block_sparsity_hint_signature
def copy_to_persistent(dst, src):
try:
dst = dst.as_strided(src.shape, src.stride())
except RuntimeError as e:
raise RuntimeError(
f"Fail to re-stride a persistent tensor of shape {dst.shape} "
f"for a tensor of shape {src.shape}"
) from e
dst.copy_(src)
return dst
@dataclass
class FlexAttentionMetadata:
causal: bool
@@ -340,6 +353,9 @@ class FlexAttentionMetadata:
physical_to_logical: torch.Tensor
decode_offset: torch.Tensor
num_blocks_per_seq: torch.Tensor
persistent_kv_indices: torch.Tensor
persistent_kv_num_blocks: torch.Tensor
persistent_doc_ids: torch.Tensor
# For logging.
num_input_tokens: int = 0 # Number of tokens including padding.
@@ -656,8 +672,11 @@ class FlexAttentionMetadata:
kv_indices = unique_static_unsorted(
(used_pages_padded.long()), M=self.num_blocks
).to(torch.int32)
kv_indices = copy_to_persistent(self.persistent_kv_indices, kv_indices)
kv_num_blocks = (kv_indices >= 0).sum(dim=-1).to(torch.int32)
kv_num_blocks = copy_to_persistent(self.persistent_kv_num_blocks, kv_num_blocks)
block_mask_kwargs = {
"seq_lengths": (self.num_actual_tokens, self.total_cache_tokens),
"kv_num_blocks": kv_num_blocks[None, None],
@@ -694,6 +713,7 @@ class FlexAttentionMetadata:
assert self.suffix_kv_lens is None, "Not implemented yet."
# Create a lookup mapping from query indices -> request number
self.doc_ids = _offsets_to_doc_ids_tensor(self.query_start_loc)
self.doc_ids = copy_to_persistent(self.persistent_doc_ids, self.doc_ids)
self.num_blocks = self.total_cache_tokens // self.block_size
self.mask_mod = self.get_mask_mod()
@@ -701,6 +721,8 @@ class FlexAttentionMetadata:
class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadata]):
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS
def __init__(
self,
kv_cache_spec: AttentionSpec,
@@ -726,6 +748,38 @@ class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadat
self.q_block_size: int = 16 if supports_small_blocks else 128
self.kv_block_size: int = self.block_size if supports_small_blocks else 128
self.max_model_len = self.model_config.max_model_len
max_num_seqs = vllm_config.scheduler_config.max_num_seqs
max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens
self.max_num_q_block = (
self.max_model_len + self.q_block_size - 1
) // self.q_block_size
self.persistent_kv_num_blocks = torch.empty(
self.max_num_q_block, dtype=torch.int32, device=device
)
self.persistent_offset_tensor = torch.empty(
max_num_seqs, dtype=torch.int32, device=device
)
self.persistent_doc_ids = torch.empty(
max_num_batched_tokens, dtype=torch.int32, device=device
)
# initialize later when we can access block_table
self.persistent_physical_to_logical = None
self.persistent_kv_indices = None
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata
) -> FlexAttentionMetadata:
# Use actual max_seq_len instead of max_model_len to avoid
# torch.compile recompilation during CUDA graph capture.
common_attn_metadata.max_seq_len = (
common_attn_metadata.seq_lens_cpu.max().item()
)
return self.build(
common_prefix_len=0, common_attn_metadata=common_attn_metadata
)
def build(
self,
common_prefix_len: int,
@@ -765,8 +819,32 @@ class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadat
inverse_block_table = physical_to_logical_mapping(
block_table_tensor, seq_lens, block_size, num_gpu_blocks
)
if self.persistent_physical_to_logical is None:
max_num_seqs = self.vllm_config.scheduler_config.max_num_seqs
self.persistent_physical_to_logical = torch.empty(
max_num_seqs,
num_gpu_blocks,
dtype=torch.long,
device=self.device,
)
if self.persistent_kv_indices is None:
max_num_kv_block = (
self.max_model_len + self.kv_block_size - 1
) // self.kv_block_size
self.persistent_kv_indices = torch.empty(
self.max_model_len,
max_num_kv_block,
dtype=torch.int32,
device=self.device,
)
inverse_block_table = copy_to_persistent(
self.persistent_physical_to_logical, inverse_block_table
)
offset_tensor = common_attn_metadata.compute_num_computed_tokens()
offset_tensor = copy_to_persistent(self.persistent_offset_tensor, offset_tensor)
out = FlexAttentionMetadata(
causal=common_attn_metadata.causal,
@@ -795,7 +873,20 @@ class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadat
direct_build=(self.direct_build and common_attn_metadata.causal),
q_block_size=self.q_block_size,
kv_block_size=self.kv_block_size,
persistent_kv_indices=self.persistent_kv_indices,
persistent_kv_num_blocks=self.persistent_kv_num_blocks,
persistent_doc_ids=self.persistent_doc_ids,
)
# Pre-build block_mask so it is ready before CUDA graph capture.
# Without this, the lazy build in forward() would run non-graph-safe
# ops (e.g. torch.nonzero) inside capture.
if out.block_mask is None:
if out.direct_build:
out.block_mask = out._build_block_mask_direct()
else:
out.block_mask = out.build_block_mask()
return out
def use_cascade_attention(self, *args, **kwargs) -> bool:

View File

@@ -6077,6 +6077,7 @@ class GPUModelRunner(
skip_eplb=True,
remove_lora=False,
num_active_loras=desc.num_active_loras,
profile_seq_lens=profile_seq_lens,
)
self._dummy_run(
desc.num_tokens,