Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for v1 attention backends without GPUModelRunner dependency."""
|
||||
|
||||
from functools import partial
|
||||
from typing import Optional, Union
|
||||
|
||||
@@ -8,21 +9,30 @@ import pytest
|
||||
import torch
|
||||
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
|
||||
|
||||
from tests.v1.attention.utils import (BatchSpec, create_common_attn_metadata,
|
||||
create_standard_kv_cache_spec,
|
||||
create_vllm_config,
|
||||
get_attention_backend)
|
||||
from tests.v1.attention.utils import (
|
||||
BatchSpec,
|
||||
create_common_attn_metadata,
|
||||
create_standard_kv_cache_spec,
|
||||
create_vllm_config,
|
||||
get_attention_backend,
|
||||
)
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv, is_torch_equal_or_newer
|
||||
from vllm.v1.attention.backends.utils import (CommonAttentionMetadata,
|
||||
set_kv_cache_layout)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
CommonAttentionMetadata,
|
||||
set_kv_cache_layout,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec
|
||||
|
||||
BACKENDS_TO_TEST = [
|
||||
_Backend.FLASH_ATTN, _Backend.FLASHINFER, _Backend.FLEX_ATTENTION,
|
||||
_Backend.TRITON_ATTN, _Backend.TREE_ATTN, "FLEX_ATTENTION_SLOW"
|
||||
_Backend.FLASH_ATTN,
|
||||
_Backend.FLASHINFER,
|
||||
_Backend.FLEX_ATTENTION,
|
||||
_Backend.TRITON_ATTN,
|
||||
_Backend.TREE_ATTN,
|
||||
"FLEX_ATTENTION_SLOW",
|
||||
]
|
||||
|
||||
# Remove flashinfer from the list if it's not available
|
||||
@@ -49,42 +59,38 @@ def _convert_dtype_to_torch(dtype):
|
||||
|
||||
# Define common batch configurations
|
||||
BATCH_SPECS = {
|
||||
"small_decode":
|
||||
BatchSpec(seq_lens=[32, 40], query_lens=[1, 1]),
|
||||
"small_prefill":
|
||||
BatchSpec(seq_lens=[32, 40], query_lens=[8, 8]),
|
||||
"mixed_small":
|
||||
BatchSpec(seq_lens=[32, 40, 48, 56], query_lens=[1, 1, 5, 5]),
|
||||
"medium_decode":
|
||||
BatchSpec(seq_lens=[128, 256, 512, 1024, 128, 256, 512, 1024],
|
||||
query_lens=[1, 1, 1, 1, 1, 1, 1, 1]),
|
||||
"medium_prefill":
|
||||
BatchSpec(seq_lens=[256, 512, 1024, 2048], query_lens=[16, 16, 16, 16]),
|
||||
"mixed_medium":
|
||||
BatchSpec(seq_lens=[512, 1024, 2048, 512, 1024, 2048],
|
||||
query_lens=[1, 1, 1, 7, 7, 7]),
|
||||
"large_decode":
|
||||
BatchSpec(seq_lens=[2048] * 32, query_lens=[1] * 32),
|
||||
"large_prefill":
|
||||
BatchSpec(seq_lens=[4096] * 8, query_lens=[32] * 8),
|
||||
"single_decode":
|
||||
BatchSpec(seq_lens=[1024], query_lens=[1]),
|
||||
"single_prefill":
|
||||
BatchSpec(seq_lens=[1024], query_lens=[64]),
|
||||
"small_decode": BatchSpec(seq_lens=[32, 40], query_lens=[1, 1]),
|
||||
"small_prefill": BatchSpec(seq_lens=[32, 40], query_lens=[8, 8]),
|
||||
"mixed_small": BatchSpec(seq_lens=[32, 40, 48, 56], query_lens=[1, 1, 5, 5]),
|
||||
"medium_decode": BatchSpec(
|
||||
seq_lens=[128, 256, 512, 1024, 128, 256, 512, 1024],
|
||||
query_lens=[1, 1, 1, 1, 1, 1, 1, 1],
|
||||
),
|
||||
"medium_prefill": BatchSpec(
|
||||
seq_lens=[256, 512, 1024, 2048], query_lens=[16, 16, 16, 16]
|
||||
),
|
||||
"mixed_medium": BatchSpec(
|
||||
seq_lens=[512, 1024, 2048, 512, 1024, 2048], query_lens=[1, 1, 1, 7, 7, 7]
|
||||
),
|
||||
"large_decode": BatchSpec(seq_lens=[2048] * 32, query_lens=[1] * 32),
|
||||
"large_prefill": BatchSpec(seq_lens=[4096] * 8, query_lens=[32] * 8),
|
||||
"single_decode": BatchSpec(seq_lens=[1024], query_lens=[1]),
|
||||
"single_prefill": BatchSpec(seq_lens=[1024], query_lens=[64]),
|
||||
}
|
||||
|
||||
|
||||
def create_and_prepopulate_kv_cache(
|
||||
k_contexts: list[torch.Tensor],
|
||||
v_contexts: list[torch.Tensor],
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
num_blocks: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
randomize_blocks: bool = True) -> torch.Tensor:
|
||||
k_contexts: list[torch.Tensor],
|
||||
v_contexts: list[torch.Tensor],
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
num_blocks: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
randomize_blocks: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""Create and prepopulate a KV cache with context data.
|
||||
|
||||
Args:
|
||||
@@ -106,20 +112,18 @@ def create_and_prepopulate_kv_cache(
|
||||
"""
|
||||
batch_size = len(k_contexts)
|
||||
seq_lens = common_attn_metadata.seq_lens_cpu
|
||||
query_lens = common_attn_metadata.query_start_loc_cpu[
|
||||
1:] - common_attn_metadata.query_start_loc_cpu[:-1]
|
||||
query_lens = (
|
||||
common_attn_metadata.query_start_loc_cpu[1:]
|
||||
- common_attn_metadata.query_start_loc_cpu[:-1]
|
||||
)
|
||||
context_lens = common_attn_metadata.num_computed_tokens_cpu
|
||||
block_table = common_attn_metadata.block_table_tensor
|
||||
slot_mapping = common_attn_metadata.slot_mapping
|
||||
|
||||
# Create KV cache
|
||||
kv_cache = torch.empty(2,
|
||||
num_blocks,
|
||||
block_size,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
kv_cache = torch.empty(
|
||||
2, num_blocks, block_size, num_kv_heads, head_size, dtype=dtype, device=device
|
||||
)
|
||||
kv_cache_flat = kv_cache.view(2, -1, num_kv_heads, head_size)
|
||||
|
||||
# Populate the cache with the context tokens
|
||||
@@ -168,8 +172,8 @@ def create_and_prepopulate_kv_cache(
|
||||
start = common_attn_metadata.query_start_loc_cpu[i]
|
||||
end = common_attn_metadata.query_start_loc_cpu[i + 1]
|
||||
slot_mapping[start:end] = block_table[
|
||||
i,
|
||||
block_indices] * block_size + token_inter_block_offsets.to(device)
|
||||
i, block_indices
|
||||
] * block_size + token_inter_block_offsets.to(device)
|
||||
|
||||
return kv_cache
|
||||
|
||||
@@ -222,20 +226,19 @@ def run_attention_backend(
|
||||
# Return mock parameters for a single layer
|
||||
head_size = vllm_config.model_config.get_head_size()
|
||||
return {
|
||||
layer_name:
|
||||
PerLayerParameters(
|
||||
layer_name: PerLayerParameters(
|
||||
window_left=-1, # No sliding window
|
||||
logits_soft_cap=0.0, # No soft cap
|
||||
sm_scale=1.0 / (head_size**0.5) # Standard scale
|
||||
sm_scale=1.0 / (head_size**0.5), # Standard scale
|
||||
)
|
||||
for layer_name in layer_names
|
||||
}
|
||||
|
||||
with unittest.mock.patch(
|
||||
'vllm.v1.attention.backends.flashinfer.get_per_layer_parameters',
|
||||
mock_get_per_layer_parameters):
|
||||
builder = builder_cls(kv_cache_spec, layer_names, vllm_config,
|
||||
device)
|
||||
"vllm.v1.attention.backends.flashinfer.get_per_layer_parameters",
|
||||
mock_get_per_layer_parameters,
|
||||
):
|
||||
builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device)
|
||||
attn_metadata = builder.build(
|
||||
common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
@@ -252,9 +255,11 @@ def run_attention_backend(
|
||||
|
||||
# Instantiate implementation
|
||||
num_heads = vllm_config.model_config.get_num_attention_heads(
|
||||
vllm_config.parallel_config)
|
||||
vllm_config.parallel_config
|
||||
)
|
||||
num_kv_heads = vllm_config.model_config.get_num_kv_heads(
|
||||
vllm_config.parallel_config)
|
||||
vllm_config.parallel_config
|
||||
)
|
||||
head_size = vllm_config.model_config.get_head_size()
|
||||
scale = 1.0 / (head_size**0.5)
|
||||
impl = impl_cls(
|
||||
@@ -274,13 +279,9 @@ def run_attention_backend(
|
||||
# Run forward pass
|
||||
# NOTE: The query, key, and value are already shaped correctly
|
||||
# in the calling test function.
|
||||
output = impl.forward(mock_layer,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
output=output)
|
||||
output = impl.forward(
|
||||
mock_layer, query, key, value, kv_cache, attn_metadata, output=output
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
@@ -311,10 +312,12 @@ def _test_backend_correctness(
|
||||
5. Comparing the vLLM backend's output to the ground-truth SDPA output.
|
||||
"""
|
||||
current_platform.seed_everything(42)
|
||||
vllm_config = create_vllm_config(model_name=model,
|
||||
max_model_len=max(batch_spec.seq_lens),
|
||||
block_size=block_size,
|
||||
num_gpu_blocks=8192)
|
||||
vllm_config = create_vllm_config(
|
||||
model_name=model,
|
||||
max_model_len=max(batch_spec.seq_lens),
|
||||
block_size=block_size,
|
||||
num_gpu_blocks=8192,
|
||||
)
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
kv_cache_spec = create_standard_kv_cache_spec(vllm_config)
|
||||
@@ -324,9 +327,11 @@ def _test_backend_correctness(
|
||||
seq_lens = batch_spec.seq_lens
|
||||
query_lens = batch_spec.query_lens
|
||||
num_q_heads = vllm_config.model_config.get_num_attention_heads(
|
||||
vllm_config.parallel_config)
|
||||
vllm_config.parallel_config
|
||||
)
|
||||
num_kv_heads = vllm_config.model_config.get_num_kv_heads(
|
||||
vllm_config.parallel_config)
|
||||
vllm_config.parallel_config
|
||||
)
|
||||
head_size = vllm_config.model_config.get_head_size()
|
||||
sliding_window = vllm_config.model_config.get_sliding_window()
|
||||
dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype)
|
||||
@@ -344,21 +349,9 @@ def _test_backend_correctness(
|
||||
context_len = s_len - q_len
|
||||
|
||||
# Generate Q, K, V for the whole sequence to be used in SDPA
|
||||
q = torch.randn(q_len,
|
||||
num_q_heads,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
k_full = torch.randn(s_len,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
v_full = torch.randn(s_len,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
q = torch.randn(q_len, num_q_heads, head_size, dtype=dtype, device=device)
|
||||
k_full = torch.randn(s_len, num_kv_heads, head_size, dtype=dtype, device=device)
|
||||
v_full = torch.randn(s_len, num_kv_heads, head_size, dtype=dtype, device=device)
|
||||
|
||||
# SDPA expects (N, H, L, D), so unsqueeze batch and permute
|
||||
q_sdpa_in = q.unsqueeze(0).transpose(1, 2)
|
||||
@@ -368,7 +361,8 @@ def _test_backend_correctness(
|
||||
if num_q_heads != num_kv_heads:
|
||||
assert num_q_heads % num_kv_heads == 0, (
|
||||
f"num_q_heads ({num_q_heads}) must be divisible by "
|
||||
f"num_kv_heads ({num_kv_heads})")
|
||||
f"num_kv_heads ({num_kv_heads})"
|
||||
)
|
||||
repeats = num_q_heads // num_kv_heads
|
||||
k_sdpa_in = k_sdpa_in.repeat_interleave(repeats, dim=1)
|
||||
v_sdpa_in = v_sdpa_in.repeat_interleave(repeats, dim=1)
|
||||
@@ -378,18 +372,17 @@ def _test_backend_correctness(
|
||||
kv_len = s_len
|
||||
|
||||
final_mask_mod = partial(mask_mod, context_len=context_len)
|
||||
block_mask = create_block_mask(final_mask_mod,
|
||||
B=None,
|
||||
H=None,
|
||||
Q_LEN=q_len,
|
||||
KV_LEN=kv_len,
|
||||
device=device)
|
||||
sdpa_out_i = flex_attention(q_sdpa_in,
|
||||
k_sdpa_in,
|
||||
v_sdpa_in,
|
||||
block_mask=block_mask,
|
||||
scale=scale,
|
||||
enable_gqa=True)
|
||||
block_mask = create_block_mask(
|
||||
final_mask_mod, B=None, H=None, Q_LEN=q_len, KV_LEN=kv_len, device=device
|
||||
)
|
||||
sdpa_out_i = flex_attention(
|
||||
q_sdpa_in,
|
||||
k_sdpa_in,
|
||||
v_sdpa_in,
|
||||
block_mask=block_mask,
|
||||
scale=scale,
|
||||
enable_gqa=True,
|
||||
)
|
||||
|
||||
all_sdpa_outputs.append(sdpa_out_i.transpose(1, 2).squeeze(0))
|
||||
|
||||
@@ -408,7 +401,8 @@ def _test_backend_correctness(
|
||||
sdpa_output = torch.cat(all_sdpa_outputs, dim=0)
|
||||
|
||||
common_attn_metadata = create_common_attn_metadata(
|
||||
batch_spec, vllm_config.cache_config.block_size, device)
|
||||
batch_spec, vllm_config.cache_config.block_size, device
|
||||
)
|
||||
|
||||
# 3. Simulate Paged KV Cache and a realistic slot_mapping
|
||||
kv_cache = create_and_prepopulate_kv_cache(
|
||||
@@ -421,7 +415,8 @@ def _test_backend_correctness(
|
||||
device=device,
|
||||
num_blocks=vllm_config.cache_config.num_gpu_blocks or 1000,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
randomize_blocks=True)
|
||||
randomize_blocks=True,
|
||||
)
|
||||
|
||||
# 4. Run vLLM backends and compare
|
||||
# Note: flex_attention has known Triton kernel compatibility issues
|
||||
@@ -437,8 +432,9 @@ def _test_backend_correctness(
|
||||
kv_cache_for_backend = kv_cache.transpose(0, 1)
|
||||
|
||||
# For FlashInfer default to HND layout and
|
||||
kv_cache_for_backend = kv_cache_for_backend.transpose(
|
||||
2, 3).contiguous().transpose(2, 3)
|
||||
kv_cache_for_backend = (
|
||||
kv_cache_for_backend.transpose(2, 3).contiguous().transpose(2, 3)
|
||||
)
|
||||
set_kv_cache_layout("HND")
|
||||
|
||||
backend_output = run_attention_backend(
|
||||
@@ -458,32 +454,45 @@ def _test_backend_correctness(
|
||||
# Check shape and dtype consistency
|
||||
assert backend_output.shape == sdpa_output.shape, (
|
||||
f"[{backend_name}] shape {backend_output.shape} != "
|
||||
f"SDPA shape {sdpa_output.shape}")
|
||||
f"SDPA shape {sdpa_output.shape}"
|
||||
)
|
||||
assert backend_output.dtype == sdpa_output.dtype, (
|
||||
f"[{backend_name}] dtype {backend_output.dtype} != "
|
||||
f"SDPA dtype {sdpa_output.dtype}")
|
||||
f"SDPA dtype {sdpa_output.dtype}"
|
||||
)
|
||||
|
||||
assert torch.isfinite(backend_output).all(), (
|
||||
f"[{backend_name}] produced non-finite values")
|
||||
f"[{backend_name}] produced non-finite values"
|
||||
)
|
||||
|
||||
# Check numerical similarity
|
||||
def error_msg(msg: str, backend_name: str):
|
||||
return (f"[{backend_name}] output differs from SDPA baseline. "
|
||||
f"{msg}")
|
||||
return f"[{backend_name}] output differs from SDPA baseline. {msg}"
|
||||
|
||||
torch.testing.assert_close(backend_output,
|
||||
sdpa_output,
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
msg=partial(error_msg,
|
||||
backend_name=backend_name))
|
||||
torch.testing.assert_close(
|
||||
backend_output,
|
||||
sdpa_output,
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
msg=partial(error_msg, backend_name=backend_name),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_spec_name", [
|
||||
"small_decode", "small_prefill", "mixed_small", "medium_decode",
|
||||
"medium_prefill", "mixed_medium", "large_decode", "large_prefill",
|
||||
"single_decode", "single_prefill"
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"batch_spec_name",
|
||||
[
|
||||
"small_decode",
|
||||
"small_prefill",
|
||||
"mixed_small",
|
||||
"medium_decode",
|
||||
"medium_prefill",
|
||||
"mixed_medium",
|
||||
"large_decode",
|
||||
"large_prefill",
|
||||
"single_decode",
|
||||
"single_prefill",
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"])
|
||||
def test_causal_backend_correctness(batch_spec_name: str, model: str):
|
||||
"""Test backend's correctness with causal attention."""
|
||||
@@ -499,33 +508,33 @@ def test_causal_backend_correctness(batch_spec_name: str, model: str):
|
||||
return (q_idx + context_len) >= kv_idx
|
||||
|
||||
batch_spec = BATCH_SPECS[batch_spec_name]
|
||||
LARGE_BLOCK_BACKENDS = ([_Backend.FLEX_ATTENTION]
|
||||
if is_torch_equal_or_newer("2.9.0.dev0") else [])
|
||||
LARGE_BLOCK_BACKENDS = (
|
||||
[_Backend.FLEX_ATTENTION] if is_torch_equal_or_newer("2.9.0.dev0") else []
|
||||
)
|
||||
SMALL_BLOCK_BACKENDS = [
|
||||
x for x in BACKENDS_TO_TEST if x not in LARGE_BLOCK_BACKENDS
|
||||
]
|
||||
_test_backend_correctness(batch_spec, model, SMALL_BLOCK_BACKENDS,
|
||||
causal_mask_mod)
|
||||
_test_backend_correctness(batch_spec, model, SMALL_BLOCK_BACKENDS, causal_mask_mod)
|
||||
|
||||
# Fast FlexAttention needs to run with block_size=128
|
||||
if LARGE_BLOCK_BACKENDS:
|
||||
_test_backend_correctness(batch_spec,
|
||||
model,
|
||||
LARGE_BLOCK_BACKENDS,
|
||||
causal_mask_mod,
|
||||
block_size=128)
|
||||
_test_backend_correctness(
|
||||
batch_spec, model, LARGE_BLOCK_BACKENDS, causal_mask_mod, block_size=128
|
||||
)
|
||||
|
||||
|
||||
SLIDING_WINDOW_BACKENDS_TO_TEST = [
|
||||
_Backend.FLASH_ATTN, _Backend.FLEX_ATTENTION, _Backend.TRITON_ATTN,
|
||||
"FLEX_ATTENTION_SLOW"
|
||||
_Backend.FLASH_ATTN,
|
||||
_Backend.FLEX_ATTENTION,
|
||||
_Backend.TRITON_ATTN,
|
||||
"FLEX_ATTENTION_SLOW",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_spec_name", [
|
||||
"small_decode", "small_prefill", "mixed_medium", "large_decode",
|
||||
"large_prefill"
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"batch_spec_name",
|
||||
["small_decode", "small_prefill", "mixed_medium", "large_decode", "large_prefill"],
|
||||
)
|
||||
@pytest.mark.parametrize("model", ["microsoft/Phi-tiny-MoE-instruct"])
|
||||
def test_sliding_window_backend_correctness(batch_spec_name: str, model: str):
|
||||
"""Test backend's correctness with sliding window attention."""
|
||||
@@ -544,25 +553,28 @@ def test_sliding_window_backend_correctness(batch_spec_name: str, model: str):
|
||||
return causal_mask & window_mask
|
||||
|
||||
batch_spec = BATCH_SPECS[batch_spec_name]
|
||||
model_config = ModelConfig(model=model,
|
||||
max_model_len=max(batch_spec.seq_lens))
|
||||
model_config = ModelConfig(model=model, max_model_len=max(batch_spec.seq_lens))
|
||||
sliding_window = model_config.get_sliding_window()
|
||||
sliding_window_mask_mod_fn = partial(sliding_window_mask_mod,
|
||||
sliding_window=sliding_window)
|
||||
sliding_window_mask_mod_fn = partial(
|
||||
sliding_window_mask_mod, sliding_window=sliding_window
|
||||
)
|
||||
|
||||
LARGE_BLOCK_BACKENDS = ([_Backend.FLEX_ATTENTION]
|
||||
if is_torch_equal_or_newer("2.9.0.dev0") else [])
|
||||
LARGE_BLOCK_BACKENDS = (
|
||||
[_Backend.FLEX_ATTENTION] if is_torch_equal_or_newer("2.9.0.dev0") else []
|
||||
)
|
||||
SMALL_BLOCK_BACKENDS = [
|
||||
x for x in SLIDING_WINDOW_BACKENDS_TO_TEST
|
||||
if x not in LARGE_BLOCK_BACKENDS
|
||||
x for x in SLIDING_WINDOW_BACKENDS_TO_TEST if x not in LARGE_BLOCK_BACKENDS
|
||||
]
|
||||
_test_backend_correctness(batch_spec, model, SMALL_BLOCK_BACKENDS,
|
||||
sliding_window_mask_mod_fn)
|
||||
_test_backend_correctness(
|
||||
batch_spec, model, SMALL_BLOCK_BACKENDS, sliding_window_mask_mod_fn
|
||||
)
|
||||
|
||||
# Fast FlexAttention needs to run with block_size=128
|
||||
if LARGE_BLOCK_BACKENDS:
|
||||
_test_backend_correctness(batch_spec,
|
||||
model,
|
||||
LARGE_BLOCK_BACKENDS,
|
||||
sliding_window_mask_mod_fn,
|
||||
block_size=128)
|
||||
_test_backend_correctness(
|
||||
batch_spec,
|
||||
model,
|
||||
LARGE_BLOCK_BACKENDS,
|
||||
sliding_window_mask_mod_fn,
|
||||
block_size=128,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user