full cudagraph for flex-attn (#36298)
Signed-off-by: shunting314 <shunting@meta.com>
This commit is contained in:
@@ -170,14 +170,3 @@ class TestFullCUDAGraph:
|
||||
piecewise_res.outputs[0].text.lower()
|
||||
== full_res.outputs[0].text.lower()
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
|
||||
def test_full_cudagraph_with_invalid_backend():
|
||||
# Flex_Attention is not supported with full cuda graph
|
||||
with pytest.raises(RuntimeError):
|
||||
LLM(
|
||||
model="Qwen/Qwen2-1.5B-Instruct",
|
||||
compilation_config=CompilationConfig(cudagraph_mode="FULL"),
|
||||
attention_config={"backend": "FLEX_ATTENTION"},
|
||||
)
|
||||
|
||||
@@ -26,6 +26,59 @@ MINIMUM_TORCH_VERSION = version.parse("2.7.0")
|
||||
DIRECT_BUILD_VERSION = version.parse("2.9.dev0")
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available() or TORCH_VERSION < MINIMUM_TORCH_VERSION,
|
||||
reason="CUDA not available or PyTorch version < 2.7",
|
||||
)
|
||||
def test_flex_attention_full_cudagraphs(vllm_runner):
|
||||
"""Test the numerics for flex attention full cudagraphs support."""
|
||||
model_name = "Qwen/Qwen2.5-1.5B-Instruct"
|
||||
seed = 42
|
||||
max_tokens = 24
|
||||
num_logprobs = 5
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
]
|
||||
|
||||
# Run with flex attention eager
|
||||
set_random_seed(seed)
|
||||
with vllm_runner(
|
||||
model_name,
|
||||
runner="generate",
|
||||
tensor_parallel_size=1,
|
||||
num_gpu_blocks_override=128,
|
||||
enforce_eager=True,
|
||||
attention_config={"backend": "FLEX_ATTENTION"},
|
||||
) as llm_flex:
|
||||
output_eager = llm_flex.generate_greedy_logprobs(
|
||||
prompts, max_tokens, num_logprobs
|
||||
)
|
||||
|
||||
# Run with flex attention compiled
|
||||
set_random_seed(seed)
|
||||
with vllm_runner(
|
||||
model_name,
|
||||
runner="generate",
|
||||
tensor_parallel_size=1,
|
||||
num_gpu_blocks_override=128,
|
||||
enforce_eager=False,
|
||||
gpu_memory_utilization=0.85,
|
||||
attention_config={"backend": "FLEX_ATTENTION"},
|
||||
) as llm_default:
|
||||
output_compile = llm_default.generate_greedy_logprobs(
|
||||
prompts, max_tokens, num_logprobs
|
||||
)
|
||||
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=output_eager,
|
||||
outputs_1_lst=output_compile,
|
||||
name_0="eager",
|
||||
name_1="compile",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available() or TORCH_VERSION < MINIMUM_TORCH_VERSION,
|
||||
reason="CUDA not available or PyTorch version < 2.7",
|
||||
|
||||
@@ -30,6 +30,7 @@ from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.torch_utils import is_quantized_kv_cache, is_torch_equal_or_newer
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionCGSupport,
|
||||
AttentionImpl,
|
||||
AttentionMetadataBuilder,
|
||||
AttentionType,
|
||||
@@ -315,6 +316,18 @@ class BlockSparsityHint(NamedTuple):
|
||||
hint_fn: _block_sparsity_hint_signature
|
||||
|
||||
|
||||
def copy_to_persistent(dst, src):
|
||||
try:
|
||||
dst = dst.as_strided(src.shape, src.stride())
|
||||
except RuntimeError as e:
|
||||
raise RuntimeError(
|
||||
f"Fail to re-stride a persistent tensor of shape {dst.shape} "
|
||||
f"for a tensor of shape {src.shape}"
|
||||
) from e
|
||||
dst.copy_(src)
|
||||
return dst
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlexAttentionMetadata:
|
||||
causal: bool
|
||||
@@ -340,6 +353,9 @@ class FlexAttentionMetadata:
|
||||
physical_to_logical: torch.Tensor
|
||||
decode_offset: torch.Tensor
|
||||
num_blocks_per_seq: torch.Tensor
|
||||
persistent_kv_indices: torch.Tensor
|
||||
persistent_kv_num_blocks: torch.Tensor
|
||||
persistent_doc_ids: torch.Tensor
|
||||
|
||||
# For logging.
|
||||
num_input_tokens: int = 0 # Number of tokens including padding.
|
||||
@@ -656,8 +672,11 @@ class FlexAttentionMetadata:
|
||||
kv_indices = unique_static_unsorted(
|
||||
(used_pages_padded.long()), M=self.num_blocks
|
||||
).to(torch.int32)
|
||||
kv_indices = copy_to_persistent(self.persistent_kv_indices, kv_indices)
|
||||
|
||||
kv_num_blocks = (kv_indices >= 0).sum(dim=-1).to(torch.int32)
|
||||
kv_num_blocks = copy_to_persistent(self.persistent_kv_num_blocks, kv_num_blocks)
|
||||
|
||||
block_mask_kwargs = {
|
||||
"seq_lengths": (self.num_actual_tokens, self.total_cache_tokens),
|
||||
"kv_num_blocks": kv_num_blocks[None, None],
|
||||
@@ -694,6 +713,7 @@ class FlexAttentionMetadata:
|
||||
assert self.suffix_kv_lens is None, "Not implemented yet."
|
||||
# Create a lookup mapping from query indices -> request number
|
||||
self.doc_ids = _offsets_to_doc_ids_tensor(self.query_start_loc)
|
||||
self.doc_ids = copy_to_persistent(self.persistent_doc_ids, self.doc_ids)
|
||||
self.num_blocks = self.total_cache_tokens // self.block_size
|
||||
|
||||
self.mask_mod = self.get_mask_mod()
|
||||
@@ -701,6 +721,8 @@ class FlexAttentionMetadata:
|
||||
|
||||
|
||||
class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadata]):
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
@@ -726,6 +748,38 @@ class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadat
|
||||
self.q_block_size: int = 16 if supports_small_blocks else 128
|
||||
self.kv_block_size: int = self.block_size if supports_small_blocks else 128
|
||||
|
||||
self.max_model_len = self.model_config.max_model_len
|
||||
max_num_seqs = vllm_config.scheduler_config.max_num_seqs
|
||||
max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens
|
||||
self.max_num_q_block = (
|
||||
self.max_model_len + self.q_block_size - 1
|
||||
) // self.q_block_size
|
||||
self.persistent_kv_num_blocks = torch.empty(
|
||||
self.max_num_q_block, dtype=torch.int32, device=device
|
||||
)
|
||||
self.persistent_offset_tensor = torch.empty(
|
||||
max_num_seqs, dtype=torch.int32, device=device
|
||||
)
|
||||
self.persistent_doc_ids = torch.empty(
|
||||
max_num_batched_tokens, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
# initialize later when we can access block_table
|
||||
self.persistent_physical_to_logical = None
|
||||
self.persistent_kv_indices = None
|
||||
|
||||
def build_for_cudagraph_capture(
|
||||
self, common_attn_metadata: CommonAttentionMetadata
|
||||
) -> FlexAttentionMetadata:
|
||||
# Use actual max_seq_len instead of max_model_len to avoid
|
||||
# torch.compile recompilation during CUDA graph capture.
|
||||
common_attn_metadata.max_seq_len = (
|
||||
common_attn_metadata.seq_lens_cpu.max().item()
|
||||
)
|
||||
return self.build(
|
||||
common_prefix_len=0, common_attn_metadata=common_attn_metadata
|
||||
)
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
@@ -765,8 +819,32 @@ class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadat
|
||||
inverse_block_table = physical_to_logical_mapping(
|
||||
block_table_tensor, seq_lens, block_size, num_gpu_blocks
|
||||
)
|
||||
if self.persistent_physical_to_logical is None:
|
||||
max_num_seqs = self.vllm_config.scheduler_config.max_num_seqs
|
||||
self.persistent_physical_to_logical = torch.empty(
|
||||
max_num_seqs,
|
||||
num_gpu_blocks,
|
||||
dtype=torch.long,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
if self.persistent_kv_indices is None:
|
||||
max_num_kv_block = (
|
||||
self.max_model_len + self.kv_block_size - 1
|
||||
) // self.kv_block_size
|
||||
self.persistent_kv_indices = torch.empty(
|
||||
self.max_model_len,
|
||||
max_num_kv_block,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
inverse_block_table = copy_to_persistent(
|
||||
self.persistent_physical_to_logical, inverse_block_table
|
||||
)
|
||||
|
||||
offset_tensor = common_attn_metadata.compute_num_computed_tokens()
|
||||
offset_tensor = copy_to_persistent(self.persistent_offset_tensor, offset_tensor)
|
||||
|
||||
out = FlexAttentionMetadata(
|
||||
causal=common_attn_metadata.causal,
|
||||
@@ -795,7 +873,20 @@ class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadat
|
||||
direct_build=(self.direct_build and common_attn_metadata.causal),
|
||||
q_block_size=self.q_block_size,
|
||||
kv_block_size=self.kv_block_size,
|
||||
persistent_kv_indices=self.persistent_kv_indices,
|
||||
persistent_kv_num_blocks=self.persistent_kv_num_blocks,
|
||||
persistent_doc_ids=self.persistent_doc_ids,
|
||||
)
|
||||
|
||||
# Pre-build block_mask so it is ready before CUDA graph capture.
|
||||
# Without this, the lazy build in forward() would run non-graph-safe
|
||||
# ops (e.g. torch.nonzero) inside capture.
|
||||
if out.block_mask is None:
|
||||
if out.direct_build:
|
||||
out.block_mask = out._build_block_mask_direct()
|
||||
else:
|
||||
out.block_mask = out.build_block_mask()
|
||||
|
||||
return out
|
||||
|
||||
def use_cascade_attention(self, *args, **kwargs) -> bool:
|
||||
|
||||
@@ -6077,6 +6077,7 @@ class GPUModelRunner(
|
||||
skip_eplb=True,
|
||||
remove_lora=False,
|
||||
num_active_loras=desc.num_active_loras,
|
||||
profile_seq_lens=profile_seq_lens,
|
||||
)
|
||||
self._dummy_run(
|
||||
desc.num_tokens,
|
||||
|
||||
Reference in New Issue
Block a user