[Core] Deprecate xformers (#29262)

Signed-off-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
Roger Wang
2025-11-23 20:18:55 -08:00
committed by GitHub
parent 5253f4276f
commit 0ff70821c9
31 changed files with 77 additions and 963 deletions

View File

@@ -509,43 +509,6 @@ def pack_qkv(qkv: QKVInputs, device: torch.device | str) -> PackedQKVInputs:
)
def make_alibi_bias(
alibi_slopes: torch.Tensor,
num_kv_heads: int,
dtype: torch.dtype,
seq_lens: list[int],
) -> list[Any]:
"""Create ALiBi biases compatible with xFormers attention tests."""
from xformers.ops.fmha.attn_bias import LowerTriangularMaskWithTensorBias
if alibi_slopes is None:
return [None for _ in seq_lens]
attn_biases: list[Any] = []
num_heads = alibi_slopes.shape[0]
assert num_heads >= num_kv_heads, (
"ALiBi slopes expect at least as many heads as KV heads"
)
for seq_len in seq_lens:
bias = torch.arange(seq_len, dtype=dtype, device=alibi_slopes.device)
bias = bias[None, :] - bias[:, None]
padded_len = (seq_len + 7) // 8 * 8
bias_tensor = torch.empty(
1,
num_heads,
seq_len,
padded_len,
device=alibi_slopes.device,
dtype=dtype,
)[:, :, :, :seq_len].copy_(bias)
bias_tensor.mul_(alibi_slopes[:, None, None])
attn_biases.append(LowerTriangularMaskWithTensorBias(bias_tensor))
return attn_biases
def _make_metadata_tensors(
seq_lens: list[int] | None,
context_lens: list[int] | None,
@@ -649,23 +612,12 @@ def make_kv_cache(
Returns:
* kv_cache: 2 x num_blocks x (block_size * num_heads * head_size)
* for backend 'XFORMERS'
* kv_cache: 2 x num_blocks x block_size x num_heads x head_size
* for backend 'FLASH_ATTN'
"""
if backend == "XFORMERS":
kv_cache = torch.rand((2, num_blocks, block_size * num_heads * head_size)).to(
device
)
elif backend == "FLASH_ATTN":
kv_cache = torch.rand((2, num_blocks, block_size, num_heads, head_size)).to(
device
)
else:
raise ValueError(
f"Unknown backend value: '{backend}'. Expected 'XFORMERS' or 'FLASH_ATTN'."
)
if backend != "FLASH_ATTN":
raise ValueError(f"Unknown backend value: '{backend}'. Expected 'FLASH_ATTN'.")
kv_cache = torch.rand((2, num_blocks, block_size, num_heads, head_size)).to(device)
if default_val is not None:
kv_cache[:, :, :] = default_val
return kv_cache
@@ -843,22 +795,14 @@ def assert_actual_matches_ideal(
* output_under_test: actually observed output value
"""
ideal_output = test_params.packed_qkvo.ideal_output
if backend == "XFORMERS":
torch.testing.assert_close(
ideal_output, output_under_test.view_as(ideal_output)
)
elif backend == "FLASH_ATTN":
# For FlashAttention override the accuracy thresholds to non default
# values since we notice a higher difference between the ideal and
# actual output.
torch.testing.assert_close(
ideal_output, output_under_test.view_as(ideal_output), atol=0.01, rtol=0.016
)
else:
raise ValueError(
f"Unknown backend value: '{backend}'. Expected 'XFORMERS' or 'FLASH_ATTN'."
)
if backend != "FLASH_ATTN":
raise ValueError(f"Unknown backend value: '{backend}'. Expected 'FLASH_ATTN'.")
# For FlashAttention override the accuracy thresholds to non default
# values since we notice a higher difference between the ideal and
# actual output.
torch.testing.assert_close(
ideal_output, output_under_test.view_as(ideal_output), atol=0.01, rtol=0.016
)
# Copied/modified from torch._refs.__init__.py