63 lines
2.1 KiB
Python
63 lines
2.1 KiB
Python
|
|
# SPDX-License-Identifier: Apache-2.0
|
||
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||
|
|
|
||
|
|
import pytest
|
||
|
|
import torch
|
||
|
|
|
||
|
|
import vllm
|
||
|
|
from tests.compile.backend import TestBackend
|
||
|
|
from vllm.compilation.passes.utility.split_coalescing import SplitCoalescingPass
|
||
|
|
from vllm.config import CompilationConfig, CompilationMode, PassConfig, VllmConfig
|
||
|
|
|
||
|
|
|
||
|
|
class SplitCoalescingModel(torch.nn.Module):
|
||
|
|
"""Model with 3 separate split_with_sizes calls on the same input,
|
||
|
|
simulating the B200+FP8 graph where CSE fails to merge them."""
|
||
|
|
|
||
|
|
def __init__(self, q_size: int, kv_size: int) -> None:
|
||
|
|
super().__init__()
|
||
|
|
self.q_size = q_size
|
||
|
|
self.kv_size = kv_size
|
||
|
|
|
||
|
|
def forward(self, qkv: torch.Tensor):
|
||
|
|
q, _, _ = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||
|
|
_, k, _ = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||
|
|
_, _, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||
|
|
return q + 1, k + 2, v + 3
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||
|
|
def test_split_coalescing(dtype):
|
||
|
|
torch.set_default_device("cuda")
|
||
|
|
torch.set_default_dtype(dtype)
|
||
|
|
torch.manual_seed(0)
|
||
|
|
|
||
|
|
q_size, kv_size = 2048, 512
|
||
|
|
|
||
|
|
vllm_config = VllmConfig(
|
||
|
|
compilation_config=CompilationConfig(
|
||
|
|
mode=CompilationMode.VLLM_COMPILE,
|
||
|
|
pass_config=PassConfig(),
|
||
|
|
)
|
||
|
|
)
|
||
|
|
with vllm.config.set_current_vllm_config(vllm_config):
|
||
|
|
coalesce_pass = SplitCoalescingPass(vllm_config)
|
||
|
|
backend = TestBackend(coalesce_pass)
|
||
|
|
|
||
|
|
model = SplitCoalescingModel(q_size, kv_size)
|
||
|
|
|
||
|
|
T = 5
|
||
|
|
qkv = torch.randn(T, q_size + 2 * kv_size)
|
||
|
|
torch._dynamo.mark_dynamic(qkv, 0)
|
||
|
|
|
||
|
|
result_eager = model(qkv)
|
||
|
|
|
||
|
|
model_compiled = torch.compile(model, backend=backend)
|
||
|
|
result_compiled = model_compiled(qkv)
|
||
|
|
|
||
|
|
ATOL, RTOL = (2e-3, 2e-3)
|
||
|
|
for eager, compiled in zip(result_eager, result_compiled):
|
||
|
|
torch.testing.assert_close(eager, compiled, atol=ATOL, rtol=RTOL)
|
||
|
|
|
||
|
|
assert backend.op_count(torch.ops.aten.split_with_sizes.default) == 1
|