[Core] Add fault tolerance for RayTokenizerGroupPool (#5748)
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from typing import List, Optional
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
@@ -100,3 +102,100 @@ async def test_tokenizer_group_ray_pool_env_var_propagation(
|
||||
max_num_seqs=1,
|
||||
max_input_length=None)
|
||||
tokenizer_pool.ping()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("tokenizer_group_type", ["ray"])
|
||||
async def test_tokenizer_group_ray_pool_fault_tolerance(tokenizer_group_type):
|
||||
"""Test that Ray tokenizer pool group can recover from failures and
|
||||
if that's not possible, mark itself as unhealthy."""
|
||||
|
||||
class FailingTokenizerGroup(TokenizerGroup):
|
||||
|
||||
def __init__(self,
|
||||
*args,
|
||||
fail_at: Optional[List[int]] = None,
|
||||
**kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.i = 0
|
||||
self.fail_at = fail_at or []
|
||||
|
||||
def encode(self, *args, **kwargs):
|
||||
self.i += 1
|
||||
if self.i in self.fail_at:
|
||||
sys.exit(1)
|
||||
return super().encode(*args, **kwargs)
|
||||
|
||||
class FailingRayTokenizerGroupPool(RayTokenizerGroupPool):
|
||||
_worker_cls = FailingTokenizerGroup
|
||||
|
||||
# Fail at first iteration
|
||||
fail_at = [1]
|
||||
tokenizer_pool_config = get_tokenizer_pool_config(tokenizer_group_type)
|
||||
tokenizer_group_pool = FailingRayTokenizerGroupPool.from_config(
|
||||
tokenizer_pool_config,
|
||||
tokenizer_id="gpt2",
|
||||
enable_lora=False,
|
||||
max_num_seqs=1,
|
||||
max_input_length=None,
|
||||
fail_at=fail_at)
|
||||
tokenizer_actors = tokenizer_group_pool.tokenizer_actors.copy()
|
||||
|
||||
# Modify fail at to not fail at all (will be re-read when actor is
|
||||
# re-initialized).
|
||||
fail_at[0] = 1000
|
||||
|
||||
# We should recover successfully.
|
||||
await tokenizer_group_pool.encode_async(request_id="1",
|
||||
prompt="prompt",
|
||||
lora_request=None)
|
||||
await tokenizer_group_pool.encode_async(request_id="1",
|
||||
prompt="prompt",
|
||||
lora_request=None)
|
||||
|
||||
# Check that we have a new actor
|
||||
assert len(tokenizer_group_pool.tokenizer_actors) == len(tokenizer_actors)
|
||||
assert tokenizer_group_pool.tokenizer_actors != tokenizer_actors
|
||||
|
||||
# Fail at first iteration
|
||||
fail_at = [1]
|
||||
tokenizer_group_pool = FailingRayTokenizerGroupPool.from_config(
|
||||
tokenizer_pool_config,
|
||||
tokenizer_id="gpt2",
|
||||
enable_lora=False,
|
||||
max_num_seqs=1,
|
||||
max_input_length=None,
|
||||
fail_at=fail_at)
|
||||
|
||||
# We should fail after re-initialization.
|
||||
with pytest.raises(RuntimeError):
|
||||
await tokenizer_group_pool.encode_async(request_id="1",
|
||||
prompt="prompt",
|
||||
lora_request=None)
|
||||
|
||||
# check_health should raise the same thing
|
||||
with pytest.raises(RuntimeError):
|
||||
tokenizer_group_pool.check_health()
|
||||
|
||||
# Ensure that non-ActorDiedErrors are still propagated correctly and do not
|
||||
# cause a re-initialization.
|
||||
fail_at = []
|
||||
tokenizer_group_pool = FailingRayTokenizerGroupPool.from_config(
|
||||
tokenizer_pool_config,
|
||||
tokenizer_id="gpt2",
|
||||
enable_lora=False,
|
||||
max_num_seqs=1,
|
||||
max_input_length=2,
|
||||
fail_at=fail_at)
|
||||
tokenizer_actors = tokenizer_group_pool.tokenizer_actors.copy()
|
||||
|
||||
# Prompt too long error
|
||||
with pytest.raises(ValueError):
|
||||
await tokenizer_group_pool.encode_async(request_id="1",
|
||||
prompt="prompt" * 100,
|
||||
lora_request=None)
|
||||
await tokenizer_group_pool.encode_async(request_id="1",
|
||||
prompt="prompt",
|
||||
lora_request=None)
|
||||
# Actors should stay the same.
|
||||
assert tokenizer_group_pool.tokenizer_actors == tokenizer_actors
|
||||
|
||||
Reference in New Issue
Block a user