diff --git a/Dockerfile b/Dockerfile index 23905ed0..3f05472d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -40,8 +40,11 @@ COPY vllm/patches/deepseek_v4_attention.py ${VLLM_LAYERS_DIR}/deepseek_v4_attent COPY vllm/patches/layers/deepseek_compressor.py ${VLLM_LAYERS_DIR}/deepseek_compressor.py # Replace MHC TileLang kernels with pure PyTorch (avoids TileLang JIT on Blackwell) -# Install our torch op implementations and patch the MHC kernels __init__ +# 1. Patch layers/mhc.py — CustomOp dispatch uses torch impls instead of tilelang +# 2. Install our torch op registrations (mhc_torch_ops.py) +# 3. Patch kernels/mhc/__init__.py to not import tilelang ARG VLLM_MHC_KERNELS_DIR=/usr/local/lib/python3.12/dist-packages/vllm/model_executor/kernels/mhc +COPY vllm/patches/layers/mhc.py ${VLLM_LAYERS_DIR}/mhc.py COPY vllm/patches/kernels/mhc_torch_ops.py ${VLLM_MHC_KERNELS_DIR}/mhc_torch_ops.py RUN echo 'from .torch import *' > ${VLLM_MHC_KERNELS_DIR}/__init__.py && \ echo 'from .mhc_torch_ops import *' >> ${VLLM_MHC_KERNELS_DIR}/__init__.py diff --git a/vllm/patches/layers/mhc.py b/vllm/patches/layers/mhc.py new file mode 100644 index 00000000..b89381c9 --- /dev/null +++ b/vllm/patches/layers/mhc.py @@ -0,0 +1,102 @@ +# SPDX-License-Identifier: Apache-2.0 +# Patched MHC layer — replaces TileLang kernels with pure PyTorch. +# This avoids TileLang JIT compilation on Blackwell (SM100). + +import torch + +from vllm.model_executor.custom_op import CustomOp + +# Import our torch implementations (registers torch.ops.vllm.mhc_pre, etc.) +import vllm.model_executor.kernels.mhc.mhc_torch_ops as _mhc_torch # noqa: F401 +# Also import the original torch impls (mhc_pre_torch, mhc_post_torch) +import vllm.model_executor.kernels.mhc.torch as mhc_kernels # noqa: F401 + + +@CustomOp.register("mhc_pre") +class MHCPreOp(CustomOp): + @classmethod + def enabled(cls) -> bool: + return True + + def forward_cuda(self, residual, fn, hc_scale, hc_base, + rms_eps, hc_pre_eps, hc_sinkhorn_eps, + hc_post_mult_value, sinkhorn_repeat, n_splits=1): + return _mhc_torch.mhc_pre( + residual, fn, hc_scale, hc_base, + rms_eps, hc_pre_eps, hc_sinkhorn_eps, + hc_post_mult_value, sinkhorn_repeat, n_splits, + ) + + def forward_hip(self, *args, **kwargs): + return self.forward_cuda(*args, **kwargs) + + def forward_native(self, *args, **kwargs): + return self.forward_cuda(*args, **kwargs) + + +@CustomOp.register("mhc_post") +class MHCPostOp(CustomOp): + @classmethod + def enabled(cls) -> bool: + return True + + def forward_cuda(self, x, residual, post_layer_mix, comb_res_mix): + return _mhc_torch.mhc_post(x, residual, post_layer_mix, comb_res_mix) + + def forward_hip(self, *args, **kwargs): + return self.forward_cuda(*args, **kwargs) + + def forward_native(self, *args, **kwargs): + return self.forward_cuda(*args, **kwargs) + + +@CustomOp.register("hc_head") +class HCHeadOp(CustomOp): + @classmethod + def enabled(cls) -> bool: + return True + + def forward_cuda(self, hidden_states, hc_fn, hc_scale, hc_base, + rms_norm_eps, hc_eps): + hc_mult, hidden_size = hidden_states.shape[-2:] + outer_shape = hidden_states.shape[:-2] + hs_flat = hidden_states.view(-1, hc_mult, hidden_size) + out = torch.empty( + hs_flat.shape[0], hidden_size, + dtype=torch.bfloat16, device=hidden_states.device, + ) + _mhc_torch.hc_head_fused_kernel( + hs_flat, hc_fn, hc_scale, hc_base, + out, hidden_size, rms_norm_eps, hc_eps, hc_mult, + ) + return out.view(*outer_shape, hidden_size) + + def forward_hip(self, *args, **kwargs): + return self.forward_cuda(*args, **kwargs) + + def forward_native(self, *args, **kwargs): + return self.forward_cuda(*args, **kwargs) + + +@CustomOp.register("mhc_fused_post_pre") +class MHCFusedPostPreOp(CustomOp): + @classmethod + def enabled(cls) -> bool: + return True + + def forward_cuda(self, x, residual, post_layer_mix, comb_res_mix, + fn, hc_scale, hc_base, rms_eps, hc_pre_eps, + hc_sinkhorn_eps, hc_post_mult_value, sinkhorn_repeat, + n_splits=1, tile_n=1): + return _mhc_torch.mhc_fused_post_pre( + x, residual, post_layer_mix, comb_res_mix, + fn, hc_scale, hc_base, + rms_eps, hc_pre_eps, hc_sinkhorn_eps, + hc_post_mult_value, sinkhorn_repeat, n_splits, tile_n, + ) + + def forward_hip(self, *args, **kwargs): + return self.forward_cuda(*args, **kwargs) + + def forward_native(self, *args, **kwargs): + return self.forward_cuda(*args, **kwargs)