Compare commits

..

9 Commits

Author SHA1 Message Date
Will Eaton
b8b302cde4 Update CUDA architecture list in build pipeline for 12.9.1 wheels (#26592)
Signed-off-by: Will Eaton <wseaton@users.noreply.github.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-10-10 11:15:45 -07:00
Tyler Michael Smith
f71952c1c4 [Build/CI] Revert back to Ubuntu 20.04, install python 3.12 with uv (#26103)
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Co-authored-by: Simon Mo <simon.mo@hey.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-10-02 22:22:31 -07:00
Michael Goin
d1007767c5 [Bugfix] Disable cascade attention with FlashInfer (#26130)
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-10-02 22:22:22 -07:00
Chen Zhang
c75c2e70d6 [Deepseek v3.2] Support indexer prefill chunking (#25999)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-10-02 10:35:51 -07:00
Chenheli Hua
9d9a2b77f1 [Small] Prevent bypassing media domain restriction via HTTP redirects (#26035)
Signed-off-by: Chenheli Hua <huachenheli@outlook.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-10-02 10:35:51 -07:00
Lucas Wilkinson
6040e0b6c0 [BugFix] Fix FI accuracy issue when used for MLA prefill (#26063)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-10-02 10:35:51 -07:00
Huy Do
05bf0c52a1 Update base image to 22.04 (jammy) (#26065)
Signed-off-by: Huy Do <huydhn@gmail.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-10-02 10:35:51 -07:00
Lucas Wilkinson
c536881a7c [BugFix] ChunkedLocalAttention is currently not CG compatible (#26034)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-10-02 10:35:51 -07:00
Lucas Wilkinson
ebce361c07 [BugFix][DP/EP] Fix CUTLASS MLA hang under load (#26026)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-10-02 10:35:50 -07:00
15 changed files with 258 additions and 156 deletions

View File

@@ -48,7 +48,7 @@ steps:
agents: agents:
queue: cpu_queue_postmerge queue: cpu_queue_postmerge
commands: commands:
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.9.1 --build-arg torch_cuda_arch_list='7.0 7.5 8.0 8.9 9.0+PTX' --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ." - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.9.1 --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ."
- "mkdir artifacts" - "mkdir artifacts"
- "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'" - "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'"
- "bash .buildkite/scripts/upload-wheels.sh" - "bash .buildkite/scripts/upload-wheels.sh"

View File

@@ -580,22 +580,22 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
for (; tile_scheduler.is_valid(); ++tile_scheduler) { for (; tile_scheduler.is_valid(); ++tile_scheduler) {
auto blk_coord = tile_scheduler.get_block_coord(); auto blk_coord = tile_scheduler.get_block_coord();
auto problem_shape = params.problem_shape; auto problem_shape = params.problem_shape;
auto local_split_kv = params.split_kv; auto local_split_kv = params.split_kv;
if (params.mainloop.ptr_seq != nullptr) { if (params.mainloop.ptr_seq != nullptr) {
get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
if (params.ptr_split_kv != nullptr) { if (params.ptr_split_kv != nullptr) {
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
} }
} }
if (local_split_kv <= get<3>(blk_coord)) if (local_split_kv <= get<3>(blk_coord))
continue; continue;
load_page_table( load_page_table(
blk_coord, blk_coord,
problem_shape, problem_shape,
params.mainloop, params.mainloop,
shared_storage.tensors, shared_storage.tensors,
pipeline_page_table, pipeline_pt_producer_state, pipeline_page_table, pipeline_pt_producer_state,
local_split_kv local_split_kv
); );
} }
} }
@@ -604,15 +604,15 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
CUTLASS_PRAGMA_NO_UNROLL CUTLASS_PRAGMA_NO_UNROLL
for (; tile_scheduler.is_valid(); ++tile_scheduler) { for (; tile_scheduler.is_valid(); ++tile_scheduler) {
auto blk_coord = tile_scheduler.get_block_coord(); auto blk_coord = tile_scheduler.get_block_coord();
auto problem_shape = params.problem_shape; auto problem_shape = params.problem_shape;
auto local_split_kv = params.split_kv; auto local_split_kv = params.split_kv;
if (params.mainloop.ptr_seq != nullptr) { if (params.mainloop.ptr_seq != nullptr) {
get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
if (params.ptr_split_kv != nullptr) { if (params.ptr_split_kv != nullptr) {
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
} }
} }
if (local_split_kv <= get<3>(blk_coord)) if (local_split_kv <= get<3>(blk_coord))
continue; continue;
load_cpasync( load_cpasync(
blk_coord, blk_coord,
@@ -621,7 +621,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
params.mainloop_params, params.mainloop_params,
shared_storage.tensors, shared_storage.tensors,
pipeline_load_qk, pipeline_load_qk_producer_state, pipeline_load_qk, pipeline_load_qk_producer_state,
local_split_kv, local_split_kv,
/* must be shared pipe */ /* must be shared pipe */
pipeline_page_table, pipeline_pt_consumer_state pipeline_page_table, pipeline_pt_consumer_state
); );
@@ -633,15 +633,15 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
CUTLASS_PRAGMA_NO_UNROLL CUTLASS_PRAGMA_NO_UNROLL
for (; tile_scheduler.is_valid(); ++tile_scheduler) { for (; tile_scheduler.is_valid(); ++tile_scheduler) {
auto blk_coord = tile_scheduler.get_block_coord(); auto blk_coord = tile_scheduler.get_block_coord();
auto problem_shape = params.problem_shape; auto problem_shape = params.problem_shape;
auto local_split_kv = params.split_kv; auto local_split_kv = params.split_kv;
if (params.mainloop.ptr_seq != nullptr) { if (params.mainloop.ptr_seq != nullptr) {
get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
if (params.ptr_split_kv != nullptr) { if (params.ptr_split_kv != nullptr) {
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
} }
} }
if (local_split_kv <= get<3>(blk_coord)) if (local_split_kv <= get<3>(blk_coord))
continue; continue;
load_tma</* paged= */ true>( load_tma</* paged= */ true>(
blk_coord, blk_coord,
@@ -651,7 +651,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
shared_storage.tensors, shared_storage.tensors,
pipeline_load_qk, pipeline_load_qk_producer_state, pipeline_load_qk, pipeline_load_qk_producer_state,
pipeline_load_qk, pipeline_load_qk_producer_state, pipeline_load_qk, pipeline_load_qk_producer_state,
local_split_kv local_split_kv
); );
cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive_and_wait(); cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive_and_wait();
} }
@@ -660,15 +660,15 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
CUTLASS_PRAGMA_NO_UNROLL CUTLASS_PRAGMA_NO_UNROLL
for (; tile_scheduler.is_valid(); ++tile_scheduler) { for (; tile_scheduler.is_valid(); ++tile_scheduler) {
auto blk_coord = tile_scheduler.get_block_coord(); auto blk_coord = tile_scheduler.get_block_coord();
auto problem_shape = params.problem_shape; auto problem_shape = params.problem_shape;
auto local_split_kv = params.split_kv; auto local_split_kv = params.split_kv;
if (params.mainloop.ptr_seq != nullptr) { if (params.mainloop.ptr_seq != nullptr) {
get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
if (params.ptr_split_kv != nullptr) { if (params.ptr_split_kv != nullptr) {
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
} }
} }
if (local_split_kv <= get<3>(blk_coord)) if (local_split_kv <= get<3>(blk_coord))
continue; continue;
load_tma<false>( load_tma<false>(
blk_coord, blk_coord,
@@ -678,7 +678,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
shared_storage.tensors, shared_storage.tensors,
pipeline_load_qk, pipeline_load_qk_producer_state, pipeline_load_qk, pipeline_load_qk_producer_state,
pipeline_load_qk, pipeline_load_qk_producer_state, pipeline_load_qk, pipeline_load_qk_producer_state,
local_split_kv local_split_kv
); );
cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive_and_wait(); cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive_and_wait();
} }
@@ -694,14 +694,14 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
for (; tile_scheduler.is_valid(); ++tile_scheduler) { for (; tile_scheduler.is_valid(); ++tile_scheduler) {
auto blk_coord = tile_scheduler.get_block_coord(); auto blk_coord = tile_scheduler.get_block_coord();
auto problem_shape = params.problem_shape; auto problem_shape = params.problem_shape;
auto local_split_kv = params.split_kv; auto local_split_kv = params.split_kv;
if (params.mainloop.ptr_seq != nullptr) { if (params.mainloop.ptr_seq != nullptr) {
get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
if (params.ptr_split_kv != nullptr) { if (params.ptr_split_kv != nullptr) {
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
} }
} }
if (local_split_kv <= get<3>(blk_coord)) if (local_split_kv <= get<3>(blk_coord))
continue; continue;
mma(blk_coord, mma(blk_coord,
problem_shape, problem_shape,
@@ -711,7 +711,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
pipeline_mma_s, pipeline_mma_s_producer_state, pipeline_mma_s, pipeline_mma_s_producer_state,
pipeline_p_mma, pipeline_p_mma_consumer_state, pipeline_p_mma, pipeline_p_mma_consumer_state,
pipeline_mma_o, pipeline_mma_o_producer_state, pipeline_mma_o, pipeline_mma_o_producer_state,
local_split_kv local_split_kv
); );
} }
} }
@@ -726,15 +726,15 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
for (; tile_scheduler.is_valid(); ++tile_scheduler) { for (; tile_scheduler.is_valid(); ++tile_scheduler) {
auto blk_coord = tile_scheduler.get_block_coord(); auto blk_coord = tile_scheduler.get_block_coord();
auto problem_shape = params.problem_shape; auto problem_shape = params.problem_shape;
auto split_kv = params.split_kv; auto split_kv = params.split_kv;
auto local_split_kv = split_kv; auto local_split_kv = split_kv;
if (params.mainloop.ptr_seq != nullptr) { if (params.mainloop.ptr_seq != nullptr) {
get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
if (params.ptr_split_kv != nullptr) { if (params.ptr_split_kv != nullptr) {
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
} }
} }
if (local_split_kv <= get<3>(blk_coord)) if (local_split_kv <= get<3>(blk_coord))
continue; continue;
compute( compute(
blk_coord, blk_coord,
@@ -745,7 +745,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
pipeline_mma_s, pipeline_mma_s_consumer_state, pipeline_mma_s, pipeline_mma_s_consumer_state,
pipeline_p_mma, pipeline_p_mma_producer_state, pipeline_p_mma, pipeline_p_mma_producer_state,
pipeline_mma_o, pipeline_mma_o_consumer_state, pipeline_mma_o, pipeline_mma_o_consumer_state,
local_split_kv local_split_kv
); );
} }
@@ -1900,7 +1900,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
cutlass::arch::NamedBarrier( cutlass::arch::NamedBarrier(
(kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, (kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp,
kNamedBarrierEpilogue kNamedBarrierEpilogue
).arrive(); ).arrive_and_wait();
return; return;
} }

View File

@@ -14,6 +14,11 @@ ARG PYTHON_VERSION=3.12
# #
# Example: # Example:
# docker build --build-arg BUILD_BASE_IMAGE=registry.acme.org/mirror/nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04 # docker build --build-arg BUILD_BASE_IMAGE=registry.acme.org/mirror/nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04
# Important: We build with an old version of Ubuntu to maintain broad
# compatibility with other Linux OSes. The main reason for this is that the
# glibc version is baked into the distro, and binaries built with one glibc
# version are not backwards compatible with OSes that use an earlier version.
ARG BUILD_BASE_IMAGE=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04 ARG BUILD_BASE_IMAGE=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04
# TODO: Restore to base image after FlashInfer AOT wheel fixed # TODO: Restore to base image after FlashInfer AOT wheel fixed
ARG FINAL_BASE_IMAGE=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04 ARG FINAL_BASE_IMAGE=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04
@@ -75,34 +80,19 @@ ARG TARGETPLATFORM
ARG INSTALL_KV_CONNECTORS=false ARG INSTALL_KV_CONNECTORS=false
ENV DEBIAN_FRONTEND=noninteractive ENV DEBIAN_FRONTEND=noninteractive
ARG DEADSNAKES_MIRROR_URL
ARG DEADSNAKES_GPGKEY_URL
ARG GET_PIP_URL ARG GET_PIP_URL
# Install Python and other dependencies # Install system dependencies and uv, then create Python virtual environment
RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
&& echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \ && echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \
&& apt-get update -y \ && apt-get update -y \
&& apt-get install -y ccache software-properties-common git curl sudo \ && apt-get install -y ccache software-properties-common git curl sudo python3-pip \
&& if [ ! -z ${DEADSNAKES_MIRROR_URL} ] ; then \ && curl -LsSf https://astral.sh/uv/install.sh | sh \
if [ ! -z "${DEADSNAKES_GPGKEY_URL}" ] ; then \ && $HOME/.local/bin/uv venv /opt/venv --python ${PYTHON_VERSION} \
mkdir -p -m 0755 /etc/apt/keyrings ; \ && rm -f /usr/bin/python3 /usr/bin/python3-config /usr/bin/pip \
curl -L ${DEADSNAKES_GPGKEY_URL} | gpg --dearmor > /etc/apt/keyrings/deadsnakes.gpg ; \ && ln -s /opt/venv/bin/python3 /usr/bin/python3 \
sudo chmod 644 /etc/apt/keyrings/deadsnakes.gpg ; \ && ln -s /opt/venv/bin/python3-config /usr/bin/python3-config \
echo "deb [signed-by=/etc/apt/keyrings/deadsnakes.gpg] ${DEADSNAKES_MIRROR_URL} $(lsb_release -cs) main" > /etc/apt/sources.list.d/deadsnakes.list ; \ && ln -s /opt/venv/bin/pip /usr/bin/pip \
fi ; \
else \
for i in 1 2 3; do \
add-apt-repository -y ppa:deadsnakes/ppa && break || \
{ echo "Attempt $i failed, retrying in 5s..."; sleep 5; }; \
done ; \
fi \
&& apt-get update -y \
&& apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \
&& update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 \
&& update-alternatives --set python3 /usr/bin/python${PYTHON_VERSION} \
&& ln -sf /usr/bin/python${PYTHON_VERSION}-config /usr/bin/python3-config \
&& curl -sS ${GET_PIP_URL} | python${PYTHON_VERSION} \
&& python3 --version && python3 -m pip --version && python3 --version && python3 -m pip --version
ARG PIP_INDEX_URL UV_INDEX_URL ARG PIP_INDEX_URL UV_INDEX_URL
@@ -111,9 +101,9 @@ ARG PYTORCH_CUDA_INDEX_BASE_URL
ARG PYTORCH_CUDA_NIGHTLY_INDEX_BASE_URL ARG PYTORCH_CUDA_NIGHTLY_INDEX_BASE_URL
ARG PIP_KEYRING_PROVIDER UV_KEYRING_PROVIDER ARG PIP_KEYRING_PROVIDER UV_KEYRING_PROVIDER
# Install uv for faster pip installs # Activate virtual environment and add uv to PATH
RUN --mount=type=cache,target=/root/.cache/uv \ ENV PATH="/opt/venv/bin:/root/.local/bin:$PATH"
python3 -m pip install uv ENV VIRTUAL_ENV="/opt/venv"
# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out # This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out
# Reference: https://github.com/astral-sh/uv/pull/1694 # Reference: https://github.com/astral-sh/uv/pull/1694
@@ -142,7 +132,7 @@ WORKDIR /workspace
COPY requirements/common.txt requirements/common.txt COPY requirements/common.txt requirements/common.txt
COPY requirements/cuda.txt requirements/cuda.txt COPY requirements/cuda.txt requirements/cuda.txt
RUN --mount=type=cache,target=/root/.cache/uv \ RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system -r requirements/cuda.txt \ uv pip install --python /opt/venv/bin/python3 -r requirements/cuda.txt \
--extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') --extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
# cuda arch list used by torch # cuda arch list used by torch
@@ -172,7 +162,7 @@ ENV UV_INDEX_STRATEGY="unsafe-best-match"
ENV UV_LINK_MODE=copy ENV UV_LINK_MODE=copy
RUN --mount=type=cache,target=/root/.cache/uv \ RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system -r requirements/build.txt \ uv pip install --python /opt/venv/bin/python3 -r requirements/build.txt \
--extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') --extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
COPY . . COPY . .
@@ -269,7 +259,7 @@ COPY requirements/lint.txt requirements/lint.txt
COPY requirements/test.txt requirements/test.txt COPY requirements/test.txt requirements/test.txt
COPY requirements/dev.txt requirements/dev.txt COPY requirements/dev.txt requirements/dev.txt
RUN --mount=type=cache,target=/root/.cache/uv \ RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system -r requirements/dev.txt \ uv pip install --python /opt/venv/bin/python3 -r requirements/dev.txt \
--extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') --extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
#################### DEV IMAGE #################### #################### DEV IMAGE ####################

View File

@@ -6,7 +6,7 @@ ARG CUDA_VERSION=12.8.0
# #
#################### BASE BUILD IMAGE #################### #################### BASE BUILD IMAGE ####################
# prepare basic build environment # prepare basic build environment
FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04 AS base FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04 AS base
ARG CUDA_VERSION=12.8.0 ARG CUDA_VERSION=12.8.0
ARG PYTHON_VERSION=3.12 ARG PYTHON_VERSION=3.12
ARG TARGETPLATFORM ARG TARGETPLATFORM

View File

@@ -8,6 +8,9 @@ This page teaches you how to pass multi-modal inputs to [multi-modal models][sup
!!! tip !!! tip
When serving multi-modal models, consider setting `--allowed-media-domains` to restrict domain that vLLM can access to prevent it from accessing arbitrary endpoints that can potentially be vulnerable to Server-Side Request Forgery (SSRF) attacks. You can provide a list of domains for this arg. For example: `--allowed-media-domains upload.wikimedia.org github.com www.bogotobogo.com` When serving multi-modal models, consider setting `--allowed-media-domains` to restrict domain that vLLM can access to prevent it from accessing arbitrary endpoints that can potentially be vulnerable to Server-Side Request Forgery (SSRF) attacks. You can provide a list of domains for this arg. For example: `--allowed-media-domains upload.wikimedia.org github.com www.bogotobogo.com`
Also, consider setting `VLLM_MEDIA_URL_ALLOW_REDIRECTS=0` to prevent HTTP redirects from being followed to bypass domain restrictions.
This restriction is especially important if you run vLLM in a containerized environment where the vLLM pods may have unrestricted access to internal networks. This restriction is especially important if you run vLLM in a containerized environment where the vLLM pods may have unrestricted access to internal networks.
## Offline Inference ## Offline Inference

View File

@@ -66,6 +66,9 @@ Restrict domains that vLLM can access for media URLs by setting
`--allowed-media-domains` to prevent Server-Side Request Forgery (SSRF) attacks. `--allowed-media-domains` to prevent Server-Side Request Forgery (SSRF) attacks.
(e.g. `--allowed-media-domains upload.wikimedia.org github.com www.bogotobogo.com`) (e.g. `--allowed-media-domains upload.wikimedia.org github.com www.bogotobogo.com`)
Also, consider setting `VLLM_MEDIA_URL_ALLOW_REDIRECTS=0` to prevent HTTP
redirects from being followed to bypass domain restrictions.
## Security and Firewalls: Protecting Exposed vLLM Systems ## Security and Firewalls: Protecting Exposed vLLM Systems
While vLLM is designed to allow unsafe network services to be isolated to While vLLM is designed to allow unsafe network services to be isolated to

View File

@@ -22,6 +22,7 @@ from vllm.utils import cdiv
from vllm.v1.attention.backends.mla.flashmla_sparse import ( from vllm.v1.attention.backends.mla.flashmla_sparse import (
FlashMLASparseBackend, FlashMLASparseDecodeAndContextMetadata, FlashMLASparseBackend, FlashMLASparseDecodeAndContextMetadata,
FlashMLASparseImpl, FlashMLASparseMetadata) FlashMLASparseImpl, FlashMLASparseMetadata)
from vllm.v1.attention.backends.mla.indexer import split_prefill_chunks
SPARSE_BACKEND_BATCH_SPECS = { SPARSE_BACKEND_BATCH_SPECS = {
name: BATCH_SPECS[name] name: BATCH_SPECS[name]
@@ -424,3 +425,24 @@ def test_sparse_backend_decode_correctness(dist_init, batch_name,
sdpa_reference, sdpa_reference,
rtol=0.5, rtol=0.5,
atol=0.5) atol=0.5)
@pytest.mark.parametrize(
"seq_lens,max_buf,start,expected",
[
# Basic split: totals per chunk ≤ max_buf
(torch.tensor([2, 3, 4, 2]), 5, 0, [(0, 2), (2, 3), (3, 4)]),
# Non-zero start index
(torch.tensor([2, 3, 4, 2]), 5, 1, [(1, 2), (2, 3), (3, 4)]),
# Exact fits should split between items when adding the next would
# overflow
(torch.tensor([5, 5, 5]), 5, 0, [(0, 1), (1, 2), (2, 3)]),
# All requests fit in a single chunk
(torch.tensor([1, 1, 1]), 10, 0, [(0, 3)]),
# Large buffer with non-zero start
(torch.tensor([4, 4, 4]), 100, 1, [(1, 3)]),
],
)
def test_split_prefill_chunks(seq_lens, max_buf, start, expected):
out = split_prefill_chunks(seq_lens, max_buf, start)
assert out == expected

View File

@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools import functools
from typing import List, Optional from typing import ClassVar, List, Optional
import torch import torch
@@ -11,8 +11,8 @@ from vllm.attention.backends.abstract import (AttentionBackend,
from vllm.attention.selector import get_attn_backend from vllm.attention.selector import get_attn_backend
from vllm.config import CacheConfig, QuantizationConfig from vllm.config import CacheConfig, QuantizationConfig
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata, make_local_attention_virtual_batches, AttentionCGSupport, CommonAttentionMetadata,
subclass_attention_backend) make_local_attention_virtual_batches, subclass_attention_backend)
from ..layer import Attention from ..layer import Attention
@@ -28,6 +28,8 @@ def create_chunked_local_attention_backend(
underlying_builder = underlying_attn_backend.get_builder_cls() underlying_builder = underlying_attn_backend.get_builder_cls()
class ChunkedLocalAttentionBuilder(underlying_builder): # type: ignore class ChunkedLocalAttentionBuilder(underlying_builder): # type: ignore
cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.NEVER
def build(self, def build(self,
common_prefix_len: int, common_prefix_len: int,

View File

@@ -54,6 +54,7 @@ class HTTPConnection:
stream: bool = False, stream: bool = False,
timeout: Optional[float] = None, timeout: Optional[float] = None,
extra_headers: Optional[Mapping[str, str]] = None, extra_headers: Optional[Mapping[str, str]] = None,
allow_redirects: bool = True,
): ):
self._validate_http_url(url) self._validate_http_url(url)
@@ -63,7 +64,8 @@ class HTTPConnection:
return client.get(url, return client.get(url,
headers=self._headers(**extra_headers), headers=self._headers(**extra_headers),
stream=stream, stream=stream,
timeout=timeout) timeout=timeout,
allow_redirects=allow_redirects)
async def get_async_response( async def get_async_response(
self, self,
@@ -71,6 +73,7 @@ class HTTPConnection:
*, *,
timeout: Optional[float] = None, timeout: Optional[float] = None,
extra_headers: Optional[Mapping[str, str]] = None, extra_headers: Optional[Mapping[str, str]] = None,
allow_redirects: bool = True,
): ):
self._validate_http_url(url) self._validate_http_url(url)
@@ -79,10 +82,17 @@ class HTTPConnection:
return client.get(url, return client.get(url,
headers=self._headers(**extra_headers), headers=self._headers(**extra_headers),
timeout=timeout) timeout=timeout,
allow_redirects=allow_redirects)
def get_bytes(self, url: str, *, timeout: Optional[float] = None) -> bytes: def get_bytes(self,
with self.get_response(url, timeout=timeout) as r: url: str,
*,
timeout: Optional[float] = None,
allow_redirects: bool = True) -> bytes:
with self.get_response(url,
timeout=timeout,
allow_redirects=allow_redirects) as r:
r.raise_for_status() r.raise_for_status()
return r.content return r.content
@@ -92,8 +102,10 @@ class HTTPConnection:
url: str, url: str,
*, *,
timeout: Optional[float] = None, timeout: Optional[float] = None,
allow_redirects: bool = True,
) -> bytes: ) -> bytes:
async with await self.get_async_response(url, timeout=timeout) as r: async with await self.get_async_response(
url, timeout=timeout, allow_redirects=allow_redirects) as r:
r.raise_for_status() r.raise_for_status()
return await r.read() return await r.read()

View File

@@ -68,6 +68,7 @@ if TYPE_CHECKING:
VLLM_IMAGE_FETCH_TIMEOUT: int = 5 VLLM_IMAGE_FETCH_TIMEOUT: int = 5
VLLM_VIDEO_FETCH_TIMEOUT: int = 30 VLLM_VIDEO_FETCH_TIMEOUT: int = 30
VLLM_AUDIO_FETCH_TIMEOUT: int = 10 VLLM_AUDIO_FETCH_TIMEOUT: int = 10
VLLM_MEDIA_URL_ALLOW_REDIRECTS: bool = True
VLLM_MEDIA_LOADING_THREAD_COUNT: int = 8 VLLM_MEDIA_LOADING_THREAD_COUNT: int = 8
VLLM_MAX_AUDIO_CLIP_FILESIZE_MB: int = 25 VLLM_MAX_AUDIO_CLIP_FILESIZE_MB: int = 25
VLLM_VIDEO_LOADER_BACKEND: str = "opencv" VLLM_VIDEO_LOADER_BACKEND: str = "opencv"
@@ -725,6 +726,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_AUDIO_FETCH_TIMEOUT": "VLLM_AUDIO_FETCH_TIMEOUT":
lambda: int(os.getenv("VLLM_AUDIO_FETCH_TIMEOUT", "10")), lambda: int(os.getenv("VLLM_AUDIO_FETCH_TIMEOUT", "10")),
# Whether to allow HTTP redirects when fetching from media URLs.
# Default to True
"VLLM_MEDIA_URL_ALLOW_REDIRECTS":
lambda: bool(int(os.getenv("VLLM_MEDIA_URL_ALLOW_REDIRECTS", "1"))),
# Max number of workers for the thread pool handling # Max number of workers for the thread pool handling
# media bytes loading. Set to 1 to disable parallel processing. # media bytes loading. Set to 1 to disable parallel processing.
# Default is 8 # Default is 8

View File

@@ -583,44 +583,43 @@ def sparse_attn_indexer(
topk_indices_buffer[:hidden_states.shape[0]] = -1 topk_indices_buffer[:hidden_states.shape[0]] = -1
if has_prefill: if has_prefill:
prefill_metadata = attn_metadata.prefill prefill_metadata = attn_metadata.prefill
num_prefills = attn_metadata.num_prefills for chunk in prefill_metadata.chunks:
k_fp8 = torch.empty([prefill_metadata.total_seq_lens, head_dim], k_fp8 = torch.empty([chunk.total_seq_lens, head_dim],
device=k.device, device=k.device,
dtype=torch.float8_e4m3fn) dtype=torch.float8_e4m3fn)
k_scale = torch.empty([prefill_metadata.total_seq_lens, 1], k_scale = torch.empty([chunk.total_seq_lens, 1],
device=k.device, device=k.device,
dtype=torch.float32) dtype=torch.float32)
cp_gather_indexer_k_quant_cache( cp_gather_indexer_k_quant_cache(
kv_cache, kv_cache,
k_fp8, k_fp8,
k_scale, k_scale,
prefill_metadata.block_table, chunk.block_table,
prefill_metadata.cu_seq_lens, chunk.cu_seq_lens,
num_prefills, chunk.num_reqs,
) )
cu_seqlen_ks = prefill_metadata.cu_seqlen_ks logits = fp8_mqa_logits(
cu_seqlen_ke = prefill_metadata.cu_seqlen_ke q_fp8[chunk.token_start:chunk.token_end],
num_tokens = attn_metadata.num_actual_tokens (k_fp8, k_scale),
logits = fp8_mqa_logits( weights[chunk.token_start:chunk.token_end],
q_fp8[num_decode_tokens:num_tokens], chunk.cu_seqlen_ks,
(k_fp8, k_scale), chunk.cu_seqlen_ke,
weights[num_decode_tokens:num_tokens], )
cu_seqlen_ks, topk_indices = logits.topk(min(topk_tokens, logits.shape[-1]),
cu_seqlen_ke, dim=-1)[1]
) topk_indices -= chunk.cu_seqlen_ks[:, None]
topk_indices = logits.topk(min(topk_tokens, logits.shape[-1]), mask_lo = topk_indices >= 0
dim=-1)[1] mask_hi = topk_indices - (chunk.cu_seqlen_ke -
topk_indices -= cu_seqlen_ks[:, None] chunk.cu_seqlen_ks)[:, None] < 0
mask_lo = topk_indices >= 0 mask = torch.full_like(topk_indices,
mask_hi = topk_indices - (cu_seqlen_ke - cu_seqlen_ks)[:, None] < 0 False,
mask = torch.full_like(topk_indices, dtype=torch.bool,
False, device=topk_indices.device)
dtype=torch.bool, mask = mask_lo & mask_hi
device=topk_indices.device) topk_indices = topk_indices.masked_fill(~mask, -1)
mask = mask_lo & mask_hi topk_indices_buffer[
topk_indices = topk_indices.masked_fill(~mask, -1) chunk.token_start:chunk.token_end, :topk_indices.
topk_indices_buffer[num_decode_tokens:num_tokens, :topk_indices. shape[-1]] = topk_indices.to(dtype=torch.int32)
shape[-1]] = topk_indices.to(dtype=torch.int32)
if has_decode: if has_decode:
decode_metadata = attn_metadata.decode decode_metadata = attn_metadata.decode

View File

@@ -140,7 +140,11 @@ class MediaConnector:
self._assert_url_in_allowed_media_domains(url_spec) self._assert_url_in_allowed_media_domains(url_spec)
connection = self.connection connection = self.connection
data = connection.get_bytes(url, timeout=fetch_timeout) data = connection.get_bytes(
url,
timeout=fetch_timeout,
allow_redirects=envs.VLLM_MEDIA_URL_ALLOW_REDIRECTS,
)
return media_io.load_bytes(data) return media_io.load_bytes(data)
@@ -167,7 +171,11 @@ class MediaConnector:
self._assert_url_in_allowed_media_domains(url_spec) self._assert_url_in_allowed_media_domains(url_spec)
connection = self.connection connection = self.connection
data = await connection.async_get_bytes(url, timeout=fetch_timeout) data = await connection.async_get_bytes(
url,
timeout=fetch_timeout,
allow_redirects=envs.VLLM_MEDIA_URL_ALLOW_REDIRECTS,
)
future = loop.run_in_executor(global_thread_pool, future = loop.run_in_executor(global_thread_pool,
media_io.load_bytes, data) media_io.load_bytes, data)
return await future return await future

View File

@@ -29,7 +29,6 @@ from vllm.utils.flashinfer import (can_use_trtllm_attention,
flashinfer_disable_q_quantization, flashinfer_disable_q_quantization,
supports_trtllm_attention, supports_trtllm_attention,
use_trtllm_attention) use_trtllm_attention)
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
# yapf: disable # yapf: disable
from vllm.v1.attention.backends.utils import (AttentionCGSupport, from vllm.v1.attention.backends.utils import (AttentionCGSupport,
@@ -677,7 +676,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# TODO: The cascade wrapper currently does not support setting # TODO: The cascade wrapper currently does not support setting
# kv cache dtype to something different from query dtype. # kv cache dtype to something different from query dtype.
return False return False
return use_cascade_attention(*args, **kwargs) # TODO: Cascade attention doesn't work, disable it for now
# return use_cascade_attention(*args, **kwargs)
return False
class FlashInferImpl(AttentionImpl): class FlashInferImpl(AttentionImpl):

View File

@@ -1211,13 +1211,18 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
k, v, return_softmax_lse): k, v, return_softmax_lse):
assert isinstance(prefill, FlashInferPrefillMetadata) assert isinstance(prefill, FlashInferPrefillMetadata)
assert prefill.prefill_main is not None assert prefill.prefill_main is not None
return prefill.prefill_main.run( ret = prefill.prefill_main.run(
q=q, q=q,
k=k, k=k,
v=v, v=v,
return_lse=return_softmax_lse, return_lse=return_softmax_lse,
) )
if isinstance(ret, tuple):
# Convert from (q_len, num_heads) to (num_heads, q_len)
return ret[0], ret[1].transpose(0, 1).contiguous()
return ret
def _run_prefill_new_tokens_cudnn(self, prefill: MLACommonPrefillMetadata, def _run_prefill_new_tokens_cudnn(self, prefill: MLACommonPrefillMetadata,
q, k, v, return_softmax_lse): q, k, v, return_softmax_lse):
assert isinstance(prefill, CudnnPrefillMetadata) assert isinstance(prefill, CudnnPrefillMetadata)
@@ -1260,12 +1265,14 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
def _run_prefill_context_chunk_fi(self, prefill: MLACommonPrefillMetadata, def _run_prefill_context_chunk_fi(self, prefill: MLACommonPrefillMetadata,
chunk_idx: int, q, k, v): chunk_idx: int, q, k, v):
assert isinstance(prefill, FlashInferPrefillMetadata) assert isinstance(prefill, FlashInferPrefillMetadata)
return prefill.prefill_chunks[chunk_idx].run( attn_out, lse = prefill.prefill_chunks[chunk_idx].run(
q=q, q=q,
k=k, k=k,
v=v, v=v,
return_lse=True, return_lse=True,
) )
# Convert from (q_len, num_heads) to (num_heads, q_len)
return attn_out, lse.transpose(0, 1).contiguous()
def _run_prefill_context_chunk_cudnn(self, def _run_prefill_context_chunk_cudnn(self,
prefill: MLACommonPrefillMetadata, prefill: MLACommonPrefillMetadata,

View File

@@ -49,14 +49,20 @@ class DeepseekV32IndexerBackend(AttentionBackend):
@dataclass @dataclass
class DeepseekV32IndexerPrefillMetadata: class DeepseekV32IndexerPrefillChunkMetadata:
block_table: torch.Tensor block_table: torch.Tensor
query_start_loc: torch.Tensor
max_query_len: int
cu_seqlen_ks: torch.Tensor cu_seqlen_ks: torch.Tensor
cu_seqlen_ke: torch.Tensor cu_seqlen_ke: torch.Tensor
cu_seq_lens: torch.Tensor cu_seq_lens: torch.Tensor
total_seq_lens: int total_seq_lens: int
token_start: int
token_end: int
num_reqs: int
@dataclass
class DeepseekV32IndexerPrefillMetadata:
chunks: list[DeepseekV32IndexerPrefillChunkMetadata]
@dataclass @dataclass
@@ -98,8 +104,8 @@ class DeepseekV32IndexerMetadata:
# TODO (zyongye) optimize this, this is now vibe coded # TODO (zyongye) optimize this, this is now vibe coded
def kv_spans_from_batches( def kv_spans_from_batches(
start_seq_loc: torch.Tensor, start_seq_loc: torch.Tensor, seq_len_per_batch: torch.Tensor,
seq_len_per_batch: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: device: torch.device) -> tuple[torch.Tensor, torch.Tensor]:
""" """
Args: Args:
start_seq_loc: 1D long tensor [B+1], cumulative counts of start_seq_loc: 1D long tensor [B+1], cumulative counts of
@@ -122,7 +128,7 @@ def kv_spans_from_batches(
are the **last** `counts[i]` positions of that sequence. are the **last** `counts[i]` positions of that sequence.
""" """
q = start_seq_loc.to(dtype=torch.long) q = start_seq_loc.to(dtype=torch.long)
L = seq_len_per_batch.to(dtype=torch.long, device=q.device) L = seq_len_per_batch.to(dtype=torch.long)
assert q.dim() == 1 and L.dim() == 1 assert q.dim() == 1 and L.dim() == 1
assert q.numel() == L.numel() + 1, "start_seq_loc must have length B+1" assert q.numel() == L.numel() + 1, "start_seq_loc must have length B+1"
@@ -130,7 +136,6 @@ def kv_spans_from_batches(
counts = q[1:] - q[:-1] # [B] counts = q[1:] - q[:-1] # [B]
N = int(q[-1].item()) # total selected tokens N = int(q[-1].item()) # total selected tokens
B = L.numel() B = L.numel()
device = L.device
if N == 0: if N == 0:
return (torch.empty(0, dtype=torch.long, device=device), return (torch.empty(0, dtype=torch.long, device=device),
@@ -140,8 +145,7 @@ def kv_spans_from_batches(
kv_starts_per_batch = torch.cumsum(L, dim=0) - L # [B] kv_starts_per_batch = torch.cumsum(L, dim=0) - L # [B]
# For each selected token, which batch does it belong to? # For each selected token, which batch does it belong to?
batch_id = torch.repeat_interleave(torch.arange(B, device=device), batch_id = torch.repeat_interleave(torch.arange(B), counts) # [N]
counts) # [N]
# Map batch KV start to each token # Map batch KV start to each token
start_tensor = kv_starts_per_batch[batch_id] # [N] start_tensor = kv_starts_per_batch[batch_id] # [N]
@@ -151,22 +155,51 @@ def kv_spans_from_batches(
L_expand = torch.repeat_interleave(L, counts) # [N] L_expand = torch.repeat_interleave(L, counts) # [N]
m_expand = torch.repeat_interleave(counts, counts) # [N] m_expand = torch.repeat_interleave(counts, counts) # [N]
# position within the selected block: 1..counts[b] # position within the selected block: 1..counts[b]
pos_within = (torch.arange(N, device=device, dtype=torch.long) - pos_within = (torch.arange(N, dtype=torch.long) -
torch.repeat_interleave(q[:-1], counts) + 1) torch.repeat_interleave(q[:-1], counts) + 1)
local_pos = L_expand - m_expand + pos_within # [N], 1-based local_pos = L_expand - m_expand + pos_within # [N], 1-based
end_location = start_tensor + local_pos # exclusive end end_location = start_tensor + local_pos # exclusive end
return start_tensor.int(), end_location.int() return start_tensor.int().to(device), end_location.int().to(device)
def get_max_prefill_buffer_size(vllm_config: VllmConfig): def get_max_prefill_buffer_size(vllm_config: VllmConfig):
max_model_len = vllm_config.model_config.max_model_len max_model_len = vllm_config.model_config.max_model_len
# max_num_batched_tokens = \ # NOTE(Chen): 2 is a magic number for controlling the prefill buffer size.
# vllm_config.scheduler_config.max_num_batched_tokens # May be tuned later.
max_num_seq = vllm_config.scheduler_config.max_num_seqs return max_model_len * 2
# NOTE(Chen): an estimated max size of flattened_kv. Need to double check.
return max_model_len * max_num_seq
def split_prefill_chunks(seq_lens_cpu: torch.Tensor,
max_prefill_buffer_size: int,
reqs_start: int) -> list[tuple[int, int]]:
"""
Split the prefill chunks into a list of tuples of (reqs_start, reqs_end)
such that the total sequence length of each chunk is less than the
maximum prefill buffer size.
Args:
seq_lens_cpu: The sequence lengths of the prefill requests.
max_prefill_buffer_size: The maximum prefill buffer size.
reqs_start: The start index of the prefill requests.
Returns:
A list of tuples of (reqs_start, reqs_end).
"""
chunk_seq_ids = []
total_seq_lens = 0
for i in range(reqs_start, len(seq_lens_cpu)):
cur_seq_len = seq_lens_cpu[i].item()
assert cur_seq_len <= max_prefill_buffer_size
total_seq_lens += cur_seq_len
if total_seq_lens > max_prefill_buffer_size:
chunk_seq_ids.append((reqs_start, i))
reqs_start = i
total_seq_lens = cur_seq_len
if total_seq_lens > 0:
chunk_seq_ids.append((reqs_start, len(seq_lens_cpu)))
return chunk_seq_ids
class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
@@ -201,6 +234,33 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
dtype=torch.int32, dtype=torch.int32,
device=self.device) device=self.device)
def build_one_prefill_chunk(self, reqs_start, reqs_end,
query_start_loc_cpu, seq_lens_cpu,
block_table):
prefill_query_start_loc = query_start_loc_cpu[
reqs_start:reqs_end + 1] - query_start_loc_cpu[reqs_start]
cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches(
prefill_query_start_loc, seq_lens_cpu[reqs_start:reqs_end],
self.device)
token_start = query_start_loc_cpu[reqs_start].item()
token_end = query_start_loc_cpu[reqs_end].item()
total_seq_lens = seq_lens_cpu[reqs_start:reqs_end].sum()
assert total_seq_lens <= self.max_prefill_buffer_size
cu_seq_lens = torch.cat([
torch.zeros(1, dtype=torch.int32),
seq_lens_cpu[reqs_start:reqs_end].cumsum(dim=0)
]).to(torch.int32).to(self.device)
return DeepseekV32IndexerPrefillChunkMetadata(
cu_seqlen_ks=cu_seqlen_ks,
cu_seqlen_ke=cu_seqlen_ke,
cu_seq_lens=cu_seq_lens,
total_seq_lens=total_seq_lens,
block_table=block_table[reqs_start:reqs_end],
token_start=token_start,
token_end=token_end,
num_reqs=reqs_end - reqs_start,
)
def build(self, def build(self,
common_prefix_len: int, common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata, common_attn_metadata: CommonAttentionMetadata,
@@ -209,11 +269,7 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
num_reqs = common_attn_metadata.num_reqs num_reqs = common_attn_metadata.num_reqs
num_tokens = common_attn_metadata.num_actual_tokens num_tokens = common_attn_metadata.num_actual_tokens
device = self.device query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
block_table_tensor = common_attn_metadata.block_table_tensor
query_start_loc = common_attn_metadata.query_start_loc
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \ num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
split_decodes_and_prefills( split_decodes_and_prefills(
common_attn_metadata, common_attn_metadata,
@@ -224,27 +280,20 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
prefill_metadata = None prefill_metadata = None
if num_prefills > 0: if num_prefills > 0:
reqs_start = num_decodes chunk_seq_ids = split_prefill_chunks(
prefill_query_start_loc = query_start_loc[ common_attn_metadata.seq_lens_cpu,
reqs_start:] - query_start_loc[reqs_start] self.max_prefill_buffer_size,
cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches( num_decodes,
prefill_query_start_loc,
common_attn_metadata.seq_lens[reqs_start:])
total_seq_lens = common_attn_metadata.seq_lens[reqs_start:].sum()
assert total_seq_lens < self.max_prefill_buffer_size
cu_seq_lens = torch.cat([
torch.zeros(1, dtype=torch.int32, device=device),
common_attn_metadata.seq_lens[reqs_start:].cumsum(dim=0)
]).to(torch.int32).cuda()
prefill_metadata = DeepseekV32IndexerPrefillMetadata(
block_table=block_table_tensor[reqs_start:, ...],
query_start_loc=prefill_query_start_loc,
max_query_len=common_attn_metadata.max_query_len,
cu_seqlen_ks=cu_seqlen_ks,
cu_seqlen_ke=cu_seqlen_ke,
cu_seq_lens=cu_seq_lens,
total_seq_lens=total_seq_lens,
) )
chunks = [
self.build_one_prefill_chunk(
reqs_start, reqs_end, query_start_loc_cpu,
common_attn_metadata.seq_lens_cpu,
common_attn_metadata.block_table_tensor)
for reqs_start, reqs_end in chunk_seq_ids
]
prefill_metadata = DeepseekV32IndexerPrefillMetadata(
chunks=chunks, )
decode_metadata = None decode_metadata = None
if num_decodes > 0: if num_decodes > 0: