Handle operands with zero elements in HLO PadOp folder
PiperOrigin-RevId: 348034821
This commit is contained in:
parent
8134bff98d
commit
737d15ded5
|
@ -1879,28 +1879,29 @@ OpFoldResult PadOp::fold(ArrayRef<Attribute> operands) {
|
|||
llvm::ArrayRef<int64_t> shape) {
|
||||
for (int64_t i = index.size() - 1; i >= 0; --i) {
|
||||
++index[i];
|
||||
if (index[i] < shape[i]) return true;
|
||||
if (index[i] < shape[i]) return;
|
||||
index[i] = 0;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
// Iterate over all elements of the input tensor and copy it to the correct
|
||||
// location in the output tensor.
|
||||
llvm::SmallVector<uint64_t, 8> index(input.getType().getRank(), 0);
|
||||
do {
|
||||
uint64_t linear_index = 0;
|
||||
uint64_t linear_index_multiplyer = 1;
|
||||
uint64_t num_elements = input.getNumElements();
|
||||
for (uint64_t operand_idx = 0; operand_idx < num_elements; operand_idx++) {
|
||||
uint64_t result_idx = 0;
|
||||
uint64_t idx_multiplyer = 1;
|
||||
for (int64_t i = index.size() - 1; i >= 0; --i) {
|
||||
linear_index +=
|
||||
result_idx +=
|
||||
(edge_padding_low().getValue<int64_t>({uint64_t(i)}) +
|
||||
index[i] *
|
||||
(interior_padding().getValue<int64_t>({uint64_t(i)}) + 1)) *
|
||||
linear_index_multiplyer;
|
||||
linear_index_multiplyer *= return_type.getShape()[i];
|
||||
idx_multiplyer;
|
||||
idx_multiplyer *= return_type.getDimSize(i);
|
||||
}
|
||||
result[linear_index] = input.getValue(index);
|
||||
} while (next_index(index, input.getType().getShape()));
|
||||
result[result_idx] = input.getValue(index);
|
||||
next_index(index, input.getType().getShape());
|
||||
}
|
||||
return DenseElementsAttr::get(return_type, result);
|
||||
}
|
||||
|
||||
|
|
|
@ -1515,6 +1515,14 @@ func @pad_fold() -> tensor<4x5xi32> {
|
|||
// CHECK-SAME: ]> : tensor<4x5xi32>
|
||||
}
|
||||
|
||||
func @pad_fold_zero_elements() -> tensor<3xi32> {
|
||||
%0 = mhlo.constant dense<> : tensor<0xi32>
|
||||
%1 = mhlo.constant dense<7> : tensor<i32>
|
||||
%2 = "mhlo.pad"(%0, %1) {edge_padding_high = dense<3> : tensor<1xi64>, edge_padding_low = dense<0> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>} : (tensor<0xi32>, tensor<i32>) -> tensor<3xi32>
|
||||
return %2 : tensor<3xi32>
|
||||
// CHECK: mhlo.constant dense<7> : tensor<3xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @identity_broadcast_reshape
|
||||
func @identity_broadcast_reshape(%arg0: tensor<128xf32>) -> tensor<128xf32> {
|
||||
%0 = "mhlo.broadcast"(%arg0) {
|
||||
|
|
Loading…
Reference in New Issue