[WIP] TPU V1 Support Refactored (#13049)
This commit is contained in:
committed by
GitHub
parent
b0ccfc565a
commit
45f90bcbba
@@ -21,7 +21,7 @@ TASK = "gsm8k"
|
||||
FILTER = "exact_match,strict-match"
|
||||
RTOL = 0.03
|
||||
EXPECTED_VALUE = 0.58
|
||||
DEFAULT_ARGS = ["--max-model-len", "2048", "--disable-log-requests"]
|
||||
DEFAULT_ARGS = ["--max-model-len", "4096", "--disable-log-requests"]
|
||||
MORE_ARGS_LIST = [
|
||||
[], # Default
|
||||
["--enable-chunked-prefill"], # Chunked
|
||||
@@ -67,14 +67,21 @@ def run_test(more_args):
|
||||
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(),
|
||||
reason="V1 currently only supported on CUDA")
|
||||
@pytest.mark.skipif(not current_platform.is_cuda()
|
||||
and not current_platform.is_tpu(),
|
||||
reason="V1 currently only supported on CUDA and TPU")
|
||||
def test_lm_eval_accuracy_v1_engine(monkeypatch):
|
||||
"""Run with the V1 Engine."""
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
run_test([])
|
||||
more_args = []
|
||||
|
||||
# Limit compilation time for V1
|
||||
if current_platform.is_tpu():
|
||||
more_args = ["--max-num-seqs", "64"]
|
||||
|
||||
run_test(more_args)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("more_args", MORE_ARGS_LIST)
|
||||
|
||||
Reference in New Issue
Block a user