Add support for token operands to mhlo.tuple.

mhlo.get_tuple_element supports extracting a mhlo.token type from a tuple. This updates the creation of tuples to allow for mhlo.token typed operands.

PiperOrigin-RevId: 324628663
This commit is contained in:
Andy Ly 2020-08-03 10:17:46 -07:00 committed by TensorFlow MLIR Team
parent 3fe9a7d2db
commit 4c8fead3e0
2 changed files with 8 additions and 1 deletions

View File

@ -664,7 +664,7 @@ def HLO_GetTupleElementOp: HLO_Op<"get_tuple_element", [NoSideEffect]>, BASE_HLO
} }
def HLO_TupleOp : HLO_Op<"tuple", [NoSideEffect]>, BASE_HLO_TupleOp { def HLO_TupleOp : HLO_Op<"tuple", [NoSideEffect]>, BASE_HLO_TupleOp {
let arguments = (ins Variadic<HLO_TensorOrTuple>:$val); let arguments = (ins Variadic<HLO_TensorOrTokenOrTuple>:$val);
let results = (outs HLO_Tuple); let results = (outs HLO_Tuple);
let builders = [OpBuilder< let builders = [OpBuilder<

View File

@ -847,6 +847,13 @@ func @tuple(%arg0: tensor<1xi32>, %arg1: tensor<1x2xf32>) -> tuple<tensor<1xi32>
// ----- // -----
func @tuple_token(%arg0: tensor<f32>, %arg1: !mhlo.token) -> tuple<tensor<f32>, !mhlo.token> {
%0 = "mhlo.tuple"(%arg0, %arg1) : (tensor<f32>, !mhlo.token) -> tuple<tensor<f32>, !mhlo.token>
return %0 : tuple<tensor<f32>, !mhlo.token>
}
// -----
func @tuple_arg_size_mismatch(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tuple<tensor<f32>, tensor<f32>, tensor<f32>> { func @tuple_arg_size_mismatch(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tuple<tensor<f32>, tensor<f32>, tensor<f32>> {
// expected-error@+1 {{has return type tuple<tensor<f32>, tensor<f32>, tensor<f32>>, but expected tuple<tensor<f32>, tensor<f32>>}} // expected-error@+1 {{has return type tuple<tensor<f32>, tensor<f32>, tensor<f32>>, but expected tuple<tensor<f32>, tensor<f32>>}}
%0 = "mhlo.tuple"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tuple<tensor<f32>, tensor<f32>, tensor<f32>> %0 = "mhlo.tuple"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tuple<tensor<f32>, tensor<f32>, tensor<f32>>