[Bugfix]: Make chat content text allow type content (#9358)

Signed-off-by: Vinay Damodaran <vrdn@hey.com>
This commit is contained in:
Vinay R Damodaran
2024-10-24 01:05:49 -04:00
committed by GitHub
parent b7df53cd42
commit 33bab41060
8 changed files with 107 additions and 12 deletions

View File

@@ -121,7 +121,7 @@ class ConversationMessage(TypedDict, total=False):
role: Required[str]
"""The role of the message's author."""
content: Optional[str]
content: Union[Optional[str], List[Dict[str, str]]]
"""The contents of the message"""
tool_call_id: Optional[str]
@@ -431,7 +431,7 @@ MM_PARSER_MAP: Dict[str, Callable[[ChatCompletionContentPartParam], str]] = {
def _parse_chat_message_content_mm_part(
part: ChatCompletionContentPartParam) -> Tuple[str, str]:
"""
Parses a given multi modal content part based on its type.
Parses a given multi-modal content part based on its type.
Args:
part: A dict containing the content part, with a potential 'type' field.
@@ -485,21 +485,26 @@ def _parse_chat_message_content_parts(
role: str,
parts: Iterable[ChatCompletionContentPartParam],
mm_tracker: BaseMultiModalItemTracker,
chat_template_text_format: str,
) -> List[ConversationMessage]:
content: List[Union[str, Dict[str, str]]] = []
mm_parser = mm_tracker.create_parser()
keep_multimodal_content = \
wrap_dicts = \
mm_tracker._model_config.hf_config.model_type in \
MODEL_KEEP_MULTI_MODAL_CONTENT
MODEL_KEEP_MULTI_MODAL_CONTENT or \
(chat_template_text_format == "openai")
for part in parts:
parse_res = _parse_chat_message_content_part(
part, mm_parser, wrap_dicts=keep_multimodal_content)
part,
mm_parser,
wrap_dicts=wrap_dicts,
)
if parse_res:
content.append(parse_res)
if keep_multimodal_content:
if wrap_dicts:
# Parsing wraps images and texts as interleaved dictionaries
return [ConversationMessage(role=role,
content=content)] # type: ignore
@@ -560,6 +565,7 @@ _ToolParser = partial(cast, ChatCompletionToolMessageParam)
def _parse_chat_message_content(
message: ChatCompletionMessageParam,
mm_tracker: BaseMultiModalItemTracker,
chat_template_text_format: str,
) -> List[ConversationMessage]:
role = message["role"]
content = message.get("content")
@@ -575,6 +581,7 @@ def _parse_chat_message_content(
role,
content, # type: ignore
mm_tracker,
chat_template_text_format,
)
for result_msg in result:
@@ -618,7 +625,11 @@ def parse_chat_messages(
mm_tracker = MultiModalItemTracker(model_config, tokenizer)
for msg in messages:
sub_messages = _parse_chat_message_content(msg, mm_tracker)
sub_messages = _parse_chat_message_content(
msg,
mm_tracker,
model_config.chat_template_text_format,
)
conversation.extend(sub_messages)
@@ -636,7 +647,11 @@ def parse_chat_messages_futures(
mm_tracker = AsyncMultiModalItemTracker(model_config, tokenizer)
for msg in messages:
sub_messages = _parse_chat_message_content(msg, mm_tracker)
sub_messages = _parse_chat_message_content(
msg,
mm_tracker,
model_config.chat_template_text_format,
)
conversation.extend(sub_messages)