[Perf] Disable clean_logits in deepgemm fp8_mqa_logits kernel (#33568)

This commit is contained in:
Xin Yang
2026-02-05 17:34:00 -08:00
committed by GitHub
parent 325ab6b0a8
commit 79028d4388
4 changed files with 61 additions and 27 deletions

View File

@@ -95,7 +95,8 @@ def _ref_fp8_mqa_logits(
@pytest.mark.skipif(
not current_platform.has_device_capability(90), reason="SM90 and SM100 only"
)
def test_deepgemm_fp8_mqa_logits():
@pytest.mark.parametrize("clean_logits", [True, False])
def test_deepgemm_fp8_mqa_logits(clean_logits: bool):
torch.manual_seed(0)
random.seed(0)
num_heads, head_dim = 32, 128
@@ -126,7 +127,9 @@ def test_deepgemm_fp8_mqa_logits():
q_fp8 = q.to(torch.float8_e4m3fn)
kv_fp8 = per_custom_dims_cast_to_fp8(kv, (0,), False)
logits = fp8_mqa_logits(q_fp8, kv_fp8, weights, ks, ke)
logits = fp8_mqa_logits(
q_fp8, kv_fp8, weights, ks, ke, clean_logits=clean_logits
)
ref_logits = _ref_fp8_mqa_logits(
q=q,
@@ -135,13 +138,14 @@ def test_deepgemm_fp8_mqa_logits():
cu_seqlen_ks=ks,
cu_seqlen_ke=ke,
)
ref_neginf_mask = ref_logits == float("-inf")
neginf_mask = logits == float("-inf")
assert torch.equal(neginf_mask, ref_neginf_mask)
if clean_logits:
neginf_mask = logits == float("-inf")
assert torch.equal(neginf_mask, ref_neginf_mask)
ref_logits = ref_logits.masked_fill(ref_neginf_mask, 0)
logits = logits.masked_fill(neginf_mask, 0)
logits = logits.masked_fill(ref_neginf_mask, 0)
diff = calc_diff(logits, ref_logits)
assert diff < 1e-3, f"{diff=}"
@@ -201,7 +205,8 @@ def _ref_fp8_paged_mqa_logits(
@pytest.mark.skipif(
not current_platform.has_device_capability(90), reason="SM90 and SM100 only"
)
def test_deepgemm_fp8_paged_mqa_logits():
@pytest.mark.parametrize("clean_logits", [True, False])
def test_deepgemm_fp8_paged_mqa_logits(clean_logits: bool):
torch.manual_seed(0)
random.seed(0)
@@ -264,6 +269,7 @@ def test_deepgemm_fp8_paged_mqa_logits():
block_tables,
schedule_metadata,
max_model_len,
clean_logits=clean_logits,
)
ref_logits = _ref_fp8_paged_mqa_logits(