From fcd76805833892edf1599a6b4a53160164a0e3ca Mon Sep 17 00:00:00 2001 From: biondizzle Date: Mon, 1 Jun 2026 07:12:52 +0000 Subject: [PATCH] Fix CuTe tensor creation: use from_dlpack + mark_layout_dynamic --- .../router/nvfp4_fused_router_kernel.py | 27 +++++++++++-------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/dsv4/kernels/router/nvfp4_fused_router_kernel.py b/dsv4/kernels/router/nvfp4_fused_router_kernel.py index 175fd377..b2aec19d 100644 --- a/dsv4/kernels/router/nvfp4_fused_router_kernel.py +++ b/dsv4/kernels/router/nvfp4_fused_router_kernel.py @@ -728,17 +728,22 @@ def run_nvfp4_fused_router( from dsv4.ops.quantize import quantize_activation_nvfp4 x_fp4, x_sf = quantize_activation_nvfp4(hidden_states, gsa) - # Create CuTe tensors - a_tensor = cutlass_torch.as_tensor(x_fp4, shape=x_fp4.shape) - b_tensor = cutlass_torch.as_tensor(mat_b, shape=mat_b.shape) - sfa_tensor = cutlass_torch.as_tensor(x_sf, shape=x_sf.shape) - sfb_tensor = cutlass_torch.as_tensor(scale_b, shape=scale_b.shape) - e_bias_ct = cutlass_torch.as_tensor(e_bias, shape=e_bias.shape) - out_w_ct = cutlass_torch.as_tensor(out_weights, shape=out_weights.shape) - out_id_ct = cutlass_torch.as_tensor(out_ids, shape=out_ids.shape) - eo_ct = cutlass_torch.as_tensor(expert_offsets, shape=expert_offsets.shape) - gsa_ct = cutlass_torch.as_tensor(gsa_t, shape=gsa_t.shape) - gsb_ct = cutlass_torch.as_tensor(gsb_t, shape=gsb_t.shape) + # Create CuTe tensors (from_dlpack + mark_layout_dynamic for proper TMA) + def _to_cute(t): + ct = cutlass_torch.from_dlpack(t) + ct = ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t)) + return ct + + a_tensor = _to_cute(x_fp4) + b_tensor = _to_cute(mat_b) + sfa_tensor = _to_cute(x_sf) + sfb_tensor = _to_cute(scale_b) + e_bias_ct = _to_cute(e_bias) + out_w_ct = _to_cute(out_weights) + out_id_ct = _to_cute(out_ids) + eo_ct = _to_cute(expert_offsets) + gsa_ct = _to_cute(gsa_t) + gsb_ct = _to_cute(gsb_t) kernel = Nvfp4FusedRouterKernel(top_k=top_k) kernel.run(