Run ruff format on a few files. (#24075)

Signed-off-by: Chenheli Hua <huachenheli@outlook.com>
This commit is contained in:
Chenheli Hua
2025-09-02 10:55:32 -07:00
committed by GitHub
parent 1c41310584
commit f399182e8c
3 changed files with 948 additions and 691 deletions

View File

@@ -103,6 +103,7 @@ class PILImage(BaseModel):
"""
A PIL.Image.Image object.
"""
image_pil: Image.Image
model_config = ConfigDict(arbitrary_types_allowed=True)
@@ -115,6 +116,7 @@ class CustomChatCompletionContentPILImageParam(TypedDict, total=False):
"image_pil": ImageAsset('cherry_blossom').pil_image
}
"""
image_pil: Required[PILImage]
@@ -127,6 +129,7 @@ class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False):
"image_url": "https://example.com/image.jpg"
}
"""
image_url: Required[str]
@@ -138,6 +141,7 @@ class CustomChatCompletionContentSimpleAudioParam(TypedDict, total=False):
"audio_url": "https://example.com/audio.mp3"
}
"""
audio_url: Required[str]
@@ -149,6 +153,7 @@ class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False):
"video_url": "https://example.com/video.mp4"
}
"""
video_url: Required[str]
@@ -174,19 +179,24 @@ class CustomThinkCompletionContentParam(TypedDict, total=False):
ChatCompletionContentPartParam: TypeAlias = Union[
OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam,
OpenAIChatCompletionContentPartParam,
ChatCompletionContentPartAudioParam,
ChatCompletionContentPartInputAudioParam,
ChatCompletionContentPartVideoParam, ChatCompletionContentPartRefusalParam,
ChatCompletionContentPartVideoParam,
ChatCompletionContentPartRefusalParam,
CustomChatCompletionContentPILImageParam,
CustomChatCompletionContentSimpleImageParam,
ChatCompletionContentPartImageEmbedsParam,
CustomChatCompletionContentSimpleAudioParam,
CustomChatCompletionContentSimpleVideoParam, str,
CustomThinkCompletionContentParam]
CustomChatCompletionContentSimpleVideoParam,
str,
CustomThinkCompletionContentParam,
]
class CustomChatCompletionMessageParam(TypedDict, total=False):
"""Enables custom roles in the Chat Completion API."""
role: Required[str]
"""The role of the message's author."""
@@ -207,9 +217,11 @@ class CustomChatCompletionMessageParam(TypedDict, total=False):
"""The tool calls generated by the model, such as function calls."""
ChatCompletionMessageParam = Union[OpenAIChatCompletionMessageParam,
CustomChatCompletionMessageParam,
OpenAIHarmonyMessage]
ChatCompletionMessageParam = Union[
OpenAIChatCompletionMessageParam,
CustomChatCompletionMessageParam,
OpenAIHarmonyMessage,
]
# TODO: Make fields ReadOnly once mypy supports it
@@ -262,13 +274,13 @@ def _is_var_or_elems_access(
key: Optional[str] = None,
) -> bool:
if isinstance(node, jinja2.nodes.Filter):
return (node.node is not None
and _is_var_or_elems_access(node.node, varname, key))
return node.node is not None and _is_var_or_elems_access(
node.node, varname, key)
if isinstance(node, jinja2.nodes.Test):
return _is_var_or_elems_access(node.node, varname, key)
if (isinstance(node, jinja2.nodes.Getitem)
and isinstance(node.arg, jinja2.nodes.Slice)):
if isinstance(node, jinja2.nodes.Getitem) and isinstance(
node.arg, jinja2.nodes.Slice):
return _is_var_or_elems_access(node.node, varname, key)
# yapf: disable
@@ -373,15 +385,18 @@ def resolve_mistral_chat_template(
) -> Optional[str]:
if chat_template is not None:
logger.warning_once(
"'chat_template' cannot be overridden for mistral tokenizer.")
"'chat_template' cannot be overridden for mistral tokenizer."
)
if "add_generation_prompt" in kwargs:
logger.warning_once(
"'add_generation_prompt' is not supported for mistral tokenizer, "
"so it will be ignored.")
"so it will be ignored."
)
if "continue_final_message" in kwargs:
logger.warning_once(
"'continue_final_message' is not supported for mistral tokenizer, "
"so it will be ignored.")
"so it will be ignored."
)
return None
@@ -401,23 +416,35 @@ def resolve_hf_chat_template(
try:
processor = cached_get_processor(
tokenizer.name_or_path,
processor_cls=(PreTrainedTokenizer, PreTrainedTokenizerFast,
ProcessorMixin),
processor_cls=(
PreTrainedTokenizer,
PreTrainedTokenizerFast,
ProcessorMixin,
),
trust_remote_code=model_config.trust_remote_code,
)
if isinstance(processor, ProcessorMixin) and \
hasattr(processor, 'chat_template') and \
processor.chat_template is not None:
if (
isinstance(processor, ProcessorMixin)
and hasattr(processor, "chat_template")
and processor.chat_template is not None
):
return processor.chat_template
except Exception:
logger.debug("Failed to load AutoProcessor chat template for %s", tokenizer.name_or_path, exc_info=True) # noqa: E501
logger.debug(
"Failed to load AutoProcessor chat template for %s",
tokenizer.name_or_path,
exc_info=True,
) # noqa: E501
# 3rd priority: AutoTokenizer chat template
try:
return tokenizer.get_chat_template(chat_template, tools=tools)
except Exception:
logger.debug("Failed to load AutoTokenizer chat template for %s",
tokenizer.name_or_path, exc_info=True)
logger.debug(
"Failed to load AutoTokenizer chat template for %s",
tokenizer.name_or_path,
exc_info=True,
)
# 4th priority: Predefined fallbacks
path = get_chat_template_fallback_path(
@@ -425,12 +452,16 @@ def resolve_hf_chat_template(
tokenizer_name_or_path=model_config.tokenizer,
)
if path is not None:
logger.info("Loading chat template fallback for %s as there isn't one "
"defined on HF Hub.", tokenizer.name_or_path)
logger.info(
"Loading chat template fallback for %s as there isn't one "
"defined on HF Hub.",
tokenizer.name_or_path,
)
chat_template = load_chat_template(path)
else:
logger.debug("There is no chat template fallback for %s",
tokenizer.name_or_path)
logger.debug(
"There is no chat template fallback for %s", tokenizer.name_or_path
)
return chat_template
@@ -452,11 +483,17 @@ def _resolve_chat_template_content_format(
else:
hf_chat_template = None
jinja_text = (hf_chat_template if isinstance(hf_chat_template, str)
else load_chat_template(chat_template, is_literal=True))
jinja_text = (
hf_chat_template
if isinstance(hf_chat_template, str)
else load_chat_template(chat_template, is_literal=True)
)
detected_format = ("string" if jinja_text is None else
_detect_content_format(jinja_text, default="string"))
detected_format = (
"string"
if jinja_text is None
else _detect_content_format(jinja_text, default="string")
)
return detected_format
@@ -512,7 +549,6 @@ def resolve_chat_template_content_format(
return detected_format
ModalityStr = Literal["image", "audio", "video", "image_embeds"]
_T = TypeVar("_T")
@@ -539,6 +575,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
@cached_property
def model_cls(self) -> type[SupportsMultiModal]:
from vllm.model_executor.model_loader import get_model_cls
model_cls = get_model_cls(self.model_config)
return cast(type[SupportsMultiModal], model_cls)
@@ -574,28 +611,29 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
def all_mm_data(self) -> Optional[MultiModalDataDict]:
if not self._items_by_modality:
return None
mm_inputs = {}
items_by_modality = dict(self._items_by_modality)
if "image" in items_by_modality and "image_embeds" in items_by_modality:
raise ValueError(\
"Mixing raw image and embedding inputs is not allowed")
raise ValueError(
"Mixing raw image and embedding inputs is not allowed"
)
if "image_embeds" in items_by_modality:
image_embeds_lst = items_by_modality["image_embeds"]
if len(image_embeds_lst) > 1:
raise ValueError(\
"Only one message can have {'type': 'image_embeds'}")
raise ValueError(
"Only one message can have {'type': 'image_embeds'}"
)
mm_inputs["image"] = image_embeds_lst[0]
if "image" in items_by_modality:
mm_inputs["image"] = items_by_modality["image"] # A list of images
mm_inputs["image"] = items_by_modality["image"] # A list of images
if "audio" in items_by_modality:
mm_inputs["audio"] = items_by_modality["audio"] # A list of audios
mm_inputs["audio"] = items_by_modality["audio"] # A list of audios
if "video" in items_by_modality:
mm_inputs["video"] = items_by_modality["video"] # A list of videos
mm_inputs["video"] = items_by_modality["video"] # A list of videos
return mm_inputs
def create_parser(self) -> "BaseMultiModalContentParser":
@@ -603,32 +641,33 @@ class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
async def all_mm_data(self) -> Optional[MultiModalDataDict]:
if not self._items_by_modality:
return None
mm_inputs = {}
items_by_modality = {
modality: await asyncio.gather(*items)
for modality, items in self._items_by_modality.items()
}
modality: await asyncio.gather(*items)
for modality, items in self._items_by_modality.items()
}
if "image" in items_by_modality and "image_embeds" in items_by_modality:
raise ValueError(
"Mixing raw image and embedding inputs is not allowed")
"Mixing raw image and embedding inputs is not allowed"
)
if "image_embeds" in items_by_modality:
image_embeds_lst = items_by_modality["image_embeds"]
if len(image_embeds_lst) > 1:
raise ValueError(
"Only one message can have {'type': 'image_embeds'}")
"Only one message can have {'type': 'image_embeds'}"
)
mm_inputs["image"] = image_embeds_lst[0]
if "image" in items_by_modality:
mm_inputs["image"] = items_by_modality["image"] # A list of images
mm_inputs["image"] = items_by_modality["image"] # A list of images
if "audio" in items_by_modality:
mm_inputs["audio"] = items_by_modality["audio"] # A list of audios
mm_inputs["audio"] = items_by_modality["audio"] # A list of audios
if "video" in items_by_modality:
mm_inputs["video"] = items_by_modality["video"] # A list of videos
mm_inputs["video"] = items_by_modality["video"] # A list of videos
return mm_inputs
def create_parser(self) -> "BaseMultiModalContentParser":
@@ -636,7 +675,6 @@ class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
class BaseMultiModalContentParser(ABC):
def __init__(self) -> None:
super().__init__()
@@ -648,8 +686,9 @@ class BaseMultiModalContentParser(ABC):
# }
self._placeholder_storage: dict[str, list] = defaultdict(list)
def _add_placeholder(self, modality: ModalityStr,
placeholder: Optional[str]):
def _add_placeholder(
self, modality: ModalityStr, placeholder: Optional[str]
):
mod_placeholder = MODALITY_PLACEHOLDERS_MAP[modality]
if placeholder:
self._placeholder_storage[mod_placeholder].append(placeholder)
@@ -662,8 +701,9 @@ class BaseMultiModalContentParser(ABC):
raise NotImplementedError
@abstractmethod
def parse_image_embeds(self,
image_embeds: Union[str, dict[str, str]]) -> None:
def parse_image_embeds(
self, image_embeds: Union[str, dict[str, str]]
) -> None:
raise NotImplementedError
@abstractmethod
@@ -684,7 +724,6 @@ class BaseMultiModalContentParser(ABC):
class MultiModalContentParser(BaseMultiModalContentParser):
def __init__(self, tracker: MultiModalItemTracker) -> None:
super().__init__()
@@ -701,8 +740,9 @@ class MultiModalContentParser(BaseMultiModalContentParser):
placeholder = self._tracker.add("image", image)
self._add_placeholder("image", placeholder)
def parse_image_embeds(self,
image_embeds: Union[str, dict[str, str]]) -> None:
def parse_image_embeds(
self, image_embeds: Union[str, dict[str, str]]
) -> None:
if isinstance(image_embeds, dict):
embeds = {
k: self._connector.fetch_image_embedding(v)
@@ -741,14 +781,13 @@ class MultiModalContentParser(BaseMultiModalContentParser):
class AsyncMultiModalContentParser(BaseMultiModalContentParser):
def __init__(self, tracker: AsyncMultiModalItemTracker) -> None:
super().__init__()
self._tracker = tracker
self._connector = MediaConnector(
media_io_kwargs=self._tracker._model_config.media_io_kwargs,
allowed_local_media_path=tracker.allowed_local_media_path
allowed_local_media_path=tracker.allowed_local_media_path,
)
def parse_image(self, image_url: str) -> None:
@@ -757,8 +796,9 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
placeholder = self._tracker.add("image", image_coro)
self._add_placeholder("image", placeholder)
def parse_image_embeds(self,
image_embeds: Union[str, dict[str, str]]) -> None:
def parse_image_embeds(
self, image_embeds: Union[str, dict[str, str]]
) -> None:
future: asyncio.Future[Union[str, dict[str, str]]] = asyncio.Future()
if isinstance(image_embeds, dict):
@@ -769,8 +809,7 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
future.set_result(embeds)
if isinstance(image_embeds, str):
embedding = self._connector.\
fetch_image_embedding(image_embeds)
embedding = self._connector.fetch_image_embedding(image_embeds)
future.set_result(embedding)
placeholder = self._tracker.add("image_embeds", future)
@@ -809,20 +848,23 @@ def validate_chat_template(chat_template: Optional[Union[Path, str]]):
return
elif isinstance(chat_template, Path) and not chat_template.exists():
raise FileNotFoundError(
"the supplied chat template path doesn't exist")
raise FileNotFoundError("the supplied chat template path doesn't exist")
elif isinstance(chat_template, str):
JINJA_CHARS = "{}\n"
if not any(c in chat_template
for c in JINJA_CHARS) and not Path(chat_template).exists():
if (
not any(c in chat_template for c in JINJA_CHARS)
and not Path(chat_template).exists()
):
raise ValueError(
f"The supplied chat template string ({chat_template}) "
f"appears path-like, but doesn't exist!")
f"appears path-like, but doesn't exist!"
)
else:
raise TypeError(
f"{type(chat_template)} is not a valid chat template type")
f"{type(chat_template)} is not a valid chat template type"
)
def _load_chat_template(
@@ -835,8 +877,9 @@ def _load_chat_template(
if is_literal:
if isinstance(chat_template, Path):
raise TypeError("chat_template is expected to be read directly "
"from its value")
raise TypeError(
"chat_template is expected to be read directly from its value"
)
return chat_template
@@ -849,9 +892,11 @@ def _load_chat_template(
JINJA_CHARS = "{}\n"
if not any(c in chat_template for c in JINJA_CHARS):
msg = (f"The supplied chat template ({chat_template}) "
f"looks like a file path, but it failed to be "
f"opened. Reason: {e}")
msg = (
f"The supplied chat template ({chat_template}) "
f"looks like a file path, but it failed to be "
f"opened. Reason: {e}"
)
raise ValueError(msg) from e
# If opening a file fails, set chat template to be args to
@@ -870,8 +915,9 @@ def load_chat_template(
return _cached_load_chat_template(chat_template, is_literal=is_literal)
def _get_interleaved_text_prompt(placeholder_storage: dict[str, list],
texts: list[str]) -> str:
def _get_interleaved_text_prompt(
placeholder_storage: dict[str, list], texts: list[str]
) -> str:
for idx, elem in enumerate(texts):
if elem in placeholder_storage:
texts[idx] = placeholder_storage[elem].pop(0)
@@ -881,10 +927,11 @@ def _get_interleaved_text_prompt(placeholder_storage: dict[str, list],
# TODO: Let user specify how to insert multimodal tokens into prompt
# (similar to chat template)
def _get_full_multimodal_text_prompt(placeholder_storage: dict[str, list],
texts: list[str],
interleave_strings: bool
) -> str:
def _get_full_multimodal_text_prompt(
placeholder_storage: dict[str, list],
texts: list[str],
interleave_strings: bool,
) -> str:
"""Combine multimodal prompts for a multimodal language model."""
# flatten storage to make it looks like
@@ -907,7 +954,6 @@ def _get_full_multimodal_text_prompt(placeholder_storage: dict[str, list],
# Look through the text prompt to check for missing placeholders
missing_placeholders: list[str] = []
for placeholder in placeholder_counts:
# For any existing placeholder in the text prompt, we leave it as is
placeholder_counts[placeholder] -= text_prompt.count(placeholder)
@@ -916,15 +962,18 @@ def _get_full_multimodal_text_prompt(placeholder_storage: dict[str, list],
"Placeholder count is negative! "
"Ensure that the 'interleave_strings' flag is disabled "
"(current value: %s) "
"when manually placing image placeholders.", interleave_strings
"when manually placing image placeholders.",
interleave_strings,
)
logger.debug("Input prompt: %s", text_prompt)
raise ValueError(
f"Found more '{placeholder}' placeholders in input prompt than "
"actual multimodal data items.")
"actual multimodal data items."
)
missing_placeholders.extend([placeholder] *
placeholder_counts[placeholder])
missing_placeholders.extend(
[placeholder] * placeholder_counts[placeholder]
)
# NOTE: Default behaviour: we always add missing placeholders
# at the front of the prompt, if interleave_strings=False
@@ -944,7 +993,8 @@ _AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam).validate_python
_VideoParser = TypeAdapter(ChatCompletionContentPartVideoParam).validate_python
_ResponsesInputImageParser = TypeAdapter(
ResponseInputImageParam).validate_python
ResponseInputImageParam
).validate_python
_ContentPart: TypeAlias = Union[str, dict[str, str], InputAudio, PILImage]
# Define a mapping from part types to their corresponding parsing functions.
@@ -952,32 +1002,35 @@ MM_PARSER_MAP: dict[
str,
Callable[[ChatCompletionContentPartParam], _ContentPart],
] = {
"text":
lambda part: _TextParser(part).get("text", None),
"thinking":
lambda part: _ThinkParser(part).get("thinking", None),
"input_text":
lambda part: _TextParser(part).get("text", None),
"input_image":
lambda part: _ResponsesInputImageParser(part).get("image_url", None),
"image_url":
lambda part: _ImageParser(part).get("image_url", {}).get("url", None),
"image_embeds":
lambda part: _ImageEmbedsParser(part).get("image_embeds", None),
"text": lambda part: _TextParser(part).get("text", None),
"thinking": lambda part: _ThinkParser(part).get("thinking", None),
"input_text": lambda part: _TextParser(part).get("text", None),
"input_image": lambda part: _ResponsesInputImageParser(part).get(
"image_url", None
),
"image_url": lambda part: _ImageParser(part)
.get("image_url", {})
.get("url", None),
"image_embeds": lambda part: _ImageEmbedsParser(part).get(
"image_embeds", None
),
"image_pil": lambda part: _PILImageParser(part).get("image_pil", None),
"audio_url":
lambda part: _AudioParser(part).get("audio_url", {}).get("url", None),
"input_audio":
lambda part: _InputAudioParser(part).get("input_audio", None),
"refusal":
lambda part: _RefusalParser(part).get("refusal", None),
"video_url":
lambda part: _VideoParser(part).get("video_url", {}).get("url", None),
"audio_url": lambda part: _AudioParser(part)
.get("audio_url", {})
.get("url", None),
"input_audio": lambda part: _InputAudioParser(part).get(
"input_audio", None
),
"refusal": lambda part: _RefusalParser(part).get("refusal", None),
"video_url": lambda part: _VideoParser(part)
.get("video_url", {})
.get("url", None),
}
def _parse_chat_message_content_mm_part(
part: ChatCompletionContentPartParam) -> tuple[str, _ContentPart]:
part: ChatCompletionContentPartParam,
) -> tuple[str, _ContentPart]:
"""
Parses a given multi-modal content part based on its type.
@@ -993,7 +1046,8 @@ def _parse_chat_message_content_mm_part(
ValueError: If the 'type' field is missing and no direct URL is found.
"""
assert isinstance(
part, dict) # This is needed to avoid mypy errors: part.get() from str
part, dict
) # This is needed to avoid mypy errors: part.get() from str
part_type = part.get("type", None)
if isinstance(part_type, str) and part_type in MM_PARSER_MAP:
@@ -1002,8 +1056,10 @@ def _parse_chat_message_content_mm_part(
# Special case for 'image_url.detail'
# We only support 'auto', which is the default
if part_type == "image_url" and part.get("detail", "auto") != "auto":
logger.warning("'image_url.detail' is currently not supported "
"and will be ignored.")
logger.warning(
"'image_url.detail' is currently not supported "
"and will be ignored."
)
return part_type, content
@@ -1011,19 +1067,22 @@ def _parse_chat_message_content_mm_part(
# 'type' is required field by pydantic
if part_type is None:
if part.get("image_url") is not None:
image_params = cast(CustomChatCompletionContentSimpleImageParam,
part)
image_params = cast(
CustomChatCompletionContentSimpleImageParam, part
)
return "image_url", image_params.get("image_url", "")
if part.get("audio_url") is not None:
audio_params = cast(CustomChatCompletionContentSimpleAudioParam,
part)
audio_params = cast(
CustomChatCompletionContentSimpleAudioParam, part
)
return "audio_url", audio_params.get("audio_url", "")
if part.get("input_audio") is not None:
input_audio_params = cast(dict[str, str], part)
return "input_audio", input_audio_params
if part.get("video_url") is not None:
video_params = cast(CustomChatCompletionContentSimpleVideoParam,
part)
video_params = cast(
CustomChatCompletionContentSimpleVideoParam, part
)
return "video_url", video_params.get("video_url", "")
# Raise an error if no 'type' or direct URL is found.
raise ValueError("Missing 'type' field in multimodal part.")
@@ -1033,9 +1092,16 @@ def _parse_chat_message_content_mm_part(
return part_type, "unknown part_type content"
VALID_MESSAGE_CONTENT_MM_PART_TYPES = ("text", "refusal", "image_url",
"image_embeds", "image_pil",
"audio_url", "input_audio", "video_url")
VALID_MESSAGE_CONTENT_MM_PART_TYPES = (
"text",
"refusal",
"image_url",
"image_embeds",
"image_pil",
"audio_url",
"input_audio",
"video_url",
)
def _parse_chat_message_content_parts(
@@ -1055,21 +1121,20 @@ def _parse_chat_message_content_parts(
part,
mm_parser,
wrap_dicts=wrap_dicts,
interleave_strings=interleave_strings
interleave_strings=interleave_strings,
)
if parse_res:
content.append(parse_res)
if wrap_dicts:
# Parsing wraps images and texts as interleaved dictionaries
return [ConversationMessage(role=role,
content=content)] # type: ignore
return [ConversationMessage(role=role, content=content)] # type: ignore
texts = cast(list[str], content)
mm_placeholder_storage = mm_parser.mm_placeholder_storage()
if mm_placeholder_storage:
text_prompt = _get_full_multimodal_text_prompt(mm_placeholder_storage,
texts,
interleave_strings)
text_prompt = _get_full_multimodal_text_prompt(
mm_placeholder_storage, texts, interleave_strings
)
else:
text_prompt = "\n".join(texts)
@@ -1099,13 +1164,16 @@ def _parse_chat_message_content_part(
if part_type in VALID_MESSAGE_CONTENT_MM_PART_TYPES and content is None:
logger.warning(
"Skipping multimodal part '%s' (type: '%s') "
"with empty / unparsable content.", part, part_type)
"with empty / unparsable content.",
part,
part_type,
)
return None
if part_type in ("text", "input_text", "refusal", "thinking"):
str_content = cast(str, content)
if wrap_dicts:
return {'type': 'text', 'text': str_content}
return {"type": "text", "text": str_content}
else:
return str_content
@@ -1137,8 +1205,12 @@ def _parse_chat_message_content_part(
else:
raise NotImplementedError(f"Unknown part type: {part_type}")
return {'type': modality} if wrap_dicts else (
MODALITY_PLACEHOLDERS_MAP[modality] if interleave_strings else None
return (
{"type": modality}
if wrap_dicts
else (
MODALITY_PLACEHOLDERS_MAP[modality] if interleave_strings else None
)
)
@@ -1171,14 +1243,16 @@ def _parse_chat_message_content(
)
for result_msg in result:
if role == 'assistant':
if role == "assistant":
parsed_msg = _AssistantParser(message)
# The 'tool_calls' is not None check ensures compatibility.
# It's needed only if downstream code doesn't strictly
# follow the OpenAI spec.
if ("tool_calls" in parsed_msg
and parsed_msg["tool_calls"] is not None):
if (
"tool_calls" in parsed_msg
and parsed_msg["tool_calls"] is not None
):
result_msg["tool_calls"] = list(parsed_msg["tool_calls"])
elif role == "tool":
parsed_msg = _ToolParser(message)
@@ -1198,12 +1272,15 @@ def _postprocess_messages(messages: list[ConversationMessage]) -> None:
# so, for messages that have tool_calls, parse the string (which we get
# from openAI format) to dict
for message in messages:
if (message["role"] == "assistant" and "tool_calls" in message
and isinstance(message["tool_calls"], list)):
if (
message["role"] == "assistant"
and "tool_calls" in message
and isinstance(message["tool_calls"], list)
):
for item in message["tool_calls"]:
item["function"]["arguments"] = json.loads(
item["function"]["arguments"])
item["function"]["arguments"]
)
def parse_chat_messages(
@@ -1224,7 +1301,7 @@ def parse_chat_messages(
content_format == "string"
and model_config.multimodal_config is not None
and model_config.multimodal_config.interleave_mm_strings
)
),
)
conversation.extend(sub_messages)
@@ -1252,7 +1329,7 @@ def parse_chat_messages_futures(
content_format == "string"
and model_config.multimodal_config is not None
and model_config.multimodal_config.interleave_mm_strings
)
),
)
conversation.extend(sub_messages)
@@ -1283,10 +1360,10 @@ def apply_hf_chat_template(
raise ValueError(
"As of transformers v4.44, default chat template is no longer "
"allowed, so you must provide a chat template if the tokenizer "
"does not define one.")
"does not define one."
)
try:
return tokenizer.apply_chat_template(
conversation=conversation, # type: ignore[arg-type]
tools=tools, # type: ignore[arg-type]
@@ -1298,13 +1375,14 @@ def apply_hf_chat_template(
# External library exceptions can sometimes occur despite the framework's
# internal exception management capabilities.
except Exception as e:
# Log and report any library-related exceptions for further
# investigation.
logger.exception(
"An error occurred in `transformers` while applying chat template")
"An error occurred in `transformers` while applying chat template"
)
raise ValueError(str(e)) from e
def apply_mistral_chat_template(
tokenizer: MistralTokenizer,
messages: list[ChatCompletionMessageParam],
@@ -1337,26 +1415,26 @@ def apply_mistral_chat_template(
# External library exceptions can sometimes occur despite the framework's
# internal exception management capabilities.
except Exception as e:
# Log and report any library-related exceptions for further
# investigation.
logger.exception(
"An error occurred in `mistral_common` while applying chat "
"template")
"An error occurred in `mistral_common` while applying chat template"
)
raise ValueError(str(e)) from e
def get_history_tool_calls_cnt(conversation: list[ConversationMessage]):
idx = 0
for msg in conversation:
if msg['role'] == 'assistant':
tool_calls = msg.get('tool_calls')
idx += len(list(tool_calls)) if tool_calls is not None else 0 # noqa
if msg["role"] == "assistant":
tool_calls = msg.get("tool_calls")
idx += len(list(tool_calls)) if tool_calls is not None else 0 # noqa
return idx
def make_tool_call_id(id_type:str='random', func_name=None, idx=None):
if id_type=='kimi_k2':
return f'functions.{func_name}:{idx}'
def make_tool_call_id(id_type: str = "random", func_name=None, idx=None):
if id_type == "kimi_k2":
return f"functions.{func_name}:{idx}"
else:
# by default return random
return f"chatcmpl-tool-{random_uuid()}"