[Frontend] Factor out chat message parsing (#7055)

This commit is contained in:
Cyrus Leung
2024-08-03 12:31:27 +08:00
committed by GitHub
parent 69ea15e5cc
commit 8c025fa703
3 changed files with 39 additions and 27 deletions

View File

@@ -1,7 +1,8 @@
import codecs
from dataclasses import dataclass, field
from dataclasses import dataclass
from functools import lru_cache
from typing import Awaitable, Iterable, List, Optional, Union, cast, final
from typing import (Awaitable, Iterable, List, Optional, Tuple, Union, cast,
final)
# yapf conflicts with isort for this block
# yapf: disable
@@ -65,8 +66,7 @@ class ConversationMessage(TypedDict):
@dataclass(frozen=True)
class ChatMessageParseResult:
messages: List[ConversationMessage]
mm_futures: List[Awaitable[MultiModalDataDict]] = field(
default_factory=list)
mm_futures: List[Awaitable[MultiModalDataDict]]
def load_chat_template(chat_template: Optional[str]) -> Optional[str]:
@@ -174,7 +174,7 @@ def _parse_chat_message_content_parts(
return ChatMessageParseResult(messages=messages, mm_futures=mm_futures)
def parse_chat_message_content(
def _parse_chat_message_content(
message: ChatCompletionMessageParam,
model_config: ModelConfig,
tokenizer: PreTrainedTokenizer,
@@ -190,3 +190,21 @@ def parse_chat_message_content(
return _parse_chat_message_content_parts(role, content, model_config,
tokenizer)
def parse_chat_messages(
messages: List[ChatCompletionMessageParam],
model_config: ModelConfig,
tokenizer: PreTrainedTokenizer,
) -> Tuple[List[ConversationMessage], List[Awaitable[MultiModalDataDict]]]:
conversation: List[ConversationMessage] = []
mm_futures: List[Awaitable[MultiModalDataDict]] = []
for msg in messages:
parse_result = _parse_chat_message_content(msg, model_config,
tokenizer)
conversation.extend(parse_result.messages)
mm_futures.extend(parse_result.mm_futures)
return conversation, mm_futures