Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
This commit is contained in:
committed by
GitHub
parent
67ebaff528
commit
31aedfe7d6
@@ -28,6 +28,7 @@ def rocm_per_tensor_float_w8a8_scaled_mm_impl(
|
||||
A.shape[0] == 1
|
||||
and B.shape[1] % 16 == 0
|
||||
and ((bias is None) or (bias.dtype == out_dtype))
|
||||
and A.is_contiguous()
|
||||
):
|
||||
output = ops.wvSplitKQ(
|
||||
B.t(),
|
||||
|
||||
@@ -165,6 +165,7 @@ def rocm_unquantized_gemm_impl(
|
||||
and n <= 128
|
||||
and k > 512
|
||||
and math.ceil(k / 512) * math.ceil(m / 16) < get_cu_count()
|
||||
and x.is_contiguous()
|
||||
)
|
||||
# k == 2880 and (m == 640 or m == 128))
|
||||
)
|
||||
@@ -179,6 +180,7 @@ def rocm_unquantized_gemm_impl(
|
||||
and on_gfx9()
|
||||
and x.dtype in [torch.float16, torch.bfloat16]
|
||||
and k % 8 == 0
|
||||
and x.is_contiguous()
|
||||
)
|
||||
|
||||
if use_skinny is not True:
|
||||
|
||||
Reference in New Issue
Block a user