# ==============================================================================
# Triton Kernels Build (TFA) - vLLM v0.19.0 + triton_kernels
# ==============================================================================
# This branch adds triton_kernels from Triton v3.6.0 for MoE support.
#
# Based on working Build #43 (v0.18.2rc0) with vLLM upgraded to v0.19.0:
#   - vLLM: v0.19.0
#   - flashinfer: v0.6.6
#   - flash-attention: hopper branch
#   - lmcache: dev branch
#   - infinistore: main
#   - triton: 3.6.0 (PyPI wheel)
#   - triton_kernels: v3.6.0 (from Triton repo)
#   - Base: nvcr.io/nvidia/pytorch:26.03-py3 (PyTorch 2.11.0a0, CUDA 13.2.0)
#
# HARD RULES:
#   1. NO DOWNGRADES - CUDA 13+, PyTorch 2.9+, vLLM 0.18.1+
#   2. NO SKIPPING COMPILATION - Build from source
#   3. CLEAR ALL CHANGES WITH MIKE BEFORE MAKING THEM
#   4. ONE BUILD AT A TIME - Mike reports failure → I assess → I report
#
# Image tag: gh200-vllm-tfa:v0.19.0-tfa
# ==============================================================================

# ---------- Builder Base ----------
# Using NVIDIA NGC PyTorch container (26.03) with:
# - PyTorch 2.11.0a0 (bleeding edge)
# - CUDA 13.2.0
# - cuDNN 9.20, NCCL 2.29.7, TensorRT 10.16, TransformerEngine 2.13
# - Multi-arch: x86 + ARM SBSA (GH200 support)
FROM nvcr.io/nvidia/pytorch:26.03-py3 AS base

# Set arch lists for all targets
# 'a' suffix is not forward compatible but enables all optimizations
ARG TORCH_CUDA_ARCH_LIST="9.0a"
ENV TORCH_CUDA_ARCH_LIST=${TORCH_CUDA_ARCH_LIST}
ARG VLLM_FA_CMAKE_GPU_ARCHES="90a-real"
ENV VLLM_FA_CMAKE_GPU_ARCHES=${VLLM_FA_CMAKE_GPU_ARCHES}

# Install additional build dependencies
ENV DEBIAN_FRONTEND=noninteractive
RUN apt update && apt install -y --no-install-recommends \
        curl \
        git \
        libibverbs-dev \
        zlib1g-dev \
        libnuma-dev \
        wget \
    && apt clean \
    && rm -rf /var/lib/apt/lists/* /var/cache/apt/archives

# Set compiler paths
ENV CC=/usr/bin/gcc
ENV CXX=/usr/bin/g++

# Install uv for faster package management
RUN curl -LsSf https://astral.sh/uv/install.sh | env UV_INSTALL_DIR=/usr/local/bin sh

# Setup build workspace
WORKDIR /workspace

# Environment setup (PyTorch container already has CUDA paths set)
ENV CUDA_HOME=/usr/local/cuda
ENV LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH}
ENV CPLUS_INCLUDE_PATH=${CUDA_HOME}/include/cccl
ENV C_INCLUDE_PATH=${CUDA_HOME}/include/cccl
ENV PATH=${CUDA_HOME}/cuda/bin:${PATH}

# Use the Python environment from the container
# The NGC container already has a working Python/PyTorch setup

FROM base AS build-base
RUN mkdir /wheels

# Install build deps that aren't in project requirements files
# Pin setuptools to <81 for LMCache compatibility (needs >=77.0.3,<81.0.0)
# Note: wheel is already installed in NGC container, don't try to upgrade it
RUN pip install -U build cmake ninja pybind11 "setuptools>=77.0.3,<81.0.0"

# Use PyPI triton wheel instead of building (QEMU segfaults during triton build)
FROM build-base AS build-triton
RUN mkdir -p /wheels && \
    pip download triton==3.6.0 --platform manylinux_2_27_aarch64 --only-binary=:all: --no-deps -d /wheels

# Install triton_kernels from Triton repo (v3.6.0) for MoE support
# vLLM v0.19.0 requires this for triton_kernels.matmul_ogs module
FROM build-base AS build-triton-kernels
RUN pip install --target=/wheels git+https://github.com/triton-lang/triton.git@v3.6.0#subdirectory=python/triton_kernels

# Skip xformers - vLLM has built-in FlashAttention kernels
# xformers requires TORCH_STABLE_ONLY which needs PyTorch headers not in 2.9.0
# FROM build-base AS build-xformers
# RUN git clone https://github.com/facebookresearch/xformers.git
# RUN cd xformers && \
#     git submodule sync && \
#     git submodule update --init --recursive -j 8 && \
#     MAX_JOBS=8 pip build --wheel --no-build-isolation -o /wheels

FROM build-base AS build-flashinfer
ARG FLASHINFER_ENABLE_AOT=1
# flashinfer version compatibility:
#   - v0.6.7 works with vLLM v0.18.2rc0 (Build #43)
#   - v0.6.6 works with vLLM v0.19.0 (for Gemma 4 support)
# ARG FLASHINFER_REF=v0.6.7  # For vLLM v0.18.2rc0
ARG FLASHINFER_REF=v0.6.6
ARG FLASHINFER_BUILD_SUFFIX=cu132
ENV FLASHINFER_LOCAL_VERSION=${FLASHINFER_BUILD_SUFFIX:-}
RUN git clone https://github.com/flashinfer-ai/flashinfer.git
RUN pip install "apache-tvm-ffi>=0.1.6,<0.2,!=0.1.8,!=0.1.8.post0"
RUN cd flashinfer && \
    git checkout ${FLASHINFER_REF} && \
    git submodule sync && \
    git submodule update --init --recursive -j 8 && \
    python -m build --wheel --no-isolation -o /wheels

FROM build-base AS build-lmcache
# Bleeding edge: build from dev branch (v0.4.2+)
RUN git clone https://github.com/LMCache/LMCache.git && \
    cd LMCache && \
    git checkout dev && \
    echo "\n\n========================================" && \
    echo ">>> BUILDING LMCACHE FROM:" && \
    echo ">>> BRANCH: $(git rev-parse --abbrev-ref HEAD)" && \
    echo ">>> COMMIT: $(git rev-parse HEAD)" && \
    echo ">>> DATE:   $(git log -1 --format=%cd --date=short)" && \
    echo "========================================\n\n" && \
    sed -i '/torch/d' pyproject.toml && \
    pip install setuptools_scm && \
    MAX_JOBS=8 python -m build --wheel --no-isolation && \
    cp dist/*.whl /wheels/


FROM build-base AS build-flash-attention
RUN apt-get update && apt-get install -y build-essential cmake gcc && \
    git clone https://github.com/Dao-AILab/flash-attention flash-attention && \
    cd flash-attention/hopper && \
    mkdir wheels && \
    export MAX_JOBS=8 && \
    export NVCC_THREADS=4 && \
    export CMAKE_BUILD_PARALLEL_LEVEL=$MAX_JOBS && \
    MAX_JOBS=$MAX_JOBS \
    CMAKE_BUILD_PARALLEL_LEVEL=$MAX_JOBS \
    FLASH_ATTENTION_FORCE_BUILD="TRUE" \
    FLASH_ATTENTION_FORCE_CXX11_ABI="FALSE" \
    FLASH_ATTENTION_SKIP_CUDA_BUILD="FALSE" \
    pip wheel . -v --no-deps --no-build-isolation -w ./wheels/ && \
    cp wheels/*.whl /wheels/

# ==============================================================================
# Build vLLM from source
# ==============================================================================
FROM build-base AS build-vllm
# vLLM version to build
ARG VLLM_REF=v0.19.0
# Install ccache for faster compilation
RUN apt-get update && apt-get install -y ccache
RUN git clone https://github.com/vllm-project/vllm.git
RUN cd vllm && \
    git checkout ${VLLM_REF} && \
    echo "\n\n========================================" && \
    echo ">>> BUILDING VLLM FROM:" && \
    echo ">>> BRANCH: $(git rev-parse --abbrev-ref HEAD)" && \
    echo ">>> COMMIT: $(git rev-parse HEAD)" && \
    echo ">>> DATE:   $(git log -1 --format=%cd --date=short)" && \
    echo ">>> TAG:    $(git describe --tags --always 2>/dev/null || echo 'no tag')" && \
    echo "========================================\n\n" && \
    git submodule sync && \
    git submodule update --init --recursive -j 8 && \
    sed -i 's/GIT_TAG [a-f0-9]\{40\}/GIT_TAG main/' cmake/external_projects/vllm_flash_attn.cmake && \
    sed -i 's/register_opaque_type(ModuleName, typ="value", hoist=True)/register_opaque_type(ModuleName, typ="value")/' vllm/utils/torch_utils.py && \
    export MAX_JOBS=8 && \
    export CMAKE_BUILD_PARALLEL_LEVEL=$MAX_JOBS && \
    python use_existing_torch.py && \
    pip install -r requirements/build.txt && \
    CCACHE_NOHASHDIR="true" python -m build --wheel --no-isolation -o /wheels

# Build infinistore after vllm to avoid cache invalidation
FROM build-base AS build-infinistore
# Install additional dependencies needed for building infinistore on aarch64
RUN apt update && apt install -y cmake pybind11-dev python3-dev libuv1-dev libspdlog-dev libboost-dev libboost-all-dev meson

# Build flatbuffers from source with proper CMake version
RUN git clone -b v1.12.0 https://github.com/google/flatbuffers.git && \
  cd flatbuffers && \
  cmake -B build -DFLATBUFFERS_BUILD_TESTS=OFF -DCMAKE_POLICY_VERSION_MINIMUM=3.5 && \
  cmake --build build -j && \
  cmake --install build

# Build InfiniStore from source as a Python package
RUN git clone https://github.com/bytedance/InfiniStore && \
    cd InfiniStore && \
    pip install meson && \
    pip install --no-deps --no-build-isolation -e . && \
    pip uninstall -y infinistore && \
    python -m build --wheel --no-isolation && \
    cp dist/*.whl /wheels/

FROM base AS vllm-openai
COPY --from=build-flash-attention /wheels/* wheels/
COPY --from=build-flashinfer /wheels/* wheels/
COPY --from=build-triton /wheels/* wheels/
COPY --from=build-triton-kernels /wheels/triton_kernels /usr/local/lib/python3.12/dist-packages/triton_kernels
COPY --from=build-vllm /wheels/* wheels/
COPY --from=build-lmcache /wheels/* wheels/
COPY --from=build-infinistore /wheels/* wheels/

# Install wheels (infinistore is now built as a wheel)
RUN pip install wheels/*
RUN rm -r wheels

# Install pynvml
RUN pip install pynvml pandas

# Add additional packages for vLLM OpenAI
# Bleeding edge: latest transformers
RUN pip install accelerate hf_transfer modelscope bitsandbytes timm boto3 runai-model-streamer runai-model-streamer[s3] tensorizer transformers --upgrade

# Clean pip cache
RUN pip cache purge || true

# Install build tools and dependencies
RUN pip install -U build cmake ninja pybind11 setuptools==79.0.1

# Enable hf-transfer
ENV HF_HUB_ENABLE_HF_TRANSFER=1
RUN pip install datasets aiohttp

# Install nsys for profiling
ARG NSYS_URL=https://developer.nvidia.com/downloads/assets/tools/secure/nsight-systems/2025_5/
ARG NSYS_PKG=nsight-systems-cli-2025.5.1_2025.5.1.121-1_arm64.deb
RUN apt-get update && apt install -y wget libglib2.0-0
RUN wget ${NSYS_URL}${NSYS_PKG} && dpkg -i $NSYS_PKG && rm $NSYS_PKG
RUN apt install -y --no-install-recommends tmux cmake

# Deprecated cleanup
RUN pip uninstall -y pynvml && pip install nvidia-ml-py

# Copy over nemotron reasonong parser
COPY ./super_v3_reasoning_parser.py /opt/super_v3_reasoning_parser.py

# Copy vLLM shim that intercepts --model to download custom weights from URLs
COPY vllm_shim_module.py /opt/vllm-shim/vllm_shim_module.py

# Shadow `python -m vllm.*` invocations via PYTHONPATH
# The shim masquerades as the vllm package so python -m vllm/entrypoints/openai/api_server
# hits our interceptor first, which downloads weights then execs the real vLLM
RUN mkdir -p /opt/vllm-shim/vllm/entrypoints/openai \
             /opt/vllm-shim/vllm/entrypoints/cli && \
    cp /opt/vllm-shim/vllm_shim_module.py /opt/vllm-shim/vllm/__main__.py && \
    cp /opt/vllm-shim/vllm_shim_module.py /opt/vllm-shim/vllm/entrypoints/openai/api_server.py && \
    cp /opt/vllm-shim/vllm_shim_module.py /opt/vllm-shim/vllm/entrypoints/cli/main.py && \
    touch /opt/vllm-shim/vllm/__init__.py \
          /opt/vllm-shim/vllm/entrypoints/__init__.py \
          /opt/vllm-shim/vllm/entrypoints/openai/__init__.py \
          /opt/vllm-shim/vllm/entrypoints/cli/__init__.py

ENV PYTHONPATH=/opt/vllm-shim
ENV PYTHONUNBUFFERED=1

# API server entrypoint
# ENTRYPOINT ["vllm", "serve"]
#CMD ["/bin/bash"]