[V1] Avoid redundant input processing in n>1 case (#14985)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@@ -41,10 +41,10 @@ async def test_tokenizer_group(tokenizer_group_type):
|
||||
max_input_length=None,
|
||||
)
|
||||
assert reference_tokenizer.encode("prompt") == tokenizer_group.encode(
|
||||
request_id="request_id", prompt="prompt", lora_request=None)
|
||||
prompt="prompt", lora_request=None)
|
||||
assert reference_tokenizer.encode(
|
||||
"prompt") == await tokenizer_group.encode_async(
|
||||
request_id="request_id", prompt="prompt", lora_request=None)
|
||||
"prompt") == await tokenizer_group.encode_async(prompt="prompt",
|
||||
lora_request=None)
|
||||
assert isinstance(tokenizer_group.get_lora_tokenizer(None),
|
||||
PreTrainedTokenizerBase)
|
||||
assert tokenizer_group.get_lora_tokenizer(
|
||||
@@ -69,8 +69,7 @@ async def test_tokenizer_group_pool(tokenizer_group_type):
|
||||
# and check that all requests are processed correctly.
|
||||
num_requests = tokenizer_group_pool.pool_size * 5
|
||||
requests = [
|
||||
tokenizer_group_pool.encode_async(request_id=str(i),
|
||||
prompt=f"prompt {i}",
|
||||
tokenizer_group_pool.encode_async(prompt=f"prompt {i}",
|
||||
lora_request=None)
|
||||
for i in range(num_requests)
|
||||
]
|
||||
@@ -161,12 +160,8 @@ async def test_tokenizer_group_ray_pool_fault_tolerance(tokenizer_group_type):
|
||||
fail_at[0] = 1000
|
||||
|
||||
# We should recover successfully.
|
||||
await tokenizer_group_pool.encode_async(request_id="1",
|
||||
prompt="prompt",
|
||||
lora_request=None)
|
||||
await tokenizer_group_pool.encode_async(request_id="1",
|
||||
prompt="prompt",
|
||||
lora_request=None)
|
||||
await tokenizer_group_pool.encode_async(prompt="prompt", lora_request=None)
|
||||
await tokenizer_group_pool.encode_async(prompt="prompt", lora_request=None)
|
||||
|
||||
# Check that we have a new actor
|
||||
assert len(tokenizer_group_pool.tokenizer_actors) == len(tokenizer_actors)
|
||||
@@ -184,8 +179,7 @@ async def test_tokenizer_group_ray_pool_fault_tolerance(tokenizer_group_type):
|
||||
|
||||
# We should fail after re-initialization.
|
||||
with pytest.raises(RuntimeError):
|
||||
await tokenizer_group_pool.encode_async(request_id="1",
|
||||
prompt="prompt",
|
||||
await tokenizer_group_pool.encode_async(prompt="prompt",
|
||||
lora_request=None)
|
||||
|
||||
# check_health should raise the same thing
|
||||
@@ -206,11 +200,8 @@ async def test_tokenizer_group_ray_pool_fault_tolerance(tokenizer_group_type):
|
||||
|
||||
# Prompt too long error
|
||||
with pytest.raises(ValueError):
|
||||
await tokenizer_group_pool.encode_async(request_id="1",
|
||||
prompt="prompt" * 100,
|
||||
await tokenizer_group_pool.encode_async(prompt="prompt" * 100,
|
||||
lora_request=None)
|
||||
await tokenizer_group_pool.encode_async(request_id="1",
|
||||
prompt="prompt",
|
||||
lora_request=None)
|
||||
await tokenizer_group_pool.encode_async(prompt="prompt", lora_request=None)
|
||||
# Actors should stay the same.
|
||||
assert tokenizer_group_pool.tokenizer_actors == tokenizer_actors
|
||||
|
||||
Reference in New Issue
Block a user