diff --git a/cutedsl/bridge.py b/cutedsl/bridge.py index edfe3e4f..e79cc5b7 100644 --- a/cutedsl/bridge.py +++ b/cutedsl/bridge.py @@ -230,12 +230,12 @@ def quantize_weight_to_nvfp4(w_bf16, block_size=SF_VEC_SIZE): def make_b_k_major(b_tensor): """Convert B tensor from N-major to K-major (required by kernel). - Input: (E, N, K_packed) or (E, K_packed, N) - Output: (E, K_packed, N) contiguous in K-major order + Input: (E, K_packed, N) — may be N-major or K-major + Output: (E, K_packed, N) contiguous in K-major order (stride of K_packed dim == 1) - If already K-major (stride[2] == 1), returns as-is. + For shape (E, K, N): K-major means stride[1]==1, N-major means stride[2]==1. """ - if b_tensor.stride(2) == 1: + if b_tensor.dim() == 3 and b_tensor.stride(1) == 1: return b_tensor.contiguous() return b_tensor.permute(0, 2, 1).contiguous()