Add folder for mhlo::scatter
PiperOrigin-RevId: 337274351
This commit is contained in:
parent
0f36979e2c
commit
05ee41baf8
|
@ -1124,6 +1124,8 @@ def HLO_ScatterOp: HLO_Op<"scatter", [RecursiveSideEffects]>,
|
||||||
let results = (outs HLO_Tensor);
|
let results = (outs HLO_Tensor);
|
||||||
|
|
||||||
let hasCustomHLOConverter = 1;
|
let hasCustomHLOConverter = 1;
|
||||||
|
|
||||||
|
let hasFolder = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(jpienaar): Add broadcastable trait.
|
// TODO(jpienaar): Add broadcastable trait.
|
||||||
|
|
|
@ -27,6 +27,7 @@ limitations under the License.
|
||||||
#include "llvm/ADT/APFloat.h"
|
#include "llvm/ADT/APFloat.h"
|
||||||
#include "llvm/ADT/APInt.h"
|
#include "llvm/ADT/APInt.h"
|
||||||
#include "llvm/ADT/ArrayRef.h"
|
#include "llvm/ADT/ArrayRef.h"
|
||||||
|
#include "llvm/ADT/DenseMap.h"
|
||||||
#include "llvm/ADT/STLExtras.h"
|
#include "llvm/ADT/STLExtras.h"
|
||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
#include "llvm/ADT/StringRef.h"
|
#include "llvm/ADT/StringRef.h"
|
||||||
|
@ -2645,6 +2646,145 @@ OpFoldResult CompareOp::fold(ArrayRef<Attribute> operands) {
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// ScatterOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
llvm::SmallVector<Attribute, 4> evaluateMhloRegion(Region& region,
|
||||||
|
ArrayRef<Attribute> inputs) {
|
||||||
|
if (region.getNumArguments() != inputs.size()) return {};
|
||||||
|
|
||||||
|
llvm::DenseMap<Value, Attribute> values;
|
||||||
|
values.reserve(region.getNumArguments());
|
||||||
|
for (auto it : llvm::zip(region.getArguments(), inputs)) {
|
||||||
|
values.try_emplace(std::get<0>(it), std::get<1>(it));
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto& op : region.getOps()) {
|
||||||
|
llvm::SmallVector<Attribute, 4> inputs;
|
||||||
|
for (auto& operand : op.getOpOperands()) {
|
||||||
|
inputs.push_back(values.lookup(operand.get()));
|
||||||
|
}
|
||||||
|
if (isa<ReturnOp>(op)) return inputs;
|
||||||
|
|
||||||
|
llvm::SmallVector<OpFoldResult, 4> results;
|
||||||
|
if (failed(op.fold(inputs, results))) return {};
|
||||||
|
for (auto it : llvm::zip(op.getResults(), results)) {
|
||||||
|
if (!std::get<1>(it).is<Attribute>()) return {};
|
||||||
|
values.insert({std::get<0>(it), std::get<1>(it).get<Attribute>()});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
OpFoldResult ScatterOp::fold(ArrayRef<Attribute> operands) {
|
||||||
|
auto base = operands[0].dyn_cast_or_null<DenseElementsAttr>();
|
||||||
|
auto index = operands[1].dyn_cast_or_null<DenseIntElementsAttr>();
|
||||||
|
auto update = operands[2].dyn_cast_or_null<DenseElementsAttr>();
|
||||||
|
if (!base || !index || !update) return {};
|
||||||
|
|
||||||
|
auto base_type = base.getType().dyn_cast<RankedTensorType>();
|
||||||
|
auto index_type = index.getType().dyn_cast<RankedTensorType>();
|
||||||
|
auto update_type = update.getType().dyn_cast<RankedTensorType>();
|
||||||
|
if (!base_type || !index_type || !update_type) return {};
|
||||||
|
|
||||||
|
// Add the virtual trailing dimension of size 1 if index_vector_dim equals to
|
||||||
|
// index_type.rank.
|
||||||
|
const int64_t index_vector_dim =
|
||||||
|
scatter_dimension_numbers().index_vector_dim().getInt();
|
||||||
|
if (index_vector_dim == index_type.getRank()) {
|
||||||
|
auto index_shape = index_type.getShape().vec();
|
||||||
|
index_shape.push_back(1);
|
||||||
|
index_type =
|
||||||
|
RankedTensorType::get(index_shape, index_type.getElementType());
|
||||||
|
index = index.reshape(index_type).cast<DenseIntElementsAttr>();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Increment the multi-dimensional index vector based on the limits for each
|
||||||
|
// dimension specified by shape and returns false if the index rolled around
|
||||||
|
// with true otherwise.
|
||||||
|
auto next_index = [](llvm::SmallVector<uint64_t, 8>& index,
|
||||||
|
llvm::ArrayRef<int64_t> shape) {
|
||||||
|
for (int64_t i = index.size() - 1; i >= 0; --i) {
|
||||||
|
++index[i];
|
||||||
|
if (index[i] < shape[i]) return true;
|
||||||
|
index[i] = 0;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Iterate over all elements of the update tensor, then find the corresponding
|
||||||
|
// value in the indices tensor to determine which location we have to update
|
||||||
|
// in the base/result tensor.
|
||||||
|
llvm::SmallVector<Attribute, 8> results(base.getValues<Attribute>());
|
||||||
|
llvm::SmallVector<uint64_t, 8> update_index(update_type.getRank(), 0);
|
||||||
|
llvm::SmallVector<uint64_t, 8> index_index;
|
||||||
|
index_index.reserve(index_type.getRank());
|
||||||
|
llvm::SmallVector<uint64_t, 8> base_index;
|
||||||
|
base_index.reserve(base_type.getRank());
|
||||||
|
do {
|
||||||
|
// Compute the index for the slice of the indices tensor for this update
|
||||||
|
// value.
|
||||||
|
index_index.clear();
|
||||||
|
if (index_vector_dim == 0) index_index.push_back(0);
|
||||||
|
for (int64_t i = 0; i < update_index.size(); ++i) {
|
||||||
|
if (llvm::count(scatter_dimension_numbers().update_window_dims(), i) == 0)
|
||||||
|
index_index.push_back(update_index[i]);
|
||||||
|
if (index_index.size() == index_vector_dim) index_index.push_back(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute the index for the given update value in the base tensor.
|
||||||
|
base_index.assign(base_type.getRank(), 0);
|
||||||
|
uint64_t index_count = index_type.getShape()[index_vector_dim];
|
||||||
|
for (uint64_t i = 0; i < index_count; ++i) {
|
||||||
|
uint64_t operand_dim = scatter_dimension_numbers()
|
||||||
|
.scatter_dims_to_operand_dims()
|
||||||
|
.getValue<APInt>({i})
|
||||||
|
.getSExtValue();
|
||||||
|
index_index[index_vector_dim] = i;
|
||||||
|
base_index[operand_dim] +=
|
||||||
|
index.getValue<APInt>(index_index).getSExtValue();
|
||||||
|
}
|
||||||
|
uint64_t update_window_dim_index = 0;
|
||||||
|
for (uint64_t i = 0; i < base_index.size(); ++i) {
|
||||||
|
if (llvm::count(scatter_dimension_numbers().inserted_window_dims(), i))
|
||||||
|
continue;
|
||||||
|
base_index[i] +=
|
||||||
|
update_index[scatter_dimension_numbers()
|
||||||
|
.update_window_dims()
|
||||||
|
.getValue<APInt>({update_window_dim_index})
|
||||||
|
.getSExtValue()];
|
||||||
|
update_window_dim_index++;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute the linear index for the index into the base tensor.
|
||||||
|
int64_t linear_base_index = 0;
|
||||||
|
int64_t linear_base_index_multiplyer = 1;
|
||||||
|
for (int64_t i = base_index.size() - 1; i >= 0; --i) {
|
||||||
|
// Out of bound index have backend specific behaviour so avoid folding it.
|
||||||
|
if (base_index[i] < 0 || base_index[i] >= base_type.getShape()[i])
|
||||||
|
return {};
|
||||||
|
linear_base_index += base_index[i] * linear_base_index_multiplyer;
|
||||||
|
linear_base_index_multiplyer *= base_type.getShape()[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Evaluate update computation and update the value with the newly computed
|
||||||
|
// attribute in the base tensor.
|
||||||
|
auto lhs = DenseElementsAttr::get(
|
||||||
|
RankedTensorType::get({}, base_type.getElementType()),
|
||||||
|
results[linear_base_index]);
|
||||||
|
auto rhs = DenseElementsAttr::get(
|
||||||
|
RankedTensorType::get({}, base_type.getElementType()),
|
||||||
|
update.getValue<Attribute>(update_index));
|
||||||
|
auto new_value = evaluateMhloRegion(update_computation(), {lhs, rhs});
|
||||||
|
if (new_value.size() != 1 || !new_value[0]) return {};
|
||||||
|
results[linear_base_index] =
|
||||||
|
new_value[0].cast<DenseElementsAttr>().getValue<Attribute>({});
|
||||||
|
} while (next_index(update_index, update_type.getShape()));
|
||||||
|
|
||||||
|
return DenseElementsAttr::get(base_type, results);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mhlo
|
} // namespace mhlo
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
|
|
|
@ -1167,3 +1167,265 @@ func @not_fold_sqrt_neg_constants() -> tensor<4xf32> {
|
||||||
// CHECK: mhlo.sqrt
|
// CHECK: mhlo.sqrt
|
||||||
return %1 : tensor<4xf32>
|
return %1 : tensor<4xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @tensor_flow_scatter_v1_update
|
||||||
|
func @tensor_flow_scatter_v1_update() -> tensor<3x3xi32> {
|
||||||
|
%0 = constant dense<[[1, 2, 3], [4, 5, 6], [7, 8, 9]]> : tensor<3x3xi32>
|
||||||
|
%1 = constant dense<[0, 2]> : tensor<2xi32>
|
||||||
|
%2 = constant dense<[[10, 20, 30], [70, 80, 90]]> : tensor<2x3xi32>
|
||||||
|
%3 = "mhlo.scatter"(%0, %1, %2) ( {
|
||||||
|
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
|
||||||
|
"mhlo.return"(%arg1) : (tensor<i32>) -> ()
|
||||||
|
}) {indices_are_sorted = false,
|
||||||
|
scatter_dimension_numbers = {
|
||||||
|
index_vector_dim = 1 : i64,
|
||||||
|
inserted_window_dims = dense<0> : tensor<1xi64>,
|
||||||
|
scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>,
|
||||||
|
update_window_dims = dense<[1]> : tensor<1xi64>
|
||||||
|
},
|
||||||
|
unique_indices = false
|
||||||
|
} : (tensor<3x3xi32>, tensor<2xi32>, tensor<2x3xi32>) -> tensor<3x3xi32>
|
||||||
|
return %3 : tensor<3x3xi32>
|
||||||
|
// CHECK: mhlo.constant dense<[
|
||||||
|
// CHECK-SAME: [10, 20, 30], [4, 5, 6], [70, 80, 90]
|
||||||
|
// CHECK-SAME: ]> : tensor<3x3xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @tensor_flow_scatter_v2_update
|
||||||
|
func @tensor_flow_scatter_v2_update() -> tensor<3x3xi32> {
|
||||||
|
%0 = constant dense<[[1, 2, 3], [4, 5, 6], [7, 8, 9]]> : tensor<3x3xi32>
|
||||||
|
%1 = constant dense<[0, 2]> : tensor<2xi32>
|
||||||
|
%2 = constant dense<[[10, 30], [40, 60], [70, 90]]> : tensor<3x2xi32>
|
||||||
|
%3 = "mhlo.scatter"(%0, %1, %2) ( {
|
||||||
|
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
|
||||||
|
"mhlo.return"(%arg1) : (tensor<i32>) -> ()
|
||||||
|
}) {indices_are_sorted = false,
|
||||||
|
scatter_dimension_numbers = {
|
||||||
|
index_vector_dim = 1 : i64,
|
||||||
|
inserted_window_dims = dense<1> : tensor<1xi64>,
|
||||||
|
scatter_dims_to_operand_dims = dense<1> : tensor<1xi64>,
|
||||||
|
update_window_dims = dense<[0]> : tensor<1xi64>
|
||||||
|
},
|
||||||
|
unique_indices = false
|
||||||
|
} : (tensor<3x3xi32>, tensor<2xi32>, tensor<3x2xi32>) -> tensor<3x3xi32>
|
||||||
|
return %3 : tensor<3x3xi32>
|
||||||
|
// CHECK: mhlo.constant dense<[
|
||||||
|
// CHECK-SAME: [10, 2, 30], [40, 5, 60], [70, 8, 90]
|
||||||
|
// CHECK-SAME: ]> : tensor<3x3xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @tensor_flow_scatter_add
|
||||||
|
func @tensor_flow_scatter_add() -> tensor<3x3xi32> {
|
||||||
|
%0 = constant dense<[[1, 2, 3], [4, 5, 6], [7, 8, 9]]> : tensor<3x3xi32>
|
||||||
|
%1 = constant dense<[0, 2]> : tensor<2xi32>
|
||||||
|
%2 = constant dense<[[10, 20, 30], [70, 80, 90]]> : tensor<2x3xi32>
|
||||||
|
%3 = "mhlo.scatter"(%0, %1, %2) ( {
|
||||||
|
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
|
||||||
|
%4 = "mhlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> (tensor<i32>)
|
||||||
|
"mhlo.return"(%4) : (tensor<i32>) -> ()
|
||||||
|
}) {indices_are_sorted = false,
|
||||||
|
scatter_dimension_numbers = {
|
||||||
|
index_vector_dim = 1 : i64,
|
||||||
|
inserted_window_dims = dense<0> : tensor<1xi64>,
|
||||||
|
scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>,
|
||||||
|
update_window_dims = dense<[1]> : tensor<1xi64>
|
||||||
|
},
|
||||||
|
unique_indices = false
|
||||||
|
} : (tensor<3x3xi32>, tensor<2xi32>, tensor<2x3xi32>) -> tensor<3x3xi32>
|
||||||
|
return %3 : tensor<3x3xi32>
|
||||||
|
// CHECK: mhlo.constant dense<[
|
||||||
|
// CHECK-SAME: [11, 22, 33], [4, 5, 6], [77, 88, 99]
|
||||||
|
// CHECK-SAME: ]> : tensor<3x3xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @tensor_flow_scatter_repeated
|
||||||
|
func @tensor_flow_scatter_repeated() -> tensor<3x3xi32> {
|
||||||
|
%0 = constant dense<[[1, 2, 3], [4, 5, 6], [7, 8, 9]]> : tensor<3x3xi32>
|
||||||
|
%1 = constant dense<[1, 1]> : tensor<2xi32>
|
||||||
|
%2 = constant dense<[[10, 20, 30], [70, 80, 90]]> : tensor<2x3xi32>
|
||||||
|
%3 = "mhlo.scatter"(%0, %1, %2) ( {
|
||||||
|
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
|
||||||
|
%4 = "mhlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> (tensor<i32>)
|
||||||
|
"mhlo.return"(%4) : (tensor<i32>) -> ()
|
||||||
|
}) {indices_are_sorted = false,
|
||||||
|
scatter_dimension_numbers = {
|
||||||
|
index_vector_dim = 1 : i64,
|
||||||
|
inserted_window_dims = dense<0> : tensor<1xi64>,
|
||||||
|
scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>,
|
||||||
|
update_window_dims = dense<[1]> : tensor<1xi64>
|
||||||
|
},
|
||||||
|
unique_indices = false
|
||||||
|
} : (tensor<3x3xi32>, tensor<2xi32>, tensor<2x3xi32>) -> tensor<3x3xi32>
|
||||||
|
return %3 : tensor<3x3xi32>
|
||||||
|
// CHECK: mhlo.constant dense<[
|
||||||
|
// CHECK-SAME: [1, 2, 3], [84, 105, 126], [7, 8, 9]
|
||||||
|
// CHECK-SAME: ]> : tensor<3x3xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @tensor_flow_scatter_multiple_batch
|
||||||
|
func @tensor_flow_scatter_multiple_batch() -> tensor<3x3xi32> {
|
||||||
|
%0 = constant dense<[[1, 2, 3], [4, 5, 6], [7, 8, 9]]> : tensor<3x3xi32>
|
||||||
|
%1 = constant dense<[[0, 2], [2, 1]]> : tensor<2x2xi32>
|
||||||
|
%2 = constant dense<[[[10, 30], [40, 60], [70, 90]], [[5, 5], [5, 5], [5, 5]]]> : tensor<2x3x2xi32>
|
||||||
|
%3 = "mhlo.scatter"(%0, %1, %2) ( {
|
||||||
|
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
|
||||||
|
%4 = "mhlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> (tensor<i32>)
|
||||||
|
"mhlo.return"(%4) : (tensor<i32>) -> ()
|
||||||
|
}) {indices_are_sorted = false,
|
||||||
|
scatter_dimension_numbers = {
|
||||||
|
index_vector_dim = 2 : i64,
|
||||||
|
inserted_window_dims = dense<1> : tensor<1xi64>,
|
||||||
|
scatter_dims_to_operand_dims = dense<1> : tensor<1xi64>,
|
||||||
|
update_window_dims = dense<[1]> : tensor<1xi64>
|
||||||
|
},
|
||||||
|
unique_indices = false
|
||||||
|
} : (tensor<3x3xi32>, tensor<2x2xi32>, tensor<2x3x2xi32>) -> tensor<3x3xi32>
|
||||||
|
return %3 : tensor<3x3xi32>
|
||||||
|
// CHECK: mhlo.constant dense<[
|
||||||
|
// CHECK-SAME: [11, 7, 38], [44, 10, 71], [77, 13, 104]
|
||||||
|
// CHECK-SAME: ]> : tensor<3x3xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @tensor_flow_scatter_nd
|
||||||
|
func @tensor_flow_scatter_nd() -> tensor<3x3x2xi32> {
|
||||||
|
%0 = constant dense<[[[-1, 1], [-2, 2], [-3, 3]], [[-4, 4], [-5, 5], [-6, 6]], [[-7, 7], [-8, 8], [-9, 9]]]> : tensor<3x3x2xi32>
|
||||||
|
%1 = constant dense<[[0, 0], [1, 0]]> : tensor<2x2xi32>
|
||||||
|
%2 = constant dense<[[-10, 10], [-40, 40]]> : tensor<2x2xi32>
|
||||||
|
%3 = "mhlo.scatter"(%0, %1, %2) ( {
|
||||||
|
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
|
||||||
|
"mhlo.return"(%arg1) : (tensor<i32>) -> ()
|
||||||
|
}) {indices_are_sorted = false,
|
||||||
|
scatter_dimension_numbers = {
|
||||||
|
index_vector_dim = 1 : i64,
|
||||||
|
inserted_window_dims = dense<[0, 1]> : tensor<2xi64>,
|
||||||
|
scatter_dims_to_operand_dims = dense<[0, 1]> : tensor<2xi64>,
|
||||||
|
update_window_dims = dense<[1]> : tensor<1xi64>
|
||||||
|
},
|
||||||
|
unique_indices = false
|
||||||
|
} : (tensor<3x3x2xi32>, tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<3x3x2xi32>
|
||||||
|
return %3 : tensor<3x3x2xi32>
|
||||||
|
// CHECK: mhlo.constant dense<[
|
||||||
|
// CHECK-SAME: [-10, 10], [-2, 2], [-3, 3]
|
||||||
|
// CHECK-SAME: [-40, 40], [-5, 5], [-6, 6]
|
||||||
|
// CHECK-SAME: [-7, 7], [-8, 8], [-9, 9]
|
||||||
|
// CHECK-SAME: ]> : tensor<3x3x2xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @tensor_flow_scatter_nd_index_vector
|
||||||
|
func @tensor_flow_scatter_nd_index_vector() -> tensor<3x3x2xi32> {
|
||||||
|
%0 = constant dense<[[[-1, 1], [-2, 2], [-3, 3]], [[-4, 4], [-5, 5], [-6, 6]], [[-7, 7], [-8, 8], [-9, 9]]]> : tensor<3x3x2xi32>
|
||||||
|
%1 = constant dense<[[0, 0], [1, 0]]> : tensor<2x2xi32>
|
||||||
|
%2 = constant dense<[[-10, 10], [-20, 20]]> : tensor<2x2xi32>
|
||||||
|
%3 = "mhlo.scatter"(%0, %1, %2) ( {
|
||||||
|
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
|
||||||
|
"mhlo.return"(%arg1) : (tensor<i32>) -> ()
|
||||||
|
}) {indices_are_sorted = false,
|
||||||
|
scatter_dimension_numbers = {
|
||||||
|
index_vector_dim = 0 : i64,
|
||||||
|
inserted_window_dims = dense<[0, 1]> : tensor<2xi64>,
|
||||||
|
scatter_dims_to_operand_dims = dense<[0, 1]> : tensor<2xi64>,
|
||||||
|
update_window_dims = dense<[1]> : tensor<1xi64>
|
||||||
|
},
|
||||||
|
unique_indices = false
|
||||||
|
} : (tensor<3x3x2xi32>, tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<3x3x2xi32>
|
||||||
|
return %3 : tensor<3x3x2xi32>
|
||||||
|
// CHECK: mhlo.constant dense<[
|
||||||
|
// CHECK-SAME: [-20, 20], [-10, 10], [-3, 3]
|
||||||
|
// CHECK-SAME: [-4, 4], [-5, 5], [-6, 6]
|
||||||
|
// CHECK-SAME: [-7, 7], [-8, 8], [-9, 9]
|
||||||
|
// CHECK-SAME: ]> : tensor<3x3x2xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @scatter_batch_dus
|
||||||
|
func @scatter_batch_dus() -> tensor<3x3xi32> {
|
||||||
|
%0 = constant dense<[[1, 2, 3], [4, 5, 6], [7, 8, 9]]> : tensor<3x3xi32>
|
||||||
|
%1 = constant dense<[[2, 1], [1, 1]]> : tensor<2x2xi32>
|
||||||
|
%2 = constant dense<[[[10]], [[20]]]> : tensor<2x1x1xi32>
|
||||||
|
%3 = "mhlo.scatter"(%0, %1, %2) ( {
|
||||||
|
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
|
||||||
|
"mhlo.return"(%arg1) : (tensor<i32>) -> ()
|
||||||
|
}) {indices_are_sorted = false,
|
||||||
|
scatter_dimension_numbers = {
|
||||||
|
index_vector_dim = 0 : i64,
|
||||||
|
inserted_window_dims = dense<> : tensor<0xi64>,
|
||||||
|
scatter_dims_to_operand_dims = dense<[0, 1]> : tensor<2xi64>,
|
||||||
|
update_window_dims = dense<[1, 2]> : tensor<2xi64>
|
||||||
|
},
|
||||||
|
unique_indices = false
|
||||||
|
} : (tensor<3x3xi32>, tensor<2x2xi32>, tensor<2x1x1xi32>) -> tensor<3x3xi32>
|
||||||
|
return %3 : tensor<3x3xi32>
|
||||||
|
// CHECK: mhlo.constant dense<[
|
||||||
|
// CHECK-SAME: [1, 2, 3], [4, 20, 6], [7, 10, 9]
|
||||||
|
// CHECK-SAME: ]> : tensor<3x3xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @scatter_no_update_window_dim
|
||||||
|
func @scatter_no_update_window_dim() -> tensor<3xi32> {
|
||||||
|
%0 = constant dense<[0, 1, 2]> : tensor<3xi32>
|
||||||
|
%1 = constant dense<[[[0], [1]], [[2], [1]]]> : tensor<2x2x1xi32>
|
||||||
|
%2 = constant dense<[[10, 20], [30, 40]]> : tensor<2x2xi32>
|
||||||
|
%3 = "mhlo.scatter"(%0, %1, %2) ( {
|
||||||
|
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
|
||||||
|
%4 = "mhlo.add"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> (tensor<i32>)
|
||||||
|
"mhlo.return"(%4) : (tensor<i32>) -> ()
|
||||||
|
}) {indices_are_sorted = false,
|
||||||
|
scatter_dimension_numbers = {
|
||||||
|
index_vector_dim = 2 : i64,
|
||||||
|
inserted_window_dims = dense<0> : tensor<1xi64>,
|
||||||
|
scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>,
|
||||||
|
update_window_dims = dense<> : tensor<0xi64>
|
||||||
|
},
|
||||||
|
unique_indices = false
|
||||||
|
} : (tensor<3xi32>, tensor<2x2x1xi32>, tensor<2x2xi32>) -> tensor<3xi32>
|
||||||
|
return %3 : tensor<3xi32>
|
||||||
|
// CHECK: mhlo.constant dense<[10, 61, 32]> : tensor<3xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @scatter_negative_index
|
||||||
|
func @scatter_negative_index() -> tensor<3x3xi32> {
|
||||||
|
%0 = constant dense<[[1, 2, 3], [4, 5, 6], [7, 8, 9]]> : tensor<3x3xi32>
|
||||||
|
%1 = constant dense<[0, -1]> : tensor<2xi32>
|
||||||
|
%2 = constant dense<[[10, 20, 30], [70, 80, 90]]> : tensor<2x3xi32>
|
||||||
|
%3 = "mhlo.scatter"(%0, %1, %2) ( {
|
||||||
|
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
|
||||||
|
"mhlo.return"(%arg1) : (tensor<i32>) -> ()
|
||||||
|
}) {indices_are_sorted = false,
|
||||||
|
scatter_dimension_numbers = {
|
||||||
|
index_vector_dim = 1 : i64,
|
||||||
|
inserted_window_dims = dense<0> : tensor<1xi64>,
|
||||||
|
scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>,
|
||||||
|
update_window_dims = dense<[1]> : tensor<1xi64>
|
||||||
|
},
|
||||||
|
unique_indices = false
|
||||||
|
} : (tensor<3x3xi32>, tensor<2xi32>, tensor<2x3xi32>) -> tensor<3x3xi32>
|
||||||
|
return %3 : tensor<3x3xi32>
|
||||||
|
// CHECK: constant dense<[
|
||||||
|
// CHECK-SAME: [1, 2, 3], [4, 5, 6], [7, 8, 9]
|
||||||
|
// CHECK-SAME: ]> : tensor<3x3xi32>
|
||||||
|
// CHECK: "mhlo.scatter"
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @scatter_out_of_bound
|
||||||
|
func @scatter_out_of_bound() -> tensor<3x3xi32> {
|
||||||
|
%0 = constant dense<[[1, 2, 3], [4, 5, 6], [7, 8, 9]]> : tensor<3x3xi32>
|
||||||
|
%1 = constant dense<[1, 5]> : tensor<2xi32>
|
||||||
|
%2 = constant dense<[[10, 20, 30], [70, 80, 90]]> : tensor<2x3xi32>
|
||||||
|
%3 = "mhlo.scatter"(%0, %1, %2) ( {
|
||||||
|
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>):
|
||||||
|
"mhlo.return"(%arg1) : (tensor<i32>) -> ()
|
||||||
|
}) {indices_are_sorted = false,
|
||||||
|
scatter_dimension_numbers = {
|
||||||
|
index_vector_dim = 1 : i64,
|
||||||
|
inserted_window_dims = dense<0> : tensor<1xi64>,
|
||||||
|
scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>,
|
||||||
|
update_window_dims = dense<[1]> : tensor<1xi64>
|
||||||
|
},
|
||||||
|
unique_indices = false
|
||||||
|
} : (tensor<3x3xi32>, tensor<2xi32>, tensor<2x3xi32>) -> tensor<3x3xi32>
|
||||||
|
return %3 : tensor<3x3xi32>
|
||||||
|
// CHECK: constant dense<[
|
||||||
|
// CHECK-SAME: [1, 2, 3], [4, 5, 6], [7, 8, 9]
|
||||||
|
// CHECK-SAME: ]> : tensor<3x3xi32>
|
||||||
|
// CHECK: "mhlo.scatter"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue