From 8db96f54d386c3cbf046af8af3a629fb6260b0f0 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 19 Apr 2021 14:35:18 -0700 Subject: [PATCH] [mhlo] Add a folder for mhlo.map which does nothing but return one of the arguments. Add a folder for maps whose body returns only one of the arguments. When this arises the fold replaces the map output with one of the operand tensors. PiperOrigin-RevId: 369304322 --- include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td | 1 + lib/Dialect/mhlo/IR/hlo_ops.cc | 14 ++++++++++++++ tests/canonicalize.mlir | 12 ++++++++++++ 3 files changed, 27 insertions(+) diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td index e9286fa..3a915b3 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td @@ -1068,6 +1068,7 @@ def HLO_MapOp: HLO_Op<"map", ); let regions = (region SizedRegion<1>:$computation); let results = (outs HLO_Tensor); + let hasFolder = 1; let hasCustomHLOConverter = 1; } diff --git a/lib/Dialect/mhlo/IR/hlo_ops.cc b/lib/Dialect/mhlo/IR/hlo_ops.cc index f2a1aa6..6100223 100644 --- a/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -1775,6 +1775,20 @@ static LogicalResult Verify(MapOp op) { return success(); } +OpFoldResult MapOp::fold(ArrayRef operands) { + mlir::Block& bb = computation().front(); + mlir::Operation& front_op = bb.front(); + + auto ret_op = mlir::dyn_cast(front_op); + if (!ret_op) return nullptr; + if (ret_op.results().size() != 1) return nullptr; + + for (mlir::BlockArgument barg : bb.getArguments()) { + if (barg == ret_op.results()[0]) return getOperands()[barg.getArgNumber()]; + } + return nullptr; +} + //===----------------------------------------------------------------------===// // RecvOp //===----------------------------------------------------------------------===// diff --git a/tests/canonicalize.mlir b/tests/canonicalize.mlir index 70a86e1..85816b0 100644 --- a/tests/canonicalize.mlir +++ b/tests/canonicalize.mlir @@ -1735,3 +1735,15 @@ func @reshape_of_same_shape_op_result(%arg: tensor, // CHECK: mhlo.dynamic_reshape // CHECK-NEXT: mhlo.abs // CHECK-NOT: mhlo.dynamic_reshape + +// CHECK-LABEL: @map_op_fold +func @map_op_fold(%arg: tensor, %arg1: tensor) -> tensor { + %0 = "mhlo.map"(%arg, %arg1) ( { + ^bb0(%a: tensor, %b: tensor): // no predecessors + "mhlo.return"(%b) : (tensor) -> () + }) {dimensions = dense<[0]> : tensor<1xi64>} : (tensor, tensor) -> tensor + return %0 : tensor +} +// CHECK: return %arg1 : tensor + +