[compile] Add patched_fused_scaled_matmul_reduce_scatter (#26604)
Signed-off-by: angelayi <yiangela7@gmail.com>
This commit is contained in:
@@ -142,7 +142,7 @@ class TestScaledMMRSModel(_BaseScaledMMModel):
|
||||
return [torch.ops.vllm.reduce_scatter.default]
|
||||
|
||||
def ops_in_model_after(self):
|
||||
return [torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter.default]
|
||||
return [torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter.default]
|
||||
|
||||
|
||||
class TestAGScaledMMModel(_BaseScaledMMModel):
|
||||
@@ -195,7 +195,7 @@ class TestCutlassScaledMMRSModel(_BaseScaledMMModel):
|
||||
return [torch.ops.vllm.reduce_scatter.default]
|
||||
|
||||
def ops_in_model_after(self):
|
||||
return [torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter.default]
|
||||
return [torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter.default]
|
||||
|
||||
|
||||
class TestAGCutlassScaledMMModel(_BaseScaledMMModel):
|
||||
@@ -243,9 +243,15 @@ class TestAGCutlassScaledMMModel(_BaseScaledMMModel):
|
||||
@pytest.mark.parametrize("seq_len", [16])
|
||||
@pytest.mark.parametrize("hidden_size", [16])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("dynamic", [True, False])
|
||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
|
||||
def test_async_tp_pass_replace(
|
||||
test_model: str, batch_size: int, seq_len: int, hidden_size: int, dtype: torch.dtype
|
||||
test_model: str,
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
dynamic: bool,
|
||||
):
|
||||
if (
|
||||
test_model
|
||||
@@ -269,7 +275,15 @@ def test_async_tp_pass_replace(
|
||||
# torch.distributed and cuda
|
||||
torch.multiprocessing.spawn(
|
||||
fn,
|
||||
args=(num_processes, test_model, batch_size, seq_len, hidden_size, dtype),
|
||||
args=(
|
||||
num_processes,
|
||||
test_model,
|
||||
batch_size,
|
||||
seq_len,
|
||||
hidden_size,
|
||||
dtype,
|
||||
dynamic,
|
||||
),
|
||||
nprocs=nprocs,
|
||||
)
|
||||
|
||||
@@ -284,6 +298,7 @@ def async_tp_pass_on_test_model(
|
||||
seq_len: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
dynamic: bool,
|
||||
):
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
@@ -331,6 +346,9 @@ def async_tp_pass_on_test_model(
|
||||
(batch_size * seq_len, hidden_size), dtype=dtype, requires_grad=False
|
||||
)
|
||||
|
||||
if dynamic:
|
||||
torch._dynamo.mark_dynamic(hidden_states, 0)
|
||||
|
||||
compiled_model = torch.compile(model, backend=backend)
|
||||
compiled_model(hidden_states)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user