Fix: use tensor.mark_layout_dynamic() method (not cute.mark_layout_dynamic)
This commit is contained in:
@@ -1056,14 +1056,14 @@ def run_nvfp4_fused_router(
|
||||
# A tensor: [K_packed, M, L] where K_packed = K/2 (2 elements per byte for FP4)
|
||||
K_packed = K // 2
|
||||
mat_a = cutlass_torch.from_dlpack(act_nvfp4)
|
||||
mat_a = cute.mark_layout_dynamic(mat_a)
|
||||
mat_a = mat_a.mark_layout_dynamic()
|
||||
# SFA tensor: [K_sf, M, L]
|
||||
scale_a = cutlass_torch.from_dlpack(act_sf)
|
||||
scale_a = cute.mark_layout_dynamic(scale_a)
|
||||
scale_a = scale_a.mark_layout_dynamic()
|
||||
|
||||
# e_bias must be a CuTe tensor
|
||||
e_bias_cute = cutlass_torch.from_dlpack(e_bias)
|
||||
e_bias_cute = cute.mark_layout_dynamic(e_bias_cute)
|
||||
e_bias_cute = e_bias_cute.mark_layout_dynamic()
|
||||
|
||||
# Number of experts from e_bias
|
||||
E = e_bias.shape[0]
|
||||
@@ -1072,9 +1072,9 @@ def run_nvfp4_fused_router(
|
||||
out_weights = torch.zeros(M, top_k, dtype=torch.float32, device=device)
|
||||
out_ids = torch.zeros(M, top_k, dtype=torch.int32, device=device)
|
||||
out_w_cute = cutlass_torch.from_dlpack(out_weights)
|
||||
out_w_cute = cute.mark_layout_dynamic(out_w_cute)
|
||||
out_w_cute = out_w_cute.mark_layout_dynamic()
|
||||
out_id_cute = cutlass_torch.from_dlpack(out_ids)
|
||||
out_id_cute = cute.mark_layout_dynamic(out_id_cute)
|
||||
out_id_cute = out_id_cute.mark_layout_dynamic()
|
||||
|
||||
# MMA tiler: (128, 128, 64) for decode
|
||||
mma_tiler_mnk = (128, 128, 64)
|
||||
|
||||
@@ -102,9 +102,9 @@ def test_fused_router():
|
||||
|
||||
# CuTe tensors for A (activation)
|
||||
mat_a = cutlass_torch.from_dlpack(act_nvfp4)
|
||||
mat_a = cute.mark_layout_dynamic(mat_a)
|
||||
mat_a = mat_a.mark_layout_dynamic()
|
||||
scale_a = cutlass_torch.from_dlpack(act_sf)
|
||||
scale_a = cute.mark_layout_dynamic(scale_a)
|
||||
scale_a = scale_a.mark_layout_dynamic()
|
||||
|
||||
# CuTe tensors for B (weight) — from gate_lin
|
||||
mat_b = gate_lin._mat_b
|
||||
@@ -112,15 +112,15 @@ def test_fused_router():
|
||||
|
||||
# e_bias CuTe tensor
|
||||
e_bias_cute = cutlass_torch.from_dlpack(e_bias)
|
||||
e_bias_cute = cute.mark_layout_dynamic(e_bias_cute)
|
||||
e_bias_cute = e_bias_cute.mark_layout_dynamic()
|
||||
|
||||
# Output buffers
|
||||
out_weights = torch.zeros(M, top_k, dtype=torch.float32, device=device)
|
||||
out_ids = torch.zeros(M, top_k, dtype=torch.int32, device=device)
|
||||
out_w_cute = cutlass_torch.from_dlpack(out_weights)
|
||||
out_w_cute = cute.mark_layout_dynamic(out_w_cute)
|
||||
out_w_cute = out_w_cute.mark_layout_dynamic()
|
||||
out_id_cute = cutlass_torch.from_dlpack(out_ids)
|
||||
out_id_cute = cute.mark_layout_dynamic(out_id_cute)
|
||||
out_id_cute = out_id_cute.mark_layout_dynamic()
|
||||
|
||||
kernel = Nvfp4FusedRouterKernel(
|
||||
sf_vec_size=sf_vec_size,
|
||||
|
||||
Reference in New Issue
Block a user