Compare commits
9 Commits
v0.11.0rc4
...
v0.11.0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b8b302cde4 | ||
|
|
f71952c1c4 | ||
|
|
d1007767c5 | ||
|
|
c75c2e70d6 | ||
|
|
9d9a2b77f1 | ||
|
|
6040e0b6c0 | ||
|
|
05bf0c52a1 | ||
|
|
c536881a7c | ||
|
|
ebce361c07 |
@@ -48,7 +48,7 @@ steps:
|
|||||||
agents:
|
agents:
|
||||||
queue: cpu_queue_postmerge
|
queue: cpu_queue_postmerge
|
||||||
commands:
|
commands:
|
||||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.9.1 --build-arg torch_cuda_arch_list='7.0 7.5 8.0 8.9 9.0+PTX' --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ."
|
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.9.1 --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ."
|
||||||
- "mkdir artifacts"
|
- "mkdir artifacts"
|
||||||
- "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'"
|
- "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'"
|
||||||
- "bash .buildkite/scripts/upload-wheels.sh"
|
- "bash .buildkite/scripts/upload-wheels.sh"
|
||||||
|
|||||||
@@ -580,22 +580,22 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
|||||||
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
||||||
auto blk_coord = tile_scheduler.get_block_coord();
|
auto blk_coord = tile_scheduler.get_block_coord();
|
||||||
auto problem_shape = params.problem_shape;
|
auto problem_shape = params.problem_shape;
|
||||||
auto local_split_kv = params.split_kv;
|
auto local_split_kv = params.split_kv;
|
||||||
if (params.mainloop.ptr_seq != nullptr) {
|
if (params.mainloop.ptr_seq != nullptr) {
|
||||||
get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
|
get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
|
||||||
if (params.ptr_split_kv != nullptr) {
|
if (params.ptr_split_kv != nullptr) {
|
||||||
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
|
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (local_split_kv <= get<3>(blk_coord))
|
if (local_split_kv <= get<3>(blk_coord))
|
||||||
continue;
|
continue;
|
||||||
load_page_table(
|
load_page_table(
|
||||||
blk_coord,
|
blk_coord,
|
||||||
problem_shape,
|
problem_shape,
|
||||||
params.mainloop,
|
params.mainloop,
|
||||||
shared_storage.tensors,
|
shared_storage.tensors,
|
||||||
pipeline_page_table, pipeline_pt_producer_state,
|
pipeline_page_table, pipeline_pt_producer_state,
|
||||||
local_split_kv
|
local_split_kv
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -604,15 +604,15 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
|||||||
CUTLASS_PRAGMA_NO_UNROLL
|
CUTLASS_PRAGMA_NO_UNROLL
|
||||||
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
||||||
auto blk_coord = tile_scheduler.get_block_coord();
|
auto blk_coord = tile_scheduler.get_block_coord();
|
||||||
auto problem_shape = params.problem_shape;
|
auto problem_shape = params.problem_shape;
|
||||||
auto local_split_kv = params.split_kv;
|
auto local_split_kv = params.split_kv;
|
||||||
if (params.mainloop.ptr_seq != nullptr) {
|
if (params.mainloop.ptr_seq != nullptr) {
|
||||||
get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
|
get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
|
||||||
if (params.ptr_split_kv != nullptr) {
|
if (params.ptr_split_kv != nullptr) {
|
||||||
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
|
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (local_split_kv <= get<3>(blk_coord))
|
if (local_split_kv <= get<3>(blk_coord))
|
||||||
continue;
|
continue;
|
||||||
load_cpasync(
|
load_cpasync(
|
||||||
blk_coord,
|
blk_coord,
|
||||||
@@ -621,7 +621,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
|||||||
params.mainloop_params,
|
params.mainloop_params,
|
||||||
shared_storage.tensors,
|
shared_storage.tensors,
|
||||||
pipeline_load_qk, pipeline_load_qk_producer_state,
|
pipeline_load_qk, pipeline_load_qk_producer_state,
|
||||||
local_split_kv,
|
local_split_kv,
|
||||||
/* must be shared pipe */
|
/* must be shared pipe */
|
||||||
pipeline_page_table, pipeline_pt_consumer_state
|
pipeline_page_table, pipeline_pt_consumer_state
|
||||||
);
|
);
|
||||||
@@ -633,15 +633,15 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
|||||||
CUTLASS_PRAGMA_NO_UNROLL
|
CUTLASS_PRAGMA_NO_UNROLL
|
||||||
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
||||||
auto blk_coord = tile_scheduler.get_block_coord();
|
auto blk_coord = tile_scheduler.get_block_coord();
|
||||||
auto problem_shape = params.problem_shape;
|
auto problem_shape = params.problem_shape;
|
||||||
auto local_split_kv = params.split_kv;
|
auto local_split_kv = params.split_kv;
|
||||||
if (params.mainloop.ptr_seq != nullptr) {
|
if (params.mainloop.ptr_seq != nullptr) {
|
||||||
get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
|
get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
|
||||||
if (params.ptr_split_kv != nullptr) {
|
if (params.ptr_split_kv != nullptr) {
|
||||||
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
|
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (local_split_kv <= get<3>(blk_coord))
|
if (local_split_kv <= get<3>(blk_coord))
|
||||||
continue;
|
continue;
|
||||||
load_tma</* paged= */ true>(
|
load_tma</* paged= */ true>(
|
||||||
blk_coord,
|
blk_coord,
|
||||||
@@ -651,7 +651,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
|||||||
shared_storage.tensors,
|
shared_storage.tensors,
|
||||||
pipeline_load_qk, pipeline_load_qk_producer_state,
|
pipeline_load_qk, pipeline_load_qk_producer_state,
|
||||||
pipeline_load_qk, pipeline_load_qk_producer_state,
|
pipeline_load_qk, pipeline_load_qk_producer_state,
|
||||||
local_split_kv
|
local_split_kv
|
||||||
);
|
);
|
||||||
cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive_and_wait();
|
cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive_and_wait();
|
||||||
}
|
}
|
||||||
@@ -660,15 +660,15 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
|||||||
CUTLASS_PRAGMA_NO_UNROLL
|
CUTLASS_PRAGMA_NO_UNROLL
|
||||||
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
||||||
auto blk_coord = tile_scheduler.get_block_coord();
|
auto blk_coord = tile_scheduler.get_block_coord();
|
||||||
auto problem_shape = params.problem_shape;
|
auto problem_shape = params.problem_shape;
|
||||||
auto local_split_kv = params.split_kv;
|
auto local_split_kv = params.split_kv;
|
||||||
if (params.mainloop.ptr_seq != nullptr) {
|
if (params.mainloop.ptr_seq != nullptr) {
|
||||||
get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
|
get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
|
||||||
if (params.ptr_split_kv != nullptr) {
|
if (params.ptr_split_kv != nullptr) {
|
||||||
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
|
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (local_split_kv <= get<3>(blk_coord))
|
if (local_split_kv <= get<3>(blk_coord))
|
||||||
continue;
|
continue;
|
||||||
load_tma<false>(
|
load_tma<false>(
|
||||||
blk_coord,
|
blk_coord,
|
||||||
@@ -678,7 +678,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
|||||||
shared_storage.tensors,
|
shared_storage.tensors,
|
||||||
pipeline_load_qk, pipeline_load_qk_producer_state,
|
pipeline_load_qk, pipeline_load_qk_producer_state,
|
||||||
pipeline_load_qk, pipeline_load_qk_producer_state,
|
pipeline_load_qk, pipeline_load_qk_producer_state,
|
||||||
local_split_kv
|
local_split_kv
|
||||||
);
|
);
|
||||||
cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive_and_wait();
|
cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive_and_wait();
|
||||||
}
|
}
|
||||||
@@ -694,14 +694,14 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
|||||||
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
||||||
auto blk_coord = tile_scheduler.get_block_coord();
|
auto blk_coord = tile_scheduler.get_block_coord();
|
||||||
auto problem_shape = params.problem_shape;
|
auto problem_shape = params.problem_shape;
|
||||||
auto local_split_kv = params.split_kv;
|
auto local_split_kv = params.split_kv;
|
||||||
if (params.mainloop.ptr_seq != nullptr) {
|
if (params.mainloop.ptr_seq != nullptr) {
|
||||||
get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
|
get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
|
||||||
if (params.ptr_split_kv != nullptr) {
|
if (params.ptr_split_kv != nullptr) {
|
||||||
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
|
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (local_split_kv <= get<3>(blk_coord))
|
if (local_split_kv <= get<3>(blk_coord))
|
||||||
continue;
|
continue;
|
||||||
mma(blk_coord,
|
mma(blk_coord,
|
||||||
problem_shape,
|
problem_shape,
|
||||||
@@ -711,7 +711,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
|||||||
pipeline_mma_s, pipeline_mma_s_producer_state,
|
pipeline_mma_s, pipeline_mma_s_producer_state,
|
||||||
pipeline_p_mma, pipeline_p_mma_consumer_state,
|
pipeline_p_mma, pipeline_p_mma_consumer_state,
|
||||||
pipeline_mma_o, pipeline_mma_o_producer_state,
|
pipeline_mma_o, pipeline_mma_o_producer_state,
|
||||||
local_split_kv
|
local_split_kv
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -726,15 +726,15 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
|||||||
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
|
||||||
auto blk_coord = tile_scheduler.get_block_coord();
|
auto blk_coord = tile_scheduler.get_block_coord();
|
||||||
auto problem_shape = params.problem_shape;
|
auto problem_shape = params.problem_shape;
|
||||||
auto split_kv = params.split_kv;
|
auto split_kv = params.split_kv;
|
||||||
auto local_split_kv = split_kv;
|
auto local_split_kv = split_kv;
|
||||||
if (params.mainloop.ptr_seq != nullptr) {
|
if (params.mainloop.ptr_seq != nullptr) {
|
||||||
get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
|
get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
|
||||||
if (params.ptr_split_kv != nullptr) {
|
if (params.ptr_split_kv != nullptr) {
|
||||||
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
|
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (local_split_kv <= get<3>(blk_coord))
|
if (local_split_kv <= get<3>(blk_coord))
|
||||||
continue;
|
continue;
|
||||||
compute(
|
compute(
|
||||||
blk_coord,
|
blk_coord,
|
||||||
@@ -745,7 +745,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
|||||||
pipeline_mma_s, pipeline_mma_s_consumer_state,
|
pipeline_mma_s, pipeline_mma_s_consumer_state,
|
||||||
pipeline_p_mma, pipeline_p_mma_producer_state,
|
pipeline_p_mma, pipeline_p_mma_producer_state,
|
||||||
pipeline_mma_o, pipeline_mma_o_consumer_state,
|
pipeline_mma_o, pipeline_mma_o_consumer_state,
|
||||||
local_split_kv
|
local_split_kv
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1900,7 +1900,7 @@ struct Sm100FmhaMlaKernelTmaWarpspecialized {
|
|||||||
cutlass::arch::NamedBarrier(
|
cutlass::arch::NamedBarrier(
|
||||||
(kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp,
|
(kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp,
|
||||||
kNamedBarrierEpilogue
|
kNamedBarrierEpilogue
|
||||||
).arrive();
|
).arrive_and_wait();
|
||||||
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,6 +14,11 @@ ARG PYTHON_VERSION=3.12
|
|||||||
#
|
#
|
||||||
# Example:
|
# Example:
|
||||||
# docker build --build-arg BUILD_BASE_IMAGE=registry.acme.org/mirror/nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04
|
# docker build --build-arg BUILD_BASE_IMAGE=registry.acme.org/mirror/nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04
|
||||||
|
|
||||||
|
# Important: We build with an old version of Ubuntu to maintain broad
|
||||||
|
# compatibility with other Linux OSes. The main reason for this is that the
|
||||||
|
# glibc version is baked into the distro, and binaries built with one glibc
|
||||||
|
# version are not backwards compatible with OSes that use an earlier version.
|
||||||
ARG BUILD_BASE_IMAGE=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04
|
ARG BUILD_BASE_IMAGE=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04
|
||||||
# TODO: Restore to base image after FlashInfer AOT wheel fixed
|
# TODO: Restore to base image after FlashInfer AOT wheel fixed
|
||||||
ARG FINAL_BASE_IMAGE=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04
|
ARG FINAL_BASE_IMAGE=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04
|
||||||
@@ -75,34 +80,19 @@ ARG TARGETPLATFORM
|
|||||||
ARG INSTALL_KV_CONNECTORS=false
|
ARG INSTALL_KV_CONNECTORS=false
|
||||||
ENV DEBIAN_FRONTEND=noninteractive
|
ENV DEBIAN_FRONTEND=noninteractive
|
||||||
|
|
||||||
ARG DEADSNAKES_MIRROR_URL
|
|
||||||
ARG DEADSNAKES_GPGKEY_URL
|
|
||||||
ARG GET_PIP_URL
|
ARG GET_PIP_URL
|
||||||
|
|
||||||
# Install Python and other dependencies
|
# Install system dependencies and uv, then create Python virtual environment
|
||||||
RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
|
RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
|
||||||
&& echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \
|
&& echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \
|
||||||
&& apt-get update -y \
|
&& apt-get update -y \
|
||||||
&& apt-get install -y ccache software-properties-common git curl sudo \
|
&& apt-get install -y ccache software-properties-common git curl sudo python3-pip \
|
||||||
&& if [ ! -z ${DEADSNAKES_MIRROR_URL} ] ; then \
|
&& curl -LsSf https://astral.sh/uv/install.sh | sh \
|
||||||
if [ ! -z "${DEADSNAKES_GPGKEY_URL}" ] ; then \
|
&& $HOME/.local/bin/uv venv /opt/venv --python ${PYTHON_VERSION} \
|
||||||
mkdir -p -m 0755 /etc/apt/keyrings ; \
|
&& rm -f /usr/bin/python3 /usr/bin/python3-config /usr/bin/pip \
|
||||||
curl -L ${DEADSNAKES_GPGKEY_URL} | gpg --dearmor > /etc/apt/keyrings/deadsnakes.gpg ; \
|
&& ln -s /opt/venv/bin/python3 /usr/bin/python3 \
|
||||||
sudo chmod 644 /etc/apt/keyrings/deadsnakes.gpg ; \
|
&& ln -s /opt/venv/bin/python3-config /usr/bin/python3-config \
|
||||||
echo "deb [signed-by=/etc/apt/keyrings/deadsnakes.gpg] ${DEADSNAKES_MIRROR_URL} $(lsb_release -cs) main" > /etc/apt/sources.list.d/deadsnakes.list ; \
|
&& ln -s /opt/venv/bin/pip /usr/bin/pip \
|
||||||
fi ; \
|
|
||||||
else \
|
|
||||||
for i in 1 2 3; do \
|
|
||||||
add-apt-repository -y ppa:deadsnakes/ppa && break || \
|
|
||||||
{ echo "Attempt $i failed, retrying in 5s..."; sleep 5; }; \
|
|
||||||
done ; \
|
|
||||||
fi \
|
|
||||||
&& apt-get update -y \
|
|
||||||
&& apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \
|
|
||||||
&& update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 \
|
|
||||||
&& update-alternatives --set python3 /usr/bin/python${PYTHON_VERSION} \
|
|
||||||
&& ln -sf /usr/bin/python${PYTHON_VERSION}-config /usr/bin/python3-config \
|
|
||||||
&& curl -sS ${GET_PIP_URL} | python${PYTHON_VERSION} \
|
|
||||||
&& python3 --version && python3 -m pip --version
|
&& python3 --version && python3 -m pip --version
|
||||||
|
|
||||||
ARG PIP_INDEX_URL UV_INDEX_URL
|
ARG PIP_INDEX_URL UV_INDEX_URL
|
||||||
@@ -111,9 +101,9 @@ ARG PYTORCH_CUDA_INDEX_BASE_URL
|
|||||||
ARG PYTORCH_CUDA_NIGHTLY_INDEX_BASE_URL
|
ARG PYTORCH_CUDA_NIGHTLY_INDEX_BASE_URL
|
||||||
ARG PIP_KEYRING_PROVIDER UV_KEYRING_PROVIDER
|
ARG PIP_KEYRING_PROVIDER UV_KEYRING_PROVIDER
|
||||||
|
|
||||||
# Install uv for faster pip installs
|
# Activate virtual environment and add uv to PATH
|
||||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
ENV PATH="/opt/venv/bin:/root/.local/bin:$PATH"
|
||||||
python3 -m pip install uv
|
ENV VIRTUAL_ENV="/opt/venv"
|
||||||
|
|
||||||
# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out
|
# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out
|
||||||
# Reference: https://github.com/astral-sh/uv/pull/1694
|
# Reference: https://github.com/astral-sh/uv/pull/1694
|
||||||
@@ -142,7 +132,7 @@ WORKDIR /workspace
|
|||||||
COPY requirements/common.txt requirements/common.txt
|
COPY requirements/common.txt requirements/common.txt
|
||||||
COPY requirements/cuda.txt requirements/cuda.txt
|
COPY requirements/cuda.txt requirements/cuda.txt
|
||||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||||
uv pip install --system -r requirements/cuda.txt \
|
uv pip install --python /opt/venv/bin/python3 -r requirements/cuda.txt \
|
||||||
--extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
|
--extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
|
||||||
|
|
||||||
# cuda arch list used by torch
|
# cuda arch list used by torch
|
||||||
@@ -172,7 +162,7 @@ ENV UV_INDEX_STRATEGY="unsafe-best-match"
|
|||||||
ENV UV_LINK_MODE=copy
|
ENV UV_LINK_MODE=copy
|
||||||
|
|
||||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||||
uv pip install --system -r requirements/build.txt \
|
uv pip install --python /opt/venv/bin/python3 -r requirements/build.txt \
|
||||||
--extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
|
--extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
|
||||||
|
|
||||||
COPY . .
|
COPY . .
|
||||||
@@ -269,7 +259,7 @@ COPY requirements/lint.txt requirements/lint.txt
|
|||||||
COPY requirements/test.txt requirements/test.txt
|
COPY requirements/test.txt requirements/test.txt
|
||||||
COPY requirements/dev.txt requirements/dev.txt
|
COPY requirements/dev.txt requirements/dev.txt
|
||||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||||
uv pip install --system -r requirements/dev.txt \
|
uv pip install --python /opt/venv/bin/python3 -r requirements/dev.txt \
|
||||||
--extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
|
--extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
|
||||||
#################### DEV IMAGE ####################
|
#################### DEV IMAGE ####################
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ ARG CUDA_VERSION=12.8.0
|
|||||||
#
|
#
|
||||||
#################### BASE BUILD IMAGE ####################
|
#################### BASE BUILD IMAGE ####################
|
||||||
# prepare basic build environment
|
# prepare basic build environment
|
||||||
FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04 AS base
|
FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04 AS base
|
||||||
ARG CUDA_VERSION=12.8.0
|
ARG CUDA_VERSION=12.8.0
|
||||||
ARG PYTHON_VERSION=3.12
|
ARG PYTHON_VERSION=3.12
|
||||||
ARG TARGETPLATFORM
|
ARG TARGETPLATFORM
|
||||||
|
|||||||
@@ -8,6 +8,9 @@ This page teaches you how to pass multi-modal inputs to [multi-modal models][sup
|
|||||||
|
|
||||||
!!! tip
|
!!! tip
|
||||||
When serving multi-modal models, consider setting `--allowed-media-domains` to restrict domain that vLLM can access to prevent it from accessing arbitrary endpoints that can potentially be vulnerable to Server-Side Request Forgery (SSRF) attacks. You can provide a list of domains for this arg. For example: `--allowed-media-domains upload.wikimedia.org github.com www.bogotobogo.com`
|
When serving multi-modal models, consider setting `--allowed-media-domains` to restrict domain that vLLM can access to prevent it from accessing arbitrary endpoints that can potentially be vulnerable to Server-Side Request Forgery (SSRF) attacks. You can provide a list of domains for this arg. For example: `--allowed-media-domains upload.wikimedia.org github.com www.bogotobogo.com`
|
||||||
|
|
||||||
|
Also, consider setting `VLLM_MEDIA_URL_ALLOW_REDIRECTS=0` to prevent HTTP redirects from being followed to bypass domain restrictions.
|
||||||
|
|
||||||
This restriction is especially important if you run vLLM in a containerized environment where the vLLM pods may have unrestricted access to internal networks.
|
This restriction is especially important if you run vLLM in a containerized environment where the vLLM pods may have unrestricted access to internal networks.
|
||||||
|
|
||||||
## Offline Inference
|
## Offline Inference
|
||||||
|
|||||||
@@ -66,6 +66,9 @@ Restrict domains that vLLM can access for media URLs by setting
|
|||||||
`--allowed-media-domains` to prevent Server-Side Request Forgery (SSRF) attacks.
|
`--allowed-media-domains` to prevent Server-Side Request Forgery (SSRF) attacks.
|
||||||
(e.g. `--allowed-media-domains upload.wikimedia.org github.com www.bogotobogo.com`)
|
(e.g. `--allowed-media-domains upload.wikimedia.org github.com www.bogotobogo.com`)
|
||||||
|
|
||||||
|
Also, consider setting `VLLM_MEDIA_URL_ALLOW_REDIRECTS=0` to prevent HTTP
|
||||||
|
redirects from being followed to bypass domain restrictions.
|
||||||
|
|
||||||
## Security and Firewalls: Protecting Exposed vLLM Systems
|
## Security and Firewalls: Protecting Exposed vLLM Systems
|
||||||
|
|
||||||
While vLLM is designed to allow unsafe network services to be isolated to
|
While vLLM is designed to allow unsafe network services to be isolated to
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ from vllm.utils import cdiv
|
|||||||
from vllm.v1.attention.backends.mla.flashmla_sparse import (
|
from vllm.v1.attention.backends.mla.flashmla_sparse import (
|
||||||
FlashMLASparseBackend, FlashMLASparseDecodeAndContextMetadata,
|
FlashMLASparseBackend, FlashMLASparseDecodeAndContextMetadata,
|
||||||
FlashMLASparseImpl, FlashMLASparseMetadata)
|
FlashMLASparseImpl, FlashMLASparseMetadata)
|
||||||
|
from vllm.v1.attention.backends.mla.indexer import split_prefill_chunks
|
||||||
|
|
||||||
SPARSE_BACKEND_BATCH_SPECS = {
|
SPARSE_BACKEND_BATCH_SPECS = {
|
||||||
name: BATCH_SPECS[name]
|
name: BATCH_SPECS[name]
|
||||||
@@ -424,3 +425,24 @@ def test_sparse_backend_decode_correctness(dist_init, batch_name,
|
|||||||
sdpa_reference,
|
sdpa_reference,
|
||||||
rtol=0.5,
|
rtol=0.5,
|
||||||
atol=0.5)
|
atol=0.5)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"seq_lens,max_buf,start,expected",
|
||||||
|
[
|
||||||
|
# Basic split: totals per chunk ≤ max_buf
|
||||||
|
(torch.tensor([2, 3, 4, 2]), 5, 0, [(0, 2), (2, 3), (3, 4)]),
|
||||||
|
# Non-zero start index
|
||||||
|
(torch.tensor([2, 3, 4, 2]), 5, 1, [(1, 2), (2, 3), (3, 4)]),
|
||||||
|
# Exact fits should split between items when adding the next would
|
||||||
|
# overflow
|
||||||
|
(torch.tensor([5, 5, 5]), 5, 0, [(0, 1), (1, 2), (2, 3)]),
|
||||||
|
# All requests fit in a single chunk
|
||||||
|
(torch.tensor([1, 1, 1]), 10, 0, [(0, 3)]),
|
||||||
|
# Large buffer with non-zero start
|
||||||
|
(torch.tensor([4, 4, 4]), 100, 1, [(1, 3)]),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_split_prefill_chunks(seq_lens, max_buf, start, expected):
|
||||||
|
out = split_prefill_chunks(seq_lens, max_buf, start)
|
||||||
|
assert out == expected
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import functools
|
import functools
|
||||||
from typing import List, Optional
|
from typing import ClassVar, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -11,8 +11,8 @@ from vllm.attention.backends.abstract import (AttentionBackend,
|
|||||||
from vllm.attention.selector import get_attn_backend
|
from vllm.attention.selector import get_attn_backend
|
||||||
from vllm.config import CacheConfig, QuantizationConfig
|
from vllm.config import CacheConfig, QuantizationConfig
|
||||||
from vllm.v1.attention.backends.utils import (
|
from vllm.v1.attention.backends.utils import (
|
||||||
CommonAttentionMetadata, make_local_attention_virtual_batches,
|
AttentionCGSupport, CommonAttentionMetadata,
|
||||||
subclass_attention_backend)
|
make_local_attention_virtual_batches, subclass_attention_backend)
|
||||||
|
|
||||||
from ..layer import Attention
|
from ..layer import Attention
|
||||||
|
|
||||||
@@ -28,6 +28,8 @@ def create_chunked_local_attention_backend(
|
|||||||
underlying_builder = underlying_attn_backend.get_builder_cls()
|
underlying_builder = underlying_attn_backend.get_builder_cls()
|
||||||
|
|
||||||
class ChunkedLocalAttentionBuilder(underlying_builder): # type: ignore
|
class ChunkedLocalAttentionBuilder(underlying_builder): # type: ignore
|
||||||
|
cudagraph_support: ClassVar[AttentionCGSupport] = \
|
||||||
|
AttentionCGSupport.NEVER
|
||||||
|
|
||||||
def build(self,
|
def build(self,
|
||||||
common_prefix_len: int,
|
common_prefix_len: int,
|
||||||
|
|||||||
@@ -54,6 +54,7 @@ class HTTPConnection:
|
|||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
timeout: Optional[float] = None,
|
timeout: Optional[float] = None,
|
||||||
extra_headers: Optional[Mapping[str, str]] = None,
|
extra_headers: Optional[Mapping[str, str]] = None,
|
||||||
|
allow_redirects: bool = True,
|
||||||
):
|
):
|
||||||
self._validate_http_url(url)
|
self._validate_http_url(url)
|
||||||
|
|
||||||
@@ -63,7 +64,8 @@ class HTTPConnection:
|
|||||||
return client.get(url,
|
return client.get(url,
|
||||||
headers=self._headers(**extra_headers),
|
headers=self._headers(**extra_headers),
|
||||||
stream=stream,
|
stream=stream,
|
||||||
timeout=timeout)
|
timeout=timeout,
|
||||||
|
allow_redirects=allow_redirects)
|
||||||
|
|
||||||
async def get_async_response(
|
async def get_async_response(
|
||||||
self,
|
self,
|
||||||
@@ -71,6 +73,7 @@ class HTTPConnection:
|
|||||||
*,
|
*,
|
||||||
timeout: Optional[float] = None,
|
timeout: Optional[float] = None,
|
||||||
extra_headers: Optional[Mapping[str, str]] = None,
|
extra_headers: Optional[Mapping[str, str]] = None,
|
||||||
|
allow_redirects: bool = True,
|
||||||
):
|
):
|
||||||
self._validate_http_url(url)
|
self._validate_http_url(url)
|
||||||
|
|
||||||
@@ -79,10 +82,17 @@ class HTTPConnection:
|
|||||||
|
|
||||||
return client.get(url,
|
return client.get(url,
|
||||||
headers=self._headers(**extra_headers),
|
headers=self._headers(**extra_headers),
|
||||||
timeout=timeout)
|
timeout=timeout,
|
||||||
|
allow_redirects=allow_redirects)
|
||||||
|
|
||||||
def get_bytes(self, url: str, *, timeout: Optional[float] = None) -> bytes:
|
def get_bytes(self,
|
||||||
with self.get_response(url, timeout=timeout) as r:
|
url: str,
|
||||||
|
*,
|
||||||
|
timeout: Optional[float] = None,
|
||||||
|
allow_redirects: bool = True) -> bytes:
|
||||||
|
with self.get_response(url,
|
||||||
|
timeout=timeout,
|
||||||
|
allow_redirects=allow_redirects) as r:
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
|
|
||||||
return r.content
|
return r.content
|
||||||
@@ -92,8 +102,10 @@ class HTTPConnection:
|
|||||||
url: str,
|
url: str,
|
||||||
*,
|
*,
|
||||||
timeout: Optional[float] = None,
|
timeout: Optional[float] = None,
|
||||||
|
allow_redirects: bool = True,
|
||||||
) -> bytes:
|
) -> bytes:
|
||||||
async with await self.get_async_response(url, timeout=timeout) as r:
|
async with await self.get_async_response(
|
||||||
|
url, timeout=timeout, allow_redirects=allow_redirects) as r:
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
|
|
||||||
return await r.read()
|
return await r.read()
|
||||||
|
|||||||
@@ -68,6 +68,7 @@ if TYPE_CHECKING:
|
|||||||
VLLM_IMAGE_FETCH_TIMEOUT: int = 5
|
VLLM_IMAGE_FETCH_TIMEOUT: int = 5
|
||||||
VLLM_VIDEO_FETCH_TIMEOUT: int = 30
|
VLLM_VIDEO_FETCH_TIMEOUT: int = 30
|
||||||
VLLM_AUDIO_FETCH_TIMEOUT: int = 10
|
VLLM_AUDIO_FETCH_TIMEOUT: int = 10
|
||||||
|
VLLM_MEDIA_URL_ALLOW_REDIRECTS: bool = True
|
||||||
VLLM_MEDIA_LOADING_THREAD_COUNT: int = 8
|
VLLM_MEDIA_LOADING_THREAD_COUNT: int = 8
|
||||||
VLLM_MAX_AUDIO_CLIP_FILESIZE_MB: int = 25
|
VLLM_MAX_AUDIO_CLIP_FILESIZE_MB: int = 25
|
||||||
VLLM_VIDEO_LOADER_BACKEND: str = "opencv"
|
VLLM_VIDEO_LOADER_BACKEND: str = "opencv"
|
||||||
@@ -725,6 +726,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
"VLLM_AUDIO_FETCH_TIMEOUT":
|
"VLLM_AUDIO_FETCH_TIMEOUT":
|
||||||
lambda: int(os.getenv("VLLM_AUDIO_FETCH_TIMEOUT", "10")),
|
lambda: int(os.getenv("VLLM_AUDIO_FETCH_TIMEOUT", "10")),
|
||||||
|
|
||||||
|
# Whether to allow HTTP redirects when fetching from media URLs.
|
||||||
|
# Default to True
|
||||||
|
"VLLM_MEDIA_URL_ALLOW_REDIRECTS":
|
||||||
|
lambda: bool(int(os.getenv("VLLM_MEDIA_URL_ALLOW_REDIRECTS", "1"))),
|
||||||
|
|
||||||
# Max number of workers for the thread pool handling
|
# Max number of workers for the thread pool handling
|
||||||
# media bytes loading. Set to 1 to disable parallel processing.
|
# media bytes loading. Set to 1 to disable parallel processing.
|
||||||
# Default is 8
|
# Default is 8
|
||||||
|
|||||||
@@ -583,44 +583,43 @@ def sparse_attn_indexer(
|
|||||||
topk_indices_buffer[:hidden_states.shape[0]] = -1
|
topk_indices_buffer[:hidden_states.shape[0]] = -1
|
||||||
if has_prefill:
|
if has_prefill:
|
||||||
prefill_metadata = attn_metadata.prefill
|
prefill_metadata = attn_metadata.prefill
|
||||||
num_prefills = attn_metadata.num_prefills
|
for chunk in prefill_metadata.chunks:
|
||||||
k_fp8 = torch.empty([prefill_metadata.total_seq_lens, head_dim],
|
k_fp8 = torch.empty([chunk.total_seq_lens, head_dim],
|
||||||
device=k.device,
|
device=k.device,
|
||||||
dtype=torch.float8_e4m3fn)
|
dtype=torch.float8_e4m3fn)
|
||||||
k_scale = torch.empty([prefill_metadata.total_seq_lens, 1],
|
k_scale = torch.empty([chunk.total_seq_lens, 1],
|
||||||
device=k.device,
|
device=k.device,
|
||||||
dtype=torch.float32)
|
dtype=torch.float32)
|
||||||
cp_gather_indexer_k_quant_cache(
|
cp_gather_indexer_k_quant_cache(
|
||||||
kv_cache,
|
kv_cache,
|
||||||
k_fp8,
|
k_fp8,
|
||||||
k_scale,
|
k_scale,
|
||||||
prefill_metadata.block_table,
|
chunk.block_table,
|
||||||
prefill_metadata.cu_seq_lens,
|
chunk.cu_seq_lens,
|
||||||
num_prefills,
|
chunk.num_reqs,
|
||||||
)
|
)
|
||||||
cu_seqlen_ks = prefill_metadata.cu_seqlen_ks
|
logits = fp8_mqa_logits(
|
||||||
cu_seqlen_ke = prefill_metadata.cu_seqlen_ke
|
q_fp8[chunk.token_start:chunk.token_end],
|
||||||
num_tokens = attn_metadata.num_actual_tokens
|
(k_fp8, k_scale),
|
||||||
logits = fp8_mqa_logits(
|
weights[chunk.token_start:chunk.token_end],
|
||||||
q_fp8[num_decode_tokens:num_tokens],
|
chunk.cu_seqlen_ks,
|
||||||
(k_fp8, k_scale),
|
chunk.cu_seqlen_ke,
|
||||||
weights[num_decode_tokens:num_tokens],
|
)
|
||||||
cu_seqlen_ks,
|
topk_indices = logits.topk(min(topk_tokens, logits.shape[-1]),
|
||||||
cu_seqlen_ke,
|
dim=-1)[1]
|
||||||
)
|
topk_indices -= chunk.cu_seqlen_ks[:, None]
|
||||||
topk_indices = logits.topk(min(topk_tokens, logits.shape[-1]),
|
mask_lo = topk_indices >= 0
|
||||||
dim=-1)[1]
|
mask_hi = topk_indices - (chunk.cu_seqlen_ke -
|
||||||
topk_indices -= cu_seqlen_ks[:, None]
|
chunk.cu_seqlen_ks)[:, None] < 0
|
||||||
mask_lo = topk_indices >= 0
|
mask = torch.full_like(topk_indices,
|
||||||
mask_hi = topk_indices - (cu_seqlen_ke - cu_seqlen_ks)[:, None] < 0
|
False,
|
||||||
mask = torch.full_like(topk_indices,
|
dtype=torch.bool,
|
||||||
False,
|
device=topk_indices.device)
|
||||||
dtype=torch.bool,
|
mask = mask_lo & mask_hi
|
||||||
device=topk_indices.device)
|
topk_indices = topk_indices.masked_fill(~mask, -1)
|
||||||
mask = mask_lo & mask_hi
|
topk_indices_buffer[
|
||||||
topk_indices = topk_indices.masked_fill(~mask, -1)
|
chunk.token_start:chunk.token_end, :topk_indices.
|
||||||
topk_indices_buffer[num_decode_tokens:num_tokens, :topk_indices.
|
shape[-1]] = topk_indices.to(dtype=torch.int32)
|
||||||
shape[-1]] = topk_indices.to(dtype=torch.int32)
|
|
||||||
|
|
||||||
if has_decode:
|
if has_decode:
|
||||||
decode_metadata = attn_metadata.decode
|
decode_metadata = attn_metadata.decode
|
||||||
|
|||||||
@@ -140,7 +140,11 @@ class MediaConnector:
|
|||||||
self._assert_url_in_allowed_media_domains(url_spec)
|
self._assert_url_in_allowed_media_domains(url_spec)
|
||||||
|
|
||||||
connection = self.connection
|
connection = self.connection
|
||||||
data = connection.get_bytes(url, timeout=fetch_timeout)
|
data = connection.get_bytes(
|
||||||
|
url,
|
||||||
|
timeout=fetch_timeout,
|
||||||
|
allow_redirects=envs.VLLM_MEDIA_URL_ALLOW_REDIRECTS,
|
||||||
|
)
|
||||||
|
|
||||||
return media_io.load_bytes(data)
|
return media_io.load_bytes(data)
|
||||||
|
|
||||||
@@ -167,7 +171,11 @@ class MediaConnector:
|
|||||||
self._assert_url_in_allowed_media_domains(url_spec)
|
self._assert_url_in_allowed_media_domains(url_spec)
|
||||||
|
|
||||||
connection = self.connection
|
connection = self.connection
|
||||||
data = await connection.async_get_bytes(url, timeout=fetch_timeout)
|
data = await connection.async_get_bytes(
|
||||||
|
url,
|
||||||
|
timeout=fetch_timeout,
|
||||||
|
allow_redirects=envs.VLLM_MEDIA_URL_ALLOW_REDIRECTS,
|
||||||
|
)
|
||||||
future = loop.run_in_executor(global_thread_pool,
|
future = loop.run_in_executor(global_thread_pool,
|
||||||
media_io.load_bytes, data)
|
media_io.load_bytes, data)
|
||||||
return await future
|
return await future
|
||||||
|
|||||||
@@ -29,7 +29,6 @@ from vllm.utils.flashinfer import (can_use_trtllm_attention,
|
|||||||
flashinfer_disable_q_quantization,
|
flashinfer_disable_q_quantization,
|
||||||
supports_trtllm_attention,
|
supports_trtllm_attention,
|
||||||
use_trtllm_attention)
|
use_trtllm_attention)
|
||||||
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
|
|
||||||
# yapf conflicts with isort for this block
|
# yapf conflicts with isort for this block
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
|
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
|
||||||
@@ -677,7 +676,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
# TODO: The cascade wrapper currently does not support setting
|
# TODO: The cascade wrapper currently does not support setting
|
||||||
# kv cache dtype to something different from query dtype.
|
# kv cache dtype to something different from query dtype.
|
||||||
return False
|
return False
|
||||||
return use_cascade_attention(*args, **kwargs)
|
# TODO: Cascade attention doesn't work, disable it for now
|
||||||
|
# return use_cascade_attention(*args, **kwargs)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class FlashInferImpl(AttentionImpl):
|
class FlashInferImpl(AttentionImpl):
|
||||||
|
|||||||
@@ -1211,13 +1211,18 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
|||||||
k, v, return_softmax_lse):
|
k, v, return_softmax_lse):
|
||||||
assert isinstance(prefill, FlashInferPrefillMetadata)
|
assert isinstance(prefill, FlashInferPrefillMetadata)
|
||||||
assert prefill.prefill_main is not None
|
assert prefill.prefill_main is not None
|
||||||
return prefill.prefill_main.run(
|
ret = prefill.prefill_main.run(
|
||||||
q=q,
|
q=q,
|
||||||
k=k,
|
k=k,
|
||||||
v=v,
|
v=v,
|
||||||
return_lse=return_softmax_lse,
|
return_lse=return_softmax_lse,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if isinstance(ret, tuple):
|
||||||
|
# Convert from (q_len, num_heads) to (num_heads, q_len)
|
||||||
|
return ret[0], ret[1].transpose(0, 1).contiguous()
|
||||||
|
return ret
|
||||||
|
|
||||||
def _run_prefill_new_tokens_cudnn(self, prefill: MLACommonPrefillMetadata,
|
def _run_prefill_new_tokens_cudnn(self, prefill: MLACommonPrefillMetadata,
|
||||||
q, k, v, return_softmax_lse):
|
q, k, v, return_softmax_lse):
|
||||||
assert isinstance(prefill, CudnnPrefillMetadata)
|
assert isinstance(prefill, CudnnPrefillMetadata)
|
||||||
@@ -1260,12 +1265,14 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
|||||||
def _run_prefill_context_chunk_fi(self, prefill: MLACommonPrefillMetadata,
|
def _run_prefill_context_chunk_fi(self, prefill: MLACommonPrefillMetadata,
|
||||||
chunk_idx: int, q, k, v):
|
chunk_idx: int, q, k, v):
|
||||||
assert isinstance(prefill, FlashInferPrefillMetadata)
|
assert isinstance(prefill, FlashInferPrefillMetadata)
|
||||||
return prefill.prefill_chunks[chunk_idx].run(
|
attn_out, lse = prefill.prefill_chunks[chunk_idx].run(
|
||||||
q=q,
|
q=q,
|
||||||
k=k,
|
k=k,
|
||||||
v=v,
|
v=v,
|
||||||
return_lse=True,
|
return_lse=True,
|
||||||
)
|
)
|
||||||
|
# Convert from (q_len, num_heads) to (num_heads, q_len)
|
||||||
|
return attn_out, lse.transpose(0, 1).contiguous()
|
||||||
|
|
||||||
def _run_prefill_context_chunk_cudnn(self,
|
def _run_prefill_context_chunk_cudnn(self,
|
||||||
prefill: MLACommonPrefillMetadata,
|
prefill: MLACommonPrefillMetadata,
|
||||||
|
|||||||
@@ -49,14 +49,20 @@ class DeepseekV32IndexerBackend(AttentionBackend):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DeepseekV32IndexerPrefillMetadata:
|
class DeepseekV32IndexerPrefillChunkMetadata:
|
||||||
block_table: torch.Tensor
|
block_table: torch.Tensor
|
||||||
query_start_loc: torch.Tensor
|
|
||||||
max_query_len: int
|
|
||||||
cu_seqlen_ks: torch.Tensor
|
cu_seqlen_ks: torch.Tensor
|
||||||
cu_seqlen_ke: torch.Tensor
|
cu_seqlen_ke: torch.Tensor
|
||||||
cu_seq_lens: torch.Tensor
|
cu_seq_lens: torch.Tensor
|
||||||
total_seq_lens: int
|
total_seq_lens: int
|
||||||
|
token_start: int
|
||||||
|
token_end: int
|
||||||
|
num_reqs: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DeepseekV32IndexerPrefillMetadata:
|
||||||
|
chunks: list[DeepseekV32IndexerPrefillChunkMetadata]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -98,8 +104,8 @@ class DeepseekV32IndexerMetadata:
|
|||||||
|
|
||||||
# TODO (zyongye) optimize this, this is now vibe coded
|
# TODO (zyongye) optimize this, this is now vibe coded
|
||||||
def kv_spans_from_batches(
|
def kv_spans_from_batches(
|
||||||
start_seq_loc: torch.Tensor,
|
start_seq_loc: torch.Tensor, seq_len_per_batch: torch.Tensor,
|
||||||
seq_len_per_batch: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
device: torch.device) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
start_seq_loc: 1D long tensor [B+1], cumulative counts of
|
start_seq_loc: 1D long tensor [B+1], cumulative counts of
|
||||||
@@ -122,7 +128,7 @@ def kv_spans_from_batches(
|
|||||||
are the **last** `counts[i]` positions of that sequence.
|
are the **last** `counts[i]` positions of that sequence.
|
||||||
"""
|
"""
|
||||||
q = start_seq_loc.to(dtype=torch.long)
|
q = start_seq_loc.to(dtype=torch.long)
|
||||||
L = seq_len_per_batch.to(dtype=torch.long, device=q.device)
|
L = seq_len_per_batch.to(dtype=torch.long)
|
||||||
assert q.dim() == 1 and L.dim() == 1
|
assert q.dim() == 1 and L.dim() == 1
|
||||||
assert q.numel() == L.numel() + 1, "start_seq_loc must have length B+1"
|
assert q.numel() == L.numel() + 1, "start_seq_loc must have length B+1"
|
||||||
|
|
||||||
@@ -130,7 +136,6 @@ def kv_spans_from_batches(
|
|||||||
counts = q[1:] - q[:-1] # [B]
|
counts = q[1:] - q[:-1] # [B]
|
||||||
N = int(q[-1].item()) # total selected tokens
|
N = int(q[-1].item()) # total selected tokens
|
||||||
B = L.numel()
|
B = L.numel()
|
||||||
device = L.device
|
|
||||||
|
|
||||||
if N == 0:
|
if N == 0:
|
||||||
return (torch.empty(0, dtype=torch.long, device=device),
|
return (torch.empty(0, dtype=torch.long, device=device),
|
||||||
@@ -140,8 +145,7 @@ def kv_spans_from_batches(
|
|||||||
kv_starts_per_batch = torch.cumsum(L, dim=0) - L # [B]
|
kv_starts_per_batch = torch.cumsum(L, dim=0) - L # [B]
|
||||||
|
|
||||||
# For each selected token, which batch does it belong to?
|
# For each selected token, which batch does it belong to?
|
||||||
batch_id = torch.repeat_interleave(torch.arange(B, device=device),
|
batch_id = torch.repeat_interleave(torch.arange(B), counts) # [N]
|
||||||
counts) # [N]
|
|
||||||
|
|
||||||
# Map batch KV start to each token
|
# Map batch KV start to each token
|
||||||
start_tensor = kv_starts_per_batch[batch_id] # [N]
|
start_tensor = kv_starts_per_batch[batch_id] # [N]
|
||||||
@@ -151,22 +155,51 @@ def kv_spans_from_batches(
|
|||||||
L_expand = torch.repeat_interleave(L, counts) # [N]
|
L_expand = torch.repeat_interleave(L, counts) # [N]
|
||||||
m_expand = torch.repeat_interleave(counts, counts) # [N]
|
m_expand = torch.repeat_interleave(counts, counts) # [N]
|
||||||
# position within the selected block: 1..counts[b]
|
# position within the selected block: 1..counts[b]
|
||||||
pos_within = (torch.arange(N, device=device, dtype=torch.long) -
|
pos_within = (torch.arange(N, dtype=torch.long) -
|
||||||
torch.repeat_interleave(q[:-1], counts) + 1)
|
torch.repeat_interleave(q[:-1], counts) + 1)
|
||||||
|
|
||||||
local_pos = L_expand - m_expand + pos_within # [N], 1-based
|
local_pos = L_expand - m_expand + pos_within # [N], 1-based
|
||||||
end_location = start_tensor + local_pos # exclusive end
|
end_location = start_tensor + local_pos # exclusive end
|
||||||
|
|
||||||
return start_tensor.int(), end_location.int()
|
return start_tensor.int().to(device), end_location.int().to(device)
|
||||||
|
|
||||||
|
|
||||||
def get_max_prefill_buffer_size(vllm_config: VllmConfig):
|
def get_max_prefill_buffer_size(vllm_config: VllmConfig):
|
||||||
max_model_len = vllm_config.model_config.max_model_len
|
max_model_len = vllm_config.model_config.max_model_len
|
||||||
# max_num_batched_tokens = \
|
# NOTE(Chen): 2 is a magic number for controlling the prefill buffer size.
|
||||||
# vllm_config.scheduler_config.max_num_batched_tokens
|
# May be tuned later.
|
||||||
max_num_seq = vllm_config.scheduler_config.max_num_seqs
|
return max_model_len * 2
|
||||||
# NOTE(Chen): an estimated max size of flattened_kv. Need to double check.
|
|
||||||
return max_model_len * max_num_seq
|
|
||||||
|
def split_prefill_chunks(seq_lens_cpu: torch.Tensor,
|
||||||
|
max_prefill_buffer_size: int,
|
||||||
|
reqs_start: int) -> list[tuple[int, int]]:
|
||||||
|
"""
|
||||||
|
Split the prefill chunks into a list of tuples of (reqs_start, reqs_end)
|
||||||
|
such that the total sequence length of each chunk is less than the
|
||||||
|
maximum prefill buffer size.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seq_lens_cpu: The sequence lengths of the prefill requests.
|
||||||
|
max_prefill_buffer_size: The maximum prefill buffer size.
|
||||||
|
reqs_start: The start index of the prefill requests.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of tuples of (reqs_start, reqs_end).
|
||||||
|
"""
|
||||||
|
chunk_seq_ids = []
|
||||||
|
total_seq_lens = 0
|
||||||
|
for i in range(reqs_start, len(seq_lens_cpu)):
|
||||||
|
cur_seq_len = seq_lens_cpu[i].item()
|
||||||
|
assert cur_seq_len <= max_prefill_buffer_size
|
||||||
|
total_seq_lens += cur_seq_len
|
||||||
|
if total_seq_lens > max_prefill_buffer_size:
|
||||||
|
chunk_seq_ids.append((reqs_start, i))
|
||||||
|
reqs_start = i
|
||||||
|
total_seq_lens = cur_seq_len
|
||||||
|
if total_seq_lens > 0:
|
||||||
|
chunk_seq_ids.append((reqs_start, len(seq_lens_cpu)))
|
||||||
|
return chunk_seq_ids
|
||||||
|
|
||||||
|
|
||||||
class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
|
class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
|
||||||
@@ -201,6 +234,33 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
|
|||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=self.device)
|
device=self.device)
|
||||||
|
|
||||||
|
def build_one_prefill_chunk(self, reqs_start, reqs_end,
|
||||||
|
query_start_loc_cpu, seq_lens_cpu,
|
||||||
|
block_table):
|
||||||
|
prefill_query_start_loc = query_start_loc_cpu[
|
||||||
|
reqs_start:reqs_end + 1] - query_start_loc_cpu[reqs_start]
|
||||||
|
cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches(
|
||||||
|
prefill_query_start_loc, seq_lens_cpu[reqs_start:reqs_end],
|
||||||
|
self.device)
|
||||||
|
token_start = query_start_loc_cpu[reqs_start].item()
|
||||||
|
token_end = query_start_loc_cpu[reqs_end].item()
|
||||||
|
total_seq_lens = seq_lens_cpu[reqs_start:reqs_end].sum()
|
||||||
|
assert total_seq_lens <= self.max_prefill_buffer_size
|
||||||
|
cu_seq_lens = torch.cat([
|
||||||
|
torch.zeros(1, dtype=torch.int32),
|
||||||
|
seq_lens_cpu[reqs_start:reqs_end].cumsum(dim=0)
|
||||||
|
]).to(torch.int32).to(self.device)
|
||||||
|
return DeepseekV32IndexerPrefillChunkMetadata(
|
||||||
|
cu_seqlen_ks=cu_seqlen_ks,
|
||||||
|
cu_seqlen_ke=cu_seqlen_ke,
|
||||||
|
cu_seq_lens=cu_seq_lens,
|
||||||
|
total_seq_lens=total_seq_lens,
|
||||||
|
block_table=block_table[reqs_start:reqs_end],
|
||||||
|
token_start=token_start,
|
||||||
|
token_end=token_end,
|
||||||
|
num_reqs=reqs_end - reqs_start,
|
||||||
|
)
|
||||||
|
|
||||||
def build(self,
|
def build(self,
|
||||||
common_prefix_len: int,
|
common_prefix_len: int,
|
||||||
common_attn_metadata: CommonAttentionMetadata,
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
@@ -209,11 +269,7 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
|
|||||||
num_reqs = common_attn_metadata.num_reqs
|
num_reqs = common_attn_metadata.num_reqs
|
||||||
num_tokens = common_attn_metadata.num_actual_tokens
|
num_tokens = common_attn_metadata.num_actual_tokens
|
||||||
|
|
||||||
device = self.device
|
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||||
block_table_tensor = common_attn_metadata.block_table_tensor
|
|
||||||
|
|
||||||
query_start_loc = common_attn_metadata.query_start_loc
|
|
||||||
|
|
||||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
|
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
|
||||||
split_decodes_and_prefills(
|
split_decodes_and_prefills(
|
||||||
common_attn_metadata,
|
common_attn_metadata,
|
||||||
@@ -224,27 +280,20 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
|
|||||||
|
|
||||||
prefill_metadata = None
|
prefill_metadata = None
|
||||||
if num_prefills > 0:
|
if num_prefills > 0:
|
||||||
reqs_start = num_decodes
|
chunk_seq_ids = split_prefill_chunks(
|
||||||
prefill_query_start_loc = query_start_loc[
|
common_attn_metadata.seq_lens_cpu,
|
||||||
reqs_start:] - query_start_loc[reqs_start]
|
self.max_prefill_buffer_size,
|
||||||
cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches(
|
num_decodes,
|
||||||
prefill_query_start_loc,
|
|
||||||
common_attn_metadata.seq_lens[reqs_start:])
|
|
||||||
total_seq_lens = common_attn_metadata.seq_lens[reqs_start:].sum()
|
|
||||||
assert total_seq_lens < self.max_prefill_buffer_size
|
|
||||||
cu_seq_lens = torch.cat([
|
|
||||||
torch.zeros(1, dtype=torch.int32, device=device),
|
|
||||||
common_attn_metadata.seq_lens[reqs_start:].cumsum(dim=0)
|
|
||||||
]).to(torch.int32).cuda()
|
|
||||||
prefill_metadata = DeepseekV32IndexerPrefillMetadata(
|
|
||||||
block_table=block_table_tensor[reqs_start:, ...],
|
|
||||||
query_start_loc=prefill_query_start_loc,
|
|
||||||
max_query_len=common_attn_metadata.max_query_len,
|
|
||||||
cu_seqlen_ks=cu_seqlen_ks,
|
|
||||||
cu_seqlen_ke=cu_seqlen_ke,
|
|
||||||
cu_seq_lens=cu_seq_lens,
|
|
||||||
total_seq_lens=total_seq_lens,
|
|
||||||
)
|
)
|
||||||
|
chunks = [
|
||||||
|
self.build_one_prefill_chunk(
|
||||||
|
reqs_start, reqs_end, query_start_loc_cpu,
|
||||||
|
common_attn_metadata.seq_lens_cpu,
|
||||||
|
common_attn_metadata.block_table_tensor)
|
||||||
|
for reqs_start, reqs_end in chunk_seq_ids
|
||||||
|
]
|
||||||
|
prefill_metadata = DeepseekV32IndexerPrefillMetadata(
|
||||||
|
chunks=chunks, )
|
||||||
|
|
||||||
decode_metadata = None
|
decode_metadata = None
|
||||||
if num_decodes > 0:
|
if num_decodes > 0:
|
||||||
|
|||||||
Reference in New Issue
Block a user