[MLIR][MHLO] Add polygamma op to the CHLO dialect
PiperOrigin-RevId: 357724465
This commit is contained in:
parent
02dab94054
commit
81abaf364d
|
@ -178,6 +178,15 @@ def HLOClient_BroadcastMulOp : HLOClient_BroadcastBinaryElementwiseOp<
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def HLOClient_BroadcastPolygammaOp : HLOClient_BroadcastBinaryElementwiseOp<
|
||||||
|
"broadcast_polygamma", [NoSideEffect, SameOperandsAndResultElementType]> {
|
||||||
|
let summary = "Polygamma function (with optional broadcasting)";
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
Returns `Polygamma(operand, operand)` element-wise.
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def HLOClient_BroadcastPowOp : HLOClient_BroadcastBinaryElementwiseOp<
|
def HLOClient_BroadcastPowOp : HLOClient_BroadcastBinaryElementwiseOp<
|
||||||
"broadcast_power",
|
"broadcast_power",
|
||||||
[NoSideEffect, SameOperandsAndResultElementType]> {
|
[NoSideEffect, SameOperandsAndResultElementType]> {
|
||||||
|
@ -339,10 +348,9 @@ def HLOClient_BroadcastXorOp : HLOClient_BroadcastBinaryLogicalElementwiseOp<
|
||||||
// not part of the HLO compiler instructions as modelled by the MHLO dialect.
|
// not part of the HLO compiler instructions as modelled by the MHLO dialect.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
def HLOClient_ZetaOp : HLOClient_Op<"zeta",
|
def HLOClient_ZetaOp : HLOClient_Op<"zeta", [NoSideEffect,
|
||||||
[NoSideEffect, SameOperandsAndResultType]> {
|
SameOperandsAndResultType]> {
|
||||||
let summary = "Hurwitz zeta function";
|
let summary = "Hurwitz zeta function";
|
||||||
|
|
||||||
let description = [{
|
let description = [{
|
||||||
Returns `Zeta(operand, operand)` element-wise.
|
Returns `Zeta(operand, operand)` element-wise.
|
||||||
|
|
||||||
|
@ -351,15 +359,26 @@ def HLOClient_ZetaOp : HLOClient_Op<"zeta",
|
||||||
$$
|
$$
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins HLO_FpTensor:$x, HLO_FpTensor:$q);
|
||||||
HLO_FpTensor:$x,
|
let results = (outs HLO_FpTensor:$result);
|
||||||
HLO_FpTensor:$q
|
|
||||||
);
|
|
||||||
|
|
||||||
let results = (outs HLO_FpTensor);
|
|
||||||
|
|
||||||
let assemblyFormat = [{
|
let assemblyFormat = [{
|
||||||
$x `,` $q attr-dict `:` `(` type($x) `,` type($q) `)` `->` type(results)
|
$x `,` $q attr-dict `:` type($x) `,` type($q) `->` type(results)
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def HLOClient_PolygammaOp : HLOClient_Op<"polygamma", [NoSideEffect,
|
||||||
|
SameOperandsAndResultType]> {
|
||||||
|
let summary = "Polygamma function";
|
||||||
|
let description = [{
|
||||||
|
Returns `Polygamma(operand, operand)` element-wise.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins HLO_FpTensor:$n, HLO_FpTensor:$x);
|
||||||
|
let results = (outs HLO_FpTensor:$result);
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
$n `,` $x attr-dict `:` type($n) `,` type($x) `->` type(results)
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -318,6 +318,7 @@ BROADCAST_BINARY_OP_DEFS(BroadcastMaxOp);
|
||||||
BROADCAST_BINARY_OP_DEFS(BroadcastMinOp);
|
BROADCAST_BINARY_OP_DEFS(BroadcastMinOp);
|
||||||
BROADCAST_BINARY_OP_DEFS(BroadcastMulOp);
|
BROADCAST_BINARY_OP_DEFS(BroadcastMulOp);
|
||||||
BROADCAST_BINARY_OP_DEFS(BroadcastOrOp);
|
BROADCAST_BINARY_OP_DEFS(BroadcastOrOp);
|
||||||
|
BROADCAST_BINARY_OP_DEFS(BroadcastPolygammaOp);
|
||||||
BROADCAST_BINARY_OP_DEFS(BroadcastPowOp);
|
BROADCAST_BINARY_OP_DEFS(BroadcastPowOp);
|
||||||
BROADCAST_BINARY_OP_DEFS(BroadcastRemOp);
|
BROADCAST_BINARY_OP_DEFS(BroadcastRemOp);
|
||||||
BROADCAST_BINARY_OP_DEFS(BroadcastShiftLeftOp);
|
BROADCAST_BINARY_OP_DEFS(BroadcastShiftLeftOp);
|
||||||
|
|
|
@ -1110,7 +1110,6 @@ func @digamma_f16(%arg : tensor<f16>) -> tensor<f16> {
|
||||||
// CHECK-SAME: %[[VAL_0:.*]]: tensor<f16>,
|
// CHECK-SAME: %[[VAL_0:.*]]: tensor<f16>,
|
||||||
// CHECK-SAME: %[[VAL_1:.*]]: tensor<f16>) -> tensor<f16> {
|
// CHECK-SAME: %[[VAL_1:.*]]: tensor<f16>) -> tensor<f16> {
|
||||||
func @zeta_f16(%arg0: tensor<f16>, %arg1: tensor<f16>) -> tensor<f16> {
|
func @zeta_f16(%arg0: tensor<f16>, %arg1: tensor<f16>) -> tensor<f16> {
|
||||||
%0 = chlo.zeta %arg0, %arg1 : (tensor<f16>, tensor<f16>) -> tensor<f16>
|
|
||||||
// CHECK: %[[VAL_2:.*]] = "mhlo.convert"(%[[VAL_0]]) : (tensor<f16>) -> tensor<f32>
|
// CHECK: %[[VAL_2:.*]] = "mhlo.convert"(%[[VAL_0]]) : (tensor<f16>) -> tensor<f32>
|
||||||
// CHECK: %[[VAL_3:.*]] = "mhlo.convert"(%[[VAL_1]]) : (tensor<f16>) -> tensor<f32>
|
// CHECK: %[[VAL_3:.*]] = "mhlo.convert"(%[[VAL_1]]) : (tensor<f16>) -> tensor<f32>
|
||||||
// CHECK: %[[VAL_4:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
// CHECK: %[[VAL_4:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
||||||
|
@ -1283,6 +1282,7 @@ func @zeta_f16(%arg0: tensor<f16>, %arg1: tensor<f16>) -> tensor<f16> {
|
||||||
// CHECK: %[[VAL_172:.*]] = "mhlo.select"(%[[VAL_171]], %[[VAL_157]], %[[VAL_162]]) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
// CHECK: %[[VAL_172:.*]] = "mhlo.select"(%[[VAL_171]], %[[VAL_157]], %[[VAL_162]]) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||||
// CHECK: %[[VAL_173:.*]] = "mhlo.select"(%[[VAL_167]], %[[VAL_160]], %[[VAL_172]]) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
// CHECK: %[[VAL_173:.*]] = "mhlo.select"(%[[VAL_167]], %[[VAL_160]], %[[VAL_172]]) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||||
// CHECK: %[[VAL_174:.*]] = "mhlo.convert"(%[[VAL_173]]) : (tensor<f32>) -> tensor<f16>
|
// CHECK: %[[VAL_174:.*]] = "mhlo.convert"(%[[VAL_173]]) : (tensor<f32>) -> tensor<f16>
|
||||||
return %0 : tensor<f16>
|
|
||||||
// CHECK: return %[[VAL_174]] : tensor<f16>
|
// CHECK: return %[[VAL_174]] : tensor<f16>
|
||||||
|
%0 = chlo.zeta %arg0, %arg1 : tensor<f16>, tensor<f16> -> tensor<f16>
|
||||||
|
return %0 : tensor<f16>
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue