[Misc] Pass attention to impl backend (#12218)

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
wangxiyuan
2025-01-20 23:25:28 +08:00
committed by GitHub
parent 5f0ec3935a
commit 86bfb6dba7
12 changed files with 86 additions and 78 deletions

View File

@@ -130,13 +130,12 @@ class FlashAttentionImpl(AttentionImpl):
def forward(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: FlashAttentionMetadata,
k_scale: float = 1.0,
v_scale: float = 1.0,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention.
@@ -151,7 +150,7 @@ class FlashAttentionImpl(AttentionImpl):
shape = [num_tokens, num_heads * head_size]
"""
# NOTE(woosuk): FlashAttention does not support FP8 KV cache.
assert k_scale == 1.0 and v_scale == 1.0, (
assert layer._k_scale == 1.0 and layer._v_scale == 1.0, (
"key/v_scale is not supported in FlashAttention.")
assert output is not None, "Output tensor must be provided."
@@ -183,8 +182,8 @@ class FlashAttentionImpl(AttentionImpl):
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
k_scale,
v_scale,
layer._k_scale,
layer._v_scale,
)
# Compute attention and update output up to `num_actual_tokens`.