[V1] Integrate Piecewise CUDA graphs (#10058)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon
2024-11-05 22:16:04 -08:00
committed by GitHub
parent 9d59b75593
commit 4089985552
3 changed files with 133 additions and 36 deletions

View File

@@ -51,6 +51,7 @@ class FlashAttentionMetadata:
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
num_actual_tokens: int # Number of tokens excluding padding.
max_query_len: int
query_start_loc: torch.Tensor
max_seq_len: int
@@ -134,7 +135,9 @@ class FlashAttentionImpl(AttentionImpl):
assert k_scale == 1.0 and v_scale == 1.0, (
"key/v_scale is not supported in FlashAttention.")
output = torch.ops.vllm.unified_flash_attention(
output = torch.empty_like(query)
torch.ops.vllm.unified_flash_attention(
output,
query,
key,
value,
@@ -154,6 +157,7 @@ class FlashAttentionImpl(AttentionImpl):
def unified_flash_attention(
output: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
@@ -168,17 +172,17 @@ def unified_flash_attention(
window_size: Optional[List[int]] = None,
alibi_slopes: Optional[torch.Tensor] = None,
logits_soft_cap: Optional[float] = None,
) -> torch.Tensor:
) -> None:
current_metadata = get_forward_context()
if current_metadata is None:
# Profiling run.
return torch.empty_like(query)
return
assert current_metadata is not None
assert isinstance(current_metadata, FlashAttentionMetadata)
attn_metadata: FlashAttentionMetadata = current_metadata
num_actual_tokens = attn_metadata.num_actual_tokens
num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, num_heads, head_size)
key = key.view(-1, num_kv_heads, head_size)
@@ -188,18 +192,18 @@ def unified_flash_attention(
key_cache = kv_cache[0]
value_cache = kv_cache[1]
torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
kv_cache[0],
kv_cache[1],
key[:num_actual_tokens],
value[:num_actual_tokens],
key_cache,
value_cache,
attn_metadata.slot_mapping,
kv_cache_dtype,
k_scale,
v_scale,
)
output = flash_attn_varlen_func(
q=query,
attn_output = flash_attn_varlen_func(
q=query[:num_actual_tokens],
k=key_cache,
v=value_cache,
cu_seqlens_q=attn_metadata.query_start_loc,
@@ -213,10 +217,13 @@ def unified_flash_attention(
block_table=attn_metadata.block_table,
softcap=logits_soft_cap,
)
return output.view(num_tokens, hidden_size)
attn_output = attn_output.view(num_actual_tokens, -1)
# TODO(woosuk): Optimize this.
output[:num_actual_tokens].copy_(attn_output)
def unified_flash_attention_fake(
output: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
@@ -231,13 +238,13 @@ def unified_flash_attention_fake(
window_size: Optional[List[int]] = None,
alibi_slopes: Optional[torch.Tensor] = None,
logits_soft_cap: Optional[float] = None,
) -> torch.Tensor:
return torch.empty_like(query)
) -> None:
return
direct_register_custom_op(
op_name="unified_flash_attention",
op_func=unified_flash_attention,
mutates_args=["kv_cache"],
mutates_args=["kv_cache", "output"],
fake_impl=unified_flash_attention_fake,
)