[ROCm] Apply FP8 weights padding to values not divisible by 512 bytes on ROCm (#13231)

This commit is contained in:
Gregory Shtrasberg
2025-02-22 08:54:38 -05:00
committed by GitHub
parent 558db8083c
commit c904fdddf6
3 changed files with 20 additions and 1 deletions

View File

@@ -494,7 +494,7 @@ def w8a8_block_fp8_matmul(
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
M = A.numel() // A.shape[-1]
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
assert B.ndim == 2 and Bs.ndim == 2
N, K = B.shape
assert triton.cdiv(N, block_n) == Bs.shape[0]
assert triton.cdiv(K, block_k) == Bs.shape[1]