From 24fed15ed63b52f0e45fe3242859979eb5b9c3ff Mon Sep 17 00:00:00 2001 From: biondizzle Date: Mon, 1 Jun 2026 10:02:40 +0000 Subject: [PATCH] Fix: convert PyTorch tensors to CuTe tensors for fused router kernel - Added cutlass_torch.from_dlpack() + mark_layout_dynamic() conversions - quantize_activation_nvfp4 returns (fp4_packed, fp8_scales) which are converted to CuTe tensors before passing to the kernel - Same pattern as gemm_runner.py --- .../router/nvfp4_fused_router_kernel.py | 34 ++++++++++++++----- 1 file changed, 26 insertions(+), 8 deletions(-) diff --git a/dsv4/kernels/router/nvfp4_fused_router_kernel.py b/dsv4/kernels/router/nvfp4_fused_router_kernel.py index d693abe2..224cd52f 100644 --- a/dsv4/kernels/router/nvfp4_fused_router_kernel.py +++ b/dsv4/kernels/router/nvfp4_fused_router_kernel.py @@ -818,13 +818,31 @@ def run_nvfp4_fused_router( # Quantize activation to NVFP4 from dsv4.ops.quantize import quantize_activation_nvfp4 - mat_a, scale_a = quantize_activation_nvfp4(hidden_states, gsa) + mat_a_bf16_packed, scale_a_fp8 = quantize_activation_nvfp4(hidden_states, gsa) # Output tensor: FP32 activated scores [N, E] - # We write sqrt(softplus(logits)) + e_bias here, - # then top-k reads from it activated_scores = torch.empty(N, E, dtype=torch.float32, device=device) + # Convert PyTorch tensors to CuTe tensors (same as gemm_runner.py pattern) + import cutlass.torch as cutlass_torch + + def _to_cute(t, leading_dim=None): + ct = cutlass_torch.from_dlpack(t) + if leading_dim is not None: + return ct.mark_layout_dynamic(leading_dim=leading_dim) + return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t)) + + # Determine leading dimensions from tensor shapes + # mat_a_bf16_packed: [N, K_packed] — K-major (row-major for GEMM A) + # mat_b: [E, K_packed] — K-major (col-major for GEMM B, i.e. N-major) + # Actually, for NVFP4 GEMM: A is M-major, B is N-major + # Check the existing Nvfp4Linear to see how it handles this + cute_a = _to_cute(mat_a_bf16_packed) + cute_b = _to_cute(mat_b) + cute_sfa = _to_cute(scale_a_fp8) + cute_sfb = _to_cute(scale_b) + cute_c = _to_cute(activated_scores) + # Run the CuTeDSL kernel: NVFP4 GEMM + sqrt(softplus) epilogue kernel = Nvfp4FusedRouterKernel( sf_vec_size=16, @@ -832,11 +850,11 @@ def run_nvfp4_fused_router( cluster_shape_mnk=(1, 1, 1), ) kernel.run( - mat_a=mat_a, - mat_b=mat_b, - scale_a=scale_a, - scale_b=scale_b, - mat_c=activated_scores, + mat_a=cute_a, + mat_b=cute_b, + scale_a=cute_sfa, + scale_b=cute_sfb, + mat_c=cute_c, M=N, N=E, K=K, )