[Misc] Support attention logits soft-capping with flash-attn (#7022)
This commit is contained in:
@@ -20,6 +20,7 @@ def ref_paged_attn(
|
||||
block_tables: torch.Tensor,
|
||||
scale: float,
|
||||
sliding_window: Optional[int] = None,
|
||||
soft_cap: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
num_seqs = len(query_lens)
|
||||
block_tables = block_tables.cpu().numpy()
|
||||
@@ -53,6 +54,8 @@ def ref_paged_attn(
|
||||
(query_len + sliding_window) +
|
||||
1).bool().logical_not()
|
||||
mask |= sliding_window_mask
|
||||
if soft_cap is not None:
|
||||
attn = soft_cap * torch.tanh(attn / soft_cap)
|
||||
attn.masked_fill_(mask, float("-inf"))
|
||||
attn = torch.softmax(attn, dim=-1).to(v.dtype)
|
||||
out = torch.einsum("hqk,khd->qhd", attn, v)
|
||||
@@ -68,13 +71,15 @@ def ref_paged_attn(
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@torch.inference_mode
|
||||
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
|
||||
@torch.inference_mode()
|
||||
def test_flash_attn_with_paged_kv(
|
||||
kv_lens: List[int],
|
||||
num_heads: Tuple[int, int],
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
block_size: int,
|
||||
soft_cap: Optional[float],
|
||||
) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
torch.cuda.manual_seed_all(0)
|
||||
@@ -108,6 +113,7 @@ def test_flash_attn_with_paged_kv(
|
||||
causal=True,
|
||||
block_table=block_tables,
|
||||
cache_seqlens=kv_lens_tensor,
|
||||
softcap=soft_cap if soft_cap is not None else 0,
|
||||
).squeeze(1)
|
||||
|
||||
ref_output = ref_paged_attn(
|
||||
@@ -118,6 +124,7 @@ def test_flash_attn_with_paged_kv(
|
||||
kv_lens=kv_lens,
|
||||
block_tables=block_tables,
|
||||
scale=scale,
|
||||
soft_cap=soft_cap,
|
||||
)
|
||||
assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \
|
||||
f"{torch.max(torch.abs(output - ref_output))}"
|
||||
@@ -129,7 +136,8 @@ def test_flash_attn_with_paged_kv(
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("sliding_window", [None])
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@torch.inference_mode
|
||||
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
|
||||
@torch.inference_mode()
|
||||
def test_varlen_with_paged_kv(
|
||||
seq_lens: List[Tuple[int, int]],
|
||||
num_heads: Tuple[int, int],
|
||||
@@ -137,6 +145,7 @@ def test_varlen_with_paged_kv(
|
||||
sliding_window: Optional[int],
|
||||
dtype: torch.dtype,
|
||||
block_size: int,
|
||||
soft_cap: Optional[float],
|
||||
) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
torch.cuda.manual_seed_all(0)
|
||||
@@ -163,10 +172,6 @@ def test_varlen_with_paged_kv(
|
||||
head_size,
|
||||
dtype=dtype)
|
||||
value_cache = torch.randn_like(key_cache)
|
||||
# Normalize the scale of the key and value caches to mitigate
|
||||
# numerical instability.
|
||||
key_cache /= head_size**0.5
|
||||
value_cache /= head_size**0.5
|
||||
cu_query_lens = torch.tensor([0] + query_lens,
|
||||
dtype=torch.int32).cumsum(dim=0,
|
||||
dtype=torch.int32)
|
||||
@@ -192,6 +197,7 @@ def test_varlen_with_paged_kv(
|
||||
causal=True,
|
||||
window_size=window_size,
|
||||
block_table=block_tables,
|
||||
softcap=soft_cap if soft_cap is not None else 0,
|
||||
)
|
||||
|
||||
ref_output = ref_paged_attn(
|
||||
@@ -203,6 +209,7 @@ def test_varlen_with_paged_kv(
|
||||
block_tables=block_tables,
|
||||
scale=scale,
|
||||
sliding_window=sliding_window,
|
||||
soft_cap=soft_cap,
|
||||
)
|
||||
assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \
|
||||
f"{torch.max(torch.abs(output - ref_output))}"
|
||||
|
||||
Reference in New Issue
Block a user