[Perf] Optimize batch invariant BMM, 18.1% Throughput improvement, 10.7% TTFT improvement (#29345)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Wentao Ye
2025-11-26 11:38:52 -05:00
committed by GitHub
parent 70d5953f82
commit 0b0aa874e8
3 changed files with 217 additions and 16 deletions

View File

@@ -159,7 +159,6 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
"backend",
BACKENDS,
)
@pytest.mark.forked
def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
backend, monkeypatch: pytest.MonkeyPatch
):
@@ -429,7 +428,6 @@ def test_simple_generation(backend, monkeypatch: pytest.MonkeyPatch):
"backend",
BACKENDS,
)
@pytest.mark.forked
def test_logprobs_without_batch_invariance_should_fail(
backend, monkeypatch: pytest.MonkeyPatch
):
@@ -646,7 +644,6 @@ def test_logprobs_without_batch_invariance_should_fail(
@skip_unsupported
@pytest.mark.parametrize("backend", ["FLASH_ATTN"])
@pytest.mark.forked
def test_decode_logprobs_match_prefill_logprobs(
backend, monkeypatch: pytest.MonkeyPatch
):