[Bugfix] awq_gemm: fix argument order swap (#30364)

Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
Matthias Gehre
2025-12-14 11:15:37 +01:00
committed by GitHub
parent 3224ea9915
commit e9add129ad
2 changed files with 7 additions and 7 deletions

View File

@@ -41,9 +41,9 @@ def test_awq_gemm_opcheck(monkeypatch: pytest.MonkeyPatch):
qweight = torch.randint(
-2000000000, 2000000000, (8192, 256), device="cuda", dtype=torch.int32
)
scales = torch.randint(
scales = torch.empty((64, 2048), device="cuda", dtype=torch.float16)
qzeros = torch.randint(
-2000000000, 2000000000, (64, 256), device="cuda", dtype=torch.int32
)
qzeros = torch.empty((64, 2048), device="cuda", dtype=torch.float16)
split_k_iters = 8
opcheck(torch.ops._C.awq_gemm, (input, qweight, qzeros, scales, split_k_iters))
opcheck(torch.ops._C.awq_gemm, (input, qweight, scales, qzeros, split_k_iters))