[Perf] Disable clean_logits in deepgemm fp8_mqa_logits kernel (#33568)
This commit is contained in:
@@ -95,7 +95,8 @@ def _ref_fp8_mqa_logits(
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(90), reason="SM90 and SM100 only"
|
||||
)
|
||||
def test_deepgemm_fp8_mqa_logits():
|
||||
@pytest.mark.parametrize("clean_logits", [True, False])
|
||||
def test_deepgemm_fp8_mqa_logits(clean_logits: bool):
|
||||
torch.manual_seed(0)
|
||||
random.seed(0)
|
||||
num_heads, head_dim = 32, 128
|
||||
@@ -126,7 +127,9 @@ def test_deepgemm_fp8_mqa_logits():
|
||||
|
||||
q_fp8 = q.to(torch.float8_e4m3fn)
|
||||
kv_fp8 = per_custom_dims_cast_to_fp8(kv, (0,), False)
|
||||
logits = fp8_mqa_logits(q_fp8, kv_fp8, weights, ks, ke)
|
||||
logits = fp8_mqa_logits(
|
||||
q_fp8, kv_fp8, weights, ks, ke, clean_logits=clean_logits
|
||||
)
|
||||
|
||||
ref_logits = _ref_fp8_mqa_logits(
|
||||
q=q,
|
||||
@@ -135,13 +138,14 @@ def test_deepgemm_fp8_mqa_logits():
|
||||
cu_seqlen_ks=ks,
|
||||
cu_seqlen_ke=ke,
|
||||
)
|
||||
|
||||
ref_neginf_mask = ref_logits == float("-inf")
|
||||
neginf_mask = logits == float("-inf")
|
||||
assert torch.equal(neginf_mask, ref_neginf_mask)
|
||||
|
||||
if clean_logits:
|
||||
neginf_mask = logits == float("-inf")
|
||||
assert torch.equal(neginf_mask, ref_neginf_mask)
|
||||
|
||||
ref_logits = ref_logits.masked_fill(ref_neginf_mask, 0)
|
||||
logits = logits.masked_fill(neginf_mask, 0)
|
||||
logits = logits.masked_fill(ref_neginf_mask, 0)
|
||||
diff = calc_diff(logits, ref_logits)
|
||||
assert diff < 1e-3, f"{diff=}"
|
||||
|
||||
@@ -201,7 +205,8 @@ def _ref_fp8_paged_mqa_logits(
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(90), reason="SM90 and SM100 only"
|
||||
)
|
||||
def test_deepgemm_fp8_paged_mqa_logits():
|
||||
@pytest.mark.parametrize("clean_logits", [True, False])
|
||||
def test_deepgemm_fp8_paged_mqa_logits(clean_logits: bool):
|
||||
torch.manual_seed(0)
|
||||
random.seed(0)
|
||||
|
||||
@@ -264,6 +269,7 @@ def test_deepgemm_fp8_paged_mqa_logits():
|
||||
block_tables,
|
||||
schedule_metadata,
|
||||
max_model_len,
|
||||
clean_logits=clean_logits,
|
||||
)
|
||||
|
||||
ref_logits = _ref_fp8_paged_mqa_logits(
|
||||
|
||||
Reference in New Issue
Block a user