Add HLO RngBitGenerator

This adds the XlaBuilder RngBitGenerator to the MHLO dialect. The op is currently represented very directly using int attribute for random algorithm and direct import/export.

PiperOrigin-RevId: 325814134
This commit is contained in:
Jacques Pienaar 2020-08-10 08:59:35 -07:00 committed by TensorFlow MLIR Team
parent 843af36e05
commit 17ccca7f4b
2 changed files with 28 additions and 1 deletions

View File

@ -1329,8 +1329,9 @@ def HLO_TorchIndexSelectOp : HLO_Op<"torch_index_select", [NoSideEffect]> {
}
//===----------------------------------------------------------------------===//
// MHLO RngUniform Operator.
// MHLO RNG Operators.
//===----------------------------------------------------------------------===//
def HLO_RngUniformOp : HLO_Op<"rng_uniform", []>, BASE_HLO_RngUniformOp {
let arguments = (ins
HLO_PredIntOrFpTensor:$a,
@ -1355,6 +1356,19 @@ def HLO_RngNormalOp : HLO_Op<"rng_normal", []>, BASE_HLO_RngNormalOp {
let hasCustomHLOConverter = 1;
}
def HLO_RngBitGeneratorOp : HLO_Op<"rng_bit_generator", [NoSideEffect]>, BASE_HLO_RngBitGeneratorOp {
let arguments = (ins
// TODO(jpienaar): This could be an enum instead.
I32Attr:$rng_algorithm,
HLO_IntOrFpTensor:$initial_state
);
let results = (outs HLO_TensorOrTuple:$result);
// TODO(jpienaar): This should not be needed.
let hasCustomHLOConverter = 1;
}
//===----------------------------------------------------------------------===//
// MHLO Quantize Operator.
//===----------------------------------------------------------------------===//

View File

@ -316,6 +316,19 @@ class BASE_HLO_RealOp {
}];
}
class BASE_HLO_RngBitGeneratorOp {
string summary = "Uniform random number generator operator";
string description = [{
Returns an output with a given shape filled with uniform random bits using
the specified algorithm (or backend default) and returns an updated state
(with the same shape as initial state) and the generated random data.
See
https://www.tensorflow.org/xla/operation_semantics#rngbitgenerator.
}];
}
class BASE_HLO_RoundOp {
string summary = "Round operator";