[TPU] support attention head dim smaller than 128 (#19620)

Signed-off-by: Chengji Yao <chengjiyao@google.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Chengji Yao
2025-06-15 23:40:53 -07:00
committed by GitHub
parent b692e9cd07
commit a77aea59fd
2 changed files with 65 additions and 7 deletions

View File

@@ -67,6 +67,43 @@ def test_basic(
assert "1024" in output or "0, 1" in output
@pytest.mark.skipif(not current_platform.is_tpu(),
reason="This is a basic test for TPU only")
@pytest.mark.parametrize("max_tokens", [8])
@pytest.mark.parametrize("max_num_seqs", [16])
def test_phi3(
vllm_runner: type[VllmRunner],
monkeypatch: pytest.MonkeyPatch,
max_tokens: int,
max_num_seqs: int,
) -> None:
prompts = [
"A robot may not injure a human being",
"It is only with the heart that one can see rightly;",
"The greatest glory in living lies not in never falling,",
]
answers = [
" or, by violating privacy",
" what is essential is love.",
" but in rising every time we fall.",
]
# test head dim = 96
model = "microsoft/Phi-3-mini-128k-instruct"
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
with vllm_runner(model,
max_num_batched_tokens=256,
max_num_seqs=max_num_seqs) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(prompts, max_tokens)
# vllm_outputs is a list of tuples whose first element is the token id
# and the second element is the output (including the prompt).
for output, answer in zip(vllm_outputs, answers):
generated_text = output[1]
assert answer in generated_text
TP_SIZE_8 = 8