diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td index db98bd1..e83bf87 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td @@ -664,7 +664,7 @@ def HLO_GetTupleElementOp: HLO_Op<"get_tuple_element", [NoSideEffect]>, BASE_HLO } def HLO_TupleOp : HLO_Op<"tuple", [NoSideEffect]>, BASE_HLO_TupleOp { - let arguments = (ins Variadic:$val); + let arguments = (ins Variadic:$val); let results = (outs HLO_Tuple); let builders = [OpBuilder< diff --git a/tests/ops.mlir b/tests/ops.mlir index 920e62e..212e794 100644 --- a/tests/ops.mlir +++ b/tests/ops.mlir @@ -847,6 +847,13 @@ func @tuple(%arg0: tensor<1xi32>, %arg1: tensor<1x2xf32>) -> tuple // ----- +func @tuple_token(%arg0: tensor, %arg1: !mhlo.token) -> tuple, !mhlo.token> { + %0 = "mhlo.tuple"(%arg0, %arg1) : (tensor, !mhlo.token) -> tuple, !mhlo.token> + return %0 : tuple, !mhlo.token> +} + +// ----- + func @tuple_arg_size_mismatch(%arg0: tensor, %arg1: tensor) -> tuple, tensor, tensor> { // expected-error@+1 {{has return type tuple, tensor, tensor>, but expected tuple, tensor>}} %0 = "mhlo.tuple"(%arg0, %arg1) : (tensor, tensor) -> tuple, tensor, tensor>