diff --git a/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td index 48adffc..a8a214e 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td @@ -66,7 +66,7 @@ class HLOClient_Op traits> : // broadcasting (via the broadcast_dimensions attribute) and implicit degenerate // shape broadcasting. // -// These correspond to operations in the mhlo dialect without the +// These correspond to operations in the chlo and mhlo dialects without the // "broadcast_" prefix, except that those ops require same-shaped operands and // results. // @@ -256,8 +256,31 @@ def HLOClient_BroadcastSubOp : HLOClient_BroadcastBinaryElementwiseOp< }]; } +def HLOCLient_BroadcastZetaOp : HLOClient_BroadcastBinaryElementwiseOp< + "broadcast_zeta", + [NoSideEffect, SameOperandsAndResultElementType]> { + let summary = "Hurwitz zeta function"; + + let description = [{ + Returns `Zeta(operand, operand)` element-wise. + + $$ + \(\zeta(x, q) = \sum_{n=0}^{\infty} (q + n)^{-x}\) + $$ + }]; + + let arguments = (ins + HLO_FpTensor:$lhs, + HLO_FpTensor:$rhs, + // Explicit rank-broadcast dimension mappings. Defaults to "numpy" prefix + // padded rank-broadcast semantics if omitted. + OptionalAttr:$broadcast_dimensions + ); + let results = (outs HLO_FpTensor); +} + //===----------------------------------------------------------------------===// -// XLA binary elementwise op definitions. +// XLA binary logical elementwise op definitions. // The same description as the arithmetic binary elementwise ops applies. //===----------------------------------------------------------------------===// @@ -309,6 +332,38 @@ def HLOClient_BroadcastXorOp : HLOClient_BroadcastBinaryLogicalElementwiseOp< }]; } +//===----------------------------------------------------------------------===// +// XLA non-broadcasting binary operations. +// +// These are operations that are supported by the XLA Builder API but that are +// not part of the HLO compiler instructions as modelled by the MHLO dialect. +//===----------------------------------------------------------------------===// + +def HLOClient_ZetaOp : HLOClient_Op<"zeta", + [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Hurwitz zeta function"; + + let description = [{ + Returns `Zeta(operand, operand)` element-wise. + + $$ + \(\zeta(x, q) = \sum_{n=0}^{\infty} (q + n)^{-x}\) + $$ + }]; + + let arguments = (ins + HLO_FpTensor:$lhs, + HLO_FpTensor:$rhs + ); + + let results = (outs HLO_FpTensor); + + let assemblyFormat = [{ + $lhs `,` $rhs attr-dict `:` + `(` type($lhs) `,` type($rhs) `)` `->` type(results) + }]; +} + //===----------------------------------------------------------------------===// // Broadcasting complex op //===----------------------------------------------------------------------===// diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h b/include/mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h index 316e650..2b1b07c 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/map_chlo_to_hlo_op.h @@ -53,7 +53,7 @@ struct HloCompareAdaptor { }; // Populate a pattern for each Broadcasting CHlo op. This requires the pattern -// to take a ChloOpTy, MhloOpTy, and an Adaptor as templated values. +// to take a ChloOpTy, NonBroadcastingOpTy, and an Adaptor as templated values. template