[compile] Add patched_fused_scaled_matmul_reduce_scatter (#26604)

Signed-off-by: angelayi <yiangela7@gmail.com>
This commit is contained in:
Angela Yi
2025-10-11 05:44:43 -07:00
committed by GitHub
parent d0bed837ac
commit a25f2adee9
3 changed files with 119 additions and 6 deletions

View File

@@ -142,7 +142,7 @@ class TestScaledMMRSModel(_BaseScaledMMModel):
return [torch.ops.vllm.reduce_scatter.default]
def ops_in_model_after(self):
return [torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter.default]
return [torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter.default]
class TestAGScaledMMModel(_BaseScaledMMModel):
@@ -195,7 +195,7 @@ class TestCutlassScaledMMRSModel(_BaseScaledMMModel):
return [torch.ops.vllm.reduce_scatter.default]
def ops_in_model_after(self):
return [torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter.default]
return [torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter.default]
class TestAGCutlassScaledMMModel(_BaseScaledMMModel):
@@ -243,9 +243,15 @@ class TestAGCutlassScaledMMModel(_BaseScaledMMModel):
@pytest.mark.parametrize("seq_len", [16])
@pytest.mark.parametrize("hidden_size", [16])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("dynamic", [True, False])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
def test_async_tp_pass_replace(
test_model: str, batch_size: int, seq_len: int, hidden_size: int, dtype: torch.dtype
test_model: str,
batch_size: int,
seq_len: int,
hidden_size: int,
dtype: torch.dtype,
dynamic: bool,
):
if (
test_model
@@ -269,7 +275,15 @@ def test_async_tp_pass_replace(
# torch.distributed and cuda
torch.multiprocessing.spawn(
fn,
args=(num_processes, test_model, batch_size, seq_len, hidden_size, dtype),
args=(
num_processes,
test_model,
batch_size,
seq_len,
hidden_size,
dtype,
dynamic,
),
nprocs=nprocs,
)
@@ -284,6 +298,7 @@ def async_tp_pass_on_test_model(
seq_len: int,
hidden_size: int,
dtype: torch.dtype,
dynamic: bool,
):
current_platform.seed_everything(0)
@@ -331,6 +346,9 @@ def async_tp_pass_on_test_model(
(batch_size * seq_len, hidden_size), dtype=dtype, requires_grad=False
)
if dynamic:
torch._dynamo.mark_dynamic(hidden_states, 0)
compiled_model = torch.compile(model, backend=backend)
compiled_model(hidden_states)