Add support for token operands to mhlo.tuple.
mhlo.get_tuple_element supports extracting a mhlo.token type from a tuple. This updates the creation of tuples to allow for mhlo.token typed operands. PiperOrigin-RevId: 324628663
This commit is contained in:
parent
3fe9a7d2db
commit
4c8fead3e0
|
@ -664,7 +664,7 @@ def HLO_GetTupleElementOp: HLO_Op<"get_tuple_element", [NoSideEffect]>, BASE_HLO
|
||||||
}
|
}
|
||||||
|
|
||||||
def HLO_TupleOp : HLO_Op<"tuple", [NoSideEffect]>, BASE_HLO_TupleOp {
|
def HLO_TupleOp : HLO_Op<"tuple", [NoSideEffect]>, BASE_HLO_TupleOp {
|
||||||
let arguments = (ins Variadic<HLO_TensorOrTuple>:$val);
|
let arguments = (ins Variadic<HLO_TensorOrTokenOrTuple>:$val);
|
||||||
let results = (outs HLO_Tuple);
|
let results = (outs HLO_Tuple);
|
||||||
|
|
||||||
let builders = [OpBuilder<
|
let builders = [OpBuilder<
|
||||||
|
|
|
@ -847,6 +847,13 @@ func @tuple(%arg0: tensor<1xi32>, %arg1: tensor<1x2xf32>) -> tuple<tensor<1xi32>
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
func @tuple_token(%arg0: tensor<f32>, %arg1: !mhlo.token) -> tuple<tensor<f32>, !mhlo.token> {
|
||||||
|
%0 = "mhlo.tuple"(%arg0, %arg1) : (tensor<f32>, !mhlo.token) -> tuple<tensor<f32>, !mhlo.token>
|
||||||
|
return %0 : tuple<tensor<f32>, !mhlo.token>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
func @tuple_arg_size_mismatch(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tuple<tensor<f32>, tensor<f32>, tensor<f32>> {
|
func @tuple_arg_size_mismatch(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tuple<tensor<f32>, tensor<f32>, tensor<f32>> {
|
||||||
// expected-error@+1 {{has return type tuple<tensor<f32>, tensor<f32>, tensor<f32>>, but expected tuple<tensor<f32>, tensor<f32>>}}
|
// expected-error@+1 {{has return type tuple<tensor<f32>, tensor<f32>, tensor<f32>>, but expected tuple<tensor<f32>, tensor<f32>>}}
|
||||||
%0 = "mhlo.tuple"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tuple<tensor<f32>, tensor<f32>, tensor<f32>>
|
%0 = "mhlo.tuple"(%arg0, %arg1) : (tensor<f32>, tensor<f32>) -> tuple<tensor<f32>, tensor<f32>, tensor<f32>>
|
||||||
|
|
Loading…
Reference in New Issue