Adds padding and perf improvements to wvSplitK_fp8 (#33527)
Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com>
This commit is contained in:
@@ -25,10 +25,10 @@ def rocm_per_tensor_float_w8a8_scaled_mm_impl(
|
||||
bias: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
if (
|
||||
A.shape[0] == 1
|
||||
and B.shape[1] % 16 == 0
|
||||
A.shape[0] <= 4
|
||||
and B.shape[0] % 16 == 0 # M TODO: needed?
|
||||
and B.shape[1] % 16 == 0 # K
|
||||
and ((bias is None) or (bias.dtype == out_dtype))
|
||||
and A.is_contiguous()
|
||||
):
|
||||
output = ops.wvSplitKQ(
|
||||
B.t(),
|
||||
|
||||
Reference in New Issue
Block a user