Fix return type of replica-id to unsigned int 32 tensor.
PiperOrigin-RevId: 330012181
This commit is contained in:
parent
12b221f459
commit
b7248424ae
|
@ -491,9 +491,7 @@ def HLO_RecvOp : HLO_Op<"recv", []> {
|
|||
|
||||
def HLO_ReplicaIdOp : HLO_Op<"replica_id", [NoSideEffect]>,
|
||||
BASE_HLO_ReplicaIdOp {
|
||||
// TODO(prakalps): The output should unsigned 32-bit integer but mlir does
|
||||
// not differentiate between signed and unsigned int.
|
||||
let results = (outs I32Tensor);
|
||||
let results = (outs TensorOf<[UI32]>);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -600,6 +600,14 @@ func @recv_non_token_second_result(%token: !mhlo.token) -> tuple<tensor<3x4xi32>
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @replica_id
|
||||
func @replica_id() -> tensor<ui32> {
|
||||
%0 = "mhlo.replica_id"() : () -> tensor<ui32>
|
||||
return %0 : tensor<ui32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @rng_uniform_invalid_type(%mu: tensor<complex<f32>>, %sigma: tensor<f32>) -> tensor<2x3x5xf32> {
|
||||
%shape = mhlo.constant dense<[2, 3, 5]> : tensor<3xi64>
|
||||
// expected-error@+1 {{but got 'tensor<complex<f32>>'}}
|
||||
|
|
Loading…
Reference in New Issue