Compare commits
16 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 183147b58f | |||
| eba5f3545d | |||
| 4bb1b6ca51 | |||
| c1c5db6568 | |||
| 53f56b17d6 | |||
| b5537b9c52 | |||
| 91fd11cf0b | |||
| c0df9172d9 | |||
| d1d85080e4 | |||
| 17a0eb538b | |||
| 95d3f6df95 | |||
| 51ff6900db | |||
| f3f46c6d27 | |||
| 7dbd4fe7ea | |||
| 6a2e87884c | |||
| 90dfb9c23c |
@@ -23,7 +23,26 @@ ENV PYTORCH_ROCM_ARCH=gfx942 \
|
||||
OMP_NUM_THREADS=1
|
||||
|
||||
# --- Copy and apply DFlash patches ---
|
||||
COPY payload/patch_dflash_rocm.py /tmp/patch_dflash_rocm.py
|
||||
COPY patches/patch_dflash_rocm.py /tmp/patch_dflash_rocm.py
|
||||
RUN python3 /tmp/patch_dflash_rocm.py && rm /tmp/patch_dflash_rocm.py
|
||||
|
||||
# --- Pre-download DFlash draft models ---
|
||||
# These are needed for speculative decoding and must be local paths.
|
||||
# Baking them into the image avoids runtime downloads/mounts.
|
||||
# Pass HF_TOKEN build arg if the models are gated.
|
||||
ARG HF_TOKEN=
|
||||
RUN bash -c 'if [ ! -d "/opt/draft-models/Kimi-K2.5-DFlash" ]; then \
|
||||
pip install --no-cache-dir huggingface_hub && \
|
||||
python3 -c "from huggingface_hub import snapshot_download; snapshot_download(\"z-lab/Kimi-K2.5-DFlash\", local_dir=\"/opt/draft-models/Kimi-K2.5-DFlash\")" && \
|
||||
rm -rf /root/.cache/huggingface; \
|
||||
fi'
|
||||
|
||||
# Patch tool and reasoning parsers for Eagle
|
||||
#COPY kimi_k2_tool_parser.py /usr/local/lib/python3.12/dist-packages/vllm/tool_parsers/kimi_k2_tool_parser.py
|
||||
|
||||
#COPY kimi_k2_reasoning_parser.py /usr/local/lib/python3.12/dist-packages/vllm/reasoning/kimi_k2_reasoning_parser.py
|
||||
|
||||
# Patch serving layer: flush reasoning→content on finish_reason=length
|
||||
#COPY serving.py /usr/local/lib/python3.12/dist-packages/vllm/entrypoints/openai/chat_completion/serving.py
|
||||
|
||||
ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"]
|
||||
|
||||
58
Jenkinsfile
vendored
Normal file
58
Jenkinsfile
vendored
Normal file
@@ -0,0 +1,58 @@
|
||||
pipeline {
|
||||
agent any
|
||||
|
||||
environment {
|
||||
REGISTRY = 'atl.vultrcr.com/vllm'
|
||||
IMAGE_NAME = 'kimi-k26-dflash-mi300x'
|
||||
}
|
||||
|
||||
parameters {
|
||||
string(name: 'IMAGE_TAG', defaultValue: 'nightly', description: 'Docker image tag')
|
||||
string(name: 'GIT_REPO', defaultValue: 'https://sweetapi.com/biondizzle/kimi-k26-dflash-mi300x.git', description: 'Git repository URL')
|
||||
string(name: 'GIT_BRANCH', defaultValue: 'master', description: 'Git branch to build')
|
||||
}
|
||||
|
||||
stages {
|
||||
stage('Checkout') {
|
||||
steps {
|
||||
script {
|
||||
if (params.GIT_REPO) {
|
||||
git url: params.GIT_REPO, branch: params.GIT_BRANCH
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
stage('Build') {
|
||||
steps {
|
||||
script {
|
||||
withCredentials([string(credentialsId: 'HF_TOKEN', variable: 'HF_SECRET')]) {
|
||||
docker.withRegistry("https://${REGISTRY}", 'ATL_VCR_VLLM') {
|
||||
def imageTag = params.IMAGE_TAG
|
||||
sh "docker build -f Dockerfile.kimi26-dflash --build-arg HF_TOKEN=\${HF_SECRET} -t ${REGISTRY}/${IMAGE_NAME}:${imageTag} ."
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
stage('Push') {
|
||||
steps {
|
||||
script {
|
||||
docker.withRegistry("https://${REGISTRY}", 'ATL_VCR_VLLM') {
|
||||
docker.image("${REGISTRY}/${IMAGE_NAME}:${params.IMAGE_TAG}").push()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
post {
|
||||
success {
|
||||
echo "✅ Image pushed: ${REGISTRY}/${IMAGE_NAME}:${params.IMAGE_TAG}"
|
||||
}
|
||||
failure {
|
||||
echo "❌ Build failed"
|
||||
}
|
||||
}
|
||||
}
|
||||
373
kimi_k2_reasoning_parser.py
Normal file
373
kimi_k2_reasoning_parser.py
Normal file
@@ -0,0 +1,373 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Kimi-K2 Reasoning Parser — MTP-compatible version.
|
||||
|
||||
Fixes applied over the upstream parser:
|
||||
|
||||
1. **<think>/</think> tag suppression no longer requires single-token
|
||||
deltas.** The original used ``len(delta_token_ids) == 1`` to detect
|
||||
and suppress think tags. With MTP speculative decoding, these tokens
|
||||
arrive fused with reasoning text, so the guard fails and raw tags
|
||||
leak into the reasoning or content output.
|
||||
|
||||
2. **Text-based detection replaces token-ID-only detection** in
|
||||
``extract_reasoning_streaming``. Since ``<think>`` and ``</think>``
|
||||
are single tokens, they always appear as complete strings in
|
||||
``delta_text`` (the detokenizer never splits a single token across
|
||||
deltas). Text-based stripping is therefore safe and MTP-agnostic.
|
||||
|
||||
3. **Handles ``</think>`` + ``<|tool_calls_section_begin|>`` arriving
|
||||
in the same delta** — the reasoning portion is correctly terminated
|
||||
and the tool-call content is forwarded so the tool parser can
|
||||
detect it on the same or next call.
|
||||
|
||||
Drop-in replacement: same class name, same interface.
|
||||
"""
|
||||
|
||||
from collections.abc import Iterable, Sequence
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.entrypoints.openai.engine.protocol import DeltaMessage
|
||||
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
|
||||
from vllm.reasoning.identity_reasoning_parser import IdentityReasoningParser
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
|
||||
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
|
||||
|
||||
|
||||
class KimiK2ReasoningParser(ReasoningParser):
|
||||
"""
|
||||
Reasoning parser for Kimi K2 model — MTP-compatible.
|
||||
|
||||
Uses ``<think>...</think>`` to denote reasoning text. Reasoning
|
||||
may also end implicitly when ``<|tool_calls_section_begin|>``
|
||||
appears.
|
||||
|
||||
All detection uses text-based matching so the parser is agnostic
|
||||
to how many tokens arrive per streaming step (robust against MTP
|
||||
and EAGLE speculative decoding).
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: PreTrainedTokenizerBase, *args, **kwargs):
|
||||
super().__init__(tokenizer, *args, **kwargs)
|
||||
|
||||
if not self.model_tokenizer:
|
||||
raise ValueError(
|
||||
"The model tokenizer must be passed to the ReasoningParser "
|
||||
"constructor during construction."
|
||||
)
|
||||
|
||||
# Check if thinking is disabled via chat_template_kwargs
|
||||
chat_kwargs = kwargs.get("chat_template_kwargs", {}) or {}
|
||||
thinking = bool(chat_kwargs.get("thinking", True))
|
||||
|
||||
# If thinking is not enabled, use identity parser to fall through
|
||||
self._identity_parser: IdentityReasoningParser | None
|
||||
if not thinking:
|
||||
self._identity_parser = IdentityReasoningParser(
|
||||
tokenizer, *args, **kwargs
|
||||
)
|
||||
else:
|
||||
self._identity_parser = None
|
||||
|
||||
# Token definitions
|
||||
self._start_token = "<think>"
|
||||
self._end_token = "</think>"
|
||||
self._tool_section_start_token = "<|tool_calls_section_begin|>"
|
||||
|
||||
# Also support singular variant for tool section
|
||||
self._tool_section_start_variants = [
|
||||
"<|tool_calls_section_begin|>",
|
||||
"<|tool_call_section_begin|>",
|
||||
]
|
||||
|
||||
# Get token IDs (used by is_reasoning_end for non-streaming,
|
||||
# and is_reasoning_end_streaming for delta checks)
|
||||
self._start_token_id = self.vocab.get(self._start_token)
|
||||
self._end_token_id = self.vocab.get(self._end_token)
|
||||
self._tool_section_start_token_id = self.vocab.get(
|
||||
self._tool_section_start_token
|
||||
)
|
||||
|
||||
# Collect all tool section start token IDs (for ID-based checks)
|
||||
self._tool_section_start_token_ids: set[int] = set()
|
||||
for variant in self._tool_section_start_variants:
|
||||
tid = self.vocab.get(variant)
|
||||
if tid is not None:
|
||||
self._tool_section_start_token_ids.add(tid)
|
||||
|
||||
if self._start_token_id is None or self._end_token_id is None:
|
||||
raise RuntimeError(
|
||||
"KimiK2ReasoningParser could not locate think start/end "
|
||||
"tokens in the tokenizer!"
|
||||
)
|
||||
|
||||
# Streaming state — tracks reasoning within the CURRENT
|
||||
# generation only, avoiding false positives from prior turns'
|
||||
# </think> tokens that appear in the prompt token IDs.
|
||||
self._reasoning_ended: bool = False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _find_tool_section_start(self, text: str) -> int:
|
||||
"""Return the index of the earliest tool-section-start marker,
|
||||
or -1 if none found."""
|
||||
best = -1
|
||||
for variant in self._tool_section_start_variants:
|
||||
idx = text.find(variant)
|
||||
if idx != -1 and (best == -1 or idx < best):
|
||||
best = idx
|
||||
return best
|
||||
|
||||
def _strip_think_tags(self, text: str) -> str:
|
||||
"""Remove ``<think>`` and ``</think>`` tag text from *text*."""
|
||||
return text.replace(self._start_token, "").replace(self._end_token, "")
|
||||
|
||||
def _strip_tool_section_markers(self, text: str) -> str:
|
||||
"""Remove all tool-section start markers from *text*.
|
||||
|
||||
The tool parser finds these in ``current_text`` independently;
|
||||
forwarding them as content causes double-handling.
|
||||
"""
|
||||
for variant in self._tool_section_start_variants:
|
||||
text = text.replace(variant, "")
|
||||
return text
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Full-sequence methods (these scan all IDs — MTP-safe as-is)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
|
||||
"""Check if reasoning has ended based on the token ID sequence.
|
||||
|
||||
Scans backward to find the last think-start or think-end token.
|
||||
Returns True only if the last relevant token is a think-end or
|
||||
a tool-section-start, AND there is no think-start after it.
|
||||
|
||||
CRITICAL: When called with prompt_token_ids (as the vLLM serving
|
||||
layer does), the input contains the full chat history. On
|
||||
multi-turn conversations, the prompt ends with tokens from the
|
||||
prior assistant message, which may include think-end. However,
|
||||
this think-end belongs to the PRIOR generation — the new
|
||||
generation will start its own reasoning with think-start.
|
||||
|
||||
To handle this correctly, we check whether the input ends with
|
||||
a complete reasoning block (think-start ... think-end). If the
|
||||
last think token is think-end AND it's followed by non-reasoning
|
||||
tokens (like tool_call tokens or end-of-sequence), we return
|
||||
True. But if the input is just the prompt with no generated
|
||||
tokens yet, we return False because the new generation hasn't
|
||||
started reasoning yet.
|
||||
|
||||
The key insight: in the chat template for multi-turn, after the
|
||||
last assistant message's think-end, the template adds
|
||||
<|im_end|> followed by new user/assistant markers. The
|
||||
assistant generation prompt ends with <|im_assistant|> and
|
||||
<|im_middle|> — no think tokens. So if we scan backward and
|
||||
find think-end but then find prompt-end tokens (not think-start)
|
||||
after it, we know reasoning ended in a PRIOR turn, not the
|
||||
current one. We return False to let the new generation start
|
||||
fresh.
|
||||
"""
|
||||
if self._identity_parser is not None:
|
||||
return self._identity_parser.is_reasoning_end(input_ids)
|
||||
|
||||
# Scan backward to find the last think-start or think-end
|
||||
# or tool-section-start token.
|
||||
last_start = -1
|
||||
last_end = -1
|
||||
last_tool_section = -1
|
||||
|
||||
for i in range(len(input_ids) - 1, -1, -1):
|
||||
if input_ids[i] == self._start_token_id and last_start == -1:
|
||||
last_start = i
|
||||
if input_ids[i] == self._end_token_id and last_end == -1:
|
||||
last_end = i
|
||||
if input_ids[i] in self._tool_section_start_token_ids and last_tool_section == -1:
|
||||
last_tool_section = i
|
||||
# Stop early if we found think-start — it's the boundary
|
||||
if last_start != -1:
|
||||
break
|
||||
|
||||
# No think tokens at all — not a reasoning model output
|
||||
if last_start == -1 and last_end == -1 and last_tool_section == -1:
|
||||
return False
|
||||
|
||||
# think-start is the last relevant token — reasoning is in progress
|
||||
if last_start != -1 and (last_end == -1 or last_start > last_end):
|
||||
return False
|
||||
|
||||
# think-end or tool-section is the last relevant token.
|
||||
# This could be from the prompt (prior turn) or from generated
|
||||
# tokens. For prompt tokens on multi-turn, the think-end is
|
||||
# from a prior assistant message and the new generation hasn't
|
||||
# started yet — we should return False.
|
||||
#
|
||||
# Heuristic: if think-end appears but is followed by more tokens
|
||||
# (like <|im_end|>, user markers, etc.), it's from the prompt
|
||||
# and reasoning hasn't started in the current generation yet.
|
||||
# Return False.
|
||||
#
|
||||
# If think-end is the very last token or near the end, it's
|
||||
# from generated tokens and reasoning has ended. Return True.
|
||||
last_relevant = max(last_end, last_tool_section)
|
||||
tokens_after = len(input_ids) - 1 - last_relevant
|
||||
|
||||
# If there are more than a few tokens after the last think-end,
|
||||
# those are prompt tokens (chat template wrapping), meaning
|
||||
# the think-end is from a prior turn. Return False.
|
||||
if tokens_after > 3:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def is_reasoning_end_streaming(
|
||||
self, input_ids: Sequence[int], delta_ids: Iterable[int]
|
||||
) -> bool:
|
||||
"""Check if reasoning ends in this delta."""
|
||||
if self._identity_parser is not None:
|
||||
return self._identity_parser.is_reasoning_end_streaming(
|
||||
input_ids, delta_ids
|
||||
)
|
||||
|
||||
delta_ids_set = set(delta_ids)
|
||||
if self._end_token_id in delta_ids_set:
|
||||
return True
|
||||
return bool(delta_ids_set & self._tool_section_start_token_ids)
|
||||
|
||||
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
|
||||
"""Extract content token IDs (everything after reasoning ends)."""
|
||||
if self._identity_parser is not None:
|
||||
return self._identity_parser.extract_content_ids(input_ids)
|
||||
|
||||
if self._end_token_id in input_ids:
|
||||
end_idx = (
|
||||
len(input_ids) - 1 - input_ids[::-1].index(self._end_token_id)
|
||||
)
|
||||
if end_idx != -1:
|
||||
return input_ids[end_idx + 1:]
|
||||
|
||||
# Check for implicit reasoning end via tool section
|
||||
for tid in self._tool_section_start_token_ids:
|
||||
if tid in input_ids:
|
||||
tool_idx = (
|
||||
len(input_ids) - 1 - input_ids[::-1].index(tid)
|
||||
)
|
||||
if tool_idx != -1:
|
||||
return input_ids[tool_idx:]
|
||||
|
||||
return []
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Non-streaming extraction
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def extract_reasoning(
|
||||
self,
|
||||
model_output: str,
|
||||
request: "ChatCompletionRequest | ResponsesRequest",
|
||||
) -> tuple[str | None, str | None]:
|
||||
"""Extract (reasoning, content) from complete model output."""
|
||||
if self._identity_parser is not None:
|
||||
return self._identity_parser.extract_reasoning(
|
||||
model_output, request
|
||||
)
|
||||
|
||||
# Consume <think> at the start if present
|
||||
start_idx = model_output.find(self._start_token)
|
||||
start_idx = 0 if start_idx != 0 else len(self._start_token)
|
||||
|
||||
# Look for explicit </think>
|
||||
end_idx = model_output.find(self._end_token)
|
||||
if end_idx != -1:
|
||||
reasoning = model_output[start_idx:end_idx]
|
||||
content = model_output[end_idx + len(self._end_token):]
|
||||
return reasoning, content or None
|
||||
|
||||
# Look for implicit reasoning end via tool section
|
||||
tool_idx = self._find_tool_section_start(model_output)
|
||||
if tool_idx != -1:
|
||||
reasoning = model_output[start_idx:tool_idx]
|
||||
content = model_output[tool_idx:]
|
||||
return reasoning, content or None
|
||||
|
||||
# Still reasoning (no content yet)
|
||||
return model_output[start_idx:], None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Streaming extraction — MTP-compatible
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def extract_reasoning_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
) -> DeltaMessage | None:
|
||||
"""Extract reasoning from a streaming delta.
|
||||
|
||||
Uses **text-based** detection to strip ``<think>``/``</think>``
|
||||
tags. This is safe because these are single tokens — the
|
||||
detokenizer always produces them as complete strings, never
|
||||
split across deltas. This makes the method agnostic to how
|
||||
many tokens arrive per step (MTP-compatible).
|
||||
"""
|
||||
if self._identity_parser is not None:
|
||||
return self._identity_parser.extract_reasoning_streaming(
|
||||
previous_text, current_text, delta_text,
|
||||
previous_token_ids, current_token_ids, delta_token_ids,
|
||||
)
|
||||
|
||||
# Reset state on new stream — previous_text is empty on the
|
||||
# first delta of each generation.
|
||||
if not previous_text:
|
||||
self._reasoning_ended = False
|
||||
|
||||
# ── Already past reasoning → everything is content ──
|
||||
# Uses our own _reasoning_ended flag instead of scanning
|
||||
# previous_token_ids, which may contain </think> from prior
|
||||
# assistant turns in the prompt and cause false positives.
|
||||
if self._reasoning_ended:
|
||||
cleaned = self._strip_tool_section_markers(
|
||||
self._strip_think_tags(delta_text)
|
||||
)
|
||||
return DeltaMessage(content=cleaned) if cleaned else None
|
||||
|
||||
# ── Check for </think> in this delta ──
|
||||
if self._end_token in delta_text:
|
||||
self._reasoning_ended = True
|
||||
end_idx = delta_text.find(self._end_token)
|
||||
reasoning = self._strip_think_tags(delta_text[:end_idx])
|
||||
content = self._strip_tool_section_markers(
|
||||
delta_text[end_idx + len(self._end_token):]
|
||||
)
|
||||
|
||||
kwargs: dict = {}
|
||||
if reasoning:
|
||||
kwargs["reasoning"] = reasoning
|
||||
if content:
|
||||
kwargs["content"] = content
|
||||
return DeltaMessage(**kwargs) if kwargs else None
|
||||
|
||||
# ── Check for implicit reasoning end via tool section ──
|
||||
tool_idx = self._find_tool_section_start(delta_text)
|
||||
if tool_idx != -1:
|
||||
self._reasoning_ended = True
|
||||
reasoning = self._strip_think_tags(delta_text[:tool_idx])
|
||||
kwargs = {}
|
||||
if reasoning:
|
||||
kwargs["reasoning"] = reasoning
|
||||
return DeltaMessage(**kwargs) if kwargs else None
|
||||
|
||||
# ── Still in reasoning — strip <think> tag if present ──
|
||||
cleaned = self._strip_think_tags(delta_text)
|
||||
return DeltaMessage(reasoning=cleaned) if cleaned else None
|
||||
590
kimi_k2_tool_parser.py
Normal file
590
kimi_k2_tool_parser.py
Normal file
@@ -0,0 +1,590 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Kimi-K2 Tool Call Parser — re-parse-and-diff version.
|
||||
|
||||
Adapted from the GLM-4/DeepSeek-V3.2 streaming fix to make the
|
||||
streaming path robust against multi-token deltas produced by MTP
|
||||
speculative decoding.
|
||||
|
||||
Instead of counting start/end tokens to maintain an incremental state
|
||||
machine, the streaming path re-parses the *entire* current_text on
|
||||
every call, finds all <|tool_call_begin|> regions (complete and
|
||||
in-progress), extracts the JSON arguments for each, and diffs against
|
||||
what was previously sent.
|
||||
|
||||
Key changes vs. the upstream token-count parser:
|
||||
1. No token-count state machine — the parser is stateless w.r.t.
|
||||
how many tokens arrived per step.
|
||||
2. Content forwarding uses delta_text (not re-parsed current_text)
|
||||
so reasoning text is never re-emitted as content.
|
||||
3. _extract_tool_call_regions() finds both complete and incomplete
|
||||
tool-call blocks, enabling argument streaming.
|
||||
4. _compute_args_diff() emits only newly-added characters.
|
||||
5. Handles singular/plural section marker variants.
|
||||
6. Returns empty deltas inside open sections to keep the stream
|
||||
alive while tool call tokens are still arriving.
|
||||
|
||||
Drop-in replacement: same class name, same interface.
|
||||
|
||||
Example tool call format::
|
||||
|
||||
<|tool_calls_section_begin|>
|
||||
<|tool_call_begin|>
|
||||
functions.get_weather:0
|
||||
<|tool_call_argument_begin|>
|
||||
{"location": "杭州", "date": "2024-01-16"}
|
||||
<|tool_call_end|>
|
||||
<|tool_call_begin|>
|
||||
functions.get_time:1
|
||||
<|tool_call_argument_begin|>
|
||||
{"timezone": "Asia/Shanghai"}
|
||||
<|tool_call_end|>
|
||||
<|tool_calls_section_end|>
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
import regex as re
|
||||
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import (
|
||||
ChatCompletionRequest,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
DeltaFunctionCall,
|
||||
DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall,
|
||||
ToolCall,
|
||||
)
|
||||
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tool_parsers.abstract_tool_parser import (
|
||||
Tool,
|
||||
ToolParser,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Utility — inlined to avoid import issues across vLLM versions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def partial_tag_overlap(text: str, tag: str) -> int:
|
||||
"""Length of the longest prefix of *tag* that matches a suffix of *text*.
|
||||
|
||||
E.g. text ending in ``"<|tool_call"`` returns 11 when tag is
|
||||
``"<|tool_call_begin|>"``. Returns 0 when there is no overlap.
|
||||
"""
|
||||
max_check = min(len(tag) - 1, len(text))
|
||||
for k in range(max_check, 0, -1):
|
||||
if text.endswith(tag[:k]):
|
||||
return k
|
||||
return 0
|
||||
|
||||
|
||||
class KimiK2ToolParser(ToolParser):
|
||||
"""Re-parse-and-diff tool parser for Kimi-K2 format.
|
||||
|
||||
On every streaming call the parser re-parses ``current_text`` to
|
||||
find tool-call regions, extracts the JSON arguments for each, and
|
||||
diffs against what was previously sent. This is robust against
|
||||
multi-token deltas from MTP / EAGLE speculative decoding.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
|
||||
super().__init__(tokenizer, tools)
|
||||
|
||||
# ----- Tag constants -----
|
||||
# Section wrappers (support singular & plural variants)
|
||||
self.tool_calls_section_start_variants: list[str] = [
|
||||
"<|tool_calls_section_begin|>",
|
||||
"<|tool_call_section_begin|>",
|
||||
]
|
||||
self.tool_calls_section_end_variants: list[str] = [
|
||||
"<|tool_calls_section_end|>",
|
||||
"<|tool_call_section_end|>",
|
||||
]
|
||||
# Some model variants omit the section-level marker and go
|
||||
# directly to <|tool_call_begin|>. Treat it as a fallback.
|
||||
self._fallback_section_start: str = "<|tool_call_begin|>"
|
||||
# Primary variant for ToolParser base class / adjust_request
|
||||
self.tool_calls_start_token: str = "<|tool_calls_section_begin|>"
|
||||
self.tool_calls_end_token: str = "<|tool_calls_section_end|>"
|
||||
|
||||
# Individual tool-call markers
|
||||
self.tool_call_start_token: str = "<|tool_call_begin|>"
|
||||
self.tool_call_end_token: str = "<|tool_call_end|>"
|
||||
self.tool_call_arg_begin: str = "<|tool_call_argument_begin|>"
|
||||
|
||||
# ----- Compiled regexes -----
|
||||
# Complete tool call block.
|
||||
self.tool_call_regex = re.compile(
|
||||
r"<\|tool_call_begin\|>\s*"
|
||||
r"(?P<tool_call_id>[^<]+:\d+)\s*"
|
||||
r"<\|tool_call_argument_begin\|>\s*"
|
||||
r"(?P<function_arguments>(?:(?!<\|tool_call_begin\|>).)*?)\s*"
|
||||
r"<\|tool_call_end\|>",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
# For extracting tool ID from the start of a tool-call region.
|
||||
self.tool_id_regex = re.compile(
|
||||
r"\s*(?P<tool_id>[^\s<]+:\d+)\s*", re.DOTALL
|
||||
)
|
||||
|
||||
# ----- Streaming state (reset per request) -----
|
||||
self._tool_call_ids: list[str] = []
|
||||
self.streamed_args_for_tool: list[str] = []
|
||||
self.prev_tool_call_arr: list[dict[str, Any]] = []
|
||||
self.current_tool_id: int = -1
|
||||
|
||||
if not self.model_tokenizer:
|
||||
raise ValueError(
|
||||
"The model tokenizer must be passed to the ToolParser "
|
||||
"constructor during construction."
|
||||
)
|
||||
|
||||
# Validate that the primary section tokens exist in vocab.
|
||||
self.tool_calls_start_token_id = self.vocab.get(
|
||||
self.tool_calls_start_token
|
||||
)
|
||||
self.tool_calls_end_token_id = self.vocab.get(
|
||||
self.tool_calls_end_token
|
||||
)
|
||||
self.tool_call_start_token_id = self.vocab.get(
|
||||
self.tool_call_start_token
|
||||
)
|
||||
self.tool_call_end_token_id = self.vocab.get(
|
||||
self.tool_call_end_token
|
||||
)
|
||||
|
||||
if (
|
||||
self.tool_calls_start_token_id is None
|
||||
or self.tool_calls_end_token_id is None
|
||||
):
|
||||
raise RuntimeError(
|
||||
"Kimi-K2 Tool parser could not locate tool call start/end "
|
||||
"tokens in the tokenizer!"
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"Successfully initialized %s", self.__class__.__name__
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Request adjustment
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def adjust_request(
|
||||
self, request: ChatCompletionRequest | ResponsesRequest
|
||||
) -> ChatCompletionRequest | ResponsesRequest:
|
||||
request = super().adjust_request(request)
|
||||
if request.tools and request.tool_choice != "none":
|
||||
# Ensure tool-call tokens (<|tool_calls_section_begin|>,
|
||||
# <|tool_call_begin|>, etc.) are not stripped during decoding.
|
||||
request.skip_special_tokens = False
|
||||
return request
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _tools_enabled(request: ChatCompletionRequest) -> bool:
|
||||
try:
|
||||
tools = getattr(request, "tools", None)
|
||||
tool_choice = getattr(request, "tool_choice", None)
|
||||
return bool(tools) and tool_choice != "none"
|
||||
except Exception:
|
||||
logger.exception("Failed to determine if tools are enabled.")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _parse_tool_id(raw_id: str) -> tuple[str, str]:
|
||||
"""Parse ``'functions.get_weather:0'`` → ``('get_weather', 'functions.get_weather:0')``."""
|
||||
raw_id = raw_id.strip()
|
||||
function_name = raw_id.split(":")[0].split(".")[-1]
|
||||
return function_name, raw_id
|
||||
|
||||
def _find_section_start(self, text: str) -> int:
|
||||
"""Return the index of the first section-start marker, or -1.
|
||||
|
||||
Falls back to <|tool_call_begin|> if no section-level marker
|
||||
is found. Some model variants skip <|tool_calls_section_begin|>
|
||||
and go directly to <|tool_call_begin|>.
|
||||
"""
|
||||
best = -1
|
||||
for variant in self.tool_calls_section_start_variants:
|
||||
idx = text.find(variant)
|
||||
if idx != -1 and (best == -1 or idx < best):
|
||||
best = idx
|
||||
# Fallback: if no section-level marker found, look for
|
||||
# <|tool_call_begin|> directly.
|
||||
if best == -1 and self._fallback_section_start:
|
||||
idx = text.find(self._fallback_section_start)
|
||||
if idx != -1:
|
||||
best = idx
|
||||
return best
|
||||
|
||||
def _find_section_start_end(self, text: str) -> tuple[int, int]:
|
||||
"""Return (start_of_inner, end_of_inner) for the section region.
|
||||
|
||||
*start_of_inner* points just past the section-start marker.
|
||||
*end_of_inner* is the index of the section-end marker, or -1
|
||||
if the section is still open.
|
||||
|
||||
Falls back to <|tool_call_begin|> if no section-level marker
|
||||
is found.
|
||||
"""
|
||||
for variant in self.tool_calls_section_start_variants:
|
||||
idx = text.find(variant)
|
||||
if idx != -1:
|
||||
inner_start = idx + len(variant)
|
||||
# Look for end marker
|
||||
for end_variant in self.tool_calls_section_end_variants:
|
||||
end_idx = text.find(end_variant, inner_start)
|
||||
if end_idx != -1:
|
||||
return inner_start, end_idx
|
||||
return inner_start, -1
|
||||
|
||||
# Fallback: no section-level marker found. Look for
|
||||
# <|tool_call_begin|> directly as the section start.
|
||||
if self._fallback_section_start:
|
||||
idx = text.find(self._fallback_section_start)
|
||||
if idx != -1:
|
||||
inner_start = idx + len(self._fallback_section_start)
|
||||
# Look for <|tool_call_end|> as the section end
|
||||
end_marker = self._fallback_section_start.replace("begin", "end")
|
||||
end_idx = text.find(end_marker, inner_start)
|
||||
if end_idx != -1:
|
||||
return inner_start, end_idx
|
||||
return inner_start, -1
|
||||
|
||||
return -1, -1
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Non-streaming extraction (logic preserved from original)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def extract_tool_calls(
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ExtractedToolCallInformation:
|
||||
if self.tool_calls_start_token not in model_output:
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
try:
|
||||
function_call_tuples = self.tool_call_regex.findall(model_output)
|
||||
logger.debug("function_call_tuples: %s", function_call_tuples)
|
||||
|
||||
tool_calls: list[ToolCall] = []
|
||||
for match in function_call_tuples:
|
||||
function_id, function_args = match
|
||||
function_name, full_id = self._parse_tool_id(function_id)
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
id=full_id,
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=function_name,
|
||||
arguments=function_args,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
content_end = self._find_section_start(model_output)
|
||||
content = model_output[:content_end] if content_end > 0 else None
|
||||
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True,
|
||||
tool_calls=tool_calls,
|
||||
content=content if content else None,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error in extracting tool call from response.")
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Streaming helpers — re-parse-and-diff
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _reset_streaming_state(self) -> None:
|
||||
self._tool_call_ids.clear()
|
||||
self.streamed_args_for_tool.clear()
|
||||
self.prev_tool_call_arr.clear()
|
||||
self.current_tool_id = -1
|
||||
|
||||
def _extract_tool_call_regions(
|
||||
self, text: str
|
||||
) -> list[tuple[str, bool]]:
|
||||
"""Find all ``<|tool_call_begin|>`` … ``<|tool_call_end|>``
|
||||
blocks inside the tool-calls section.
|
||||
|
||||
Returns a list of ``(inner_text, is_complete)`` tuples.
|
||||
*inner_text* is everything between the tool-call open and close
|
||||
tags (or end-of-available-text for the last partial block).
|
||||
"""
|
||||
results: list[tuple[str, bool]] = []
|
||||
|
||||
# Find the section region.
|
||||
inner_start, inner_end = self._find_section_start_end(text)
|
||||
if inner_start == -1:
|
||||
return results
|
||||
|
||||
region = text[inner_start:inner_end] if inner_end != -1 else text[inner_start:]
|
||||
|
||||
pos = 0
|
||||
while pos < len(region):
|
||||
tc_start = region.find(self.tool_call_start_token, pos)
|
||||
if tc_start == -1:
|
||||
break
|
||||
|
||||
body_start = tc_start + len(self.tool_call_start_token)
|
||||
tc_end = region.find(self.tool_call_end_token, body_start)
|
||||
|
||||
if tc_end != -1:
|
||||
body = region[body_start:tc_end]
|
||||
results.append((body, True))
|
||||
pos = tc_end + len(self.tool_call_end_token)
|
||||
else:
|
||||
# Incomplete — still being generated.
|
||||
body = region[body_start:]
|
||||
overlap = partial_tag_overlap(body, self.tool_call_end_token)
|
||||
if overlap:
|
||||
body = body[:-overlap]
|
||||
results.append((body, False))
|
||||
break
|
||||
|
||||
return results
|
||||
|
||||
def _parse_tool_call_body(
|
||||
self, body: str, is_complete: bool
|
||||
) -> tuple[str | None, str | None, str]:
|
||||
"""Parse a tool-call body into (func_name, tool_id, args_so_far).
|
||||
|
||||
The body looks like::
|
||||
|
||||
functions.get_weather:0
|
||||
<|tool_call_argument_begin|>
|
||||
{"location": "杭州"}
|
||||
|
||||
Returns ``(None, None, "")`` if the body doesn't contain enough
|
||||
information yet (e.g. the tool ID is still arriving).
|
||||
"""
|
||||
# Extract tool ID (everything before <|tool_call_argument_begin|>
|
||||
# or end of string).
|
||||
arg_begin_idx = body.find(self.tool_call_arg_begin)
|
||||
|
||||
if arg_begin_idx != -1:
|
||||
id_portion = body[:arg_begin_idx]
|
||||
args_portion = body[arg_begin_idx + len(self.tool_call_arg_begin):]
|
||||
else:
|
||||
id_portion = body
|
||||
args_portion = ""
|
||||
|
||||
# Try to extract the tool ID.
|
||||
id_match = self.tool_id_regex.match(id_portion)
|
||||
if not id_match:
|
||||
# Not enough tokens yet to identify the tool.
|
||||
return None, None, ""
|
||||
|
||||
raw_id = id_match.group("tool_id")
|
||||
func_name, full_id = self._parse_tool_id(raw_id)
|
||||
|
||||
# Build args string.
|
||||
args = args_portion.strip()
|
||||
|
||||
if is_complete:
|
||||
# For a complete block, args is the final JSON.
|
||||
return func_name, full_id, args
|
||||
else:
|
||||
# For a partial block, strip any trailing partial-tag overlap
|
||||
# against tool_call_end (already done in caller), but also
|
||||
# check for partial overlap against tool_call_argument_begin
|
||||
# in case it hasn't fully arrived yet.
|
||||
if arg_begin_idx == -1:
|
||||
# No argument section yet.
|
||||
overlap = partial_tag_overlap(
|
||||
id_portion, self.tool_call_arg_begin
|
||||
)
|
||||
if overlap:
|
||||
# The tag is still arriving — we have the name but
|
||||
# no args yet.
|
||||
pass
|
||||
return func_name, full_id, ""
|
||||
|
||||
return func_name, full_id, args
|
||||
|
||||
def _compute_args_diff(self, index: int, args_so_far: str) -> str | None:
|
||||
"""Return only the characters in *args_so_far* that haven't been
|
||||
sent yet, or ``None`` if there's nothing new."""
|
||||
prev = self.streamed_args_for_tool[index]
|
||||
if not args_so_far or len(args_so_far) <= len(prev):
|
||||
return None
|
||||
diff = args_so_far[len(prev):]
|
||||
self.streamed_args_for_tool[index] = args_so_far
|
||||
self.prev_tool_call_arr[index]["arguments"] = args_so_far
|
||||
return diff
|
||||
|
||||
def _ensure_tool_state_for(self, index: int) -> None:
|
||||
"""Grow the streaming-state arrays so *index* is valid."""
|
||||
while len(self._tool_call_ids) <= index:
|
||||
self._tool_call_ids.append("")
|
||||
while len(self.streamed_args_for_tool) <= index:
|
||||
self.streamed_args_for_tool.append("")
|
||||
while len(self.prev_tool_call_arr) <= index:
|
||||
self.prev_tool_call_arr.append({})
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Main streaming entry point
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> DeltaMessage | None:
|
||||
"""Extract tool calls from streaming output.
|
||||
|
||||
Hybrid approach:
|
||||
- **Content forwarding** uses ``delta_text`` (same as the
|
||||
original parser) so we never re-emit text that the reasoning
|
||||
parser already handled.
|
||||
- **Tool call detection** re-parses ``current_text`` on every
|
||||
call (the re-parse-and-diff approach) so it's agnostic to
|
||||
how many tokens arrived per step — robust against MTP.
|
||||
"""
|
||||
logger.debug("delta_text: %s", delta_text)
|
||||
logger.debug("delta_token_ids: %s", delta_token_ids)
|
||||
|
||||
# First chunk of a new stream — reset state.
|
||||
if not previous_text:
|
||||
self._reset_streaming_state()
|
||||
|
||||
# If tools aren't enabled, just forward content.
|
||||
if not self._tools_enabled(request):
|
||||
return DeltaMessage(content=delta_text) if delta_text else None
|
||||
|
||||
# ── Determine section state from full text (MTP-safe) ──
|
||||
inner_start, inner_end = self._find_section_start_end(current_text)
|
||||
in_open_section = inner_start != -1 and inner_end == -1
|
||||
|
||||
# Was the section already open in previous_text?
|
||||
prev_inner_start, _ = self._find_section_start_end(previous_text)
|
||||
section_existed_before = prev_inner_start != -1
|
||||
|
||||
# ── Re-parse tool calls from current_text (MTP-safe) ──
|
||||
regions = self._extract_tool_call_regions(current_text)
|
||||
tool_call_deltas: list[DeltaToolCall] = []
|
||||
|
||||
for i, (body, is_complete) in enumerate(regions):
|
||||
self._ensure_tool_state_for(i)
|
||||
|
||||
func_name, tool_id, args_so_far = self._parse_tool_call_body(
|
||||
body, is_complete
|
||||
)
|
||||
if func_name is None:
|
||||
# Not enough data to identify the tool yet.
|
||||
break
|
||||
|
||||
# Emit the tool name (once per tool call).
|
||||
if "name" not in self.prev_tool_call_arr[i]:
|
||||
self.prev_tool_call_arr[i]["name"] = func_name
|
||||
self._tool_call_ids[i] = tool_id or ""
|
||||
tool_call_deltas.append(
|
||||
DeltaToolCall(
|
||||
index=i,
|
||||
id=tool_id,
|
||||
type="function",
|
||||
function=DeltaFunctionCall(
|
||||
name=func_name,
|
||||
arguments="",
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Diff the arguments and emit any new characters.
|
||||
diff = self._compute_args_diff(i, args_so_far)
|
||||
if diff:
|
||||
tool_call_deltas.append(
|
||||
DeltaToolCall(
|
||||
index=i,
|
||||
function=DeltaFunctionCall(arguments=diff),
|
||||
)
|
||||
)
|
||||
|
||||
if regions:
|
||||
self.current_tool_id = len(regions) - 1
|
||||
|
||||
# ── Emit results ──
|
||||
|
||||
# Case 1: We have tool call updates — emit them.
|
||||
if tool_call_deltas:
|
||||
return DeltaMessage(tool_calls=tool_call_deltas)
|
||||
|
||||
# Case 2: No tool section has started yet — forward delta_text
|
||||
# as content. The reasoning parser handles the reasoning/content
|
||||
# split; we just pass through whatever delta the serving layer
|
||||
# gave us.
|
||||
if inner_start == -1:
|
||||
return DeltaMessage(content=delta_text) if delta_text else None
|
||||
|
||||
# Case 3: The section just appeared in this delta. Extract any
|
||||
# content that came before the section marker in this delta
|
||||
# (e.g. "Let me check.<|tool_calls_section_begin|>").
|
||||
if not section_existed_before:
|
||||
section_start_in_text = self._find_section_start(current_text)
|
||||
pre_section = current_text[len(previous_text):section_start_in_text]
|
||||
if pre_section.strip():
|
||||
return DeltaMessage(content=pre_section)
|
||||
# No real content before the section — return None instead of
|
||||
# an empty-string delta. Empty content deltas confuse clients
|
||||
# that distinguish content=null from content="".
|
||||
return None
|
||||
|
||||
# Case 4: Inside an open tool section but tool calls aren't
|
||||
# parseable yet — return None. The serving layer will emit
|
||||
# its own keep-alive if needed; we should not emit empty-string
|
||||
# content deltas that pollute the response.
|
||||
if in_open_section:
|
||||
return None
|
||||
|
||||
# Case 5: Section is closed and we're past it — forward any
|
||||
# new content that appeared after the section end marker.
|
||||
if inner_end != -1:
|
||||
for variant in self.tool_calls_section_end_variants:
|
||||
end_marker_pos = current_text.find(variant, inner_start)
|
||||
if end_marker_pos != -1:
|
||||
after_section = current_text[
|
||||
end_marker_pos + len(variant):
|
||||
]
|
||||
# Only emit what's new (not previously seen)
|
||||
prev_after_len = 0
|
||||
prev_end_pos = previous_text.find(variant)
|
||||
if prev_end_pos != -1:
|
||||
prev_after_len = len(
|
||||
previous_text[prev_end_pos + len(variant):]
|
||||
)
|
||||
new_after = after_section[prev_after_len:]
|
||||
if new_after:
|
||||
return DeltaMessage(content=new_after)
|
||||
break
|
||||
return None
|
||||
|
||||
return None
|
||||
@@ -79,31 +79,72 @@ def patch_rocm_aiter_fa(text: str, path: Path) -> str:
|
||||
' @staticmethod\n def get_name() -> str:\n return "FLASH_ATTN"\n\n @classmethod\n def supports_non_causal(cls) -> bool:\n return True\n\n @staticmethod\n def get_impl_cls() -> type["AiterFlashAttentionImpl"]:\n',
|
||||
path,
|
||||
)
|
||||
text = replace_once(
|
||||
# Ensure AiterFlashAttentionMetadata dataclass has causal: bool field
|
||||
if " causal: bool" not in text.split("class AiterFlashAttentionMetadata")[1].split("def ")[0]:
|
||||
text = replace_once(
|
||||
text,
|
||||
"class AiterFlashAttentionMetadata:\n",
|
||||
"class AiterFlashAttentionMetadata:\n causal: bool\n",
|
||||
path,
|
||||
)
|
||||
|
||||
# Ensure AiterFlashAttentionMetadata() constructors have causal= kwarg.
|
||||
# Use a robust approach: find each constructor call, ensure causal= appears
|
||||
# exactly once (first arg). This handles all upstream variations.
|
||||
def _ensure_causal_in_constructor(m: re.Match) -> str:
|
||||
call_text = m.group(0)
|
||||
if re.search(r'\bcausal=', call_text):
|
||||
return call_text # already has causal=, leave it
|
||||
# Insert causal= as first kwarg after the opening paren
|
||||
return call_text.replace(
|
||||
"AiterFlashAttentionMetadata(\n",
|
||||
"AiterFlashAttentionMetadata(\n causal=common_attn_metadata.causal,\n",
|
||||
)
|
||||
|
||||
text = re.sub(
|
||||
r'(?:attn_metadata = |return )AiterFlashAttentionMetadata\([^)]*?\)',
|
||||
_ensure_causal_in_constructor,
|
||||
text,
|
||||
"class AiterFlashAttentionMetadata:\n",
|
||||
"class AiterFlashAttentionMetadata:\n causal: bool\n",
|
||||
path,
|
||||
flags=re.DOTALL,
|
||||
)
|
||||
text = replace_once(
|
||||
text,
|
||||
" attn_metadata = AiterFlashAttentionMetadata(\n num_actual_tokens=common_attn_metadata.num_actual_tokens,\n",
|
||||
" attn_metadata = AiterFlashAttentionMetadata(\n causal=common_attn_metadata.causal,\n num_actual_tokens=common_attn_metadata.num_actual_tokens,\n",
|
||||
path,
|
||||
)
|
||||
text = replace_once(
|
||||
text,
|
||||
" return AiterFlashAttentionMetadata(\n num_actual_tokens=num_tokens,\n",
|
||||
" return AiterFlashAttentionMetadata(\n causal=common_attn_metadata.causal,\n num_actual_tokens=num_tokens,\n",
|
||||
path,
|
||||
)
|
||||
text = replace_all_regex(
|
||||
text,
|
||||
|
||||
# Replace hardcoded causal=True with dynamic causal=attn_metadata.causal
|
||||
# in flash attention calls. Skip if upstream already uses a dynamic causal.
|
||||
text = re.sub(
|
||||
r"(softmax_scale=self\.scale,\n)(\s*)causal=True,",
|
||||
r"\1\2causal=attn_metadata.causal,",
|
||||
path,
|
||||
min_count=5,
|
||||
text,
|
||||
flags=re.MULTILINE,
|
||||
)
|
||||
|
||||
# Safety: remove any duplicate causal= kwargs that may have been created
|
||||
# by patching an upstream that already had causal= in constructors.
|
||||
# This finds lines like `causal=...,` that appear more than once in the
|
||||
# same parenthesized call and removes the extras (keeps first occurrence).
|
||||
def _dedup_causal(m: re.Match) -> str:
|
||||
block = m.group(0)
|
||||
causal_lines = [i for i, line in enumerate(block.split('\n')) if re.match(r'\s*causal=', line)]
|
||||
if len(causal_lines) <= 1:
|
||||
return block
|
||||
lines = block.split('\n')
|
||||
# Keep only the first causal= line
|
||||
seen = False
|
||||
result = []
|
||||
for line in lines:
|
||||
if re.match(r'\s*causal=', line):
|
||||
if seen:
|
||||
continue # drop duplicate
|
||||
seen = True
|
||||
result.append(line)
|
||||
return '\n'.join(result)
|
||||
|
||||
text = re.sub(
|
||||
r'AiterFlashAttentionMetadata\([^)]*\)',
|
||||
_dedup_causal,
|
||||
text,
|
||||
flags=re.DOTALL,
|
||||
)
|
||||
|
||||
return text
|
||||
|
||||
|
||||
|
||||
1863
serving.py
Normal file
1863
serving.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user