From 17ccca7f4b437b51336b84dd633105b3aac7366e Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Mon, 10 Aug 2020 08:59:35 -0700 Subject: [PATCH] Add HLO RngBitGenerator This adds the XlaBuilder RngBitGenerator to the MHLO dialect. The op is currently represented very directly using int attribute for random algorithm and direct import/export. PiperOrigin-RevId: 325814134 --- include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td | 16 +++++++++++++++- include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td | 13 +++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td index b8b1926..d0abbe0 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td @@ -1329,8 +1329,9 @@ def HLO_TorchIndexSelectOp : HLO_Op<"torch_index_select", [NoSideEffect]> { } //===----------------------------------------------------------------------===// -// MHLO RngUniform Operator. +// MHLO RNG Operators. //===----------------------------------------------------------------------===// + def HLO_RngUniformOp : HLO_Op<"rng_uniform", []>, BASE_HLO_RngUniformOp { let arguments = (ins HLO_PredIntOrFpTensor:$a, @@ -1355,6 +1356,19 @@ def HLO_RngNormalOp : HLO_Op<"rng_normal", []>, BASE_HLO_RngNormalOp { let hasCustomHLOConverter = 1; } +def HLO_RngBitGeneratorOp : HLO_Op<"rng_bit_generator", [NoSideEffect]>, BASE_HLO_RngBitGeneratorOp { + let arguments = (ins + // TODO(jpienaar): This could be an enum instead. + I32Attr:$rng_algorithm, + HLO_IntOrFpTensor:$initial_state + ); + + let results = (outs HLO_TensorOrTuple:$result); + + // TODO(jpienaar): This should not be needed. + let hasCustomHLOConverter = 1; +} + //===----------------------------------------------------------------------===// // MHLO Quantize Operator. //===----------------------------------------------------------------------===// diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td index 7f9784d..2f80545 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td @@ -316,6 +316,19 @@ class BASE_HLO_RealOp { }]; } +class BASE_HLO_RngBitGeneratorOp { + string summary = "Uniform random number generator operator"; + + string description = [{ + Returns an output with a given shape filled with uniform random bits using + the specified algorithm (or backend default) and returns an updated state + (with the same shape as initial state) and the generated random data. + + See + https://www.tensorflow.org/xla/operation_semantics#rngbitgenerator. + }]; +} + class BASE_HLO_RoundOp { string summary = "Round operator";