ARG BASE_IMAGE=rocm/dev-ubuntu-22.04:7.0-complete ARG TRITON_BRANCH="57c693b6" ARG TRITON_REPO="https://github.com/ROCm/triton.git" ARG PYTORCH_BRANCH="89075173" ARG PYTORCH_REPO="https://github.com/ROCm/pytorch.git" ARG PYTORCH_VISION_BRANCH="v0.24.1" ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git" ARG PYTORCH_AUDIO_BRANCH="v2.9.0" ARG PYTORCH_AUDIO_REPO="https://github.com/pytorch/audio.git" ARG FA_BRANCH="0e60e394" ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git" ARG AITER_BRANCH="6af8b687" ARG AITER_REPO="https://github.com/ROCm/aiter.git" ARG MORI_BRANCH="2d02c6a9" ARG MORI_REPO="https://github.com/ROCm/mori.git" #TODO: When patch has been upstreamed, switch to the main repo/branch # ARG RIXL_BRANCH="" # ARG RIXL_REPO="https://github.com/ROCm/RIXL.git" ARG RIXL_BRANCH="50d63d94" ARG RIXL_REPO="https://github.com/vcave/RIXL.git" # Needed by RIXL ARG ETCD_BRANCH="7c6e714f" ARG ETCD_REPO="https://github.com/etcd-cpp-apiv3/etcd-cpp-apiv3.git" ARG UCX_BRANCH="da3fac2a" ARG UCX_REPO="https://github.com/ROCm/ucx.git" FROM ${BASE_IMAGE} AS base ENV PATH=/opt/rocm/llvm/bin:/opt/rocm/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin ENV ROCM_PATH=/opt/rocm ENV LD_LIBRARY_PATH=/opt/rocm/lib:/usr/local/lib: ARG PYTORCH_ROCM_ARCH=gfx90a;gfx942;gfx950;gfx1100;gfx1101;gfx1200;gfx1201;gfx1150;gfx1151 ENV PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH} ENV AITER_ROCM_ARCH=gfx942;gfx950 ENV MORI_GPU_ARCHS=gfx942;gfx950 # Required for RCCL in ROCm7.1 ENV HSA_NO_SCRATCH_RECLAIM=1 ARG PYTHON_VERSION=3.12 ENV PYTHON_VERSION=${PYTHON_VERSION} RUN mkdir -p /app WORKDIR /app ENV DEBIAN_FRONTEND=noninteractive # Install Python and other dependencies RUN apt-get update -y \ && apt-get install -y software-properties-common git curl sudo vim less libgfortran5 libopenmpi-dev libpci-dev \ && for i in 1 2 3; do \ add-apt-repository -y ppa:deadsnakes/ppa && break || \ { echo "Attempt $i failed, retrying in 5s..."; sleep 5; }; \ done \ && apt-get update -y \ && apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \ python${PYTHON_VERSION}-lib2to3 python-is-python3 \ && update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 \ && update-alternatives --set python3 /usr/bin/python${PYTHON_VERSION} \ && ln -sf /usr/bin/python${PYTHON_VERSION}-config /usr/bin/python3-config \ && curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION} \ && python3 --version && python3 -m pip --version RUN pip install -U packaging 'cmake<4' ninja wheel 'setuptools<80' pybind11 Cython RUN apt-get update && apt-get install -y libjpeg-dev libsox-dev libsox-fmt-all sox && rm -rf /var/lib/apt/lists/* ### ### Triton Build ### FROM base AS build_triton ARG TRITON_BRANCH ARG TRITON_REPO RUN git clone ${TRITON_REPO} RUN cd triton \ && git checkout ${TRITON_BRANCH} \ && if [ ! -f setup.py ]; then cd python; fi \ && python3 setup.py bdist_wheel --dist-dir=dist \ && mkdir -p /app/install && cp dist/*.whl /app/install RUN if [ -d triton/python/triton_kernels ]; then pip install build && cd triton/python/triton_kernels \ && python3 -m build --wheel && cp dist/*.whl /app/install; fi ### ### AMD SMI Build ### FROM base AS build_amdsmi RUN cd /opt/rocm/share/amd_smi \ && pip wheel . --wheel-dir=dist RUN mkdir -p /app/install && cp /opt/rocm/share/amd_smi/dist/*.whl /app/install ### ### Pytorch build ### FROM base AS build_pytorch ARG PYTORCH_BRANCH ARG PYTORCH_VISION_BRANCH ARG PYTORCH_AUDIO_BRANCH ARG PYTORCH_REPO ARG PYTORCH_VISION_REPO ARG PYTORCH_AUDIO_REPO RUN git clone ${PYTORCH_REPO} pytorch RUN cd pytorch && git checkout ${PYTORCH_BRANCH} \ && pip install -r requirements.txt && git submodule update --init --recursive \ && python3 tools/amd_build/build_amd.py \ && CMAKE_PREFIX_PATH=$(python3 -c 'import sys; print(sys.prefix)') python3 setup.py bdist_wheel --dist-dir=dist \ && pip install dist/*.whl RUN git clone ${PYTORCH_VISION_REPO} vision RUN cd vision && git checkout ${PYTORCH_VISION_BRANCH} \ && python3 setup.py bdist_wheel --dist-dir=dist \ && pip install dist/*.whl RUN git clone ${PYTORCH_AUDIO_REPO} audio RUN cd audio && git checkout ${PYTORCH_AUDIO_BRANCH} \ && git submodule update --init --recursive \ && pip install -r requirements.txt \ && python3 setup.py bdist_wheel --dist-dir=dist \ && pip install dist/*.whl RUN mkdir -p /app/install && cp /app/pytorch/dist/*.whl /app/install \ && cp /app/vision/dist/*.whl /app/install \ && cp /app/audio/dist/*.whl /app/install ### ### MORI Build ### FROM base AS build_mori ARG MORI_BRANCH ARG MORI_REPO RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \ pip install /install/*.whl RUN git clone ${MORI_REPO} RUN cd mori \ && git checkout ${MORI_BRANCH} \ && git submodule update --init --recursive \ && python3 setup.py bdist_wheel --dist-dir=dist && ls /app/mori/dist/*.whl RUN mkdir -p /app/install && cp /app/mori/dist/*.whl /app/install ### ### RIXL Build ### FROM build_pytorch AS build_rixl ARG RIXL_BRANCH ARG RIXL_REPO ARG ETCD_BRANCH ARG ETCD_REPO ARG UCX_BRANCH ARG UCX_REPO ENV ROCM_PATH=/opt/rocm ENV UCX_HOME=/usr/local/ucx ENV RIXL_HOME=/usr/local/rixl ENV RIXL_BENCH_HOME=/usr/local/rixl_bench # RIXL build system dependences and RDMA support RUN apt-get -y update && apt-get -y install autoconf libtool pkg-config \ libgrpc-dev \ libgrpc++-dev \ libprotobuf-dev \ protobuf-compiler-grpc \ libcpprest-dev \ libaio-dev \ librdmacm1 \ librdmacm-dev \ libibverbs1 \ libibverbs-dev \ ibverbs-utils \ rdmacm-utils \ ibverbs-providers RUN pip install meson auditwheel patchelf tomlkit WORKDIR /workspace RUN git clone ${ETCD_REPO} && \ cd etcd-cpp-apiv3 && \ git checkout ${ETCD_BRANCH} && \ mkdir build && cd build && \ cmake .. -DCMAKE_POLICY_VERSION_MINIMUM=3.5 && \ make -j$(nproc) && \ make install RUN cd /usr/local/src && \ git clone ${UCX_REPO} && \ cd ucx && \ git checkout ${UCX_BRANCH} && \ ./autogen.sh && \ mkdir build && cd build && \ ../configure \ --prefix=/usr/local/ucx \ --enable-shared \ --disable-static \ --disable-doxygen-doc \ --enable-optimizations \ --enable-devel-headers \ --with-rocm=/opt/rocm \ --with-verbs \ --with-dm \ --enable-mt && \ make -j && \ make -j install ENV PATH=/usr/local/ucx/bin:$PATH ENV LD_LIBRARY_PATH=${UCX_HOME}/lib:${LD_LIBRARY_PATH} RUN git clone ${RIXL_REPO} /opt/rixl && \ cd /opt/rixl && \ git checkout ${RIXL_BRANCH} && \ meson setup build --prefix=${RIXL_HOME} \ -Ducx_path=${UCX_HOME} \ -Drocm_path=${ROCM_PATH} && \ cd build && \ ninja && \ ninja install # Generate RIXL wheel RUN cd /opt/rixl && mkdir -p /app/install && \ ./contrib/build-wheel.sh \ --output-dir /app/install \ --rocm-dir ${ROCM_PATH} \ --ucx-plugins-dir ${UCX_HOME}/lib/ucx \ --nixl-plugins-dir ${RIXL_HOME}/lib/x86_64-linux-gnu/plugins ### ### FlashAttention Build ### FROM base AS build_fa ARG FA_BRANCH ARG FA_REPO RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \ pip install /install/*.whl RUN git clone ${FA_REPO} RUN cd flash-attention \ && git checkout ${FA_BRANCH} \ && git submodule update --init \ && GPU_ARCHS=$(echo ${PYTORCH_ROCM_ARCH} | sed -e 's/;gfx1[0-9]\{3\}//g') python3 setup.py bdist_wheel --dist-dir=dist RUN mkdir -p /app/install && cp /app/flash-attention/dist/*.whl /app/install ### ### AITER Build ### FROM base AS build_aiter ARG AITER_BRANCH ARG AITER_REPO RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \ pip install /install/*.whl RUN git clone --recursive ${AITER_REPO} RUN cd aiter \ && git checkout ${AITER_BRANCH} \ && git submodule update --init --recursive \ && pip install -r requirements.txt RUN pip install pyyaml && cd aiter && PREBUILD_KERNELS=1 GPU_ARCHS=${AITER_ROCM_ARCH} python3 setup.py bdist_wheel --dist-dir=dist && ls /app/aiter/dist/*.whl RUN mkdir -p /app/install && cp /app/aiter/dist/*.whl /app/install ### ### Final Build ### FROM base AS debs RUN mkdir /app/debs RUN --mount=type=bind,from=build_triton,src=/app/install/,target=/install \ cp /install/*.whl /app/debs RUN --mount=type=bind,from=build_fa,src=/app/install/,target=/install \ cp /install/*.whl /app/debs RUN --mount=type=bind,from=build_amdsmi,src=/app/install/,target=/install \ cp /install/*.whl /app/debs RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \ cp /install/*.whl /app/debs RUN --mount=type=bind,from=build_aiter,src=/app/install/,target=/install \ cp /install/*.whl /app/debs RUN --mount=type=bind,from=build_mori,src=/app/install/,target=/install \ cp /install/*.whl /app/debs RUN --mount=type=bind,from=build_rixl,src=/app/install/,target=/install \ cp /install/*.whl /app/debs FROM base AS final RUN --mount=type=bind,from=debs,src=/app/debs,target=/install \ pip install /install/*.whl ARG BASE_IMAGE ARG TRITON_BRANCH ARG TRITON_REPO ARG PYTORCH_BRANCH ARG PYTORCH_VISION_BRANCH ARG PYTORCH_REPO ARG PYTORCH_VISION_REPO ARG PYTORCH_AUDIO_BRANCH ARG PYTORCH_AUDIO_REPO ARG FA_BRANCH ARG FA_REPO ARG AITER_BRANCH ARG AITER_REPO ARG RIXL_BRANCH ARG RIXL_REPO ARG ETCD_BRANCH ARG ETCD_REPO ARG UCX_BRANCH ARG UCX_REPO ARG MORI_BRANCH ARG MORI_REPO RUN echo "BASE_IMAGE: ${BASE_IMAGE}" > /app/versions.txt \ && echo "TRITON_BRANCH: ${TRITON_BRANCH}" >> /app/versions.txt \ && echo "TRITON_REPO: ${TRITON_REPO}" >> /app/versions.txt \ && echo "PYTORCH_BRANCH: ${PYTORCH_BRANCH}" >> /app/versions.txt \ && echo "PYTORCH_VISION_BRANCH: ${PYTORCH_VISION_BRANCH}" >> /app/versions.txt \ && echo "PYTORCH_REPO: ${PYTORCH_REPO}" >> /app/versions.txt \ && echo "PYTORCH_VISION_REPO: ${PYTORCH_VISION_REPO}" >> /app/versions.txt \ && echo "PYTORCH_AUDIO_BRANCH: ${PYTORCH_AUDIO_BRANCH}" >> /app/versions.txt \ && echo "PYTORCH_AUDIO_REPO: ${PYTORCH_AUDIO_REPO}" >> /app/versions.txt \ && echo "FA_BRANCH: ${FA_BRANCH}" >> /app/versions.txt \ && echo "FA_REPO: ${FA_REPO}" >> /app/versions.txt \ && echo "AITER_BRANCH: ${AITER_BRANCH}" >> /app/versions.txt \ && echo "AITER_REPO: ${AITER_REPO}" >> /app/versions.txt \ && echo "RIXL_BRANCH: ${RIXL_BRANCH}" >> /app/versions.txt \ && echo "RIXL_REPO: ${RIXL_REPO}" >> /app/versions.txt \ && echo "ETCD_BRANCH: ${ETCD_BRANCH}" >> /app/versions.txt \ && echo "ETCD_REPO: ${ETCD_REPO}" >> /app/versions.txt \ && echo "UCX_BRANCH: ${UCX_BRANCH}" >> /app/versions.txt \ && echo "UCX_REPO: ${UCX_REPO}" >> /app/versions.txt \ && echo "MORI_BRANCH: ${MORI_BRANCH}" >> /app/versions.txt \ && echo "MORI_REPO: ${MORI_REPO}" >> /app/versions.txt