[TPU] support attention head dim smaller than 128 (#19620)
Signed-off-by: Chengji Yao <chengjiyao@google.com> Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@@ -67,6 +67,43 @@ def test_basic(
|
||||
assert "1024" in output or "0, 1" in output
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_tpu(),
|
||||
reason="This is a basic test for TPU only")
|
||||
@pytest.mark.parametrize("max_tokens", [8])
|
||||
@pytest.mark.parametrize("max_num_seqs", [16])
|
||||
def test_phi3(
|
||||
vllm_runner: type[VllmRunner],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
max_tokens: int,
|
||||
max_num_seqs: int,
|
||||
) -> None:
|
||||
prompts = [
|
||||
"A robot may not injure a human being",
|
||||
"It is only with the heart that one can see rightly;",
|
||||
"The greatest glory in living lies not in never falling,",
|
||||
]
|
||||
answers = [
|
||||
" or, by violating privacy",
|
||||
" what is essential is love.",
|
||||
" but in rising every time we fall.",
|
||||
]
|
||||
# test head dim = 96
|
||||
model = "microsoft/Phi-3-mini-128k-instruct"
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
with vllm_runner(model,
|
||||
max_num_batched_tokens=256,
|
||||
max_num_seqs=max_num_seqs) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_greedy(prompts, max_tokens)
|
||||
# vllm_outputs is a list of tuples whose first element is the token id
|
||||
# and the second element is the output (including the prompt).
|
||||
for output, answer in zip(vllm_outputs, answers):
|
||||
generated_text = output[1]
|
||||
assert answer in generated_text
|
||||
|
||||
|
||||
TP_SIZE_8 = 8
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user