[V1] Integrate Piecewise CUDA graphs (#10058)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@@ -51,6 +51,7 @@ class FlashAttentionMetadata:
|
||||
# |-------------------- seq_len ---------------------|
|
||||
# |-- query_len ---|
|
||||
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
max_query_len: int
|
||||
query_start_loc: torch.Tensor
|
||||
max_seq_len: int
|
||||
@@ -134,7 +135,9 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
assert k_scale == 1.0 and v_scale == 1.0, (
|
||||
"key/v_scale is not supported in FlashAttention.")
|
||||
|
||||
output = torch.ops.vllm.unified_flash_attention(
|
||||
output = torch.empty_like(query)
|
||||
torch.ops.vllm.unified_flash_attention(
|
||||
output,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
@@ -154,6 +157,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
|
||||
|
||||
def unified_flash_attention(
|
||||
output: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
@@ -168,17 +172,17 @@ def unified_flash_attention(
|
||||
window_size: Optional[List[int]] = None,
|
||||
alibi_slopes: Optional[torch.Tensor] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> None:
|
||||
current_metadata = get_forward_context()
|
||||
if current_metadata is None:
|
||||
# Profiling run.
|
||||
return torch.empty_like(query)
|
||||
return
|
||||
|
||||
assert current_metadata is not None
|
||||
assert isinstance(current_metadata, FlashAttentionMetadata)
|
||||
attn_metadata: FlashAttentionMetadata = current_metadata
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
|
||||
num_tokens, hidden_size = query.shape
|
||||
# Reshape the query, key, and value tensors.
|
||||
query = query.view(-1, num_heads, head_size)
|
||||
key = key.view(-1, num_kv_heads, head_size)
|
||||
@@ -188,18 +192,18 @@ def unified_flash_attention(
|
||||
key_cache = kv_cache[0]
|
||||
value_cache = kv_cache[1]
|
||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
key[:num_actual_tokens],
|
||||
value[:num_actual_tokens],
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
|
||||
output = flash_attn_varlen_func(
|
||||
q=query,
|
||||
attn_output = flash_attn_varlen_func(
|
||||
q=query[:num_actual_tokens],
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
cu_seqlens_q=attn_metadata.query_start_loc,
|
||||
@@ -213,10 +217,13 @@ def unified_flash_attention(
|
||||
block_table=attn_metadata.block_table,
|
||||
softcap=logits_soft_cap,
|
||||
)
|
||||
return output.view(num_tokens, hidden_size)
|
||||
attn_output = attn_output.view(num_actual_tokens, -1)
|
||||
# TODO(woosuk): Optimize this.
|
||||
output[:num_actual_tokens].copy_(attn_output)
|
||||
|
||||
|
||||
def unified_flash_attention_fake(
|
||||
output: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
@@ -231,13 +238,13 @@ def unified_flash_attention_fake(
|
||||
window_size: Optional[List[int]] = None,
|
||||
alibi_slopes: Optional[torch.Tensor] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(query)
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="unified_flash_attention",
|
||||
op_func=unified_flash_attention,
|
||||
mutates_args=["kv_cache"],
|
||||
mutates_args=["kv_cache", "output"],
|
||||
fake_impl=unified_flash_attention_fake,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user