[mhlo] Add a folder for mhlo.map which does nothing but return one of the arguments.
Add a folder for maps whose body returns only one of the arguments. When this arises the fold replaces the map output with one of the operand tensors. PiperOrigin-RevId: 369304322
This commit is contained in:
parent
9374a1c0c5
commit
8db96f54d3
|
@ -1068,6 +1068,7 @@ def HLO_MapOp: HLO_Op<"map",
|
|||
);
|
||||
let regions = (region SizedRegion<1>:$computation);
|
||||
let results = (outs HLO_Tensor);
|
||||
let hasFolder = 1;
|
||||
let hasCustomHLOConverter = 1;
|
||||
}
|
||||
|
||||
|
|
|
@ -1775,6 +1775,20 @@ static LogicalResult Verify(MapOp op) {
|
|||
return success();
|
||||
}
|
||||
|
||||
OpFoldResult MapOp::fold(ArrayRef<Attribute> operands) {
|
||||
mlir::Block& bb = computation().front();
|
||||
mlir::Operation& front_op = bb.front();
|
||||
|
||||
auto ret_op = mlir::dyn_cast<ReturnOp>(front_op);
|
||||
if (!ret_op) return nullptr;
|
||||
if (ret_op.results().size() != 1) return nullptr;
|
||||
|
||||
for (mlir::BlockArgument barg : bb.getArguments()) {
|
||||
if (barg == ret_op.results()[0]) return getOperands()[barg.getArgNumber()];
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// RecvOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -1735,3 +1735,15 @@ func @reshape_of_same_shape_op_result(%arg: tensor<?xf32>,
|
|||
// CHECK: mhlo.dynamic_reshape
|
||||
// CHECK-NEXT: mhlo.abs
|
||||
// CHECK-NOT: mhlo.dynamic_reshape
|
||||
|
||||
// CHECK-LABEL: @map_op_fold
|
||||
func @map_op_fold(%arg: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
|
||||
%0 = "mhlo.map"(%arg, %arg1) ( {
|
||||
^bb0(%a: tensor<f32>, %b: tensor<f32>): // no predecessors
|
||||
"mhlo.return"(%b) : (tensor<f32>) -> ()
|
||||
}) {dimensions = dense<[0]> : tensor<1xi64>} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
// CHECK: return %arg1 : tensor<?xf32>
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue