[torch.compile] directly register custom op (#9896)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2024-10-31 21:56:09 -07:00
committed by GitHub
parent 031a7995f3
commit 96e0c9cbbd
9 changed files with 192 additions and 67 deletions

View File

@@ -7,6 +7,7 @@ import torch
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
from vllm.forward_context import get_forward_context
from vllm.utils import direct_register_custom_op
from vllm.vllm_flash_attn import flash_attn_varlen_func
@@ -152,8 +153,6 @@ class FlashAttentionImpl(AttentionImpl):
return output
@torch.library.custom_op("vllm::unified_flash_attention",
mutates_args=["kv_cache"])
def unified_flash_attention(
query: torch.Tensor,
key: torch.Tensor,
@@ -217,8 +216,7 @@ def unified_flash_attention(
return output.view(num_tokens, hidden_size)
@unified_flash_attention.register_fake
def _(
def unified_flash_attention_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
@@ -235,3 +233,11 @@ def _(
logits_soft_cap: Optional[float] = None,
) -> torch.Tensor:
return torch.empty_like(query)
direct_register_custom_op(
op_name="unified_flash_attention",
op_func=unified_flash_attention,
mutates_args=["kv_cache"],
fake_impl=unified_flash_attention_fake,
)