Fix CuTe tensor creation: use from_dlpack + mark_layout_dynamic
This commit is contained in:
@@ -728,17 +728,22 @@ def run_nvfp4_fused_router(
|
||||
from dsv4.ops.quantize import quantize_activation_nvfp4
|
||||
x_fp4, x_sf = quantize_activation_nvfp4(hidden_states, gsa)
|
||||
|
||||
# Create CuTe tensors
|
||||
a_tensor = cutlass_torch.as_tensor(x_fp4, shape=x_fp4.shape)
|
||||
b_tensor = cutlass_torch.as_tensor(mat_b, shape=mat_b.shape)
|
||||
sfa_tensor = cutlass_torch.as_tensor(x_sf, shape=x_sf.shape)
|
||||
sfb_tensor = cutlass_torch.as_tensor(scale_b, shape=scale_b.shape)
|
||||
e_bias_ct = cutlass_torch.as_tensor(e_bias, shape=e_bias.shape)
|
||||
out_w_ct = cutlass_torch.as_tensor(out_weights, shape=out_weights.shape)
|
||||
out_id_ct = cutlass_torch.as_tensor(out_ids, shape=out_ids.shape)
|
||||
eo_ct = cutlass_torch.as_tensor(expert_offsets, shape=expert_offsets.shape)
|
||||
gsa_ct = cutlass_torch.as_tensor(gsa_t, shape=gsa_t.shape)
|
||||
gsb_ct = cutlass_torch.as_tensor(gsb_t, shape=gsb_t.shape)
|
||||
# Create CuTe tensors (from_dlpack + mark_layout_dynamic for proper TMA)
|
||||
def _to_cute(t):
|
||||
ct = cutlass_torch.from_dlpack(t)
|
||||
ct = ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t))
|
||||
return ct
|
||||
|
||||
a_tensor = _to_cute(x_fp4)
|
||||
b_tensor = _to_cute(mat_b)
|
||||
sfa_tensor = _to_cute(x_sf)
|
||||
sfb_tensor = _to_cute(scale_b)
|
||||
e_bias_ct = _to_cute(e_bias)
|
||||
out_w_ct = _to_cute(out_weights)
|
||||
out_id_ct = _to_cute(out_ids)
|
||||
eo_ct = _to_cute(expert_offsets)
|
||||
gsa_ct = _to_cute(gsa_t)
|
||||
gsb_ct = _to_cute(gsb_t)
|
||||
|
||||
kernel = Nvfp4FusedRouterKernel(top_k=top_k)
|
||||
kernel.run(
|
||||
|
||||
Reference in New Issue
Block a user