Fix MHC import: don't import .torch from layers/mhc.py

The layers/mhc.py was trying to import kernels.mhc.torch which
failed because our __init__.py was breaking the package. Instead,
just import our mhc_torch_ops which has everything we need.

Also fix __init__.py to explicitly import mhc_pre_torch and
mhc_post_torch from .torch instead of using import *.
This commit is contained in:
2026-05-19 05:36:35 +00:00
parent e404e18efb
commit dfd9c10ae9
2 changed files with 3 additions and 5 deletions

View File

@@ -42,12 +42,12 @@ COPY vllm/patches/layers/deepseek_compressor.py ${VLLM_LAYERS_DIR}/deepseek_comp
# Replace MHC TileLang kernels with pure PyTorch (avoids TileLang JIT on Blackwell)
# 1. Patch layers/mhc.py — CustomOp dispatch uses torch impls instead of tilelang
# 2. Install our torch op registrations (mhc_torch_ops.py)
# 3. Patch kernels/mhc/__init__.py to not import tilelang
# 3. Patch kernels/mhc/__init__.py to not import tilelang/aiter
ARG VLLM_MHC_KERNELS_DIR=/usr/local/lib/python3.12/dist-packages/vllm/model_executor/kernels/mhc
COPY vllm/patches/layers/mhc.py ${VLLM_LAYERS_DIR}/mhc.py
COPY vllm/patches/kernels/mhc_torch_ops.py ${VLLM_MHC_KERNELS_DIR}/mhc_torch_ops.py
RUN echo 'from .torch import *' > ${VLLM_MHC_KERNELS_DIR}/__init__.py && \
echo 'from .mhc_torch_ops import *' >> ${VLLM_MHC_KERNELS_DIR}/__init__.py
# Rewrite __init__.py: import torch impls + our custom ops, skip tilelang/aiter
RUN printf 'from .torch import mhc_pre_torch, mhc_post_torch\nfrom .mhc_torch_ops import *\n' > ${VLLM_MHC_KERNELS_DIR}/__init__.py
# CuTeDSL NVFP4 linear kernel (registered as NvFp4LinearKernel)
ARG VLLM_NVFP4_DIR=/usr/local/lib/python3.12/dist-packages/vllm/model_executor/kernels/linear/nvfp4

View File

@@ -8,8 +8,6 @@ from vllm.model_executor.custom_op import CustomOp
# Import our torch implementations (registers torch.ops.vllm.mhc_pre, etc.)
import vllm.model_executor.kernels.mhc.mhc_torch_ops as _mhc_torch # noqa: F401
# Also import the original torch impls (mhc_pre_torch, mhc_post_torch)
import vllm.model_executor.kernels.mhc.torch as mhc_kernels # noqa: F401
@CustomOp.register("mhc_pre")