From a4db6c57aa8e4446d595b3f8341223c4caabfad7 Mon Sep 17 00:00:00 2001 From: Itai Zukerman Date: Tue, 11 May 2021 09:47:40 -0700 Subject: [PATCH] Removed all (most) BASE_HLO_* ops. Moved the corresponding `summary` and `description` fields into the subclasses. Kept BASE_HLO_ConvOp for `hasWindowReversal()'. PiperOrigin-RevId: 373173025 --- include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td | 964 ++++++++++++-- .../mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td | 1186 ----------------- .../mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.td | 27 +- include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td | 809 +++++++++-- 4 files changed, 1559 insertions(+), 1427 deletions(-) diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td index f8f1c45..1ec7013 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td @@ -50,8 +50,11 @@ def HLO_FusionKindAttr : StrEnumAttr<"FusionKind", "fusion kind", [ //===----------------------------------------------------------------------===// def HLO_ConstOp : HLO_Op<"constant", - [ConstantLike, NoSideEffect, AllTypesMatch<["value", "output"]>]>, - BASE_HLO_ConstOp { + [ConstantLike, NoSideEffect, AllTypesMatch<["value", "output"]>]> { + let summary = "Constant operator"; + let description = [{ + Represents a constant value. + }]; let arguments = (ins ElementsAttr:$value ); @@ -71,7 +74,11 @@ def HLO_ConstOp : HLO_Op<"constant", let hasCustomHLOConverter = 1; } -def HLO_IotaOp : HLO_Op<"iota", [NoSideEffect]>, BASE_HLO_IotaOp { +def HLO_IotaOp : HLO_Op<"iota", [NoSideEffect]> { + let summary = "Iota operator"; + let description = [{ + Creates a rank 1 array of values starting at zero and incrementing by one. + }]; let arguments = (ins I64Attr:$iota_dimension); let results = (outs HLO_IntFpOrComplexTensor:$output); @@ -102,9 +109,9 @@ def HLO_DynamicIotaOp: HLO_Op<"dynamic_iota", [NoSideEffect]> { def HLO_CreateTokenOp : HLO_Op<"create_token", [NoSideEffect]> { - string summary = "Create Token operator"; + let summary = "Create Token operator"; - string description = [{ + let description = [{ Produces a HLO token. Tokens are used for ordering side-effecting perations. This is exported to HLO as an AfterAll operation with no operands to generate a token. @@ -149,18 +156,45 @@ class HLO_UnaryElementwiseOp traits, def HLO_AbsOp: HLO_UnaryElementwiseOp<"abs", [NoSideEffect, DeclareOpInterfaceMethods], - TensorOf<[HLO_SInt, AnyFloat, HLO_Complex]>>, BASE_HLO_AbsOp { + TensorOf<[HLO_SInt, AnyFloat, HLO_Complex]>> { + let summary = "Absolute value operator"; + let description = [{ + Returns `abs(operand)` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; } def HLO_CbrtOp: HLO_UnaryElementwiseOp<"cbrt", - [NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_CbrtOp; + [NoSideEffect, SameOperandsAndResultType], HLO_FpTensor> { + let summary = "Cubic root operator"; + let description = [{ + Returns element-wise cubic root of the operand. + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; +} def HLO_CeilOp: HLO_UnaryElementwiseOp<"ceil", - [NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_CeilOp; + [NoSideEffect, SameOperandsAndResultType], HLO_FpTensor> { + let summary = "Ceil operator"; + let description = [{ + Returns `Ceil(operand)` element-wise. + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; +} def HLO_ConvertOp : HLO_UnaryElementwiseOp<"convert", - [NoSideEffect, SameOperandsAndResultShape], HLO_Tensor>, - BASE_HLO_ConvertOp { + [NoSideEffect, SameOperandsAndResultShape], HLO_Tensor> { + let summary = "Convert operator"; + let description = [{ + Performs element-wise conversion of values from one type to another, e.g. + float to int. + + See https://www.tensorflow.org/xla/operation_semantics#convertelementtype. + }]; let builders = [ OpBuilder<(ins "Value":$operand, "Type":$result_element_ty)>]; @@ -170,101 +204,232 @@ def HLO_ConvertOp : HLO_UnaryElementwiseOp<"convert", } def HLO_ClzOp: HLO_UnaryElementwiseOp<"count_leading_zeros", - [NoSideEffect, SameOperandsAndResultType], HLO_IntTensor>, - BASE_HLO_ClzOp; + [NoSideEffect, SameOperandsAndResultType], HLO_IntTensor> { + let summary = "Count-leading-zeros (Clz) operator"; + let description = [{ + Returns the number of leading zeros in each operand element-wise. + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; +} def HLO_CosOp: HLO_UnaryElementwiseOp<"cosine", - [NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor>, - BASE_HLO_CosOp; + [NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor> { + let summary = "Cos operator"; + let description = [{ + Returns `Cos(operand)` element-wise. + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; +} def HLO_ExpOp: HLO_UnaryElementwiseOp<"exponential", - [NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor>, - BASE_HLO_ExpOp; + [NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor> { + let summary = "Exponential operator"; + let description = [{ + Returns `e^(operand)` element-wise. + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; +} def HLO_Expm1Op: HLO_UnaryElementwiseOp<"exponential_minus_one", - [NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor>, - BASE_HLO_Expm1Op; + [NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor> { + let summary = "Exponential minus one operator"; + let description = [{ + Returns `e^(operand) - 1` element-wise. + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; +} def HLO_FloorOp: HLO_UnaryElementwiseOp<"floor", - [NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_FloorOp; + [NoSideEffect, SameOperandsAndResultType], HLO_FpTensor> { + let summary = "Floor operator"; + let description = [{ + Returns `Floor(operand)` element-wise. + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; +} def HLO_ImagOp: HLO_UnaryElementwiseOp<"imag", [NoSideEffect, DeclareOpInterfaceMethods], - HLO_ComplexTensor>, BASE_HLO_ImagOp { + HLO_ComplexTensor> { + let summary = "Imag operator"; + let description = [{ + Returns `Imag(operand)` element-wise. + }]; let results = (outs HLO_FpTensor); let hasFolder = 1; } def HLO_IsFiniteOp: HLO_UnaryElementwiseOp<"is_finite", [NoSideEffect, - DeclareOpInterfaceMethods], HLO_Tensor>, - BASE_HLO_IsFiniteOp { + DeclareOpInterfaceMethods], HLO_Tensor> { + let summary = "IsFinite operator"; + let description = [{ + Tests whether each element of operand is finite, i.e., is not positive or + negative infinity, and is not NaN. Returns a tensor of 1-bit integers with + the same shape as the input, where each element is nonzero (i.e. true) if + and only if the corresponding input element is finite. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; let arguments = (ins HLO_FpTensor:$x); let results = (outs HLO_PredTensor:$y); } def HLO_LogOp: HLO_UnaryElementwiseOp<"log", - [NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor>, - BASE_HLO_LogOp; + [NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor> { + let summary = "Logarithm operator"; + let description = [{ + Returns `log(operand)` element-wise. + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; +} def HLO_Log1pOp: HLO_UnaryElementwiseOp<"log_plus_one", - [NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor>, - BASE_HLO_Log1pOp; + [NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor> { + let summary = "Log1p operator"; + let description = [{ + Returns `log(operand+1)` element-wise. + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; +} def HLO_LogisticOp: HLO_UnaryElementwiseOp<"logistic", - [NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor>, - BASE_HLO_LogisticOp; + [NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor> { + let summary = "Logistic operator"; + let description = [{ + Returns `logistic(operand)` element-wise. + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; +} def HLO_NotOp: HLO_UnaryElementwiseOp<"not", - [NoSideEffect, SameOperandsAndResultType], HLO_PredOrIntTensor>, - BASE_HLO_NotOp { + [NoSideEffect, SameOperandsAndResultType], HLO_PredOrIntTensor> { + let summary = "Not operator"; + let description = [{ + Returns `!operand` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; let hasFolder = 1; } def HLO_NegOp: HLO_UnaryElementwiseOp<"negate", - [NoSideEffect, SameOperandsAndResultType], HLO_IntFpOrComplexTensor>, - BASE_HLO_NegOp { + [NoSideEffect, SameOperandsAndResultType], HLO_IntFpOrComplexTensor> { + let summary = "Negation operator"; + let description = [{ + Returns `-operand` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; let hasFolder = 1; } def HLO_PopulationCountOp: HLO_UnaryElementwiseOp<"popcnt", - [NoSideEffect, SameOperandsAndResultType], HLO_IntTensor>, - BASE_HLO_PopulationCountOp; + [NoSideEffect, SameOperandsAndResultType], HLO_IntTensor> { + let summary = "PopulationCount operator"; + let description = [{ + Returns the number of bits set in each operand element-wise. + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; +} def HLO_RealOp: HLO_UnaryElementwiseOp<"real", [NoSideEffect, DeclareOpInterfaceMethods], - HLO_ComplexTensor>, BASE_HLO_RealOp { + HLO_ComplexTensor> { + let summary = "Real operator"; + let description = [{ + Returns `Real(operand)` element-wise. + }]; let results = (outs HLO_FpTensor); let hasFolder = 1; } def HLO_RoundOp: HLO_UnaryElementwiseOp<"round_nearest_afz", - [NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_RoundOp { + [NoSideEffect, SameOperandsAndResultType], HLO_FpTensor> { + let summary = "Round operator"; + let description = [{ + Returns `Round(operand)` element-wise, rounding to nearest integer with + half-way cases rounding away from zero. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; let hasFolder = 1; } def HLO_RsqrtOp: HLO_UnaryElementwiseOp<"rsqrt", - [NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor>, - BASE_HLO_RsqrtOp; + [NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor> { + let summary = "Reciprocal Square-root operator"; + let description = [{ + Returns `1.0 / sqrt(operand)` element-wise. + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; +} def HLO_SignOp: HLO_UnaryElementwiseOp<"sign", [NoSideEffect, SameOperandsAndResultType], - TensorOf<[HLO_SInt, AnyFloat, HLO_Complex]>>, - BASE_HLO_SignOp; + TensorOf<[HLO_SInt, AnyFloat, HLO_Complex]>> { + let summary = "Sign operator"; + let description = [{ + Returns `sign(operand)` element-wise, where + ``` + sign(x) = -1 : x < 0 + = -0 : x = -0 + = NaN : x = NaN + = +0 : x = +0 + = 1 : x > 0 + ``` + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; +} def HLO_SinOp: HLO_UnaryElementwiseOp<"sine", - [NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor>, - BASE_HLO_SinOp; + [NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor> { + let summary = "Sin operator"; + let description = [{ + Returns `Sin(operand)` element-wise. + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; +} def HLO_SqrtOp: HLO_UnaryElementwiseOp<"sqrt", - [NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor>, - BASE_HLO_SqrtOp { + [NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor> { + let summary = "Square-root operator"; + let description = [{ + Returns `sqrt(operand)` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; let hasFolder = 1; } def HLO_TanhOp: HLO_UnaryElementwiseOp<"tanh", [NoSideEffect, SameOperandsAndResultType], - HLO_FpOrComplexTensor>, BASE_HLO_TanhOp; + HLO_FpOrComplexTensor> { + let summary = "Tanh operator"; + let description = [{ + Returns `tanh(operand)` element-wise. + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; +} //===----------------------------------------------------------------------===// // MHLO binary elementwise op definitions. //===----------------------------------------------------------------------===// @@ -307,60 +472,151 @@ class HLO_BinaryElementwiseOp traits> : } def HLO_AddOp : HLO_BinaryElementwiseOp<"add", - [Commutative, NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_AddOp { + [Commutative, NoSideEffect, SameOperandsAndResultType]> { + let summary = "Addition operator"; + let description = [{ + Returns `lhs + rhs` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; let hasFolder = 1; } def HLO_Atan2Op : HLO_BinaryElementwiseOp<"atan2", - [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_Atan2Op; + [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Atan2 operator"; + let description = [{ + Returns `atan2(lhs/rhs)` element-wise. + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} def HLO_ComplexOp: HLO_BinaryElementwiseOp<"complex", - [NoSideEffect, DeclareOpInterfaceMethods]>, - BASE_HLO_ComplexOp { + [NoSideEffect, DeclareOpInterfaceMethods]> { + let summary = "Complex operator"; + let description = [{ + Performs element-wise conversion of a pair of real and imaginary values to + a complex value. + }]; let arguments = (ins HLO_FpTensor:$lhs, HLO_FpTensor:$rhs); let results = (outs HLO_ComplexTensor); let hasFolder = 1; } def HLO_DivOp : HLO_BinaryElementwiseOp<"divide", - [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_DivOp { + [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Division operator"; + let description = [{ + Returns `lhs / rhs` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; let hasFolder = 1; } def HLO_MaxOp : HLO_BinaryElementwiseOp<"maximum", - [Commutative, NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_MaxOp { + [Commutative, NoSideEffect, SameOperandsAndResultType]> { + let summary = "Maximum operator"; + let description = [{ + Returns `max(lhs, rhs)` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; let hasFolder = 1; } def HLO_MinOp : HLO_BinaryElementwiseOp<"minimum", - [Commutative, NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_MinOp { + [Commutative, NoSideEffect, SameOperandsAndResultType]> { + let summary = "Minimum operator"; + let description = [{ + Returns `min(lhs, rhs)` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; let hasFolder = 1; } def HLO_MulOp : HLO_BinaryElementwiseOp<"multiply", - [Commutative, NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_MulOp { + [Commutative, NoSideEffect, SameOperandsAndResultType]> { + let summary = "Multiplication operator"; + let description = [{ + Returns `lhs * rhs` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; let hasFolder = 1; } def HLO_PowOp : HLO_BinaryElementwiseOp<"power", - [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_PowOp; + [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Power operator"; + let description = [{ + Returns `lhs ^ rhs` element-wise. + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} def HLO_RemOp : HLO_BinaryElementwiseOp<"remainder", - [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_RemOp { + [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Remainder operator"; + let description = [{ + Returns `lhs % rhs` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; let hasFolder = 1; } def HLO_ShiftLeftOp : HLO_BinaryElementwiseOp<"shift_left", - [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_ShiftLeftOp; + [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Shift Left operator"; + let description = [{ + Returns `lhs << rhs` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} def HLO_ShiftRightArithmeticOp : HLO_BinaryElementwiseOp<"shift_right_arithmetic", - [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_ShiftRightArithmeticOp; + [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Shift right arithmetic operator"; + let description = [{ + Returns arithmetic `lhs >> rhs` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} def HLO_ShiftRightLogicalOp : HLO_BinaryElementwiseOp<"shift_right_logical", - [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_ShiftRightLogicalOp; + [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Shift right logical operator"; + let description = [{ + Returns logical `lhs >> rhs` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} def HLO_SubOp : HLO_BinaryElementwiseOp<"subtract", - [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_SubOp { + [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Subtraction operator"; + let description = [{ + Returns `lhs - rhs` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; let hasFolder = 1; } @@ -380,9 +636,35 @@ class HLO_BinaryLogicalElementwiseOp : let hasFolder = 1; } -def HLO_AndOp: HLO_BinaryLogicalElementwiseOp<"and">, BASE_HLO_AndOp; -def HLO_OrOp: HLO_BinaryLogicalElementwiseOp<"or">, BASE_HLO_OrOp; -def HLO_XorOp : HLO_BinaryLogicalElementwiseOp<"xor">, BASE_HLO_XorOp; +def HLO_AndOp: HLO_BinaryLogicalElementwiseOp<"and"> { + let summary = "Logical and"; + let description = [{ + Returns `logical_and(lhs, rhs)` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} + +def HLO_OrOp: HLO_BinaryLogicalElementwiseOp<"or"> { + let summary = "Logical or"; + let description = [{ + Returns `logical_or(lhs, rhs)` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} + +def HLO_XorOp : HLO_BinaryLogicalElementwiseOp<"xor"> { + let summary = "Logical xor"; + let description = [{ + Returns `logical_xor(lhs, rhs)` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} //===----------------------------------------------------------------------===// // MHLO communication op definitions. @@ -392,9 +674,9 @@ def HLO_XorOp : HLO_BinaryLogicalElementwiseOp<"xor">, BASE_HLO_XorOp; // InfeedWithToken allows ordering of infeed HLO instructions using tokens. def HLO_InfeedOp : HLO_Op<"infeed", []> { - string summary = "Infeed operator"; + let summary = "Infeed operator"; - string description = [{ + let description = [{ Reads a single data item from the implicit Infeed streaming interface of the device, interpreting the data as the given shape, and returns a XlaOp of the data. Multiple Infeed operations are allowed in a computation, but @@ -421,9 +703,9 @@ def HLO_InfeedOp : HLO_Op<"infeed", []> { // OutfeedWithToken allows ordering of outfeed HLO instructions using tokens. def HLO_OutfeedOp : HLO_Op<"outfeed", []> { - string summary = "Outfeed operator"; + let summary = "Outfeed operator"; - string description = [{ + let description = [{ Generates outgoing data transfers for the given data. It takes data and a token type operand and produces a token type value. Tokens are used for ordering side-effecting operations. @@ -442,9 +724,9 @@ def HLO_OutfeedOp : HLO_Op<"outfeed", []> { def HLO_SendOp : HLO_Op<"send", []> { - string summary = "Send operator"; + let summary = "Send operator"; - string description = [{ + let description = [{ Sends the given operand data to a Recv instruction in another computation that shares the same channel handle. Does not return any data. Similar to the Recv operation, Send operation represents synchronous communication, @@ -467,9 +749,9 @@ def HLO_SendOp : HLO_Op<"send", []> { def HLO_RecvOp : HLO_Op<"recv", []> { - string summary = "Recv operator"; + let summary = "Recv operator"; - string description = [{ + let description = [{ Receives data of the given shape from a Send instruction in another computation that shares the same channel handle. Returns a tuple containing value for the received data and a token. Recv operation represents @@ -495,8 +777,18 @@ def HLO_RecvOp : HLO_Op<"recv", []> { //===----------------------------------------------------------------------===// def HLO_ReplicaIdOp : HLO_Op<"replica_id", [NoSideEffect, - DeclareOpInterfaceMethods]>, - BASE_HLO_ReplicaIdOp { + DeclareOpInterfaceMethods]> { + let summary = "ReplicaId operator"; + let description = [{ + Returns the unique ID (int32 scalar) of the replica. + + The unique ID of each replica is an unsigned integer in the interval [0, N), + where N is the number of replicas. Since all the replicas are running the + same program, a ReplicaId() call in the program will return a different + value on each replica. + + See https://www.tensorflow.org/xla/operation_semantics#replicaid. + }]; let results = (outs TensorOf<[UI32]>); } @@ -506,9 +798,9 @@ def HLO_ReplicaIdOp : HLO_Op<"replica_id", [NoSideEffect, def HLO_AfterAllOp : HLO_Op<"after_all", [NoSideEffect]> { - string summary = "AfterAll operator"; + let summary = "AfterAll operator"; - string description = [{ + let description = [{ AfterAll takes a variadic number of tokens and produces a single token. Tokens are primitive types which can be threaded between side-effecting operations to enforce ordering. AfterAll can be used as a join of tokens @@ -527,9 +819,9 @@ def HLO_AfterAllOp : HLO_Op<"after_all", [NoSideEffect]> { def HLO_IfOp: HLO_Op<"if", [ RecursiveSideEffects, SingleBlockImplicitTerminator<"ReturnOp">]> { - string summary = "If operator"; + let summary = "If operator"; - string description = [{ + let description = [{ Returns the result of executing either a true or false function depending on the result of a condition function. @@ -559,7 +851,21 @@ def HLO_IfOp: HLO_Op<"if", [ def HLO_CaseOp: HLO_Op<"case", [ RecursiveSideEffects, SingleBlockImplicitTerminator<"ReturnOp"> - ]>, BASE_HLO_CaseOp { + ]> { + let summary = "Switch-Case operator"; + let description = [{ + Returns the result of executing `branches[index]`. If + `index` is < 0 or >= N, then `branches[N-1] is executed as + the default branch. + + Each branch `branches[b]` must take in a single argument of same type as + `branch_operands[b]` and will be invoked with `branch_operands[b]`. The type + of the returned value of each branch must be the same. + + Note that only one of the branches will be executed depending on the value + of index. + See https://www.tensorflow.org/xla/operation_semantics#conditional. + }]; let arguments = (ins I32Tensor:$index, @@ -580,7 +886,14 @@ def HLO_WhileOp: HLO_Op<"while", [ RecursiveSideEffects, SameOperandsAndResultType, SingleBlockImplicitTerminator<"ReturnOp"> - ]>, BASE_HLO_WhileOp { + ]> { + let summary = "While operator"; + let description = [{ + Returns the result of executing a body function until the cond body returns + true. + + See https://www.tensorflow.org/xla/operation_semantics#while. + }]; let arguments = (ins HLO_TensorOrTuple:$val); let regions = (region SizedRegion<1>:$cond, SizedRegion<1>:$body); @@ -592,7 +905,13 @@ def HLO_WhileOp: HLO_Op<"while", [ } def HLO_AllReduceOp : HLO_Op<"all_reduce", - [SameOperandsAndResultType]>, BASE_HLO_AllReduceOp { + [SameOperandsAndResultType]> { + let summary = "AllReduce operator"; + let description = [{ + Performs a custom reduction across replicas. + + See https://www.tensorflow.org/xla/operation_semantics#allreduce. + }]; let arguments = (ins HLO_Tensor:$operand, @@ -606,7 +925,7 @@ def HLO_AllReduceOp : HLO_Op<"all_reduce", } def HLO_AllToAllOp : HLO_Op<"all_to_all", - [NoSideEffect, SameOperandsElementType, SameOperandsShape]>, BASE_HLO_AllToAllOp { + [NoSideEffect, SameOperandsElementType, SameOperandsShape]> { let arguments = (ins HLO_Tensor:$operand, @@ -623,7 +942,14 @@ def HLO_ReduceOp: HLO_Op<"reduce", [ SameVariadicOperandSize, SingleBlockImplicitTerminator<"ReturnOp">, InferFusibilityOpInterface - ]>, BASE_HLO_ReduceOp { + ]> { + let summary = "Reduce operator"; + let description = [{ + Returns the result of executing a reduction function on one or more arrays + in parallel. + + See https://www.tensorflow.org/xla/operation_semantics#reduce. + }]; let arguments = (ins Variadic:$inputs, Variadic:$init_values, @@ -661,7 +987,13 @@ def HLO_ReduceOp: HLO_Op<"reduce", [ //===----------------------------------------------------------------------===// // MHLO tuple op definitions. //===----------------------------------------------------------------------===// -def HLO_GetTupleElementOp: HLO_Op<"get_tuple_element", [NoSideEffect]>, BASE_HLO_GetTupleElementOp { +def HLO_GetTupleElementOp: HLO_Op<"get_tuple_element", [NoSideEffect]> { + let summary = "GetTupleElement operator"; + let description = [{ + Returns a member of a tuple specified by an index. + + See https://www.tensorflow.org/xla/operation_semantics#gettupleelement. + }]; let arguments = (ins HLO_Tuple, I32Attr:$index @@ -675,7 +1007,13 @@ def HLO_GetTupleElementOp: HLO_Op<"get_tuple_element", [NoSideEffect]>, BASE_HLO OpBuilder<(ins "Value":$value, "int32_t":$index)>]; } -def HLO_TupleOp : HLO_Op<"tuple", [NoSideEffect]>, BASE_HLO_TupleOp { +def HLO_TupleOp : HLO_Op<"tuple", [NoSideEffect]> { + let summary = "XLA's tuple op"; + let description = [{ + Groups a set of tensor inputs into a single tuple object. + + See https://www.tensorflow.org/xla/operation_semantics#tuple. + }]; let arguments = (ins Variadic:$val); let results = (outs HLO_Tuple); @@ -688,8 +1026,17 @@ def HLO_TupleOp : HLO_Op<"tuple", [NoSideEffect]>, BASE_HLO_TupleOp { def HLO_CompareOp: HLO_Op<"compare", [NoSideEffect, SameTypeOperands, SameOperandsAndResultShape, DeclareOpInterfaceMethods]>, - BASE_HLO_CompareOp { + ["inferReturnTypeComponents", "reifyReturnTypeShapes"]>]> { + let summary = "Comparison operator"; + let description = [{ + Compares `lhs` and `rhs` elementwise according to `comparison_direction` + and `compare_type`. If unspecified, `compare_type` is FLOAT for float element + types, SIGNED for signed element types and UNSIGNED for unsigned element + types. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_comparison_operations. + }]; let arguments = (ins HLO_Tensor:$lhs, HLO_Tensor:$rhs, @@ -731,7 +1078,13 @@ def HLO_SliceOp: HLO_Op< } def HLO_DynamicSliceOp: HLO_Op<"dynamic-slice", - [NoSideEffect, AllElementTypesMatch<["operand", "result"]>]>, BASE_HLO_DynamicSliceOp { + [NoSideEffect, AllElementTypesMatch<["operand", "result"]>]> { + let summary = "Dynamic Slice operator"; + let description = [{ + Extracts a sub-array from the input array at dynamic start_indices. + + See https://www.tensorflow.org/xla/operation_semantics#dynamicslice. + }]; let arguments = (ins HLO_Tensor:$operand, Variadic:$start_indices, @@ -744,7 +1097,14 @@ def HLO_DynamicSliceOp: HLO_Op<"dynamic-slice", def HLO_DynamicUpdateSliceOp: HLO_Op<"dynamic-update-slice", [NoSideEffect, AllElementTypesMatch<["operand", "update", "result"]>, - AllShapesMatch<["operand", "result"]>]>, BASE_HLO_DynamicUpdateSliceOp { + AllShapesMatch<["operand", "result"]>]> { + let summary = "Dynamic Update Slice operator"; + let description = [{ + DynamicUpdateSlice generates a result which is the value of the input array + operand, with a slice update overwritten at start_indices. + + See https://www.tensorflow.org/xla/operation_semantics#dynamicupdateslice. + }]; let arguments = (ins HLO_Tensor:$operand, HLO_Tensor:$update, @@ -759,8 +1119,13 @@ def HLO_DynamicUpdateSliceOp: HLO_Op<"dynamic-update-slice", // MHLO Other op definitions. //===----------------------------------------------------------------------===// -def HLO_BatchNormGradOp : HLO_Op<"batch_norm_grad", [NoSideEffect]>, - BASE_HLO_BatchNormGradOp { +def HLO_BatchNormGradOp : HLO_Op<"batch_norm_grad", [NoSideEffect]> { + let summary = "Batch Normalization Gradient"; + let description = [{ + Calculates gradients of batch norm. + + See https://www.tensorflow.org/xla/operation_semantics#batchnormgrad + }]; let arguments = (ins HLO_Tensor:$operand, @@ -776,7 +1141,13 @@ def HLO_BatchNormGradOp : HLO_Op<"batch_norm_grad", [NoSideEffect]>, } def HLO_BatchNormInferenceOp : HLO_Op<"batch_norm_inference", - [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_BatchNormInferenceOp { + [NoSideEffect, SameOperandsAndResultElementType]> { + let summary = "Batch Normalization for Inference"; + let description = [{ + Normalizes an array across batch and spatial dimensions. + + See https://www.tensorflow.org/xla/operation_semantics#batchnorminference + }]; let arguments = (ins HLO_Tensor:$operand, @@ -791,8 +1162,13 @@ def HLO_BatchNormInferenceOp : HLO_Op<"batch_norm_inference", let results = (outs HLO_Tensor); } -def HLO_BatchNormTrainingOp : HLO_Op<"batch_norm_training", [NoSideEffect]>, - BASE_HLO_BatchNormTrainingOp { +def HLO_BatchNormTrainingOp : HLO_Op<"batch_norm_training", [NoSideEffect]> { + let summary = "Batch Normalization for Training"; + let description = [{ + Normalizes an array across batch and spatial dimensions. + + See https://www.tensorflow.org/xla/operation_semantics#batchnormtraining + }]; let arguments = (ins HLO_Tensor:$operand, @@ -806,7 +1182,17 @@ def HLO_BatchNormTrainingOp : HLO_Op<"batch_norm_training", [NoSideEffect]>, } def HLO_BitcastConvertOp : HLO_Op<"bitcast_convert", - [NoSideEffect, SameOperandsAndResultShape]>, BASE_HLO_BitcastConvertOp { + [NoSideEffect, SameOperandsAndResultShape]> { + let summary = "BitcastConvert operator"; + let description = [{ + Similar to a 'tf.bitcast' in TensorFlow, performs an element-wise bitcast + operation from a data shape to a target shape. The dimensions must match, + and the conversion is an element-wise one. Bitcast is implemented as a + low-level cast, so machines with different floating-point representations + will give different results. + + See https://www.tensorflow.org/xla/operation_semantics#bitcastconverttype. + }]; let arguments = (ins HLO_Tensor:$operand); let results = (outs HLO_Tensor); @@ -814,7 +1200,19 @@ def HLO_BitcastConvertOp : HLO_Op<"bitcast_convert", } def HLO_BroadcastOp : HLO_Op<"broadcast", - [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_BroadcastOp { + [NoSideEffect, SameOperandsAndResultElementType]> { + let summary = "Broadcast a tensor to a higher rank by prepending dimensions"; + let description = [{ + Broadcasts the operand tensor to a higher rank by prepending + `broadcast_sizes` to the dimensions. The current values of the operand are + copied into the other dimensions. + + This is a more limited form of broadcasting, that corresponds to the XLA + client Broadcast method. For a more general form of broadcasting, see the + BroadcastInDimOp. + + See https://www.tensorflow.org/xla/operation_semantics#broadcast. + }]; let arguments = (ins HLO_Tensor:$operand, I64ElementsAttr:$broadcast_sizes @@ -824,7 +1222,24 @@ def HLO_BroadcastOp : HLO_Op<"broadcast", } def HLO_BroadcastInDimOp : HLO_Op<"broadcast_in_dim", - [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_BroadcastInDimOp { + [NoSideEffect, SameOperandsAndResultElementType]> { + let summary = "Broadcast a tensor into the given shape by adding dimensions."; + let description = [{ + Broadcasts the `operand` tensor to a higher rank. This is not the limited + form of broadcasting exposed as the XLA client broadcast op, but rather the + more powerful "InDim" broadcasting, which is closer to the HLO broadcast op + and exposed in the XLA client BroadcastInDim method. + + `broadcast_dimensions` maps the operand dimension number to the target shape + dimension number. It must have the same size as the rank of the operand. The + mapped dimensions must either be the same size or the dimension being + broadcast from must be size 1 (degenerate broadcasting). + + For a scalar (0D tensor) operand, `broadcast_dimensions` must be empty. The + The scalar value will be broadcast to every element in the target shape. + + See https://www.tensorflow.org/xla/broadcasting. + }]; let arguments = (ins HLO_Tensor:$operand, BroadcastDimAttr:$broadcast_dimensions @@ -840,8 +1255,8 @@ def HLO_BroadcastInDimOp : HLO_Op<"broadcast_in_dim", def HLO_DynamicBroadcastInDimOp : HLO_Op<"dynamic_broadcast_in_dim", [ NoSideEffect, DeclareOpInterfaceMethods]> { - string summary = "Broadcast a tensor into the given dynamic shape by adding dimensions."; - string description = [{ + let summary = "Broadcast a tensor into the given dynamic shape by adding dimensions."; + let description = [{ This is a generalization of the BroadcastInDimOp which accepts its output dimensions as an argument. It should eventually supercede the statically shaped original, but is being phased as a separate op in order to support @@ -866,7 +1281,29 @@ def HLO_DynamicBroadcastInDimOp : HLO_Op<"dynamic_broadcast_in_dim", [ // directly. def HLO_CholeskyOp : HLO_Op<"cholesky", - [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_CholeskyOp { + [NoSideEffect, SameOperandsAndResultElementType]> { + let summary = "Cholesky operator"; + let description = [{ + Computes the Cholesky decomposition of a batch of symmetric (Hermitian) + positive definite matrices. + + If lower is true, computes lower-triangular matrices l such that + `a=l.Transpose(l)`. If lower is false, computes upper-triangular matrices u such + that `a=Transpose(u).u`. + + Input data is read only from the lower/upper triangle of a, depending on the + value of lower. Values from the other triangle are ignored. Output data is + returned in the same triangle; the values in the other triangle are + implementation-defined and may be anything. + + If the rank of a is greater than 2, a is treated as a batch of matrices, where + all except the minor 2 dimensions are batch dimensions. + + If a is not symmetric (Hermitian) positive definite, the result is + implementation-defined. + + See https://www.tensorflow.org/xla/operation_semantics#cholesky. + }]; let arguments = (ins HLO_FpOrComplexTensor:$a, DefaultValuedAttr:$lower @@ -876,7 +1313,17 @@ def HLO_CholeskyOp : HLO_Op<"cholesky", } def HLO_ClampOp : HLO_Op<"clamp", - [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_ClampOp { + [NoSideEffect, SameOperandsAndResultElementType]> { + let summary = "Clamp operator"; + let description = [{ + Clamps an operand to within the range between a minimum and maximum value. + + Note: All three arrays must be the same shape. Alternatively, as a + restricted form of broadcasting, min and/or max can be a scalar (0D + tensor) of the element type of the tensor operand. + + See https://www.tensorflow.org/xla/operation_semantics#clamp. + }]; let arguments = (ins HLO_Tensor:$min, HLO_Tensor:$operand, @@ -888,7 +1335,13 @@ def HLO_ClampOp : HLO_Op<"clamp", def HLO_ConcatenateOp : HLO_Op<"concatenate", [NoSideEffect, SameOperandsAndResultElementType, - DeclareOpInterfaceMethods]>, BASE_HLO_ConcatenateOp { + DeclareOpInterfaceMethods]> { + let summary = "XLA's concatenate op"; + let description = [{ + Concatenates a set of tensors along the specified dimension. + + See https://www.tensorflow.org/xla/operation_semantics#concatenate. + }]; let arguments = (ins Variadic:$val, @@ -908,7 +1361,20 @@ def HLO_ConcatenateOp : HLO_Op<"concatenate", } def HLO_CollectivePermuteOp: HLO_Op<"collective_permute", - [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_CollectivePermuteOp { + [NoSideEffect, SameOperandsAndResultType]> { + let summary = "CollectivePermute operator"; + let description = [{ + CollectivePermute is a collective operation that sends and receives data + cross replicas. + Note that there are the following restrictions on the source_target_pair: + - Any two pairs should not have the same target replica id, and they should + not have the same source replica id. + - If a replica id is not a target in any pair, then the output on that + replica is a tensor consists of 0(s) with the same shape as the input. + + See https://www.tensorflow.org/xla/operation_semantics#collectivepermute. + + }]; let arguments = (ins HLO_Tensor:$operand, @@ -928,15 +1394,30 @@ def HLO_ConvOp : HLO_Op<"convolution", [NoSideEffect]>, BASE_HLO_ConvOp { let hasCustomHLOConverter = 1; } -def HLO_CopyOp: HLO_Op<"copy", [NoSideEffect, SameOperandsAndResultType]>, - BASE_HLO_CopyOp { +def HLO_CopyOp: HLO_Op<"copy", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Copy operator"; + let description = [{ + Returns a copy of `operand`. + }]; let arguments = (ins HLO_Tensor); let results = (outs HLO_Tensor); let hasFolder = 1; } def HLO_CrossReplicaSumOp : HLO_Op<"cross-replica-sum", - [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_CrossReplicaSumOp { + [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Sums input across replicated instances."; + let description = [{ + For each of the replica groups, operands of the group devices are summed + so that each device has the sum. + + For example, suppose there are 8 TPU devices: `[A, B, C, D, E, F, G, H]`. + Passing group_assignment=`[[0,2,4,6],[1,3,5,7]]` sets `A, C, E, G` as group 0, + and `B, D, F, H` as group 1. Thus we get the outputs: + `[A+C+E+G, B+D+F+H, A+C+E+G, B+D+F+H, A+C+E+G, B+D+F+H, A+C+E+G, B+D+F+H]`. + + See https://www.tensorflow.org/xla/operation_semantics#crossreplicasum. + }]; let arguments = (ins HLO_Tensor:$operand, @@ -946,7 +1427,21 @@ def HLO_CrossReplicaSumOp : HLO_Op<"cross-replica-sum", let results = (outs HLO_Tensor); } -def HLO_CustomCallOp: HLO_Op<"custom_call", []>, BASE_HLO_CustomCallOp { +def HLO_CustomCallOp: HLO_Op<"custom_call", []> { + let summary = "CustomCall operator"; + let description = [{ + A custom call invokes code external to XLA. The `args` are passed to the + external code, and the external code is expected to produce a result of the + given type. The exact mechanism is backend-specific. For example, in the CPU + backend, a call instruction is emitted which targets a symbol with the name + `call_target_name`. + + `call_target_name` and `backend_config` can be arbitrary strings, but + `call_target_name` should be short as it may be used in labels. + `backend_config` can encode arbitrarily large amounts of information. + + See https://www.tensorflow.org/xla/operation_semantics#customcall. + }]; let arguments = (ins Variadic:$args, StrAttr:$call_target_name, @@ -957,7 +1452,14 @@ def HLO_CustomCallOp: HLO_Op<"custom_call", []>, BASE_HLO_CustomCallOp { let hasCustomHLOConverter = 1; } -def HLO_DotOp: HLO_Op<"dot", [NoSideEffect]>, BASE_HLO_DotOp { +def HLO_DotOp: HLO_Op<"dot", [NoSideEffect]> { + let summary = "Dot operator"; + let description = [{ + Performs dot products between vectors, vector/matrix and matrix/matrix + multiplication. + + See https://www.tensorflow.org/xla/operation_semantics#dot. + }]; let arguments = ( ins HLO_Tensor:$lhs, HLO_Tensor:$rhs, @@ -966,8 +1468,14 @@ def HLO_DotOp: HLO_Op<"dot", [NoSideEffect]>, BASE_HLO_DotOp { let results = (outs HLO_Tensor); } -def HLO_DotGeneralOp: HLO_Op<"dot_general", [NoSideEffect]>, - BASE_HLO_DotGeneralOp { +def HLO_DotGeneralOp: HLO_Op<"dot_general", [NoSideEffect]> { + let summary = "General Dot operator"; + let description = [{ + Performs general dot products between vectors, vector/matrix and + matrix/matrix multiplication. + + See https://www.tensorflow.org/xla/operation_semantics#dotgeneral. + }]; let arguments = (ins HLO_Tensor:$lhs, HLO_Tensor:$rhs, @@ -1021,7 +1529,14 @@ def HLO_UnaryEinsumOp: HLO_Op<"unary_einsum", [NoSideEffect]>, BASE_EinsumOp { let hasCustomHLOConverter = 1; } -def HLO_FftOp: HLO_Op<"fft", [NoSideEffect]>, BASE_HLO_FftOp { +def HLO_FftOp: HLO_Op<"fft", [NoSideEffect]> { + let summary = "Fast fourier transform operator"; + let description = [{ + Returns the fast-fourier-transform of the input array. + + See + https://www.tensorflow.org/xla/operation_semantics#fft. + }]; let arguments = (ins HLO_Tensor:$operand, HLO_FftTypeAttr: $fft_type, @@ -1031,7 +1546,7 @@ def HLO_FftOp: HLO_Op<"fft", [NoSideEffect]>, BASE_HLO_FftOp { let results = (outs HLO_Tensor); } -def HLO_GatherOp: HLO_Op<"gather", [NoSideEffect]>, BASE_HLO_GatherOp { +def HLO_GatherOp: HLO_Op<"gather", [NoSideEffect]> { let arguments = (ins HLO_Tensor:$operand, HLO_IntTensor:$start_indices, @@ -1045,8 +1560,14 @@ def HLO_GatherOp: HLO_Op<"gather", [NoSideEffect]>, BASE_HLO_GatherOp { let hasCanonicalizer = 1; } -def HLO_GetDimensionSizeOp: HLO_Op<"get_dimension_size", [NoSideEffect]>, - BASE_HLO_GetDimensionSizeOp { +def HLO_GetDimensionSizeOp: HLO_Op<"get_dimension_size", [NoSideEffect]> { + let summary = "GetDimensionSize operator"; + let description = [{ + Returns the size of the given dimension of the operand. + + See + https://www.tensorflow.org/xla/operation_semantics#getdimensionsize. + }]; let arguments = (ins HLO_Tensor:$operand, I64Attr:$dimension @@ -1061,8 +1582,20 @@ def HLO_GetDimensionSizeOp: HLO_Op<"get_dimension_size", [NoSideEffect]>, def HLO_MapOp: HLO_Op<"map", [RecursiveSideEffects, SameOperandsElementType, - SameOperandsAndResultShape, SingleBlockImplicitTerminator<"ReturnOp">]>, - BASE_HLO_MapOp { + SameOperandsAndResultShape, SingleBlockImplicitTerminator<"ReturnOp">]> { + let summary = "Map operator"; + let description = [{ + Applies a scalar function over the given operands arrays, producing an array + of the same dimensions where each element is the result of the mapped function + applied to the corresponding elements in the input arrays. + + The mapped function is an arbitrary computation with the restriction that it + has N inputs of scalar type T and a single output with type S. The output has + the same dimensions as the operands except that the element type T is replaced + with S. + + See https://www.tensorflow.org/xla/operation_semantics#map. + }]; let arguments = (ins Variadic:$operands, I64ElementsAttr:$dimensions @@ -1074,7 +1607,13 @@ def HLO_MapOp: HLO_Op<"map", } def HLO_ReshapeOp: HLO_Op<"reshape", - [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_ReshapeOp { + [NoSideEffect, SameOperandsAndResultElementType]> { + let summary = "Reshape operator"; + let description = [{ + Reshapes the dimensions of `operand` into a new configuration. + + See https://www.tensorflow.org/xla/operation_semantics#reshape. + }]; let arguments = (ins HLO_Tensor:$operand); let results = (outs HLO_StaticShapeTensor); @@ -1104,8 +1643,15 @@ def HLO_DynamicReshapeOp: HLO_Op<"dynamic_reshape", [NoSideEffect]> { let hasCustomHLOConverter = 1; } -def HLO_ScatterOp: HLO_Op<"scatter", [RecursiveSideEffects]>, - BASE_HLO_ScatterOp { +def HLO_ScatterOp: HLO_Op<"scatter", [RecursiveSideEffects]> { + let summary = "Scatter operator"; + let description = [{ + Generates a result which is the value of the input array `operand`, + with several slices (at indices specified by `scatter_indices`) + updated with the values in `updates` using `update_computation`. + + See https://www.tensorflow.org/xla/operation_semantics#scatter. + }]; let arguments = (ins HLO_Tensor:$operand, HLO_Tensor:$scatter_indices, @@ -1129,7 +1675,14 @@ def HLO_SelectOp: HLO_Op<"select", [NoSideEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, - ]>, BASE_HLO_SelectOp { + ]> { + let summary = "Select operator"; + let description = [{ + Constructs an output tensor from the elements of `on_true` and `on_false` + based on the values of `pred`. + + `pred`, `on_true` and `on_false` must be broadcast compatible. + }]; let arguments = (ins HLO_PredTensor:$pred, HLO_Tensor:$on_true, @@ -1142,7 +1695,18 @@ def HLO_SelectOp: HLO_Op<"select", [NoSideEffect, } def HLO_SelectAndScatterOp: HLO_Op<"select_and_scatter", - [RecursiveSideEffects]>, BASE_HLO_SelectAndScatterOp { + [RecursiveSideEffects]> { + let summary = "SelectAndScatter operator"; + let description = [{ + Runs a windowed selection `select` function over `operand` with shape + `window_dimensions` and stride `window_strides`. This will produce an amount + of selected locations whose shape matches `source`. These are then scattered + to the output which is initialized with `init_value`. + Multiple scattered elements which land in the same output location are + combined using the `scatter` function. + + See https://www.tensorflow.org/xla/operation_semantics#selectandscatter. + }]; let arguments = (ins HLO_Tensor:$operand, HLO_Tensor:$source, @@ -1159,8 +1723,15 @@ def HLO_SelectAndScatterOp: HLO_Op<"select_and_scatter", let hasCustomHLOConverter = 1; } -def HLO_SetDimensionSizeOp: HLO_Op<"set_dimension_size", [NoSideEffect]>, - BASE_HLO_SetDimensionSizeOp { +def HLO_SetDimensionSizeOp: HLO_Op<"set_dimension_size", [NoSideEffect]> { + let summary = "SetDimensionSize operator"; + let description = [{ + Sets the dynamic size of operand's given dimension. Pass through the operand + as result, with dynamic dimension tracked by the compiler. Padded values + will be ignored by downstream reduction ops. + + See https://www.tensorflow.org/xla/operation_semantics#setdimensionsize. + }]; let arguments = (ins HLO_Tensor:$operand, I32Tensor:$size, @@ -1172,7 +1743,14 @@ def HLO_SetDimensionSizeOp: HLO_Op<"set_dimension_size", [NoSideEffect]>, } def HLO_SortOp : HLO_Op<"sort", [RecursiveSideEffects, - SameOperandsAndResultShape]>, BASE_HLO_SortOp { + SameOperandsAndResultShape]> { + let summary = "Sort operator"; + let description = [{ + Sorts the given `operands` at the given `dimension` with the given + `comparator`. + + See https://www.tensorflow.org/xla/operation_semantics#sort. + }]; let arguments = (ins Variadic:$operands, DefaultValuedAttr:$dimension, @@ -1192,7 +1770,14 @@ def HLO_SortOp : HLO_Op<"sort", [RecursiveSideEffects, } def HLO_ReverseOp: HLO_Op<"reverse", - [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_ReverseOp { + [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Reverse operator"; + let description = [{ + Reverses the specified dimensions of `operand` according to the given + `dimensions`. + + See https://www.tensorflow.org/xla/operation_semantics#rev_reverse. + }]; let arguments = (ins HLO_Tensor:$operand, I64ElementsAttr:$dimensions @@ -1204,7 +1789,14 @@ def HLO_ReverseOp: HLO_Op<"reverse", } def HLO_PadOp: HLO_Op<"pad", - [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_PadOp { + [NoSideEffect, SameOperandsAndResultElementType]> { + let summary = "Pad operator"; + let description = [{ + Pads the edges of `operand` with the `padding_value` and according to + the passed configuration. + + See https://www.tensorflow.org/xla/operation_semantics#pad. + }]; let arguments = (ins HLO_Tensor:$operand, HLO_Tensor:$padding_value, @@ -1225,7 +1817,11 @@ def HLO_PadOp: HLO_Op<"pad", let hasFolder = 1; } -def HLO_TraceOp: HLO_Op<"trace", []>, BASE_HLO_TraceOp { +def HLO_TraceOp: HLO_Op<"trace", []> { + let summary = "Trace operator"; + let description = [{ + Emits a logging message `tag` with the `operand`. + }]; let arguments = (ins HLO_Tensor:$operand, StrAttr:$tag @@ -1234,7 +1830,15 @@ def HLO_TraceOp: HLO_Op<"trace", []>, BASE_HLO_TraceOp { } def HLO_TransposeOp: HLO_Op<"transpose", - [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_TransposeOp { + [NoSideEffect, SameOperandsAndResultElementType]> { + let summary = "Transpose operator"; + let description = [{ + Permutes the dimensions of `operand` according to the given `permutation`. + + `res_dimensions[i] = operand_dimensions[permutation[i]]` + + See https://www.tensorflow.org/xla/operation_semantics#transpose. + }]; let arguments = (ins HLO_Tensor:$operand, I64ElementsAttr:$permutation @@ -1245,8 +1849,27 @@ def HLO_TransposeOp: HLO_Op<"transpose", } def HLO_TriangularSolveOp: HLO_Op<"triangular_solve", - [NoSideEffect, SameOperandsAndResultElementType]>, - BASE_HLO_TriangularSolveOp { + [NoSideEffect, SameOperandsAndResultElementType]> { + let summary = "TriangularSolve operator"; + let description = [{ + Solves systems of linear equations with lower or upper triangular + coefficient matrices by forward- or back-substitution. Broadcasting along + leading dimensions, this routine solves one of the matrix systems + op(a) * x = b, or x * op(a) = b, for the variable x, given a and b, where + op(a) is either op(a) = a, or op(a) = Transpose(a), or + op(a) = Conj(Transpose(a)). + + Input data is read only from the lower/upper triangle of a, depending on the + value of lower. Values from the other triangle are ignored. Output data is + returned in the same triangle; the values in the other triangle are + implementation-defined and may be anything. + + If the rank of a and b are greater than 2, they are treated as batches of + matrices, where all except the minor 2 dimensions are batch dimensions. a + and b must have equal batch dimensions. + + See https://www.tensorflow.org/xla/operation_semantics#triangularsolve. + }]; let arguments = (ins HLO_FpOrComplexTensor:$a, HLO_FpOrComplexTensor:$b, @@ -1262,7 +1885,14 @@ def HLO_ReduceWindowOp: HLO_Op<"reduce_window", [ RecursiveSideEffects, SameVariadicOperandSize, SingleBlockImplicitTerminator<"ReturnOp"> - ]>, BASE_HLO_ReduceWindowOp { + ]> { + let summary = "ReduceWindow operator"; + let description = [{ + Returns the result of executing a reduction function over all elements in + each window of one or more arrays in parallel. + + See https://www.tensorflow.org/xla/operation_semantics#reducewindow. + }]; // TODO(hinsu): Verify that padding attribute is 2-d and the remaining // attributes are 1-d. Attributes' leading dimension should match rank of the @@ -1347,8 +1977,16 @@ def HLO_TorchIndexSelectOp : HLO_Op<"torch_index_select", [NoSideEffect]> { //===----------------------------------------------------------------------===// def HLO_RngUniformOp : HLO_Op<"rng_uniform", - InferTensorType<["inferReturnTypeComponents"]>.traits>, - BASE_HLO_RngUniformOp { + InferTensorType<["inferReturnTypeComponents"]>.traits> { + let summary = "RNG with uniform distribution."; + let description = [{ + Constructs an output of a given shape with random numbers generated + following the uniform distribution over the interval `[a,b)`. The parameters + and output element type have to be a boolean type, an integral type or a + floating point types, and the types have to be consistent. + + See https://www.tensorflow.org/xla/operation_semantics#rnguniform. + }]; let arguments = (ins HLO_PredIntOrFpTensor:$a, HLO_PredIntOrFpTensor:$b, @@ -1368,8 +2006,16 @@ def HLO_RngUniformOp : HLO_Op<"rng_uniform", } def HLO_RngNormalOp : HLO_Op<"rng_normal", - InferTensorType<["inferReturnTypeComponents"]>.traits>, - BASE_HLO_RngNormalOp { + InferTensorType<["inferReturnTypeComponents"]>.traits> { + let summary = "RNG with normal distribution."; + let description = [{ + Constructs an output of a given shape with random numbers generated + following the normal distribution with parameters `mu` and `sigma`. The + parameters and output shape have to have a floating point elemental type. + The parameters furthermore have to be scalar valued. + + See https://www.tensorflow.org/xla/operation_semantics#rngnormal. + }]; let arguments = (ins HLO_FpTensor:$mu, HLO_FpTensor:$sigma, @@ -1388,8 +2034,16 @@ def HLO_RngNormalOp : HLO_Op<"rng_normal", }]; } -def HLO_RngBitGeneratorOp : HLO_Op<"rng_bit_generator", [NoSideEffect]>, - BASE_HLO_RngBitGeneratorOp { +def HLO_RngBitGeneratorOp : HLO_Op<"rng_bit_generator", [NoSideEffect]> { + let summary = "Uniform random number generator operator"; + let description = [{ + Returns an output with a given shape filled with uniform random bits using + the specified algorithm (or backend default) and returns an updated state + (with the same shape as initial state) and the generated random data. + + See + https://www.tensorflow.org/xla/operation_semantics#rngbitgenerator. + }]; let arguments = (ins // TODO(jpienaar): This could be an enum instead. I32Attr:$rng_algorithm, @@ -1405,8 +2059,19 @@ def HLO_RngBitGeneratorOp : HLO_Op<"rng_bit_generator", [NoSideEffect]>, //===----------------------------------------------------------------------===// // MHLO Quantize Operator. //===----------------------------------------------------------------------===// -def HLO_DequantizeOp : HLO_Op<"dequantize", [NoSideEffect]>, - BASE_HLO_DequantizeOp { +def HLO_DequantizeOp : HLO_Op<"dequantize", [NoSideEffect]> { + let summary = "Dequantize operator"; + let description = [{ + Dequantize the quantized input of packed uint32 to bfloat16. Only uint8 or + uint16 is supported for the original unpacked input. + + Returns a tensor of shape [d0,..., dn * unpack_size] if unpacked input shape + is [d0, ..., dn], where unpack_size = sizeof(unit32) / sizeof(T), where T is + the unpacked input type. If transpose_output is true, will return a tensor + of shape [dn * unpack_size, dn-1, ..., d1, d0]. transpose_output is faster + when input's rank higher than 1. The input needs to be transposed to use + transpose_output feature. + }]; let arguments = (ins TensorOf<[I32]>:$input, F32Attr:$min_range, @@ -1446,15 +2111,34 @@ def HLO_FusionOp : HLO_Op<"fusion", []> { } // This is an op for purposes internal to XLA/GPU. -def HLO_BitcastOp : HLO_Op<"bitcast", [NoSideEffect]>, BASE_HLO_BitcastOp { +def HLO_BitcastOp : HLO_Op<"bitcast", [NoSideEffect]> { + let summary = "Bitcast operator"; + let description = [{ + This op changes the shape of the input in the way that the physical + arrangement of elements are unchanged. + + However, the op needs layout information to make sense of "physical + arrangement of elements". Layout support in MHLO is currently under + exploration. + }]; let arguments = (ins HLO_Tensor:$operand); let results = (outs HLO_Tensor); let hasCustomHLOConverter = 1; } def HLO_ReducePrecisionOp : - HLO_Op<"reduce_precision", [SameOperandsAndResultShape]>, - BASE_HLO_ReducePrecisionOp { + HLO_Op<"reduce_precision", [SameOperandsAndResultShape]> { + let summary = "Reduce precision operator"; + let description = [{ + Models the effect of converting floating - point values to a lower - + precision format(such as IEEE - FP16) and back to the original + format. The number of exponent and mantissa bits in the lower - + precision format can be specified arbitrarily, + although all bit sizes may not be supported on all hardware + implementations. + + See https://www.tensorflow.org/xla/operation_semantics#reduceprecision. + }]; let arguments = (ins HLO_FpTensor:$operand, I32Attr:$exponent_bits, diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td index 896fe0f..096161f 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td @@ -110,873 +110,6 @@ def HLO_LayoutAttr : Attr< let convertFromStorage = IndexElementsAttr.convertFromStorage; } -//===----------------------------------------------------------------------===// -// MHLO nullary op definitions. -//===----------------------------------------------------------------------===// - -class BASE_HLO_ConstOp { - string summary = "Constant operator"; - - string description = [{ - Represents a constant value. - }]; -} - -class BASE_HLO_IotaOp { - string summary = "Iota operator"; - - string description = [{ - Creates a rank 1 array of values starting at zero and incrementing by one. - }]; -} - -//===----------------------------------------------------------------------===// -// MHLO unary elementwise op definitions. -//===----------------------------------------------------------------------===// -// See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions - -class BASE_HLO_AbsOp { - string summary = "Absolute value operator"; - - string description = [{ - Returns `abs(operand)` element-wise. - - See - https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. - }]; -} - -class BASE_HLO_CbrtOp { - string summary = "Cubic root operator"; - - string description = [{ - Returns element-wise cubic root of the operand. - - See - https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. - }]; -} - -class BASE_HLO_CeilOp { - string summary = "Ceil operator"; - - string description = [{ - Returns `Ceil(operand)` element-wise. - - See - https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. - }]; -} - -class BASE_HLO_ClzOp { - string summary = "Count-leading-zeros (Clz) operator"; - - string description = [{ - Returns the number of leading zeros in each operand element-wise. - - See - https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. - }]; -} - -class BASE_HLO_ConvertOp { - string summary = "Convert operator"; - - string description = [{ - Performs element-wise conversion of values from one type to another, e.g. - float to int. - - See https://www.tensorflow.org/xla/operation_semantics#convertelementtype. - }]; -} - -class BASE_HLO_CosOp { - string summary = "Cos operator"; - - string description = [{ - Returns `Cos(operand)` element-wise. - - See - https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. - }]; -} - -class BASE_HLO_ExpOp { - string summary = "Exponential operator"; - - string description = [{ - Returns `e^(operand)` element-wise. - - See - https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. - }]; -} - -class BASE_HLO_Expm1Op { - string summary = "Exponential minus one operator"; - - string description = [{ - Returns `e^(operand) - 1` element-wise. - - See - https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. - }]; -} - -class BASE_HLO_FloorOp { - string summary = "Floor operator"; - - string description = [{ - Returns `Floor(operand)` element-wise. - - See - https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. - }]; -} - -class BASE_HLO_GetDimensionSizeOp { - string summary = "GetDimensionSize operator"; - - string description = [{ - Returns the size of the given dimension of the operand. - - See - https://www.tensorflow.org/xla/operation_semantics#getdimensionsize. - }]; -} - -class BASE_HLO_ImagOp { - string summary = "Imag operator"; - - string description = [{ - Returns `Imag(operand)` element-wise. - }]; -} - -class BASE_HLO_IsFiniteOp { - string summary = "IsFinite operator"; - - string description = [{ - Tests whether each element of operand is finite, i.e., is not positive or - negative infinity, and is not NaN. Returns a tensor of 1-bit integers with - the same shape as the input, where each element is nonzero (i.e. true) if - and only if the corresponding input element is finite. - - See - https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. - }]; -} - -class BASE_HLO_LogOp { - string summary = "Logarithm operator"; - - string description = [{ - Returns `log(operand)` element-wise. - - See - https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. - }]; -} - -class BASE_HLO_Log1pOp { - string summary = "Log1p operator"; - - string description = [{ - Returns `log(operand+1)` element-wise. - - See - https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. - }]; -} - -class BASE_HLO_LogisticOp { - string summary = "Logistic operator"; - - string description = [{ - Returns `logistic(operand)` element-wise. - - See - https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. - }]; -} - -class BASE_HLO_NegOp { - string summary = "Negation operator"; - - string description = [{ - Returns `-operand` element-wise. - - See - https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. - }]; -} - -class BASE_HLO_NotOp { - string summary = "Not operator"; - - string description = [{ - Returns `!operand` element-wise. - - See - https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. - }]; -} - -class BASE_HLO_PopulationCountOp { - string summary = "PopulationCount operator"; - - string description = [{ - Returns the number of bits set in each operand element-wise. - - See - https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. - }]; -} - -class BASE_HLO_RealOp { - string summary = "Real operator"; - - string description = [{ - Returns `Real(operand)` element-wise. - }]; -} - -class BASE_HLO_RngBitGeneratorOp { - string summary = "Uniform random number generator operator"; - - string description = [{ - Returns an output with a given shape filled with uniform random bits using - the specified algorithm (or backend default) and returns an updated state - (with the same shape as initial state) and the generated random data. - - See - https://www.tensorflow.org/xla/operation_semantics#rngbitgenerator. - }]; -} - -class BASE_HLO_RoundOp { - string summary = "Round operator"; - - string description = [{ - Returns `Round(operand)` element-wise, rounding to nearest integer with - half-way cases rounding away from zero. - - See - https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. - }]; -} - -class BASE_HLO_RsqrtOp { - string summary = "Reciprocal Square-root operator"; - - string description = [{ - Returns `1.0 / sqrt(operand)` element-wise. - - See - https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. - }]; -} - -class BASE_HLO_SignOp { - string summary = "Sign operator"; - - string description = [{ - Returns `sign(operand)` element-wise, where - - ``` - sign(x) = -1 : x < 0 - = -0 : x = -0 - = NaN : x = NaN - = +0 : x = +0 - = 1 : x > 0 - ``` - - See - https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. - }]; -} - -class BASE_HLO_SinOp { - string summary = "Sin operator"; - - string description = [{ - Returns `Sin(operand)` element-wise. - - See - https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. - }]; -} - -class BASE_HLO_SqrtOp { - string summary = "Square-root operator"; - - string description = [{ - Returns `sqrt(operand)` element-wise. - - See - https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. - }]; -} - -class BASE_HLO_TanhOp { - string summary = "Tanh operator"; - - string description = [{ - Returns `tanh(operand)` element-wise. - - See - https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. - }]; -} - -//===----------------------------------------------------------------------===// -// XLA binary elementwise op definitions. -//===----------------------------------------------------------------------===// - -class BASE_HLO_AddOp { - string summary = "Addition operator"; - - string description = [{ - Returns `lhs + rhs` element-wise. - - See - https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. - }]; -} - -class BASE_HLO_ComplexOp { - string summary = "Complex operator"; - - string description = [{ - Performs element-wise conversion of a pair of real and imaginary values to - a complex value. - }]; -} - -class BASE_HLO_DivOp { - string summary = "Division operator"; - - string description = [{ - Returns `lhs / rhs` element-wise. - - See - https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. - }]; -} - -class BASE_HLO_MaxOp { - string summary = "Maximum operator"; - - string description = [{ - Returns `max(lhs, rhs)` element-wise. - - See - https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. - }]; -} - -class BASE_HLO_MinOp { - string summary = "Minimum operator"; - - string description = [{ - Returns `min(lhs, rhs)` element-wise. - - See - https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. - }]; -} - -class BASE_HLO_MulOp { - string summary = "Multiplication operator"; - - string description = [{ - Returns `lhs * rhs` element-wise. - - See - https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. - }]; -} -class BASE_HLO_PowOp { - string summary = "Power operator"; - - string description = [{ - Returns `lhs ^ rhs` element-wise. - - See - https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. - }]; -} - -class BASE_HLO_RemOp { - string summary = "Remainder operator"; - - string description = [{ - Returns `lhs % rhs` element-wise. - - See - https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. - }]; -} - -class BASE_HLO_SubOp { - string summary = "Subtraction operator"; - - string description = [{ - Returns `lhs - rhs` element-wise. - - See - https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. - }]; -} - -class BASE_HLO_ShiftLeftOp { - string summary = "Shift Left operator"; - - string description = [{ - Returns `lhs << rhs` element-wise. - - See - https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. - }]; -} - -class BASE_HLO_ShiftRightArithmeticOp { - string summary = "Shift right arithmetic operator"; - - string description = [{ - Returns arithmetic `lhs >> rhs` element-wise. - - See - https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. - }]; -} - -class BASE_HLO_ShiftRightLogicalOp { - string summary = "Shift right logical operator"; - - string description = [{ - Returns logical `lhs >> rhs` element-wise. - - See - https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. - }]; -} - -class BASE_HLO_Atan2Op { - string summary = "Atan2 operator"; - - string description = [{ - Returns `atan2(lhs/rhs)` element-wise. - - See - https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. - }]; -} - -class BASE_HLO_AndOp { - string summary = "Logical and"; - - string description = [{ - Returns `logical_and(lhs, rhs)` element-wise. - - See - https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. - }]; -} - -class BASE_HLO_OrOp { - string summary = "Logical or"; - - string description = [{ - Returns `logical_or(lhs, rhs)` element-wise. - - See - https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. - }]; -} - -class BASE_HLO_XorOp { - string summary = "Logical xor"; - - string description = [{ - Returns `logical_xor(lhs, rhs)` element-wise. - - See - https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. - }]; -} - -//===----------------------------------------------------------------------===// -// XLA control flow related op definitions. -//===----------------------------------------------------------------------===// - -class BASE_HLO_CaseOp { - string summary = "Switch-Case operator"; - - string description = [{ - Returns the result of executing `branches[index]`. If - `index` is < 0 or >= N, then `branches[N-1] is executed as - the default branch. - - Each branch `branches[b]` must take in a single argument of same type as - `branch_operands[b]` and will be invoked with `branch_operands[b]`. The type - of the returned value of each branch must be the same. - - Note that only one of the branches will be executed depending on the value - of index. - See https://www.tensorflow.org/xla/operation_semantics#conditional. - }]; - -} - -//===----------------------------------------------------------------------===// -// XLA parallelism related op definitions. -//===----------------------------------------------------------------------===// - -class BASE_HLO_ReplicaIdOp { - string summary = "ReplicaId operator"; - - string description = [{ - Returns the unique ID (int32 scalar) of the replica. - - The unique ID of each replica is an unsigned integer in the interval [0, N), - where N is the number of replicas. Since all the replicas are running the - same program, a ReplicaId() call in the program will return a different - value on each replica. - - See https://www.tensorflow.org/xla/operation_semantics#replicaid. - }]; -} - -class BASE_HLO_PartitionIdOp { - string summary = "PartitionId operator"; - - string description = [{ - Returns the unique ID (int32 scalar) of the partition. - }]; -} - -class BASE_HLO_AllGatherOp { - string summary = "AllGather operator"; - - string description = [{ - Performs concatenation across replicas. - - See https://www.tensorflow.org/xla/operation_semantics#allgather - }]; -} - -class BASE_HLO_AllReduceOp { - string summary = "AllReduce operator"; - - string description = [{ - Performs a custom reduction across replicas. - - See https://www.tensorflow.org/xla/operation_semantics#allreduce. - }]; -} - -class BASE_HLO_ReduceOp { - string summary = "Reduce operator"; - - string description = [{ - Returns the result of executing a reduction function on one or more arrays - in parallel. - - See https://www.tensorflow.org/xla/operation_semantics#reduce. - }]; -} - -class BASE_HLO_ReduceWindowOp { - string summary = "ReduceWindow operator"; - - string description = [{ - Returns the result of executing a reduction function over all elements in - each window of one or more arrays in parallel. - - See https://www.tensorflow.org/xla/operation_semantics#reducewindow. - }]; -} - -//===----------------------------------------------------------------------===// -// XLA tuple op definitions. -//===----------------------------------------------------------------------===// -class BASE_HLO_GetTupleElementOp { - string summary = "GetTupleElement operator"; - - string description = [{ - Returns a member of a tuple specified by an index. - - See https://www.tensorflow.org/xla/operation_semantics#gettupleelement. - }]; -} - -class BASE_HLO_TupleOp { - string summary = "XLA's tuple op"; - - string description = [{ - Groups a set of tensor inputs into a single tuple object. - - See https://www.tensorflow.org/xla/operation_semantics#tuple. - }]; -} - - - - -class BASE_HLO_CompareOp { - string summary = "Comparison operator"; - - string description = [{ - Compares `lhs` and `rhs` elementwise according to `comparison_direction` - and `compare_type`. If unspecified, `compare_type` is FLOAT for float element - types, SIGNED for signed element types and UNSIGNED for unsigned element - types. - - See - https://www.tensorflow.org/xla/operation_semantics#element-wise_comparison_operations. - }]; -} - -//===----------------------------------------------------------------------===// -// Quantize op definitions. -//===----------------------------------------------------------------------===// - -class BASE_HLO_DequantizeOp { - string summary = "Dequantize operator"; - - string description = [{ - Dequantize the quantized input of packed uint32 to bfloat16. Only uint8 or - uint16 is supported for the original unpacked input. - - Returns a tensor of shape [d0,..., dn * unpack_size] if unpacked input shape - is [d0, ..., dn], where unpack_size = sizeof(unit32) / sizeof(T), where T is - the unpacked input type. If transpose_output is true, will return a tensor - of shape [dn * unpack_size, dn-1, ..., d1, d0]. transpose_output is faster - when input's rank higher than 1. The input needs to be transposed to use - transpose_output feature. - }]; -} - -//===----------------------------------------------------------------------===// -// XLA Slice definitions. -//===----------------------------------------------------------------------===// - -class BASE_HLO_SliceOp { - string summary = "Slice operator"; - - string description = [{ - Slices a portion of the `operand` into a new configuration. - - See https://www.tensorflow.org/xla/operation_semantics#slice. - }]; -} - -class BASE_HLO_DynamicSliceOp { - string summary = "Dynamic Slice operator"; - - string description = [{ - Extracts a sub-array from the input array at dynamic start_indices. - - See https://www.tensorflow.org/xla/operation_semantics#dynamicslice. - }]; -} - -class BASE_HLO_DynamicUpdateSliceOp { - string summary = "Dynamic Update Slice operator"; - - string description = [{ - DynamicUpdateSlice generates a result which is the value of the input array - operand, with a slice update overwritten at start_indices. - - See https://www.tensorflow.org/xla/operation_semantics#dynamicupdateslice. - }]; -} - -//===----------------------------------------------------------------------===// -// XLA Other op definitions. -//===----------------------------------------------------------------------===// - -class BASE_HLO_AllToAllOp { - string summary = "AllToAll"; - - string description = [{ - AllToAll is a collective operation that sends data from all cores to all - cores. It has two phases: - - The scatter phase. On each core, the operand is split into `split_count` - number of blocks along the `split_dimension`, and the blocks are - scattered to all cores, e.g., the i-th block is sent to the i-th core. - - The gather phase. Each core concatenates the received blocks along the - `concat_dimension`. - - The participating cores can be configured by: - - replica_groups: each ReplicaGroup contains a list of replica id - participating in the computation (replica id for the current replica can - be retrieved using ReplicaId op). AllToAll will be applied within - subgroups in the specified order. For example, - `replica_groups` = {{1,2,3}, {4,5,0}} means that an AllToAll will be applied - within replicas {1, 2, 3}, and in the gather phase, the received blocks - will be concatenated in the same order of 1, 2, 3. Then, another AllToAll - will be applied within replicas 4, 5, 0, and the concatenation order is - also 4, 5, 0. If `replica_groups` is empty, all replicas belong to one - group, and the concatenation order is the numerical order (0, 1, 2, ...). - - Prerequisites: - - The dimension size of the operand on the split_dimension is divisible by - `split_count`. - - The operand's shape is not tuple. - - See https://www.tensorflow.org/xla/operation_semantics#alltoall - }]; -} - -class BASE_HLO_BatchNormGradOp { - string summary = "Batch Normalization Gradient"; - - string description = [{ - Calculates gradients of batch norm. - - See https://www.tensorflow.org/xla/operation_semantics#batchnormgrad - }]; -} - -class BASE_HLO_BatchNormInferenceOp { - string summary = "Batch Normalization for Inference"; - - string description = [{ - Normalizes an array across batch and spatial dimensions. - - See https://www.tensorflow.org/xla/operation_semantics#batchnorminference - }]; -} - -class BASE_HLO_BatchNormTrainingOp { - string summary = "Batch Normalization for Training"; - - string description = [{ - Normalizes an array across batch and spatial dimensions. - - See https://www.tensorflow.org/xla/operation_semantics#batchnormtraining - }]; -} - -class BASE_HLO_BitcastConvertOp { - string summary = "BitcastConvert operator"; - - string description = [{ - Similar to a 'tf.bitcast' in TensorFlow, performs an element-wise bitcast - operation from a data shape to a target shape. The dimensions must match, - and the conversion is an element-wise one. Bitcast is implemented as a - low-level cast, so machines with different floating-point representations - will give different results. - - See https://www.tensorflow.org/xla/operation_semantics#bitcastconverttype. - }]; -} - -class BASE_HLO_BroadcastOp { - string summary = "Broadcast a tensor to a higher rank by prepending dimensions"; - - string description = [{ - Broadcasts the operand tensor to a higher rank by prepending - `broadcast_sizes` to the dimensions. The current values of the operand are - copied into the other dimensions. - - This is a more limited form of broadcasting, that corresponds to the XLA - client Broadcast method. For a more general form of broadcasting, see the - BroadcastInDimOp. - - See https://www.tensorflow.org/xla/operation_semantics#broadcast. - }]; -} - -class BASE_HLO_BroadcastInDimOp { - string summary = "Broadcast a tensor into the given shape by adding dimensions."; - - string description = [{ - Broadcasts the `operand` tensor to a higher rank. This is not the limited - form of broadcasting exposed as the XLA client broadcast op, but rather the - more powerful "InDim" broadcasting, which is closer to the HLO broadcast op - and exposed in the XLA client BroadcastInDim method. - - `broadcast_dimensions` maps the operand dimension number to the target shape - dimension number. It must have the same size as the rank of the operand. The - mapped dimensions must either be the same size or the dimension being - broadcast from must be size 1 (degenerate broadcasting). - - For a scalar (0D tensor) operand, `broadcast_dimensions` must be empty. The - The scalar value will be broadcast to every element in the target shape. - - See https://www.tensorflow.org/xla/broadcasting. - }]; -} - -class BASE_HLO_CholeskyOp { - string summary = "Cholesky operator"; - - string description = [{ - Computes the Cholesky decomposition of a batch of symmetric (Hermitian) - positive definite matrices. - - If lower is true, computes lower-triangular matrices l such that - `a=l.Transpose(l)`. If lower is false, computes upper-triangular matrices u such - that `a=Transpose(u).u`. - - Input data is read only from the lower/upper triangle of a, depending on the - value of lower. Values from the other triangle are ignored. Output data is - returned in the same triangle; the values in the other triangle are - implementation-defined and may be anything. - - If the rank of a is greater than 2, a is treated as a batch of matrices, where - all except the minor 2 dimensions are batch dimensions. - - If a is not symmetric (Hermitian) positive definite, the result is - implementation-defined. - - See https://www.tensorflow.org/xla/operation_semantics#cholesky. - }]; -} - -class BASE_HLO_ClampOp { - string summary = "Clamp operator"; - - string description = [{ - Clamps an operand to within the range between a minimum and maximum value. - - Note: All three arrays must be the same shape. Alternatively, as a - restricted form of broadcasting, min and/or max can be a scalar (0D - tensor) of the element type of the tensor operand. - - See https://www.tensorflow.org/xla/operation_semantics#clamp. - }]; -} - -class BASE_HLO_CollectivePermuteOp { - string summary = "CollectivePermute operator"; - - string description = [{ - CollectivePermute is a collective operation that sends and receives data - cross replicas. - Note that there are the following restrictions on the source_target_pair: - - Any two pairs should not have the same target replica id, and they should - not have the same source replica id. - - If a replica id is not a target in any pair, then the output on that - replica is a tensor consists of 0(s) with the same shape as the input. - - See https://www.tensorflow.org/xla/operation_semantics#collectivepermute. - - }]; -} -class BASE_HLO_ConcatenateOp { - string summary = "XLA's concatenate op"; - - string description = [{ - Concatenates a set of tensors along the specified dimension. - - See https://www.tensorflow.org/xla/operation_semantics#concatenate. - }]; -} - //===----------------------------------------------------------------------===// // Common convolution attributes //===----------------------------------------------------------------------===// @@ -1034,323 +167,4 @@ class BASE_HLO_ConvOp { }]; } -class BASE_HLO_CopyOp { - string summary = "Copy operator"; - - string description = [{ - Returns a copy of `operand`. - }]; -} - -class BASE_HLO_CrossReplicaSumOp { - string summary = "Sums input across replicated instances."; - - string description = [{ - For each of the replica groups, operands of the group devices are summed - so that each device has the sum. - - For example, suppose there are 8 TPU devices: `[A, B, C, D, E, F, G, H]`. - Passing group_assignment=`[[0,2,4,6],[1,3,5,7]]` sets `A, C, E, G` as group 0, - and `B, D, F, H` as group 1. Thus we get the outputs: - `[A+C+E+G, B+D+F+H, A+C+E+G, B+D+F+H, A+C+E+G, B+D+F+H, A+C+E+G, B+D+F+H]`. - - See https://www.tensorflow.org/xla/operation_semantics#crossreplicasum. - }]; -} - - -class BASE_HLO_CustomCallOp { - string summary = "CustomCall operator"; - - string description = [{ - A custom call invokes code external to XLA. The `args` are passed to the - external code, and the external code is expected to produce a result of the - given type. The exact mechanism is backend-specific. For example, in the CPU - backend, a call instruction is emitted which targets a symbol with the name - `call_target_name`. - - `call_target_name` and `backend_config` can be arbitrary strings, but - `call_target_name` should be short as it may be used in labels. - `backend_config` can encode arbitrarily large amounts of information. - - See https://www.tensorflow.org/xla/operation_semantics#customcall. - }]; -} - -class BASE_HLO_DotOp { - string summary = "Dot operator"; - string description = [{ - Performs dot products between vectors, vector/matrix and matrix/matrix - multiplication. - - See https://www.tensorflow.org/xla/operation_semantics#dot. - }]; -} - -class BASE_HLO_DotGeneralOp { - string summary = "General Dot operator"; - string description = [{ - Performs general dot products between vectors, vector/matrix and - matrix/matrix multiplication. - - See https://www.tensorflow.org/xla/operation_semantics#dotgeneral. - }]; -} - -class BASE_HLO_FftOp { - string summary = "Fast fourier transform operator"; - - string description = [{ - Returns the fast-fourier-transform of the input array. - - See - https://www.tensorflow.org/xla/operation_semantics#fft. - }]; -} - -class BASE_HLO_GatherOp{ - string summary = "Gather operator"; - - string description = [{ - Stitches together several slices of an input array. - - See https://www.tensorflow.org/xla/operation_semantics#gather. - }]; -} - -class BASE_HLO_MapOp { - string summary = "Map operator"; - - string description = [{ - Applies a scalar function over the given operands arrays, producing an array - of the same dimensions where each element is the result of the mapped function - applied to the corresponding elements in the input arrays. - - The mapped function is an arbitrary computation with the restriction that it - has N inputs of scalar type T and a single output with type S. The output has - the same dimensions as the operands except that the element type T is replaced - with S. - - See https://www.tensorflow.org/xla/operation_semantics#map. - }]; -} - -class BASE_HLO_ReshapeOp { - string summary = "Reshape operator"; - - string description = [{ - Reshapes the dimensions of `operand` into a new configuration. - - See https://www.tensorflow.org/xla/operation_semantics#reshape. - }]; -} - -class BASE_HLO_ScatterOp { - string summary = "Scatter operator"; - - string description = [{ - Generates a result which is the value of the input array `operand`, - with several slices (at indices specified by `scatter_indices`) - updated with the values in `updates` using `update_computation`. - - See https://www.tensorflow.org/xla/operation_semantics#scatter. - }]; -} - -class BASE_HLO_SelectOp { - string summary = "Select operator"; - - string description = [{ - Constructs an output tensor from the elements of `on_true` and `on_false` - based on the values of `pred`. - - `pred`, `on_true` and `on_false` must be broadcast compatible. - }]; -} - -class BASE_HLO_SelectAndScatterOp { - string summary = "SelectAndScatter operator"; - - string description = [{ - Runs a windowed selection `select` function over `operand` with shape - `window_dimensions` and stride `window_strides`. This will produce an amount - of selected locations whose shape matches `source`. These are then scattered - to the output which is initialized with `init_value`. - Multiple scattered elements which land in the same output location are - combined using the `scatter` function. - - See https://www.tensorflow.org/xla/operation_semantics#selectandscatter. - }]; -} - -class BASE_HLO_SetDimensionSizeOp { - string summary = "SetDimensionSize operator"; - - string description = [{ - Sets the dynamic size of operand's given dimension. Pass through the operand - as result, with dynamic dimension tracked by the compiler. Padded values - will be ignored by downstream reduction ops. - - See https://www.tensorflow.org/xla/operation_semantics#setdimensionsize. - }]; -} - -class BASE_HLO_SortOp { - string summary = "Sort operator"; - - string description = [{ - Sorts the given `operands` at the given `dimension` with the given - `comparator`. - - See https://www.tensorflow.org/xla/operation_semantics#sort. - }]; -} - -class BASE_HLO_ReverseOp { - string summary = "Reverse operator"; - - string description = [{ - Reverses the specified dimensions of `operand` according to the given - `dimensions`. - - See https://www.tensorflow.org/xla/operation_semantics#rev_reverse. - }]; -} - -class BASE_HLO_PadOp { - string summary = "Pad operator"; - - string description = [{ - Pads the edges of `operand` with the `padding_value` and according to - the passed configuration. - - See https://www.tensorflow.org/xla/operation_semantics#pad. - }]; -} - -class BASE_HLO_TraceOp { - string summary = "Trace operator"; - - string description = [{ - Emits a logging message `tag` with the `operand`. - }]; -} - -class BASE_HLO_TransposeOp { - string summary = "Transpose operator"; - - string description = [{ - Permutes the dimensions of `operand` according to the given `permutation`. - - `res_dimensions[i] = operand_dimensions[permutation[i]]` - - See https://www.tensorflow.org/xla/operation_semantics#transpose. - }]; -} - -class BASE_HLO_TriangularSolveOp { - string summary = "TriangularSolve operator"; - - string description = [{ - Solves systems of linear equations with lower or upper triangular - coefficient matrices by forward- or back-substitution. Broadcasting along - leading dimensions, this routine solves one of the matrix systems - op(a) * x = b, or x * op(a) = b, for the variable x, given a and b, where - op(a) is either op(a) = a, or op(a) = Transpose(a), or - op(a) = Conj(Transpose(a)). - - Input data is read only from the lower/upper triangle of a, depending on the - value of lower. Values from the other triangle are ignored. Output data is - returned in the same triangle; the values in the other triangle are - implementation-defined and may be anything. - - If the rank of a and b are greater than 2, they are treated as batches of - matrices, where all except the minor 2 dimensions are batch dimensions. a - and b must have equal batch dimensions. - - See https://www.tensorflow.org/xla/operation_semantics#triangularsolve. - }]; - -} - -class BASE_HLO_RngUniformOp { - string summary = "RNG with uniform distribution."; - - string description = [{ - Constructs an output of a given shape with random numbers generated - following the uniform distribution over the interval `[a,b)`. The parameters - and output element type have to be a boolean type, an integral type or a - floating point types, and the types have to be consistent. - - See https://www.tensorflow.org/xla/operation_semantics#rnguniform. - }]; -} - -class BASE_HLO_RngNormalOp { - string summary = "RNG with normal distribution."; - - string description = [{ - Constructs an output of a given shape with random numbers generated - following the normal distribution with parameters `mu` and `sigma`. The - parameters and output shape have to have a floating point elemental type. - The parameters furthermore have to be scalar valued. - - See https://www.tensorflow.org/xla/operation_semantics#rngnormal. - }]; -} - -class BASE_HLO_ReducePrecisionOp { - string summary = "Reduce precision operator"; - - string description = [{ - Models the effect of converting floating - point values to a lower - - precision format(such as IEEE - FP16) and back to the original - format. The number of exponent and mantissa bits in the lower - - precision format can be specified arbitrarily, - although all bit sizes may not be supported on all hardware - implementations. - - See https://www.tensorflow.org/xla/operation_semantics#reduceprecision. - }]; -} - -class BASE_HLO_InfeedOp { - string summary = "Infeed operator"; - - string description = [{ - Reads a single data item from the implicit Infeed streaming interface of - the device, interpreting the data as the given shape and its layout, and - returns an LHLO op of the data. Multiple Infeed operations are allowed in a - computation, but there must be a total order among the Infeed operations. - For example, two Infeeds in the code below have a total order since there - is a dependency between the while loops. - - See https://www.tensorflow.org/xla/operation_semantics#infeed - }]; -} - -class BASE_HLO_WhileOp { - string summary = "While operator"; - - string description = [{ - Returns the result of executing a body function until the cond body returns - true. - - See https://www.tensorflow.org/xla/operation_semantics#while. - }]; -} - -class BASE_HLO_BitcastOp { - string summary = "Bitcast operator"; - - string description = [{ - This op changes the shape of the input in the way that the physical - arrangement of elements are unchanged. - - However, the op needs layout information to make sense of "physical - arrangement of elements". Layout support in MHLO is currently under - exploration. - }]; -} - #endif // HLO_OPS_BASE diff --git a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.td index b9fe5fb..791d11d 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.td @@ -44,8 +44,13 @@ def I32Buffer : MemRefOf<[I32]>; // calls generate or consume standard deviation, whereas LHLO ops generate or // consume variance (= std-dev ^ 2). -def LHLOGPU_BatchNormGradOp : LHLOGPU_Op<"batch_norm_grad">, - BASE_HLO_BatchNormGradOp { +def LHLOGPU_BatchNormGradOp : LHLOGPU_Op<"batch_norm_grad"> { + let summary = "Batch Normalization Gradient"; + let description = [{ + Calculates gradients of batch norm. + + See https://www.tensorflow.org/xla/operation_semantics#batchnormgrad + }]; let arguments = (ins Arg:$operand, Arg:$scale, @@ -60,8 +65,13 @@ def LHLOGPU_BatchNormGradOp : LHLOGPU_Op<"batch_norm_grad">, ); } -def LHLOGPU_BatchNormInferenceOp : LHLOGPU_Op<"batch_norm_inference">, - BASE_HLO_BatchNormInferenceOp { +def LHLOGPU_BatchNormInferenceOp : LHLOGPU_Op<"batch_norm_inference"> { + let summary = "Batch Normalization for Inference"; + let description = [{ + Normalizes an array across batch and spatial dimensions. + + See https://www.tensorflow.org/xla/operation_semantics#batchnorminference + }]; let arguments = (ins Arg:$operand, Arg:$scale, @@ -73,8 +83,13 @@ def LHLOGPU_BatchNormInferenceOp : LHLOGPU_Op<"batch_norm_inference">, I64Attr:$feature_index); } -def LHLOGPU_BatchNormTrainingOp : LHLOGPU_Op<"batch_norm_training">, - BASE_HLO_BatchNormTrainingOp { +def LHLOGPU_BatchNormTrainingOp : LHLOGPU_Op<"batch_norm_training"> { + let summary = "Batch Normalization for Training"; + let description = [{ + Normalizes an array across batch and spatial dimensions. + + See https://www.tensorflow.org/xla/operation_semantics#batchnormtraining + }]; let arguments = (ins Arg:$operand, diff --git a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td index 6dacce1..344aaf6 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td @@ -52,7 +52,11 @@ class LHLO_Op traits> : Op], traits)>; -def LHLO_ConstOp : LHLO_Op<"constant", []>, BASE_HLO_ConstOp { +def LHLO_ConstOp : LHLO_Op<"constant", []> { + let summary = "Constant operator"; + let description = [{ + Represents a constant value. + }]; let arguments = (ins ElementsAttr:$value, Arg:$output @@ -61,7 +65,11 @@ def LHLO_ConstOp : LHLO_Op<"constant", []>, BASE_HLO_ConstOp { let hasCanonicalizer = 1; } -def LHLO_IotaOp : LHLO_Op<"iota", []>, BASE_HLO_IotaOp { +def LHLO_IotaOp : LHLO_Op<"iota", []> { + let summary = "Iota operator"; + let description = [{ + Creates a rank 1 array of values starting at zero and incrementing by one. + }]; let arguments = (ins I64Attr:$iota_dimension, Arg:$output); } @@ -80,70 +88,254 @@ class LHLO_UnaryElementwiseOp, BASE_HLO_AbsOp { +def LHLO_AbsOp: LHLO_UnaryElementwiseOp<"abs", LHLO_Buffer, [SameOperandsShape]> { + let summary = "Absolute value operator"; + let description = [{ + Returns `abs(operand)` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; let verifier = [{ return Verify(*this); }]; } // TODO(timshen): add a custom verifier. def LHLO_BitcastConvertOp: - LHLO_UnaryElementwiseOp<"bitcast_convert", LHLO_Buffer, [SameOperandsShape]>, BASE_HLO_BitcastConvertOp; + LHLO_UnaryElementwiseOp<"bitcast_convert", LHLO_Buffer, [SameOperandsShape]> { + let summary = "BitcastConvert operator"; + let description = [{ + Similar to a 'tf.bitcast' in TensorFlow, performs an element-wise bitcast + operation from a data shape to a target shape. The dimensions must match, + and the conversion is an element-wise one. Bitcast is implemented as a + low-level cast, so machines with different floating-point representations + will give different results. -def LHLO_CbrtOp: LHLO_UnaryElementwiseOp<"cbrt", LHLO_FpBuffer>, BASE_HLO_CbrtOp; + See https://www.tensorflow.org/xla/operation_semantics#bitcastconverttype. + }]; +} +def LHLO_CbrtOp: LHLO_UnaryElementwiseOp<"cbrt", LHLO_FpBuffer> { + let summary = "Cubic root operator"; + let description = [{ + Returns element-wise cubic root of the operand. -def LHLO_CeilOp: LHLO_UnaryElementwiseOp<"ceil", LHLO_FpBuffer>, BASE_HLO_CeilOp; + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; +} +def LHLO_CeilOp: LHLO_UnaryElementwiseOp<"ceil", LHLO_FpBuffer> { + let summary = "Ceil operator"; + let description = [{ + Returns `Ceil(operand)` element-wise. -def LHLO_ClzOp: LHLO_UnaryElementwiseOp<"count_leading_zeros", LHLO_IntBuffer>, BASE_HLO_ClzOp; + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; +} +def LHLO_ClzOp: LHLO_UnaryElementwiseOp<"count_leading_zeros", LHLO_IntBuffer> { + let summary = "Count-leading-zeros (Clz) operator"; + let description = [{ + Returns the number of leading zeros in each operand element-wise. + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; +} // TODO(timshen): add a custom verifier. -def LHLO_ConvertOp : LHLO_UnaryElementwiseOp<"convert", LHLO_Buffer, [SameOperandsShape]>, BASE_HLO_ConvertOp; +def LHLO_ConvertOp : LHLO_UnaryElementwiseOp<"convert", LHLO_Buffer, [SameOperandsShape]> { + let summary = "Convert operator"; + let description = [{ + Performs element-wise conversion of values from one type to another, e.g. + float to int. -def LHLO_CosOp: LHLO_UnaryElementwiseOp<"cosine", LHLO_FpOrComplexBuffer>, BASE_HLO_CosOp; + See https://www.tensorflow.org/xla/operation_semantics#convertelementtype. + }]; +} +def LHLO_CosOp: LHLO_UnaryElementwiseOp<"cosine", LHLO_FpOrComplexBuffer> { + let summary = "Cos operator"; + let description = [{ + Returns `Cos(operand)` element-wise. -def LHLO_ExpOp: LHLO_UnaryElementwiseOp<"exponential", LHLO_FpOrComplexBuffer>, BASE_HLO_ExpOp; + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; +} +def LHLO_ExpOp: LHLO_UnaryElementwiseOp<"exponential", LHLO_FpOrComplexBuffer> { + let summary = "Exponential operator"; + let description = [{ + Returns `e^(operand)` element-wise. -def LHLO_Expm1Op: LHLO_UnaryElementwiseOp<"exponential_minus_one", LHLO_FpOrComplexBuffer>, BASE_HLO_Expm1Op; + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; +} +def LHLO_Expm1Op: LHLO_UnaryElementwiseOp<"exponential_minus_one", LHLO_FpOrComplexBuffer> { + let summary = "Exponential minus one operator"; + let description = [{ + Returns `e^(operand) - 1` element-wise. -def LHLO_FloorOp: LHLO_UnaryElementwiseOp<"floor", LHLO_FpBuffer>, BASE_HLO_FloorOp; + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; +} +def LHLO_FloorOp: LHLO_UnaryElementwiseOp<"floor", LHLO_FpBuffer> { + let summary = "Floor operator"; + let description = [{ + Returns `Floor(operand)` element-wise. -def LHLO_ImagOp: LHLO_Op<"imag", [SameOperandsShape]>, BASE_HLO_ImagOp { + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; +} +def LHLO_ImagOp: LHLO_Op<"imag", [SameOperandsShape]> { + let summary = "Imag operator"; + let description = [{ + Returns `Imag(operand)` element-wise. + }]; let arguments = (ins Arg:$input, Arg:$output); } -def LHLO_IsFiniteOp: LHLO_Op<"is_finite", [SameOperandsShape]>, BASE_HLO_IsFiniteOp { +def LHLO_IsFiniteOp: LHLO_Op<"is_finite", [SameOperandsShape]> { + let summary = "IsFinite operator"; + let description = [{ + Tests whether each element of operand is finite, i.e., is not positive or + negative infinity, and is not NaN. Returns a tensor of 1-bit integers with + the same shape as the input, where each element is nonzero (i.e. true) if + and only if the corresponding input element is finite. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; let arguments = (ins Arg:$input, Arg:$output); } -def LHLO_LogOp: LHLO_UnaryElementwiseOp<"log", LHLO_FpOrComplexBuffer>, BASE_HLO_LogOp; +def LHLO_LogOp: LHLO_UnaryElementwiseOp<"log", LHLO_FpOrComplexBuffer> { + let summary = "Logarithm operator"; + let description = [{ + Returns `log(operand)` element-wise. -def LHLO_LogisticOp : LHLO_UnaryElementwiseOp<"logistic", LHLO_FpOrComplexBuffer>, BASE_HLO_LogisticOp; + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; +} +def LHLO_LogisticOp : LHLO_UnaryElementwiseOp<"logistic", LHLO_FpOrComplexBuffer> { + let summary = "Logistic operator"; + let description = [{ + Returns `logistic(operand)` element-wise. -def LHLO_Log1pOp: LHLO_UnaryElementwiseOp<"log_plus_one", LHLO_FpOrComplexBuffer>, BASE_HLO_Log1pOp; + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; +} +def LHLO_Log1pOp: LHLO_UnaryElementwiseOp<"log_plus_one", LHLO_FpOrComplexBuffer> { + let summary = "Log1p operator"; + let description = [{ + Returns `log(operand+1)` element-wise. -def LHLO_NegOp: LHLO_UnaryElementwiseOp<"negate">, BASE_HLO_NegOp; + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; +} +def LHLO_NegOp: LHLO_UnaryElementwiseOp<"negate"> { + let summary = "Negation operator"; + let description = [{ + Returns `-operand` element-wise. -def LHLO_NotOp: LHLO_UnaryElementwiseOp<"not", LHLO_PredOrIntBuffer>, BASE_HLO_NotOp; + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; +} +def LHLO_NotOp: LHLO_UnaryElementwiseOp<"not", LHLO_PredOrIntBuffer> { + let summary = "Not operator"; + let description = [{ + Returns `!operand` element-wise. -def LHLO_PopulationCountOp: LHLO_UnaryElementwiseOp<"popcnt", LHLO_IntBuffer>, BASE_HLO_PopulationCountOp; + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; +} +def LHLO_PopulationCountOp: LHLO_UnaryElementwiseOp<"popcnt", LHLO_IntBuffer> { + let summary = "PopulationCount operator"; + let description = [{ + Returns the number of bits set in each operand element-wise. -def LHLO_RealOp: LHLO_Op<"real", [SameOperandsShape]>, BASE_HLO_RealOp { + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; +} +def LHLO_RealOp: LHLO_Op<"real", [SameOperandsShape]> { + let summary = "Real operator"; + let description = [{ + Returns `Real(operand)` element-wise. + }]; let arguments = (ins Arg:$input, Arg:$output); } -def LHLO_RoundOp: LHLO_UnaryElementwiseOp<"round_nearest_afz", LHLO_FpBuffer>, BASE_HLO_RoundOp; +def LHLO_RoundOp: LHLO_UnaryElementwiseOp<"round_nearest_afz", LHLO_FpBuffer> { + let summary = "Round operator"; + let description = [{ + Returns `Round(operand)` element-wise, rounding to nearest integer with + half-way cases rounding away from zero. -def LHLO_RsqrtOp: LHLO_UnaryElementwiseOp<"rsqrt", LHLO_FpOrComplexBuffer>, BASE_HLO_RsqrtOp; + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; +} +def LHLO_RsqrtOp: LHLO_UnaryElementwiseOp<"rsqrt", LHLO_FpOrComplexBuffer> { + let summary = "Reciprocal Square-root operator"; + let description = [{ + Returns `1.0 / sqrt(operand)` element-wise. -def LHLO_SqrtOp: LHLO_UnaryElementwiseOp<"sqrt", LHLO_FpOrComplexBuffer>, BASE_HLO_SqrtOp; + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; +} +def LHLO_SqrtOp: LHLO_UnaryElementwiseOp<"sqrt", LHLO_FpOrComplexBuffer> { + let summary = "Square-root operator"; + let description = [{ + Returns `sqrt(operand)` element-wise. -def LHLO_SignOp: LHLO_UnaryElementwiseOp<"sign">, BASE_HLO_SignOp; + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; +} +def LHLO_SignOp: LHLO_UnaryElementwiseOp<"sign"> { + let summary = "Sign operator"; + let description = [{ + Returns `sign(operand)` element-wise, where -def LHLO_SinOp: LHLO_UnaryElementwiseOp<"sine", LHLO_FpOrComplexBuffer>, BASE_HLO_SinOp; + ``` + sign(x) = -1 : x < 0 + = -0 : x = -0 + = NaN : x = NaN + = +0 : x = +0 + = 1 : x > 0 + ``` -def LHLO_TanhOp: LHLO_UnaryElementwiseOp<"tanh", LHLO_FpOrComplexBuffer>, BASE_HLO_TanhOp; + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; +} +def LHLO_SinOp: LHLO_UnaryElementwiseOp<"sine", LHLO_FpOrComplexBuffer> { + let summary = "Sin operator"; + let description = [{ + Returns `Sin(operand)` element-wise. + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; +} +def LHLO_TanhOp: LHLO_UnaryElementwiseOp<"tanh", LHLO_FpOrComplexBuffer> { + let summary = "Tanh operator"; + let description = [{ + Returns `tanh(operand)` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; +} //===----------------------------------------------------------------------===// // LMHLO binary elementwise op definitions. //===----------------------------------------------------------------------===// @@ -160,13 +352,39 @@ class LHLO_BinaryElementwiseOp, BASE_HLO_AddOp; +def LHLO_AddOp : LHLO_BinaryElementwiseOp<"add"> { + let summary = "Addition operator"; + let description = [{ + Returns `lhs + rhs` element-wise. -def LHLO_AndOp: LHLO_BinaryElementwiseOp<"and", LHLO_PredOrIntBuffer>, BASE_HLO_AndOp; + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} +def LHLO_AndOp: LHLO_BinaryElementwiseOp<"and", LHLO_PredOrIntBuffer> { + let summary = "Logical and"; + let description = [{ + Returns `logical_and(lhs, rhs)` element-wise. -def LHLO_Atan2Op : LHLO_BinaryElementwiseOp<"atan2", LHLO_FpOrComplexBuffer>, BASE_HLO_Atan2Op; + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} +def LHLO_Atan2Op : LHLO_BinaryElementwiseOp<"atan2", LHLO_FpOrComplexBuffer> { + let summary = "Atan2 operator"; + let description = [{ + Returns `atan2(lhs/rhs)` element-wise. -def LHLO_ComplexOp: LHLO_Op<"complex", [SameOperandsShape]>, BASE_HLO_ComplexOp { + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} +def LHLO_ComplexOp: LHLO_Op<"complex", [SameOperandsShape]> { + let summary = "Complex operator"; + let description = [{ + Performs element-wise conversion of a pair of real and imaginary values to + a complex value. + }]; let arguments = (ins Arg:$lhs, Arg:$rhs, @@ -175,30 +393,114 @@ def LHLO_ComplexOp: LHLO_Op<"complex", [SameOperandsShape]>, BASE_HLO_ComplexOp ); } -def LHLO_DivOp : LHLO_BinaryElementwiseOp<"divide">, BASE_HLO_DivOp; +def LHLO_DivOp : LHLO_BinaryElementwiseOp<"divide"> { + let summary = "Division operator"; + let description = [{ + Returns `lhs / rhs` element-wise. -def LHLO_MaxOp : LHLO_BinaryElementwiseOp<"maximum">, BASE_HLO_MaxOp; + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} +def LHLO_MaxOp : LHLO_BinaryElementwiseOp<"maximum"> { + let summary = "Maximum operator"; + let description = [{ + Returns `max(lhs, rhs)` element-wise. -def LHLO_MinOp : LHLO_BinaryElementwiseOp<"minimum">, BASE_HLO_MinOp; + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} +def LHLO_MinOp : LHLO_BinaryElementwiseOp<"minimum"> { + let summary = "Minimum operator"; + let description = [{ + Returns `min(lhs, rhs)` element-wise. -def LHLO_MulOp : LHLO_BinaryElementwiseOp<"multiply">, BASE_HLO_MulOp; + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} +def LHLO_MulOp : LHLO_BinaryElementwiseOp<"multiply"> { + let summary = "Multiplication operator"; + let description = [{ + Returns `lhs * rhs` element-wise. -def LHLO_OrOp : LHLO_BinaryElementwiseOp<"or", LHLO_PredOrIntBuffer>, BASE_HLO_OrOp; + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} +def LHLO_OrOp : LHLO_BinaryElementwiseOp<"or", LHLO_PredOrIntBuffer> { + let summary = "Logical or"; + let description = [{ + Returns `logical_or(lhs, rhs)` element-wise. -def LHLO_PowOp : LHLO_BinaryElementwiseOp<"power">, BASE_HLO_PowOp; + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} +def LHLO_PowOp : LHLO_BinaryElementwiseOp<"power"> { + let summary = "Power operator"; + let description = [{ + Returns `lhs ^ rhs` element-wise. -def LHLO_RemOp : LHLO_BinaryElementwiseOp<"remainder", LHLO_IntOrFpBuffer>, BASE_HLO_RemOp; + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} +def LHLO_RemOp : LHLO_BinaryElementwiseOp<"remainder", LHLO_IntOrFpBuffer> { + let summary = "Remainder operator"; + let description = [{ + Returns `lhs % rhs` element-wise. -def LHLO_ShiftLeftOp : LHLO_BinaryElementwiseOp<"shift_left", LHLO_IntBuffer>, BASE_HLO_ShiftLeftOp; + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} +def LHLO_ShiftLeftOp : LHLO_BinaryElementwiseOp<"shift_left", LHLO_IntBuffer> { + let summary = "Shift Left operator"; + let description = [{ + Returns `lhs << rhs` element-wise. -def LHLO_ShiftRightArithmeticOp : LHLO_BinaryElementwiseOp<"shift_right_arithmetic", LHLO_IntBuffer>, BASE_HLO_ShiftRightArithmeticOp; + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} +def LHLO_ShiftRightArithmeticOp : LHLO_BinaryElementwiseOp<"shift_right_arithmetic", LHLO_IntBuffer> { + let summary = "Shift right arithmetic operator"; + let description = [{ + Returns arithmetic `lhs >> rhs` element-wise. -def LHLO_ShiftRightLogicalOp : LHLO_BinaryElementwiseOp<"shift_right_logical", LHLO_IntBuffer>, BASE_HLO_ShiftRightLogicalOp; + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} +def LHLO_ShiftRightLogicalOp : LHLO_BinaryElementwiseOp<"shift_right_logical", LHLO_IntBuffer> { + let summary = "Shift right logical operator"; + let description = [{ + Returns logical `lhs >> rhs` element-wise. -def LHLO_SubOp : LHLO_BinaryElementwiseOp<"subtract">, BASE_HLO_SubOp; + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} +def LHLO_SubOp : LHLO_BinaryElementwiseOp<"subtract"> { + let summary = "Subtraction operator"; + let description = [{ + Returns `lhs - rhs` element-wise. -def LHLO_XorOp : LHLO_BinaryElementwiseOp<"xor", LHLO_PredOrIntBuffer>, BASE_HLO_XorOp; + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} +def LHLO_XorOp : LHLO_BinaryElementwiseOp<"xor", LHLO_PredOrIntBuffer> { + let summary = "Logical xor"; + let description = [{ + Returns `logical_xor(lhs, rhs)` element-wise. + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} //===----------------------------------------------------------------------===// // LMHLO control flow op definitions. //===----------------------------------------------------------------------===// @@ -208,7 +510,14 @@ def LHLO_XorOp : LHLO_BinaryElementwiseOp<"xor", LHLO_PredOrIntBuffer>, BASE_HLO // The region `body` may return lmhlo.TerminatorOp or mhlo.ReturnOp. We are // moving towards mhlo.ReturnOp, but some code that needs cleanup still assumes lmhlo.TerminatorOp. // TODO(timshen): cleanup lmhlo.TerminatorOp. -def LHLO_ReduceOp: LHLO_Op<"reduce", [SameVariadicOperandSize]>, BASE_HLO_ReduceOp { +def LHLO_ReduceOp: LHLO_Op<"reduce", [SameVariadicOperandSize]> { + let summary = "Reduce operator"; + let description = [{ + Returns the result of executing a reduction function on one or more arrays + in parallel. + + See https://www.tensorflow.org/xla/operation_semantics#reduce. + }]; let arguments = (ins Arg, "", [MemRead]>:$inputs, Arg, "", [MemRead]>:$init_values, @@ -221,8 +530,14 @@ def LHLO_ReduceOp: LHLO_Op<"reduce", [SameVariadicOperandSize]>, BASE_HLO_Reduce let hasCanonicalizer = 1; } -def LHLO_ReduceWindowOp: LHLO_Op<"reduce_window", [SameVariadicOperandSize]>, - BASE_HLO_ReduceWindowOp { +def LHLO_ReduceWindowOp: LHLO_Op<"reduce_window", [SameVariadicOperandSize]> { + let summary = "ReduceWindow operator"; + let description = [{ + Returns the result of executing a reduction function over all elements in + each window of one or more arrays in parallel. + + See https://www.tensorflow.org/xla/operation_semantics#reducewindow. + }]; let arguments = (ins Arg, "", [MemRead]>:$inputs, Arg, "", [MemRead]>:$init_values, @@ -244,8 +559,21 @@ def LHLO_ReduceWindowOp: LHLO_Op<"reduce_window", [SameVariadicOperandSize]>, // TODO(timshen): Add a custom syntax for this. def LHLO_CaseOp: LHLO_Op<"case", [ SingleBlockImplicitTerminator<"TerminatorOp">, - DeclareOpInterfaceMethods]>, - BASE_HLO_CaseOp { + DeclareOpInterfaceMethods]> { + let summary = "Switch-Case operator"; + let description = [{ + Returns the result of executing `branches[index]`. If + `index` is < 0 or >= N, then `branches[N-1] is executed as + the default branch. + + Each branch `branches[b]` must take in a single argument of same type as + `branch_operands[b]` and will be invoked with `branch_operands[b]`. The type + of the returned value of each branch must be the same. + + Note that only one of the branches will be executed depending on the value + of index. + See https://www.tensorflow.org/xla/operation_semantics#conditional. + }]; let arguments = (ins Arg:$index); @@ -255,8 +583,14 @@ def LHLO_CaseOp: LHLO_Op<"case", [ // TODO(timshen): Add a custom syntax for this. def LHLO_WhileOp: LHLO_Op<"while", [ DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods]>, - BASE_HLO_WhileOp { + DeclareOpInterfaceMethods]> { + let summary = "While operator"; + let description = [{ + Returns the result of executing a body function until the cond body returns + true. + + See https://www.tensorflow.org/xla/operation_semantics#while. + }]; let arguments = (ins Arg, "", [MemWrite]>:$cond_val, OptionalAttr:$trip_count); @@ -264,8 +598,21 @@ def LHLO_WhileOp: LHLO_Op<"while", [ let regions = (region SizedRegion<1>:$cond, SizedRegion<1>:$body); } -def LHLO_CustomCallOp : LHLO_Op<"custom_call", [AttrSizedOperandSegments]>, - BASE_HLO_CustomCallOp { +def LHLO_CustomCallOp : LHLO_Op<"custom_call", [AttrSizedOperandSegments]> { + let summary = "CustomCall operator"; + let description = [{ + A custom call invokes code external to XLA. The `args` are passed to the + external code, and the external code is expected to produce a result of the + given type. The exact mechanism is backend-specific. For example, in the CPU + backend, a call instruction is emitted which targets a symbol with the name + `call_target_name`. + + `call_target_name` and `backend_config` can be arbitrary strings, but + `call_target_name` should be short as it may be used in labels. + `backend_config` can encode arbitrarily large amounts of information. + + See https://www.tensorflow.org/xla/operation_semantics#customcall. + }]; let arguments = (ins Arg, "", [MemRead]>:$args, Arg, "", [MemWrite]>:$output, @@ -281,7 +628,17 @@ def LHLO_CustomCallOp : LHLO_Op<"custom_call", [AttrSizedOperandSegments]>, // LMHLO tuple op definitions. //===----------------------------------------------------------------------===// -def LHLO_CompareOp: LHLO_Op<"compare", []>, BASE_HLO_CompareOp { +def LHLO_CompareOp: LHLO_Op<"compare", []> { + let summary = "Comparison operator"; + let description = [{ + Compares `lhs` and `rhs` elementwise according to `comparison_direction` + and `compare_type`. If unspecified, `compare_type` is FLOAT for float element + types, SIGNED for signed element types and UNSIGNED for unsigned element + types. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_comparison_operations. + }]; let arguments = (ins Arg:$lhs, Arg:$rhs, @@ -309,7 +666,13 @@ def LHLO_SliceOp: LHLO_Op< } def LHLO_DynamicSliceOp: LHLO_Op<"dynamic_slice", - [AllElementTypesMatch<["operand", "output"]>]>, BASE_HLO_DynamicSliceOp { + [AllElementTypesMatch<["operand", "output"]>]> { + let summary = "Dynamic Slice operator"; + let description = [{ + Extracts a sub-array from the input array at dynamic start_indices. + + See https://www.tensorflow.org/xla/operation_semantics#dynamicslice. + }]; let arguments = (ins Arg:$operand, Arg, "", [MemRead]>:$start_indices, @@ -318,7 +681,14 @@ def LHLO_DynamicSliceOp: LHLO_Op<"dynamic_slice", ); } -def LHLO_DynamicUpdateSliceOp: LHLO_Op<"dynamic-update-slice", []>, BASE_HLO_DynamicUpdateSliceOp { +def LHLO_DynamicUpdateSliceOp: LHLO_Op<"dynamic-update-slice", []> { + let summary = "Dynamic Update Slice operator"; + let description = [{ + DynamicUpdateSlice generates a result which is the value of the input array + operand, with a slice update overwritten at start_indices. + + See https://www.tensorflow.org/xla/operation_semantics#dynamicupdateslice. + }]; let arguments = (ins Arg:$operand, Arg:$update, @@ -331,8 +701,13 @@ def LHLO_DynamicUpdateSliceOp: LHLO_Op<"dynamic-update-slice", []>, BASE_HLO_Dyn // LMHLO Other op definitions. //===----------------------------------------------------------------------===// -def LHLO_BatchNormGradOp : LHLO_Op<"batch_norm_grad", []>, - BASE_HLO_BatchNormGradOp { +def LHLO_BatchNormGradOp : LHLO_Op<"batch_norm_grad", []> { + let summary = "Batch Normalization Gradient"; + let description = [{ + Calculates gradients of batch norm. + + See https://www.tensorflow.org/xla/operation_semantics#batchnormgrad + }]; let arguments = (ins Arg:$operand, @@ -349,8 +724,13 @@ def LHLO_BatchNormGradOp : LHLO_Op<"batch_norm_grad", []>, } -def LHLO_BatchNormInferenceOp : LHLO_Op<"batch_norm_inference", []>, - BASE_HLO_BatchNormInferenceOp { +def LHLO_BatchNormInferenceOp : LHLO_Op<"batch_norm_inference", []> { + let summary = "Batch Normalization for Inference"; + let description = [{ + Normalizes an array across batch and spatial dimensions. + + See https://www.tensorflow.org/xla/operation_semantics#batchnorminference + }]; let arguments = (ins Arg:$operand, @@ -364,8 +744,13 @@ def LHLO_BatchNormInferenceOp : LHLO_Op<"batch_norm_inference", []>, ); } -def LHLO_BatchNormTrainingOp : LHLO_Op<"batch_norm_training", []>, - BASE_HLO_BatchNormTrainingOp { +def LHLO_BatchNormTrainingOp : LHLO_Op<"batch_norm_training", []> { + let summary = "Batch Normalization for Training"; + let description = [{ + Normalizes an array across batch and spatial dimensions. + + See https://www.tensorflow.org/xla/operation_semantics#batchnormtraining + }]; let arguments = (ins Arg:$operand, @@ -380,7 +765,19 @@ def LHLO_BatchNormTrainingOp : LHLO_Op<"batch_norm_training", []>, } def LHLO_BroadcastOp : LHLO_Op<"broadcast", - []>, BASE_HLO_BroadcastOp { + []> { + let summary = "Broadcast a tensor to a higher rank by prepending dimensions"; + let description = [{ + Broadcasts the operand tensor to a higher rank by prepending + `broadcast_sizes` to the dimensions. The current values of the operand are + copied into the other dimensions. + + This is a more limited form of broadcasting, that corresponds to the XLA + client Broadcast method. For a more general form of broadcasting, see the + BroadcastInDimOp. + + See https://www.tensorflow.org/xla/operation_semantics#broadcast. + }]; let arguments = (ins Arg:$operand, Arg:$output, @@ -389,7 +786,24 @@ def LHLO_BroadcastOp : LHLO_Op<"broadcast", } def LHLO_BroadcastInDimOp : LHLO_Op<"broadcast_in_dim", - []>, BASE_HLO_BroadcastInDimOp { + []> { + let summary = "Broadcast a tensor into the given shape by adding dimensions."; + let description = [{ + Broadcasts the `operand` tensor to a higher rank. This is not the limited + form of broadcasting exposed as the XLA client broadcast op, but rather the + more powerful "InDim" broadcasting, which is closer to the HLO broadcast op + and exposed in the XLA client BroadcastInDim method. + + `broadcast_dimensions` maps the operand dimension number to the target shape + dimension number. It must have the same size as the rank of the operand. The + mapped dimensions must either be the same size or the dimension being + broadcast from must be size 1 (degenerate broadcasting). + + For a scalar (0D tensor) operand, `broadcast_dimensions` must be empty. The + The scalar value will be broadcast to every element in the target shape. + + See https://www.tensorflow.org/xla/broadcasting. + }]; let arguments = (ins Arg:$operand, Arg:$output, @@ -397,7 +811,17 @@ def LHLO_BroadcastInDimOp : LHLO_Op<"broadcast_in_dim", ); } -def LHLO_ClampOp : LHLO_Op<"clamp", []>, BASE_HLO_ClampOp { +def LHLO_ClampOp : LHLO_Op<"clamp", []> { + let summary = "Clamp operator"; + let description = [{ + Clamps an operand to within the range between a minimum and maximum value. + + Note: All three arrays must be the same shape. Alternatively, as a + restricted form of broadcasting, min and/or max can be a scalar (0D + tensor) of the element type of the tensor operand. + + See https://www.tensorflow.org/xla/operation_semantics#clamp. + }]; let arguments = (ins Arg:$min, Arg:$operand, @@ -406,7 +830,13 @@ def LHLO_ClampOp : LHLO_Op<"clamp", []>, BASE_HLO_ClampOp { ); } -def LHLO_ConcatenateOp : LHLO_Op<"concatenate", []>, BASE_HLO_ConcatenateOp { +def LHLO_ConcatenateOp : LHLO_Op<"concatenate", []> { + let summary = "XLA's concatenate op"; + let description = [{ + Concatenates a set of tensors along the specified dimension. + + See https://www.tensorflow.org/xla/operation_semantics#concatenate. + }]; let arguments = (ins Arg, "", [MemRead]>:$val, Arg:$output, @@ -423,7 +853,11 @@ def LHLO_ConvOp : LHLO_Op<"convolution", []>, BASE_HLO_ConvOp { ConvolutionAttributes.attributes); } -def LHLO_CopyOp: LHLO_Op<"copy", [CopyOpInterface]>, BASE_HLO_CopyOp { +def LHLO_CopyOp: LHLO_Op<"copy", [CopyOpInterface]> { + let summary = "Copy operator"; + let description = [{ + Returns a copy of `operand`. + }]; let arguments = (ins Arg:$operand, Arg:$output @@ -435,7 +869,14 @@ def LHLO_CopyOp: LHLO_Op<"copy", [CopyOpInterface]>, BASE_HLO_CopyOp { }]; } -def LHLO_DotOp: LHLO_Op<"dot", []>, BASE_HLO_DotOp { +def LHLO_DotOp: LHLO_Op<"dot", []> { + let summary = "Dot operator"; + let description = [{ + Performs dot products between vectors, vector/matrix and matrix/matrix + multiplication. + + See https://www.tensorflow.org/xla/operation_semantics#dot. + }]; let arguments = (ins Arg:$lhs, Arg:$rhs, @@ -445,7 +886,7 @@ def LHLO_DotOp: LHLO_Op<"dot", []>, BASE_HLO_DotOp { ); } -def LHLO_GatherOp: LHLO_Op<"gather", []>, BASE_HLO_GatherOp { +def LHLO_GatherOp: LHLO_Op<"gather", []> { let arguments = (ins Arg:$operand, Arg:$start_indices, @@ -455,14 +896,28 @@ def LHLO_GatherOp: LHLO_Op<"gather", []>, BASE_HLO_GatherOp { ); } -def LHLO_ReshapeOp: LHLO_Op<"reshape", []>, BASE_HLO_ReshapeOp { +def LHLO_ReshapeOp: LHLO_Op<"reshape", []> { + let summary = "Reshape operator"; + let description = [{ + Reshapes the dimensions of `operand` into a new configuration. + + See https://www.tensorflow.org/xla/operation_semantics#reshape. + }]; let arguments = (ins Arg:$operand, Arg:$output ); } -def LHLO_ScatterOp: LHLO_Op<"scatter", []>, BASE_HLO_ScatterOp { +def LHLO_ScatterOp: LHLO_Op<"scatter", []> { + let summary = "Scatter operator"; + let description = [{ + Generates a result which is the value of the input array `operand`, + with several slices (at indices specified by `scatter_indices`) + updated with the values in `updates` using `update_computation`. + + See https://www.tensorflow.org/xla/operation_semantics#scatter. + }]; let arguments = (ins Arg:$operand, Arg:$scatter_indices, @@ -476,7 +931,14 @@ def LHLO_ScatterOp: LHLO_Op<"scatter", []>, BASE_HLO_ScatterOp { let regions = (region SizedRegion<1>:$update_computation); } -def LHLO_SelectOp: LHLO_Op<"select", []>, BASE_HLO_SelectOp { +def LHLO_SelectOp: LHLO_Op<"select", []> { + let summary = "Select operator"; + let description = [{ + Constructs an output tensor from the elements of `on_true` and `on_false` + based on the values of `pred`. + + `pred`, `on_true` and `on_false` must be broadcast compatible. + }]; let arguments = (ins Arg:$pred, Arg:$on_true, @@ -485,8 +947,18 @@ def LHLO_SelectOp: LHLO_Op<"select", []>, BASE_HLO_SelectOp { ); } -def LHLO_SelectAndScatterOp: LHLO_Op<"select_and_scatter", []>, - BASE_HLO_SelectAndScatterOp { +def LHLO_SelectAndScatterOp: LHLO_Op<"select_and_scatter", []> { + let summary = "SelectAndScatter operator"; + let description = [{ + Runs a windowed selection `select` function over `operand` with shape + `window_dimensions` and stride `window_strides`. This will produce an amount + of selected locations whose shape matches `source`. These are then scattered + to the output which is initialized with `init_value`. + Multiple scattered elements which land in the same output location are + combined using the `scatter` function. + + See https://www.tensorflow.org/xla/operation_semantics#selectandscatter. + }]; let arguments = (ins Arg:$operand, Arg:$source, @@ -500,7 +972,14 @@ def LHLO_SelectAndScatterOp: LHLO_Op<"select_and_scatter", []>, let regions = (region SizedRegion<1>:$select, SizedRegion<1>:$scatter); } -def LHLO_ReverseOp: LHLO_Op<"reverse", []>, BASE_HLO_ReverseOp { +def LHLO_ReverseOp: LHLO_Op<"reverse", []> { + let summary = "Reverse operator"; + let description = [{ + Reverses the specified dimensions of `operand` according to the given + `dimensions`. + + See https://www.tensorflow.org/xla/operation_semantics#rev_reverse. + }]; let arguments = (ins Arg:$operand, I64ElementsAttr:$dimensions, @@ -508,7 +987,14 @@ def LHLO_ReverseOp: LHLO_Op<"reverse", []>, BASE_HLO_ReverseOp { ); } -def LHLO_PadOp: LHLO_Op<"pad", []>, BASE_HLO_PadOp { +def LHLO_PadOp: LHLO_Op<"pad", []> { + let summary = "Pad operator"; + let description = [{ + Pads the edges of `operand` with the `padding_value` and according to + the passed configuration. + + See https://www.tensorflow.org/xla/operation_semantics#pad. + }]; let arguments = (ins Arg:$operand, Arg:$padding_value, @@ -519,7 +1005,15 @@ def LHLO_PadOp: LHLO_Op<"pad", []>, BASE_HLO_PadOp { ); } -def LHLO_TransposeOp: LHLO_Op<"transpose", []>, BASE_HLO_TransposeOp { +def LHLO_TransposeOp: LHLO_Op<"transpose", []> { + let summary = "Transpose operator"; + let description = [{ + Permutes the dimensions of `operand` according to the given `permutation`. + + `res_dimensions[i] = operand_dimensions[permutation[i]]` + + See https://www.tensorflow.org/xla/operation_semantics#transpose. + }]; let arguments = (ins Arg:$operand, I64ElementsAttr:$permutation, @@ -527,8 +1021,18 @@ def LHLO_TransposeOp: LHLO_Op<"transpose", []>, BASE_HLO_TransposeOp { ); } -def LHLO_ReducePrecisionOp: LHLO_Op<"reduce_precision", [SameTypeOperands]>, - BASE_HLO_ReducePrecisionOp { +def LHLO_ReducePrecisionOp: LHLO_Op<"reduce_precision", [SameTypeOperands]> { + let summary = "Reduce precision operator"; + let description = [{ + Models the effect of converting floating - point values to a lower - + precision format(such as IEEE - FP16) and back to the original + format. The number of exponent and mantissa bits in the lower - + precision format can be specified arbitrarily, + although all bit sizes may not be supported on all hardware + implementations. + + See https://www.tensorflow.org/xla/operation_semantics#reduceprecision. + }]; let arguments = (ins Arg:$operand, Arg:$output, @@ -555,28 +1059,49 @@ class LHLO_CollectiveCommunicationOp traits = []> : }]; } -def LHLO_AllGatherOp : LHLO_CollectiveCommunicationOp<"all_gather">, - BASE_HLO_AllGatherOp { +def LHLO_AllGatherOp : LHLO_CollectiveCommunicationOp<"all_gather"> { + let summary = "AllGather operator"; + let description = [{ + Performs concatenation across replicas. + + See https://www.tensorflow.org/xla/operation_semantics#allgather + }]; let arguments = !con( arguments_base, (ins I64Attr:$all_gather_dimension)); } -def LHLO_AllReduceOp : LHLO_CollectiveCommunicationOp<"all_reduce", [SameOperandsElementType]>, - BASE_HLO_AllReduceOp { +def LHLO_AllReduceOp : LHLO_CollectiveCommunicationOp<"all_reduce", [SameOperandsElementType]> { + let summary = "AllReduce operator"; + let description = [{ + Performs a custom reduction across replicas. + + See https://www.tensorflow.org/xla/operation_semantics#allreduce. + }]; let arguments = arguments_base; let regions = (region SizedRegion<1>:$computation); } -def LHLO_AllToAllOp : LHLO_CollectiveCommunicationOp<"all_to_all", [SameOperandsElementType]>, - BASE_HLO_AllToAllOp { +def LHLO_AllToAllOp : LHLO_CollectiveCommunicationOp<"all_to_all", [SameOperandsElementType]> { let arguments = !con( arguments_base, (ins OptionalAttr:$split_dimension)); } -def LHLO_CollectivePermuteOp: LHLO_Op<"collective_permute", [SameTypeOperands]>, - BASE_HLO_CollectivePermuteOp { +def LHLO_CollectivePermuteOp: LHLO_Op<"collective_permute", [SameTypeOperands]> { + let summary = "CollectivePermute operator"; + let description = [{ + CollectivePermute is a collective operation that sends and receives data + cross replicas. + Note that there are the following restrictions on the source_target_pair: + - Any two pairs should not have the same target replica id, and they should + not have the same source replica id. + - If a replica id is not a target in any pair, then the output on that + replica is a tensor consists of 0(s) with the same shape as the input. + + See https://www.tensorflow.org/xla/operation_semantics#collectivepermute. + + }]; let arguments = (ins Arg:$operand, @@ -587,7 +1112,14 @@ def LHLO_CollectivePermuteOp: LHLO_Op<"collective_permute", [SameTypeOperands]>, let verifier = [{ return Verify(*this); }]; } -def LHLO_FftOp: LHLO_Op<"fft", []>, BASE_HLO_FftOp { +def LHLO_FftOp: LHLO_Op<"fft", []> { + let summary = "Fast fourier transform operator"; + let description = [{ + Returns the fast-fourier-transform of the input array. + + See + https://www.tensorflow.org/xla/operation_semantics#fft. + }]; let arguments = (ins Arg:$operand, Arg:$output, @@ -596,7 +1128,29 @@ def LHLO_FftOp: LHLO_Op<"fft", []>, BASE_HLO_FftOp { ); } -def LHLO_CholeskyOp: LHLO_Op<"cholesky", [SameOperandsElementType]>, BASE_HLO_CholeskyOp { +def LHLO_CholeskyOp: LHLO_Op<"cholesky", [SameOperandsElementType]> { + let summary = "Cholesky operator"; + let description = [{ + Computes the Cholesky decomposition of a batch of symmetric (Hermitian) + positive definite matrices. + + If lower is true, computes lower-triangular matrices l such that + `a=l.Transpose(l)`. If lower is false, computes upper-triangular matrices u such + that `a=Transpose(u).u`. + + Input data is read only from the lower/upper triangle of a, depending on the + value of lower. Values from the other triangle are ignored. Output data is + returned in the same triangle; the values in the other triangle are + implementation-defined and may be anything. + + If the rank of a is greater than 2, a is treated as a batch of matrices, where + all except the minor 2 dimensions are batch dimensions. + + If a is not symmetric (Hermitian) positive definite, the result is + implementation-defined. + + See https://www.tensorflow.org/xla/operation_semantics#cholesky. + }]; let arguments = (ins Arg:$a, Arg:$output, @@ -604,7 +1158,18 @@ def LHLO_CholeskyOp: LHLO_Op<"cholesky", [SameOperandsElementType]>, BASE_HLO_Ch ); } -def LHLO_InfeedOp: LHLO_Op<"infeed", []>, BASE_HLO_InfeedOp { +def LHLO_InfeedOp: LHLO_Op<"infeed", []> { + let summary = "Infeed operator"; + let description = [{ + Reads a single data item from the implicit Infeed streaming interface of + the device, interpreting the data as the given shape and its layout, and + returns an LHLO op of the data. Multiple Infeed operations are allowed in a + computation, but there must be a total order among the Infeed operations. + For example, two Infeeds in the code below have a total order since there + is a dependency between the while loops. + + See https://www.tensorflow.org/xla/operation_semantics#infeed + }]; let arguments = (ins Arg, "", [MemWrite]>:$outputs, DefaultValuedAttr:$config @@ -618,16 +1183,50 @@ def LHLO_OutfeedOp: LHLO_Op<"outfeed", []> { ); } -def LHLO_ReplicaIdOp : LHLO_Op<"replica_id", []>, BASE_HLO_ReplicaIdOp { +def LHLO_ReplicaIdOp : LHLO_Op<"replica_id", []> { + let summary = "ReplicaId operator"; + let description = [{ + Returns the unique ID (int32 scalar) of the replica. + + The unique ID of each replica is an unsigned integer in the interval [0, N), + where N is the number of replicas. Since all the replicas are running the + same program, a ReplicaId() call in the program will return a different + value on each replica. + + See https://www.tensorflow.org/xla/operation_semantics#replicaid. + }]; let arguments = (ins Arg, "", [MemWrite]>); } -def LHLO_PartitionIdOp : LHLO_Op<"partition_id", []>, BASE_HLO_PartitionIdOp { +def LHLO_PartitionIdOp : LHLO_Op<"partition_id", []> { + let summary = "PartitionId operator"; + let description = [{ + Returns the unique ID (int32 scalar) of the partition. + }]; let arguments = (ins Arg, "", [MemWrite]>); } -def LHLO_TriangularSolveOp: LHLO_Op<"triangular_solve", [SameOperandsElementType]>, - BASE_HLO_TriangularSolveOp { +def LHLO_TriangularSolveOp: LHLO_Op<"triangular_solve", [SameOperandsElementType]> { + let summary = "TriangularSolve operator"; + let description = [{ + Solves systems of linear equations with lower or upper triangular + coefficient matrices by forward- or back-substitution. Broadcasting along + leading dimensions, this routine solves one of the matrix systems + op(a) * x = b, or x * op(a) = b, for the variable x, given a and b, where + op(a) is either op(a) = a, or op(a) = Transpose(a), or + op(a) = Conj(Transpose(a)). + + Input data is read only from the lower/upper triangle of a, depending on the + value of lower. Values from the other triangle are ignored. Output data is + returned in the same triangle; the values in the other triangle are + implementation-defined and may be anything. + + If the rank of a and b are greater than 2, they are treated as batches of + matrices, where all except the minor 2 dimensions are batch dimensions. a + and b must have equal batch dimensions. + + See https://www.tensorflow.org/xla/operation_semantics#triangularsolve. + }]; let arguments = (ins Arg:$a, Arg:$b, @@ -643,7 +1242,20 @@ def LHLO_TriangularSolveOp: LHLO_Op<"triangular_solve", [SameOperandsElementType } // TODO(timshen): add a custom verifier. -def LHLO_MapOp: LHLO_Op<"map", [SameOperandsShape]>, BASE_HLO_MapOp { +def LHLO_MapOp: LHLO_Op<"map", [SameOperandsShape]> { + let summary = "Map operator"; + let description = [{ + Applies a scalar function over the given operands arrays, producing an array + of the same dimensions where each element is the result of the mapped function + applied to the corresponding elements in the input arrays. + + The mapped function is an arbitrary computation with the restriction that it + has N inputs of scalar type T and a single output with type S. The output has + the same dimensions as the operands except that the element type T is replaced + with S. + + See https://www.tensorflow.org/xla/operation_semantics#map. + }]; let arguments = (ins Arg, "", [MemRead]>:$operands, Arg:$output, @@ -660,7 +1272,14 @@ def LHLO_RngGetAndUpdateStateOp: LHLO_Op<"rng_get_and_update_state", []> { } // TODO(timshen): add a custom verifier. -def LHLO_SortOp: LHLO_Op<"sort", [SameVariadicOperandSize, SameOperandsShape]>, BASE_HLO_SortOp { +def LHLO_SortOp: LHLO_Op<"sort", [SameVariadicOperandSize, SameOperandsShape]> { + let summary = "Sort operator"; + let description = [{ + Sorts the given `operands` at the given `dimension` with the given + `comparator`. + + See https://www.tensorflow.org/xla/operation_semantics#sort. + }]; let arguments = (ins Arg, "", [MemRead]>:$operands, Arg, "", [MemWrite]>:$output,