[Misc] Pass attention to impl backend (#12218)
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -130,13 +130,12 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: FlashAttentionMetadata,
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashAttention.
|
||||
@@ -151,7 +150,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
# NOTE(woosuk): FlashAttention does not support FP8 KV cache.
|
||||
assert k_scale == 1.0 and v_scale == 1.0, (
|
||||
assert layer._k_scale == 1.0 and layer._v_scale == 1.0, (
|
||||
"key/v_scale is not supported in FlashAttention.")
|
||||
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
@@ -183,8 +182,8 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
# Compute attention and update output up to `num_actual_tokens`.
|
||||
|
||||
Reference in New Issue
Block a user