Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
This commit is contained in:
committed by
GitHub
parent
67ebaff528
commit
31aedfe7d6
@@ -87,6 +87,13 @@ NKM_FACTORS_WVSPLITK_FP8 = [
|
||||
SEEDS = [0]
|
||||
|
||||
|
||||
def pad_weights_fp8(weight):
|
||||
num_pad = 256 // weight.element_size()
|
||||
import torch.nn.functional as F
|
||||
|
||||
return F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITKRC)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@@ -191,11 +198,12 @@ def test_rocm_wvsplitk_bias2D_kernel(n, k, m, dtype, seed):
|
||||
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK_FP8)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("padded", [False, True])
|
||||
@pytest.mark.skipif(
|
||||
not (current_platform.is_rocm() and current_platform.supports_fp8()),
|
||||
reason="only test for rocm fp8",
|
||||
)
|
||||
def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed):
|
||||
def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed, padded):
|
||||
torch.manual_seed(seed)
|
||||
|
||||
A = torch.rand(n, k, device="cuda") - 0.5
|
||||
@@ -203,6 +211,8 @@ def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed):
|
||||
|
||||
A, scale_a = ref_dynamic_per_tensor_fp8_quant(A)
|
||||
B, scale_b = ref_dynamic_per_tensor_fp8_quant(B)
|
||||
if padded:
|
||||
B = pad_weights_fp8(B)
|
||||
|
||||
ref_out = torch._scaled_mm(
|
||||
A, B.t(), out_dtype=dtype, scale_a=scale_a, scale_b=scale_b
|
||||
@@ -222,11 +232,12 @@ def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed):
|
||||
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK_FP8)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("padded", [False, True])
|
||||
@pytest.mark.skipif(
|
||||
not (current_platform.is_rocm() and current_platform.supports_fp8()),
|
||||
reason="only test for rocm fp8",
|
||||
)
|
||||
def test_rocm_wvsplitk_fp8_bias1D_kernel(n, k, m, dtype, seed):
|
||||
def test_rocm_wvsplitk_fp8_bias1D_kernel(n, k, m, dtype, seed, padded):
|
||||
torch.manual_seed(seed)
|
||||
|
||||
xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas
|
||||
@@ -236,6 +247,8 @@ def test_rocm_wvsplitk_fp8_bias1D_kernel(n, k, m, dtype, seed):
|
||||
|
||||
A, scale_a = ref_dynamic_per_tensor_fp8_quant(A)
|
||||
B, scale_b = ref_dynamic_per_tensor_fp8_quant(B)
|
||||
if padded:
|
||||
B = pad_weights_fp8(B)
|
||||
|
||||
ref_out = torch._scaled_mm(
|
||||
A, B.t(), out_dtype=dtype, scale_a=scale_a, scale_b=scale_b, bias=BIAS
|
||||
|
||||
Reference in New Issue
Block a user