[mhlo] Add a folder for mhlo.map which does nothing but return one of the arguments.
Add a folder for maps whose body returns only one of the arguments. When this arises the fold replaces the map output with one of the operand tensors. PiperOrigin-RevId: 369304322
This commit is contained in:
parent
9374a1c0c5
commit
8db96f54d3
|
@ -1068,6 +1068,7 @@ def HLO_MapOp: HLO_Op<"map",
|
||||||
);
|
);
|
||||||
let regions = (region SizedRegion<1>:$computation);
|
let regions = (region SizedRegion<1>:$computation);
|
||||||
let results = (outs HLO_Tensor);
|
let results = (outs HLO_Tensor);
|
||||||
|
let hasFolder = 1;
|
||||||
let hasCustomHLOConverter = 1;
|
let hasCustomHLOConverter = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1775,6 +1775,20 @@ static LogicalResult Verify(MapOp op) {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
OpFoldResult MapOp::fold(ArrayRef<Attribute> operands) {
|
||||||
|
mlir::Block& bb = computation().front();
|
||||||
|
mlir::Operation& front_op = bb.front();
|
||||||
|
|
||||||
|
auto ret_op = mlir::dyn_cast<ReturnOp>(front_op);
|
||||||
|
if (!ret_op) return nullptr;
|
||||||
|
if (ret_op.results().size() != 1) return nullptr;
|
||||||
|
|
||||||
|
for (mlir::BlockArgument barg : bb.getArguments()) {
|
||||||
|
if (barg == ret_op.results()[0]) return getOperands()[barg.getArgNumber()];
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// RecvOp
|
// RecvOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -1735,3 +1735,15 @@ func @reshape_of_same_shape_op_result(%arg: tensor<?xf32>,
|
||||||
// CHECK: mhlo.dynamic_reshape
|
// CHECK: mhlo.dynamic_reshape
|
||||||
// CHECK-NEXT: mhlo.abs
|
// CHECK-NEXT: mhlo.abs
|
||||||
// CHECK-NOT: mhlo.dynamic_reshape
|
// CHECK-NOT: mhlo.dynamic_reshape
|
||||||
|
|
||||||
|
// CHECK-LABEL: @map_op_fold
|
||||||
|
func @map_op_fold(%arg: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
|
||||||
|
%0 = "mhlo.map"(%arg, %arg1) ( {
|
||||||
|
^bb0(%a: tensor<f32>, %b: tensor<f32>): // no predecessors
|
||||||
|
"mhlo.return"(%b) : (tensor<f32>) -> ()
|
||||||
|
}) {dimensions = dense<[0]> : tensor<1xi64>} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||||
|
return %0 : tensor<?xf32>
|
||||||
|
}
|
||||||
|
// CHECK: return %arg1 : tensor<?xf32>
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue