Fix: convert PyTorch tensors to CuTe tensors for fused router kernel
- Added cutlass_torch.from_dlpack() + mark_layout_dynamic() conversions - quantize_activation_nvfp4 returns (fp4_packed, fp8_scales) which are converted to CuTe tensors before passing to the kernel - Same pattern as gemm_runner.py
This commit is contained in:
@@ -818,13 +818,31 @@ def run_nvfp4_fused_router(
|
||||
|
||||
# Quantize activation to NVFP4
|
||||
from dsv4.ops.quantize import quantize_activation_nvfp4
|
||||
mat_a, scale_a = quantize_activation_nvfp4(hidden_states, gsa)
|
||||
mat_a_bf16_packed, scale_a_fp8 = quantize_activation_nvfp4(hidden_states, gsa)
|
||||
|
||||
# Output tensor: FP32 activated scores [N, E]
|
||||
# We write sqrt(softplus(logits)) + e_bias here,
|
||||
# then top-k reads from it
|
||||
activated_scores = torch.empty(N, E, dtype=torch.float32, device=device)
|
||||
|
||||
# Convert PyTorch tensors to CuTe tensors (same as gemm_runner.py pattern)
|
||||
import cutlass.torch as cutlass_torch
|
||||
|
||||
def _to_cute(t, leading_dim=None):
|
||||
ct = cutlass_torch.from_dlpack(t)
|
||||
if leading_dim is not None:
|
||||
return ct.mark_layout_dynamic(leading_dim=leading_dim)
|
||||
return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t))
|
||||
|
||||
# Determine leading dimensions from tensor shapes
|
||||
# mat_a_bf16_packed: [N, K_packed] — K-major (row-major for GEMM A)
|
||||
# mat_b: [E, K_packed] — K-major (col-major for GEMM B, i.e. N-major)
|
||||
# Actually, for NVFP4 GEMM: A is M-major, B is N-major
|
||||
# Check the existing Nvfp4Linear to see how it handles this
|
||||
cute_a = _to_cute(mat_a_bf16_packed)
|
||||
cute_b = _to_cute(mat_b)
|
||||
cute_sfa = _to_cute(scale_a_fp8)
|
||||
cute_sfb = _to_cute(scale_b)
|
||||
cute_c = _to_cute(activated_scores)
|
||||
|
||||
# Run the CuTeDSL kernel: NVFP4 GEMM + sqrt(softplus) epilogue
|
||||
kernel = Nvfp4FusedRouterKernel(
|
||||
sf_vec_size=16,
|
||||
@@ -832,11 +850,11 @@ def run_nvfp4_fused_router(
|
||||
cluster_shape_mnk=(1, 1, 1),
|
||||
)
|
||||
kernel.run(
|
||||
mat_a=mat_a,
|
||||
mat_b=mat_b,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
mat_c=activated_scores,
|
||||
mat_a=cute_a,
|
||||
mat_b=cute_b,
|
||||
scale_a=cute_sfa,
|
||||
scale_b=cute_sfb,
|
||||
mat_c=cute_c,
|
||||
M=N, N=E, K=K,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user