[MLIR][MHLO] Add polygamma op to the CHLO dialect
PiperOrigin-RevId: 357724465
This commit is contained in:
parent
02dab94054
commit
81abaf364d
|
@ -178,6 +178,15 @@ def HLOClient_BroadcastMulOp : HLOClient_BroadcastBinaryElementwiseOp<
|
|||
}];
|
||||
}
|
||||
|
||||
def HLOClient_BroadcastPolygammaOp : HLOClient_BroadcastBinaryElementwiseOp<
|
||||
"broadcast_polygamma", [NoSideEffect, SameOperandsAndResultElementType]> {
|
||||
let summary = "Polygamma function (with optional broadcasting)";
|
||||
|
||||
let description = [{
|
||||
Returns `Polygamma(operand, operand)` element-wise.
|
||||
}];
|
||||
}
|
||||
|
||||
def HLOClient_BroadcastPowOp : HLOClient_BroadcastBinaryElementwiseOp<
|
||||
"broadcast_power",
|
||||
[NoSideEffect, SameOperandsAndResultElementType]> {
|
||||
|
@ -339,10 +348,9 @@ def HLOClient_BroadcastXorOp : HLOClient_BroadcastBinaryLogicalElementwiseOp<
|
|||
// not part of the HLO compiler instructions as modelled by the MHLO dialect.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def HLOClient_ZetaOp : HLOClient_Op<"zeta",
|
||||
[NoSideEffect, SameOperandsAndResultType]> {
|
||||
def HLOClient_ZetaOp : HLOClient_Op<"zeta", [NoSideEffect,
|
||||
SameOperandsAndResultType]> {
|
||||
let summary = "Hurwitz zeta function";
|
||||
|
||||
let description = [{
|
||||
Returns `Zeta(operand, operand)` element-wise.
|
||||
|
||||
|
@ -351,15 +359,26 @@ def HLOClient_ZetaOp : HLOClient_Op<"zeta",
|
|||
$$
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
HLO_FpTensor:$x,
|
||||
HLO_FpTensor:$q
|
||||
);
|
||||
|
||||
let results = (outs HLO_FpTensor);
|
||||
let arguments = (ins HLO_FpTensor:$x, HLO_FpTensor:$q);
|
||||
let results = (outs HLO_FpTensor:$result);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$x `,` $q attr-dict `:` `(` type($x) `,` type($q) `)` `->` type(results)
|
||||
$x `,` $q attr-dict `:` type($x) `,` type($q) `->` type(results)
|
||||
}];
|
||||
}
|
||||
|
||||
def HLOClient_PolygammaOp : HLOClient_Op<"polygamma", [NoSideEffect,
|
||||
SameOperandsAndResultType]> {
|
||||
let summary = "Polygamma function";
|
||||
let description = [{
|
||||
Returns `Polygamma(operand, operand)` element-wise.
|
||||
}];
|
||||
|
||||
let arguments = (ins HLO_FpTensor:$n, HLO_FpTensor:$x);
|
||||
let results = (outs HLO_FpTensor:$result);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$n `,` $x attr-dict `:` type($n) `,` type($x) `->` type(results)
|
||||
}];
|
||||
}
|
||||
|
||||
|
|
|
@ -318,6 +318,7 @@ BROADCAST_BINARY_OP_DEFS(BroadcastMaxOp);
|
|||
BROADCAST_BINARY_OP_DEFS(BroadcastMinOp);
|
||||
BROADCAST_BINARY_OP_DEFS(BroadcastMulOp);
|
||||
BROADCAST_BINARY_OP_DEFS(BroadcastOrOp);
|
||||
BROADCAST_BINARY_OP_DEFS(BroadcastPolygammaOp);
|
||||
BROADCAST_BINARY_OP_DEFS(BroadcastPowOp);
|
||||
BROADCAST_BINARY_OP_DEFS(BroadcastRemOp);
|
||||
BROADCAST_BINARY_OP_DEFS(BroadcastShiftLeftOp);
|
||||
|
|
|
@ -1110,7 +1110,6 @@ func @digamma_f16(%arg : tensor<f16>) -> tensor<f16> {
|
|||
// CHECK-SAME: %[[VAL_0:.*]]: tensor<f16>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: tensor<f16>) -> tensor<f16> {
|
||||
func @zeta_f16(%arg0: tensor<f16>, %arg1: tensor<f16>) -> tensor<f16> {
|
||||
%0 = chlo.zeta %arg0, %arg1 : (tensor<f16>, tensor<f16>) -> tensor<f16>
|
||||
// CHECK: %[[VAL_2:.*]] = "mhlo.convert"(%[[VAL_0]]) : (tensor<f16>) -> tensor<f32>
|
||||
// CHECK: %[[VAL_3:.*]] = "mhlo.convert"(%[[VAL_1]]) : (tensor<f16>) -> tensor<f32>
|
||||
// CHECK: %[[VAL_4:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
||||
|
@ -1283,6 +1282,7 @@ func @zeta_f16(%arg0: tensor<f16>, %arg1: tensor<f16>) -> tensor<f16> {
|
|||
// CHECK: %[[VAL_172:.*]] = "mhlo.select"(%[[VAL_171]], %[[VAL_157]], %[[VAL_162]]) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
// CHECK: %[[VAL_173:.*]] = "mhlo.select"(%[[VAL_167]], %[[VAL_160]], %[[VAL_172]]) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
// CHECK: %[[VAL_174:.*]] = "mhlo.convert"(%[[VAL_173]]) : (tensor<f32>) -> tensor<f16>
|
||||
return %0 : tensor<f16>
|
||||
// CHECK: return %[[VAL_174]] : tensor<f16>
|
||||
%0 = chlo.zeta %arg0, %arg1 : tensor<f16>, tensor<f16> -> tensor<f16>
|
||||
return %0 : tensor<f16>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue