Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -15,14 +15,12 @@ pytestmark = pytest.mark.cpu_test
|
||||
|
||||
|
||||
def _make_model_runner_output(
|
||||
scheduler_output: SchedulerOutput, ) -> ModelRunnerOutput:
|
||||
scheduler_output: SchedulerOutput,
|
||||
) -> ModelRunnerOutput:
|
||||
req_ids = list(scheduler_output.num_scheduled_tokens.keys())
|
||||
return ModelRunnerOutput(
|
||||
req_ids=req_ids,
|
||||
req_id_to_index={
|
||||
req_id: i
|
||||
for i, req_id in enumerate(req_ids)
|
||||
},
|
||||
req_id_to_index={req_id: i for i, req_id in enumerate(req_ids)},
|
||||
sampled_token_ids=[[i] for i in range(len(req_ids))],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
@@ -75,8 +73,7 @@ def test_abort():
|
||||
if not abort_order:
|
||||
return
|
||||
req = requests[abort_order.pop(0)]
|
||||
scheduler.finish_requests(req.request_id,
|
||||
RequestStatus.FINISHED_ABORTED)
|
||||
scheduler.finish_requests(req.request_id, RequestStatus.FINISHED_ABORTED)
|
||||
|
||||
while sched_outputs:
|
||||
# Abort a scheduled request.
|
||||
@@ -112,8 +109,7 @@ def test_preempt():
|
||||
if not abort_order:
|
||||
return
|
||||
req = requests[abort_order.pop(0)]
|
||||
scheduler.finish_requests(req.request_id,
|
||||
RequestStatus.FINISHED_ABORTED)
|
||||
scheduler.finish_requests(req.request_id, RequestStatus.FINISHED_ABORTED)
|
||||
|
||||
while sched_outputs:
|
||||
# Abort a scheduled request.
|
||||
@@ -135,15 +131,19 @@ def test_prefix_caching_for_prefill_dedup():
|
||||
CHUNK_SIZE = 1000
|
||||
BLOCK_SIZE = 16
|
||||
num_prompt_tokens = 100
|
||||
scheduler = create_scheduler(async_scheduling=True,
|
||||
max_num_batched_tokens=CHUNK_SIZE,
|
||||
enable_prefix_caching=True,
|
||||
block_size=BLOCK_SIZE)
|
||||
requests = create_requests(num_requests=5,
|
||||
num_tokens=num_prompt_tokens,
|
||||
max_tokens=3,
|
||||
same_prompt=True,
|
||||
block_size=BLOCK_SIZE)
|
||||
scheduler = create_scheduler(
|
||||
async_scheduling=True,
|
||||
max_num_batched_tokens=CHUNK_SIZE,
|
||||
enable_prefix_caching=True,
|
||||
block_size=BLOCK_SIZE,
|
||||
)
|
||||
requests = create_requests(
|
||||
num_requests=5,
|
||||
num_tokens=num_prompt_tokens,
|
||||
max_tokens=3,
|
||||
same_prompt=True,
|
||||
block_size=BLOCK_SIZE,
|
||||
)
|
||||
requests_copy = requests.copy()
|
||||
|
||||
# Two requests with the same prompt.
|
||||
@@ -185,14 +185,18 @@ def test_prefix_caching_for_multi_turn():
|
||||
BLOCK_SIZE = 16
|
||||
num_prompt_tokens = 100
|
||||
num_output_tokens = 200
|
||||
scheduler = create_scheduler(async_scheduling=True,
|
||||
max_num_batched_tokens=CHUNK_SIZE,
|
||||
enable_prefix_caching=True,
|
||||
block_size=BLOCK_SIZE)
|
||||
requests = create_requests(num_requests=5,
|
||||
num_tokens=num_prompt_tokens,
|
||||
max_tokens=num_output_tokens,
|
||||
block_size=BLOCK_SIZE)
|
||||
scheduler = create_scheduler(
|
||||
async_scheduling=True,
|
||||
max_num_batched_tokens=CHUNK_SIZE,
|
||||
enable_prefix_caching=True,
|
||||
block_size=BLOCK_SIZE,
|
||||
)
|
||||
requests = create_requests(
|
||||
num_requests=5,
|
||||
num_tokens=num_prompt_tokens,
|
||||
max_tokens=num_output_tokens,
|
||||
block_size=BLOCK_SIZE,
|
||||
)
|
||||
|
||||
for req in requests:
|
||||
scheduler.add_request(req)
|
||||
@@ -212,14 +216,16 @@ def test_prefix_caching_for_multi_turn():
|
||||
|
||||
# Create next-turn requests whose prompts are the full output of the
|
||||
# previous turn.
|
||||
next_turn_requests = create_requests(num_requests=5,
|
||||
num_tokens=num_prompt_tokens +
|
||||
num_output_tokens,
|
||||
max_tokens=num_output_tokens,
|
||||
block_size=BLOCK_SIZE)
|
||||
next_turn_requests = create_requests(
|
||||
num_requests=5,
|
||||
num_tokens=num_prompt_tokens + num_output_tokens,
|
||||
max_tokens=num_output_tokens,
|
||||
block_size=BLOCK_SIZE,
|
||||
)
|
||||
for i, req in enumerate(next_turn_requests):
|
||||
req.prompt_token_ids = (requests[i].prompt_token_ids +
|
||||
list(requests[i].output_token_ids))
|
||||
req.prompt_token_ids = requests[i].prompt_token_ids + list(
|
||||
requests[i].output_token_ids
|
||||
)
|
||||
req._all_token_ids = req.prompt_token_ids.copy()
|
||||
req.all_token_ids = ConstantList(req._all_token_ids)
|
||||
req.block_hashes = []
|
||||
@@ -233,5 +239,4 @@ def test_prefix_caching_for_multi_turn():
|
||||
# Make sure the next-turn requests get prefix cache hit by the previous
|
||||
# requests.
|
||||
for req in next_turn_requests:
|
||||
assert (req.num_cached_tokens == req.num_prompt_tokens // BLOCK_SIZE *
|
||||
BLOCK_SIZE)
|
||||
assert req.num_cached_tokens == req.num_prompt_tokens // BLOCK_SIZE * BLOCK_SIZE
|
||||
|
||||
Reference in New Issue
Block a user