[V1] Make v1 more testable (#9888)
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
This commit is contained in:
@@ -136,7 +136,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
"key/v_scale is not supported in FlashAttention.")
|
||||
|
||||
output = torch.empty_like(query)
|
||||
torch.ops.vllm.unified_flash_attention(
|
||||
torch.ops.vllm.unified_v1_flash_attention(
|
||||
output,
|
||||
query,
|
||||
key,
|
||||
@@ -156,7 +156,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
return output
|
||||
|
||||
|
||||
def unified_flash_attention(
|
||||
def unified_v1_flash_attention(
|
||||
output: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
@@ -222,7 +222,7 @@ def unified_flash_attention(
|
||||
output[:num_actual_tokens].copy_(attn_output)
|
||||
|
||||
|
||||
def unified_flash_attention_fake(
|
||||
def unified_v1_flash_attention_fake(
|
||||
output: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
@@ -243,8 +243,8 @@ def unified_flash_attention_fake(
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="unified_flash_attention",
|
||||
op_func=unified_flash_attention,
|
||||
op_name="unified_v1_flash_attention",
|
||||
op_func=unified_v1_flash_attention,
|
||||
mutates_args=["kv_cache", "output"],
|
||||
fake_impl=unified_flash_attention_fake,
|
||||
fake_impl=unified_v1_flash_attention_fake,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user