[Misc] Support attention logits soft-capping with flash-attn (#7022)

This commit is contained in:
Woosuk Kwon
2024-08-01 13:14:37 -07:00
committed by GitHub
parent 562e580abc
commit 805a8a75f2
14 changed files with 71 additions and 47 deletions

View File

@@ -20,6 +20,7 @@ def ref_paged_attn(
block_tables: torch.Tensor,
scale: float,
sliding_window: Optional[int] = None,
soft_cap: Optional[float] = None,
) -> torch.Tensor:
num_seqs = len(query_lens)
block_tables = block_tables.cpu().numpy()
@@ -53,6 +54,8 @@ def ref_paged_attn(
(query_len + sliding_window) +
1).bool().logical_not()
mask |= sliding_window_mask
if soft_cap is not None:
attn = soft_cap * torch.tanh(attn / soft_cap)
attn.masked_fill_(mask, float("-inf"))
attn = torch.softmax(attn, dim=-1).to(v.dtype)
out = torch.einsum("hqk,khd->qhd", attn, v)
@@ -68,13 +71,15 @@ def ref_paged_attn(
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@torch.inference_mode
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
@torch.inference_mode()
def test_flash_attn_with_paged_kv(
kv_lens: List[int],
num_heads: Tuple[int, int],
head_size: int,
dtype: torch.dtype,
block_size: int,
soft_cap: Optional[float],
) -> None:
torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0)
@@ -108,6 +113,7 @@ def test_flash_attn_with_paged_kv(
causal=True,
block_table=block_tables,
cache_seqlens=kv_lens_tensor,
softcap=soft_cap if soft_cap is not None else 0,
).squeeze(1)
ref_output = ref_paged_attn(
@@ -118,6 +124,7 @@ def test_flash_attn_with_paged_kv(
kv_lens=kv_lens,
block_tables=block_tables,
scale=scale,
soft_cap=soft_cap,
)
assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}"
@@ -129,7 +136,8 @@ def test_flash_attn_with_paged_kv(
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("sliding_window", [None])
@pytest.mark.parametrize("dtype", DTYPES)
@torch.inference_mode
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
@torch.inference_mode()
def test_varlen_with_paged_kv(
seq_lens: List[Tuple[int, int]],
num_heads: Tuple[int, int],
@@ -137,6 +145,7 @@ def test_varlen_with_paged_kv(
sliding_window: Optional[int],
dtype: torch.dtype,
block_size: int,
soft_cap: Optional[float],
) -> None:
torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0)
@@ -163,10 +172,6 @@ def test_varlen_with_paged_kv(
head_size,
dtype=dtype)
value_cache = torch.randn_like(key_cache)
# Normalize the scale of the key and value caches to mitigate
# numerical instability.
key_cache /= head_size**0.5
value_cache /= head_size**0.5
cu_query_lens = torch.tensor([0] + query_lens,
dtype=torch.int32).cumsum(dim=0,
dtype=torch.int32)
@@ -192,6 +197,7 @@ def test_varlen_with_paged_kv(
causal=True,
window_size=window_size,
block_table=block_tables,
softcap=soft_cap if soft_cap is not None else 0,
)
ref_output = ref_paged_attn(
@@ -203,6 +209,7 @@ def test_varlen_with_paged_kv(
block_tables=block_tables,
scale=scale,
sliding_window=sliding_window,
soft_cap=soft_cap,
)
assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}"