Add folder for mhlo::pad

PiperOrigin-RevId: 337827560
This commit is contained in:
A. Unique TensorFlower 2020-10-19 04:20:57 -07:00 committed by TensorFlow MLIR Team
parent 4a18aa41ee
commit 7a983ea389
3 changed files with 83 additions and 0 deletions

View File

@ -1201,6 +1201,8 @@ def HLO_PadOp: HLO_Op<"pad",
// TODO(b/129422361): PadOp has a custom constructor for HLO.
let hasCustomHLOConverter = 1;
let hasFolder = 1;
}
def HLO_TraceOp: HLO_Op<"trace", []>, BASE_HLO_TraceOp {

View File

@ -1787,6 +1787,61 @@ static LogicalResult Verify(PadOp op) {
return success();
}
OpFoldResult PadOp::fold(ArrayRef<Attribute> operands) {
// If all padding is zero then it is an identity pad.
auto is_zero = [](const APInt& i) { return i == 0; };
if (llvm::all_of(edge_padding_low().getIntValues(), is_zero) &&
llvm::all_of(edge_padding_high().getIntValues(), is_zero) &&
llvm::all_of(interior_padding().getIntValues(), is_zero))
return operand();
// If any padding is negative then it isn't supported by the folder (yet).
auto is_negative = [](const APInt& i) { return i.slt(0); };
if (llvm::all_of(edge_padding_low().getIntValues(), is_negative) &&
llvm::all_of(edge_padding_high().getIntValues(), is_negative) &&
llvm::all_of(interior_padding().getIntValues(), is_negative))
return {};
DenseElementsAttr input = operands[0].dyn_cast_or_null<DenseElementsAttr>();
DenseElementsAttr padding = operands[1].dyn_cast_or_null<DenseElementsAttr>();
RankedTensorType return_type = getType().dyn_cast_or_null<RankedTensorType>();
if (!input || !input.getType().hasRank() || !padding || !return_type ||
!return_type.hasStaticShape())
return {};
// Fill the full result tensor with the padding value.
llvm::SmallVector<Attribute, 4> result(return_type.getNumElements(),
padding.getValue({}));
auto next_index = [](llvm::SmallVector<uint64_t, 8>& index,
llvm::ArrayRef<int64_t> shape) {
for (int64_t i = index.size() - 1; i >= 0; --i) {
++index[i];
if (index[i] < shape[i]) return true;
index[i] = 0;
}
return false;
};
// Iterate over all elements of the input tensor and copy it to the correct
// location in the output tensor.
llvm::SmallVector<uint64_t, 8> index(input.getType().getRank(), 0);
do {
uint64_t linear_index = 0;
uint64_t linear_index_multiplyer = 1;
for (int64_t i = index.size() - 1; i >= 0; --i) {
linear_index +=
(edge_padding_low().getValue<int64_t>({uint64_t(i)}) +
index[i] *
(interior_padding().getValue<int64_t>({uint64_t(i)}) + 1)) *
linear_index_multiplyer;
linear_index_multiplyer *= return_type.getShape()[i];
}
result[linear_index] = input.getValue(index);
} while (next_index(index, input.getType().getShape()));
return DenseElementsAttr::get(return_type, result);
}
//===----------------------------------------------------------------------===//
// ReshapeOp
//===----------------------------------------------------------------------===//

View File

@ -1437,3 +1437,29 @@ func @scatter_out_of_bound() -> tensor<3x3xi32> {
// CHECK: "mhlo.scatter"
}
// CHECK-LABEL: @pad_identity_fold
func @pad_identity_fold(%arg0: tensor<5x7xf32>) -> tensor<5x7xf32> {
%0 = constant dense<0.0> : tensor<f32>
%1 = "mhlo.pad"(%arg0, %0) {
edge_padding_low = dense<0> : tensor<2xi64>,
edge_padding_high = dense<0> : tensor<2xi64>,
interior_padding = dense<0> : tensor<2xi64>
} : (tensor<5x7xf32>, tensor<f32>) -> tensor<5x7xf32>
return %1 : tensor<5x7xf32>
// CHECK: return %arg0 : tensor<5x7xf32>
}
// CHECK-LABEL: @pad_fold
func @pad_fold() -> tensor<4x5xi32> {
%0 = constant dense<[[2, 3], [4, 5]]> : tensor<2x2xi32>
%1 = constant dense<1> : tensor<i32>
%3 = "mhlo.pad"(%0, %1) {
edge_padding_low = dense<[1, 0]> : tensor<2xi64>,
edge_padding_high = dense<[1, 2]> : tensor<2xi64>,
interior_padding = dense<[0, 1]> : tensor<2xi64>
} : (tensor<2x2xi32>, tensor<i32>) -> tensor<4x5xi32>
return %3 : tensor<4x5xi32>
// CHECK: constant dense<[
// CHECK-SAME: [1, 1, 1, 1, 1], [2, 1, 3, 1, 1], [4, 1, 5, 1, 1], [1, 1, 1, 1, 1]
// CHECK-SAME: ]> : tensor<4x5xi32>
}