Make mhlo.sort return variadic results instead of a tuple

Tuple is only used on XLA's sort to return multiple inputs. MLIR supports
multiple inputs, switch to a tuple return.

PiperOrigin-RevId: 334226937
This commit is contained in:
Robert Suderman 2020-09-28 13:31:28 -07:00 committed by TensorFlow MLIR Team
parent be2ffd2e21
commit 26ac5baae4
3 changed files with 23 additions and 26 deletions

View File

@ -1198,14 +1198,14 @@ def HLO_SetDimensionSizeOp: HLO_Op<"set_dimension_size", [NoSideEffect]>,
let results = (outs HLO_Tensor); let results = (outs HLO_Tensor);
} }
def HLO_SortOp : HLO_Op<"sort", [RecursiveSideEffects]>, BASE_HLO_SortOp { def HLO_SortOp : HLO_Op<"sort", [RecursiveSideEffects, SameOperandsAndResultShape]>, BASE_HLO_SortOp {
let arguments = (ins let arguments = (ins
Variadic<HLO_Tensor>:$operands, Variadic<HLO_Tensor>:$operands,
DefaultValuedAttr<I64Attr, "-1">:$dimension, DefaultValuedAttr<I64Attr, "-1">:$dimension,
DefaultValuedAttr<BoolAttr, "false">:$is_stable DefaultValuedAttr<BoolAttr, "false">:$is_stable
); );
let results = (outs HLO_TensorOrTuple); let results = (outs Variadic<HLO_Tensor>);
let regions = (region SizedRegion<1>:$comparator); let regions = (region SizedRegion<1>:$comparator);

View File

@ -2261,10 +2261,7 @@ void SortOp::build(OpBuilder& builder, OperationState& state,
state.addAttribute("dimension", builder.getI64IntegerAttr(dimension)); state.addAttribute("dimension", builder.getI64IntegerAttr(dimension));
state.addAttribute("is_stable", builder.getBoolAttr(dimension)); state.addAttribute("is_stable", builder.getBoolAttr(dimension));
SmallVector<Type, 2> element_types; for (Value operand : operands) state.addTypes(operand.getType());
element_types.reserve(operands.size());
for (Value operand : operands) element_types.push_back(operand.getType());
state.addTypes(builder.getTupleType(element_types));
state.addRegion(); state.addRegion();
} }

View File

@ -1010,34 +1010,34 @@ func @constant_invalid() -> () {
func @sort(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) { func @sort(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) {
// CHECK: mhlo.sort // CHECK: mhlo.sort
%0 = "mhlo.sort"(%input0, %input1) ( { %0:2 = "mhlo.sort"(%input0, %input1) ( {
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<i32>): ^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<i32>):
%7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1> %7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"mhlo.return"(%7) : (tensor<i1>) -> () "mhlo.return"(%7) : (tensor<i1>) -> ()
}) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>> }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>)
return return
} }
// ----- // -----
func @sort_no_operands() { func @sort_no_operands() {
// expected-error @+1 {{op requires at least one input}} // expected-error @+1 {{expected named operation to have atleast 1 result}}
%0 = "mhlo.sort"() ( { %0:0 = "mhlo.sort"() ( {
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>, %arg3: tensor<i32>, %arg4: tensor<i32>): ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>, %arg3: tensor<i32>, %arg4: tensor<i32>):
%7 = "mhlo.compare"(%arg1, %arg2) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1> %7 = "mhlo.compare"(%arg1, %arg2) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"mhlo.return"(%7) : (tensor<i1>) -> () "mhlo.return"(%7) : (tensor<i1>) -> ()
}) {dimension = 1 : i64, is_stable = true} : () -> tuple<> }) {dimension = 1 : i64, is_stable = true} : () -> ()
return return
} }
// ----- // -----
func @sort_unknown_rank(%input0: tensor<*xf32>, %input1: tensor<16x16xi32>) { func @sort_unknown_rank(%input0: tensor<*xf32>, %input1: tensor<16x16xi32>) {
%0 = "mhlo.sort"(%input0, %input1) ( { %0:2 = "mhlo.sort"(%input0, %input1) ( {
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<i32>): ^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<i32>):
%7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1> %7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"mhlo.return"(%7) : (tensor<i1>) -> () "mhlo.return"(%7) : (tensor<i1>) -> ()
}) {dimension = 1 : i64, is_stable = true} : (tensor<*xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>> }) {dimension = 1 : i64, is_stable = true} : (tensor<*xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>)
return return
} }
@ -1045,23 +1045,23 @@ func @sort_unknown_rank(%input0: tensor<*xf32>, %input1: tensor<16x16xi32>) {
func @sort_unknown_rank(%input0: tensor<*xf32>, %input1: tensor<16x16xi32>) { func @sort_unknown_rank(%input0: tensor<*xf32>, %input1: tensor<16x16xi32>) {
// expected-error @+1 {{comparator block argument #0 should be of type 'tensor<f32>' but got 'tensor<i32>'}} // expected-error @+1 {{comparator block argument #0 should be of type 'tensor<f32>' but got 'tensor<i32>'}}
%0 = "mhlo.sort"(%input0, %input1) ( { %0:2 = "mhlo.sort"(%input0, %input1) ( {
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<i32>): ^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<i32>):
%7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<i32>, tensor<i32>) -> tensor<i1> %7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
"mhlo.return"(%7) : (tensor<i1>) -> () "mhlo.return"(%7) : (tensor<i1>) -> ()
}) {dimension = 1 : i64, is_stable = true} : (tensor<*xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>> }) {dimension = 1 : i64, is_stable = true} : (tensor<*xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>)
return return
} }
// ----- // -----
func @sort_different_dims(%input0: tensor<16x8xf32>, %input1: tensor<16x16xi32>) { func @sort_different_dims(%input0: tensor<16x8xf32>, %input1: tensor<16x16xi32>) {
// expected-error @+1 {{op requires all inputs to have the same dimensions}} // expected-error @+1 {{op requires the same shape for all operands and results}}
%0 = "mhlo.sort"(%input0, %input1) ( { %0:2 = "mhlo.sort"(%input0, %input1) ( {
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<i32>): ^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<i32>):
%7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1> %7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"mhlo.return"(%7) : (tensor<i1>) -> () "mhlo.return"(%7) : (tensor<i1>) -> ()
}) {dimension = 1 : i64, is_stable = true} : (tensor<16x8xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>> }) {dimension = 1 : i64, is_stable = true} : (tensor<16x8xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>)
return return
} }
@ -1069,11 +1069,11 @@ func @sort_different_dims(%input0: tensor<16x8xf32>, %input1: tensor<16x16xi32>)
func @sort_dim_out_of_range(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) { func @sort_dim_out_of_range(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) {
// expected-error @+1 {{dimension attribute value must be in range [-2, 2), but found 10}} // expected-error @+1 {{dimension attribute value must be in range [-2, 2), but found 10}}
%0 = "mhlo.sort"(%input0, %input1) ( { %0:2 = "mhlo.sort"(%input0, %input1) ( {
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<i32>): ^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<i32>):
%7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1> %7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"mhlo.return"(%7) : (tensor<i1>) -> () "mhlo.return"(%7) : (tensor<i1>) -> ()
}) {dimension = 10 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>> }) {dimension = 10 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>)
return return
} }
@ -1081,11 +1081,11 @@ func @sort_dim_out_of_range(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi3
func @sort_dim_out_of_range(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) { func @sort_dim_out_of_range(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) {
// expected-error @+1 {{dimension attribute value must be in range [-2, 2), but found -3}} // expected-error @+1 {{dimension attribute value must be in range [-2, 2), but found -3}}
%0 = "mhlo.sort"(%input0, %input1) ( { %0:2 = "mhlo.sort"(%input0, %input1) ( {
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<i32>): ^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<i32>):
%7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1> %7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"mhlo.return"(%7) : (tensor<i1>) -> () "mhlo.return"(%7) : (tensor<i1>) -> ()
}) {dimension = -3 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>> }) {dimension = -3 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>)
return return
} }
@ -1093,11 +1093,11 @@ func @sort_dim_out_of_range(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi3
func @sort_wrong_block_arg_count(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) { func @sort_wrong_block_arg_count(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) {
// expected-error @+1 {{op comparator block should have 4 arguments}} // expected-error @+1 {{op comparator block should have 4 arguments}}
%0 = "mhlo.sort"(%input0, %input1) ( { %0:2 = "mhlo.sort"(%input0, %input1) ( {
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>): ^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>):
%7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1> %7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"mhlo.return"(%7) : (tensor<i1>) -> () "mhlo.return"(%7) : (tensor<i1>) -> ()
}) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>> }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>)
return return
} }
@ -1105,11 +1105,11 @@ func @sort_wrong_block_arg_count(%input0: tensor<16x16xf32>, %input1: tensor<16x
func @sort_wrong_block_arg_type(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) { func @sort_wrong_block_arg_type(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) {
// expected-error @+1 {{op comparator block argument #3 should be of type 'tensor<i32>' but got 'tensor<f32>'}} // expected-error @+1 {{op comparator block argument #3 should be of type 'tensor<i32>' but got 'tensor<f32>'}}
%0 = "mhlo.sort"(%input0, %input1) ( { %0:2 = "mhlo.sort"(%input0, %input1) ( {
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<f32>): ^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<f32>):
%7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1> %7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"mhlo.return"(%7) : (tensor<i1>) -> () "mhlo.return"(%7) : (tensor<i1>) -> ()
}) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>> }) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>)
return return
} }