diff --git a/tests/compile/fullgraph/test_full_cudagraph.py b/tests/compile/fullgraph/test_full_cudagraph.py index c7c737371..95306e206 100644 --- a/tests/compile/fullgraph/test_full_cudagraph.py +++ b/tests/compile/fullgraph/test_full_cudagraph.py @@ -170,14 +170,3 @@ class TestFullCUDAGraph: piecewise_res.outputs[0].text.lower() == full_res.outputs[0].text.lower() ) - - -@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda") -def test_full_cudagraph_with_invalid_backend(): - # Flex_Attention is not supported with full cuda graph - with pytest.raises(RuntimeError): - LLM( - model="Qwen/Qwen2-1.5B-Instruct", - compilation_config=CompilationConfig(cudagraph_mode="FULL"), - attention_config={"backend": "FLEX_ATTENTION"}, - ) diff --git a/tests/kernels/test_flex_attention.py b/tests/kernels/test_flex_attention.py index 69113b57c..41d298134 100644 --- a/tests/kernels/test_flex_attention.py +++ b/tests/kernels/test_flex_attention.py @@ -26,6 +26,59 @@ MINIMUM_TORCH_VERSION = version.parse("2.7.0") DIRECT_BUILD_VERSION = version.parse("2.9.dev0") +@pytest.mark.skipif( + not torch.cuda.is_available() or TORCH_VERSION < MINIMUM_TORCH_VERSION, + reason="CUDA not available or PyTorch version < 2.7", +) +def test_flex_attention_full_cudagraphs(vllm_runner): + """Test the numerics for flex attention full cudagraphs support.""" + model_name = "Qwen/Qwen2.5-1.5B-Instruct" + seed = 42 + max_tokens = 24 + num_logprobs = 5 + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + ] + + # Run with flex attention eager + set_random_seed(seed) + with vllm_runner( + model_name, + runner="generate", + tensor_parallel_size=1, + num_gpu_blocks_override=128, + enforce_eager=True, + attention_config={"backend": "FLEX_ATTENTION"}, + ) as llm_flex: + output_eager = llm_flex.generate_greedy_logprobs( + prompts, max_tokens, num_logprobs + ) + + # Run with flex attention compiled + set_random_seed(seed) + with vllm_runner( + model_name, + runner="generate", + tensor_parallel_size=1, + num_gpu_blocks_override=128, + enforce_eager=False, + gpu_memory_utilization=0.85, + attention_config={"backend": "FLEX_ATTENTION"}, + ) as llm_default: + output_compile = llm_default.generate_greedy_logprobs( + prompts, max_tokens, num_logprobs + ) + + check_logprobs_close( + outputs_0_lst=output_eager, + outputs_1_lst=output_compile, + name_0="eager", + name_1="compile", + ) + + @pytest.mark.skipif( not torch.cuda.is_available() or TORCH_VERSION < MINIMUM_TORCH_VERSION, reason="CUDA not available or PyTorch version < 2.7", diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index e832f6bdd..5e202e00f 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -30,6 +30,7 @@ from vllm.utils.math_utils import cdiv from vllm.utils.torch_utils import is_quantized_kv_cache, is_torch_equal_or_newer from vllm.v1.attention.backend import ( AttentionBackend, + AttentionCGSupport, AttentionImpl, AttentionMetadataBuilder, AttentionType, @@ -315,6 +316,18 @@ class BlockSparsityHint(NamedTuple): hint_fn: _block_sparsity_hint_signature +def copy_to_persistent(dst, src): + try: + dst = dst.as_strided(src.shape, src.stride()) + except RuntimeError as e: + raise RuntimeError( + f"Fail to re-stride a persistent tensor of shape {dst.shape} " + f"for a tensor of shape {src.shape}" + ) from e + dst.copy_(src) + return dst + + @dataclass class FlexAttentionMetadata: causal: bool @@ -340,6 +353,9 @@ class FlexAttentionMetadata: physical_to_logical: torch.Tensor decode_offset: torch.Tensor num_blocks_per_seq: torch.Tensor + persistent_kv_indices: torch.Tensor + persistent_kv_num_blocks: torch.Tensor + persistent_doc_ids: torch.Tensor # For logging. num_input_tokens: int = 0 # Number of tokens including padding. @@ -656,8 +672,11 @@ class FlexAttentionMetadata: kv_indices = unique_static_unsorted( (used_pages_padded.long()), M=self.num_blocks ).to(torch.int32) + kv_indices = copy_to_persistent(self.persistent_kv_indices, kv_indices) kv_num_blocks = (kv_indices >= 0).sum(dim=-1).to(torch.int32) + kv_num_blocks = copy_to_persistent(self.persistent_kv_num_blocks, kv_num_blocks) + block_mask_kwargs = { "seq_lengths": (self.num_actual_tokens, self.total_cache_tokens), "kv_num_blocks": kv_num_blocks[None, None], @@ -694,6 +713,7 @@ class FlexAttentionMetadata: assert self.suffix_kv_lens is None, "Not implemented yet." # Create a lookup mapping from query indices -> request number self.doc_ids = _offsets_to_doc_ids_tensor(self.query_start_loc) + self.doc_ids = copy_to_persistent(self.persistent_doc_ids, self.doc_ids) self.num_blocks = self.total_cache_tokens // self.block_size self.mask_mod = self.get_mask_mod() @@ -701,6 +721,8 @@ class FlexAttentionMetadata: class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadata]): + _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS + def __init__( self, kv_cache_spec: AttentionSpec, @@ -726,6 +748,38 @@ class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadat self.q_block_size: int = 16 if supports_small_blocks else 128 self.kv_block_size: int = self.block_size if supports_small_blocks else 128 + self.max_model_len = self.model_config.max_model_len + max_num_seqs = vllm_config.scheduler_config.max_num_seqs + max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens + self.max_num_q_block = ( + self.max_model_len + self.q_block_size - 1 + ) // self.q_block_size + self.persistent_kv_num_blocks = torch.empty( + self.max_num_q_block, dtype=torch.int32, device=device + ) + self.persistent_offset_tensor = torch.empty( + max_num_seqs, dtype=torch.int32, device=device + ) + self.persistent_doc_ids = torch.empty( + max_num_batched_tokens, dtype=torch.int32, device=device + ) + + # initialize later when we can access block_table + self.persistent_physical_to_logical = None + self.persistent_kv_indices = None + + def build_for_cudagraph_capture( + self, common_attn_metadata: CommonAttentionMetadata + ) -> FlexAttentionMetadata: + # Use actual max_seq_len instead of max_model_len to avoid + # torch.compile recompilation during CUDA graph capture. + common_attn_metadata.max_seq_len = ( + common_attn_metadata.seq_lens_cpu.max().item() + ) + return self.build( + common_prefix_len=0, common_attn_metadata=common_attn_metadata + ) + def build( self, common_prefix_len: int, @@ -765,8 +819,32 @@ class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadat inverse_block_table = physical_to_logical_mapping( block_table_tensor, seq_lens, block_size, num_gpu_blocks ) + if self.persistent_physical_to_logical is None: + max_num_seqs = self.vllm_config.scheduler_config.max_num_seqs + self.persistent_physical_to_logical = torch.empty( + max_num_seqs, + num_gpu_blocks, + dtype=torch.long, + device=self.device, + ) + + if self.persistent_kv_indices is None: + max_num_kv_block = ( + self.max_model_len + self.kv_block_size - 1 + ) // self.kv_block_size + self.persistent_kv_indices = torch.empty( + self.max_model_len, + max_num_kv_block, + dtype=torch.int32, + device=self.device, + ) + + inverse_block_table = copy_to_persistent( + self.persistent_physical_to_logical, inverse_block_table + ) offset_tensor = common_attn_metadata.compute_num_computed_tokens() + offset_tensor = copy_to_persistent(self.persistent_offset_tensor, offset_tensor) out = FlexAttentionMetadata( causal=common_attn_metadata.causal, @@ -795,7 +873,20 @@ class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadat direct_build=(self.direct_build and common_attn_metadata.causal), q_block_size=self.q_block_size, kv_block_size=self.kv_block_size, + persistent_kv_indices=self.persistent_kv_indices, + persistent_kv_num_blocks=self.persistent_kv_num_blocks, + persistent_doc_ids=self.persistent_doc_ids, ) + + # Pre-build block_mask so it is ready before CUDA graph capture. + # Without this, the lazy build in forward() would run non-graph-safe + # ops (e.g. torch.nonzero) inside capture. + if out.block_mask is None: + if out.direct_build: + out.block_mask = out._build_block_mask_direct() + else: + out.block_mask = out.build_block_mask() + return out def use_cascade_attention(self, *args, **kwargs) -> bool: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 979ff8d33..bba707df0 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -6077,6 +6077,7 @@ class GPUModelRunner( skip_eplb=True, remove_lora=False, num_active_loras=desc.num_active_loras, + profile_seq_lens=profile_seq_lens, ) self._dummy_run( desc.num_tokens,