Fix: use tensor.mark_layout_dynamic() method (not cute.mark_layout_dynamic)

This commit is contained in:
2026-06-01 09:16:33 +00:00
parent 2412745b21
commit 483e759d53
2 changed files with 10 additions and 10 deletions

View File

@@ -1056,14 +1056,14 @@ def run_nvfp4_fused_router(
# A tensor: [K_packed, M, L] where K_packed = K/2 (2 elements per byte for FP4)
K_packed = K // 2
mat_a = cutlass_torch.from_dlpack(act_nvfp4)
mat_a = cute.mark_layout_dynamic(mat_a)
mat_a = mat_a.mark_layout_dynamic()
# SFA tensor: [K_sf, M, L]
scale_a = cutlass_torch.from_dlpack(act_sf)
scale_a = cute.mark_layout_dynamic(scale_a)
scale_a = scale_a.mark_layout_dynamic()
# e_bias must be a CuTe tensor
e_bias_cute = cutlass_torch.from_dlpack(e_bias)
e_bias_cute = cute.mark_layout_dynamic(e_bias_cute)
e_bias_cute = e_bias_cute.mark_layout_dynamic()
# Number of experts from e_bias
E = e_bias.shape[0]
@@ -1072,9 +1072,9 @@ def run_nvfp4_fused_router(
out_weights = torch.zeros(M, top_k, dtype=torch.float32, device=device)
out_ids = torch.zeros(M, top_k, dtype=torch.int32, device=device)
out_w_cute = cutlass_torch.from_dlpack(out_weights)
out_w_cute = cute.mark_layout_dynamic(out_w_cute)
out_w_cute = out_w_cute.mark_layout_dynamic()
out_id_cute = cutlass_torch.from_dlpack(out_ids)
out_id_cute = cute.mark_layout_dynamic(out_id_cute)
out_id_cute = out_id_cute.mark_layout_dynamic()
# MMA tiler: (128, 128, 64) for decode
mma_tiler_mnk = (128, 128, 64)

View File

@@ -102,9 +102,9 @@ def test_fused_router():
# CuTe tensors for A (activation)
mat_a = cutlass_torch.from_dlpack(act_nvfp4)
mat_a = cute.mark_layout_dynamic(mat_a)
mat_a = mat_a.mark_layout_dynamic()
scale_a = cutlass_torch.from_dlpack(act_sf)
scale_a = cute.mark_layout_dynamic(scale_a)
scale_a = scale_a.mark_layout_dynamic()
# CuTe tensors for B (weight) — from gate_lin
mat_b = gate_lin._mat_b
@@ -112,15 +112,15 @@ def test_fused_router():
# e_bias CuTe tensor
e_bias_cute = cutlass_torch.from_dlpack(e_bias)
e_bias_cute = cute.mark_layout_dynamic(e_bias_cute)
e_bias_cute = e_bias_cute.mark_layout_dynamic()
# Output buffers
out_weights = torch.zeros(M, top_k, dtype=torch.float32, device=device)
out_ids = torch.zeros(M, top_k, dtype=torch.int32, device=device)
out_w_cute = cutlass_torch.from_dlpack(out_weights)
out_w_cute = cute.mark_layout_dynamic(out_w_cute)
out_w_cute = out_w_cute.mark_layout_dynamic()
out_id_cute = cutlass_torch.from_dlpack(out_ids)
out_id_cute = cute.mark_layout_dynamic(out_id_cute)
out_id_cute = out_id_cute.mark_layout_dynamic()
kernel = Nvfp4FusedRouterKernel(
sf_vec_size=sf_vec_size,