[torch.compile] directly register custom op (#9896)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -7,6 +7,7 @@ import torch
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata, AttentionType)
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||
|
||||
|
||||
@@ -152,8 +153,6 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
return output
|
||||
|
||||
|
||||
@torch.library.custom_op("vllm::unified_flash_attention",
|
||||
mutates_args=["kv_cache"])
|
||||
def unified_flash_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
@@ -217,8 +216,7 @@ def unified_flash_attention(
|
||||
return output.view(num_tokens, hidden_size)
|
||||
|
||||
|
||||
@unified_flash_attention.register_fake
|
||||
def _(
|
||||
def unified_flash_attention_fake(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
@@ -235,3 +233,11 @@ def _(
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(query)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="unified_flash_attention",
|
||||
op_func=unified_flash_attention,
|
||||
mutates_args=["kv_cache"],
|
||||
fake_impl=unified_flash_attention_fake,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user