Fix: convert PyTorch tensors to CuTe tensors for fused router kernel

- Added cutlass_torch.from_dlpack() + mark_layout_dynamic() conversions
- quantize_activation_nvfp4 returns (fp4_packed, fp8_scales) which are
  converted to CuTe tensors before passing to the kernel
- Same pattern as gemm_runner.py
This commit is contained in:
2026-06-01 10:02:40 +00:00
parent bab748763e
commit 24fed15ed6

View File

@@ -818,13 +818,31 @@ def run_nvfp4_fused_router(
# Quantize activation to NVFP4
from dsv4.ops.quantize import quantize_activation_nvfp4
mat_a, scale_a = quantize_activation_nvfp4(hidden_states, gsa)
mat_a_bf16_packed, scale_a_fp8 = quantize_activation_nvfp4(hidden_states, gsa)
# Output tensor: FP32 activated scores [N, E]
# We write sqrt(softplus(logits)) + e_bias here,
# then top-k reads from it
activated_scores = torch.empty(N, E, dtype=torch.float32, device=device)
# Convert PyTorch tensors to CuTe tensors (same as gemm_runner.py pattern)
import cutlass.torch as cutlass_torch
def _to_cute(t, leading_dim=None):
ct = cutlass_torch.from_dlpack(t)
if leading_dim is not None:
return ct.mark_layout_dynamic(leading_dim=leading_dim)
return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t))
# Determine leading dimensions from tensor shapes
# mat_a_bf16_packed: [N, K_packed] — K-major (row-major for GEMM A)
# mat_b: [E, K_packed] — K-major (col-major for GEMM B, i.e. N-major)
# Actually, for NVFP4 GEMM: A is M-major, B is N-major
# Check the existing Nvfp4Linear to see how it handles this
cute_a = _to_cute(mat_a_bf16_packed)
cute_b = _to_cute(mat_b)
cute_sfa = _to_cute(scale_a_fp8)
cute_sfb = _to_cute(scale_b)
cute_c = _to_cute(activated_scores)
# Run the CuTeDSL kernel: NVFP4 GEMM + sqrt(softplus) epilogue
kernel = Nvfp4FusedRouterKernel(
sf_vec_size=16,
@@ -832,11 +850,11 @@ def run_nvfp4_fused_router(
cluster_shape_mnk=(1, 1, 1),
)
kernel.run(
mat_a=mat_a,
mat_b=mat_b,
scale_a=scale_a,
scale_b=scale_b,
mat_c=activated_scores,
mat_a=cute_a,
mat_b=cute_b,
scale_a=cute_sfa,
scale_b=cute_sfb,
mat_c=cute_c,
M=N, N=E, K=K,
)