Handle operands with zero elements in HLO PadOp folder
PiperOrigin-RevId: 348034821
This commit is contained in:
parent
8134bff98d
commit
737d15ded5
|
@ -1879,28 +1879,29 @@ OpFoldResult PadOp::fold(ArrayRef<Attribute> operands) {
|
||||||
llvm::ArrayRef<int64_t> shape) {
|
llvm::ArrayRef<int64_t> shape) {
|
||||||
for (int64_t i = index.size() - 1; i >= 0; --i) {
|
for (int64_t i = index.size() - 1; i >= 0; --i) {
|
||||||
++index[i];
|
++index[i];
|
||||||
if (index[i] < shape[i]) return true;
|
if (index[i] < shape[i]) return;
|
||||||
index[i] = 0;
|
index[i] = 0;
|
||||||
}
|
}
|
||||||
return false;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Iterate over all elements of the input tensor and copy it to the correct
|
// Iterate over all elements of the input tensor and copy it to the correct
|
||||||
// location in the output tensor.
|
// location in the output tensor.
|
||||||
llvm::SmallVector<uint64_t, 8> index(input.getType().getRank(), 0);
|
llvm::SmallVector<uint64_t, 8> index(input.getType().getRank(), 0);
|
||||||
do {
|
uint64_t num_elements = input.getNumElements();
|
||||||
uint64_t linear_index = 0;
|
for (uint64_t operand_idx = 0; operand_idx < num_elements; operand_idx++) {
|
||||||
uint64_t linear_index_multiplyer = 1;
|
uint64_t result_idx = 0;
|
||||||
|
uint64_t idx_multiplyer = 1;
|
||||||
for (int64_t i = index.size() - 1; i >= 0; --i) {
|
for (int64_t i = index.size() - 1; i >= 0; --i) {
|
||||||
linear_index +=
|
result_idx +=
|
||||||
(edge_padding_low().getValue<int64_t>({uint64_t(i)}) +
|
(edge_padding_low().getValue<int64_t>({uint64_t(i)}) +
|
||||||
index[i] *
|
index[i] *
|
||||||
(interior_padding().getValue<int64_t>({uint64_t(i)}) + 1)) *
|
(interior_padding().getValue<int64_t>({uint64_t(i)}) + 1)) *
|
||||||
linear_index_multiplyer;
|
idx_multiplyer;
|
||||||
linear_index_multiplyer *= return_type.getShape()[i];
|
idx_multiplyer *= return_type.getDimSize(i);
|
||||||
|
}
|
||||||
|
result[result_idx] = input.getValue(index);
|
||||||
|
next_index(index, input.getType().getShape());
|
||||||
}
|
}
|
||||||
result[linear_index] = input.getValue(index);
|
|
||||||
} while (next_index(index, input.getType().getShape()));
|
|
||||||
return DenseElementsAttr::get(return_type, result);
|
return DenseElementsAttr::get(return_type, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1515,6 +1515,14 @@ func @pad_fold() -> tensor<4x5xi32> {
|
||||||
// CHECK-SAME: ]> : tensor<4x5xi32>
|
// CHECK-SAME: ]> : tensor<4x5xi32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func @pad_fold_zero_elements() -> tensor<3xi32> {
|
||||||
|
%0 = mhlo.constant dense<> : tensor<0xi32>
|
||||||
|
%1 = mhlo.constant dense<7> : tensor<i32>
|
||||||
|
%2 = "mhlo.pad"(%0, %1) {edge_padding_high = dense<3> : tensor<1xi64>, edge_padding_low = dense<0> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>} : (tensor<0xi32>, tensor<i32>) -> tensor<3xi32>
|
||||||
|
return %2 : tensor<3xi32>
|
||||||
|
// CHECK: mhlo.constant dense<7> : tensor<3xi32>
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: @identity_broadcast_reshape
|
// CHECK-LABEL: @identity_broadcast_reshape
|
||||||
func @identity_broadcast_reshape(%arg0: tensor<128xf32>) -> tensor<128xf32> {
|
func @identity_broadcast_reshape(%arg0: tensor<128xf32>) -> tensor<128xf32> {
|
||||||
%0 = "mhlo.broadcast"(%arg0) {
|
%0 = "mhlo.broadcast"(%arg0) {
|
||||||
|
|
Loading…
Reference in New Issue