Handle operands with zero elements in HLO PadOp folder

PiperOrigin-RevId: 348034821
This commit is contained in:
Smit Hinsu 2020-12-17 09:25:26 -08:00 committed by TensorFlow MLIR Team
parent 8134bff98d
commit 737d15ded5
2 changed files with 19 additions and 10 deletions

View File

@ -1879,28 +1879,29 @@ OpFoldResult PadOp::fold(ArrayRef<Attribute> operands) {
llvm::ArrayRef<int64_t> shape) {
for (int64_t i = index.size() - 1; i >= 0; --i) {
++index[i];
if (index[i] < shape[i]) return true;
if (index[i] < shape[i]) return;
index[i] = 0;
}
return false;
};
// Iterate over all elements of the input tensor and copy it to the correct
// location in the output tensor.
llvm::SmallVector<uint64_t, 8> index(input.getType().getRank(), 0);
do {
uint64_t linear_index = 0;
uint64_t linear_index_multiplyer = 1;
uint64_t num_elements = input.getNumElements();
for (uint64_t operand_idx = 0; operand_idx < num_elements; operand_idx++) {
uint64_t result_idx = 0;
uint64_t idx_multiplyer = 1;
for (int64_t i = index.size() - 1; i >= 0; --i) {
linear_index +=
result_idx +=
(edge_padding_low().getValue<int64_t>({uint64_t(i)}) +
index[i] *
(interior_padding().getValue<int64_t>({uint64_t(i)}) + 1)) *
linear_index_multiplyer;
linear_index_multiplyer *= return_type.getShape()[i];
idx_multiplyer;
idx_multiplyer *= return_type.getDimSize(i);
}
result[result_idx] = input.getValue(index);
next_index(index, input.getType().getShape());
}
result[linear_index] = input.getValue(index);
} while (next_index(index, input.getType().getShape()));
return DenseElementsAttr::get(return_type, result);
}

View File

@ -1515,6 +1515,14 @@ func @pad_fold() -> tensor<4x5xi32> {
// CHECK-SAME: ]> : tensor<4x5xi32>
}
func @pad_fold_zero_elements() -> tensor<3xi32> {
%0 = mhlo.constant dense<> : tensor<0xi32>
%1 = mhlo.constant dense<7> : tensor<i32>
%2 = "mhlo.pad"(%0, %1) {edge_padding_high = dense<3> : tensor<1xi64>, edge_padding_low = dense<0> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>} : (tensor<0xi32>, tensor<i32>) -> tensor<3xi32>
return %2 : tensor<3xi32>
// CHECK: mhlo.constant dense<7> : tensor<3xi32>
}
// CHECK-LABEL: @identity_broadcast_reshape
func @identity_broadcast_reshape(%arg0: tensor<128xf32>) -> tensor<128xf32> {
%0 = "mhlo.broadcast"(%arg0) {