Fix return type of replica-id to unsigned int 32 tensor.

PiperOrigin-RevId: 330012181
This commit is contained in:
Prakalp Srivastava 2020-09-03 16:07:55 -07:00 committed by TensorFlow MLIR Team
parent 12b221f459
commit b7248424ae
2 changed files with 9 additions and 3 deletions

View File

@ -491,9 +491,7 @@ def HLO_RecvOp : HLO_Op<"recv", []> {
def HLO_ReplicaIdOp : HLO_Op<"replica_id", [NoSideEffect]>,
BASE_HLO_ReplicaIdOp {
// TODO(prakalps): The output should unsigned 32-bit integer but mlir does
// not differentiate between signed and unsigned int.
let results = (outs I32Tensor);
let results = (outs TensorOf<[UI32]>);
}
//===----------------------------------------------------------------------===//

View File

@ -600,6 +600,14 @@ func @recv_non_token_second_result(%token: !mhlo.token) -> tuple<tensor<3x4xi32>
// -----
// CHECK-LABEL: func @replica_id
func @replica_id() -> tensor<ui32> {
%0 = "mhlo.replica_id"() : () -> tensor<ui32>
return %0 : tensor<ui32>
}
// -----
func @rng_uniform_invalid_type(%mu: tensor<complex<f32>>, %sigma: tensor<f32>) -> tensor<2x3x5xf32> {
%shape = mhlo.constant dense<[2, 3, 5]> : tensor<3xi64>
// expected-error@+1 {{but got 'tensor<complex<f32>>'}}