Adds padding and perf improvements to wvSplitK_fp8 (#33527)

Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com>
This commit is contained in:
Hashem Hashemi
2026-02-05 14:16:02 -08:00
committed by GitHub
parent 42d5d705f9
commit d5c4800112
3 changed files with 169 additions and 229 deletions

View File

@@ -25,10 +25,10 @@ def rocm_per_tensor_float_w8a8_scaled_mm_impl(
bias: torch.Tensor,
) -> torch.Tensor:
if (
A.shape[0] == 1
and B.shape[1] % 16 == 0
A.shape[0] <= 4
and B.shape[0] % 16 == 0 # M TODO: needed?
and B.shape[1] % 16 == 0 # K
and ((bias is None) or (bias.dtype == out_dtype))
and A.is_contiguous()
):
output = ops.wvSplitKQ(
B.t(),