Fix custom_op registration: use as decorator with proper type hints
This commit is contained in:
@@ -21,79 +21,73 @@ from .base import NvFp4LinearKernel, NvFp4LinearLayerConfig
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _cutedsl_nvfp4_gemm_impl(
|
||||
x: torch.Tensor,
|
||||
mat_b: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
global_scale_b: torch.Tensor,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
activation_global_scale: float,
|
||||
) -> torch.Tensor:
|
||||
"""Run a single-group NVFP4 GEMM via CuTeDSL."""
|
||||
from cutedsl.bridge import (pad_and_swizzle_single,
|
||||
quantize_activation_nvfp4,
|
||||
run_nvfp4_grouped_gemm)
|
||||
from cutedsl.nvfp4_linear import cutedsl_ceil_div
|
||||
|
||||
num_tokens = x.shape[0]
|
||||
padded_rows = cutedsl_ceil_div(num_tokens, 128) * 128
|
||||
|
||||
# Quantize activation
|
||||
x_fp4, x_sf = quantize_activation_nvfp4(x, activation_global_scale)
|
||||
|
||||
# Pad activation to 128-row alignment for TMA
|
||||
if num_tokens < padded_rows:
|
||||
x_fp4_padded = torch.zeros(padded_rows, x_fp4.shape[1],
|
||||
dtype=x_fp4.dtype, device=x.device)
|
||||
x_fp4_padded[:num_tokens] = x_fp4
|
||||
else:
|
||||
x_fp4_padded = x_fp4
|
||||
|
||||
# Assemble A-side scales in CuTeDSL layout
|
||||
scale_a = pad_and_swizzle_single(x_sf, num_tokens, x_sf.shape[1], x.device)
|
||||
|
||||
# Expert offsets for 1 group (all tokens in one group)
|
||||
expert_offsets = torch.tensor([padded_rows], dtype=torch.int64, device=x.device)
|
||||
|
||||
# Global scale for activation
|
||||
global_scale_a = torch.tensor([activation_global_scale], dtype=torch.float32, device=x.device)
|
||||
|
||||
# Run the CuTeDSL grouped GEMM (1 group)
|
||||
out = run_nvfp4_grouped_gemm(
|
||||
mat_a=x_fp4_padded,
|
||||
mat_b=mat_b,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
expert_offsets=expert_offsets,
|
||||
global_scale_a=global_scale_a,
|
||||
global_scale_b=global_scale_b,
|
||||
)
|
||||
|
||||
return out[:num_tokens]
|
||||
|
||||
|
||||
def _cutedsl_nvfp4_gemm_fake(
|
||||
x: torch.Tensor,
|
||||
mat_b: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
global_scale_b: torch.Tensor,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
activation_global_scale: float,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty((*x.shape[:-1], out_features), dtype=torch.bfloat16,
|
||||
device=x.device)
|
||||
|
||||
|
||||
# Register custom op (idempotent — safe to import multiple times)
|
||||
if not hasattr(torch.ops, 'cutedsl') or not hasattr(torch.ops.cutedsl, 'nvfp4_gemm'):
|
||||
_CUTEDSL_NVFP4_GEMM = torch.library.custom_op(
|
||||
"cutedsl::nvfp4_gemm",
|
||||
_cutedsl_nvfp4_gemm_impl,
|
||||
mutates_args=(),
|
||||
)(_cutedsl_nvfp4_gemm_impl)
|
||||
_CUTEDSL_NVFP4_GEMM.register_fake(_cutedsl_nvfp4_gemm_fake)
|
||||
|
||||
@torch.library.custom_op("cutedsl::nvfp4_gemm", mutates_args=())
|
||||
def _cutedsl_nvfp4_gemm(
|
||||
x: torch.Tensor,
|
||||
mat_b: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
global_scale_b: torch.Tensor,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
activation_global_scale: float,
|
||||
) -> torch.Tensor:
|
||||
"""Run a single-group NVFP4 GEMM via CuTeDSL."""
|
||||
from cutedsl.bridge import (pad_and_swizzle_single,
|
||||
quantize_activation_nvfp4,
|
||||
run_nvfp4_grouped_gemm)
|
||||
from cutedsl.nvfp4_linear import cutedsl_ceil_div
|
||||
|
||||
num_tokens = x.shape[0]
|
||||
padded_rows = cutedsl_ceil_div(num_tokens, 128) * 128
|
||||
|
||||
# Quantize activation
|
||||
x_fp4, x_sf = quantize_activation_nvfp4(x, activation_global_scale)
|
||||
|
||||
# Pad activation to 128-row alignment for TMA
|
||||
if num_tokens < padded_rows:
|
||||
x_fp4_padded = torch.zeros(padded_rows, x_fp4.shape[1],
|
||||
dtype=x_fp4.dtype, device=x.device)
|
||||
x_fp4_padded[:num_tokens] = x_fp4
|
||||
else:
|
||||
x_fp4_padded = x_fp4
|
||||
|
||||
# Assemble A-side scales in CuTeDSL layout
|
||||
scale_a = pad_and_swizzle_single(x_sf, num_tokens, x_sf.shape[1], x.device)
|
||||
|
||||
# Expert offsets for 1 group (all tokens in one group)
|
||||
expert_offsets = torch.tensor([padded_rows], dtype=torch.int64, device=x.device)
|
||||
|
||||
# Global scale for activation
|
||||
global_scale_a = torch.tensor([activation_global_scale], dtype=torch.float32, device=x.device)
|
||||
|
||||
# Run the CuTeDSL grouped GEMM (1 group)
|
||||
out = run_nvfp4_grouped_gemm(
|
||||
mat_a=x_fp4_padded,
|
||||
mat_b=mat_b,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
expert_offsets=expert_offsets,
|
||||
global_scale_a=global_scale_a,
|
||||
global_scale_b=global_scale_b,
|
||||
)
|
||||
|
||||
return out[:num_tokens]
|
||||
|
||||
@_cutedsl_nvfp4_gemm.register_fake
|
||||
def _cutedsl_nvfp4_gemm_fake(
|
||||
x: torch.Tensor,
|
||||
mat_b: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
global_scale_b: torch.Tensor,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
activation_global_scale: float,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty((*x.shape[:-1], out_features), dtype=torch.bfloat16,
|
||||
device=x.device)
|
||||
|
||||
|
||||
class CuTeDSLNvFp4LinearKernel(NvFp4LinearKernel):
|
||||
|
||||
Reference in New Issue
Block a user