[compile] Add patched_fused_scaled_matmul_reduce_scatter (#26604)

Signed-off-by: angelayi <yiangela7@gmail.com>
This commit is contained in:
Angela Yi
2025-10-11 05:44:43 -07:00
committed by GitHub
parent d0bed837ac
commit a25f2adee9
3 changed files with 119 additions and 6 deletions

View File

@@ -172,7 +172,7 @@ class ScaledMMReduceScatterPattern(BasePattern):
# Calculate output shape: input @ mat2 with scatter_dim reduced
output_shape = [*input.shape[:-1], mat2.shape[1]]
scatter_dim = 0
gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter(
gemm_rs = torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter(
input,
mat2,
scale_a,
@@ -307,7 +307,7 @@ class CutlassScaledMMReduceScatterPattern(BasePattern):
# Calculate output shape: input @ mat2 with scatter_dim reduced
output_shape = [*input.shape[:-1], mat2.shape[1]]
scatter_dim = 0
gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter(
gemm_rs = torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter(
input,
mat2,
scale_a,