diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td index e9286fa..3a915b3 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td @@ -1068,6 +1068,7 @@ def HLO_MapOp: HLO_Op<"map", ); let regions = (region SizedRegion<1>:$computation); let results = (outs HLO_Tensor); + let hasFolder = 1; let hasCustomHLOConverter = 1; } diff --git a/lib/Dialect/mhlo/IR/hlo_ops.cc b/lib/Dialect/mhlo/IR/hlo_ops.cc index f2a1aa6..6100223 100644 --- a/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -1775,6 +1775,20 @@ static LogicalResult Verify(MapOp op) { return success(); } +OpFoldResult MapOp::fold(ArrayRef operands) { + mlir::Block& bb = computation().front(); + mlir::Operation& front_op = bb.front(); + + auto ret_op = mlir::dyn_cast(front_op); + if (!ret_op) return nullptr; + if (ret_op.results().size() != 1) return nullptr; + + for (mlir::BlockArgument barg : bb.getArguments()) { + if (barg == ret_op.results()[0]) return getOperands()[barg.getArgNumber()]; + } + return nullptr; +} + //===----------------------------------------------------------------------===// // RecvOp //===----------------------------------------------------------------------===// diff --git a/tests/canonicalize.mlir b/tests/canonicalize.mlir index 70a86e1..85816b0 100644 --- a/tests/canonicalize.mlir +++ b/tests/canonicalize.mlir @@ -1735,3 +1735,15 @@ func @reshape_of_same_shape_op_result(%arg: tensor, // CHECK: mhlo.dynamic_reshape // CHECK-NEXT: mhlo.abs // CHECK-NOT: mhlo.dynamic_reshape + +// CHECK-LABEL: @map_op_fold +func @map_op_fold(%arg: tensor, %arg1: tensor) -> tensor { + %0 = "mhlo.map"(%arg, %arg1) ( { + ^bb0(%a: tensor, %b: tensor): // no predecessors + "mhlo.return"(%b) : (tensor) -> () + }) {dimensions = dense<[0]> : tensor<1xi64>} : (tensor, tensor) -> tensor + return %0 : tensor +} +// CHECK: return %arg1 : tensor + +