Also replace layers/mhc.py CustomOp dispatch
The original layers/mhc.py forward_cuda calls torch.ops.vllm.mhc_pre_tilelang which triggers TileLang JIT. Replace with our torch implementations in forward_cuda. This is what the CustomOp dispatch routes through.
This commit is contained in:
@@ -40,8 +40,11 @@ COPY vllm/patches/deepseek_v4_attention.py ${VLLM_LAYERS_DIR}/deepseek_v4_attent
|
||||
COPY vllm/patches/layers/deepseek_compressor.py ${VLLM_LAYERS_DIR}/deepseek_compressor.py
|
||||
|
||||
# Replace MHC TileLang kernels with pure PyTorch (avoids TileLang JIT on Blackwell)
|
||||
# Install our torch op implementations and patch the MHC kernels __init__
|
||||
# 1. Patch layers/mhc.py — CustomOp dispatch uses torch impls instead of tilelang
|
||||
# 2. Install our torch op registrations (mhc_torch_ops.py)
|
||||
# 3. Patch kernels/mhc/__init__.py to not import tilelang
|
||||
ARG VLLM_MHC_KERNELS_DIR=/usr/local/lib/python3.12/dist-packages/vllm/model_executor/kernels/mhc
|
||||
COPY vllm/patches/layers/mhc.py ${VLLM_LAYERS_DIR}/mhc.py
|
||||
COPY vllm/patches/kernels/mhc_torch_ops.py ${VLLM_MHC_KERNELS_DIR}/mhc_torch_ops.py
|
||||
RUN echo 'from .torch import *' > ${VLLM_MHC_KERNELS_DIR}/__init__.py && \
|
||||
echo 'from .mhc_torch_ops import *' >> ${VLLM_MHC_KERNELS_DIR}/__init__.py
|
||||
|
||||
102
vllm/patches/layers/mhc.py
Normal file
102
vllm/patches/layers/mhc.py
Normal file
@@ -0,0 +1,102 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Patched MHC layer — replaces TileLang kernels with pure PyTorch.
|
||||
# This avoids TileLang JIT compilation on Blackwell (SM100).
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
|
||||
# Import our torch implementations (registers torch.ops.vllm.mhc_pre, etc.)
|
||||
import vllm.model_executor.kernels.mhc.mhc_torch_ops as _mhc_torch # noqa: F401
|
||||
# Also import the original torch impls (mhc_pre_torch, mhc_post_torch)
|
||||
import vllm.model_executor.kernels.mhc.torch as mhc_kernels # noqa: F401
|
||||
|
||||
|
||||
@CustomOp.register("mhc_pre")
|
||||
class MHCPreOp(CustomOp):
|
||||
@classmethod
|
||||
def enabled(cls) -> bool:
|
||||
return True
|
||||
|
||||
def forward_cuda(self, residual, fn, hc_scale, hc_base,
|
||||
rms_eps, hc_pre_eps, hc_sinkhorn_eps,
|
||||
hc_post_mult_value, sinkhorn_repeat, n_splits=1):
|
||||
return _mhc_torch.mhc_pre(
|
||||
residual, fn, hc_scale, hc_base,
|
||||
rms_eps, hc_pre_eps, hc_sinkhorn_eps,
|
||||
hc_post_mult_value, sinkhorn_repeat, n_splits,
|
||||
)
|
||||
|
||||
def forward_hip(self, *args, **kwargs):
|
||||
return self.forward_cuda(*args, **kwargs)
|
||||
|
||||
def forward_native(self, *args, **kwargs):
|
||||
return self.forward_cuda(*args, **kwargs)
|
||||
|
||||
|
||||
@CustomOp.register("mhc_post")
|
||||
class MHCPostOp(CustomOp):
|
||||
@classmethod
|
||||
def enabled(cls) -> bool:
|
||||
return True
|
||||
|
||||
def forward_cuda(self, x, residual, post_layer_mix, comb_res_mix):
|
||||
return _mhc_torch.mhc_post(x, residual, post_layer_mix, comb_res_mix)
|
||||
|
||||
def forward_hip(self, *args, **kwargs):
|
||||
return self.forward_cuda(*args, **kwargs)
|
||||
|
||||
def forward_native(self, *args, **kwargs):
|
||||
return self.forward_cuda(*args, **kwargs)
|
||||
|
||||
|
||||
@CustomOp.register("hc_head")
|
||||
class HCHeadOp(CustomOp):
|
||||
@classmethod
|
||||
def enabled(cls) -> bool:
|
||||
return True
|
||||
|
||||
def forward_cuda(self, hidden_states, hc_fn, hc_scale, hc_base,
|
||||
rms_norm_eps, hc_eps):
|
||||
hc_mult, hidden_size = hidden_states.shape[-2:]
|
||||
outer_shape = hidden_states.shape[:-2]
|
||||
hs_flat = hidden_states.view(-1, hc_mult, hidden_size)
|
||||
out = torch.empty(
|
||||
hs_flat.shape[0], hidden_size,
|
||||
dtype=torch.bfloat16, device=hidden_states.device,
|
||||
)
|
||||
_mhc_torch.hc_head_fused_kernel(
|
||||
hs_flat, hc_fn, hc_scale, hc_base,
|
||||
out, hidden_size, rms_norm_eps, hc_eps, hc_mult,
|
||||
)
|
||||
return out.view(*outer_shape, hidden_size)
|
||||
|
||||
def forward_hip(self, *args, **kwargs):
|
||||
return self.forward_cuda(*args, **kwargs)
|
||||
|
||||
def forward_native(self, *args, **kwargs):
|
||||
return self.forward_cuda(*args, **kwargs)
|
||||
|
||||
|
||||
@CustomOp.register("mhc_fused_post_pre")
|
||||
class MHCFusedPostPreOp(CustomOp):
|
||||
@classmethod
|
||||
def enabled(cls) -> bool:
|
||||
return True
|
||||
|
||||
def forward_cuda(self, x, residual, post_layer_mix, comb_res_mix,
|
||||
fn, hc_scale, hc_base, rms_eps, hc_pre_eps,
|
||||
hc_sinkhorn_eps, hc_post_mult_value, sinkhorn_repeat,
|
||||
n_splits=1, tile_n=1):
|
||||
return _mhc_torch.mhc_fused_post_pre(
|
||||
x, residual, post_layer_mix, comb_res_mix,
|
||||
fn, hc_scale, hc_base,
|
||||
rms_eps, hc_pre_eps, hc_sinkhorn_eps,
|
||||
hc_post_mult_value, sinkhorn_repeat, n_splits, tile_n,
|
||||
)
|
||||
|
||||
def forward_hip(self, *args, **kwargs):
|
||||
return self.forward_cuda(*args, **kwargs)
|
||||
|
||||
def forward_native(self, *args, **kwargs):
|
||||
return self.forward_cuda(*args, **kwargs)
|
||||
Reference in New Issue
Block a user