[compile] Add patched_fused_scaled_matmul_reduce_scatter (#26604)
Signed-off-by: angelayi <yiangela7@gmail.com>
This commit is contained in:
@@ -172,7 +172,7 @@ class ScaledMMReduceScatterPattern(BasePattern):
|
||||
# Calculate output shape: input @ mat2 with scatter_dim reduced
|
||||
output_shape = [*input.shape[:-1], mat2.shape[1]]
|
||||
scatter_dim = 0
|
||||
gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter(
|
||||
gemm_rs = torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter(
|
||||
input,
|
||||
mat2,
|
||||
scale_a,
|
||||
@@ -307,7 +307,7 @@ class CutlassScaledMMReduceScatterPattern(BasePattern):
|
||||
# Calculate output shape: input @ mat2 with scatter_dim reduced
|
||||
output_shape = [*input.shape[:-1], mat2.shape[1]]
|
||||
scatter_dim = 0
|
||||
gemm_rs = torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter(
|
||||
gemm_rs = torch.ops.vllm.patched_fused_scaled_matmul_reduce_scatter(
|
||||
input,
|
||||
mat2,
|
||||
scale_a,
|
||||
|
||||
Reference in New Issue
Block a user