[Kernel] W8A16 Int8 inside FusedMoE (#7415)

This commit is contained in:
Mor Zusman
2024-08-16 20:06:51 +03:00
committed by GitHub
parent e837b624f2
commit 7fc23be81c
15 changed files with 412 additions and 136 deletions

View File

@@ -6,9 +6,12 @@ from vllm.worker.model_runner import _get_graph_batch_size
MODELS = ["ai21labs/Jamba-tiny-random"]
# Fails due to usage of MoE as MLP(E=1_, which is different than the HF impl
# TODO: Fix this with trained model
@pytest.mark.skip()
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [20])
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [10])
def test_models(
hf_runner,
vllm_runner,
@@ -17,8 +20,6 @@ def test_models(
dtype: str,
max_tokens: int,
) -> None:
# To pass the small model tests, we need full precision.
assert dtype == "float"
with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
@@ -36,8 +37,8 @@ def test_models(
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [20])
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [5])
def test_batching(
vllm_runner,
example_prompts,