Fix custom_op registration: use as decorator with proper type hints

This commit is contained in:
2026-05-19 00:54:30 +00:00
parent c609e9ba3c
commit e0eb436914

View File

@@ -21,79 +21,73 @@ from .base import NvFp4LinearKernel, NvFp4LinearLayerConfig
logger = init_logger(__name__)
def _cutedsl_nvfp4_gemm_impl(
x: torch.Tensor,
mat_b: torch.Tensor,
scale_b: torch.Tensor,
global_scale_b: torch.Tensor,
in_features: int,
out_features: int,
activation_global_scale: float,
) -> torch.Tensor:
"""Run a single-group NVFP4 GEMM via CuTeDSL."""
from cutedsl.bridge import (pad_and_swizzle_single,
quantize_activation_nvfp4,
run_nvfp4_grouped_gemm)
from cutedsl.nvfp4_linear import cutedsl_ceil_div
num_tokens = x.shape[0]
padded_rows = cutedsl_ceil_div(num_tokens, 128) * 128
# Quantize activation
x_fp4, x_sf = quantize_activation_nvfp4(x, activation_global_scale)
# Pad activation to 128-row alignment for TMA
if num_tokens < padded_rows:
x_fp4_padded = torch.zeros(padded_rows, x_fp4.shape[1],
dtype=x_fp4.dtype, device=x.device)
x_fp4_padded[:num_tokens] = x_fp4
else:
x_fp4_padded = x_fp4
# Assemble A-side scales in CuTeDSL layout
scale_a = pad_and_swizzle_single(x_sf, num_tokens, x_sf.shape[1], x.device)
# Expert offsets for 1 group (all tokens in one group)
expert_offsets = torch.tensor([padded_rows], dtype=torch.int64, device=x.device)
# Global scale for activation
global_scale_a = torch.tensor([activation_global_scale], dtype=torch.float32, device=x.device)
# Run the CuTeDSL grouped GEMM (1 group)
out = run_nvfp4_grouped_gemm(
mat_a=x_fp4_padded,
mat_b=mat_b,
scale_a=scale_a,
scale_b=scale_b,
expert_offsets=expert_offsets,
global_scale_a=global_scale_a,
global_scale_b=global_scale_b,
)
return out[:num_tokens]
def _cutedsl_nvfp4_gemm_fake(
x: torch.Tensor,
mat_b: torch.Tensor,
scale_b: torch.Tensor,
global_scale_b: torch.Tensor,
in_features: int,
out_features: int,
activation_global_scale: float,
) -> torch.Tensor:
return torch.empty((*x.shape[:-1], out_features), dtype=torch.bfloat16,
device=x.device)
# Register custom op (idempotent — safe to import multiple times)
if not hasattr(torch.ops, 'cutedsl') or not hasattr(torch.ops.cutedsl, 'nvfp4_gemm'):
_CUTEDSL_NVFP4_GEMM = torch.library.custom_op(
"cutedsl::nvfp4_gemm",
_cutedsl_nvfp4_gemm_impl,
mutates_args=(),
)(_cutedsl_nvfp4_gemm_impl)
_CUTEDSL_NVFP4_GEMM.register_fake(_cutedsl_nvfp4_gemm_fake)
@torch.library.custom_op("cutedsl::nvfp4_gemm", mutates_args=())
def _cutedsl_nvfp4_gemm(
x: torch.Tensor,
mat_b: torch.Tensor,
scale_b: torch.Tensor,
global_scale_b: torch.Tensor,
in_features: int,
out_features: int,
activation_global_scale: float,
) -> torch.Tensor:
"""Run a single-group NVFP4 GEMM via CuTeDSL."""
from cutedsl.bridge import (pad_and_swizzle_single,
quantize_activation_nvfp4,
run_nvfp4_grouped_gemm)
from cutedsl.nvfp4_linear import cutedsl_ceil_div
num_tokens = x.shape[0]
padded_rows = cutedsl_ceil_div(num_tokens, 128) * 128
# Quantize activation
x_fp4, x_sf = quantize_activation_nvfp4(x, activation_global_scale)
# Pad activation to 128-row alignment for TMA
if num_tokens < padded_rows:
x_fp4_padded = torch.zeros(padded_rows, x_fp4.shape[1],
dtype=x_fp4.dtype, device=x.device)
x_fp4_padded[:num_tokens] = x_fp4
else:
x_fp4_padded = x_fp4
# Assemble A-side scales in CuTeDSL layout
scale_a = pad_and_swizzle_single(x_sf, num_tokens, x_sf.shape[1], x.device)
# Expert offsets for 1 group (all tokens in one group)
expert_offsets = torch.tensor([padded_rows], dtype=torch.int64, device=x.device)
# Global scale for activation
global_scale_a = torch.tensor([activation_global_scale], dtype=torch.float32, device=x.device)
# Run the CuTeDSL grouped GEMM (1 group)
out = run_nvfp4_grouped_gemm(
mat_a=x_fp4_padded,
mat_b=mat_b,
scale_a=scale_a,
scale_b=scale_b,
expert_offsets=expert_offsets,
global_scale_a=global_scale_a,
global_scale_b=global_scale_b,
)
return out[:num_tokens]
@_cutedsl_nvfp4_gemm.register_fake
def _cutedsl_nvfp4_gemm_fake(
x: torch.Tensor,
mat_b: torch.Tensor,
scale_b: torch.Tensor,
global_scale_b: torch.Tensor,
in_features: int,
out_features: int,
activation_global_scale: float,
) -> torch.Tensor:
return torch.empty((*x.shape[:-1], out_features), dtype=torch.bfloat16,
device=x.device)
class CuTeDSLNvFp4LinearKernel(NvFp4LinearKernel):