712 lines
22 KiB
TableGen
712 lines
22 KiB
TableGen
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
// Defines "client" aligned HLO ops.
|
|
// These ops are not necessarily orthogonal or optimized for transformation but
|
|
// for ease of expression in certain cases deemed important for client
|
|
// libraries (i.e. implicit broadcasting, helper ops, etc).
|
|
// This dialect is considered to exist in addition to augment the mhlo
|
|
// dialect for ergonomic needs, not duplicate/replace it.
|
|
//
|
|
// The typical use of this dialect is for client libraries to be able to emit
|
|
// less constrained ops and rely on the conversion framework to lower any
|
|
// chlo ops to canonical mhlo ops.
|
|
//
|
|
// See: https://www.tensorflow.org/xla/operation_semantics
|
|
|
|
#ifndef CHLO_OPS
|
|
#define CHLO_OPS
|
|
|
|
include "mlir/IR/OpBase.td"
|
|
include "mlir/Interfaces/InferTypeOpInterface.td"
|
|
include "mlir/Interfaces/SideEffectInterfaces.td"
|
|
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td"
|
|
include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td"
|
|
|
|
def HLOClient_Dialect : Dialect {
|
|
let name = "chlo";
|
|
let cppNamespace = "::mlir::chlo";
|
|
let summary = [{
|
|
Client HLO Ops
|
|
}];
|
|
|
|
let description = [{
|
|
This dialect contains ops that align closely with the API surface area
|
|
of the XlaBuilder C++ API, where such ops have semantics that go beyond
|
|
what exists in the lower level dialects (such as `mhlo`). Essentially,
|
|
whenever the client library uses syntactic sugar or composition
|
|
of multiple ops for an API call, this dialect tries to model the API call
|
|
and provide conversion patterns to fully materialize into lower level
|
|
dialects.
|
|
}];
|
|
}
|
|
|
|
class HLOClient_Op<string mnemonic, list<OpTrait> traits> :
|
|
Op<HLOClient_Dialect, mnemonic, traits> {
|
|
// TODO(b/129012527) Much of this custom verification should be expressed as
|
|
// type constraints.
|
|
let verifier = [{ return Verify(*this); }];
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// CHLO binary elementwise op definitions.
|
|
// From the client perspective, each of these support both explicit rank
|
|
// broadcasting (via the broadcast_dimensions attribute) and implicit degenerate
|
|
// shape broadcasting.
|
|
//
|
|
// These correspond to operations in the chlo and mhlo dialects without the
|
|
// "broadcast_" prefix, except that those ops require same-shaped operands and
|
|
// results.
|
|
//
|
|
// See:
|
|
// https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations
|
|
// https://www.tensorflow.org/xla/broadcasting
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
class HLOClient_BroadcastBinaryElementwiseOp<
|
|
string mnemonic, list<OpTrait> traits> :
|
|
HLOClient_Op<mnemonic,
|
|
!listconcat(traits, [
|
|
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
|
|
["reifyReturnTypeShapes"]>])> {
|
|
let arguments = (ins
|
|
HLO_Tensor:$lhs,
|
|
HLO_Tensor:$rhs,
|
|
// Explicit rank-broadcast dimension mappings. Defaults to "numpy" prefix
|
|
// padded rank-broadcast semantics if omitted.
|
|
OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions
|
|
);
|
|
|
|
let builders = [
|
|
OpBuilderDAG<(ins "Value":$left, "Value":$right,
|
|
"DenseIntElementsAttr":$broadcast_dimensions)>];
|
|
|
|
let results = (outs HLO_Tensor);
|
|
|
|
let assemblyFormat = [{
|
|
$lhs `,` $rhs attr-dict `:`
|
|
`(` type($lhs) `,` type($rhs) `)` `->` type(results)
|
|
}];
|
|
}
|
|
|
|
def HLOClient_BroadcastAddOp : HLOClient_BroadcastBinaryElementwiseOp<"broadcast_add",
|
|
[Commutative, NoSideEffect, SameOperandsAndResultElementType]> {
|
|
string summary = "Addition operator (with optional broadcasting)";
|
|
|
|
string description = [{
|
|
Returns `lhs + rhs` element-wise.
|
|
|
|
See
|
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
|
|
}];
|
|
}
|
|
|
|
def HLOClient_BroadcastAtan2Op : HLOClient_BroadcastBinaryElementwiseOp<
|
|
"broadcast_atan2",
|
|
[NoSideEffect, SameOperandsAndResultElementType]> {
|
|
string summary = "Atan2 operator (with optional broadcasting)";
|
|
|
|
string description = [{
|
|
Returns `atan2(lhs/rhs)` element-wise.
|
|
|
|
See
|
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
|
|
}];
|
|
}
|
|
|
|
def HLOClient_BroadcastDivOp : HLOClient_BroadcastBinaryElementwiseOp<
|
|
"broadcast_divide",
|
|
[NoSideEffect, SameOperandsAndResultElementType]> {
|
|
string summary = "Division operator (with optional broadcasting)";
|
|
|
|
string description = [{
|
|
Returns `lhs / rhs` element-wise.
|
|
|
|
See
|
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
|
|
}];
|
|
}
|
|
|
|
def HLOClient_BroadcastMaxOp : HLOClient_BroadcastBinaryElementwiseOp<
|
|
"broadcast_maximum",
|
|
[Commutative, NoSideEffect, SameOperandsAndResultElementType]> {
|
|
string summary = "Maximum operator (with optional broadcasting)";
|
|
|
|
string description = [{
|
|
Returns `max(lhs, rhs)` element-wise.
|
|
|
|
See
|
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
|
|
}];
|
|
}
|
|
|
|
def HLOClient_BroadcastMinOp : HLOClient_BroadcastBinaryElementwiseOp<
|
|
"broadcast_minimum",
|
|
[Commutative, NoSideEffect, SameOperandsAndResultElementType]> {
|
|
string summary = "Minimum operator (with optional broadcasting)";
|
|
|
|
string description = [{
|
|
Returns `min(lhs, rhs)` element-wise.
|
|
|
|
See
|
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
|
|
}];
|
|
}
|
|
|
|
def HLOClient_BroadcastMulOp : HLOClient_BroadcastBinaryElementwiseOp<
|
|
"broadcast_multiply",
|
|
[Commutative, NoSideEffect, SameOperandsAndResultElementType]> {
|
|
string summary = "Multiplication operator (with optional broadcasting)";
|
|
|
|
string description = [{
|
|
Returns `lhs * rhs` element-wise.
|
|
|
|
See
|
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
|
|
}];
|
|
}
|
|
|
|
def HLOClient_BroadcastPolygammaOp : HLOClient_BroadcastBinaryElementwiseOp<
|
|
"broadcast_polygamma", [NoSideEffect, SameOperandsAndResultElementType]> {
|
|
let summary = "Polygamma function (with optional broadcasting)";
|
|
|
|
let description = [{
|
|
Returns `Polygamma(operand, operand)` element-wise.
|
|
}];
|
|
}
|
|
|
|
def HLOClient_BroadcastPowOp : HLOClient_BroadcastBinaryElementwiseOp<
|
|
"broadcast_power",
|
|
[NoSideEffect, SameOperandsAndResultElementType]> {
|
|
string summary = "Power operator (with optional broadcasting)";
|
|
|
|
string description = [{
|
|
Returns `lhs ^ rhs` element-wise.
|
|
|
|
See
|
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
|
|
}];
|
|
}
|
|
|
|
def HLOClient_BroadcastRemOp : HLOClient_BroadcastBinaryElementwiseOp<
|
|
"broadcast_remainder",
|
|
[NoSideEffect, SameOperandsAndResultElementType]> {
|
|
string summary = "Remainder operator (with optional broadcasting)";
|
|
|
|
string description = [{
|
|
Returns `lhs % rhs` element-wise.
|
|
|
|
See
|
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
|
|
}];
|
|
}
|
|
|
|
def HLOClient_BroadcastShiftLeftOp : HLOClient_BroadcastBinaryElementwiseOp<
|
|
"broadcast_shift_left",
|
|
[NoSideEffect, SameOperandsAndResultElementType]> {
|
|
string summary = "Shift left operator (with optional broadcasting)";
|
|
|
|
string description = [{
|
|
Returns `lhs << rhs` element-wise.
|
|
|
|
See
|
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
|
|
}];
|
|
}
|
|
|
|
def HLOClient_BroadcastShiftRightArithmeticOp : HLOClient_BroadcastBinaryElementwiseOp<
|
|
"broadcast_shift_right_arithmetic",
|
|
[NoSideEffect, SameOperandsAndResultElementType]> {
|
|
string summary = "Shift right arithmetic operator (with optional broadcasting)";
|
|
|
|
string description = [{
|
|
Returns `lhs >> rhs` element-wise.
|
|
|
|
See
|
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
|
|
}];
|
|
}
|
|
|
|
def HLOClient_BroadcastShiftRightLogicalOp : HLOClient_BroadcastBinaryElementwiseOp<
|
|
"broadcast_shift_right_logical",
|
|
[NoSideEffect, SameOperandsAndResultElementType]> {
|
|
string summary = "Shift right logical operator (with optional broadcasting)";
|
|
|
|
string description = [{
|
|
Returns `lhs >> rhs` element-wise.
|
|
|
|
See
|
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
|
|
}];
|
|
}
|
|
|
|
def HLOClient_BroadcastSubOp : HLOClient_BroadcastBinaryElementwiseOp<
|
|
"broadcast_subtract",
|
|
[NoSideEffect, SameOperandsAndResultElementType]> {
|
|
string summary = "Subtraction operator (with optional broadcasting)";
|
|
|
|
string description = [{
|
|
Returns `lhs - rhs` element-wise.
|
|
|
|
See
|
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
|
|
}];
|
|
}
|
|
|
|
def HLOClient_BroadcastZetaOp : HLOClient_BroadcastBinaryElementwiseOp<
|
|
"broadcast_zeta",
|
|
[NoSideEffect, SameOperandsAndResultElementType]> {
|
|
let summary = "Hurwitz zeta function";
|
|
|
|
let description = [{
|
|
Returns `Zeta(operand, operand)` element-wise.
|
|
|
|
$$
|
|
\(\zeta(x, q) = \sum_{n=0}^{\infty} (q + n)^{-x}\)
|
|
$$
|
|
}];
|
|
|
|
let arguments = (ins
|
|
HLO_FpTensor:$lhs,
|
|
HLO_FpTensor:$rhs,
|
|
// Explicit rank-broadcast dimension mappings. Defaults to "numpy" prefix
|
|
// padded rank-broadcast semantics if omitted.
|
|
OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions
|
|
);
|
|
let results = (outs HLO_FpTensor);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// XLA binary logical elementwise op definitions.
|
|
// The same description as the arithmetic binary elementwise ops applies.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
class HLOClient_BroadcastBinaryLogicalElementwiseOp<string mnemonic> :
|
|
HLOClient_BroadcastBinaryElementwiseOp<
|
|
mnemonic, [Commutative, NoSideEffect]> {
|
|
let arguments = (ins
|
|
HLO_PredOrIntTensor:$lhs,
|
|
HLO_PredOrIntTensor:$rhs,
|
|
// Explicit rank-broadcast dimension mappings. Defaults to "numpy" prefix
|
|
// padded rank-broadcast semantics if omitted.
|
|
OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions
|
|
);
|
|
}
|
|
|
|
def HLOClient_BroadcastAndOp: HLOClient_BroadcastBinaryLogicalElementwiseOp<
|
|
"broadcast_and"> {
|
|
string summary = "Logical and operator (with optional broadcasting)";
|
|
|
|
string description = [{
|
|
Returns `logical_and(lhs, rhs)` element-wise.
|
|
|
|
See
|
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
|
|
}];
|
|
}
|
|
|
|
def HLOClient_BroadcastOrOp: HLOClient_BroadcastBinaryLogicalElementwiseOp<
|
|
"broadcast_or"> {
|
|
string summary = "Logical or operator (with optional broadcasting)";
|
|
|
|
string description = [{
|
|
Returns `logical_or(lhs, rhs)` element-wise.
|
|
|
|
See
|
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
|
|
}];
|
|
}
|
|
|
|
def HLOClient_BroadcastXorOp : HLOClient_BroadcastBinaryLogicalElementwiseOp<
|
|
"broadcast_xor"> {
|
|
string summary = "Logical xor operator (with optional broadcasting)";
|
|
|
|
string description = [{
|
|
Returns `logical_xor(lhs, rhs)` element-wise.
|
|
|
|
See
|
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
|
|
}];
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// XLA non-broadcasting binary operations.
|
|
//
|
|
// These are operations that are supported by the XLA Builder API but that are
|
|
// not part of the HLO compiler instructions as modelled by the MHLO dialect.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
def HLOClient_ZetaOp : HLOClient_Op<"zeta", [NoSideEffect,
|
|
SameOperandsAndResultType]> {
|
|
let summary = "Hurwitz zeta function";
|
|
let description = [{
|
|
Returns `Zeta(operand, operand)` element-wise.
|
|
|
|
$$
|
|
\(\zeta(x, q) = \sum_{n=0}^{\infty} (q + n)^{-x}\)
|
|
$$
|
|
}];
|
|
|
|
let arguments = (ins HLO_FpTensor:$x, HLO_FpTensor:$q);
|
|
let results = (outs HLO_FpTensor:$result);
|
|
|
|
let assemblyFormat = [{
|
|
$x `,` $q attr-dict `:` type($x) `,` type($q) `->` type(results)
|
|
}];
|
|
}
|
|
|
|
def HLOClient_PolygammaOp : HLOClient_Op<"polygamma", [NoSideEffect,
|
|
SameOperandsAndResultType]> {
|
|
let summary = "Polygamma function";
|
|
let description = [{
|
|
Returns `Polygamma(operand, operand)` element-wise.
|
|
}];
|
|
|
|
let arguments = (ins HLO_FpTensor:$n, HLO_FpTensor:$x);
|
|
let results = (outs HLO_FpTensor:$result);
|
|
|
|
let assemblyFormat = [{
|
|
$n `,` $x attr-dict `:` type($n) `,` type($x) `->` type(results)
|
|
}];
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Broadcasting complex op
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
def HLOClient_BroadcastComplexOp : HLOClient_BroadcastBinaryElementwiseOp<
|
|
"broadcast_complex", [NoSideEffect]> {
|
|
string summary = "Complex operator (with optional broadcasting)";
|
|
|
|
string description = [{
|
|
Performs element-wise conversion of a pair of real and imaginary values to
|
|
a complex value.
|
|
}];
|
|
|
|
let arguments = (ins
|
|
HLO_FpTensor:$lhs,
|
|
HLO_FpTensor:$rhs,
|
|
// Explicit rank-broadcast dimension mappings. Defaults to "numpy" prefix
|
|
// padded rank-broadcast semantics if omitted.
|
|
OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions
|
|
);
|
|
let results = (outs HLO_ComplexTensor);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Unary op
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
class HLOClient_UnaryElementwiseOp<string mnemonic, list<OpTrait> traits,
|
|
Type ArgTensorType, Type ResultTensorType> : HLOClient_Op<mnemonic,
|
|
!listconcat(traits, [InferFusibilityOpInterface, NoSideEffect,
|
|
SameOperandsAndResultShape])> {
|
|
let arguments = (ins ArgTensorType:$operand);
|
|
let results = (outs ResultTensorType:$result);
|
|
|
|
let assemblyFormat = [{
|
|
$operand attr-dict `:` type($operand) `->` type($result)
|
|
}];
|
|
}
|
|
|
|
def HLOClient_AcosOp : HLOClient_UnaryElementwiseOp<"acos",
|
|
[SameOperandsAndResultType], HLO_FpOrComplexTensor, HLO_FpOrComplexTensor> {
|
|
let summary = "Acos operator";
|
|
|
|
let description = [{
|
|
Returns `Acos(operand)` element-wise.
|
|
|
|
$$
|
|
\acos(x) = 2 * \atan(\sqrt(1 - x^2) / (1 + x)) if x != -1
|
|
= pi if x == -1
|
|
$$
|
|
}];
|
|
}
|
|
|
|
def HLOClient_AcoshOp : HLOClient_UnaryElementwiseOp<"acosh",
|
|
[SameOperandsAndResultType], HLO_FpOrComplexTensor, HLO_FpOrComplexTensor> {
|
|
let summary = "Acosh operation";
|
|
|
|
let description = [{
|
|
Returns `Acosh(operand)` element-wise.
|
|
|
|
$$
|
|
\acosh(x) = log(x + sqrt(x^2 - 1)) if x >= -1
|
|
\acosh(x) = nan if x < -1
|
|
$$
|
|
}];
|
|
}
|
|
|
|
def HLOClient_AsinOp : HLOClient_UnaryElementwiseOp<"asin",
|
|
[SameOperandsAndResultType], HLO_FpOrComplexTensor, HLO_FpOrComplexTensor> {
|
|
let summary = "Asin operator";
|
|
|
|
let description = [{
|
|
Returns `Asin(operand)` element-wise.
|
|
|
|
$$
|
|
\asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2)))
|
|
$$
|
|
}];
|
|
}
|
|
|
|
def HLOClient_AsinhOp : HLOClient_UnaryElementwiseOp<"asinh",
|
|
[SameOperandsAndResultType], HLO_FpOrComplexTensor, HLO_FpOrComplexTensor> {
|
|
let summary = "Asinh operation";
|
|
|
|
let description = [{
|
|
Returns `Asinh(operand)` element-wise.
|
|
|
|
$$
|
|
\asinh(x) = log(x + sqrt(x^2 + 1))
|
|
$$
|
|
}];
|
|
}
|
|
|
|
def HLOClient_AtanOp : HLOClient_UnaryElementwiseOp<"atan",
|
|
[SameOperandsAndResultType], HLO_FpOrComplexTensor, HLO_FpOrComplexTensor> {
|
|
let summary = "Atan operator";
|
|
|
|
let description = [{
|
|
Returns `Atan(operand)` element-wise.
|
|
|
|
$$
|
|
\atan(x) = \atan2(x, 1)
|
|
$$
|
|
}];
|
|
}
|
|
|
|
def HLOClient_AtanhOp : HLOClient_UnaryElementwiseOp<"atanh",
|
|
[SameOperandsAndResultType], HLO_FpOrComplexTensor, HLO_FpOrComplexTensor> {
|
|
let summary = "Atanh operator";
|
|
|
|
let description = [{
|
|
Returns `Atanh(operand)` element-wise.
|
|
|
|
$$
|
|
\atanh(x) = 0.5 * log((1 + x) / (1 - x)) if abs(x) <= 1
|
|
= nan otherwise
|
|
$$
|
|
}];
|
|
}
|
|
|
|
def HLOClient_ConjOp : HLOClient_UnaryElementwiseOp<"conj",
|
|
[SameOperandsAndResultType], HLO_FpOrComplexTensor, HLO_FpOrComplexTensor> {
|
|
let summary = "Conj operator";
|
|
|
|
let description = [{
|
|
Returns `Conj(operand)` element-wise.
|
|
|
|
$$
|
|
\conj(x) = (\real(x), \neg(\imag(x)))
|
|
$$
|
|
}];
|
|
}
|
|
|
|
def HLOClient_CoshOp : HLOClient_UnaryElementwiseOp<"cosh",
|
|
[SameOperandsAndResultType], HLO_FpOrComplexTensor, HLO_FpOrComplexTensor> {
|
|
let summary = "Cosh operator";
|
|
|
|
let description = [{
|
|
Returns `Cosh(operand)` element-wise.
|
|
|
|
$$
|
|
\cosh(x) = (e^x + e^-x) / 2
|
|
$$
|
|
}];
|
|
}
|
|
|
|
def HLOClient_SinhOp : HLOClient_UnaryElementwiseOp<"sinh",
|
|
[SameOperandsAndResultType], HLO_FpOrComplexTensor, HLO_FpOrComplexTensor> {
|
|
let summary = "Sinh operation";
|
|
|
|
let description = [{
|
|
Returns `Sinh(operand)` element-wise.
|
|
|
|
$$
|
|
\sinh(x) = (e^x - e^-x) / 2 if |x| < 1
|
|
= e^(x + log(1/2)) - e^(-x + log(1/2)) otherwise.
|
|
$$
|
|
}];
|
|
}
|
|
|
|
def HLOClient_TanOp : HLOClient_UnaryElementwiseOp<"tan",
|
|
[SameOperandsAndResultType], HLO_FpOrComplexTensor, HLO_FpOrComplexTensor> {
|
|
let summary = "Tan operation";
|
|
|
|
let description = [{
|
|
Returns `Tan(operand)` element-wise.
|
|
|
|
$$
|
|
\tan(x) = \sin(x) / \cos(x)
|
|
$$
|
|
}];
|
|
}
|
|
|
|
def HLOClient_ConstantLikeOp : HLOClient_Op<"constant_like",
|
|
[NoSideEffect, SameOperandsAndResultShape,
|
|
InferTypeOpInterface,
|
|
DeclareOpInterfaceMethods<InferShapedTypeOpInterface>,
|
|
NativeOpTrait<"InferTensorType">]> {
|
|
let summary = "Constant like operator";
|
|
|
|
let description = [{
|
|
Returns a splat constant of the same shape as the operand.
|
|
}];
|
|
|
|
// TODO(jpienaar): value's type could be tightened.
|
|
let arguments = (ins AnyAttr:$value, HLO_Tensor:$operand);
|
|
let results = (outs HLO_Tensor);
|
|
|
|
let hasCanonicalizer = 1;
|
|
}
|
|
|
|
def HLOClient_DigammaOp : HLOClient_UnaryElementwiseOp<"digamma",
|
|
[SameOperandsAndResultType], HLO_FpTensor, HLO_FpTensor> {
|
|
let summary = "Digamma function";
|
|
|
|
let description = [{
|
|
Returns `Digamma(operand)` element-wise.
|
|
}];
|
|
}
|
|
|
|
def HLOClient_ErfOp : HLOClient_UnaryElementwiseOp<"erf",
|
|
[SameOperandsAndResultType], HLO_FpTensor, HLO_FpTensor> {
|
|
let summary = "Erfc operator";
|
|
|
|
let description = [{
|
|
Computes the Gauss error function of `x` element-wise.
|
|
|
|
erf(x) = erf_impl(x) if |x| < 1
|
|
= 1 - erfc_impl(x) otherwise
|
|
}];
|
|
}
|
|
|
|
def HLOClient_ErfcOp : HLOClient_UnaryElementwiseOp<"erfc",
|
|
[SameOperandsAndResultType], HLO_FpTensor, HLO_FpTensor> {
|
|
let summary = "Erfc operator";
|
|
|
|
let description = [{
|
|
Computes an approximation of the error function complement (1 - erf(x)).
|
|
|
|
erfc(x) = erfc_impl(x) if |x| > 1
|
|
= 1 - erf_impl(x) otherwise
|
|
}];
|
|
}
|
|
|
|
def HLOClient_IsInfOp : HLOClient_UnaryElementwiseOp<"is_inf",
|
|
[DeclareOpInterfaceMethods<InferTypeOpInterface>], HLO_FpTensor,
|
|
HLO_PredTensor> {
|
|
let summary = "IsInf predicate";
|
|
|
|
let description = [{
|
|
Returns if a value is +/-inf element-wise.
|
|
}];
|
|
}
|
|
|
|
def HLOClient_IsNegInfOp : HLOClient_UnaryElementwiseOp<"is_neg_inf",
|
|
[DeclareOpInterfaceMethods<InferTypeOpInterface>], HLO_FpTensor,
|
|
HLO_PredTensor> {
|
|
let summary = "IsNegInf predicate";
|
|
|
|
let description = [{
|
|
Returns if a value is -inf element-wise.
|
|
}];
|
|
}
|
|
|
|
def HLOClient_IsPosInfOp : HLOClient_UnaryElementwiseOp<"is_pos_inf",
|
|
[DeclareOpInterfaceMethods<InferTypeOpInterface>], HLO_FpTensor,
|
|
HLO_PredTensor> {
|
|
let summary = "IsPosInf predicate";
|
|
|
|
let description = [{
|
|
Returns if a value is +inf element-wise.
|
|
}];
|
|
}
|
|
|
|
def HLOClient_LgammaOp : HLOClient_UnaryElementwiseOp<"lgamma",
|
|
[SameOperandsAndResultType], HLO_FpTensor, HLO_FpTensor> {
|
|
let summary = "Lgamma function";
|
|
|
|
let description = [{
|
|
Returns `Lgamma(operand)` element-wise.
|
|
}];
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Broadcasting compare op
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
def HLOClient_BroadcastCompareOp : HLOClient_BroadcastBinaryElementwiseOp<
|
|
"broadcast_compare", [NoSideEffect]> {
|
|
string summary = "Compare operator (with optional broadcasting)";
|
|
|
|
string description = [{
|
|
Compares `lhs` and `rhs` elementwise according to `comparison_direction`
|
|
and `compare_type`. If unspecified, `compare_type` is FLOAT for float element
|
|
types, SIGNED for signed element types and UNSIGNED for unsigned element
|
|
types.
|
|
|
|
See
|
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_comparison_operations.
|
|
}];
|
|
|
|
let arguments = (ins
|
|
HLO_Tensor:$lhs,
|
|
HLO_Tensor:$rhs,
|
|
OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions,
|
|
HLO_ComparisonDirectionAttr:$comparison_direction,
|
|
OptionalAttr<HLO_ComparisonTypeAttr>:$compare_type
|
|
);
|
|
let results = (outs HLO_PredTensor);
|
|
|
|
let builders = [
|
|
OpBuilderDAG<(ins "Value":$lhs, "Value":$rhs,
|
|
"DenseIntElementsAttr":$broadcast_dimensions,
|
|
"StringAttr":$comparison_direction, CArg<"StringAttr", "{}">:$compare_type)>];
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Broadcasting select op
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
def HLOClient_BroadcastSelectOp : HLOClient_Op<
|
|
"broadcast_select",
|
|
[NoSideEffect, DeclareOpInterfaceMethods<InferShapedTypeOpInterface>]> {
|
|
string summary = "Select operator (with optional numpy-style broadcasting)";
|
|
|
|
string description = [{
|
|
Constructs an output array from elements of two input arrays, based on the
|
|
values of a predicate array.
|
|
|
|
See https://www.tensorflow.org/xla/operation_semantics#select
|
|
}];
|
|
|
|
let arguments = (ins
|
|
HLO_PredTensor:$pred,
|
|
HLO_Tensor:$on_true,
|
|
HLO_Tensor:$on_false
|
|
);
|
|
|
|
let results = (outs HLO_Tensor);
|
|
|
|
let assemblyFormat = [{
|
|
$pred `,` $on_true `,` $on_false attr-dict `:`
|
|
`(` type($pred) `,` type($on_true) `,` type($on_false) `)` `->` type(results)
|
|
}];
|
|
}
|
|
|
|
#endif // CHLO_OPS
|