From 5e6d45914522a5c93a2fd596629f6e35dc7fc61b Mon Sep 17 00:00:00 2001 From: biondizzle Date: Tue, 19 May 2026 05:19:48 +0000 Subject: [PATCH] Fix MHC custom op registration Previous approach used @CustomOp.register which doesn't create torch.ops.vllm.mhc_pre. The model code calls torch.ops.vllm.mhc_pre() directly, which requires direct_register_custom_op. Use direct_register_custom_op to register mhc_pre, mhc_post, mhc_fused_post_pre, and hc_head_fused_kernel as PyTorch custom ops with torch (eager) implementations. Patch kernels/mhc/__init__.py to import from both .torch (original) and .mhc_torch_ops (our replacements), skipping tilelang import. --- Dockerfile | 7 +- vllm/patches/kernels/mhc/__init__.py | 12 -- vllm/patches/kernels/mhc/torch.py | 122 ----------- .../mhc.py => kernels/mhc_torch_ops.py} | 203 ++++++++++-------- 4 files changed, 112 insertions(+), 232 deletions(-) delete mode 100644 vllm/patches/kernels/mhc/__init__.py delete mode 100644 vllm/patches/kernels/mhc/torch.py rename vllm/patches/{layers/mhc.py => kernels/mhc_torch_ops.py} (51%) diff --git a/Dockerfile b/Dockerfile index f29064ce..23905ed0 100644 --- a/Dockerfile +++ b/Dockerfile @@ -40,10 +40,11 @@ COPY vllm/patches/deepseek_v4_attention.py ${VLLM_LAYERS_DIR}/deepseek_v4_attent COPY vllm/patches/layers/deepseek_compressor.py ${VLLM_LAYERS_DIR}/deepseek_compressor.py # Replace MHC TileLang kernels with pure PyTorch (avoids TileLang JIT on Blackwell) -COPY vllm/patches/layers/mhc.py ${VLLM_LAYERS_DIR}/mhc.py +# Install our torch op implementations and patch the MHC kernels __init__ ARG VLLM_MHC_KERNELS_DIR=/usr/local/lib/python3.12/dist-packages/vllm/model_executor/kernels/mhc -COPY vllm/patches/kernels/mhc/torch.py ${VLLM_MHC_KERNELS_DIR}/torch.py -COPY vllm/patches/kernels/mhc/__init__.py ${VLLM_MHC_KERNELS_DIR}/__init__.py +COPY vllm/patches/kernels/mhc_torch_ops.py ${VLLM_MHC_KERNELS_DIR}/mhc_torch_ops.py +RUN echo 'from .torch import *' > ${VLLM_MHC_KERNELS_DIR}/__init__.py && \ + echo 'from .mhc_torch_ops import *' >> ${VLLM_MHC_KERNELS_DIR}/__init__.py # CuTeDSL NVFP4 linear kernel (registered as NvFp4LinearKernel) ARG VLLM_NVFP4_DIR=/usr/local/lib/python3.12/dist-packages/vllm/model_executor/kernels/linear/nvfp4 diff --git a/vllm/patches/kernels/mhc/__init__.py b/vllm/patches/kernels/mhc/__init__.py deleted file mode 100644 index a0882d6e..00000000 --- a/vllm/patches/kernels/mhc/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# Patched MHC kernels — pure PyTorch only, no TileLang/Triton. -# Avoids TileLang JIT compilation on Blackwell (SM100). - -from .torch import * - -__all__ = [ - "mhc_pre_torch", - "mhc_post_torch", - "mhc_fused_post_pre_torch", - "hc_head_fused_torch", -] diff --git a/vllm/patches/kernels/mhc/torch.py b/vllm/patches/kernels/mhc/torch.py deleted file mode 100644 index cfdb1c53..00000000 --- a/vllm/patches/kernels/mhc/torch.py +++ /dev/null @@ -1,122 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# Patched MHC torch implementations — adds missing fused and hc_head ops. -# Original vllm torch.py only has mhc_pre_torch and mhc_post_torch. -# We add mhc_fused_post_pre_torch and hc_head_fused_torch. - -import torch - - -def mhc_pre_torch( - residual: torch.Tensor, - fn: torch.Tensor, - hc_scale: torch.Tensor, - hc_base: torch.Tensor, - rms_eps: float, - hc_pre_eps: float, - hc_sinkhorn_eps: float, - hc_post_mult_value: float, - sinkhorn_repeat: int, - n_splits: int = 1, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - hc_mult = residual.shape[-2] - hidden_size = residual.shape[-1] - hc_mult2 = hc_mult * hc_mult - hc_mult3 = hc_mult * 2 + hc_mult2 - hc_hidden_size = hc_mult * hidden_size - outer_shape = residual.shape[:-2] - - residual_flat = residual.view(-1, hc_mult, hidden_size) - num_tokens = residual_flat.shape[0] - - x = residual_flat.view(num_tokens, hc_hidden_size).to(torch.float32) - mixes = torch.matmul(x, fn.t()) - sqrsum = x.square().sum(dim=-1, keepdim=True) - mixes = mixes * torch.rsqrt(sqrsum / hc_hidden_size + rms_eps) - - pre_logits = mixes[:, :hc_mult] * hc_scale[0] + hc_base[:hc_mult] - pre_mix = torch.sigmoid(pre_logits) + hc_pre_eps - - post_logits = mixes[:, hc_mult:2 * hc_mult] * hc_scale[1] + hc_base[hc_mult:2 * hc_mult] - post_mix = torch.sigmoid(post_logits) * hc_post_mult_value - - comb_logits = (mixes[:, 2 * hc_mult:] - .view(num_tokens, hc_mult, hc_mult) - * hc_scale[2] - + hc_base[2 * hc_mult:].view(1, hc_mult, hc_mult)) - comb_mix = torch.softmax(comb_logits, dim=-1) + hc_sinkhorn_eps - comb_mix = comb_mix / (comb_mix.sum(dim=-2, keepdim=True) + hc_sinkhorn_eps) - for _ in range(sinkhorn_repeat - 1): - comb_mix = comb_mix / (comb_mix.sum(dim=-1, keepdim=True) + hc_sinkhorn_eps) - comb_mix = comb_mix / (comb_mix.sum(dim=-2, keepdim=True) + hc_sinkhorn_eps) - - layer_input = torch.sum( - pre_mix.unsqueeze(-1) * residual_flat.to(torch.float32), dim=1 - ).to(torch.bfloat16) - - return ( - post_mix.view(*outer_shape, hc_mult, 1), - comb_mix.view(*outer_shape, hc_mult, hc_mult), - layer_input.view(*outer_shape, hidden_size), - ) - - -def mhc_post_torch( - x: torch.Tensor, - residual: torch.Tensor, - post_layer_mix: torch.Tensor, - comb_res_mix: torch.Tensor, -) -> torch.Tensor: - mixed_residual = torch.einsum( - "...ij,...ih->...jh", - comb_res_mix.to(torch.float32), - residual.to(torch.float32), - ) - post_term = post_layer_mix.to(torch.float32) * x.unsqueeze(-2).to(torch.float32) - return (mixed_residual + post_term).to(residual.dtype) - - -def mhc_fused_post_pre_torch( - x: torch.Tensor, - residual: torch.Tensor, - post_layer_mix: torch.Tensor, - comb_res_mix: torch.Tensor, - fn: torch.Tensor, - hc_scale: torch.Tensor, - hc_base: torch.Tensor, - rms_eps: float, - hc_pre_eps: float, - hc_sinkhorn_eps: float, - hc_post_mult_value: float, - sinkhorn_repeat: int, - n_splits: int = 1, - tile_n: int = 1, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - new_residual = mhc_post_torch(x, residual, post_layer_mix, comb_res_mix) - post_mix, res_mix, layer_input = mhc_pre_torch( - new_residual, fn, hc_scale, hc_base, - rms_eps, hc_pre_eps, hc_sinkhorn_eps, - hc_post_mult_value, sinkhorn_repeat, n_splits, - ) - return new_residual, post_mix, res_mix, layer_input - - -def hc_head_fused_torch( - hs_flat: torch.Tensor, - fn: torch.Tensor, - hc_scale: torch.Tensor, - hc_base: torch.Tensor, - out: torch.Tensor, - hidden_size: int, - rms_eps: float, - hc_eps: float, - hc_mult: int, -) -> None: - x_flat = hs_flat.flatten(-2) - sqrsum = x_flat.to(torch.float32).square().sum(dim=-1, keepdim=True) - x_normed = x_flat * torch.rsqrt(sqrsum / x_flat.shape[-1] + rms_eps) - mixes = torch.nn.functional.linear(x_normed.to(torch.float32), fn) - pre = torch.sigmoid(mixes * hc_scale + hc_base) + hc_eps - result = torch.sum( - pre.unsqueeze(-1) * hs_flat.to(torch.float32), dim=1 - ).to(torch.bfloat16) - out.copy_(result) diff --git a/vllm/patches/layers/mhc.py b/vllm/patches/kernels/mhc_torch_ops.py similarity index 51% rename from vllm/patches/layers/mhc.py rename to vllm/patches/kernels/mhc_torch_ops.py index 3b304ffc..eee5aba4 100644 --- a/vllm/patches/layers/mhc.py +++ b/vllm/patches/kernels/mhc_torch_ops.py @@ -1,15 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 -# Patched MHC layer — replaces TileLang kernels with pure PyTorch. -# This avoids TileLang JIT compilation on Blackwell (SM100). +# Pure PyTorch MHC kernels for DeepSeek V4. +# Replaces TileLang kernels to avoid TileLang JIT compilation on Blackwell (SM100). +# Registers torch.ops.vllm.mhc_pre, mhc_post, mhc_fused_post_pre, hc_head_fused_kernel. import torch - -from vllm.model_executor.custom_op import CustomOp +from vllm.utils.torch_utils import direct_register_custom_op -# ── Pure PyTorch MHC implementations ────────────────────────────────── +# ── Pure PyTorch implementations ────────────────────────────────────── -def _mhc_pre_torch( +def mhc_pre( residual: torch.Tensor, fn: torch.Tensor, hc_scale: torch.Tensor, @@ -63,7 +63,29 @@ def _mhc_pre_torch( ) -def _mhc_post_torch( +def _mhc_pre_fake( + residual: torch.Tensor, + fn: torch.Tensor, + hc_scale: torch.Tensor, + hc_base: torch.Tensor, + rms_eps: float, + hc_pre_eps: float, + hc_sinkhorn_eps: float, + hc_post_mult_value: float, + sinkhorn_repeat: int, + n_splits: int = 1, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + hc_mult = residual.shape[-2] + hidden_size = residual.shape[-1] + outer_shape = residual.shape[:-2] + return ( + torch.empty(*outer_shape, hc_mult, 1, dtype=torch.float32, device=residual.device), + torch.empty(*outer_shape, hc_mult, hc_mult, dtype=torch.float32, device=residual.device), + torch.empty(*outer_shape, hidden_size, dtype=torch.bfloat16, device=residual.device), + ) + + +def mhc_post( x: torch.Tensor, residual: torch.Tensor, post_layer_mix: torch.Tensor, @@ -78,7 +100,16 @@ def _mhc_post_torch( return (mixed_residual + post_term).to(residual.dtype) -def _mhc_fused_post_pre_torch( +def _mhc_post_fake( + x: torch.Tensor, + residual: torch.Tensor, + post_layer_mix: torch.Tensor, + comb_res_mix: torch.Tensor, +) -> torch.Tensor: + return torch.empty_like(residual) + + +def mhc_fused_post_pre( x: torch.Tensor, residual: torch.Tensor, post_layer_mix: torch.Tensor, @@ -94,8 +125,8 @@ def _mhc_fused_post_pre_torch( n_splits: int = 1, tile_n: int = 1, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - new_residual = _mhc_post_torch(x, residual, post_layer_mix, comb_res_mix) - post_mix, res_mix, layer_input = _mhc_pre_torch( + new_residual = mhc_post(x, residual, post_layer_mix, comb_res_mix) + post_mix, res_mix, layer_input = mhc_pre( new_residual, fn, hc_scale, hc_base, rms_eps, hc_pre_eps, hc_sinkhorn_eps, hc_post_mult_value, sinkhorn_repeat, n_splits, @@ -103,7 +134,34 @@ def _mhc_fused_post_pre_torch( return new_residual, post_mix, res_mix, layer_input -def _hc_head_fused_torch( +def _mhc_fused_post_pre_fake( + x: torch.Tensor, + residual: torch.Tensor, + post_layer_mix: torch.Tensor, + comb_res_mix: torch.Tensor, + fn: torch.Tensor, + hc_scale: torch.Tensor, + hc_base: torch.Tensor, + rms_eps: float, + hc_pre_eps: float, + hc_sinkhorn_eps: float, + hc_post_mult_value: float, + sinkhorn_repeat: int, + n_splits: int = 1, + tile_n: int = 1, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + hc_mult = residual.shape[-2] + hidden_size = residual.shape[-1] + outer_shape = residual.shape[:-2] + return ( + torch.empty_like(residual), + torch.empty(*outer_shape, hc_mult, 1, dtype=torch.float32, device=residual.device), + torch.empty(*outer_shape, hc_mult, hc_mult, dtype=torch.float32, device=residual.device), + torch.empty(*outer_shape, hidden_size, dtype=torch.bfloat16, device=residual.device), + ) + + +def hc_head_fused_kernel( hs_flat: torch.Tensor, fn: torch.Tensor, hc_scale: torch.Tensor, @@ -125,93 +183,48 @@ def _hc_head_fused_torch( out.copy_(result) -# ── CustomOp registrations ──────────────────────────────────────────── - -@CustomOp.register("mhc_pre") -class MHCPreOp(CustomOp): - @classmethod - def enabled(cls) -> bool: - return True - - def forward_cuda(self, residual, fn, hc_scale, hc_base, - rms_eps, hc_pre_eps, hc_sinkhorn_eps, - hc_post_mult_value, sinkhorn_repeat, n_splits=1): - return _mhc_pre_torch( - residual, fn, hc_scale, hc_base, - rms_eps, hc_pre_eps, hc_sinkhorn_eps, - hc_post_mult_value, sinkhorn_repeat, n_splits, - ) - - def forward_hip(self, *args, **kwargs): - return self.forward_cuda(*args, **kwargs) - - def forward_native(self, *args, **kwargs): - return self.forward_cuda(*args, **kwargs) +def _hc_head_fused_kernel_fake( + hs_flat: torch.Tensor, + fn: torch.Tensor, + hc_scale: torch.Tensor, + hc_base: torch.Tensor, + out: torch.Tensor, + hidden_size: int, + rms_eps: float, + hc_eps: float, + hc_mult: int, +) -> None: + return None -@CustomOp.register("mhc_post") -class MHCPostOp(CustomOp): - @classmethod - def enabled(cls) -> bool: - return True +# ── Register as torch custom ops ────────────────────────────────────── +# These replace the TileLang-registered ops that the model code calls via +# torch.ops.vllm.mhc_pre, torch.ops.vllm.mhc_post, etc. - def forward_cuda(self, x, residual, post_layer_mix, comb_res_mix): - return _mhc_post_torch(x, residual, post_layer_mix, comb_res_mix) +direct_register_custom_op( + op_name="mhc_pre", + op_func=mhc_pre, + mutates_args=[], + fake_impl=_mhc_pre_fake, +) - def forward_hip(self, *args, **kwargs): - return self.forward_cuda(*args, **kwargs) +direct_register_custom_op( + op_name="mhc_post", + op_func=mhc_post, + mutates_args=[], + fake_impl=_mhc_post_fake, +) - def forward_native(self, *args, **kwargs): - return self.forward_cuda(*args, **kwargs) +direct_register_custom_op( + op_name="mhc_fused_post_pre", + op_func=mhc_fused_post_pre, + mutates_args=[], + fake_impl=_mhc_fused_post_pre_fake, +) - -@CustomOp.register("hc_head") -class HCHeadOp(CustomOp): - @classmethod - def enabled(cls) -> bool: - return True - - def forward_cuda(self, hidden_states, hc_fn, hc_scale, hc_base, - rms_norm_eps, hc_eps): - hc_mult, hidden_size = hidden_states.shape[-2:] - outer_shape = hidden_states.shape[:-2] - hs_flat = hidden_states.view(-1, hc_mult, hidden_size) - out = torch.empty( - hs_flat.shape[0], hidden_size, - dtype=torch.bfloat16, device=hidden_states.device, - ) - _hc_head_fused_torch( - hs_flat, hc_fn, hc_scale, hc_base, - out, hidden_size, rms_norm_eps, hc_eps, hc_mult, - ) - return out.view(*outer_shape, hidden_size) - - def forward_hip(self, *args, **kwargs): - return self.forward_cuda(*args, **kwargs) - - def forward_native(self, *args, **kwargs): - return self.forward_cuda(*args, **kwargs) - - -@CustomOp.register("mhc_fused_post_pre") -class MHCFusedPostPreOp(CustomOp): - @classmethod - def enabled(cls) -> bool: - return True - - def forward_cuda(self, x, residual, post_layer_mix, comb_res_mix, - fn, hc_scale, hc_base, rms_eps, hc_pre_eps, - hc_sinkhorn_eps, hc_post_mult_value, sinkhorn_repeat, - n_splits=1, tile_n=1): - return _mhc_fused_post_pre_torch( - x, residual, post_layer_mix, comb_res_mix, - fn, hc_scale, hc_base, - rms_eps, hc_pre_eps, hc_sinkhorn_eps, - hc_post_mult_value, sinkhorn_repeat, n_splits, tile_n, - ) - - def forward_hip(self, *args, **kwargs): - return self.forward_cuda(*args, **kwargs) - - def forward_native(self, *args, **kwargs): - return self.forward_cuda(*args, **kwargs) +direct_register_custom_op( + op_name="hc_head_fused_kernel", + op_func=hc_head_fused_kernel, + mutates_args=["out"], + fake_impl=_hc_head_fused_kernel_fake, +)