[V1] Avoid redundant input processing in n>1 case (#14985)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill
2025-03-20 22:24:10 -07:00
committed by GitHub
parent 7297941b38
commit da6ea29f7a
13 changed files with 85 additions and 145 deletions

View File

@@ -41,10 +41,10 @@ async def test_tokenizer_group(tokenizer_group_type):
max_input_length=None,
)
assert reference_tokenizer.encode("prompt") == tokenizer_group.encode(
request_id="request_id", prompt="prompt", lora_request=None)
prompt="prompt", lora_request=None)
assert reference_tokenizer.encode(
"prompt") == await tokenizer_group.encode_async(
request_id="request_id", prompt="prompt", lora_request=None)
"prompt") == await tokenizer_group.encode_async(prompt="prompt",
lora_request=None)
assert isinstance(tokenizer_group.get_lora_tokenizer(None),
PreTrainedTokenizerBase)
assert tokenizer_group.get_lora_tokenizer(
@@ -69,8 +69,7 @@ async def test_tokenizer_group_pool(tokenizer_group_type):
# and check that all requests are processed correctly.
num_requests = tokenizer_group_pool.pool_size * 5
requests = [
tokenizer_group_pool.encode_async(request_id=str(i),
prompt=f"prompt {i}",
tokenizer_group_pool.encode_async(prompt=f"prompt {i}",
lora_request=None)
for i in range(num_requests)
]
@@ -161,12 +160,8 @@ async def test_tokenizer_group_ray_pool_fault_tolerance(tokenizer_group_type):
fail_at[0] = 1000
# We should recover successfully.
await tokenizer_group_pool.encode_async(request_id="1",
prompt="prompt",
lora_request=None)
await tokenizer_group_pool.encode_async(request_id="1",
prompt="prompt",
lora_request=None)
await tokenizer_group_pool.encode_async(prompt="prompt", lora_request=None)
await tokenizer_group_pool.encode_async(prompt="prompt", lora_request=None)
# Check that we have a new actor
assert len(tokenizer_group_pool.tokenizer_actors) == len(tokenizer_actors)
@@ -184,8 +179,7 @@ async def test_tokenizer_group_ray_pool_fault_tolerance(tokenizer_group_type):
# We should fail after re-initialization.
with pytest.raises(RuntimeError):
await tokenizer_group_pool.encode_async(request_id="1",
prompt="prompt",
await tokenizer_group_pool.encode_async(prompt="prompt",
lora_request=None)
# check_health should raise the same thing
@@ -206,11 +200,8 @@ async def test_tokenizer_group_ray_pool_fault_tolerance(tokenizer_group_type):
# Prompt too long error
with pytest.raises(ValueError):
await tokenizer_group_pool.encode_async(request_id="1",
prompt="prompt" * 100,
await tokenizer_group_pool.encode_async(prompt="prompt" * 100,
lora_request=None)
await tokenizer_group_pool.encode_async(request_id="1",
prompt="prompt",
lora_request=None)
await tokenizer_group_pool.encode_async(prompt="prompt", lora_request=None)
# Actors should stay the same.
assert tokenizer_group_pool.tokenizer_actors == tokenizer_actors