16 Commits
main ... master

Author SHA1 Message Date
183147b58f no patches, just raw 2026-04-23 07:44:54 +00:00
eba5f3545d maybe just tool call parser 2026-04-23 07:40:54 +00:00
4bb1b6ca51 were back at patching parsers god dammit 2026-04-23 06:34:43 +00:00
c1c5db6568 fix: robust causal= patch — dedup + regex-based insertion instead of naive string replace 2026-04-23 05:47:50 +00:00
53f56b17d6 fix: make causal patch idempotent — skip if upstream already has causal= kwarg 2026-04-23 05:12:43 +00:00
b5537b9c52 fix: keep huggingface_hub — vLLM requires it at runtime 2026-04-23 04:43:00 +00:00
91fd11cf0b Fix: only download Kimi-K2.5-DFlash (K2.6-DFlash doesn't exist on HF) 2026-04-23 03:41:07 +00:00
c0df9172d9 Remove explicit huggingface-cli login - HF_TOKEN env var is read automatically by huggingface_hub 2026-04-23 03:37:42 +00:00
d1d85080e4 Fix Dockerfile: wrap RUN in bash -c to fix BuildKit parse error 2026-04-23 03:36:17 +00:00
17a0eb538b Fix Dockerfile syntax error (missing paren) and secure HF_TOKEN passing in Jenkinsfile 2026-04-23 03:35:24 +00:00
95d3f6df95 Guard HF_TOKEN usage in Dockerfile: only download draft models if they are missing 2026-04-23 03:33:36 +00:00
51ff6900db Fix HF_TOKEN masking: use different var name to avoid Jenkins auto-masking 2026-04-23 03:19:51 +00:00
f3f46c6d27 Use Jenkins credentials for HF_TOKEN instead of build parameter 2026-04-23 03:18:20 +00:00
7dbd4fe7ea Fix patch path: patches/ not payload/ 2026-04-23 02:45:59 +00:00
6a2e87884c Add Jenkinsfile for CI/CD pipeline 2026-04-23 02:43:08 +00:00
90dfb9c23c jam the drafter model into the container 2026-04-23 02:35:21 +00:00
6 changed files with 2965 additions and 21 deletions

View File

@@ -23,7 +23,26 @@ ENV PYTORCH_ROCM_ARCH=gfx942 \
OMP_NUM_THREADS=1
# --- Copy and apply DFlash patches ---
COPY payload/patch_dflash_rocm.py /tmp/patch_dflash_rocm.py
COPY patches/patch_dflash_rocm.py /tmp/patch_dflash_rocm.py
RUN python3 /tmp/patch_dflash_rocm.py && rm /tmp/patch_dflash_rocm.py
# --- Pre-download DFlash draft models ---
# These are needed for speculative decoding and must be local paths.
# Baking them into the image avoids runtime downloads/mounts.
# Pass HF_TOKEN build arg if the models are gated.
ARG HF_TOKEN=
RUN bash -c 'if [ ! -d "/opt/draft-models/Kimi-K2.5-DFlash" ]; then \
pip install --no-cache-dir huggingface_hub && \
python3 -c "from huggingface_hub import snapshot_download; snapshot_download(\"z-lab/Kimi-K2.5-DFlash\", local_dir=\"/opt/draft-models/Kimi-K2.5-DFlash\")" && \
rm -rf /root/.cache/huggingface; \
fi'
# Patch tool and reasoning parsers for Eagle
#COPY kimi_k2_tool_parser.py /usr/local/lib/python3.12/dist-packages/vllm/tool_parsers/kimi_k2_tool_parser.py
#COPY kimi_k2_reasoning_parser.py /usr/local/lib/python3.12/dist-packages/vllm/reasoning/kimi_k2_reasoning_parser.py
# Patch serving layer: flush reasoning→content on finish_reason=length
#COPY serving.py /usr/local/lib/python3.12/dist-packages/vllm/entrypoints/openai/chat_completion/serving.py
ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"]

58
Jenkinsfile vendored Normal file
View File

@@ -0,0 +1,58 @@
pipeline {
agent any
environment {
REGISTRY = 'atl.vultrcr.com/vllm'
IMAGE_NAME = 'kimi-k26-dflash-mi300x'
}
parameters {
string(name: 'IMAGE_TAG', defaultValue: 'nightly', description: 'Docker image tag')
string(name: 'GIT_REPO', defaultValue: 'https://sweetapi.com/biondizzle/kimi-k26-dflash-mi300x.git', description: 'Git repository URL')
string(name: 'GIT_BRANCH', defaultValue: 'master', description: 'Git branch to build')
}
stages {
stage('Checkout') {
steps {
script {
if (params.GIT_REPO) {
git url: params.GIT_REPO, branch: params.GIT_BRANCH
}
}
}
}
stage('Build') {
steps {
script {
withCredentials([string(credentialsId: 'HF_TOKEN', variable: 'HF_SECRET')]) {
docker.withRegistry("https://${REGISTRY}", 'ATL_VCR_VLLM') {
def imageTag = params.IMAGE_TAG
sh "docker build -f Dockerfile.kimi26-dflash --build-arg HF_TOKEN=\${HF_SECRET} -t ${REGISTRY}/${IMAGE_NAME}:${imageTag} ."
}
}
}
}
}
stage('Push') {
steps {
script {
docker.withRegistry("https://${REGISTRY}", 'ATL_VCR_VLLM') {
docker.image("${REGISTRY}/${IMAGE_NAME}:${params.IMAGE_TAG}").push()
}
}
}
}
}
post {
success {
echo "✅ Image pushed: ${REGISTRY}/${IMAGE_NAME}:${params.IMAGE_TAG}"
}
failure {
echo "❌ Build failed"
}
}
}

373
kimi_k2_reasoning_parser.py Normal file
View File

@@ -0,0 +1,373 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Kimi-K2 Reasoning Parser — MTP-compatible version.
Fixes applied over the upstream parser:
1. **<think>/</think> tag suppression no longer requires single-token
deltas.** The original used ``len(delta_token_ids) == 1`` to detect
and suppress think tags. With MTP speculative decoding, these tokens
arrive fused with reasoning text, so the guard fails and raw tags
leak into the reasoning or content output.
2. **Text-based detection replaces token-ID-only detection** in
``extract_reasoning_streaming``. Since ``<think>`` and ``</think>``
are single tokens, they always appear as complete strings in
``delta_text`` (the detokenizer never splits a single token across
deltas). Text-based stripping is therefore safe and MTP-agnostic.
3. **Handles ``</think>`` + ``<|tool_calls_section_begin|>`` arriving
in the same delta** — the reasoning portion is correctly terminated
and the tool-call content is forwarded so the tool parser can
detect it on the same or next call.
Drop-in replacement: same class name, same interface.
"""
from collections.abc import Iterable, Sequence
from typing import TYPE_CHECKING
from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.engine.protocol import DeltaMessage
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
from vllm.reasoning.identity_reasoning_parser import IdentityReasoningParser
if TYPE_CHECKING:
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
class KimiK2ReasoningParser(ReasoningParser):
"""
Reasoning parser for Kimi K2 model — MTP-compatible.
Uses ``<think>...</think>`` to denote reasoning text. Reasoning
may also end implicitly when ``<|tool_calls_section_begin|>``
appears.
All detection uses text-based matching so the parser is agnostic
to how many tokens arrive per streaming step (robust against MTP
and EAGLE speculative decoding).
"""
def __init__(self, tokenizer: PreTrainedTokenizerBase, *args, **kwargs):
super().__init__(tokenizer, *args, **kwargs)
if not self.model_tokenizer:
raise ValueError(
"The model tokenizer must be passed to the ReasoningParser "
"constructor during construction."
)
# Check if thinking is disabled via chat_template_kwargs
chat_kwargs = kwargs.get("chat_template_kwargs", {}) or {}
thinking = bool(chat_kwargs.get("thinking", True))
# If thinking is not enabled, use identity parser to fall through
self._identity_parser: IdentityReasoningParser | None
if not thinking:
self._identity_parser = IdentityReasoningParser(
tokenizer, *args, **kwargs
)
else:
self._identity_parser = None
# Token definitions
self._start_token = "<think>"
self._end_token = "</think>"
self._tool_section_start_token = "<|tool_calls_section_begin|>"
# Also support singular variant for tool section
self._tool_section_start_variants = [
"<|tool_calls_section_begin|>",
"<|tool_call_section_begin|>",
]
# Get token IDs (used by is_reasoning_end for non-streaming,
# and is_reasoning_end_streaming for delta checks)
self._start_token_id = self.vocab.get(self._start_token)
self._end_token_id = self.vocab.get(self._end_token)
self._tool_section_start_token_id = self.vocab.get(
self._tool_section_start_token
)
# Collect all tool section start token IDs (for ID-based checks)
self._tool_section_start_token_ids: set[int] = set()
for variant in self._tool_section_start_variants:
tid = self.vocab.get(variant)
if tid is not None:
self._tool_section_start_token_ids.add(tid)
if self._start_token_id is None or self._end_token_id is None:
raise RuntimeError(
"KimiK2ReasoningParser could not locate think start/end "
"tokens in the tokenizer!"
)
# Streaming state — tracks reasoning within the CURRENT
# generation only, avoiding false positives from prior turns'
# </think> tokens that appear in the prompt token IDs.
self._reasoning_ended: bool = False
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
def _find_tool_section_start(self, text: str) -> int:
"""Return the index of the earliest tool-section-start marker,
or -1 if none found."""
best = -1
for variant in self._tool_section_start_variants:
idx = text.find(variant)
if idx != -1 and (best == -1 or idx < best):
best = idx
return best
def _strip_think_tags(self, text: str) -> str:
"""Remove ``<think>`` and ``</think>`` tag text from *text*."""
return text.replace(self._start_token, "").replace(self._end_token, "")
def _strip_tool_section_markers(self, text: str) -> str:
"""Remove all tool-section start markers from *text*.
The tool parser finds these in ``current_text`` independently;
forwarding them as content causes double-handling.
"""
for variant in self._tool_section_start_variants:
text = text.replace(variant, "")
return text
# ------------------------------------------------------------------
# Full-sequence methods (these scan all IDs — MTP-safe as-is)
# ------------------------------------------------------------------
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
"""Check if reasoning has ended based on the token ID sequence.
Scans backward to find the last think-start or think-end token.
Returns True only if the last relevant token is a think-end or
a tool-section-start, AND there is no think-start after it.
CRITICAL: When called with prompt_token_ids (as the vLLM serving
layer does), the input contains the full chat history. On
multi-turn conversations, the prompt ends with tokens from the
prior assistant message, which may include think-end. However,
this think-end belongs to the PRIOR generation — the new
generation will start its own reasoning with think-start.
To handle this correctly, we check whether the input ends with
a complete reasoning block (think-start ... think-end). If the
last think token is think-end AND it's followed by non-reasoning
tokens (like tool_call tokens or end-of-sequence), we return
True. But if the input is just the prompt with no generated
tokens yet, we return False because the new generation hasn't
started reasoning yet.
The key insight: in the chat template for multi-turn, after the
last assistant message's think-end, the template adds
<|im_end|> followed by new user/assistant markers. The
assistant generation prompt ends with <|im_assistant|> and
<|im_middle|> — no think tokens. So if we scan backward and
find think-end but then find prompt-end tokens (not think-start)
after it, we know reasoning ended in a PRIOR turn, not the
current one. We return False to let the new generation start
fresh.
"""
if self._identity_parser is not None:
return self._identity_parser.is_reasoning_end(input_ids)
# Scan backward to find the last think-start or think-end
# or tool-section-start token.
last_start = -1
last_end = -1
last_tool_section = -1
for i in range(len(input_ids) - 1, -1, -1):
if input_ids[i] == self._start_token_id and last_start == -1:
last_start = i
if input_ids[i] == self._end_token_id and last_end == -1:
last_end = i
if input_ids[i] in self._tool_section_start_token_ids and last_tool_section == -1:
last_tool_section = i
# Stop early if we found think-start — it's the boundary
if last_start != -1:
break
# No think tokens at all — not a reasoning model output
if last_start == -1 and last_end == -1 and last_tool_section == -1:
return False
# think-start is the last relevant token — reasoning is in progress
if last_start != -1 and (last_end == -1 or last_start > last_end):
return False
# think-end or tool-section is the last relevant token.
# This could be from the prompt (prior turn) or from generated
# tokens. For prompt tokens on multi-turn, the think-end is
# from a prior assistant message and the new generation hasn't
# started yet — we should return False.
#
# Heuristic: if think-end appears but is followed by more tokens
# (like <|im_end|>, user markers, etc.), it's from the prompt
# and reasoning hasn't started in the current generation yet.
# Return False.
#
# If think-end is the very last token or near the end, it's
# from generated tokens and reasoning has ended. Return True.
last_relevant = max(last_end, last_tool_section)
tokens_after = len(input_ids) - 1 - last_relevant
# If there are more than a few tokens after the last think-end,
# those are prompt tokens (chat template wrapping), meaning
# the think-end is from a prior turn. Return False.
if tokens_after > 3:
return False
return True
def is_reasoning_end_streaming(
self, input_ids: Sequence[int], delta_ids: Iterable[int]
) -> bool:
"""Check if reasoning ends in this delta."""
if self._identity_parser is not None:
return self._identity_parser.is_reasoning_end_streaming(
input_ids, delta_ids
)
delta_ids_set = set(delta_ids)
if self._end_token_id in delta_ids_set:
return True
return bool(delta_ids_set & self._tool_section_start_token_ids)
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
"""Extract content token IDs (everything after reasoning ends)."""
if self._identity_parser is not None:
return self._identity_parser.extract_content_ids(input_ids)
if self._end_token_id in input_ids:
end_idx = (
len(input_ids) - 1 - input_ids[::-1].index(self._end_token_id)
)
if end_idx != -1:
return input_ids[end_idx + 1:]
# Check for implicit reasoning end via tool section
for tid in self._tool_section_start_token_ids:
if tid in input_ids:
tool_idx = (
len(input_ids) - 1 - input_ids[::-1].index(tid)
)
if tool_idx != -1:
return input_ids[tool_idx:]
return []
# ------------------------------------------------------------------
# Non-streaming extraction
# ------------------------------------------------------------------
def extract_reasoning(
self,
model_output: str,
request: "ChatCompletionRequest | ResponsesRequest",
) -> tuple[str | None, str | None]:
"""Extract (reasoning, content) from complete model output."""
if self._identity_parser is not None:
return self._identity_parser.extract_reasoning(
model_output, request
)
# Consume <think> at the start if present
start_idx = model_output.find(self._start_token)
start_idx = 0 if start_idx != 0 else len(self._start_token)
# Look for explicit </think>
end_idx = model_output.find(self._end_token)
if end_idx != -1:
reasoning = model_output[start_idx:end_idx]
content = model_output[end_idx + len(self._end_token):]
return reasoning, content or None
# Look for implicit reasoning end via tool section
tool_idx = self._find_tool_section_start(model_output)
if tool_idx != -1:
reasoning = model_output[start_idx:tool_idx]
content = model_output[tool_idx:]
return reasoning, content or None
# Still reasoning (no content yet)
return model_output[start_idx:], None
# ------------------------------------------------------------------
# Streaming extraction — MTP-compatible
# ------------------------------------------------------------------
def extract_reasoning_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
) -> DeltaMessage | None:
"""Extract reasoning from a streaming delta.
Uses **text-based** detection to strip ``<think>``/``</think>``
tags. This is safe because these are single tokens — the
detokenizer always produces them as complete strings, never
split across deltas. This makes the method agnostic to how
many tokens arrive per step (MTP-compatible).
"""
if self._identity_parser is not None:
return self._identity_parser.extract_reasoning_streaming(
previous_text, current_text, delta_text,
previous_token_ids, current_token_ids, delta_token_ids,
)
# Reset state on new stream — previous_text is empty on the
# first delta of each generation.
if not previous_text:
self._reasoning_ended = False
# ── Already past reasoning → everything is content ──
# Uses our own _reasoning_ended flag instead of scanning
# previous_token_ids, which may contain </think> from prior
# assistant turns in the prompt and cause false positives.
if self._reasoning_ended:
cleaned = self._strip_tool_section_markers(
self._strip_think_tags(delta_text)
)
return DeltaMessage(content=cleaned) if cleaned else None
# ── Check for </think> in this delta ──
if self._end_token in delta_text:
self._reasoning_ended = True
end_idx = delta_text.find(self._end_token)
reasoning = self._strip_think_tags(delta_text[:end_idx])
content = self._strip_tool_section_markers(
delta_text[end_idx + len(self._end_token):]
)
kwargs: dict = {}
if reasoning:
kwargs["reasoning"] = reasoning
if content:
kwargs["content"] = content
return DeltaMessage(**kwargs) if kwargs else None
# ── Check for implicit reasoning end via tool section ──
tool_idx = self._find_tool_section_start(delta_text)
if tool_idx != -1:
self._reasoning_ended = True
reasoning = self._strip_think_tags(delta_text[:tool_idx])
kwargs = {}
if reasoning:
kwargs["reasoning"] = reasoning
return DeltaMessage(**kwargs) if kwargs else None
# ── Still in reasoning — strip <think> tag if present ──
cleaned = self._strip_think_tags(delta_text)
return DeltaMessage(reasoning=cleaned) if cleaned else None

590
kimi_k2_tool_parser.py Normal file
View File

@@ -0,0 +1,590 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Kimi-K2 Tool Call Parser — re-parse-and-diff version.
Adapted from the GLM-4/DeepSeek-V3.2 streaming fix to make the
streaming path robust against multi-token deltas produced by MTP
speculative decoding.
Instead of counting start/end tokens to maintain an incremental state
machine, the streaming path re-parses the *entire* current_text on
every call, finds all <|tool_call_begin|> regions (complete and
in-progress), extracts the JSON arguments for each, and diffs against
what was previously sent.
Key changes vs. the upstream token-count parser:
1. No token-count state machine — the parser is stateless w.r.t.
how many tokens arrived per step.
2. Content forwarding uses delta_text (not re-parsed current_text)
so reasoning text is never re-emitted as content.
3. _extract_tool_call_regions() finds both complete and incomplete
tool-call blocks, enabling argument streaming.
4. _compute_args_diff() emits only newly-added characters.
5. Handles singular/plural section marker variants.
6. Returns empty deltas inside open sections to keep the stream
alive while tool call tokens are still arriving.
Drop-in replacement: same class name, same interface.
Example tool call format::
<|tool_calls_section_begin|>
<|tool_call_begin|>
functions.get_weather:0
<|tool_call_argument_begin|>
{"location": "杭州", "date": "2024-01-16"}
<|tool_call_end|>
<|tool_call_begin|>
functions.get_time:1
<|tool_call_argument_begin|>
{"timezone": "Asia/Shanghai"}
<|tool_call_end|>
<|tool_calls_section_end|>
"""
from collections.abc import Sequence
from typing import Any
import regex as re
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
)
from vllm.entrypoints.openai.engine.protocol import (
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall,
ToolCall,
)
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import (
Tool,
ToolParser,
)
logger = init_logger(__name__)
# ---------------------------------------------------------------------------
# Utility — inlined to avoid import issues across vLLM versions
# ---------------------------------------------------------------------------
def partial_tag_overlap(text: str, tag: str) -> int:
"""Length of the longest prefix of *tag* that matches a suffix of *text*.
E.g. text ending in ``"<|tool_call"`` returns 11 when tag is
``"<|tool_call_begin|>"``. Returns 0 when there is no overlap.
"""
max_check = min(len(tag) - 1, len(text))
for k in range(max_check, 0, -1):
if text.endswith(tag[:k]):
return k
return 0
class KimiK2ToolParser(ToolParser):
"""Re-parse-and-diff tool parser for Kimi-K2 format.
On every streaming call the parser re-parses ``current_text`` to
find tool-call regions, extracts the JSON arguments for each, and
diffs against what was previously sent. This is robust against
multi-token deltas from MTP / EAGLE speculative decoding.
"""
def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
super().__init__(tokenizer, tools)
# ----- Tag constants -----
# Section wrappers (support singular & plural variants)
self.tool_calls_section_start_variants: list[str] = [
"<|tool_calls_section_begin|>",
"<|tool_call_section_begin|>",
]
self.tool_calls_section_end_variants: list[str] = [
"<|tool_calls_section_end|>",
"<|tool_call_section_end|>",
]
# Some model variants omit the section-level marker and go
# directly to <|tool_call_begin|>. Treat it as a fallback.
self._fallback_section_start: str = "<|tool_call_begin|>"
# Primary variant for ToolParser base class / adjust_request
self.tool_calls_start_token: str = "<|tool_calls_section_begin|>"
self.tool_calls_end_token: str = "<|tool_calls_section_end|>"
# Individual tool-call markers
self.tool_call_start_token: str = "<|tool_call_begin|>"
self.tool_call_end_token: str = "<|tool_call_end|>"
self.tool_call_arg_begin: str = "<|tool_call_argument_begin|>"
# ----- Compiled regexes -----
# Complete tool call block.
self.tool_call_regex = re.compile(
r"<\|tool_call_begin\|>\s*"
r"(?P<tool_call_id>[^<]+:\d+)\s*"
r"<\|tool_call_argument_begin\|>\s*"
r"(?P<function_arguments>(?:(?!<\|tool_call_begin\|>).)*?)\s*"
r"<\|tool_call_end\|>",
re.DOTALL,
)
# For extracting tool ID from the start of a tool-call region.
self.tool_id_regex = re.compile(
r"\s*(?P<tool_id>[^\s<]+:\d+)\s*", re.DOTALL
)
# ----- Streaming state (reset per request) -----
self._tool_call_ids: list[str] = []
self.streamed_args_for_tool: list[str] = []
self.prev_tool_call_arr: list[dict[str, Any]] = []
self.current_tool_id: int = -1
if not self.model_tokenizer:
raise ValueError(
"The model tokenizer must be passed to the ToolParser "
"constructor during construction."
)
# Validate that the primary section tokens exist in vocab.
self.tool_calls_start_token_id = self.vocab.get(
self.tool_calls_start_token
)
self.tool_calls_end_token_id = self.vocab.get(
self.tool_calls_end_token
)
self.tool_call_start_token_id = self.vocab.get(
self.tool_call_start_token
)
self.tool_call_end_token_id = self.vocab.get(
self.tool_call_end_token
)
if (
self.tool_calls_start_token_id is None
or self.tool_calls_end_token_id is None
):
raise RuntimeError(
"Kimi-K2 Tool parser could not locate tool call start/end "
"tokens in the tokenizer!"
)
logger.debug(
"Successfully initialized %s", self.__class__.__name__
)
# ------------------------------------------------------------------
# Request adjustment
# ------------------------------------------------------------------
def adjust_request(
self, request: ChatCompletionRequest | ResponsesRequest
) -> ChatCompletionRequest | ResponsesRequest:
request = super().adjust_request(request)
if request.tools and request.tool_choice != "none":
# Ensure tool-call tokens (<|tool_calls_section_begin|>,
# <|tool_call_begin|>, etc.) are not stripped during decoding.
request.skip_special_tokens = False
return request
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
@staticmethod
def _tools_enabled(request: ChatCompletionRequest) -> bool:
try:
tools = getattr(request, "tools", None)
tool_choice = getattr(request, "tool_choice", None)
return bool(tools) and tool_choice != "none"
except Exception:
logger.exception("Failed to determine if tools are enabled.")
return False
@staticmethod
def _parse_tool_id(raw_id: str) -> tuple[str, str]:
"""Parse ``'functions.get_weather:0'`` → ``('get_weather', 'functions.get_weather:0')``."""
raw_id = raw_id.strip()
function_name = raw_id.split(":")[0].split(".")[-1]
return function_name, raw_id
def _find_section_start(self, text: str) -> int:
"""Return the index of the first section-start marker, or -1.
Falls back to <|tool_call_begin|> if no section-level marker
is found. Some model variants skip <|tool_calls_section_begin|>
and go directly to <|tool_call_begin|>.
"""
best = -1
for variant in self.tool_calls_section_start_variants:
idx = text.find(variant)
if idx != -1 and (best == -1 or idx < best):
best = idx
# Fallback: if no section-level marker found, look for
# <|tool_call_begin|> directly.
if best == -1 and self._fallback_section_start:
idx = text.find(self._fallback_section_start)
if idx != -1:
best = idx
return best
def _find_section_start_end(self, text: str) -> tuple[int, int]:
"""Return (start_of_inner, end_of_inner) for the section region.
*start_of_inner* points just past the section-start marker.
*end_of_inner* is the index of the section-end marker, or -1
if the section is still open.
Falls back to <|tool_call_begin|> if no section-level marker
is found.
"""
for variant in self.tool_calls_section_start_variants:
idx = text.find(variant)
if idx != -1:
inner_start = idx + len(variant)
# Look for end marker
for end_variant in self.tool_calls_section_end_variants:
end_idx = text.find(end_variant, inner_start)
if end_idx != -1:
return inner_start, end_idx
return inner_start, -1
# Fallback: no section-level marker found. Look for
# <|tool_call_begin|> directly as the section start.
if self._fallback_section_start:
idx = text.find(self._fallback_section_start)
if idx != -1:
inner_start = idx + len(self._fallback_section_start)
# Look for <|tool_call_end|> as the section end
end_marker = self._fallback_section_start.replace("begin", "end")
end_idx = text.find(end_marker, inner_start)
if end_idx != -1:
return inner_start, end_idx
return inner_start, -1
return -1, -1
# ------------------------------------------------------------------
# Non-streaming extraction (logic preserved from original)
# ------------------------------------------------------------------
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
if self.tool_calls_start_token not in model_output:
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
try:
function_call_tuples = self.tool_call_regex.findall(model_output)
logger.debug("function_call_tuples: %s", function_call_tuples)
tool_calls: list[ToolCall] = []
for match in function_call_tuples:
function_id, function_args = match
function_name, full_id = self._parse_tool_id(function_id)
tool_calls.append(
ToolCall(
id=full_id,
type="function",
function=FunctionCall(
name=function_name,
arguments=function_args,
),
)
)
content_end = self._find_section_start(model_output)
content = model_output[:content_end] if content_end > 0 else None
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=tool_calls,
content=content if content else None,
)
except Exception:
logger.exception("Error in extracting tool call from response.")
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
# ------------------------------------------------------------------
# Streaming helpers — re-parse-and-diff
# ------------------------------------------------------------------
def _reset_streaming_state(self) -> None:
self._tool_call_ids.clear()
self.streamed_args_for_tool.clear()
self.prev_tool_call_arr.clear()
self.current_tool_id = -1
def _extract_tool_call_regions(
self, text: str
) -> list[tuple[str, bool]]:
"""Find all ``<|tool_call_begin|>`` … ``<|tool_call_end|>``
blocks inside the tool-calls section.
Returns a list of ``(inner_text, is_complete)`` tuples.
*inner_text* is everything between the tool-call open and close
tags (or end-of-available-text for the last partial block).
"""
results: list[tuple[str, bool]] = []
# Find the section region.
inner_start, inner_end = self._find_section_start_end(text)
if inner_start == -1:
return results
region = text[inner_start:inner_end] if inner_end != -1 else text[inner_start:]
pos = 0
while pos < len(region):
tc_start = region.find(self.tool_call_start_token, pos)
if tc_start == -1:
break
body_start = tc_start + len(self.tool_call_start_token)
tc_end = region.find(self.tool_call_end_token, body_start)
if tc_end != -1:
body = region[body_start:tc_end]
results.append((body, True))
pos = tc_end + len(self.tool_call_end_token)
else:
# Incomplete — still being generated.
body = region[body_start:]
overlap = partial_tag_overlap(body, self.tool_call_end_token)
if overlap:
body = body[:-overlap]
results.append((body, False))
break
return results
def _parse_tool_call_body(
self, body: str, is_complete: bool
) -> tuple[str | None, str | None, str]:
"""Parse a tool-call body into (func_name, tool_id, args_so_far).
The body looks like::
functions.get_weather:0
<|tool_call_argument_begin|>
{"location": "杭州"}
Returns ``(None, None, "")`` if the body doesn't contain enough
information yet (e.g. the tool ID is still arriving).
"""
# Extract tool ID (everything before <|tool_call_argument_begin|>
# or end of string).
arg_begin_idx = body.find(self.tool_call_arg_begin)
if arg_begin_idx != -1:
id_portion = body[:arg_begin_idx]
args_portion = body[arg_begin_idx + len(self.tool_call_arg_begin):]
else:
id_portion = body
args_portion = ""
# Try to extract the tool ID.
id_match = self.tool_id_regex.match(id_portion)
if not id_match:
# Not enough tokens yet to identify the tool.
return None, None, ""
raw_id = id_match.group("tool_id")
func_name, full_id = self._parse_tool_id(raw_id)
# Build args string.
args = args_portion.strip()
if is_complete:
# For a complete block, args is the final JSON.
return func_name, full_id, args
else:
# For a partial block, strip any trailing partial-tag overlap
# against tool_call_end (already done in caller), but also
# check for partial overlap against tool_call_argument_begin
# in case it hasn't fully arrived yet.
if arg_begin_idx == -1:
# No argument section yet.
overlap = partial_tag_overlap(
id_portion, self.tool_call_arg_begin
)
if overlap:
# The tag is still arriving — we have the name but
# no args yet.
pass
return func_name, full_id, ""
return func_name, full_id, args
def _compute_args_diff(self, index: int, args_so_far: str) -> str | None:
"""Return only the characters in *args_so_far* that haven't been
sent yet, or ``None`` if there's nothing new."""
prev = self.streamed_args_for_tool[index]
if not args_so_far or len(args_so_far) <= len(prev):
return None
diff = args_so_far[len(prev):]
self.streamed_args_for_tool[index] = args_so_far
self.prev_tool_call_arr[index]["arguments"] = args_so_far
return diff
def _ensure_tool_state_for(self, index: int) -> None:
"""Grow the streaming-state arrays so *index* is valid."""
while len(self._tool_call_ids) <= index:
self._tool_call_ids.append("")
while len(self.streamed_args_for_tool) <= index:
self.streamed_args_for_tool.append("")
while len(self.prev_tool_call_arr) <= index:
self.prev_tool_call_arr.append({})
# ------------------------------------------------------------------
# Main streaming entry point
# ------------------------------------------------------------------
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
"""Extract tool calls from streaming output.
Hybrid approach:
- **Content forwarding** uses ``delta_text`` (same as the
original parser) so we never re-emit text that the reasoning
parser already handled.
- **Tool call detection** re-parses ``current_text`` on every
call (the re-parse-and-diff approach) so it's agnostic to
how many tokens arrived per step — robust against MTP.
"""
logger.debug("delta_text: %s", delta_text)
logger.debug("delta_token_ids: %s", delta_token_ids)
# First chunk of a new stream — reset state.
if not previous_text:
self._reset_streaming_state()
# If tools aren't enabled, just forward content.
if not self._tools_enabled(request):
return DeltaMessage(content=delta_text) if delta_text else None
# ── Determine section state from full text (MTP-safe) ──
inner_start, inner_end = self._find_section_start_end(current_text)
in_open_section = inner_start != -1 and inner_end == -1
# Was the section already open in previous_text?
prev_inner_start, _ = self._find_section_start_end(previous_text)
section_existed_before = prev_inner_start != -1
# ── Re-parse tool calls from current_text (MTP-safe) ──
regions = self._extract_tool_call_regions(current_text)
tool_call_deltas: list[DeltaToolCall] = []
for i, (body, is_complete) in enumerate(regions):
self._ensure_tool_state_for(i)
func_name, tool_id, args_so_far = self._parse_tool_call_body(
body, is_complete
)
if func_name is None:
# Not enough data to identify the tool yet.
break
# Emit the tool name (once per tool call).
if "name" not in self.prev_tool_call_arr[i]:
self.prev_tool_call_arr[i]["name"] = func_name
self._tool_call_ids[i] = tool_id or ""
tool_call_deltas.append(
DeltaToolCall(
index=i,
id=tool_id,
type="function",
function=DeltaFunctionCall(
name=func_name,
arguments="",
),
)
)
# Diff the arguments and emit any new characters.
diff = self._compute_args_diff(i, args_so_far)
if diff:
tool_call_deltas.append(
DeltaToolCall(
index=i,
function=DeltaFunctionCall(arguments=diff),
)
)
if regions:
self.current_tool_id = len(regions) - 1
# ── Emit results ──
# Case 1: We have tool call updates — emit them.
if tool_call_deltas:
return DeltaMessage(tool_calls=tool_call_deltas)
# Case 2: No tool section has started yet — forward delta_text
# as content. The reasoning parser handles the reasoning/content
# split; we just pass through whatever delta the serving layer
# gave us.
if inner_start == -1:
return DeltaMessage(content=delta_text) if delta_text else None
# Case 3: The section just appeared in this delta. Extract any
# content that came before the section marker in this delta
# (e.g. "Let me check.<|tool_calls_section_begin|>").
if not section_existed_before:
section_start_in_text = self._find_section_start(current_text)
pre_section = current_text[len(previous_text):section_start_in_text]
if pre_section.strip():
return DeltaMessage(content=pre_section)
# No real content before the section — return None instead of
# an empty-string delta. Empty content deltas confuse clients
# that distinguish content=null from content="".
return None
# Case 4: Inside an open tool section but tool calls aren't
# parseable yet — return None. The serving layer will emit
# its own keep-alive if needed; we should not emit empty-string
# content deltas that pollute the response.
if in_open_section:
return None
# Case 5: Section is closed and we're past it — forward any
# new content that appeared after the section end marker.
if inner_end != -1:
for variant in self.tool_calls_section_end_variants:
end_marker_pos = current_text.find(variant, inner_start)
if end_marker_pos != -1:
after_section = current_text[
end_marker_pos + len(variant):
]
# Only emit what's new (not previously seen)
prev_after_len = 0
prev_end_pos = previous_text.find(variant)
if prev_end_pos != -1:
prev_after_len = len(
previous_text[prev_end_pos + len(variant):]
)
new_after = after_section[prev_after_len:]
if new_after:
return DeltaMessage(content=new_after)
break
return None
return None

View File

@@ -79,31 +79,72 @@ def patch_rocm_aiter_fa(text: str, path: Path) -> str:
' @staticmethod\n def get_name() -> str:\n return "FLASH_ATTN"\n\n @classmethod\n def supports_non_causal(cls) -> bool:\n return True\n\n @staticmethod\n def get_impl_cls() -> type["AiterFlashAttentionImpl"]:\n',
path,
)
text = replace_once(
# Ensure AiterFlashAttentionMetadata dataclass has causal: bool field
if " causal: bool" not in text.split("class AiterFlashAttentionMetadata")[1].split("def ")[0]:
text = replace_once(
text,
"class AiterFlashAttentionMetadata:\n",
"class AiterFlashAttentionMetadata:\n causal: bool\n",
path,
)
# Ensure AiterFlashAttentionMetadata() constructors have causal= kwarg.
# Use a robust approach: find each constructor call, ensure causal= appears
# exactly once (first arg). This handles all upstream variations.
def _ensure_causal_in_constructor(m: re.Match) -> str:
call_text = m.group(0)
if re.search(r'\bcausal=', call_text):
return call_text # already has causal=, leave it
# Insert causal= as first kwarg after the opening paren
return call_text.replace(
"AiterFlashAttentionMetadata(\n",
"AiterFlashAttentionMetadata(\n causal=common_attn_metadata.causal,\n",
)
text = re.sub(
r'(?:attn_metadata = |return )AiterFlashAttentionMetadata\([^)]*?\)',
_ensure_causal_in_constructor,
text,
"class AiterFlashAttentionMetadata:\n",
"class AiterFlashAttentionMetadata:\n causal: bool\n",
path,
flags=re.DOTALL,
)
text = replace_once(
text,
" attn_metadata = AiterFlashAttentionMetadata(\n num_actual_tokens=common_attn_metadata.num_actual_tokens,\n",
" attn_metadata = AiterFlashAttentionMetadata(\n causal=common_attn_metadata.causal,\n num_actual_tokens=common_attn_metadata.num_actual_tokens,\n",
path,
)
text = replace_once(
text,
" return AiterFlashAttentionMetadata(\n num_actual_tokens=num_tokens,\n",
" return AiterFlashAttentionMetadata(\n causal=common_attn_metadata.causal,\n num_actual_tokens=num_tokens,\n",
path,
)
text = replace_all_regex(
text,
# Replace hardcoded causal=True with dynamic causal=attn_metadata.causal
# in flash attention calls. Skip if upstream already uses a dynamic causal.
text = re.sub(
r"(softmax_scale=self\.scale,\n)(\s*)causal=True,",
r"\1\2causal=attn_metadata.causal,",
path,
min_count=5,
text,
flags=re.MULTILINE,
)
# Safety: remove any duplicate causal= kwargs that may have been created
# by patching an upstream that already had causal= in constructors.
# This finds lines like `causal=...,` that appear more than once in the
# same parenthesized call and removes the extras (keeps first occurrence).
def _dedup_causal(m: re.Match) -> str:
block = m.group(0)
causal_lines = [i for i, line in enumerate(block.split('\n')) if re.match(r'\s*causal=', line)]
if len(causal_lines) <= 1:
return block
lines = block.split('\n')
# Keep only the first causal= line
seen = False
result = []
for line in lines:
if re.match(r'\s*causal=', line):
if seen:
continue # drop duplicate
seen = True
result.append(line)
return '\n'.join(result)
text = re.sub(
r'AiterFlashAttentionMetadata\([^)]*\)',
_dedup_causal,
text,
flags=re.DOTALL,
)
return text

1863
serving.py Normal file

File diff suppressed because it is too large Load Diff