From 26ac5baae46c2a86e5edf6940714dc65108230c0 Mon Sep 17 00:00:00 2001 From: Robert Suderman Date: Mon, 28 Sep 2020 13:31:28 -0700 Subject: [PATCH] Make mhlo.sort return variadic results instead of a tuple Tuple is only used on XLA's sort to return multiple inputs. MLIR supports multiple inputs, switch to a tuple return. PiperOrigin-RevId: 334226937 --- include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td | 4 +-- lib/Dialect/mhlo/IR/hlo_ops.cc | 5 +-- tests/ops.mlir | 40 ++++++++++----------- 3 files changed, 23 insertions(+), 26 deletions(-) diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td index d545c2a..86d6e34 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td @@ -1198,14 +1198,14 @@ def HLO_SetDimensionSizeOp: HLO_Op<"set_dimension_size", [NoSideEffect]>, let results = (outs HLO_Tensor); } -def HLO_SortOp : HLO_Op<"sort", [RecursiveSideEffects]>, BASE_HLO_SortOp { +def HLO_SortOp : HLO_Op<"sort", [RecursiveSideEffects, SameOperandsAndResultShape]>, BASE_HLO_SortOp { let arguments = (ins Variadic:$operands, DefaultValuedAttr:$dimension, DefaultValuedAttr:$is_stable ); - let results = (outs HLO_TensorOrTuple); + let results = (outs Variadic); let regions = (region SizedRegion<1>:$comparator); diff --git a/lib/Dialect/mhlo/IR/hlo_ops.cc b/lib/Dialect/mhlo/IR/hlo_ops.cc index 1ebec66..31de088 100644 --- a/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -2261,10 +2261,7 @@ void SortOp::build(OpBuilder& builder, OperationState& state, state.addAttribute("dimension", builder.getI64IntegerAttr(dimension)); state.addAttribute("is_stable", builder.getBoolAttr(dimension)); - SmallVector element_types; - element_types.reserve(operands.size()); - for (Value operand : operands) element_types.push_back(operand.getType()); - state.addTypes(builder.getTupleType(element_types)); + for (Value operand : operands) state.addTypes(operand.getType()); state.addRegion(); } diff --git a/tests/ops.mlir b/tests/ops.mlir index aff2f7f..8cf8dba 100644 --- a/tests/ops.mlir +++ b/tests/ops.mlir @@ -1010,34 +1010,34 @@ func @constant_invalid() -> () { func @sort(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) { // CHECK: mhlo.sort - %0 = "mhlo.sort"(%input0, %input1) ( { + %0:2 = "mhlo.sort"(%input0, %input1) ( { ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): %7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor "mhlo.return"(%7) : (tensor) -> () - }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> + }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>) return } // ----- func @sort_no_operands() { - // expected-error @+1 {{op requires at least one input}} - %0 = "mhlo.sort"() ( { + // expected-error @+1 {{expected named operation to have atleast 1 result}} + %0:0 = "mhlo.sort"() ( { ^bb0(%arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor): %7 = "mhlo.compare"(%arg1, %arg2) {comparison_direction = "GT"} : (tensor, tensor) -> tensor "mhlo.return"(%7) : (tensor) -> () - }) {dimension = 1 : i64, is_stable = true} : () -> tuple<> + }) {dimension = 1 : i64, is_stable = true} : () -> () return } // ----- func @sort_unknown_rank(%input0: tensor<*xf32>, %input1: tensor<16x16xi32>) { - %0 = "mhlo.sort"(%input0, %input1) ( { + %0:2 = "mhlo.sort"(%input0, %input1) ( { ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): %7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor "mhlo.return"(%7) : (tensor) -> () - }) {dimension = 1 : i64, is_stable = true} : (tensor<*xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> + }) {dimension = 1 : i64, is_stable = true} : (tensor<*xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>) return } @@ -1045,23 +1045,23 @@ func @sort_unknown_rank(%input0: tensor<*xf32>, %input1: tensor<16x16xi32>) { func @sort_unknown_rank(%input0: tensor<*xf32>, %input1: tensor<16x16xi32>) { // expected-error @+1 {{comparator block argument #0 should be of type 'tensor' but got 'tensor'}} - %0 = "mhlo.sort"(%input0, %input1) ( { + %0:2 = "mhlo.sort"(%input0, %input1) ( { ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): %7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor "mhlo.return"(%7) : (tensor) -> () - }) {dimension = 1 : i64, is_stable = true} : (tensor<*xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> + }) {dimension = 1 : i64, is_stable = true} : (tensor<*xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>) return } // ----- func @sort_different_dims(%input0: tensor<16x8xf32>, %input1: tensor<16x16xi32>) { - // expected-error @+1 {{op requires all inputs to have the same dimensions}} - %0 = "mhlo.sort"(%input0, %input1) ( { + // expected-error @+1 {{op requires the same shape for all operands and results}} + %0:2 = "mhlo.sort"(%input0, %input1) ( { ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): %7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor "mhlo.return"(%7) : (tensor) -> () - }) {dimension = 1 : i64, is_stable = true} : (tensor<16x8xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> + }) {dimension = 1 : i64, is_stable = true} : (tensor<16x8xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>) return } @@ -1069,11 +1069,11 @@ func @sort_different_dims(%input0: tensor<16x8xf32>, %input1: tensor<16x16xi32>) func @sort_dim_out_of_range(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) { // expected-error @+1 {{dimension attribute value must be in range [-2, 2), but found 10}} - %0 = "mhlo.sort"(%input0, %input1) ( { + %0:2 = "mhlo.sort"(%input0, %input1) ( { ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): %7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor "mhlo.return"(%7) : (tensor) -> () - }) {dimension = 10 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> + }) {dimension = 10 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>) return } @@ -1081,11 +1081,11 @@ func @sort_dim_out_of_range(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi3 func @sort_dim_out_of_range(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) { // expected-error @+1 {{dimension attribute value must be in range [-2, 2), but found -3}} - %0 = "mhlo.sort"(%input0, %input1) ( { + %0:2 = "mhlo.sort"(%input0, %input1) ( { ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): %7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor "mhlo.return"(%7) : (tensor) -> () - }) {dimension = -3 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> + }) {dimension = -3 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>) return } @@ -1093,11 +1093,11 @@ func @sort_dim_out_of_range(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi3 func @sort_wrong_block_arg_count(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) { // expected-error @+1 {{op comparator block should have 4 arguments}} - %0 = "mhlo.sort"(%input0, %input1) ( { + %0:2 = "mhlo.sort"(%input0, %input1) ( { ^bb0(%arg0: tensor, %arg1: tensor): %7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor "mhlo.return"(%7) : (tensor) -> () - }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> + }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>) return } @@ -1105,11 +1105,11 @@ func @sort_wrong_block_arg_count(%input0: tensor<16x16xf32>, %input1: tensor<16x func @sort_wrong_block_arg_type(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) { // expected-error @+1 {{op comparator block argument #3 should be of type 'tensor' but got 'tensor'}} - %0 = "mhlo.sort"(%input0, %input1) ( { + %0:2 = "mhlo.sort"(%input0, %input1) ( { ^bb0(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor): %7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor, tensor) -> tensor "mhlo.return"(%7) : (tensor) -> () - }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple, tensor<16x16xi32>> + }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>) return }