[V1] Make v1 more testable (#9888)

Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
This commit is contained in:
Joe Runde
2024-11-06 12:57:35 -07:00
committed by GitHub
parent 87bd7e0515
commit d58268c56a
75 changed files with 243 additions and 165 deletions

View File

@@ -136,7 +136,7 @@ class FlashAttentionImpl(AttentionImpl):
"key/v_scale is not supported in FlashAttention.")
output = torch.empty_like(query)
torch.ops.vllm.unified_flash_attention(
torch.ops.vllm.unified_v1_flash_attention(
output,
query,
key,
@@ -156,7 +156,7 @@ class FlashAttentionImpl(AttentionImpl):
return output
def unified_flash_attention(
def unified_v1_flash_attention(
output: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
@@ -222,7 +222,7 @@ def unified_flash_attention(
output[:num_actual_tokens].copy_(attn_output)
def unified_flash_attention_fake(
def unified_v1_flash_attention_fake(
output: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
@@ -243,8 +243,8 @@ def unified_flash_attention_fake(
direct_register_custom_op(
op_name="unified_flash_attention",
op_func=unified_flash_attention,
op_name="unified_v1_flash_attention",
op_func=unified_v1_flash_attention,
mutates_args=["kv_cache", "output"],
fake_impl=unified_flash_attention_fake,
fake_impl=unified_v1_flash_attention_fake,
)