Also replace layers/mhc.py CustomOp dispatch

The original layers/mhc.py forward_cuda calls
torch.ops.vllm.mhc_pre_tilelang which triggers TileLang JIT.
Replace with our torch implementations in forward_cuda.
This is what the CustomOp dispatch routes through.
This commit is contained in:
2026-05-19 05:31:05 +00:00
parent 5e6d459145
commit e404e18efb
2 changed files with 106 additions and 1 deletions

View File

@@ -40,8 +40,11 @@ COPY vllm/patches/deepseek_v4_attention.py ${VLLM_LAYERS_DIR}/deepseek_v4_attent
COPY vllm/patches/layers/deepseek_compressor.py ${VLLM_LAYERS_DIR}/deepseek_compressor.py
# Replace MHC TileLang kernels with pure PyTorch (avoids TileLang JIT on Blackwell)
# Install our torch op implementations and patch the MHC kernels __init__
# 1. Patch layers/mhc.py — CustomOp dispatch uses torch impls instead of tilelang
# 2. Install our torch op registrations (mhc_torch_ops.py)
# 3. Patch kernels/mhc/__init__.py to not import tilelang
ARG VLLM_MHC_KERNELS_DIR=/usr/local/lib/python3.12/dist-packages/vllm/model_executor/kernels/mhc
COPY vllm/patches/layers/mhc.py ${VLLM_LAYERS_DIR}/mhc.py
COPY vllm/patches/kernels/mhc_torch_ops.py ${VLLM_MHC_KERNELS_DIR}/mhc_torch_ops.py
RUN echo 'from .torch import *' > ${VLLM_MHC_KERNELS_DIR}/__init__.py && \
echo 'from .mhc_torch_ops import *' >> ${VLLM_MHC_KERNELS_DIR}/__init__.py

102
vllm/patches/layers/mhc.py Normal file
View File

@@ -0,0 +1,102 @@
# SPDX-License-Identifier: Apache-2.0
# Patched MHC layer — replaces TileLang kernels with pure PyTorch.
# This avoids TileLang JIT compilation on Blackwell (SM100).
import torch
from vllm.model_executor.custom_op import CustomOp
# Import our torch implementations (registers torch.ops.vllm.mhc_pre, etc.)
import vllm.model_executor.kernels.mhc.mhc_torch_ops as _mhc_torch # noqa: F401
# Also import the original torch impls (mhc_pre_torch, mhc_post_torch)
import vllm.model_executor.kernels.mhc.torch as mhc_kernels # noqa: F401
@CustomOp.register("mhc_pre")
class MHCPreOp(CustomOp):
@classmethod
def enabled(cls) -> bool:
return True
def forward_cuda(self, residual, fn, hc_scale, hc_base,
rms_eps, hc_pre_eps, hc_sinkhorn_eps,
hc_post_mult_value, sinkhorn_repeat, n_splits=1):
return _mhc_torch.mhc_pre(
residual, fn, hc_scale, hc_base,
rms_eps, hc_pre_eps, hc_sinkhorn_eps,
hc_post_mult_value, sinkhorn_repeat, n_splits,
)
def forward_hip(self, *args, **kwargs):
return self.forward_cuda(*args, **kwargs)
def forward_native(self, *args, **kwargs):
return self.forward_cuda(*args, **kwargs)
@CustomOp.register("mhc_post")
class MHCPostOp(CustomOp):
@classmethod
def enabled(cls) -> bool:
return True
def forward_cuda(self, x, residual, post_layer_mix, comb_res_mix):
return _mhc_torch.mhc_post(x, residual, post_layer_mix, comb_res_mix)
def forward_hip(self, *args, **kwargs):
return self.forward_cuda(*args, **kwargs)
def forward_native(self, *args, **kwargs):
return self.forward_cuda(*args, **kwargs)
@CustomOp.register("hc_head")
class HCHeadOp(CustomOp):
@classmethod
def enabled(cls) -> bool:
return True
def forward_cuda(self, hidden_states, hc_fn, hc_scale, hc_base,
rms_norm_eps, hc_eps):
hc_mult, hidden_size = hidden_states.shape[-2:]
outer_shape = hidden_states.shape[:-2]
hs_flat = hidden_states.view(-1, hc_mult, hidden_size)
out = torch.empty(
hs_flat.shape[0], hidden_size,
dtype=torch.bfloat16, device=hidden_states.device,
)
_mhc_torch.hc_head_fused_kernel(
hs_flat, hc_fn, hc_scale, hc_base,
out, hidden_size, rms_norm_eps, hc_eps, hc_mult,
)
return out.view(*outer_shape, hidden_size)
def forward_hip(self, *args, **kwargs):
return self.forward_cuda(*args, **kwargs)
def forward_native(self, *args, **kwargs):
return self.forward_cuda(*args, **kwargs)
@CustomOp.register("mhc_fused_post_pre")
class MHCFusedPostPreOp(CustomOp):
@classmethod
def enabled(cls) -> bool:
return True
def forward_cuda(self, x, residual, post_layer_mix, comb_res_mix,
fn, hc_scale, hc_base, rms_eps, hc_pre_eps,
hc_sinkhorn_eps, hc_post_mult_value, sinkhorn_repeat,
n_splits=1, tile_n=1):
return _mhc_torch.mhc_fused_post_pre(
x, residual, post_layer_mix, comb_res_mix,
fn, hc_scale, hc_base,
rms_eps, hc_pre_eps, hc_sinkhorn_eps,
hc_post_mult_value, sinkhorn_repeat, n_splits, tile_n,
)
def forward_hip(self, *args, **kwargs):
return self.forward_cuda(*args, **kwargs)
def forward_native(self, *args, **kwargs):
return self.forward_cuda(*args, **kwargs)