Fix CuTe tensor creation: use from_dlpack + mark_layout_dynamic

This commit is contained in:
2026-06-01 07:12:52 +00:00
parent 3a8c6daeb3
commit fcd7680583

View File

@@ -728,17 +728,22 @@ def run_nvfp4_fused_router(
from dsv4.ops.quantize import quantize_activation_nvfp4
x_fp4, x_sf = quantize_activation_nvfp4(hidden_states, gsa)
# Create CuTe tensors
a_tensor = cutlass_torch.as_tensor(x_fp4, shape=x_fp4.shape)
b_tensor = cutlass_torch.as_tensor(mat_b, shape=mat_b.shape)
sfa_tensor = cutlass_torch.as_tensor(x_sf, shape=x_sf.shape)
sfb_tensor = cutlass_torch.as_tensor(scale_b, shape=scale_b.shape)
e_bias_ct = cutlass_torch.as_tensor(e_bias, shape=e_bias.shape)
out_w_ct = cutlass_torch.as_tensor(out_weights, shape=out_weights.shape)
out_id_ct = cutlass_torch.as_tensor(out_ids, shape=out_ids.shape)
eo_ct = cutlass_torch.as_tensor(expert_offsets, shape=expert_offsets.shape)
gsa_ct = cutlass_torch.as_tensor(gsa_t, shape=gsa_t.shape)
gsb_ct = cutlass_torch.as_tensor(gsb_t, shape=gsb_t.shape)
# Create CuTe tensors (from_dlpack + mark_layout_dynamic for proper TMA)
def _to_cute(t):
ct = cutlass_torch.from_dlpack(t)
ct = ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t))
return ct
a_tensor = _to_cute(x_fp4)
b_tensor = _to_cute(mat_b)
sfa_tensor = _to_cute(x_sf)
sfb_tensor = _to_cute(scale_b)
e_bias_ct = _to_cute(e_bias)
out_w_ct = _to_cute(out_weights)
out_id_ct = _to_cute(out_ids)
eo_ct = _to_cute(expert_offsets)
gsa_ct = _to_cute(gsa_t)
gsb_ct = _to_cute(gsb_t)
kernel = Nvfp4FusedRouterKernel(top_k=top_k)
kernel.run(