[ROCm] AITER fused RoPE+KVCache (#33443)
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com> Signed-off-by: charlifu <charlifu@amd.com> Signed-off-by: Rohan Potdar <66227218+Rohan138@users.noreply.github.com> Co-authored-by: charlifu <charlifu@amd.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: Douglas Lehr <91553416+dllehr-amd@users.noreply.github.com>
This commit is contained in:
@@ -1,10 +1,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import copy
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from tests.compile.backend import TestBackend
|
||||
from tests.utils import TestFP8Layer
|
||||
from vllm.compilation.passes.fusion.act_quant_fusion import (
|
||||
@@ -31,6 +32,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
TEST_FP8 = current_platform.supports_fp8()
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
@@ -198,23 +200,82 @@ class TestRotaryEmbeddingSliceScatter(torch.nn.Module):
|
||||
return [torch.ops.aten.slice_scatter.default]
|
||||
|
||||
|
||||
MODELS = [
|
||||
TestSiluMul,
|
||||
TestFusedAddRMSNorm,
|
||||
TestRotaryEmbedding,
|
||||
TestRotaryEmbeddingSliceScatter,
|
||||
]
|
||||
class TestFunctionWithMutatedArgsAndReturn(torch.nn.Module):
|
||||
OP_REGISTERED = False
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.register_test_custom_op()
|
||||
|
||||
@classmethod
|
||||
def register_test_custom_op(cls):
|
||||
if not cls.OP_REGISTERED:
|
||||
|
||||
def function_with_mutated_args_and_return_impl(
|
||||
x: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
ret = x + 1
|
||||
x.add_(2)
|
||||
return ret
|
||||
|
||||
def function_with_mutated_args_and_return_fake(
|
||||
x: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(x)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="function_with_mutated_args_and_return",
|
||||
op_func=function_with_mutated_args_and_return_impl,
|
||||
mutates_args=["x"],
|
||||
fake_impl=function_with_mutated_args_and_return_fake,
|
||||
)
|
||||
|
||||
cls.OP_REGISTERED = True
|
||||
|
||||
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Clone x to avoid mutating the original tensor
|
||||
ret = torch.ops.vllm.function_with_mutated_args_and_return(x)
|
||||
return x, ret
|
||||
|
||||
def example_inputs(self, num_tokens=32):
|
||||
hidden_states = torch.randn(num_tokens)
|
||||
return (hidden_states,)
|
||||
|
||||
def ops_in_model(self, do_fusion):
|
||||
return [torch.ops.vllm.function_with_mutated_args_and_return.default]
|
||||
|
||||
def ops_not_in_model(self):
|
||||
return []
|
||||
|
||||
|
||||
MODELS_AND_DO_FUSION = {
|
||||
TestSiluMul: [True, False],
|
||||
TestFusedAddRMSNorm: [True, False],
|
||||
TestRotaryEmbedding: [False],
|
||||
TestRotaryEmbeddingSliceScatter: [False],
|
||||
TestFunctionWithMutatedArgsAndReturn: [False],
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("model_class", MODELS)
|
||||
@pytest.mark.parametrize("do_fusion", [True, False])
|
||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda", reason="Only test on CUDA")
|
||||
@pytest.mark.parametrize(
|
||||
"model_class, do_fusion",
|
||||
[
|
||||
(model_class, do_fusion)
|
||||
for model_class, fusions in MODELS_AND_DO_FUSION.items()
|
||||
for do_fusion in fusions
|
||||
],
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda_alike(),
|
||||
reason="Only test on cuda and rocm platform",
|
||||
)
|
||||
def test_fix_functionalization(
|
||||
model_class: torch.nn.Module, do_fusion: bool, dtype: torch.dtype
|
||||
):
|
||||
torch.set_default_device("cuda")
|
||||
torch.set_default_dtype(dtype)
|
||||
torch.manual_seed(0)
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
model_config=ModelConfig(dtype=dtype),
|
||||
@@ -246,8 +307,14 @@ def test_fix_functionalization(
|
||||
backend_no_func = TestBackend(*passes)
|
||||
|
||||
model = model_class()
|
||||
torch.compile(model, backend=backend_func)(*model.example_inputs())
|
||||
torch.compile(model, backend=backend_no_func)(*model.example_inputs())
|
||||
inputs_func = model.example_inputs()
|
||||
inputs_no_func = copy.deepcopy(inputs_func)
|
||||
model_func = model_class()
|
||||
model_no_func = copy.deepcopy(model_func)
|
||||
model_func = torch.compile(model_func, backend=backend_func)
|
||||
model_no_func = torch.compile(model_no_func, backend=backend_no_func)
|
||||
model_func(*inputs_func)
|
||||
model_no_func(*inputs_no_func)
|
||||
|
||||
# check if the functionalization pass is applied
|
||||
for op in model.ops_in_model(do_fusion):
|
||||
@@ -265,3 +332,8 @@ def test_fix_functionalization(
|
||||
found[op] = True
|
||||
assert all(found[op] for op in model.ops_in_model(do_fusion))
|
||||
assert all(not found.get(op) for op in model.ops_not_in_model())
|
||||
|
||||
# TODO (Rohan138): compare the outputs from model_func and model_no_func
|
||||
# currently runs into errors while comparing `TestFusedAddRMSNorm`
|
||||
# Linked issue: https://github.com/vllm-project/vllm/issues/34996
|
||||
# torch.testing.assert_close(outputs_func, outputs_no_func)
|
||||
|
||||
Reference in New Issue
Block a user