[Performance] Improve Triton prefill attention kernel's performance (#32403)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py
2026-01-18 12:19:59 +08:00
committed by GitHub
parent 4a6af8813f
commit 8cc26acd8b
3 changed files with 32 additions and 47 deletions

View File

@@ -46,7 +46,7 @@ def test_bert_models(
for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
hf_output = hf_output.detach().clone().cpu().float()
vllm_output = vllm_output.detach().clone().cpu().float()
torch.testing.assert_close(hf_output, vllm_output, atol=1.2e-2, rtol=1e-3)
torch.testing.assert_close(hf_output, vllm_output, atol=3.2e-2, rtol=1e-3)
@pytest.mark.parametrize("model", ["disham993/electrical-ner-ModernBERT-base"])
@@ -86,7 +86,7 @@ def test_modernbert_models(
for hf_output, vllm_output in zip(hf_outputs, vllm_outputs):
hf_output = hf_output.detach().clone().cpu().float()
vllm_output = vllm_output.detach().clone().cpu().float()
torch.testing.assert_close(hf_output, vllm_output, atol=1.2e-2, rtol=1e-3)
torch.testing.assert_close(hf_output, vllm_output, atol=3.2e-2, rtol=1e-3)
@pytest.mark.parametrize("model", ["bd2lcco/Qwen3-0.6B-finetuned"])