Use w8a8 quantized matmul Pallas kernel (#19170)
Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
This commit is contained in:
@@ -145,3 +145,35 @@ def test_gemma3_27b_with_text_input_and_tp(
|
||||
for output, answer in zip(vllm_outputs, answers):
|
||||
generated_text = output[1]
|
||||
assert answer in generated_text
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_tpu(),
|
||||
reason="This is a basic test for TPU only")
|
||||
def test_w8a8_quantization(
|
||||
vllm_runner: type[VllmRunner],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
model = "neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8"
|
||||
max_tokens = 5
|
||||
tensor_parallel_size = 1
|
||||
max_num_seqs = 4
|
||||
|
||||
prompt = "The next numbers of the sequence " + ", ".join(
|
||||
str(i) for i in range(1024)) + " are:"
|
||||
example_prompts = [prompt]
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
with vllm_runner(
|
||||
model,
|
||||
max_num_batched_tokens=64,
|
||||
max_model_len=4096,
|
||||
gpu_memory_utilization=0.7,
|
||||
max_num_seqs=max_num_seqs,
|
||||
tensor_parallel_size=tensor_parallel_size) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts,
|
||||
max_tokens)
|
||||
output = vllm_outputs[0][1]
|
||||
|
||||
assert "1024" in output or "0, 1" in output
|
||||
|
||||
Reference in New Issue
Block a user