[Bugfix]: Make chat content text allow type content (#9358)
Signed-off-by: Vinay Damodaran <vrdn@hey.com>
This commit is contained in:
committed by
GitHub
parent
b7df53cd42
commit
33bab41060
@@ -121,7 +121,7 @@ class ConversationMessage(TypedDict, total=False):
|
||||
role: Required[str]
|
||||
"""The role of the message's author."""
|
||||
|
||||
content: Optional[str]
|
||||
content: Union[Optional[str], List[Dict[str, str]]]
|
||||
"""The contents of the message"""
|
||||
|
||||
tool_call_id: Optional[str]
|
||||
@@ -431,7 +431,7 @@ MM_PARSER_MAP: Dict[str, Callable[[ChatCompletionContentPartParam], str]] = {
|
||||
def _parse_chat_message_content_mm_part(
|
||||
part: ChatCompletionContentPartParam) -> Tuple[str, str]:
|
||||
"""
|
||||
Parses a given multi modal content part based on its type.
|
||||
Parses a given multi-modal content part based on its type.
|
||||
|
||||
Args:
|
||||
part: A dict containing the content part, with a potential 'type' field.
|
||||
@@ -485,21 +485,26 @@ def _parse_chat_message_content_parts(
|
||||
role: str,
|
||||
parts: Iterable[ChatCompletionContentPartParam],
|
||||
mm_tracker: BaseMultiModalItemTracker,
|
||||
chat_template_text_format: str,
|
||||
) -> List[ConversationMessage]:
|
||||
content: List[Union[str, Dict[str, str]]] = []
|
||||
|
||||
mm_parser = mm_tracker.create_parser()
|
||||
keep_multimodal_content = \
|
||||
wrap_dicts = \
|
||||
mm_tracker._model_config.hf_config.model_type in \
|
||||
MODEL_KEEP_MULTI_MODAL_CONTENT
|
||||
MODEL_KEEP_MULTI_MODAL_CONTENT or \
|
||||
(chat_template_text_format == "openai")
|
||||
|
||||
for part in parts:
|
||||
parse_res = _parse_chat_message_content_part(
|
||||
part, mm_parser, wrap_dicts=keep_multimodal_content)
|
||||
part,
|
||||
mm_parser,
|
||||
wrap_dicts=wrap_dicts,
|
||||
)
|
||||
if parse_res:
|
||||
content.append(parse_res)
|
||||
|
||||
if keep_multimodal_content:
|
||||
if wrap_dicts:
|
||||
# Parsing wraps images and texts as interleaved dictionaries
|
||||
return [ConversationMessage(role=role,
|
||||
content=content)] # type: ignore
|
||||
@@ -560,6 +565,7 @@ _ToolParser = partial(cast, ChatCompletionToolMessageParam)
|
||||
def _parse_chat_message_content(
|
||||
message: ChatCompletionMessageParam,
|
||||
mm_tracker: BaseMultiModalItemTracker,
|
||||
chat_template_text_format: str,
|
||||
) -> List[ConversationMessage]:
|
||||
role = message["role"]
|
||||
content = message.get("content")
|
||||
@@ -575,6 +581,7 @@ def _parse_chat_message_content(
|
||||
role,
|
||||
content, # type: ignore
|
||||
mm_tracker,
|
||||
chat_template_text_format,
|
||||
)
|
||||
|
||||
for result_msg in result:
|
||||
@@ -618,7 +625,11 @@ def parse_chat_messages(
|
||||
mm_tracker = MultiModalItemTracker(model_config, tokenizer)
|
||||
|
||||
for msg in messages:
|
||||
sub_messages = _parse_chat_message_content(msg, mm_tracker)
|
||||
sub_messages = _parse_chat_message_content(
|
||||
msg,
|
||||
mm_tracker,
|
||||
model_config.chat_template_text_format,
|
||||
)
|
||||
|
||||
conversation.extend(sub_messages)
|
||||
|
||||
@@ -636,7 +647,11 @@ def parse_chat_messages_futures(
|
||||
mm_tracker = AsyncMultiModalItemTracker(model_config, tokenizer)
|
||||
|
||||
for msg in messages:
|
||||
sub_messages = _parse_chat_message_content(msg, mm_tracker)
|
||||
sub_messages = _parse_chat_message_content(
|
||||
msg,
|
||||
mm_tracker,
|
||||
model_config.chat_template_text_format,
|
||||
)
|
||||
|
||||
conversation.extend(sub_messages)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user