Add HLO RngBitGenerator
This adds the XlaBuilder RngBitGenerator to the MHLO dialect. The op is currently represented very directly using int attribute for random algorithm and direct import/export. PiperOrigin-RevId: 325814134
This commit is contained in:
parent
843af36e05
commit
17ccca7f4b
|
@ -1329,8 +1329,9 @@ def HLO_TorchIndexSelectOp : HLO_Op<"torch_index_select", [NoSideEffect]> {
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// MHLO RngUniform Operator.
|
// MHLO RNG Operators.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
def HLO_RngUniformOp : HLO_Op<"rng_uniform", []>, BASE_HLO_RngUniformOp {
|
def HLO_RngUniformOp : HLO_Op<"rng_uniform", []>, BASE_HLO_RngUniformOp {
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
HLO_PredIntOrFpTensor:$a,
|
HLO_PredIntOrFpTensor:$a,
|
||||||
|
@ -1355,6 +1356,19 @@ def HLO_RngNormalOp : HLO_Op<"rng_normal", []>, BASE_HLO_RngNormalOp {
|
||||||
let hasCustomHLOConverter = 1;
|
let hasCustomHLOConverter = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def HLO_RngBitGeneratorOp : HLO_Op<"rng_bit_generator", [NoSideEffect]>, BASE_HLO_RngBitGeneratorOp {
|
||||||
|
let arguments = (ins
|
||||||
|
// TODO(jpienaar): This could be an enum instead.
|
||||||
|
I32Attr:$rng_algorithm,
|
||||||
|
HLO_IntOrFpTensor:$initial_state
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs HLO_TensorOrTuple:$result);
|
||||||
|
|
||||||
|
// TODO(jpienaar): This should not be needed.
|
||||||
|
let hasCustomHLOConverter = 1;
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// MHLO Quantize Operator.
|
// MHLO Quantize Operator.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -316,6 +316,19 @@ class BASE_HLO_RealOp {
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class BASE_HLO_RngBitGeneratorOp {
|
||||||
|
string summary = "Uniform random number generator operator";
|
||||||
|
|
||||||
|
string description = [{
|
||||||
|
Returns an output with a given shape filled with uniform random bits using
|
||||||
|
the specified algorithm (or backend default) and returns an updated state
|
||||||
|
(with the same shape as initial state) and the generated random data.
|
||||||
|
|
||||||
|
See
|
||||||
|
https://www.tensorflow.org/xla/operation_semantics#rngbitgenerator.
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
class BASE_HLO_RoundOp {
|
class BASE_HLO_RoundOp {
|
||||||
string summary = "Round operator";
|
string summary = "Round operator";
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue