Make mhlo.sort return variadic results instead of a tuple
Tuple is only used on XLA's sort to return multiple inputs. MLIR supports multiple inputs, switch to a tuple return. PiperOrigin-RevId: 334226937
This commit is contained in:
parent
be2ffd2e21
commit
26ac5baae4
|
@ -1198,14 +1198,14 @@ def HLO_SetDimensionSizeOp: HLO_Op<"set_dimension_size", [NoSideEffect]>,
|
||||||
let results = (outs HLO_Tensor);
|
let results = (outs HLO_Tensor);
|
||||||
}
|
}
|
||||||
|
|
||||||
def HLO_SortOp : HLO_Op<"sort", [RecursiveSideEffects]>, BASE_HLO_SortOp {
|
def HLO_SortOp : HLO_Op<"sort", [RecursiveSideEffects, SameOperandsAndResultShape]>, BASE_HLO_SortOp {
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
Variadic<HLO_Tensor>:$operands,
|
Variadic<HLO_Tensor>:$operands,
|
||||||
DefaultValuedAttr<I64Attr, "-1">:$dimension,
|
DefaultValuedAttr<I64Attr, "-1">:$dimension,
|
||||||
DefaultValuedAttr<BoolAttr, "false">:$is_stable
|
DefaultValuedAttr<BoolAttr, "false">:$is_stable
|
||||||
);
|
);
|
||||||
|
|
||||||
let results = (outs HLO_TensorOrTuple);
|
let results = (outs Variadic<HLO_Tensor>);
|
||||||
|
|
||||||
let regions = (region SizedRegion<1>:$comparator);
|
let regions = (region SizedRegion<1>:$comparator);
|
||||||
|
|
||||||
|
|
|
@ -2261,10 +2261,7 @@ void SortOp::build(OpBuilder& builder, OperationState& state,
|
||||||
state.addAttribute("dimension", builder.getI64IntegerAttr(dimension));
|
state.addAttribute("dimension", builder.getI64IntegerAttr(dimension));
|
||||||
state.addAttribute("is_stable", builder.getBoolAttr(dimension));
|
state.addAttribute("is_stable", builder.getBoolAttr(dimension));
|
||||||
|
|
||||||
SmallVector<Type, 2> element_types;
|
for (Value operand : operands) state.addTypes(operand.getType());
|
||||||
element_types.reserve(operands.size());
|
|
||||||
for (Value operand : operands) element_types.push_back(operand.getType());
|
|
||||||
state.addTypes(builder.getTupleType(element_types));
|
|
||||||
|
|
||||||
state.addRegion();
|
state.addRegion();
|
||||||
}
|
}
|
||||||
|
|
|
@ -1010,34 +1010,34 @@ func @constant_invalid() -> () {
|
||||||
|
|
||||||
func @sort(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) {
|
func @sort(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) {
|
||||||
// CHECK: mhlo.sort
|
// CHECK: mhlo.sort
|
||||||
%0 = "mhlo.sort"(%input0, %input1) ( {
|
%0:2 = "mhlo.sort"(%input0, %input1) ( {
|
||||||
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<i32>):
|
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<i32>):
|
||||||
%7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
%7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||||
"mhlo.return"(%7) : (tensor<i1>) -> ()
|
"mhlo.return"(%7) : (tensor<i1>) -> ()
|
||||||
}) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>>
|
}) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
func @sort_no_operands() {
|
func @sort_no_operands() {
|
||||||
// expected-error @+1 {{op requires at least one input}}
|
// expected-error @+1 {{expected named operation to have atleast 1 result}}
|
||||||
%0 = "mhlo.sort"() ( {
|
%0:0 = "mhlo.sort"() ( {
|
||||||
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>, %arg3: tensor<i32>, %arg4: tensor<i32>):
|
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>, %arg3: tensor<i32>, %arg4: tensor<i32>):
|
||||||
%7 = "mhlo.compare"(%arg1, %arg2) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
%7 = "mhlo.compare"(%arg1, %arg2) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||||
"mhlo.return"(%7) : (tensor<i1>) -> ()
|
"mhlo.return"(%7) : (tensor<i1>) -> ()
|
||||||
}) {dimension = 1 : i64, is_stable = true} : () -> tuple<>
|
}) {dimension = 1 : i64, is_stable = true} : () -> ()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
func @sort_unknown_rank(%input0: tensor<*xf32>, %input1: tensor<16x16xi32>) {
|
func @sort_unknown_rank(%input0: tensor<*xf32>, %input1: tensor<16x16xi32>) {
|
||||||
%0 = "mhlo.sort"(%input0, %input1) ( {
|
%0:2 = "mhlo.sort"(%input0, %input1) ( {
|
||||||
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<i32>):
|
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<i32>):
|
||||||
%7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
%7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||||
"mhlo.return"(%7) : (tensor<i1>) -> ()
|
"mhlo.return"(%7) : (tensor<i1>) -> ()
|
||||||
}) {dimension = 1 : i64, is_stable = true} : (tensor<*xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>>
|
}) {dimension = 1 : i64, is_stable = true} : (tensor<*xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1045,23 +1045,23 @@ func @sort_unknown_rank(%input0: tensor<*xf32>, %input1: tensor<16x16xi32>) {
|
||||||
|
|
||||||
func @sort_unknown_rank(%input0: tensor<*xf32>, %input1: tensor<16x16xi32>) {
|
func @sort_unknown_rank(%input0: tensor<*xf32>, %input1: tensor<16x16xi32>) {
|
||||||
// expected-error @+1 {{comparator block argument #0 should be of type 'tensor<f32>' but got 'tensor<i32>'}}
|
// expected-error @+1 {{comparator block argument #0 should be of type 'tensor<f32>' but got 'tensor<i32>'}}
|
||||||
%0 = "mhlo.sort"(%input0, %input1) ( {
|
%0:2 = "mhlo.sort"(%input0, %input1) ( {
|
||||||
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<i32>):
|
^bb0(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<i32>):
|
||||||
%7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
|
%7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
|
||||||
"mhlo.return"(%7) : (tensor<i1>) -> ()
|
"mhlo.return"(%7) : (tensor<i1>) -> ()
|
||||||
}) {dimension = 1 : i64, is_stable = true} : (tensor<*xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>>
|
}) {dimension = 1 : i64, is_stable = true} : (tensor<*xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
func @sort_different_dims(%input0: tensor<16x8xf32>, %input1: tensor<16x16xi32>) {
|
func @sort_different_dims(%input0: tensor<16x8xf32>, %input1: tensor<16x16xi32>) {
|
||||||
// expected-error @+1 {{op requires all inputs to have the same dimensions}}
|
// expected-error @+1 {{op requires the same shape for all operands and results}}
|
||||||
%0 = "mhlo.sort"(%input0, %input1) ( {
|
%0:2 = "mhlo.sort"(%input0, %input1) ( {
|
||||||
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<i32>):
|
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<i32>):
|
||||||
%7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
%7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||||
"mhlo.return"(%7) : (tensor<i1>) -> ()
|
"mhlo.return"(%7) : (tensor<i1>) -> ()
|
||||||
}) {dimension = 1 : i64, is_stable = true} : (tensor<16x8xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>>
|
}) {dimension = 1 : i64, is_stable = true} : (tensor<16x8xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1069,11 +1069,11 @@ func @sort_different_dims(%input0: tensor<16x8xf32>, %input1: tensor<16x16xi32>)
|
||||||
|
|
||||||
func @sort_dim_out_of_range(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) {
|
func @sort_dim_out_of_range(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) {
|
||||||
// expected-error @+1 {{dimension attribute value must be in range [-2, 2), but found 10}}
|
// expected-error @+1 {{dimension attribute value must be in range [-2, 2), but found 10}}
|
||||||
%0 = "mhlo.sort"(%input0, %input1) ( {
|
%0:2 = "mhlo.sort"(%input0, %input1) ( {
|
||||||
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<i32>):
|
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<i32>):
|
||||||
%7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
%7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||||
"mhlo.return"(%7) : (tensor<i1>) -> ()
|
"mhlo.return"(%7) : (tensor<i1>) -> ()
|
||||||
}) {dimension = 10 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>>
|
}) {dimension = 10 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1081,11 +1081,11 @@ func @sort_dim_out_of_range(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi3
|
||||||
|
|
||||||
func @sort_dim_out_of_range(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) {
|
func @sort_dim_out_of_range(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) {
|
||||||
// expected-error @+1 {{dimension attribute value must be in range [-2, 2), but found -3}}
|
// expected-error @+1 {{dimension attribute value must be in range [-2, 2), but found -3}}
|
||||||
%0 = "mhlo.sort"(%input0, %input1) ( {
|
%0:2 = "mhlo.sort"(%input0, %input1) ( {
|
||||||
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<i32>):
|
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<i32>):
|
||||||
%7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
%7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||||
"mhlo.return"(%7) : (tensor<i1>) -> ()
|
"mhlo.return"(%7) : (tensor<i1>) -> ()
|
||||||
}) {dimension = -3 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>>
|
}) {dimension = -3 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1093,11 +1093,11 @@ func @sort_dim_out_of_range(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi3
|
||||||
|
|
||||||
func @sort_wrong_block_arg_count(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) {
|
func @sort_wrong_block_arg_count(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) {
|
||||||
// expected-error @+1 {{op comparator block should have 4 arguments}}
|
// expected-error @+1 {{op comparator block should have 4 arguments}}
|
||||||
%0 = "mhlo.sort"(%input0, %input1) ( {
|
%0:2 = "mhlo.sort"(%input0, %input1) ( {
|
||||||
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>):
|
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>):
|
||||||
%7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
%7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||||
"mhlo.return"(%7) : (tensor<i1>) -> ()
|
"mhlo.return"(%7) : (tensor<i1>) -> ()
|
||||||
}) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>>
|
}) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1105,11 +1105,11 @@ func @sort_wrong_block_arg_count(%input0: tensor<16x16xf32>, %input1: tensor<16x
|
||||||
|
|
||||||
func @sort_wrong_block_arg_type(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) {
|
func @sort_wrong_block_arg_type(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) {
|
||||||
// expected-error @+1 {{op comparator block argument #3 should be of type 'tensor<i32>' but got 'tensor<f32>'}}
|
// expected-error @+1 {{op comparator block argument #3 should be of type 'tensor<i32>' but got 'tensor<f32>'}}
|
||||||
%0 = "mhlo.sort"(%input0, %input1) ( {
|
%0:2 = "mhlo.sort"(%input0, %input1) ( {
|
||||||
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<f32>):
|
^bb0(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i32>, %arg3: tensor<f32>):
|
||||||
%7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
%7 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||||
"mhlo.return"(%7) : (tensor<i1>) -> ()
|
"mhlo.return"(%7) : (tensor<i1>) -> ()
|
||||||
}) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> tuple<tensor<16x16xf32>, tensor<16x16xi32>>
|
}) {dimension = 1 : i64, is_stable = true} : (tensor<16x16xf32>, tensor<16x16xi32>) -> (tensor<16x16xf32>, tensor<16x16xi32>)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue