[Bugfix] Fix SHM cache initialization (#26427)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -245,17 +245,13 @@ class MockModelConfig:
|
||||
return self.diff_sampling_param or {}
|
||||
|
||||
|
||||
def _build_serving_chat(
|
||||
engine: AsyncLLM, model_config: MockModelConfig
|
||||
) -> OpenAIServingChat:
|
||||
def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
|
||||
models = OpenAIServingModels(
|
||||
engine_client=engine,
|
||||
base_model_paths=BASE_MODEL_PATHS,
|
||||
model_config=model_config,
|
||||
)
|
||||
serving_chat = OpenAIServingChat(
|
||||
engine,
|
||||
model_config,
|
||||
models,
|
||||
response_role="assistant",
|
||||
chat_template=CHAT_TEMPLATE,
|
||||
@@ -280,18 +276,17 @@ def _build_serving_chat(
|
||||
|
||||
@dataclass
|
||||
class MockEngine:
|
||||
async def get_model_config(self):
|
||||
return MockModelConfig()
|
||||
model_config: MockModelConfig = field(default_factory=MockModelConfig)
|
||||
processor: MagicMock = field(default_factory=MagicMock)
|
||||
io_processor: MagicMock = field(default_factory=MagicMock)
|
||||
|
||||
|
||||
async def _async_serving_chat_init():
|
||||
engine = MockEngine()
|
||||
model_config = await engine.get_model_config()
|
||||
|
||||
models = OpenAIServingModels(engine, model_config, BASE_MODEL_PATHS)
|
||||
models = OpenAIServingModels(engine, BASE_MODEL_PATHS)
|
||||
serving_completion = OpenAIServingChat(
|
||||
engine,
|
||||
model_config,
|
||||
models,
|
||||
response_role="assistant",
|
||||
chat_template=CHAT_TEMPLATE,
|
||||
@@ -311,8 +306,11 @@ async def test_serving_chat_returns_correct_model_name():
|
||||
mock_engine = MagicMock(spec=AsyncLLM)
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = MockModelConfig()
|
||||
mock_engine.processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
|
||||
serving_chat = _build_serving_chat(mock_engine, MockModelConfig())
|
||||
serving_chat = _build_serving_chat(mock_engine)
|
||||
messages = [{"role": "user", "content": "what is 1+1?"}]
|
||||
|
||||
async def return_model_name(*args):
|
||||
@@ -338,8 +336,11 @@ async def test_serving_chat_should_set_correct_max_tokens():
|
||||
mock_engine = MagicMock(spec=AsyncLLM)
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = MockModelConfig()
|
||||
mock_engine.processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
|
||||
serving_chat = _build_serving_chat(mock_engine, MockModelConfig())
|
||||
serving_chat = _build_serving_chat(mock_engine)
|
||||
|
||||
req = ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
@@ -368,9 +369,12 @@ async def test_serving_chat_should_set_correct_max_tokens():
|
||||
mock_engine = MagicMock(spec=AsyncLLM)
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = mock_model_config
|
||||
mock_engine.processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
|
||||
# Initialize the serving chat
|
||||
serving_chat = _build_serving_chat(mock_engine, mock_model_config)
|
||||
serving_chat = _build_serving_chat(mock_engine)
|
||||
|
||||
# Test Case 1: No max_tokens specified in request
|
||||
req = ChatCompletionRequest(
|
||||
@@ -410,9 +414,12 @@ async def test_serving_chat_should_set_correct_max_tokens():
|
||||
mock_engine = MagicMock(spec=AsyncLLM)
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = mock_model_config
|
||||
mock_engine.processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
|
||||
# Initialize the serving chat
|
||||
serving_chat = _build_serving_chat(mock_engine, mock_model_config)
|
||||
serving_chat = _build_serving_chat(mock_engine)
|
||||
|
||||
# Test case 1: No max_tokens specified, defaults to context_window
|
||||
req = ChatCompletionRequest(
|
||||
@@ -453,9 +460,12 @@ async def test_serving_chat_could_load_correct_generation_config():
|
||||
mock_engine = MagicMock(spec=AsyncLLM)
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = mock_model_config
|
||||
mock_engine.processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
|
||||
# Initialize the serving chat
|
||||
serving_chat = _build_serving_chat(mock_engine, mock_model_config)
|
||||
serving_chat = _build_serving_chat(mock_engine)
|
||||
|
||||
req = ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
@@ -496,8 +506,11 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type):
|
||||
mock_engine = MagicMock(spec=AsyncLLM)
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = mock_model_config
|
||||
mock_engine.processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
|
||||
serving_chat = _build_serving_chat(mock_engine, mock_model_config)
|
||||
serving_chat = _build_serving_chat(mock_engine)
|
||||
|
||||
# Test cache_salt
|
||||
req = ChatCompletionRequest(
|
||||
|
||||
Reference in New Issue
Block a user