fix: make_b_k_major stride check — K-major means stride[1]==1, not stride[2]==1
For (E, K, N): stride[2]==1 is N-major (columns contiguous). K-major requires stride[1]==1 (rows contiguous).
This commit is contained in:
@@ -230,12 +230,12 @@ def quantize_weight_to_nvfp4(w_bf16, block_size=SF_VEC_SIZE):
|
||||
def make_b_k_major(b_tensor):
|
||||
"""Convert B tensor from N-major to K-major (required by kernel).
|
||||
|
||||
Input: (E, N, K_packed) or (E, K_packed, N)
|
||||
Output: (E, K_packed, N) contiguous in K-major order
|
||||
Input: (E, K_packed, N) — may be N-major or K-major
|
||||
Output: (E, K_packed, N) contiguous in K-major order (stride of K_packed dim == 1)
|
||||
|
||||
If already K-major (stride[2] == 1), returns as-is.
|
||||
For shape (E, K, N): K-major means stride[1]==1, N-major means stride[2]==1.
|
||||
"""
|
||||
if b_tensor.stride(2) == 1:
|
||||
if b_tensor.dim() == 3 and b_tensor.stride(1) == 1:
|
||||
return b_tensor.contiguous()
|
||||
return b_tensor.permute(0, 2, 1).contiguous()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user