Fix MHC custom op registration
Previous approach used @CustomOp.register which doesn't create torch.ops.vllm.mhc_pre. The model code calls torch.ops.vllm.mhc_pre() directly, which requires direct_register_custom_op. Use direct_register_custom_op to register mhc_pre, mhc_post, mhc_fused_post_pre, and hc_head_fused_kernel as PyTorch custom ops with torch (eager) implementations. Patch kernels/mhc/__init__.py to import from both .torch (original) and .mhc_torch_ops (our replacements), skipping tilelang import.
This commit is contained in:
@@ -40,10 +40,11 @@ COPY vllm/patches/deepseek_v4_attention.py ${VLLM_LAYERS_DIR}/deepseek_v4_attent
|
||||
COPY vllm/patches/layers/deepseek_compressor.py ${VLLM_LAYERS_DIR}/deepseek_compressor.py
|
||||
|
||||
# Replace MHC TileLang kernels with pure PyTorch (avoids TileLang JIT on Blackwell)
|
||||
COPY vllm/patches/layers/mhc.py ${VLLM_LAYERS_DIR}/mhc.py
|
||||
# Install our torch op implementations and patch the MHC kernels __init__
|
||||
ARG VLLM_MHC_KERNELS_DIR=/usr/local/lib/python3.12/dist-packages/vllm/model_executor/kernels/mhc
|
||||
COPY vllm/patches/kernels/mhc/torch.py ${VLLM_MHC_KERNELS_DIR}/torch.py
|
||||
COPY vllm/patches/kernels/mhc/__init__.py ${VLLM_MHC_KERNELS_DIR}/__init__.py
|
||||
COPY vllm/patches/kernels/mhc_torch_ops.py ${VLLM_MHC_KERNELS_DIR}/mhc_torch_ops.py
|
||||
RUN echo 'from .torch import *' > ${VLLM_MHC_KERNELS_DIR}/__init__.py && \
|
||||
echo 'from .mhc_torch_ops import *' >> ${VLLM_MHC_KERNELS_DIR}/__init__.py
|
||||
|
||||
# CuTeDSL NVFP4 linear kernel (registered as NvFp4LinearKernel)
|
||||
ARG VLLM_NVFP4_DIR=/usr/local/lib/python3.12/dist-packages/vllm/model_executor/kernels/linear/nvfp4
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Patched MHC kernels — pure PyTorch only, no TileLang/Triton.
|
||||
# Avoids TileLang JIT compilation on Blackwell (SM100).
|
||||
|
||||
from .torch import *
|
||||
|
||||
__all__ = [
|
||||
"mhc_pre_torch",
|
||||
"mhc_post_torch",
|
||||
"mhc_fused_post_pre_torch",
|
||||
"hc_head_fused_torch",
|
||||
]
|
||||
@@ -1,122 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Patched MHC torch implementations — adds missing fused and hc_head ops.
|
||||
# Original vllm torch.py only has mhc_pre_torch and mhc_post_torch.
|
||||
# We add mhc_fused_post_pre_torch and hc_head_fused_torch.
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def mhc_pre_torch(
|
||||
residual: torch.Tensor,
|
||||
fn: torch.Tensor,
|
||||
hc_scale: torch.Tensor,
|
||||
hc_base: torch.Tensor,
|
||||
rms_eps: float,
|
||||
hc_pre_eps: float,
|
||||
hc_sinkhorn_eps: float,
|
||||
hc_post_mult_value: float,
|
||||
sinkhorn_repeat: int,
|
||||
n_splits: int = 1,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
hc_mult = residual.shape[-2]
|
||||
hidden_size = residual.shape[-1]
|
||||
hc_mult2 = hc_mult * hc_mult
|
||||
hc_mult3 = hc_mult * 2 + hc_mult2
|
||||
hc_hidden_size = hc_mult * hidden_size
|
||||
outer_shape = residual.shape[:-2]
|
||||
|
||||
residual_flat = residual.view(-1, hc_mult, hidden_size)
|
||||
num_tokens = residual_flat.shape[0]
|
||||
|
||||
x = residual_flat.view(num_tokens, hc_hidden_size).to(torch.float32)
|
||||
mixes = torch.matmul(x, fn.t())
|
||||
sqrsum = x.square().sum(dim=-1, keepdim=True)
|
||||
mixes = mixes * torch.rsqrt(sqrsum / hc_hidden_size + rms_eps)
|
||||
|
||||
pre_logits = mixes[:, :hc_mult] * hc_scale[0] + hc_base[:hc_mult]
|
||||
pre_mix = torch.sigmoid(pre_logits) + hc_pre_eps
|
||||
|
||||
post_logits = mixes[:, hc_mult:2 * hc_mult] * hc_scale[1] + hc_base[hc_mult:2 * hc_mult]
|
||||
post_mix = torch.sigmoid(post_logits) * hc_post_mult_value
|
||||
|
||||
comb_logits = (mixes[:, 2 * hc_mult:]
|
||||
.view(num_tokens, hc_mult, hc_mult)
|
||||
* hc_scale[2]
|
||||
+ hc_base[2 * hc_mult:].view(1, hc_mult, hc_mult))
|
||||
comb_mix = torch.softmax(comb_logits, dim=-1) + hc_sinkhorn_eps
|
||||
comb_mix = comb_mix / (comb_mix.sum(dim=-2, keepdim=True) + hc_sinkhorn_eps)
|
||||
for _ in range(sinkhorn_repeat - 1):
|
||||
comb_mix = comb_mix / (comb_mix.sum(dim=-1, keepdim=True) + hc_sinkhorn_eps)
|
||||
comb_mix = comb_mix / (comb_mix.sum(dim=-2, keepdim=True) + hc_sinkhorn_eps)
|
||||
|
||||
layer_input = torch.sum(
|
||||
pre_mix.unsqueeze(-1) * residual_flat.to(torch.float32), dim=1
|
||||
).to(torch.bfloat16)
|
||||
|
||||
return (
|
||||
post_mix.view(*outer_shape, hc_mult, 1),
|
||||
comb_mix.view(*outer_shape, hc_mult, hc_mult),
|
||||
layer_input.view(*outer_shape, hidden_size),
|
||||
)
|
||||
|
||||
|
||||
def mhc_post_torch(
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
post_layer_mix: torch.Tensor,
|
||||
comb_res_mix: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
mixed_residual = torch.einsum(
|
||||
"...ij,...ih->...jh",
|
||||
comb_res_mix.to(torch.float32),
|
||||
residual.to(torch.float32),
|
||||
)
|
||||
post_term = post_layer_mix.to(torch.float32) * x.unsqueeze(-2).to(torch.float32)
|
||||
return (mixed_residual + post_term).to(residual.dtype)
|
||||
|
||||
|
||||
def mhc_fused_post_pre_torch(
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
post_layer_mix: torch.Tensor,
|
||||
comb_res_mix: torch.Tensor,
|
||||
fn: torch.Tensor,
|
||||
hc_scale: torch.Tensor,
|
||||
hc_base: torch.Tensor,
|
||||
rms_eps: float,
|
||||
hc_pre_eps: float,
|
||||
hc_sinkhorn_eps: float,
|
||||
hc_post_mult_value: float,
|
||||
sinkhorn_repeat: int,
|
||||
n_splits: int = 1,
|
||||
tile_n: int = 1,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
new_residual = mhc_post_torch(x, residual, post_layer_mix, comb_res_mix)
|
||||
post_mix, res_mix, layer_input = mhc_pre_torch(
|
||||
new_residual, fn, hc_scale, hc_base,
|
||||
rms_eps, hc_pre_eps, hc_sinkhorn_eps,
|
||||
hc_post_mult_value, sinkhorn_repeat, n_splits,
|
||||
)
|
||||
return new_residual, post_mix, res_mix, layer_input
|
||||
|
||||
|
||||
def hc_head_fused_torch(
|
||||
hs_flat: torch.Tensor,
|
||||
fn: torch.Tensor,
|
||||
hc_scale: torch.Tensor,
|
||||
hc_base: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
hidden_size: int,
|
||||
rms_eps: float,
|
||||
hc_eps: float,
|
||||
hc_mult: int,
|
||||
) -> None:
|
||||
x_flat = hs_flat.flatten(-2)
|
||||
sqrsum = x_flat.to(torch.float32).square().sum(dim=-1, keepdim=True)
|
||||
x_normed = x_flat * torch.rsqrt(sqrsum / x_flat.shape[-1] + rms_eps)
|
||||
mixes = torch.nn.functional.linear(x_normed.to(torch.float32), fn)
|
||||
pre = torch.sigmoid(mixes * hc_scale + hc_base) + hc_eps
|
||||
result = torch.sum(
|
||||
pre.unsqueeze(-1) * hs_flat.to(torch.float32), dim=1
|
||||
).to(torch.bfloat16)
|
||||
out.copy_(result)
|
||||
@@ -1,15 +1,15 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Patched MHC layer — replaces TileLang kernels with pure PyTorch.
|
||||
# This avoids TileLang JIT compilation on Blackwell (SM100).
|
||||
# Pure PyTorch MHC kernels for DeepSeek V4.
|
||||
# Replaces TileLang kernels to avoid TileLang JIT compilation on Blackwell (SM100).
|
||||
# Registers torch.ops.vllm.mhc_pre, mhc_post, mhc_fused_post_pre, hc_head_fused_kernel.
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
|
||||
# ── Pure PyTorch MHC implementations ──────────────────────────────────
|
||||
# ── Pure PyTorch implementations ──────────────────────────────────────
|
||||
|
||||
def _mhc_pre_torch(
|
||||
def mhc_pre(
|
||||
residual: torch.Tensor,
|
||||
fn: torch.Tensor,
|
||||
hc_scale: torch.Tensor,
|
||||
@@ -63,7 +63,29 @@ def _mhc_pre_torch(
|
||||
)
|
||||
|
||||
|
||||
def _mhc_post_torch(
|
||||
def _mhc_pre_fake(
|
||||
residual: torch.Tensor,
|
||||
fn: torch.Tensor,
|
||||
hc_scale: torch.Tensor,
|
||||
hc_base: torch.Tensor,
|
||||
rms_eps: float,
|
||||
hc_pre_eps: float,
|
||||
hc_sinkhorn_eps: float,
|
||||
hc_post_mult_value: float,
|
||||
sinkhorn_repeat: int,
|
||||
n_splits: int = 1,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
hc_mult = residual.shape[-2]
|
||||
hidden_size = residual.shape[-1]
|
||||
outer_shape = residual.shape[:-2]
|
||||
return (
|
||||
torch.empty(*outer_shape, hc_mult, 1, dtype=torch.float32, device=residual.device),
|
||||
torch.empty(*outer_shape, hc_mult, hc_mult, dtype=torch.float32, device=residual.device),
|
||||
torch.empty(*outer_shape, hidden_size, dtype=torch.bfloat16, device=residual.device),
|
||||
)
|
||||
|
||||
|
||||
def mhc_post(
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
post_layer_mix: torch.Tensor,
|
||||
@@ -78,7 +100,16 @@ def _mhc_post_torch(
|
||||
return (mixed_residual + post_term).to(residual.dtype)
|
||||
|
||||
|
||||
def _mhc_fused_post_pre_torch(
|
||||
def _mhc_post_fake(
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
post_layer_mix: torch.Tensor,
|
||||
comb_res_mix: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(residual)
|
||||
|
||||
|
||||
def mhc_fused_post_pre(
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
post_layer_mix: torch.Tensor,
|
||||
@@ -94,8 +125,8 @@ def _mhc_fused_post_pre_torch(
|
||||
n_splits: int = 1,
|
||||
tile_n: int = 1,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
new_residual = _mhc_post_torch(x, residual, post_layer_mix, comb_res_mix)
|
||||
post_mix, res_mix, layer_input = _mhc_pre_torch(
|
||||
new_residual = mhc_post(x, residual, post_layer_mix, comb_res_mix)
|
||||
post_mix, res_mix, layer_input = mhc_pre(
|
||||
new_residual, fn, hc_scale, hc_base,
|
||||
rms_eps, hc_pre_eps, hc_sinkhorn_eps,
|
||||
hc_post_mult_value, sinkhorn_repeat, n_splits,
|
||||
@@ -103,7 +134,34 @@ def _mhc_fused_post_pre_torch(
|
||||
return new_residual, post_mix, res_mix, layer_input
|
||||
|
||||
|
||||
def _hc_head_fused_torch(
|
||||
def _mhc_fused_post_pre_fake(
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
post_layer_mix: torch.Tensor,
|
||||
comb_res_mix: torch.Tensor,
|
||||
fn: torch.Tensor,
|
||||
hc_scale: torch.Tensor,
|
||||
hc_base: torch.Tensor,
|
||||
rms_eps: float,
|
||||
hc_pre_eps: float,
|
||||
hc_sinkhorn_eps: float,
|
||||
hc_post_mult_value: float,
|
||||
sinkhorn_repeat: int,
|
||||
n_splits: int = 1,
|
||||
tile_n: int = 1,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
hc_mult = residual.shape[-2]
|
||||
hidden_size = residual.shape[-1]
|
||||
outer_shape = residual.shape[:-2]
|
||||
return (
|
||||
torch.empty_like(residual),
|
||||
torch.empty(*outer_shape, hc_mult, 1, dtype=torch.float32, device=residual.device),
|
||||
torch.empty(*outer_shape, hc_mult, hc_mult, dtype=torch.float32, device=residual.device),
|
||||
torch.empty(*outer_shape, hidden_size, dtype=torch.bfloat16, device=residual.device),
|
||||
)
|
||||
|
||||
|
||||
def hc_head_fused_kernel(
|
||||
hs_flat: torch.Tensor,
|
||||
fn: torch.Tensor,
|
||||
hc_scale: torch.Tensor,
|
||||
@@ -125,93 +183,48 @@ def _hc_head_fused_torch(
|
||||
out.copy_(result)
|
||||
|
||||
|
||||
# ── CustomOp registrations ────────────────────────────────────────────
|
||||
|
||||
@CustomOp.register("mhc_pre")
|
||||
class MHCPreOp(CustomOp):
|
||||
@classmethod
|
||||
def enabled(cls) -> bool:
|
||||
return True
|
||||
|
||||
def forward_cuda(self, residual, fn, hc_scale, hc_base,
|
||||
rms_eps, hc_pre_eps, hc_sinkhorn_eps,
|
||||
hc_post_mult_value, sinkhorn_repeat, n_splits=1):
|
||||
return _mhc_pre_torch(
|
||||
residual, fn, hc_scale, hc_base,
|
||||
rms_eps, hc_pre_eps, hc_sinkhorn_eps,
|
||||
hc_post_mult_value, sinkhorn_repeat, n_splits,
|
||||
)
|
||||
|
||||
def forward_hip(self, *args, **kwargs):
|
||||
return self.forward_cuda(*args, **kwargs)
|
||||
|
||||
def forward_native(self, *args, **kwargs):
|
||||
return self.forward_cuda(*args, **kwargs)
|
||||
def _hc_head_fused_kernel_fake(
|
||||
hs_flat: torch.Tensor,
|
||||
fn: torch.Tensor,
|
||||
hc_scale: torch.Tensor,
|
||||
hc_base: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
hidden_size: int,
|
||||
rms_eps: float,
|
||||
hc_eps: float,
|
||||
hc_mult: int,
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
|
||||
@CustomOp.register("mhc_post")
|
||||
class MHCPostOp(CustomOp):
|
||||
@classmethod
|
||||
def enabled(cls) -> bool:
|
||||
return True
|
||||
# ── Register as torch custom ops ──────────────────────────────────────
|
||||
# These replace the TileLang-registered ops that the model code calls via
|
||||
# torch.ops.vllm.mhc_pre, torch.ops.vllm.mhc_post, etc.
|
||||
|
||||
def forward_cuda(self, x, residual, post_layer_mix, comb_res_mix):
|
||||
return _mhc_post_torch(x, residual, post_layer_mix, comb_res_mix)
|
||||
direct_register_custom_op(
|
||||
op_name="mhc_pre",
|
||||
op_func=mhc_pre,
|
||||
mutates_args=[],
|
||||
fake_impl=_mhc_pre_fake,
|
||||
)
|
||||
|
||||
def forward_hip(self, *args, **kwargs):
|
||||
return self.forward_cuda(*args, **kwargs)
|
||||
direct_register_custom_op(
|
||||
op_name="mhc_post",
|
||||
op_func=mhc_post,
|
||||
mutates_args=[],
|
||||
fake_impl=_mhc_post_fake,
|
||||
)
|
||||
|
||||
def forward_native(self, *args, **kwargs):
|
||||
return self.forward_cuda(*args, **kwargs)
|
||||
direct_register_custom_op(
|
||||
op_name="mhc_fused_post_pre",
|
||||
op_func=mhc_fused_post_pre,
|
||||
mutates_args=[],
|
||||
fake_impl=_mhc_fused_post_pre_fake,
|
||||
)
|
||||
|
||||
|
||||
@CustomOp.register("hc_head")
|
||||
class HCHeadOp(CustomOp):
|
||||
@classmethod
|
||||
def enabled(cls) -> bool:
|
||||
return True
|
||||
|
||||
def forward_cuda(self, hidden_states, hc_fn, hc_scale, hc_base,
|
||||
rms_norm_eps, hc_eps):
|
||||
hc_mult, hidden_size = hidden_states.shape[-2:]
|
||||
outer_shape = hidden_states.shape[:-2]
|
||||
hs_flat = hidden_states.view(-1, hc_mult, hidden_size)
|
||||
out = torch.empty(
|
||||
hs_flat.shape[0], hidden_size,
|
||||
dtype=torch.bfloat16, device=hidden_states.device,
|
||||
)
|
||||
_hc_head_fused_torch(
|
||||
hs_flat, hc_fn, hc_scale, hc_base,
|
||||
out, hidden_size, rms_norm_eps, hc_eps, hc_mult,
|
||||
)
|
||||
return out.view(*outer_shape, hidden_size)
|
||||
|
||||
def forward_hip(self, *args, **kwargs):
|
||||
return self.forward_cuda(*args, **kwargs)
|
||||
|
||||
def forward_native(self, *args, **kwargs):
|
||||
return self.forward_cuda(*args, **kwargs)
|
||||
|
||||
|
||||
@CustomOp.register("mhc_fused_post_pre")
|
||||
class MHCFusedPostPreOp(CustomOp):
|
||||
@classmethod
|
||||
def enabled(cls) -> bool:
|
||||
return True
|
||||
|
||||
def forward_cuda(self, x, residual, post_layer_mix, comb_res_mix,
|
||||
fn, hc_scale, hc_base, rms_eps, hc_pre_eps,
|
||||
hc_sinkhorn_eps, hc_post_mult_value, sinkhorn_repeat,
|
||||
n_splits=1, tile_n=1):
|
||||
return _mhc_fused_post_pre_torch(
|
||||
x, residual, post_layer_mix, comb_res_mix,
|
||||
fn, hc_scale, hc_base,
|
||||
rms_eps, hc_pre_eps, hc_sinkhorn_eps,
|
||||
hc_post_mult_value, sinkhorn_repeat, n_splits, tile_n,
|
||||
)
|
||||
|
||||
def forward_hip(self, *args, **kwargs):
|
||||
return self.forward_cuda(*args, **kwargs)
|
||||
|
||||
def forward_native(self, *args, **kwargs):
|
||||
return self.forward_cuda(*args, **kwargs)
|
||||
direct_register_custom_op(
|
||||
op_name="hc_head_fused_kernel",
|
||||
op_func=hc_head_fused_kernel,
|
||||
mutates_args=["out"],
|
||||
fake_impl=_hc_head_fused_kernel_fake,
|
||||
)
|
||||
Reference in New Issue
Block a user