Fix MHC custom op registration

Previous approach used @CustomOp.register which doesn't create
torch.ops.vllm.mhc_pre. The model code calls torch.ops.vllm.mhc_pre()
directly, which requires direct_register_custom_op.

Use direct_register_custom_op to register mhc_pre, mhc_post,
mhc_fused_post_pre, and hc_head_fused_kernel as PyTorch custom ops
with torch (eager) implementations.

Patch kernels/mhc/__init__.py to import from both .torch (original)
and .mhc_torch_ops (our replacements), skipping tilelang import.
This commit is contained in:
2026-05-19 05:19:48 +00:00
parent 9ff1679064
commit 5e6d459145
4 changed files with 112 additions and 232 deletions

View File

@@ -40,10 +40,11 @@ COPY vllm/patches/deepseek_v4_attention.py ${VLLM_LAYERS_DIR}/deepseek_v4_attent
COPY vllm/patches/layers/deepseek_compressor.py ${VLLM_LAYERS_DIR}/deepseek_compressor.py
# Replace MHC TileLang kernels with pure PyTorch (avoids TileLang JIT on Blackwell)
COPY vllm/patches/layers/mhc.py ${VLLM_LAYERS_DIR}/mhc.py
# Install our torch op implementations and patch the MHC kernels __init__
ARG VLLM_MHC_KERNELS_DIR=/usr/local/lib/python3.12/dist-packages/vllm/model_executor/kernels/mhc
COPY vllm/patches/kernels/mhc/torch.py ${VLLM_MHC_KERNELS_DIR}/torch.py
COPY vllm/patches/kernels/mhc/__init__.py ${VLLM_MHC_KERNELS_DIR}/__init__.py
COPY vllm/patches/kernels/mhc_torch_ops.py ${VLLM_MHC_KERNELS_DIR}/mhc_torch_ops.py
RUN echo 'from .torch import *' > ${VLLM_MHC_KERNELS_DIR}/__init__.py && \
echo 'from .mhc_torch_ops import *' >> ${VLLM_MHC_KERNELS_DIR}/__init__.py
# CuTeDSL NVFP4 linear kernel (registered as NvFp4LinearKernel)
ARG VLLM_NVFP4_DIR=/usr/local/lib/python3.12/dist-packages/vllm/model_executor/kernels/linear/nvfp4

View File

@@ -1,12 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# Patched MHC kernels — pure PyTorch only, no TileLang/Triton.
# Avoids TileLang JIT compilation on Blackwell (SM100).
from .torch import *
__all__ = [
"mhc_pre_torch",
"mhc_post_torch",
"mhc_fused_post_pre_torch",
"hc_head_fused_torch",
]

View File

@@ -1,122 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# Patched MHC torch implementations — adds missing fused and hc_head ops.
# Original vllm torch.py only has mhc_pre_torch and mhc_post_torch.
# We add mhc_fused_post_pre_torch and hc_head_fused_torch.
import torch
def mhc_pre_torch(
residual: torch.Tensor,
fn: torch.Tensor,
hc_scale: torch.Tensor,
hc_base: torch.Tensor,
rms_eps: float,
hc_pre_eps: float,
hc_sinkhorn_eps: float,
hc_post_mult_value: float,
sinkhorn_repeat: int,
n_splits: int = 1,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
hc_mult = residual.shape[-2]
hidden_size = residual.shape[-1]
hc_mult2 = hc_mult * hc_mult
hc_mult3 = hc_mult * 2 + hc_mult2
hc_hidden_size = hc_mult * hidden_size
outer_shape = residual.shape[:-2]
residual_flat = residual.view(-1, hc_mult, hidden_size)
num_tokens = residual_flat.shape[0]
x = residual_flat.view(num_tokens, hc_hidden_size).to(torch.float32)
mixes = torch.matmul(x, fn.t())
sqrsum = x.square().sum(dim=-1, keepdim=True)
mixes = mixes * torch.rsqrt(sqrsum / hc_hidden_size + rms_eps)
pre_logits = mixes[:, :hc_mult] * hc_scale[0] + hc_base[:hc_mult]
pre_mix = torch.sigmoid(pre_logits) + hc_pre_eps
post_logits = mixes[:, hc_mult:2 * hc_mult] * hc_scale[1] + hc_base[hc_mult:2 * hc_mult]
post_mix = torch.sigmoid(post_logits) * hc_post_mult_value
comb_logits = (mixes[:, 2 * hc_mult:]
.view(num_tokens, hc_mult, hc_mult)
* hc_scale[2]
+ hc_base[2 * hc_mult:].view(1, hc_mult, hc_mult))
comb_mix = torch.softmax(comb_logits, dim=-1) + hc_sinkhorn_eps
comb_mix = comb_mix / (comb_mix.sum(dim=-2, keepdim=True) + hc_sinkhorn_eps)
for _ in range(sinkhorn_repeat - 1):
comb_mix = comb_mix / (comb_mix.sum(dim=-1, keepdim=True) + hc_sinkhorn_eps)
comb_mix = comb_mix / (comb_mix.sum(dim=-2, keepdim=True) + hc_sinkhorn_eps)
layer_input = torch.sum(
pre_mix.unsqueeze(-1) * residual_flat.to(torch.float32), dim=1
).to(torch.bfloat16)
return (
post_mix.view(*outer_shape, hc_mult, 1),
comb_mix.view(*outer_shape, hc_mult, hc_mult),
layer_input.view(*outer_shape, hidden_size),
)
def mhc_post_torch(
x: torch.Tensor,
residual: torch.Tensor,
post_layer_mix: torch.Tensor,
comb_res_mix: torch.Tensor,
) -> torch.Tensor:
mixed_residual = torch.einsum(
"...ij,...ih->...jh",
comb_res_mix.to(torch.float32),
residual.to(torch.float32),
)
post_term = post_layer_mix.to(torch.float32) * x.unsqueeze(-2).to(torch.float32)
return (mixed_residual + post_term).to(residual.dtype)
def mhc_fused_post_pre_torch(
x: torch.Tensor,
residual: torch.Tensor,
post_layer_mix: torch.Tensor,
comb_res_mix: torch.Tensor,
fn: torch.Tensor,
hc_scale: torch.Tensor,
hc_base: torch.Tensor,
rms_eps: float,
hc_pre_eps: float,
hc_sinkhorn_eps: float,
hc_post_mult_value: float,
sinkhorn_repeat: int,
n_splits: int = 1,
tile_n: int = 1,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
new_residual = mhc_post_torch(x, residual, post_layer_mix, comb_res_mix)
post_mix, res_mix, layer_input = mhc_pre_torch(
new_residual, fn, hc_scale, hc_base,
rms_eps, hc_pre_eps, hc_sinkhorn_eps,
hc_post_mult_value, sinkhorn_repeat, n_splits,
)
return new_residual, post_mix, res_mix, layer_input
def hc_head_fused_torch(
hs_flat: torch.Tensor,
fn: torch.Tensor,
hc_scale: torch.Tensor,
hc_base: torch.Tensor,
out: torch.Tensor,
hidden_size: int,
rms_eps: float,
hc_eps: float,
hc_mult: int,
) -> None:
x_flat = hs_flat.flatten(-2)
sqrsum = x_flat.to(torch.float32).square().sum(dim=-1, keepdim=True)
x_normed = x_flat * torch.rsqrt(sqrsum / x_flat.shape[-1] + rms_eps)
mixes = torch.nn.functional.linear(x_normed.to(torch.float32), fn)
pre = torch.sigmoid(mixes * hc_scale + hc_base) + hc_eps
result = torch.sum(
pre.unsqueeze(-1) * hs_flat.to(torch.float32), dim=1
).to(torch.bfloat16)
out.copy_(result)

View File

@@ -1,15 +1,15 @@
# SPDX-License-Identifier: Apache-2.0
# Patched MHC layer — replaces TileLang kernels with pure PyTorch.
# This avoids TileLang JIT compilation on Blackwell (SM100).
# Pure PyTorch MHC kernels for DeepSeek V4.
# Replaces TileLang kernels to avoid TileLang JIT compilation on Blackwell (SM100).
# Registers torch.ops.vllm.mhc_pre, mhc_post, mhc_fused_post_pre, hc_head_fused_kernel.
import torch
from vllm.model_executor.custom_op import CustomOp
from vllm.utils.torch_utils import direct_register_custom_op
# ── Pure PyTorch MHC implementations ──────────────────────────────────
# ── Pure PyTorch implementations ──────────────────────────────────────
def _mhc_pre_torch(
def mhc_pre(
residual: torch.Tensor,
fn: torch.Tensor,
hc_scale: torch.Tensor,
@@ -63,7 +63,29 @@ def _mhc_pre_torch(
)
def _mhc_post_torch(
def _mhc_pre_fake(
residual: torch.Tensor,
fn: torch.Tensor,
hc_scale: torch.Tensor,
hc_base: torch.Tensor,
rms_eps: float,
hc_pre_eps: float,
hc_sinkhorn_eps: float,
hc_post_mult_value: float,
sinkhorn_repeat: int,
n_splits: int = 1,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
hc_mult = residual.shape[-2]
hidden_size = residual.shape[-1]
outer_shape = residual.shape[:-2]
return (
torch.empty(*outer_shape, hc_mult, 1, dtype=torch.float32, device=residual.device),
torch.empty(*outer_shape, hc_mult, hc_mult, dtype=torch.float32, device=residual.device),
torch.empty(*outer_shape, hidden_size, dtype=torch.bfloat16, device=residual.device),
)
def mhc_post(
x: torch.Tensor,
residual: torch.Tensor,
post_layer_mix: torch.Tensor,
@@ -78,7 +100,16 @@ def _mhc_post_torch(
return (mixed_residual + post_term).to(residual.dtype)
def _mhc_fused_post_pre_torch(
def _mhc_post_fake(
x: torch.Tensor,
residual: torch.Tensor,
post_layer_mix: torch.Tensor,
comb_res_mix: torch.Tensor,
) -> torch.Tensor:
return torch.empty_like(residual)
def mhc_fused_post_pre(
x: torch.Tensor,
residual: torch.Tensor,
post_layer_mix: torch.Tensor,
@@ -94,8 +125,8 @@ def _mhc_fused_post_pre_torch(
n_splits: int = 1,
tile_n: int = 1,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
new_residual = _mhc_post_torch(x, residual, post_layer_mix, comb_res_mix)
post_mix, res_mix, layer_input = _mhc_pre_torch(
new_residual = mhc_post(x, residual, post_layer_mix, comb_res_mix)
post_mix, res_mix, layer_input = mhc_pre(
new_residual, fn, hc_scale, hc_base,
rms_eps, hc_pre_eps, hc_sinkhorn_eps,
hc_post_mult_value, sinkhorn_repeat, n_splits,
@@ -103,7 +134,34 @@ def _mhc_fused_post_pre_torch(
return new_residual, post_mix, res_mix, layer_input
def _hc_head_fused_torch(
def _mhc_fused_post_pre_fake(
x: torch.Tensor,
residual: torch.Tensor,
post_layer_mix: torch.Tensor,
comb_res_mix: torch.Tensor,
fn: torch.Tensor,
hc_scale: torch.Tensor,
hc_base: torch.Tensor,
rms_eps: float,
hc_pre_eps: float,
hc_sinkhorn_eps: float,
hc_post_mult_value: float,
sinkhorn_repeat: int,
n_splits: int = 1,
tile_n: int = 1,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
hc_mult = residual.shape[-2]
hidden_size = residual.shape[-1]
outer_shape = residual.shape[:-2]
return (
torch.empty_like(residual),
torch.empty(*outer_shape, hc_mult, 1, dtype=torch.float32, device=residual.device),
torch.empty(*outer_shape, hc_mult, hc_mult, dtype=torch.float32, device=residual.device),
torch.empty(*outer_shape, hidden_size, dtype=torch.bfloat16, device=residual.device),
)
def hc_head_fused_kernel(
hs_flat: torch.Tensor,
fn: torch.Tensor,
hc_scale: torch.Tensor,
@@ -125,93 +183,48 @@ def _hc_head_fused_torch(
out.copy_(result)
# ── CustomOp registrations ────────────────────────────────────────────
@CustomOp.register("mhc_pre")
class MHCPreOp(CustomOp):
@classmethod
def enabled(cls) -> bool:
return True
def forward_cuda(self, residual, fn, hc_scale, hc_base,
rms_eps, hc_pre_eps, hc_sinkhorn_eps,
hc_post_mult_value, sinkhorn_repeat, n_splits=1):
return _mhc_pre_torch(
residual, fn, hc_scale, hc_base,
rms_eps, hc_pre_eps, hc_sinkhorn_eps,
hc_post_mult_value, sinkhorn_repeat, n_splits,
)
def forward_hip(self, *args, **kwargs):
return self.forward_cuda(*args, **kwargs)
def forward_native(self, *args, **kwargs):
return self.forward_cuda(*args, **kwargs)
def _hc_head_fused_kernel_fake(
hs_flat: torch.Tensor,
fn: torch.Tensor,
hc_scale: torch.Tensor,
hc_base: torch.Tensor,
out: torch.Tensor,
hidden_size: int,
rms_eps: float,
hc_eps: float,
hc_mult: int,
) -> None:
return None
@CustomOp.register("mhc_post")
class MHCPostOp(CustomOp):
@classmethod
def enabled(cls) -> bool:
return True
# ── Register as torch custom ops ──────────────────────────────────────
# These replace the TileLang-registered ops that the model code calls via
# torch.ops.vllm.mhc_pre, torch.ops.vllm.mhc_post, etc.
def forward_cuda(self, x, residual, post_layer_mix, comb_res_mix):
return _mhc_post_torch(x, residual, post_layer_mix, comb_res_mix)
direct_register_custom_op(
op_name="mhc_pre",
op_func=mhc_pre,
mutates_args=[],
fake_impl=_mhc_pre_fake,
)
def forward_hip(self, *args, **kwargs):
return self.forward_cuda(*args, **kwargs)
direct_register_custom_op(
op_name="mhc_post",
op_func=mhc_post,
mutates_args=[],
fake_impl=_mhc_post_fake,
)
def forward_native(self, *args, **kwargs):
return self.forward_cuda(*args, **kwargs)
direct_register_custom_op(
op_name="mhc_fused_post_pre",
op_func=mhc_fused_post_pre,
mutates_args=[],
fake_impl=_mhc_fused_post_pre_fake,
)
@CustomOp.register("hc_head")
class HCHeadOp(CustomOp):
@classmethod
def enabled(cls) -> bool:
return True
def forward_cuda(self, hidden_states, hc_fn, hc_scale, hc_base,
rms_norm_eps, hc_eps):
hc_mult, hidden_size = hidden_states.shape[-2:]
outer_shape = hidden_states.shape[:-2]
hs_flat = hidden_states.view(-1, hc_mult, hidden_size)
out = torch.empty(
hs_flat.shape[0], hidden_size,
dtype=torch.bfloat16, device=hidden_states.device,
)
_hc_head_fused_torch(
hs_flat, hc_fn, hc_scale, hc_base,
out, hidden_size, rms_norm_eps, hc_eps, hc_mult,
)
return out.view(*outer_shape, hidden_size)
def forward_hip(self, *args, **kwargs):
return self.forward_cuda(*args, **kwargs)
def forward_native(self, *args, **kwargs):
return self.forward_cuda(*args, **kwargs)
@CustomOp.register("mhc_fused_post_pre")
class MHCFusedPostPreOp(CustomOp):
@classmethod
def enabled(cls) -> bool:
return True
def forward_cuda(self, x, residual, post_layer_mix, comb_res_mix,
fn, hc_scale, hc_base, rms_eps, hc_pre_eps,
hc_sinkhorn_eps, hc_post_mult_value, sinkhorn_repeat,
n_splits=1, tile_n=1):
return _mhc_fused_post_pre_torch(
x, residual, post_layer_mix, comb_res_mix,
fn, hc_scale, hc_base,
rms_eps, hc_pre_eps, hc_sinkhorn_eps,
hc_post_mult_value, sinkhorn_repeat, n_splits, tile_n,
)
def forward_hip(self, *args, **kwargs):
return self.forward_cuda(*args, **kwargs)
def forward_native(self, *args, **kwargs):
return self.forward_cuda(*args, **kwargs)
direct_register_custom_op(
op_name="hc_head_fused_kernel",
op_func=hc_head_fused_kernel,
mutates_args=["out"],
fake_impl=_hc_head_fused_kernel_fake,
)