[MLIR][MHLO] Add polygamma op to the CHLO dialect

PiperOrigin-RevId: 357724465
This commit is contained in:
A. Unique TensorFlower 2021-02-16 08:31:19 -08:00 committed by TensorFlow MLIR Team
parent 02dab94054
commit 81abaf364d
3 changed files with 204 additions and 184 deletions

View File

@ -178,6 +178,15 @@ def HLOClient_BroadcastMulOp : HLOClient_BroadcastBinaryElementwiseOp<
}];
}
def HLOClient_BroadcastPolygammaOp : HLOClient_BroadcastBinaryElementwiseOp<
"broadcast_polygamma", [NoSideEffect, SameOperandsAndResultElementType]> {
let summary = "Polygamma function (with optional broadcasting)";
let description = [{
Returns `Polygamma(operand, operand)` element-wise.
}];
}
def HLOClient_BroadcastPowOp : HLOClient_BroadcastBinaryElementwiseOp<
"broadcast_power",
[NoSideEffect, SameOperandsAndResultElementType]> {
@ -339,10 +348,9 @@ def HLOClient_BroadcastXorOp : HLOClient_BroadcastBinaryLogicalElementwiseOp<
// not part of the HLO compiler instructions as modelled by the MHLO dialect.
//===----------------------------------------------------------------------===//
def HLOClient_ZetaOp : HLOClient_Op<"zeta",
[NoSideEffect, SameOperandsAndResultType]> {
def HLOClient_ZetaOp : HLOClient_Op<"zeta", [NoSideEffect,
SameOperandsAndResultType]> {
let summary = "Hurwitz zeta function";
let description = [{
Returns `Zeta(operand, operand)` element-wise.
@ -351,15 +359,26 @@ def HLOClient_ZetaOp : HLOClient_Op<"zeta",
$$
}];
let arguments = (ins
HLO_FpTensor:$x,
HLO_FpTensor:$q
);
let results = (outs HLO_FpTensor);
let arguments = (ins HLO_FpTensor:$x, HLO_FpTensor:$q);
let results = (outs HLO_FpTensor:$result);
let assemblyFormat = [{
$x `,` $q attr-dict `:` `(` type($x) `,` type($q) `)` `->` type(results)
$x `,` $q attr-dict `:` type($x) `,` type($q) `->` type(results)
}];
}
def HLOClient_PolygammaOp : HLOClient_Op<"polygamma", [NoSideEffect,
SameOperandsAndResultType]> {
let summary = "Polygamma function";
let description = [{
Returns `Polygamma(operand, operand)` element-wise.
}];
let arguments = (ins HLO_FpTensor:$n, HLO_FpTensor:$x);
let results = (outs HLO_FpTensor:$result);
let assemblyFormat = [{
$n `,` $x attr-dict `:` type($n) `,` type($x) `->` type(results)
}];
}

View File

@ -318,6 +318,7 @@ BROADCAST_BINARY_OP_DEFS(BroadcastMaxOp);
BROADCAST_BINARY_OP_DEFS(BroadcastMinOp);
BROADCAST_BINARY_OP_DEFS(BroadcastMulOp);
BROADCAST_BINARY_OP_DEFS(BroadcastOrOp);
BROADCAST_BINARY_OP_DEFS(BroadcastPolygammaOp);
BROADCAST_BINARY_OP_DEFS(BroadcastPowOp);
BROADCAST_BINARY_OP_DEFS(BroadcastRemOp);
BROADCAST_BINARY_OP_DEFS(BroadcastShiftLeftOp);

View File

@ -1110,7 +1110,6 @@ func @digamma_f16(%arg : tensor<f16>) -> tensor<f16> {
// CHECK-SAME: %[[VAL_0:.*]]: tensor<f16>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<f16>) -> tensor<f16> {
func @zeta_f16(%arg0: tensor<f16>, %arg1: tensor<f16>) -> tensor<f16> {
%0 = chlo.zeta %arg0, %arg1 : (tensor<f16>, tensor<f16>) -> tensor<f16>
// CHECK: %[[VAL_2:.*]] = "mhlo.convert"(%[[VAL_0]]) : (tensor<f16>) -> tensor<f32>
// CHECK: %[[VAL_3:.*]] = "mhlo.convert"(%[[VAL_1]]) : (tensor<f16>) -> tensor<f32>
// CHECK: %[[VAL_4:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
@ -1283,6 +1282,7 @@ func @zeta_f16(%arg0: tensor<f16>, %arg1: tensor<f16>) -> tensor<f16> {
// CHECK: %[[VAL_172:.*]] = "mhlo.select"(%[[VAL_171]], %[[VAL_157]], %[[VAL_162]]) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
// CHECK: %[[VAL_173:.*]] = "mhlo.select"(%[[VAL_167]], %[[VAL_160]], %[[VAL_172]]) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
// CHECK: %[[VAL_174:.*]] = "mhlo.convert"(%[[VAL_173]]) : (tensor<f32>) -> tensor<f16>
return %0 : tensor<f16>
// CHECK: return %[[VAL_174]] : tensor<f16>
%0 = chlo.zeta %arg0, %arg1 : tensor<f16>, tensor<f16> -> tensor<f16>
return %0 : tensor<f16>
}