[Fix] uniform decode batch check (#30747)

Signed-off-by: Jialin Ouyang <Jialin.Ouyang@gmail.com>
This commit is contained in:
Jialin Ouyang
2025-12-17 03:58:43 -08:00
committed by GitHub
parent 6482e3895b
commit 6e9dbcc50e
2 changed files with 121 additions and 8 deletions

View File

@@ -1110,3 +1110,87 @@ def test_hybrid_cache_integration(model_runner, dist_init):
runner._update_states(scheduler_output)
assert _is_req_scheduled(runner, req_id)
assert _is_req_state_block_table_match(runner, req_id)
def test_is_uniform_decode() -> None:
# Normal
assert GPUModelRunner._is_uniform_decode(
max_num_scheduled_tokens=1,
uniform_decode_query_len=1,
num_tokens=16,
num_reqs=16,
)
assert not GPUModelRunner._is_uniform_decode(
max_num_scheduled_tokens=2,
uniform_decode_query_len=1,
num_tokens=16,
num_reqs=16,
)
assert not GPUModelRunner._is_uniform_decode(
max_num_scheduled_tokens=1,
uniform_decode_query_len=1,
num_tokens=16,
num_reqs=15,
)
# Spec decoding
assert GPUModelRunner._is_uniform_decode(
max_num_scheduled_tokens=5,
uniform_decode_query_len=5,
num_tokens=30,
num_reqs=6,
)
assert not GPUModelRunner._is_uniform_decode(
max_num_scheduled_tokens=5,
uniform_decode_query_len=4,
num_tokens=30,
num_reqs=6,
)
assert not GPUModelRunner._is_uniform_decode(
max_num_scheduled_tokens=5,
uniform_decode_query_len=5,
num_tokens=30,
num_reqs=7,
)
# Force uniform decode
assert GPUModelRunner._is_uniform_decode(
max_num_scheduled_tokens=1,
uniform_decode_query_len=1,
num_tokens=16,
num_reqs=16,
force_uniform_decode=True,
)
assert GPUModelRunner._is_uniform_decode(
max_num_scheduled_tokens=2,
uniform_decode_query_len=1,
num_tokens=16,
num_reqs=16,
force_uniform_decode=True,
)
assert GPUModelRunner._is_uniform_decode(
max_num_scheduled_tokens=1,
uniform_decode_query_len=1,
num_tokens=16,
num_reqs=15,
force_uniform_decode=True,
)
assert not GPUModelRunner._is_uniform_decode(
max_num_scheduled_tokens=1,
uniform_decode_query_len=1,
num_tokens=16,
num_reqs=16,
force_uniform_decode=False,
)
assert not GPUModelRunner._is_uniform_decode(
max_num_scheduled_tokens=2,
uniform_decode_query_len=1,
num_tokens=16,
num_reqs=16,
force_uniform_decode=False,
)
assert not GPUModelRunner._is_uniform_decode(
max_num_scheduled_tokens=1,
uniform_decode_query_len=1,
num_tokens=16,
num_reqs=15,
force_uniform_decode=False,
)