Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -6,11 +6,13 @@ import torch
|
||||
|
||||
from tests.v1.attention.test_attention_backends import BATCH_SPECS
|
||||
from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
|
||||
from vllm.v1.attention.backends.utils import (UBatchSlice,
|
||||
_make_metadata_with_slice,
|
||||
slice_query_start_locs,
|
||||
split_attn_metadata,
|
||||
split_decodes_and_prefills)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
UBatchSlice,
|
||||
_make_metadata_with_slice,
|
||||
slice_query_start_locs,
|
||||
split_attn_metadata,
|
||||
split_decodes_and_prefills,
|
||||
)
|
||||
from vllm.v1.worker.ubatch_splitting import create_ubatch_slices
|
||||
|
||||
|
||||
@@ -79,9 +81,7 @@ def small_decode_metadata():
|
||||
"""Create metadata for small decode batch"""
|
||||
batch_spec = BATCH_SPECS["small_decode"]
|
||||
device = torch.device("cpu")
|
||||
return create_common_attn_metadata(batch_spec,
|
||||
block_size=16,
|
||||
device=device)
|
||||
return create_common_attn_metadata(batch_spec, block_size=16, device=device)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -89,9 +89,7 @@ def large_decode_metadata():
|
||||
"""Create metadata for small decode batch"""
|
||||
batch_spec = BATCH_SPECS["large_decode"]
|
||||
device = torch.device("cpu")
|
||||
return create_common_attn_metadata(batch_spec,
|
||||
block_size=16,
|
||||
device=device)
|
||||
return create_common_attn_metadata(batch_spec, block_size=16, device=device)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -99,9 +97,7 @@ def mixed_small_metadata():
|
||||
"""Create metadata for mixed small batch"""
|
||||
batch_spec = BATCH_SPECS["mixed_small"]
|
||||
device = torch.device("cpu")
|
||||
return create_common_attn_metadata(batch_spec,
|
||||
block_size=16,
|
||||
device=device)
|
||||
return create_common_attn_metadata(batch_spec, block_size=16, device=device)
|
||||
|
||||
|
||||
# Tests for _make_metadata_with_slice
|
||||
@@ -122,8 +118,7 @@ def test_make_metadata_with_slice_decode_batch(small_decode_metadata):
|
||||
|
||||
def test_make_metadata_with_slice_mixed_batch(mixed_small_metadata):
|
||||
"""Test slicing mixed batch metadata"""
|
||||
ubatch_slice = UBatchSlice(slice(1, 3),
|
||||
slice(1, 7)) # Requests 1-3, tokens 1-7
|
||||
ubatch_slice = UBatchSlice(slice(1, 3), slice(1, 7)) # Requests 1-3, tokens 1-7
|
||||
|
||||
result = _make_metadata_with_slice(ubatch_slice, mixed_small_metadata)
|
||||
|
||||
@@ -140,8 +135,7 @@ def test_split_attn_metadata_decode_batch(large_decode_metadata):
|
||||
mid_point = num_tokens // 2
|
||||
ubatch_slices = [
|
||||
UBatchSlice(slice(0, mid_point), slice(0, mid_point)),
|
||||
UBatchSlice(slice(mid_point, num_tokens), slice(mid_point,
|
||||
num_tokens)),
|
||||
UBatchSlice(slice(mid_point, num_tokens), slice(mid_point, num_tokens)),
|
||||
]
|
||||
|
||||
results = split_attn_metadata(ubatch_slices, large_decode_metadata)
|
||||
@@ -159,26 +153,30 @@ def test_split_attn_metadata_decode_batch(large_decode_metadata):
|
||||
assert torch.equal(results[1].seq_lens, torch.tensor([2048] * mid_point))
|
||||
|
||||
|
||||
def apply_split_decodes_and_prefills(query_lens: list[int],
|
||||
decode_threshold: int,
|
||||
require_uniform: bool):
|
||||
def apply_split_decodes_and_prefills(
|
||||
query_lens: list[int], decode_threshold: int, require_uniform: bool
|
||||
):
|
||||
"""Helper function to apply split_decodes_and_prefills and return
|
||||
the results."""
|
||||
device = torch.device("cpu")
|
||||
seq_lens = [10 * (i + 1) for i in range(len(query_lens))]
|
||||
common_metadata = create_common_attn_metadata(BatchSpec(
|
||||
seq_lens=seq_lens, query_lens=query_lens),
|
||||
block_size=16,
|
||||
device=device)
|
||||
return split_decodes_and_prefills(common_metadata,
|
||||
decode_threshold=decode_threshold,
|
||||
require_uniform=require_uniform)
|
||||
common_metadata = create_common_attn_metadata(
|
||||
BatchSpec(seq_lens=seq_lens, query_lens=query_lens),
|
||||
block_size=16,
|
||||
device=device,
|
||||
)
|
||||
return split_decodes_and_prefills(
|
||||
common_metadata,
|
||||
decode_threshold=decode_threshold,
|
||||
require_uniform=require_uniform,
|
||||
)
|
||||
|
||||
|
||||
def test_split_decodes_and_prefills_nonuniform_all_ones():
|
||||
query_lens = [1, 1, 1]
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
apply_split_decodes_and_prefills(query_lens, 1, False))
|
||||
apply_split_decodes_and_prefills(query_lens, 1, False)
|
||||
)
|
||||
assert num_decodes == 3
|
||||
assert num_prefills == 0
|
||||
assert num_decode_tokens == 3
|
||||
@@ -188,7 +186,8 @@ def test_split_decodes_and_prefills_nonuniform_all_ones():
|
||||
def test_split_decodes_and_prefills_nonuniform_all_short_decodes():
|
||||
query_lens = [1, 2, 1, 3, 2, 1, 2]
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
apply_split_decodes_and_prefills(query_lens, 3, False))
|
||||
apply_split_decodes_and_prefills(query_lens, 3, False)
|
||||
)
|
||||
assert num_decodes == 7
|
||||
assert num_prefills == 0
|
||||
assert num_decode_tokens == sum(query_lens)
|
||||
@@ -198,7 +197,8 @@ def test_split_decodes_and_prefills_nonuniform_all_short_decodes():
|
||||
def test_split_decodes_and_prefills_nonuniform_all_prefills():
|
||||
query_lens = [4, 5, 6, 7]
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
apply_split_decodes_and_prefills(query_lens, 3, False))
|
||||
apply_split_decodes_and_prefills(query_lens, 3, False)
|
||||
)
|
||||
assert num_decodes == 0
|
||||
assert num_prefills == 4
|
||||
assert num_decode_tokens == 0
|
||||
@@ -208,7 +208,8 @@ def test_split_decodes_and_prefills_nonuniform_all_prefills():
|
||||
def test_split_decodes_and_prefills_nonuniform_mixed_batch():
|
||||
query_lens = [2, 1, 3, 4, 5, 6, 7, 8]
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
apply_split_decodes_and_prefills(query_lens, 4, False))
|
||||
apply_split_decodes_and_prefills(query_lens, 4, False)
|
||||
)
|
||||
assert num_decodes == 4 # 2, 1, 3, 4 are all <= 4
|
||||
assert num_prefills == 4 # 5, 6, 7, 8 are all > 4
|
||||
assert num_decode_tokens == 10 # 2 + 1 + 3 + 4
|
||||
@@ -218,7 +219,8 @@ def test_split_decodes_and_prefills_nonuniform_mixed_batch():
|
||||
def test_split_decodes_and_prefills_uniform_all_ones():
|
||||
query_lens = [1, 1, 1]
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
apply_split_decodes_and_prefills(query_lens, 1, True))
|
||||
apply_split_decodes_and_prefills(query_lens, 1, True)
|
||||
)
|
||||
assert num_decodes == 3
|
||||
assert num_prefills == 0
|
||||
assert num_decode_tokens == 3
|
||||
@@ -228,7 +230,8 @@ def test_split_decodes_and_prefills_uniform_all_ones():
|
||||
def test_split_decodes_and_prefills_uniform_all_short_decodes():
|
||||
query_lens = [2, 2, 1, 3, 2, 1, 2]
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
apply_split_decodes_and_prefills(query_lens, 3, True))
|
||||
apply_split_decodes_and_prefills(query_lens, 3, True)
|
||||
)
|
||||
assert num_decodes == 2
|
||||
assert num_prefills == 5
|
||||
assert num_decode_tokens == 4
|
||||
@@ -238,7 +241,8 @@ def test_split_decodes_and_prefills_uniform_all_short_decodes():
|
||||
def test_split_decodes_and_prefills_uniform_all_prefills():
|
||||
query_lens = [4, 5, 6, 7]
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
apply_split_decodes_and_prefills(query_lens, 3, True))
|
||||
apply_split_decodes_and_prefills(query_lens, 3, True)
|
||||
)
|
||||
assert num_decodes == 0
|
||||
assert num_prefills == 4
|
||||
assert num_decode_tokens == 0
|
||||
@@ -248,7 +252,8 @@ def test_split_decodes_and_prefills_uniform_all_prefills():
|
||||
def test_split_decodes_and_prefills_uniform_mixed_batch_all_uniform_decodes():
|
||||
query_lens = [2, 2, 2, 4, 5, 6, 7, 8]
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
apply_split_decodes_and_prefills(query_lens, 4, True))
|
||||
apply_split_decodes_and_prefills(query_lens, 4, True)
|
||||
)
|
||||
assert num_decodes == 3 # 2, 2, 2 are all <= 4 and uniform
|
||||
assert num_prefills == 5 # 4, 5, 6, 7, 8 are all > 4
|
||||
assert num_decode_tokens == 6 # 2 + 2 + 2
|
||||
@@ -258,7 +263,8 @@ def test_split_decodes_and_prefills_uniform_mixed_batch_all_uniform_decodes():
|
||||
def test_split_decodes_and_prefills_uniform_mixed_batch_non_uniform_decodes():
|
||||
query_lens = [2, 1, 2, 4, 5, 6, 7, 8]
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
apply_split_decodes_and_prefills(query_lens, 4, True))
|
||||
apply_split_decodes_and_prefills(query_lens, 4, True)
|
||||
)
|
||||
assert num_decodes == 1 # only the first 2 is taken as decode
|
||||
assert num_prefills == 7 # 1, 2, 4, 5, 6, 7, 8 are all > 4 or non-uniform
|
||||
assert num_decode_tokens == 2 # only the first 2
|
||||
@@ -274,17 +280,15 @@ def test_split_decodes_and_prefills_uniform_mixed_batch_non_uniform_decodes():
|
||||
([32, 40], [8, 8], 4, 1, 2),
|
||||
],
|
||||
)
|
||||
def test_prefill_split_across_ubatches(seq_lens, query_lens, split_point,
|
||||
expected_first_reqs,
|
||||
expected_second_reqs):
|
||||
def test_prefill_split_across_ubatches(
|
||||
seq_lens, query_lens, split_point, expected_first_reqs, expected_second_reqs
|
||||
):
|
||||
"""Test splitting a prefill across ubatches"""
|
||||
import numpy as np
|
||||
|
||||
device = torch.device("cpu")
|
||||
batch_spec = BatchSpec(seq_lens=seq_lens, query_lens=query_lens)
|
||||
common = create_common_attn_metadata(batch_spec,
|
||||
block_size=16,
|
||||
device=device)
|
||||
common = create_common_attn_metadata(batch_spec, block_size=16, device=device)
|
||||
|
||||
num_scheduled_tokens = np.array(query_lens, dtype=np.int32)
|
||||
qsl_np = common.query_start_loc_cpu.numpy()
|
||||
@@ -307,19 +311,19 @@ def test_prefill_split_across_ubatches(seq_lens, query_lens, split_point,
|
||||
# Identify which request is split and how many tokens are in the first chunk
|
||||
split_req_idx = int(np.searchsorted(qsl_np, split_point, side="right") - 1)
|
||||
tokens_in_first_chunk = split_point - int(qsl_np[split_req_idx])
|
||||
orig_q_lens = (common.query_start_loc_cpu[1:] -
|
||||
common.query_start_loc_cpu[:-1])
|
||||
orig_q_lens = common.query_start_loc_cpu[1:] - common.query_start_loc_cpu[:-1]
|
||||
|
||||
# Check query length continuity: first-chunk + second-chunk == original qlen
|
||||
# First ubatch last request query length
|
||||
qlen_first_last = int(first_meta.query_start_loc_cpu[-1] -
|
||||
first_meta.query_start_loc_cpu[-2])
|
||||
qlen_first_last = int(
|
||||
first_meta.query_start_loc_cpu[-1] - first_meta.query_start_loc_cpu[-2]
|
||||
)
|
||||
# Second ubatch first request query length
|
||||
qlen_second_first = int(second_meta.query_start_loc_cpu[1] -
|
||||
second_meta.query_start_loc_cpu[0])
|
||||
qlen_second_first = int(
|
||||
second_meta.query_start_loc_cpu[1] - second_meta.query_start_loc_cpu[0]
|
||||
)
|
||||
assert qlen_first_last == tokens_in_first_chunk
|
||||
assert qlen_first_last + qlen_second_first == int(
|
||||
orig_q_lens[split_req_idx])
|
||||
assert qlen_first_last + qlen_second_first == int(orig_q_lens[split_req_idx])
|
||||
|
||||
# Check seq_lens adjustments
|
||||
# Context lengths per original request
|
||||
|
||||
Reference in New Issue
Block a user