[MLIR] Add cbrt, reduce-precision, and bitcast ops to MHLO.

PiperOrigin-RevId: 335109804
This commit is contained in:
Tim Shen 2020-10-02 15:12:40 -07:00 committed by TensorFlow MLIR Team
parent bcf6fbf612
commit c708bfd6d0
3 changed files with 65 additions and 0 deletions

View File

@ -157,6 +157,9 @@ def HLO_AbsOp: HLO_UnaryElementwiseOp<"abs",
>];
}
def HLO_CbrtOp: HLO_UnaryElementwiseOp<"cbrt",
[NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_CbrtOp;
def HLO_CeilOp: HLO_UnaryElementwiseOp<"ceil",
[NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_CeilOp;
@ -1423,4 +1426,21 @@ def HLO_FusionOp : HLO_Op<"fusion", []> {
let hasCustomHLOConverter = 1;
}
// This is an op for purposes internal to XLA/GPU.
def HLO_BitcastOp : HLO_Op<"bitcast", [NoSideEffect]>, BASE_HLO_BitcastOp {
let arguments = (ins HLO_Tensor:$operand);
let results = (outs HLO_Tensor);
let hasCustomHLOConverter = 1;
}
def HLO_ReducePrecisionOp: HLO_Op<"reduce_precision", [SameOperandsAndResultShape]>,
BASE_HLO_ReducePrecisionOp {
let arguments = (ins
HLO_FpTensor:$operand,
I32Attr:$exponent_bits,
I32Attr:$mantissa_bits
);
let results = (outs HLO_FpTensor:$output);
}
#endif // HLO_OPS

View File

@ -127,6 +127,17 @@ class BASE_HLO_AbsOp {
}];
}
class BASE_HLO_CbrtOp {
string summary = "Cubic root operator";
string description = [{
Returns element-wise cubic root of the operand.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions.
}];
}
class BASE_HLO_CeilOp {
string summary = "Ceil operator";
@ -1336,4 +1347,17 @@ class BASE_HLO_WhileOp {
}];
}
class BASE_HLO_BitcastOp {
string summary = "Bitcast operator";
string description = [{
This op changes the shape of the input in the way that the physical
arranggment of elements are unchanged.
However, the op needs layout information to make sense of "physical
arrangement of elements". Layout support in MHLO is currently under
exploration.
}];
}
#endif // HLO_OPS_BASE

View File

@ -1193,3 +1193,24 @@ func @incompatible_shapes(%arg0: tensor<?xf32>, %shape: tensor<2xindex>) -> tens
%0 = "mhlo.dynamic_reshape"(%arg0, %shape) : (tensor<?xf32>, tensor<2xindex>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// -----
func @cbrt(%arg: tensor<2x4xf32>) -> tensor<2x4xf32> {
%0 = "mhlo.cbrt"(%arg) : (tensor<2x4xf32>) -> tensor<2x4xf32>
return %0 : tensor<2x4xf32>
}
// -----
func @bitcast(%arg: tensor<2x4xf32>) -> tensor<2x4xf32> {
%0 = "mhlo.bitcast"(%arg) : (tensor<2x4xf32>) -> tensor<2x4xf32>
return %0 : tensor<2x4xf32>
}
// -----
func @bitcast(%arg: tensor<2x4xf32>) -> tensor<2x4xf32> {
%0 = "mhlo.reduce_precision"(%arg) {exponent_bits=2 : i32, mantissa_bits=3 : i32} : (tensor<2x4xf32>) -> tensor<2x4xf32>
return %0 : tensor<2x4xf32>
}