From 05ee41baf85181d66d3e8bc68cfe70306d558fa7 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 15 Oct 2020 03:25:34 -0700 Subject: [PATCH] Add folder for mhlo::scatter PiperOrigin-RevId: 337274351 --- include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td | 2 + lib/Dialect/mhlo/IR/hlo_ops.cc | 140 +++++++++++ tests/canonicalize.mlir | 262 ++++++++++++++++++++ 3 files changed, 404 insertions(+) diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td index 507f7c1..cb431bd 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td @@ -1124,6 +1124,8 @@ def HLO_ScatterOp: HLO_Op<"scatter", [RecursiveSideEffects]>, let results = (outs HLO_Tensor); let hasCustomHLOConverter = 1; + + let hasFolder = 1; } // TODO(jpienaar): Add broadcastable trait. diff --git a/lib/Dialect/mhlo/IR/hlo_ops.cc b/lib/Dialect/mhlo/IR/hlo_ops.cc index 0a5bb0e..86f048f 100644 --- a/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -27,6 +27,7 @@ limitations under the License. #include "llvm/ADT/APFloat.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" @@ -2645,6 +2646,145 @@ OpFoldResult CompareOp::fold(ArrayRef operands) { return {}; } +//===----------------------------------------------------------------------===// +// ScatterOp +//===----------------------------------------------------------------------===// + +llvm::SmallVector evaluateMhloRegion(Region& region, + ArrayRef inputs) { + if (region.getNumArguments() != inputs.size()) return {}; + + llvm::DenseMap values; + values.reserve(region.getNumArguments()); + for (auto it : llvm::zip(region.getArguments(), inputs)) { + values.try_emplace(std::get<0>(it), std::get<1>(it)); + } + + for (auto& op : region.getOps()) { + llvm::SmallVector inputs; + for (auto& operand : op.getOpOperands()) { + inputs.push_back(values.lookup(operand.get())); + } + if (isa(op)) return inputs; + + llvm::SmallVector results; + if (failed(op.fold(inputs, results))) return {}; + for (auto it : llvm::zip(op.getResults(), results)) { + if (!std::get<1>(it).is()) return {}; + values.insert({std::get<0>(it), std::get<1>(it).get()}); + } + } + return {}; +} + +OpFoldResult ScatterOp::fold(ArrayRef operands) { + auto base = operands[0].dyn_cast_or_null(); + auto index = operands[1].dyn_cast_or_null(); + auto update = operands[2].dyn_cast_or_null(); + if (!base || !index || !update) return {}; + + auto base_type = base.getType().dyn_cast(); + auto index_type = index.getType().dyn_cast(); + auto update_type = update.getType().dyn_cast(); + if (!base_type || !index_type || !update_type) return {}; + + // Add the virtual trailing dimension of size 1 if index_vector_dim equals to + // index_type.rank. + const int64_t index_vector_dim = + scatter_dimension_numbers().index_vector_dim().getInt(); + if (index_vector_dim == index_type.getRank()) { + auto index_shape = index_type.getShape().vec(); + index_shape.push_back(1); + index_type = + RankedTensorType::get(index_shape, index_type.getElementType()); + index = index.reshape(index_type).cast(); + } + + // Increment the multi-dimensional index vector based on the limits for each + // dimension specified by shape and returns false if the index rolled around + // with true otherwise. + auto next_index = [](llvm::SmallVector& index, + llvm::ArrayRef shape) { + for (int64_t i = index.size() - 1; i >= 0; --i) { + ++index[i]; + if (index[i] < shape[i]) return true; + index[i] = 0; + } + return false; + }; + + // Iterate over all elements of the update tensor, then find the corresponding + // value in the indices tensor to determine which location we have to update + // in the base/result tensor. + llvm::SmallVector results(base.getValues()); + llvm::SmallVector update_index(update_type.getRank(), 0); + llvm::SmallVector index_index; + index_index.reserve(index_type.getRank()); + llvm::SmallVector base_index; + base_index.reserve(base_type.getRank()); + do { + // Compute the index for the slice of the indices tensor for this update + // value. + index_index.clear(); + if (index_vector_dim == 0) index_index.push_back(0); + for (int64_t i = 0; i < update_index.size(); ++i) { + if (llvm::count(scatter_dimension_numbers().update_window_dims(), i) == 0) + index_index.push_back(update_index[i]); + if (index_index.size() == index_vector_dim) index_index.push_back(0); + } + + // Compute the index for the given update value in the base tensor. + base_index.assign(base_type.getRank(), 0); + uint64_t index_count = index_type.getShape()[index_vector_dim]; + for (uint64_t i = 0; i < index_count; ++i) { + uint64_t operand_dim = scatter_dimension_numbers() + .scatter_dims_to_operand_dims() + .getValue({i}) + .getSExtValue(); + index_index[index_vector_dim] = i; + base_index[operand_dim] += + index.getValue(index_index).getSExtValue(); + } + uint64_t update_window_dim_index = 0; + for (uint64_t i = 0; i < base_index.size(); ++i) { + if (llvm::count(scatter_dimension_numbers().inserted_window_dims(), i)) + continue; + base_index[i] += + update_index[scatter_dimension_numbers() + .update_window_dims() + .getValue({update_window_dim_index}) + .getSExtValue()]; + update_window_dim_index++; + } + + // Compute the linear index for the index into the base tensor. + int64_t linear_base_index = 0; + int64_t linear_base_index_multiplyer = 1; + for (int64_t i = base_index.size() - 1; i >= 0; --i) { + // Out of bound index have backend specific behaviour so avoid folding it. + if (base_index[i] < 0 || base_index[i] >= base_type.getShape()[i]) + return {}; + linear_base_index += base_index[i] * linear_base_index_multiplyer; + linear_base_index_multiplyer *= base_type.getShape()[i]; + } + + // Evaluate update computation and update the value with the newly computed + // attribute in the base tensor. + auto lhs = DenseElementsAttr::get( + RankedTensorType::get({}, base_type.getElementType()), + results[linear_base_index]); + auto rhs = DenseElementsAttr::get( + RankedTensorType::get({}, base_type.getElementType()), + update.getValue(update_index)); + auto new_value = evaluateMhloRegion(update_computation(), {lhs, rhs}); + if (new_value.size() != 1 || !new_value[0]) return {}; + results[linear_base_index] = + new_value[0].cast().getValue({}); + } while (next_index(update_index, update_type.getShape())); + + return DenseElementsAttr::get(base_type, results); +} + } // namespace mhlo } // namespace mlir diff --git a/tests/canonicalize.mlir b/tests/canonicalize.mlir index 974585b..4effdc1 100644 --- a/tests/canonicalize.mlir +++ b/tests/canonicalize.mlir @@ -1167,3 +1167,265 @@ func @not_fold_sqrt_neg_constants() -> tensor<4xf32> { // CHECK: mhlo.sqrt return %1 : tensor<4xf32> } + +// CHECK-LABEL: @tensor_flow_scatter_v1_update +func @tensor_flow_scatter_v1_update() -> tensor<3x3xi32> { + %0 = constant dense<[[1, 2, 3], [4, 5, 6], [7, 8, 9]]> : tensor<3x3xi32> + %1 = constant dense<[0, 2]> : tensor<2xi32> + %2 = constant dense<[[10, 20, 30], [70, 80, 90]]> : tensor<2x3xi32> + %3 = "mhlo.scatter"(%0, %1, %2) ( { + ^bb0(%arg0: tensor, %arg1: tensor): + "mhlo.return"(%arg1) : (tensor) -> () + }) {indices_are_sorted = false, + scatter_dimension_numbers = { + index_vector_dim = 1 : i64, + inserted_window_dims = dense<0> : tensor<1xi64>, + scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>, + update_window_dims = dense<[1]> : tensor<1xi64> + }, + unique_indices = false + } : (tensor<3x3xi32>, tensor<2xi32>, tensor<2x3xi32>) -> tensor<3x3xi32> + return %3 : tensor<3x3xi32> + // CHECK: mhlo.constant dense<[ + // CHECK-SAME: [10, 20, 30], [4, 5, 6], [70, 80, 90] + // CHECK-SAME: ]> : tensor<3x3xi32> +} + +// CHECK-LABEL: @tensor_flow_scatter_v2_update +func @tensor_flow_scatter_v2_update() -> tensor<3x3xi32> { + %0 = constant dense<[[1, 2, 3], [4, 5, 6], [7, 8, 9]]> : tensor<3x3xi32> + %1 = constant dense<[0, 2]> : tensor<2xi32> + %2 = constant dense<[[10, 30], [40, 60], [70, 90]]> : tensor<3x2xi32> + %3 = "mhlo.scatter"(%0, %1, %2) ( { + ^bb0(%arg0: tensor, %arg1: tensor): + "mhlo.return"(%arg1) : (tensor) -> () + }) {indices_are_sorted = false, + scatter_dimension_numbers = { + index_vector_dim = 1 : i64, + inserted_window_dims = dense<1> : tensor<1xi64>, + scatter_dims_to_operand_dims = dense<1> : tensor<1xi64>, + update_window_dims = dense<[0]> : tensor<1xi64> + }, + unique_indices = false + } : (tensor<3x3xi32>, tensor<2xi32>, tensor<3x2xi32>) -> tensor<3x3xi32> + return %3 : tensor<3x3xi32> + // CHECK: mhlo.constant dense<[ + // CHECK-SAME: [10, 2, 30], [40, 5, 60], [70, 8, 90] + // CHECK-SAME: ]> : tensor<3x3xi32> +} + +// CHECK-LABEL: @tensor_flow_scatter_add +func @tensor_flow_scatter_add() -> tensor<3x3xi32> { + %0 = constant dense<[[1, 2, 3], [4, 5, 6], [7, 8, 9]]> : tensor<3x3xi32> + %1 = constant dense<[0, 2]> : tensor<2xi32> + %2 = constant dense<[[10, 20, 30], [70, 80, 90]]> : tensor<2x3xi32> + %3 = "mhlo.scatter"(%0, %1, %2) ( { + ^bb0(%arg0: tensor, %arg1: tensor): + %4 = "mhlo.add"(%arg0, %arg1) : (tensor, tensor) -> (tensor) + "mhlo.return"(%4) : (tensor) -> () + }) {indices_are_sorted = false, + scatter_dimension_numbers = { + index_vector_dim = 1 : i64, + inserted_window_dims = dense<0> : tensor<1xi64>, + scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>, + update_window_dims = dense<[1]> : tensor<1xi64> + }, + unique_indices = false + } : (tensor<3x3xi32>, tensor<2xi32>, tensor<2x3xi32>) -> tensor<3x3xi32> + return %3 : tensor<3x3xi32> + // CHECK: mhlo.constant dense<[ + // CHECK-SAME: [11, 22, 33], [4, 5, 6], [77, 88, 99] + // CHECK-SAME: ]> : tensor<3x3xi32> +} + +// CHECK-LABEL: @tensor_flow_scatter_repeated +func @tensor_flow_scatter_repeated() -> tensor<3x3xi32> { + %0 = constant dense<[[1, 2, 3], [4, 5, 6], [7, 8, 9]]> : tensor<3x3xi32> + %1 = constant dense<[1, 1]> : tensor<2xi32> + %2 = constant dense<[[10, 20, 30], [70, 80, 90]]> : tensor<2x3xi32> + %3 = "mhlo.scatter"(%0, %1, %2) ( { + ^bb0(%arg0: tensor, %arg1: tensor): + %4 = "mhlo.add"(%arg0, %arg1) : (tensor, tensor) -> (tensor) + "mhlo.return"(%4) : (tensor) -> () + }) {indices_are_sorted = false, + scatter_dimension_numbers = { + index_vector_dim = 1 : i64, + inserted_window_dims = dense<0> : tensor<1xi64>, + scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>, + update_window_dims = dense<[1]> : tensor<1xi64> + }, + unique_indices = false + } : (tensor<3x3xi32>, tensor<2xi32>, tensor<2x3xi32>) -> tensor<3x3xi32> + return %3 : tensor<3x3xi32> + // CHECK: mhlo.constant dense<[ + // CHECK-SAME: [1, 2, 3], [84, 105, 126], [7, 8, 9] + // CHECK-SAME: ]> : tensor<3x3xi32> +} + +// CHECK-LABEL: @tensor_flow_scatter_multiple_batch +func @tensor_flow_scatter_multiple_batch() -> tensor<3x3xi32> { + %0 = constant dense<[[1, 2, 3], [4, 5, 6], [7, 8, 9]]> : tensor<3x3xi32> + %1 = constant dense<[[0, 2], [2, 1]]> : tensor<2x2xi32> + %2 = constant dense<[[[10, 30], [40, 60], [70, 90]], [[5, 5], [5, 5], [5, 5]]]> : tensor<2x3x2xi32> + %3 = "mhlo.scatter"(%0, %1, %2) ( { + ^bb0(%arg0: tensor, %arg1: tensor): + %4 = "mhlo.add"(%arg0, %arg1) : (tensor, tensor) -> (tensor) + "mhlo.return"(%4) : (tensor) -> () + }) {indices_are_sorted = false, + scatter_dimension_numbers = { + index_vector_dim = 2 : i64, + inserted_window_dims = dense<1> : tensor<1xi64>, + scatter_dims_to_operand_dims = dense<1> : tensor<1xi64>, + update_window_dims = dense<[1]> : tensor<1xi64> + }, + unique_indices = false + } : (tensor<3x3xi32>, tensor<2x2xi32>, tensor<2x3x2xi32>) -> tensor<3x3xi32> + return %3 : tensor<3x3xi32> + // CHECK: mhlo.constant dense<[ + // CHECK-SAME: [11, 7, 38], [44, 10, 71], [77, 13, 104] + // CHECK-SAME: ]> : tensor<3x3xi32> +} + +// CHECK-LABEL: @tensor_flow_scatter_nd +func @tensor_flow_scatter_nd() -> tensor<3x3x2xi32> { + %0 = constant dense<[[[-1, 1], [-2, 2], [-3, 3]], [[-4, 4], [-5, 5], [-6, 6]], [[-7, 7], [-8, 8], [-9, 9]]]> : tensor<3x3x2xi32> + %1 = constant dense<[[0, 0], [1, 0]]> : tensor<2x2xi32> + %2 = constant dense<[[-10, 10], [-40, 40]]> : tensor<2x2xi32> + %3 = "mhlo.scatter"(%0, %1, %2) ( { + ^bb0(%arg0: tensor, %arg1: tensor): + "mhlo.return"(%arg1) : (tensor) -> () + }) {indices_are_sorted = false, + scatter_dimension_numbers = { + index_vector_dim = 1 : i64, + inserted_window_dims = dense<[0, 1]> : tensor<2xi64>, + scatter_dims_to_operand_dims = dense<[0, 1]> : tensor<2xi64>, + update_window_dims = dense<[1]> : tensor<1xi64> + }, + unique_indices = false + } : (tensor<3x3x2xi32>, tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<3x3x2xi32> + return %3 : tensor<3x3x2xi32> + // CHECK: mhlo.constant dense<[ + // CHECK-SAME: [-10, 10], [-2, 2], [-3, 3] + // CHECK-SAME: [-40, 40], [-5, 5], [-6, 6] + // CHECK-SAME: [-7, 7], [-8, 8], [-9, 9] + // CHECK-SAME: ]> : tensor<3x3x2xi32> +} + +// CHECK-LABEL: @tensor_flow_scatter_nd_index_vector +func @tensor_flow_scatter_nd_index_vector() -> tensor<3x3x2xi32> { + %0 = constant dense<[[[-1, 1], [-2, 2], [-3, 3]], [[-4, 4], [-5, 5], [-6, 6]], [[-7, 7], [-8, 8], [-9, 9]]]> : tensor<3x3x2xi32> + %1 = constant dense<[[0, 0], [1, 0]]> : tensor<2x2xi32> + %2 = constant dense<[[-10, 10], [-20, 20]]> : tensor<2x2xi32> + %3 = "mhlo.scatter"(%0, %1, %2) ( { + ^bb0(%arg0: tensor, %arg1: tensor): + "mhlo.return"(%arg1) : (tensor) -> () + }) {indices_are_sorted = false, + scatter_dimension_numbers = { + index_vector_dim = 0 : i64, + inserted_window_dims = dense<[0, 1]> : tensor<2xi64>, + scatter_dims_to_operand_dims = dense<[0, 1]> : tensor<2xi64>, + update_window_dims = dense<[1]> : tensor<1xi64> + }, + unique_indices = false + } : (tensor<3x3x2xi32>, tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<3x3x2xi32> + return %3 : tensor<3x3x2xi32> + // CHECK: mhlo.constant dense<[ + // CHECK-SAME: [-20, 20], [-10, 10], [-3, 3] + // CHECK-SAME: [-4, 4], [-5, 5], [-6, 6] + // CHECK-SAME: [-7, 7], [-8, 8], [-9, 9] + // CHECK-SAME: ]> : tensor<3x3x2xi32> +} + +// CHECK-LABEL: @scatter_batch_dus +func @scatter_batch_dus() -> tensor<3x3xi32> { + %0 = constant dense<[[1, 2, 3], [4, 5, 6], [7, 8, 9]]> : tensor<3x3xi32> + %1 = constant dense<[[2, 1], [1, 1]]> : tensor<2x2xi32> + %2 = constant dense<[[[10]], [[20]]]> : tensor<2x1x1xi32> + %3 = "mhlo.scatter"(%0, %1, %2) ( { + ^bb0(%arg0: tensor, %arg1: tensor): + "mhlo.return"(%arg1) : (tensor) -> () + }) {indices_are_sorted = false, + scatter_dimension_numbers = { + index_vector_dim = 0 : i64, + inserted_window_dims = dense<> : tensor<0xi64>, + scatter_dims_to_operand_dims = dense<[0, 1]> : tensor<2xi64>, + update_window_dims = dense<[1, 2]> : tensor<2xi64> + }, + unique_indices = false + } : (tensor<3x3xi32>, tensor<2x2xi32>, tensor<2x1x1xi32>) -> tensor<3x3xi32> + return %3 : tensor<3x3xi32> + // CHECK: mhlo.constant dense<[ + // CHECK-SAME: [1, 2, 3], [4, 20, 6], [7, 10, 9] + // CHECK-SAME: ]> : tensor<3x3xi32> +} + +// CHECK-LABEL: @scatter_no_update_window_dim +func @scatter_no_update_window_dim() -> tensor<3xi32> { + %0 = constant dense<[0, 1, 2]> : tensor<3xi32> + %1 = constant dense<[[[0], [1]], [[2], [1]]]> : tensor<2x2x1xi32> + %2 = constant dense<[[10, 20], [30, 40]]> : tensor<2x2xi32> + %3 = "mhlo.scatter"(%0, %1, %2) ( { + ^bb0(%arg0: tensor, %arg1: tensor): + %4 = "mhlo.add"(%arg0, %arg1) : (tensor, tensor) -> (tensor) + "mhlo.return"(%4) : (tensor) -> () + }) {indices_are_sorted = false, + scatter_dimension_numbers = { + index_vector_dim = 2 : i64, + inserted_window_dims = dense<0> : tensor<1xi64>, + scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>, + update_window_dims = dense<> : tensor<0xi64> + }, + unique_indices = false + } : (tensor<3xi32>, tensor<2x2x1xi32>, tensor<2x2xi32>) -> tensor<3xi32> + return %3 : tensor<3xi32> + // CHECK: mhlo.constant dense<[10, 61, 32]> : tensor<3xi32> +} + +// CHECK-LABEL: @scatter_negative_index +func @scatter_negative_index() -> tensor<3x3xi32> { + %0 = constant dense<[[1, 2, 3], [4, 5, 6], [7, 8, 9]]> : tensor<3x3xi32> + %1 = constant dense<[0, -1]> : tensor<2xi32> + %2 = constant dense<[[10, 20, 30], [70, 80, 90]]> : tensor<2x3xi32> + %3 = "mhlo.scatter"(%0, %1, %2) ( { + ^bb0(%arg0: tensor, %arg1: tensor): + "mhlo.return"(%arg1) : (tensor) -> () + }) {indices_are_sorted = false, + scatter_dimension_numbers = { + index_vector_dim = 1 : i64, + inserted_window_dims = dense<0> : tensor<1xi64>, + scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>, + update_window_dims = dense<[1]> : tensor<1xi64> + }, + unique_indices = false + } : (tensor<3x3xi32>, tensor<2xi32>, tensor<2x3xi32>) -> tensor<3x3xi32> + return %3 : tensor<3x3xi32> + // CHECK: constant dense<[ + // CHECK-SAME: [1, 2, 3], [4, 5, 6], [7, 8, 9] + // CHECK-SAME: ]> : tensor<3x3xi32> + // CHECK: "mhlo.scatter" +} + +// CHECK-LABEL: @scatter_out_of_bound +func @scatter_out_of_bound() -> tensor<3x3xi32> { + %0 = constant dense<[[1, 2, 3], [4, 5, 6], [7, 8, 9]]> : tensor<3x3xi32> + %1 = constant dense<[1, 5]> : tensor<2xi32> + %2 = constant dense<[[10, 20, 30], [70, 80, 90]]> : tensor<2x3xi32> + %3 = "mhlo.scatter"(%0, %1, %2) ( { + ^bb0(%arg0: tensor, %arg1: tensor): + "mhlo.return"(%arg1) : (tensor) -> () + }) {indices_are_sorted = false, + scatter_dimension_numbers = { + index_vector_dim = 1 : i64, + inserted_window_dims = dense<0> : tensor<1xi64>, + scatter_dims_to_operand_dims = dense<0> : tensor<1xi64>, + update_window_dims = dense<[1]> : tensor<1xi64> + }, + unique_indices = false + } : (tensor<3x3xi32>, tensor<2xi32>, tensor<2x3xi32>) -> tensor<3x3xi32> + return %3 : tensor<3x3xi32> + // CHECK: constant dense<[ + // CHECK-SAME: [1, 2, 3], [4, 5, 6], [7, 8, 9] + // CHECK-SAME: ]> : tensor<3x3xi32> + // CHECK: "mhlo.scatter" +} +