diff --git a/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h new file mode 100644 index 0000000..222b808 --- /dev/null +++ b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h @@ -0,0 +1,45 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_CHLO_OPS_H_ +#define TENSORFLOW_COMPILER_MLIR_XLA_IR_CHLO_OPS_H_ + +#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringRef.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Dialect.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/DialectImplementation.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpDefinition.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Types.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/InferTypeOpInterface.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/SideEffectInterfaces.h" + +namespace mlir { +namespace xla_chlo { + +class XlaHloClientDialect : public Dialect { + public: + explicit XlaHloClientDialect(MLIRContext *context); + static StringRef getDialectNamespace() { return "xla_chlo"; } +}; + +#define GET_OP_CLASSES +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h.inc" + +} // namespace xla_chlo +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_XLA_IR_CHLO_OPS_H_ diff --git a/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td new file mode 100644 index 0000000..4cf48c6 --- /dev/null +++ b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td @@ -0,0 +1,370 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Defines "client" aligned HLO ops. +// These ops are not necessarily orthogonal or optimized for transformation but +// for ease of expression in certain cases deemed important for client +// libraries (i.e. implicit broadcasting, helper ops, etc). +// This dialect is considered to exist in addition to augment the xla_hlo +// dialect for ergonomic needs, not duplicate/replace it. +// +// The typical use of this dialect is for client libraries to be able to emit +// less constrained ops and rely on the conversion framework to lower any +// xla_chlo ops to canonical xla_hlo ops. +// +// See: https://www.tensorflow.org/xla/operation_semantics + +#ifndef CHLO_OPS +#define CHLO_OPS + +include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpBase.td" +include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/InferTypeOpInterface.td" +include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/SideEffectInterfaces.td" +include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td" + +def HLOClient_Dialect : Dialect { + let name = "xla_chlo"; + let cppNamespace = "xla_chlo"; + let summary = [{ + XLA Client HLO Ops + }]; + + let description = [{ + This dialect contains ops that align closely with the API surface area + of the XlaBuilder C++ API, where such ops have semantics that go beyond + what exists in the lower level dialects (such as `xla_hlo`). Essentially, + whenever the client library uses syntactic sugar or composition + of multiple ops for an API call, this dialect tries to model the API call + and provide conversion patterns to fully materialize into lower level + dialects. + }]; +} + +class HLOClient_Op traits> : + Op { + // TODO(b/129012527) Much of this custom verification should be expressed as + // type constraints. + let verifier = [{ return Verify(*this); }]; +} + +//===----------------------------------------------------------------------===// +// XLA binary elementwise op definitions. +// From the client perspective, each of these support both explicit rank +// broadcasting (via the broadcast_dimensions attribute) and implicit degenerate +// shape broadcasting. +// +// These correspond to operations in the xla_hlo dialect without the +// "broadcast_" prefix, except that those ops require same-shaped operands and +// results. +// +// See: +// https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations +// https://www.tensorflow.org/xla/broadcasting +//===----------------------------------------------------------------------===// + +class HLOClient_BroadcastBinaryElementwiseOp< + string mnemonic, list traits> : + HLOClient_Op])> { + let arguments = (ins + HLO_Tensor:$lhs, + HLO_Tensor:$rhs, + // Explicit rank-broadcast dimension mappings. Defaults to "numpy" prefix + // padded rank-broadcast semantics if omitted. + OptionalAttr:$broadcast_dimensions + ); + + let builders = [OpBuilder< + "OpBuilder &builder, OperationState &result, Value left, Value right, " + "DenseIntElementsAttr broadcast_dimensions" + >]; + + let results = (outs HLO_Tensor); + + let assemblyFormat = [{ + $lhs `,` $rhs attr-dict `:` + `(` type($lhs) `,` type($rhs) `)` `->` type(results) + }]; + + let extraClassDeclaration = [{ + // TODO(laurenzo): It isn't clear to me why reifyReturnShapes does not + // have its declaration generated by DeclareOpInterfaceMethods. + LogicalResult reifyReturnTypeShapes( + OpBuilder& builder, SmallVectorImpl& reifiedReturnShapes); + }]; +} + +def HLOClient_BroadcastAddOp : HLOClient_BroadcastBinaryElementwiseOp<"broadcast_add", + [Commutative, NoSideEffect, SameOperandsAndResultElementType]> { + string summary = "Addition operator (with optional broadcasting)"; + + string description = [{ + Returns `lhs + rhs` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} + +def HLOClient_BroadcastAtan2Op : HLOClient_BroadcastBinaryElementwiseOp< + "broadcast_atan2", + [NoSideEffect, SameOperandsAndResultElementType]> { + string summary = "Atan2 operator (with optional broadcasting)"; + + string description = [{ + Returns `atan2(lhs/rhs)` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} + +def HLOClient_BroadcastDivOp : HLOClient_BroadcastBinaryElementwiseOp< + "broadcast_divide", + [NoSideEffect, SameOperandsAndResultElementType]> { + string summary = "Division operator (with optional broadcasting)"; + + string description = [{ + Returns `lhs / rhs` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} + +def HLOClient_BroadcastMaxOp : HLOClient_BroadcastBinaryElementwiseOp< + "broadcast_maximum", + [Commutative, NoSideEffect, SameOperandsAndResultElementType]> { + string summary = "Maximum operator (with optional broadcasting)"; + + string description = [{ + Returns `max(lhs, rhs)` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} + +def HLOClient_BroadcastMinOp : HLOClient_BroadcastBinaryElementwiseOp< + "broadcast_minimum", + [Commutative, NoSideEffect, SameOperandsAndResultElementType]> { + string summary = "Minimum operator (with optional broadcasting)"; + + string description = [{ + Returns `min(lhs, rhs)` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} + +def HLOClient_BroadcastMulOp : HLOClient_BroadcastBinaryElementwiseOp< + "broadcast_multiply", + [Commutative, NoSideEffect, SameOperandsAndResultElementType]> { + string summary = "Multiplication operator (with optional broadcasting)"; + + string description = [{ + Returns `lhs * rhs` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} + +def HLOClient_BroadcastPowOp : HLOClient_BroadcastBinaryElementwiseOp< + "broadcast_power", + [NoSideEffect, SameOperandsAndResultElementType]> { + string summary = "Power operator (with optional broadcasting)"; + + string description = [{ + Returns `lhs ^ rhs` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} + +def HLOClient_BroadcastRemOp : HLOClient_BroadcastBinaryElementwiseOp< + "broadcast_remainder", + [NoSideEffect, SameOperandsAndResultElementType]> { + string summary = "Remainder operator (with optional broadcasting)"; + + string description = [{ + Returns `lhs % rhs` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} + +def HLOClient_BroadcastShiftLeftOp : HLOClient_BroadcastBinaryElementwiseOp< + "broadcast_shift_left", + [NoSideEffect, SameOperandsAndResultElementType]> { + string summary = "Shift left operator (with optional broadcasting)"; + + string description = [{ + Returns `lhs << rhs` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} + +def HLOClient_BroadcastShiftRightArithmeticOp : HLOClient_BroadcastBinaryElementwiseOp< + "broadcast_shift_right_arithmetic", + [NoSideEffect, SameOperandsAndResultElementType]> { + string summary = "Shift right arithmetic operator (with optional broadcasting)"; + + string description = [{ + Returns `lhs >> rhs` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} + +def HLOClient_BroadcastShiftRightLogicalOp : HLOClient_BroadcastBinaryElementwiseOp< + "broadcast_shift_right_logical", + [NoSideEffect, SameOperandsAndResultElementType]> { + string summary = "Shift right logical operator (with optional broadcasting)"; + + string description = [{ + Returns `lhs >> rhs` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} + +def HLOClient_BroadcastSubOp : HLOClient_BroadcastBinaryElementwiseOp< + "broadcast_subtract", + [NoSideEffect, SameOperandsAndResultElementType]> { + string summary = "Subtraction operator (with optional broadcasting)"; + + string description = [{ + Returns `lhs - rhs` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} + +//===----------------------------------------------------------------------===// +// XLA binary elementwise op definitions. +// The same description as the arithmetic binary elementwise ops applies. +//===----------------------------------------------------------------------===// + +class HLOClient_BroadcastBinaryLogicalElementwiseOp : + HLOClient_BroadcastBinaryElementwiseOp< + mnemonic, [Commutative, NoSideEffect]> { + let arguments = (ins + HLO_PredOrIntTensor:$lhs, + HLO_PredOrIntTensor:$rhs, + // Explicit rank-broadcast dimension mappings. Defaults to "numpy" prefix + // padded rank-broadcast semantics if omitted. + OptionalAttr:$broadcast_dimensions + ); +} + +def HLOClient_BroadcastAndOp: HLOClient_BroadcastBinaryLogicalElementwiseOp< + "broadcast_and"> { + string summary = "Logical and operator (with optional broadcasting)"; + + string description = [{ + Returns `logical_and(lhs, rhs)` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} + +def HLOClient_BroadcastOrOp: HLOClient_BroadcastBinaryLogicalElementwiseOp< + "broadcast_or"> { + string summary = "Logical or operator (with optional broadcasting)"; + + string description = [{ + Returns `logical_or(lhs, rhs)` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} + +def HLOClient_BroadcastXorOp : HLOClient_BroadcastBinaryLogicalElementwiseOp< + "broadcast_xor"> { + string summary = "Logical xor operator (with optional broadcasting)"; + + string description = [{ + Returns `logical_xor(lhs, rhs)` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} + +//===----------------------------------------------------------------------===// +// Broadcasting complex op +//===----------------------------------------------------------------------===// + +def HLOClient_BroadcastComplexOp : HLOClient_BroadcastBinaryElementwiseOp< + "broadcast_complex", [NoSideEffect]> { + string summary = "Complex operator (with optional broadcasting)"; + + string description = [{ + Performs element-wise conversion of a pair of real and imaginary values to + a complex value. + }]; + + let arguments = (ins + HLO_FpTensor:$lhs, + HLO_FpTensor:$rhs, + // Explicit rank-broadcast dimension mappings. Defaults to "numpy" prefix + // padded rank-broadcast semantics if omitted. + OptionalAttr:$broadcast_dimensions + ); + let results = (outs HLO_ComplexTensor); +} + +//===----------------------------------------------------------------------===// +// Broadcasting compare op +//===----------------------------------------------------------------------===// + +def HLOClient_BroadcastCompareOp : HLOClient_BroadcastBinaryElementwiseOp< + "broadcast_compare", [NoSideEffect]> { + string summary = "Compare operator (with optional broadcasting)"; + + string description = [{ + Compares `lhs` and `rhs` elementwise according to `comparison_direction`. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_comparison_operations. + }]; + + let arguments = (ins + HLO_Tensor:$lhs, + HLO_Tensor:$rhs, + OptionalAttr:$broadcast_dimensions, + HLO_ComparisonDirectionAttr:$comparison_direction + ); + let results = (outs HLO_PredTensor); + + let builders = [OpBuilder< + "OpBuilder &builder, OperationState &result, Value lhs, Value rhs, " + "DenseIntElementsAttr broadcast_dimensions, StringAttr comparison_direction" + >]; +} + +#endif // CHLO_OPS diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h new file mode 100644 index 0000000..b6360b7 --- /dev/null +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h @@ -0,0 +1,99 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file defines the operations used in the XLA dialect. + +#ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_OPS_H_ +#define TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_OPS_H_ + +#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringRef.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Dialect.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/DialectImplementation.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Location.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpDefinition.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Types.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/InferTypeOpInterface.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/SideEffectInterfaces.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h" + +namespace mlir { +class OpBuilder; + +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_structs.h.inc" + +namespace xla_hlo { + +class XlaHloDialect : public Dialect { + public: + explicit XlaHloDialect(MLIRContext *context); + static StringRef getDialectNamespace() { return "xla_hlo"; } + + // Registered hook to materialize a constant operation from a given attribute + // value with the desired resultant type. + Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type, + Location loc) override; + + // Parses a type registered to this dialect. + Type parseType(DialectAsmParser &parser) const override; + + // Prints a type registered to this dialect. + void printType(Type type, DialectAsmPrinter &os) const override; +}; + +namespace HLOTypes { +enum Kind { + Token = Type::FIRST_XLA_HLO_TYPE, +}; +} // namespace HLOTypes + +class TokenType : public Type::TypeBase { + public: + using Base::Base; + + static TokenType get(MLIRContext *context) { + return Base::get(context, HLOTypes::Token); + } + + // Support method to enable LLVM-style type casting. + static bool kindof(unsigned kind) { return kind == HLOTypes::Token; } +}; + +// Shape derivation function that computes the shape of the result based on +// the first argument. For a 2-dimensional input tensor, this produces IR of +// the form +// +// %0 = dim %arg0, 0 : memref +// %1 = index_cast %0 : index to i64 +// %2 = dim %arg0, 1 : memref +// %3 = index_cast %2 : index to i64 +// %4 = "xla_hlo.scalars_to_dimension_tensor"(%1, %3) +// : (i64, i64) -> tensor<2xi64> +// +// and returns %4 as the shape value. +LogicalResult deriveShapeFromFirstOperand( + OpBuilder *builder, Operation *op, + SmallVectorImpl *reifiedReturnShapes); + +#define GET_OP_CLASSES +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h.inc" + +} // end namespace xla_hlo +} // end namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_OPS_H_ diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td new file mode 100644 index 0000000..97a10d9 --- /dev/null +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td @@ -0,0 +1,1390 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This is the operation definition file for XLA HLO ops which map to the +// traditional definition in xla_data.proto (or are aligned with the goals +// thereof). +// See: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/xla_data.proto + +#ifndef HLO_OPS +#define HLO_OPS + +include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpBase.td" +include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/InferTypeOpInterface.td" +include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/SideEffectInterfaces.td" +include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td" +include "mlir-hlo/Dialect/mhlo/IR/hlo_utils.td" +include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td" + +def HLO_Dialect : Dialect { + let name = "xla_hlo"; + let cppNamespace = "xla_hlo"; +} + +class HLO_Op traits> : + Op { + // Whether this operation has a custom conversion to HLO or not. + bit hasCustomHLOConverter = 0b0; + + // TODO(b/129012527) Much of this custom verification should be expressed as + // type constraints. + let verifier = [{ return Verify(*this); }]; +} + +//===----------------------------------------------------------------------===// +// XLA nullary op definitions. +//===----------------------------------------------------------------------===// + +def HLO_ConstOp : HLO_Op<"constant", + [ConstantLike, NoSideEffect, AllTypesMatch<["value", "output"]>]>, + BASE_HLO_ConstOp { + let arguments = (ins + ElementsAttr:$value + ); + + let results = (outs + HLO_Tensor:$output + ); + + let builders = [OpBuilder< + "OpBuilder &builder, OperationState &result, Attribute value" + >]; + + let printer = [{ return Print(*this, &p); }]; + let parser = [{ return ParseConstOp(&parser, &result); }]; + + let hasFolder = 1; + + // Constant has special conversion logic to HLO. + let hasCustomHLOConverter = 1; +} + +def HLO_IotaOp : HLO_Op<"iota", [NoSideEffect]>, BASE_HLO_IotaOp { + let arguments = (ins I64Attr:$iota_dimension); + + let results = (outs HLO_IntFpOrComplexTensor:$output); + + // TODO(b/130357376): Iota has special conversion logic to HLO. + let hasCustomHLOConverter = 1; +} + +def HLO_DynamicIotaOp: HLO_Op<"dynamic_iota", [NoSideEffect]> { + let summary = "Create linear increasing values from 0 to length -1."; + let description = [{ + Produces an HLO Tensor of the specified shape, with an incremental set of + values along the specified dimension starting at 0. + + Requires: + - The output length of the tensor result. + }]; + + let arguments = (ins HLO_DimensionTensor:$output_shape, I64Attr:$iota_dimension); + let results = (outs HLO_Tensor:$result); + + let hasCanonicalizer = 1; + // Cannot be exported to legacy formats. + let hasCustomHLOConverter = 1; +} + + +def HLO_CreateTokenOp : HLO_Op<"create_token", [NoSideEffect]> { + string summary = "Create Token operator"; + + string description = [{ + Produces a HLO token. Tokens are used for ordering side-effecting perations. + This is exported to HLO as an AfterAll operation with no operands to + generate a token. + }]; + + let results = (outs HLO_Token:$output); +} + +//===----------------------------------------------------------------------===// +// XLA unary elementwise op definitions. +//===----------------------------------------------------------------------===// +// See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions + +class HLO_UnaryElementwiseOp traits, + Type TensorType>: HLO_Op { + let arguments = (ins TensorType:$operand); + let results = (outs TensorType); + let extraClassDeclaration = [{ + static LogicalResult inferReturnTypeComponents( + MLIRContext* context, Optional location, + ValueRange operands, DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl& inferedReturnShapes) { + return failure(); + } + LogicalResult reifyReturnTypeShapes( + OpBuilder& builder, SmallVectorImpl& reifiedReturnShapes) { + return deriveShapeFromFirstOperand(&builder, getOperation(), + &reifiedReturnShapes); + } + bool inferInputOutputShapeEquality(int input, int output) { + return true; + } + llvm::Optional inferEffectiveWorkloadShape() { + return getOperation()->getResult(0); + } + }]; +} + +// Abs supports complex to real, so element type is not guaranteed to match. +def HLO_AbsOp: HLO_UnaryElementwiseOp<"abs", + [NoSideEffect, SameOperandsAndResultShape], + TensorOf<[HLO_SInt, AnyFloat, HLO_Complex]>>, BASE_HLO_AbsOp { + let builders = [OpBuilder< + "OpBuilder &builder, OperationState &result, Value operand" + >]; +} + +def HLO_CeilOp: HLO_UnaryElementwiseOp<"ceil", + [NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_CeilOp; + +def HLO_ConvertOp : HLO_UnaryElementwiseOp< + "convert", [NoSideEffect, SameOperandsAndResultShape], HLO_Tensor>, + BASE_HLO_ConvertOp { + + let builders = [OpBuilder< + "OpBuilder &, OperationState &tblgen_state, Value operand, " + "Type result_element_ty" + >]; + + let hasFolder = 1; + + let hasCustomHLOConverter = 1; +} + +def HLO_ClzOp: HLO_UnaryElementwiseOp<"count_leading_zeros", + [NoSideEffect, SameOperandsAndResultType], HLO_IntTensor>, + BASE_HLO_ClzOp; + +def HLO_CosOp: HLO_UnaryElementwiseOp<"cosine", + [NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor>, + BASE_HLO_CosOp; + +def HLO_ExpOp: HLO_UnaryElementwiseOp<"exponential", + [NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor>, + BASE_HLO_ExpOp; + +def HLO_Expm1Op: HLO_UnaryElementwiseOp<"exponential_minus_one", + [NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor>, + BASE_HLO_Expm1Op; + +def HLO_FloorOp: HLO_UnaryElementwiseOp<"floor", + [NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_FloorOp; + +def HLO_ImagOp: HLO_Op< + "imag", [NoSideEffect, SameOperandsAndResultShape]>, BASE_HLO_ImagOp { + let builders = [OpBuilder< + "OpBuilder &, OperationState &tblgen_state, Value val">]; + + let arguments = (ins HLO_ComplexTensor); + let results = (outs HLO_FpTensor); + let hasFolder = 1; +} + +def HLO_IsFiniteOp: HLO_UnaryElementwiseOp<"is_finite", + [NoSideEffect, SameOperandsAndResultShape], HLO_Tensor>, + BASE_HLO_IsFiniteOp { + let arguments = (ins HLO_FpTensor:$x); + let results = (outs HLO_PredTensor:$y); +} + +def HLO_LogOp: HLO_UnaryElementwiseOp<"log", + [NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor>, + BASE_HLO_LogOp; + +def HLO_Log1pOp: HLO_UnaryElementwiseOp<"log_plus_one", + [NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor>, + BASE_HLO_Log1pOp; + +def HLO_LogisticOp: HLO_UnaryElementwiseOp<"logistic", + [NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor>, + BASE_HLO_LogisticOp; + +def HLO_NotOp: HLO_UnaryElementwiseOp<"not", + [NoSideEffect, SameOperandsAndResultType], HLO_PredOrIntTensor>, + BASE_HLO_NotOp; + +def HLO_NegOp: HLO_UnaryElementwiseOp<"negate", + [NoSideEffect, SameOperandsAndResultType], HLO_IntFpOrComplexTensor>, + BASE_HLO_NegOp; + +def HLO_PopulationCountOp: HLO_UnaryElementwiseOp<"popcnt", + [NoSideEffect, SameOperandsAndResultType], HLO_IntTensor>, + BASE_HLO_PopulationCountOp; + +def HLO_RealOp: HLO_Op< + "real", [NoSideEffect, SameOperandsAndResultShape]>, BASE_HLO_RealOp { + let builders = [OpBuilder< + "OpBuilder &, OperationState &tblgen_state, Value val">]; + + let arguments = (ins HLO_ComplexTensor); + let results = (outs HLO_FpTensor); + let hasFolder = 1; +} + +def HLO_RoundOp: HLO_UnaryElementwiseOp<"round_nearest_afz", + [NoSideEffect, SameOperandsAndResultType], HLO_FpTensor>, BASE_HLO_RoundOp; + +def HLO_RsqrtOp: HLO_UnaryElementwiseOp<"rsqrt", + [NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor>, + BASE_HLO_RsqrtOp; + +def HLO_SignOp: HLO_UnaryElementwiseOp<"sign", + [NoSideEffect, SameOperandsAndResultType], + TensorOf<[HLO_SInt, AnyFloat, HLO_Complex]>>, + BASE_HLO_SignOp; + +def HLO_SinOp: HLO_UnaryElementwiseOp<"sine", + [NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor>, + BASE_HLO_SinOp; + +def HLO_SqrtOp: HLO_UnaryElementwiseOp<"sqrt", + [NoSideEffect, SameOperandsAndResultType], HLO_FpOrComplexTensor>, + BASE_HLO_SqrtOp; + +def HLO_TanhOp: HLO_UnaryElementwiseOp<"tanh", + [NoSideEffect, SameOperandsAndResultType], + HLO_FpOrComplexTensor>, BASE_HLO_TanhOp; + +//===----------------------------------------------------------------------===// +// XLA binary elementwise op definitions. +//===----------------------------------------------------------------------===// +// See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations + +class HLO_BinaryElementwiseOp traits> : + HLO_Op { + let arguments = (ins + HLO_Tensor:$lhs, + HLO_Tensor:$rhs + ); + + let extraClassDeclaration = [{ + static LogicalResult inferReturnTypeComponents( + MLIRContext* context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl& inferedReturnShapes) { + return failure(); + } + LogicalResult reifyReturnTypeShapes( + OpBuilder& builder, SmallVectorImpl& reifiedReturnShapes) { + return deriveShapeFromFirstOperand(&builder, getOperation(), + &reifiedReturnShapes); + } + bool inferInputsShapeEquality(int lhs, int rhs) { + return true; + } + bool inferInputOutputShapeEquality(int input, int output) { + return true; + } + llvm::Optional inferEffectiveWorkloadShape() { + return getOperation()->getResult(0); + } + }]; + + let results = (outs HLO_Tensor); + let parser = [{ return mlir::impl::parseOneResultSameOperandTypeOp(parser, result); }]; + let printer = [{ return mlir::impl::printOneResultOp(getOperation(), p); }]; +} + +def HLO_AddOp : HLO_BinaryElementwiseOp<"add", + [Commutative, NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_AddOp { + let hasFolder = 1; +} + +def HLO_Atan2Op : HLO_BinaryElementwiseOp<"atan2", + [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_Atan2Op; + +def HLO_ComplexOp: HLO_Op<"complex", + [NoSideEffect, SameOperandsAndResultShape]>, + BASE_HLO_ComplexOp { + let builders = [OpBuilder< + "OpBuilder &, OperationState &tblgen_state, Value lhs, Value rhs">]; + + let arguments = (ins HLO_FpTensor:$lhs, HLO_FpTensor:$rhs); + let results = (outs HLO_ComplexTensor); + let hasFolder = 1; +} + +def HLO_DivOp : HLO_BinaryElementwiseOp<"divide", + [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_DivOp { + let hasFolder = 1; +} + +def HLO_MaxOp : HLO_BinaryElementwiseOp<"maximum", + [Commutative, NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_MaxOp { + let hasFolder = 1; +} + +def HLO_MinOp : HLO_BinaryElementwiseOp<"minimum", + [Commutative, NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_MinOp { + let hasFolder = 1; +} + +def HLO_MulOp : HLO_BinaryElementwiseOp<"multiply", + [Commutative, NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_MulOp { + let hasFolder = 1; +} + +def HLO_PowOp : HLO_BinaryElementwiseOp<"power", + [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_PowOp; + +def HLO_RemOp : HLO_BinaryElementwiseOp<"remainder", + [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_RemOp; + +def HLO_ShiftLeftOp : HLO_BinaryElementwiseOp<"shift_left", + [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_ShiftLeftOp; + +def HLO_ShiftRightArithmeticOp : HLO_BinaryElementwiseOp<"shift_right_arithmetic", + [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_ShiftRightArithmeticOp; + +def HLO_ShiftRightLogicalOp : HLO_BinaryElementwiseOp<"shift_right_logical", + [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_ShiftRightLogicalOp; + +def HLO_SubOp : HLO_BinaryElementwiseOp<"subtract", + [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_SubOp { + let hasFolder = 1; +} + +//===----------------------------------------------------------------------===// +// XLA binary logical elementwise op definitions. +//===----------------------------------------------------------------------===// + +// See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations +class HLO_BinaryLogicalElementwiseOp : + HLO_BinaryElementwiseOp< + mnemonic, [Commutative, NoSideEffect, SameOperandsAndResultType]> { + let arguments = (ins + HLO_PredOrIntTensor:$lhs, + HLO_PredOrIntTensor:$rhs + ); +} + +def HLO_AndOp: HLO_BinaryLogicalElementwiseOp<"and">, BASE_HLO_AndOp; +def HLO_OrOp: HLO_BinaryLogicalElementwiseOp<"or">, BASE_HLO_OrOp; +def HLO_XorOp : HLO_BinaryLogicalElementwiseOp<"xor">, BASE_HLO_XorOp; + +//===----------------------------------------------------------------------===// +// XLA communication op definitions. +//===----------------------------------------------------------------------===// + +// InfeedOp corresponds to 'InfeedWithToken' xla client API and not 'Infeed'. +// InfeedWithToken allows ordering of infeed HLO instructions using tokens. +def HLO_InfeedOp : HLO_Op<"infeed", []> { + + string summary = "Infeed operator"; + + string description = [{ + Reads a single data item from the implicit Infeed streaming interface of + the device, interpreting the data as the given shape, and returns a XlaOp + of the data. Multiple Infeed operations are allowed in a computation, but + there must be a total order among the Infeed operations. + + See https://www.tensorflow.org/xla/operation_semantics#infeed. + }]; + + let arguments = (ins + HLO_Token:$token, + DefaultValuedAttr:$infeed_config + ); + let results = (outs HLO_Tuple); + let hasCustomHLOConverter = 1; +} + +// OutfeedOp corresponds to 'OutfeedWithToken' xla client API and not 'Outfeed'. +// OutfeedWithToken allows ordering of outfeed HLO instructions using tokens. +def HLO_OutfeedOp : HLO_Op<"outfeed", []> { + + string summary = "Outfeed operator"; + + string description = [{ + Generates outgoing data transfers for the given data. It takes data and a + token type operand and produces a token type value. Tokens are used for + ordering side-effecting operations. + + See https://www.tensorflow.org/xla/operation_semantics#outfeed. + }]; + + let arguments = (ins + HLO_TensorOrTuple:$operand, + HLO_Token:$token, + DefaultValuedAttr:$outfeed_config + ); + let results = (outs HLO_Token); + let hasCustomHLOConverter = 1; +} + +def HLO_SendOp : HLO_Op<"send", []> { + + string summary = "Send operator"; + + string description = [{ + Sends the given operand data to a Recv instruction in another computation + that shares the same channel handle. Does not return any data. Similar to + the Recv operation, Send operation represents synchronous communication, + and is internally decomposed into 2 HLO instructions (Send and SendDone) to + enable asynchronous data transfers. + + See https://www.tensorflow.org/xla/operation_semantics#send. + }]; + + let arguments = (ins + HLO_TensorOrTuple:$operand, + HLO_Token:$token, + ChannelHandle:$channel_id, + DefaultValuedAttr:$is_host_transfer + ); + + let results = (outs HLO_Token); + let hasCustomHLOConverter = 1; +} + +def HLO_RecvOp : HLO_Op<"recv", []> { + + string summary = "Recv operator"; + + string description = [{ + Receives data of the given shape from a Send instruction in another + computation that shares the same channel handle. Returns a tuple containing + value for the received data and a token. Recv operation represents + synchronous communication. However, the instruction is internally decomposed + into 2 HLO instructions (Recv and RecvDone) to enable asynchronous data + transfers. + + See https://www.tensorflow.org/xla/operation_semantics#recv. + }]; + + let arguments = (ins + HLO_Token:$token, + ChannelHandle:$channel_id, + DefaultValuedAttr:$is_host_transfer + ); + + let results = (outs HLO_Tuple); + let hasCustomHLOConverter = 1; +} + +//===----------------------------------------------------------------------===// +// XLA parallelism related op definitions. +//===----------------------------------------------------------------------===// + +def HLO_ReplicaIdOp : HLO_Op<"replica_id", [NoSideEffect]>, + BASE_HLO_ReplicaIdOp { + // TODO(prakalps): The output should unsigned 32-bit integer but mlir does + // not differentiate between signed and unsigned int. + let results = (outs I32Tensor); +} + +//===----------------------------------------------------------------------===// +// XLA control flow op definitions. +//===----------------------------------------------------------------------===// + +def HLO_AfterAllOp : HLO_Op<"after_all", [NoSideEffect]> { + + string summary = "AfterAll operator"; + + string description = [{ + AfterAll takes a variadic number of tokens and produces a single token. + Tokens are primitive types which can be threaded between side-effecting + operations to enforce ordering. AfterAll can be used as a join of tokens + for ordering a operation after a set operations. + + See https://www.tensorflow.org/xla/operation_semantics#afterall. + }]; + + let arguments = (ins Variadic:$operands); + let results = (outs HLO_Token); +} + +// Xla Client API has two separate calls for indexed and predicated conditional, +// although both eventually map to kConditional HLO. IfOp maps to predicated +// conditional use of kConditional HLO. +def HLO_IfOp: HLO_Op<"if", [RecursiveSideEffects]> { + string summary = "If operator"; + + string description = [{ + Returns the result of executing either a true or false function depending on + the result of a condition function. + + See https://www.tensorflow.org/xla/operation_semantics#conditional. + }]; + + let arguments = (ins + HLO_PredTensor:$pred, + HLO_TensorOrTuple:$true_arg, + HLO_TensorOrTuple:$false_arg + ); + + let regions = (region AnyRegion:$true_branch, + AnyRegion:$false_branch); + + let results = (outs HLO_TensorOrTuple); + + // TODO(b/129422361): ConditionalOp has special conversion logic to HLO. + let hasCustomHLOConverter = 1; +} + +// Xla Client API has two separate calls for indexed and predicated conditional, +// although both eventually map to kConditional HLO. CaseOp maps to indexed +// conditional use of kConditional HLO. +def HLO_CaseOp: HLO_Op<"case", [RecursiveSideEffects]>, + BASE_HLO_CaseOp { + + let arguments = (ins + I32Tensor:$index, + Variadic:$branch_operands + ); + + let regions = (region VariadicRegion:$branches); + + let results = (outs Variadic); + + let hasCustomHLOConverter = 1; +} + + +def HLO_WhileOp: HLO_Op<"while", [RecursiveSideEffects, + SameOperandsAndResultType]>, + BASE_HLO_WhileOp { + let arguments = (ins HLO_TensorOrTuple:$val); + + let regions = (region AnyRegion:$cond, AnyRegion:$body); + + let results = (outs HLO_TensorOrTuple); + + // TODO(b/129422361): WhileOp has special conversion logic to HLO. + let hasCustomHLOConverter = 1; +} + +def HLO_AllReduceOp : HLO_Op<"all_reduce", + [SameOperandsAndResultType]>, BASE_HLO_AllReduceOp { + + let arguments = (ins + HLO_Tensor:$operand, + I64ElementsAttr:$replica_groups, + OptionalAttr>:$channel_id + ); + let regions = (region SizedRegion<1>:$computation); + let results = (outs HLO_Tensor); + + let hasCustomHLOConverter = 1; +} + +def HLO_AllToAllOp : HLO_Op<"all_to_all", + [NoSideEffect, SameOperandsElementType, SameOperandsShape]>, BASE_HLO_AllToAllOp { + + let arguments = (ins + HLO_Tensor:$operand, + I64Attr:$split_dimension, + I64Attr:$concat_dimension, + I64Attr:$split_count, + I64ElementsAttr:$replica_groups + ); + let results = (outs HLO_Tensor); +} + +def HLO_ReduceOp: HLO_Op<"reduce", [ + RecursiveSideEffects, + SameVariadicOperandSize, + SingleBlockImplicitTerminator<"ReturnOp">, + InferFusibilityOpInterface + ]>, BASE_HLO_ReduceOp { + let arguments = (ins + Variadic:$operands, + Variadic:$init_values, + I64ElementsAttr:$dimensions + ); + + let results = (outs Variadic); + + let builders = [OpBuilder< + "OpBuilder &, OperationState &state, ValueRange operands, " + "ValueRange init_values, DenseIntElementsAttr dimensions" + >]; + + let extraClassDeclaration = [{ + bool isFusibleWithConsumer() { + return false; + } + llvm::Optional inferEffectiveWorkloadShape() { + return getOperation()->getOperand(0); + } + }]; + + let hasFolder = 1; + + // TODO(hinsu): Verify that the attached body arguments and results are + // compatible with reduce op's operands. + let regions = (region SizedRegion<1>:$body); + + // TODO(hinsu): Implement custom printer and parser. + + // TODO(b/129422361): ReduceOp has special conversion logic to HLO. + let hasCustomHLOConverter = 1; +} + +//===----------------------------------------------------------------------===// +// XLA tuple op definitions. +//===----------------------------------------------------------------------===// +def HLO_GetTupleElementOp: HLO_Op<"get_tuple_element", [NoSideEffect]>, BASE_HLO_GetTupleElementOp { + let arguments = (ins + HLO_Tuple, + I32Attr:$index + ); + + let results = (outs HLO_TensorOrTokenOrTuple); + + let hasFolder = 1; + + let builders = [OpBuilder< + "OpBuilder &builder, OperationState &results, " + "Value value, int32_t index">]; +} + +def HLO_TupleOp : HLO_Op<"tuple", [NoSideEffect]>, BASE_HLO_TupleOp { + let arguments = (ins Variadic:$val); + let results = (outs HLO_Tuple); + + let builders = [OpBuilder< + "OpBuilder &builder, OperationState &results, " + "ValueRange values">]; + +} + +def HLO_CompareOp: HLO_Op<"compare", + [NoSideEffect, SameTypeOperands, SameOperandsAndResultShape]>, + BASE_HLO_CompareOp { + let arguments = (ins + HLO_Tensor:$lhs, + HLO_Tensor:$rhs, + HLO_ComparisonDirectionAttr:$comparison_direction + ); + let results = (outs HLO_PredTensor); + + let builders = [OpBuilder< + "OpBuilder &builder, OperationState &result, Value lhs, Value rhs, " + "StringAttr comparison_direction" + >]; +} + +//===----------------------------------------------------------------------===// +// XLA Slice definitions. +//===----------------------------------------------------------------------===// + +def HLO_SliceOp: HLO_Op< + "slice", + [NoSideEffect, SameOperandsAndResultElementType, + AllTypesMatch<["start_indices", "limit_indices", "strides"]>]> { + let arguments = (ins + HLO_Tensor:$operand, + I64ElementsAttr:$start_indices, + I64ElementsAttr:$limit_indices, + I64ElementsAttr:$strides + ); + + let results = (outs HLO_Tensor); + + let hasCanonicalizer = 1; + let hasFolder = 1; + + let builders = [OpBuilder< + "OpBuilder &builder, OperationState &result, Value operand, " + "DenseIntElementsAttr start_indices, DenseIntElementsAttr limit_indices, " + "DenseIntElementsAttr strides" + >]; + + let extraClassDeclaration = [{ + // Infers output type for given operand and attributes. Result type is + // unranked if any of the attributes is illegal. + static Type InferOutputTypes(Builder *builder, Value operand, + DenseIntElementsAttr start_indices, + DenseIntElementsAttr limit_indices, + DenseIntElementsAttr strides); + }]; +} + +def HLO_DynamicSliceOp: HLO_Op<"dynamic-slice", + [NoSideEffect, AllElementTypesMatch<["operand", "result"]>]> { + let arguments = (ins + HLO_Tensor:$operand, + Variadic:$start_indices, + I64ElementsAttr:$slice_sizes + ); + + let results = (outs HLO_Tensor:$result); + let hasCanonicalizer = 1; +} + +def HLO_DynamicUpdateSliceOp: HLO_Op<"dynamic-update-slice", + [NoSideEffect, AllElementTypesMatch<["operand", "update", "result"]>, + AllShapesMatch<["operand", "result"]>]> { + let arguments = (ins + HLO_Tensor:$operand, + HLO_Tensor:$update, + Variadic:$start_indices + ); + + let results = (outs HLO_Tensor:$result); +} + + +//===----------------------------------------------------------------------===// +// XLA Other op definitions. +//===----------------------------------------------------------------------===// + +def HLO_BatchNormGradOp : HLO_Op<"batch_norm_grad", [NoSideEffect]>, + BASE_HLO_BatchNormGradOp { + + let arguments = (ins + HLO_Tensor:$operand, + HLO_Tensor:$scale, + HLO_Tensor:$mean, + HLO_Tensor:$variance, + HLO_Tensor:$grad_output, + F32Attr:$epsilon, + I64Attr:$feature_index + ); + + let results = (outs HLO_Tuple); +} + +def HLO_BatchNormInferenceOp : HLO_Op<"batch_norm_inference", + [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_BatchNormInferenceOp { + + let arguments = (ins + HLO_Tensor:$operand, + HLO_Tensor:$scale, + HLO_Tensor:$offset, + HLO_Tensor:$mean, + HLO_Tensor:$variance, + F32Attr:$epsilon, + I64Attr:$feature_index + ); + + let results = (outs HLO_Tensor); +} + +def HLO_BatchNormTrainingOp : HLO_Op<"batch_norm_training", [NoSideEffect]>, + BASE_HLO_BatchNormTrainingOp { + + let arguments = (ins + HLO_Tensor:$operand, + HLO_Tensor:$scale, + HLO_Tensor:$offset, + F32Attr:$epsilon, + I64Attr:$feature_index + ); + + let results = (outs HLO_Tuple); +} + +def HLO_BitcastConvertOp : HLO_Op<"bitcast_convert", + [NoSideEffect, SameOperandsAndResultShape]>, BASE_HLO_BitcastConvertOp { + + let arguments = (ins HLO_Tensor:$operand); + let results = (outs HLO_Tensor); + let hasCustomHLOConverter = 1; +} + +def HLO_BroadcastOp : HLO_Op<"broadcast", + [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_BroadcastOp { + let arguments = (ins + HLO_Tensor:$operand, + I64ElementsAttr:$broadcast_sizes + ); + + let results = (outs HLO_Tensor); +} + +def HLO_BroadcastInDimOp : HLO_Op<"broadcast_in_dim", + [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_BroadcastInDimOp { + let arguments = (ins + HLO_Tensor:$operand, + BroadcastDimAttr:$broadcast_dimensions + ); + + let results = (outs HLO_StaticShapeTensor); + + let hasFolder = 1; + // Only handles a static subset of the legacy format. + let hasCustomHLOConverter = 1; +} + +def HLO_DynamicBroadcastInDimOp : HLO_Op<"dynamic_broadcast_in_dim", + [NoSideEffect]> { + string summary = "Broadcast a tensor into the given dynamic shape by adding dimensions."; + string description = [{ + This is a generalization of the BroadcastInDimOp which accepts its output + dimensions as an argument. It should eventually supercede the statically + shaped original, but is being phased as a separate op in order to support + compatibility with lowerings and translations that precede dynamic + shapes. + }]; + let arguments = (ins + HLO_Tensor:$operand, + HLO_DimensionTensor:$output_dimensions, + BroadcastDimAttr:$broadcast_dimensions + ); + + let results = (outs HLO_Tensor); + + let hasCanonicalizer = 1; + // Cannot be exported to legacy formats. + let hasCustomHLOConverter = 1; +} + +// Note: There is no HLO_CallOp because the standard call operation mlir::CallOp +// is used instead. A mlir::CallOp is exported to a HLO call instruction +// directly. + +def HLO_CholeskyOp : HLO_Op<"cholesky", + [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_CholeskyOp { + let arguments = (ins + HLO_FpOrComplexTensor:$a, + DefaultValuedAttr:$lower + ); + + let results = (outs HLO_FpOrComplexTensor); +} + +def HLO_ClampOp : HLO_Op<"clamp", + [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_ClampOp { + let arguments = (ins + HLO_Tensor:$min, + HLO_Tensor:$operand, + HLO_Tensor:$max + ); + + let results = (outs HLO_Tensor); +} + +def HLO_ConcatenateOp : HLO_Op<"concatenate", + [NoSideEffect, SameOperandsAndResultElementType, DeclareOpInterfaceMethods]>, BASE_HLO_ConcatenateOp { + + let arguments = (ins + Variadic:$val, + I64Attr: $dimension + ); + + let results = (outs HLO_Tensor); + + let hasCanonicalizer = 1; + let hasFolder = 1; + +} + +def HLO_CollectivePermuteOp: HLO_Op<"collective_permute", + [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_CollectivePermuteOp { + + let arguments = (ins + HLO_Tensor:$operand, + I64ElementsAttr:$source_target_pairs + ); + let results = (outs HLO_Tensor); +} + +// TODO(hinsu): Make this struct dialect independent so that it can be shared +// between HLO and LHLO dialect. +def ConvDimensionNumbers : StructAttr<"ConvDimensionNumbers", HLO_Dialect, [ + StructFieldAttr<"input_batch_dimension",I64Attr>, + StructFieldAttr<"input_feature_dimension", I64Attr>, + StructFieldAttr<"input_spatial_dimensions", I64ElementsAttr>, + StructFieldAttr<"kernel_input_feature_dimension", I64Attr>, + StructFieldAttr<"kernel_output_feature_dimension", I64Attr>, + StructFieldAttr<"kernel_spatial_dimensions", I64ElementsAttr>, + StructFieldAttr<"output_batch_dimension", I64Attr>, + StructFieldAttr<"output_feature_dimension", I64Attr>, + StructFieldAttr<"output_spatial_dimensions", I64ElementsAttr>] > { + + let description = "Structure of dimension information for conv op"; +} + +def HLO_ConvOp : HLO_Op<"convolution", [NoSideEffect]>, BASE_HLO_ConvOp { + let arguments = (ins + HLO_Tensor:$lhs, + HLO_Tensor:$rhs, + // Default value: one for each of the spatial dimension. + OptionalAttr:$window_strides, + // Default value: zero for each of the spatial dimension. + OptionalAttr:$padding, + // Default value: one for each of the spatial dimension. + OptionalAttr:$lhs_dilation, + // Default value: one for each of the spatial dimension. + OptionalAttr:$rhs_dilation, + ConvDimensionNumbers:$dimension_numbers, + I64Attr:$feature_group_count, + I64Attr:$batch_group_count, + HLO_PrecisionConfigAttr:$precision_config + ); + + let results = (outs HLO_Tensor); +} + +def HLO_CopyOp: HLO_Op<"copy", [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_CopyOp { + let arguments = (ins HLO_Tensor); + let results = (outs HLO_Tensor); + let hasFolder = 1; +} + +def HLO_CrossReplicaSumOp : HLO_Op<"cross-replica-sum", + [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_CrossReplicaSumOp { + + let arguments = (ins + HLO_Tensor:$operand, + I64ElementsAttr:$replica_groups + ); + + let results = (outs HLO_Tensor); +} + +def HLO_CustomCallOp: HLO_Op<"custom_call", []>, BASE_HLO_CustomCallOp { + let arguments = (ins + Variadic:$args, + StrAttr:$call_target_name, + DefaultValuedAttr:$has_side_effect, + DefaultValuedAttr:$backend_config + ); + let results = (outs HLO_Tensor); + let hasCustomHLOConverter = 1; +} + +def HLO_DotOp: HLO_Op<"dot", [NoSideEffect]>, BASE_HLO_DotOp { + let arguments = ( + ins HLO_Tensor:$lhs, + HLO_Tensor:$rhs, + HLO_PrecisionConfigAttr:$precision_config + ); + let results = (outs HLO_Tensor); +} + +def DotDimensionNumbers : StructAttr<"DotDimensionNumbers", HLO_Dialect, [ + StructFieldAttr<"lhs_batching_dimensions", I64ElementsAttr>, + StructFieldAttr<"rhs_batching_dimensions", I64ElementsAttr>, + StructFieldAttr<"lhs_contracting_dimensions", I64ElementsAttr>, + StructFieldAttr<"rhs_contracting_dimensions", I64ElementsAttr> + ]> { + let description = "Structure of dimension information for dot product"; +} + +def HLO_DotGeneralOp: HLO_Op<"dot_general", [NoSideEffect]>, BASE_HLO_DotGeneralOp { + let arguments = (ins + HLO_Tensor:$lhs, + HLO_Tensor:$rhs, + DotDimensionNumbers:$dot_dimension_numbers, + HLO_PrecisionConfigAttr:$precision_config + ); + + let results = (outs HLO_Tensor); + let verifier = [{ return Verify(*this); }]; +} + +// Define Base Einsum op within the HLO dialect as these are client ops and +// therefore this class is not common between HLO and LHLO ops. +class BASE_EinsumOp { + string summary = "Einsum operator"; + + string description = [{ + Returns a tensor whose elements are defined by equation, which is written + in a shorthand form inspired by the Einstein summation convention. + }]; +} + +def HLO_EinsumOp: HLO_Op<"einsum", [NoSideEffect]>, BASE_EinsumOp { + let arguments = (ins + HLO_Tensor:$lhs, + HLO_Tensor:$rhs, + StrAttr:$einsum_config + ); + + let results = (outs HLO_Tensor); + + // TODO(hinsu): Canonicalize to lower this client side HLO op to server + // side HLO ops. +} + +def HLO_UnaryEinsumOp: HLO_Op<"unary_einsum", [NoSideEffect]>, BASE_EinsumOp { + let arguments = (ins + HLO_Tensor:$operand, + StrAttr:$einsum_config + ); + + let results = (outs HLO_Tensor); + + let hasCanonicalizer = 1; + + // UnaryEinsumOp is unconditionally canonicalized to the binary EinsumOp so + // the HLO converter shouldn't be invoked. + let hasCustomHLOConverter = 1; +} + +def HLO_FftOp: HLO_Op<"fft", [NoSideEffect]>, BASE_HLO_FftOp { + let arguments = (ins + HLO_Tensor:$operand, + HLO_FftTypeAttr: $fft_type, + I64ElementsAttr:$fft_length + ); + + let results = (outs HLO_Tensor); +} + +def GatherDimensionNumbers : StructAttr<"GatherDimensionNumbers", HLO_Dialect, + [StructFieldAttr<"offset_dims", I64ElementsAttr>, + StructFieldAttr<"collapsed_slice_dims", I64ElementsAttr>, + StructFieldAttr<"start_index_map", I64ElementsAttr>, + StructFieldAttr<"index_vector_dim", I64Attr>]> { + let description = "Structure of dimension information for gather"; +} + +def HLO_GatherOp: HLO_Op<"gather", [NoSideEffect]>, BASE_HLO_GatherOp { + let arguments = (ins + HLO_Tensor:$operand, + HLO_IntTensor:$start_indices, + GatherDimensionNumbers:$dimension_numbers, + I64ElementsAttr:$slice_sizes, + DefaultValuedAttr:$indices_are_sorted + ); + + let results = (outs HLO_Tensor); +} + +def HLO_GetDimensionSizeOp: HLO_Op<"get_dimension_size", [NoSideEffect]>, + BASE_HLO_GetDimensionSizeOp { + let arguments = (ins + HLO_Tensor:$operand, + I32Attr:$dimension + ); + let results = (outs HLO_IntTensor); +} + +def HLO_MapOp: HLO_Op<"map", + [RecursiveSideEffects, SameOperandsElementType, + SameOperandsAndResultShape, SingleBlockImplicitTerminator<"ReturnOp">]>, + BASE_HLO_MapOp { + let arguments = (ins + Variadic:$operands, + I64ElementsAttr:$dimensions + ); + let regions = (region SizedRegion<1>:$computation); + let results = (outs HLO_Tensor); + let hasCustomHLOConverter = 1; +} + +def HLO_ReshapeOp: HLO_Op<"reshape", + [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_ReshapeOp { + let arguments = (ins HLO_Tensor:$operand); + + let results = (outs HLO_StaticShapeTensor); + let hasFolder = 1; + + let hasCustomHLOConverter = 1; +} + +def HLO_DynamicReshapeOp: HLO_Op<"dynamic_reshape", [NoSideEffect]> { + let summary = "Reshape a tensor to a given, possibly dynamic, shape."; + let description = [{ + Reshapes `operand` to `output_shape`. + + Requires: + - The length of `output_shape` is equal to the rank of `result`. + - The number of elements in `operand` (that is, the product of extents of + its shape) is equal to the number of elements in `output_shape` (that is, + the product of values in `output_shape`). + }]; + + let arguments = (ins HLO_Tensor:$operand, HLO_DimensionTensor:$output_shape); + let results = (outs HLO_Tensor:$result); + + let hasCanonicalizer = 1; + // Cannot be exported to legacy formats. + let hasCustomHLOConverter = 1; +} + +def HLO_ScatterOp: HLO_Op<"scatter", [RecursiveSideEffects]>, + BASE_HLO_ScatterOp { + let arguments = (ins + HLO_Tensor:$operand, + HLO_Tensor:$scatter_indices, + HLO_Tensor:$updates, + ScatterDimensionNumbers:$scatter_dimension_numbers, + DefaultValuedAttr:$indices_are_sorted, + DefaultValuedAttr:$unique_indices + ); + + let regions = (region SizedRegion<1>:$update_computation); + + let results = (outs HLO_Tensor); + + let hasCustomHLOConverter = 1; +} + +// TODO(jpienaar): Add broadcastable trait. +def HLO_SelectOp: HLO_Op<"select", [NoSideEffect, DeclareOpInterfaceMethods]>, BASE_HLO_SelectOp { + let arguments = (ins + HLO_PredTensor:$pred, + HLO_Tensor:$on_true, + HLO_Tensor:$on_false + ); + + let results = (outs HLO_Tensor); +} + +def HLO_SelectAndScatterOp: HLO_Op<"select_and_scatter", + [RecursiveSideEffects]>, BASE_HLO_SelectAndScatterOp { + let arguments = (ins + HLO_Tensor:$operand, + HLO_Tensor:$source, + HLO_Tensor:$init_value, + OptionalAttr:$window_dimensions, + OptionalAttr:$window_strides, + OptionalAttr:$padding + ); + + let regions = (region SizedRegion<1>:$select, SizedRegion<1>:$scatter); + + let results = (outs HLO_Tensor); + + let hasCustomHLOConverter = 1; +} + +def HLO_SetDimensionSizeOp: HLO_Op<"set_dimension_size", [NoSideEffect]>, + BASE_HLO_SetDimensionSizeOp { + let arguments = (ins + HLO_Tensor:$operand, + I32Tensor:$size, + I32Attr:$dimension + ); + let results = (outs HLO_Tensor); +} + +def HLO_SortOp : HLO_Op<"sort", [RecursiveSideEffects]>, BASE_HLO_SortOp { + let arguments = (ins + Variadic:$operands, + DefaultValuedAttr:$dimension, + DefaultValuedAttr:$is_stable + ); + + let results = (outs HLO_TensorOrTuple); + + let regions = (region SizedRegion<1>:$comparator); + + let builders = [OpBuilder< + "OpBuilder &builder, OperationState &state, ValueRange operands, " + "int64_t dimension = -1, bool is_stable = false" + >]; + + // TODO(b/129422361): SortOp has special conversion logic to HLO. + let hasCustomHLOConverter = 1; +} + +def HLO_ReverseOp: HLO_Op<"reverse", + [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_ReverseOp { + let arguments = (ins + HLO_Tensor:$operand, + I64ElementsAttr:$dimensions + ); + + let results = (outs HLO_Tensor); + + let hasFolder = 1; +} + +def HLO_PadOp: HLO_Op<"pad", + [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_PadOp { + let arguments = (ins + HLO_Tensor:$operand, + HLO_Tensor:$padding_value, + I64ElementsAttr: $edge_padding_low, + I64ElementsAttr: $edge_padding_high, + I64ElementsAttr: $interior_padding + ); + + let results = (outs HLO_Tensor); + + let description = [{ + Pads the `operand` according to TBD. + }]; + + // TODO(b/129422361): PadOp has a custom constructor for HLO. + let hasCustomHLOConverter = 1; +} + +def HLO_TraceOp: HLO_Op<"trace", []>, BASE_HLO_TraceOp { + let arguments = (ins + HLO_Tensor:$operand, + StrAttr:$tag + ); + let hasCustomHLOConverter = 1; +} + +def HLO_TransposeOp: HLO_Op<"transpose", + [NoSideEffect, SameOperandsAndResultElementType]>, BASE_HLO_TransposeOp { + let arguments = (ins + HLO_Tensor:$operand, + I64ElementsAttr:$permutation + ); + let results = (outs HLO_Tensor); + + let hasFolder = 1; +} + +def HLO_TriangularSolveOp: HLO_Op<"triangular_solve", + [NoSideEffect, SameOperandsAndResultElementType]>, + BASE_HLO_TriangularSolveOp { + let arguments = (ins + HLO_FpOrComplexTensor:$a, + HLO_FpOrComplexTensor:$b, + BoolAttr:$left_side, + BoolAttr:$lower, + BoolAttr:$unit_diagonal, + HLO_TransposeAttr:$transpose_a + ); + let results = (outs HLO_FpOrComplexTensor); +} + +def HLO_ReduceWindowOp: HLO_Op<"reduce_window", [ + RecursiveSideEffects, + SingleBlockImplicitTerminator<"ReturnOp"> + ]>, BASE_HLO_ReduceWindowOp { + + // TODO(hinsu): Verify that padding attribute is 2-d and the remaining + // attributes are 1-d. Attributes' leading dimension should match rank of the + // inputs. + let arguments = (ins + HLO_Tensor:$operand, + HLO_Tensor:$init_value, + I64ElementsAttr:$window_dimensions, + // If strides or dilations attributes are missing then the default value is + // one for each of the input dimensions. Similarly, padding values are zero + // for both low and high in each of the dimensions, if not specified. + OptionalAttr:$window_strides, + OptionalAttr:$base_dilations, + OptionalAttr:$window_dilations, + OptionalAttr:$padding + ); + + let results = (outs HLO_Tensor); + + // TODO(hinsu): Verify that the attached body arguments and results are + // compatible with reduce op's operands. + let regions = (region SizedRegion<1>:$body); + + let hasCustomHLOConverter = 1; + + // TODO(hinsu): Implement custom printer and parser. +} + +def HLO_ReturnOp : HLO_Op<"return", [NoSideEffect, Terminator]> { + let summary = [{ + The `hlo.return` operation terminates a region and returns values. + }]; + + let arguments = (ins + Variadic:$results + ); + + // Disable conversion operator for return op as the op is not an actual XLA + // instruction and is only used as a terminator for regions. + let hasCustomHLOConverter = 1; + + // TODO(hinsu): Implement custom printer and parser. +} + +def HLO_TorchIndexSelectOp : HLO_Op<"torch_index_select", [NoSideEffect]> { + let arguments = (ins + HLO_Tensor:$input, + HLO_Tensor:$index, + I64Attr:$dim, + I64Attr:$batch_dims + ); + + let results = (outs HLO_Tensor); + + // TODO(hinsu): Canonicalize to lower this client side HLO op to server + // side HLO ops. +} + +//===----------------------------------------------------------------------===// +// XLA RngUniform Operator. +//===----------------------------------------------------------------------===// +def HLO_RngUniformOp : HLO_Op<"rng_uniform", []>, BASE_HLO_RngUniformOp { + let arguments = (ins + HLO_PredIntOrFpTensor:$a, + HLO_PredIntOrFpTensor:$b, + I64Tensor:$shape + ); + + let results = (outs HLO_PredIntOrFpTensor); + + let hasCustomHLOConverter = 1; +} + +def HLO_RngNormalOp : HLO_Op<"rng_normal", []>, BASE_HLO_RngNormalOp { + let arguments = (ins + HLO_FpTensor:$mu, + HLO_FpTensor:$sigma, + I64Tensor:$shape + ); + + let results = (outs HLO_FpTensor); + + let hasCustomHLOConverter = 1; +} + +//===----------------------------------------------------------------------===// +// XLA Quantize Operator. +//===----------------------------------------------------------------------===// +def HLO_DequantizeOp : HLO_Op<"dequantize", [NoSideEffect]>, + BASE_HLO_DequantizeOp { + let arguments = (ins + TensorOf<[I32]>:$input, + F32Attr:$min_range, + F32Attr:$max_range, + HLO_DequantizeModeAttr:$mode, + BoolAttr:$transpose_output, + DefaultValuedAttr:$is_16bits + ); + + let results = (outs TensorOf<[BF16]>:$output); + + let hasCustomHLOConverter = 1; +} + +def HLO_FusionOp : HLO_Op<"fusion", []> { + let summary = "Fusion operator"; + let description = [{ + Models the fusion instruction. + + A fusion op is consists of a group of basic ops (represented as a region + attached to it). It serves as a hint to the backend that it is beneficial + to emit the contained ops into a single loop nest or kernel. + }]; + let regions = (region SizedRegion<1>:$fused_computation); + + let arguments = (ins + Variadic:$operands + ); + + let results = (outs + Variadic:$results + ); + + // FusionOp has special conversion logic to HLO. + let hasCustomHLOConverter = 1; +} + +#endif // HLO_OPS diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td new file mode 100644 index 0000000..98a0de3 --- /dev/null +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td @@ -0,0 +1,1331 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef HLO_OPS_BASE +#define HLO_OPS_BASE + +include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpBase.td" + +def HLO_Pred : TypeAlias; + +// TODO(hinsu): Use signed integers instead of signless integer which is being +// used for legacy reasons. +def HLO_SInt : SignlessIntOfWidths<[8, 16, 32, 64]>; +def HLO_UInt : UnsignedIntOfWidths<[8, 16, 32, 64]>; +def HLO_Int : AnyTypeOf<[HLO_SInt, HLO_UInt]>; + +def HLO_Complex : Complex>; + +// The broadcasting dimensions correspond to a tuple that describes how a +// smaller rank shape is broadcast into a larger rank shape. For example, +// given a 2x3x4 cuboid and a 3x4 matrix, a broadcasting tuple (1,2) means +// matching the matrix to dimensions 1 and 2 of the cuboid. +defvar BroadcastDimAttr = I64ElementsAttr; + +//===----------------------------------------------------------------------===// +// XLA on tensors type definitions. +//===----------------------------------------------------------------------===// + +// Token type. +def HLO_Token : Type()">, "token">; + +// Any integer tensor types +def HLO_IntTensor : TensorOf<[HLO_Int]>; + +// Any integer tensor type with rank 0 (i.e. representing a single integer). +def HLO_ScalarIntTensor : ShapedContainerType< + [HLO_Int], And<[IsTensorTypePred, HasAnyRankOfPred<[0]>]>, + "a 0-dim integer tensor">; + +// Any floating-point tensor types +def HLO_FpTensor : TensorOf<[AnyFloat]>; + +def HLO_PredTensor : TensorOf<[HLO_Pred]>; + +def HLO_Tensor : TensorOf<[AnyFloat, HLO_Pred, HLO_Int, HLO_Complex]>; + +def HLO_ComplexTensor : TensorOf<[HLO_Complex]>; + +def HLO_Tuple : NestedTupleOf<[HLO_Tensor, HLO_Token]>; + +def HLO_TensorOrTuple : AnyTypeOf<[HLO_Tensor, HLO_Tuple]>; + +def HLO_TensorOrTokenOrTuple : AnyTypeOf<[HLO_Tensor, HLO_Token, HLO_Tuple]>; + +def HLO_DimensionValue : AnyTypeOf<[Index, HLO_Pred, HLO_Int]>; + +// Dynamic representation of a shape vector as a tensor. +def HLO_DimensionTensor : ShapedContainerType< + [HLO_DimensionValue], + And<[IsTensorTypePred, HasAnyRankOfPred<[1]>]>, + "a 1D tensor of dimensions">; + +// In general, static shaped tensor constraints should be avoided unless +// it is for a legacy op which is only correct with static shapes. +def HLO_StaticShapeTensor : StaticShapeTensorOf<[ + AnyFloat, HLO_Pred, HLO_Int, HLO_Complex]>; + +//===----------------------------------------------------------------------===// +// XLA on tensors combined type definitions. +//===----------------------------------------------------------------------===// + +// Any integer or floating-point tensor types +def HLO_IntOrFpTensor : TensorOf<[HLO_Int, AnyFloat]>; + +// Any integer or predicate tensor types +def HLO_PredOrIntTensor : TensorOf<[HLO_Pred, HLO_Int]>; + +// Any floating-point or complex tensor types +def HLO_FpOrComplexTensor : TensorOf<[AnyFloat, HLO_Complex]>; + +// Any int, floating-point or complex tensor types +def HLO_IntFpOrComplexTensor : TensorOf<[HLO_Int, AnyFloat, HLO_Complex]>; + +// Any pred, int or floating-point tensor types +def HLO_PredIntOrFpTensor : TensorOf<[HLO_Pred, HLO_Int, AnyFloat]>; + +//===----------------------------------------------------------------------===// +// XLA nullary op definitions. +//===----------------------------------------------------------------------===// + +class BASE_HLO_ConstOp { + string summary = "Constant operator"; + + string description = [{ + Represents a constant value. + }]; +} + +class BASE_HLO_IotaOp { + string summary = "Iota operator"; + + string description = [{ + Creates a rank 1 array of values starting at zero and incrementing by one. + }]; +} + +//===----------------------------------------------------------------------===// +// XLA unary elementwise op definitions. +//===----------------------------------------------------------------------===// +// See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions + +class BASE_HLO_AbsOp { + string summary = "Absolute value operator"; + + string description = [{ + Returns `abs(operand)` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; +} + +class BASE_HLO_CeilOp { + string summary = "Ceil operator"; + + string description = [{ + Returns `Ceil(operand)` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; +} + +class BASE_HLO_ClzOp { + string summary = "Count-leading-zeros (Clz) operator"; + + string description = [{ + Returns the number of leading zeros in each operand element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; +} + +class BASE_HLO_ConvertOp { + string summary = "Convert operator"; + + string description = [{ + Performs element-wise conversion of values from one type to another, e.g. + float to int. + + See https://www.tensorflow.org/xla/operation_semantics#convertelementtype. + }]; +} + +class BASE_HLO_CosOp { + string summary = "Cos operator"; + + string description = [{ + Returns `Cos(operand)` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; +} + +class BASE_HLO_ExpOp { + string summary = "Exponential operator"; + + string description = [{ + Returns `e^(operand)` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; +} + +class BASE_HLO_Expm1Op { + string summary = "Exponential minus one operator"; + + string description = [{ + Returns `e^(operand) - 1` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; +} + +class BASE_HLO_FloorOp { + string summary = "Floor operator"; + + string description = [{ + Returns `Floor(operand)` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; +} + +class BASE_HLO_GetDimensionSizeOp { + string summary = "GetDimensionSize operator"; + + string description = [{ + Returns the size of the given dimension of the operand. + + See + https://www.tensorflow.org/xla/operation_semantics#getdimensionsize. + }]; +} + +class BASE_HLO_ImagOp { + string summary = "Imag operator"; + + string description = [{ + Returns `Imag(operand)` element-wise. + }]; +} + +class BASE_HLO_IsFiniteOp { + string summary = "IsFinite operator"; + + string description = [{ + Tests whether each element of operand is finite, i.e., is not positive or + negative infinity, and is not NaN. Returns a tensor of 1-bit integers with + the same shape as the input, where each element is nonzero (i.e. true) if + and only if the corresponding input element is finite. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; +} + +class BASE_HLO_LogOp { + string summary = "Logarithm operator"; + + string description = [{ + Returns `log(operand)` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; +} + +class BASE_HLO_Log1pOp { + string summary = "Log1p operator"; + + string description = [{ + Returns `log(operand+1)` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; +} + +class BASE_HLO_LogisticOp { + string summary = "Logistic operator"; + + string description = [{ + Returns `logistic(operand)` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; +} + +class BASE_HLO_NegOp { + string summary = "Negation operator"; + + string description = [{ + Returns `-operand` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; +} + +class BASE_HLO_NotOp { + string summary = "Not operator"; + + string description = [{ + Returns `!operand` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; +} + +class BASE_HLO_PopulationCountOp { + string summary = "PopulationCount operator"; + + string description = [{ + Returns the number of bits set in each operand element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; +} + +class BASE_HLO_RealOp { + string summary = "Real operator"; + + string description = [{ + Returns `Real(operand)` element-wise. + }]; +} + +class BASE_HLO_RoundOp { + string summary = "Round operator"; + + string description = [{ + Returns `Round(operand)` element-wise, rounding to nearest integer with + half-way cases rounding away from zero. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; +} + +class BASE_HLO_RsqrtOp { + string summary = "Reciprocal Square-root operator"; + + string description = [{ + Returns `1.0 / sqrt(operand)` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; +} + +class BASE_HLO_SignOp { + string summary = "Sign operator"; + + string description = [{ + Returns `sign(operand)` element-wise, where + + ``` + sign(x) = -1 : x < 0 + = -0 : x = -0 + = NaN : x = NaN + = +0 : x = +0 + = 1 : x > 0 + ``` + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; +} + +class BASE_HLO_SinOp { + string summary = "Sin operator"; + + string description = [{ + Returns `Sin(operand)` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; +} + +class BASE_HLO_SqrtOp { + string summary = "Square-root operator"; + + string description = [{ + Returns `sqrt(operand)` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; +} + +class BASE_HLO_TanhOp { + string summary = "Tanh operator"; + + string description = [{ + Returns `tanh(operand)` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions. + }]; +} + +//===----------------------------------------------------------------------===// +// XLA binary elementwise op definitions. +//===----------------------------------------------------------------------===// + +class BASE_HLO_AddOp { + string summary = "Addition operator"; + + string description = [{ + Returns `lhs + rhs` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} + +class BASE_HLO_ComplexOp { + string summary = "Complex operator"; + + string description = [{ + Performs element-wise conversion of a pair of real and imaginary values to + a complex value. + }]; +} + +class BASE_HLO_DivOp { + string summary = "Division operator"; + + string description = [{ + Returns `lhs / rhs` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} + +class BASE_HLO_MaxOp { + string summary = "Maximum operator"; + + string description = [{ + Returns `max(lhs, rhs)` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} + +class BASE_HLO_MinOp { + string summary = "Minimum operator"; + + string description = [{ + Returns `min(lhs, rhs)` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} + +class BASE_HLO_MulOp { + string summary = "Multiplication operator"; + + string description = [{ + Returns `lhs * rhs` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} +class BASE_HLO_PowOp { + string summary = "Power operator"; + + string description = [{ + Returns `lhs ^ rhs` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} + +class BASE_HLO_RemOp { + string summary = "Remainder operator"; + + string description = [{ + Returns `lhs % rhs` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} + +class BASE_HLO_SubOp { + string summary = "Subtraction operator"; + + string description = [{ + Returns `lhs - rhs` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} + +class BASE_HLO_ShiftLeftOp { + string summary = "Shift Left operator"; + + string description = [{ + Returns `lhs << rhs` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} + +class BASE_HLO_ShiftRightArithmeticOp { + string summary = "Shift right arithmetic operator"; + + string description = [{ + Returns arithmetic `lhs >> rhs` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} + +class BASE_HLO_ShiftRightLogicalOp { + string summary = "Shift right logical operator"; + + string description = [{ + Returns logical `lhs >> rhs` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} + +class BASE_HLO_Atan2Op { + string summary = "Atan2 operator"; + + string description = [{ + Returns `atan2(lhs/rhs)` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} + +class BASE_HLO_AndOp { + string summary = "Logical and"; + + string description = [{ + Returns `logical_and(lhs, rhs)` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} + +class BASE_HLO_OrOp { + string summary = "Logical or"; + + string description = [{ + Returns `logical_or(lhs, rhs)` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} + +class BASE_HLO_XorOp { + string summary = "Logical xor"; + + string description = [{ + Returns `logical_xor(lhs, rhs)` element-wise. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations. + }]; +} + +//===----------------------------------------------------------------------===// +// XLA control flow related op definitions. +//===----------------------------------------------------------------------===// + +class BASE_HLO_CaseOp { + string summary = "Switch-Case operator"; + + string description = [{ + Returns the result of executing `branches[index]`. If + `index` is < 0 or >= N, then `branches[N-1] is executed as + the default branch. + + Each branch `branches[b]` must take in a single argument of same type as + `branch_operands[b]` and will be invoked with `branch_operands[b]`. The type + of the returned value of each branch must be the same. + + Note that only one of the branches will be executed depending on the value + of index. + See https://www.tensorflow.org/xla/operation_semantics#conditional. + }]; + +} + +//===----------------------------------------------------------------------===// +// XLA parallelism related op definitions. +//===----------------------------------------------------------------------===// + +// Represents a unique identifier for each Send/Recv instruction pair or +// optionally for collective instructions (AllReduce, CollectivePermute, +// AllToAll). Non-positive channel_id handle is equivalent to no channel id. +class ChannelHandle : StructAttr<"ChannelHandle", dialect, [ + StructFieldAttr<"handle", I64Attr>, + StructFieldAttr<"type", I64Attr>]> { + let description = "two 64-bit integers 'handle' and 'type'"; +} + +class BASE_HLO_ReplicaIdOp { + string summary = "ReplicaId operator"; + + string description = [{ + Returns the unique ID (int32 scalar) of the replica. + + The unique ID of each replica is an unsigned integer in the interval [0, N), + where N is the number of replicas. Since all the replicas are running the + same program, a ReplicaId() call in the program will return a different + value on each replica. + + See https://www.tensorflow.org/xla/operation_semantics#replicaid. + }]; +} + + +class BASE_HLO_AllReduceOp { + string summary = "AllReduce operator"; + + string description = [{ + Performs a custom reduction across replicas. + + See https://www.tensorflow.org/xla/operation_semantics#allreduce. + }]; +} + +class BASE_HLO_ReduceOp { + string summary = "Reduce operator"; + + string description = [{ + Returns the result of executing a reduction function on one or more arrays + in parallel. + + See https://www.tensorflow.org/xla/operation_semantics#reduce. + }]; +} + +class BASE_HLO_ReduceWindowOp { + string summary = "ReduceWindow operator"; + + string description = [{ + Returns the result of executing a reduction function over all elements in + each window of one or more arrays in parallel. + + See https://www.tensorflow.org/xla/operation_semantics#reducewindow. + }]; +} + +//===----------------------------------------------------------------------===// +// XLA tuple op definitions. +//===----------------------------------------------------------------------===// +class BASE_HLO_GetTupleElementOp { + string summary = "GetTupleElement operator"; + + string description = [{ + Returns a member of a tuple specified by an index. + + See https://www.tensorflow.org/xla/operation_semantics#gettupleelement. + }]; +} + +class BASE_HLO_TupleOp { + string summary = "XLA's tuple op"; + + string description = [{ + Groups a set of tensor inputs into a single tuple object. + + See https://www.tensorflow.org/xla/operation_semantics#tuple. + }]; +} + +//===----------------------------------------------------------------------===// +// Precision Config enum definitions. +//===----------------------------------------------------------------------===// + +// These mirror the XLA PrecisionConfig proto enum. +def HLO_PRECISION_DEFAULT : StrEnumAttrCase<"DEFAULT">; +def HLO_PRECISION_HIGH : StrEnumAttrCase<"HIGH">; +def HLO_PRECISION_HIGHEST : StrEnumAttrCase<"HIGHEST">; + +def HLO_PrecisionAttr : StrEnumAttr<"Precision", + "XLA precision for an operand. Has backend specific meaning.", + [HLO_PRECISION_DEFAULT, HLO_PRECISION_HIGH, HLO_PRECISION_HIGHEST]>; + +// TODO(b/129153247) See if it's possible to also validate the size. +def HLO_PrecisionConfigAttr: + OptionalAttr< + TypedArrayAttrBase>; + +//===----------------------------------------------------------------------===// +// Fast Fourier Transform Type enum definitions. +//===----------------------------------------------------------------------===// + +// These mirror the XLA FftType proto enum. +def HLO_FFT_TYPE_FFT : StrEnumAttrCase<"FFT">; +def HLO_FFT_TYPE_IFFT : StrEnumAttrCase<"IFFT">; +def HLO_FFT_TYPE_RFFT : StrEnumAttrCase<"RFFT">; +def HLO_FFT_TYPE_IRFFT : StrEnumAttrCase<"IRFFT">; + +def HLO_FftTypeAttr : StrEnumAttr<"FftType", + "XLA fast fourier transform type.", + [HLO_FFT_TYPE_FFT, HLO_FFT_TYPE_IFFT, + HLO_FFT_TYPE_RFFT, HLO_FFT_TYPE_IRFFT]>; + +//===----------------------------------------------------------------------===// +// Comparison op definitions. +//===----------------------------------------------------------------------===// + +// These mirror the XLA ComparisonDirection enum. +def HLO_COMPARISON_DIRECTION_EQ : StrEnumAttrCase<"EQ">; +def HLO_COMPARISON_DIRECTION_NE : StrEnumAttrCase<"NE">; +def HLO_COMPARISON_DIRECTION_GE : StrEnumAttrCase<"GE">; +def HLO_COMPARISON_DIRECTION_GT : StrEnumAttrCase<"GT">; +def HLO_COMPARISON_DIRECTION_LE : StrEnumAttrCase<"LE">; +def HLO_COMPARISON_DIRECTION_LT : StrEnumAttrCase<"LT">; + +def HLO_ComparisonDirectionAttr : StrEnumAttr<"ComparisonDirection", + "Which comparison operation to perform.", + [ + HLO_COMPARISON_DIRECTION_EQ, + HLO_COMPARISON_DIRECTION_NE, + HLO_COMPARISON_DIRECTION_GE, + HLO_COMPARISON_DIRECTION_GT, + HLO_COMPARISON_DIRECTION_LE, + HLO_COMPARISON_DIRECTION_LT + ]>; + +class BASE_HLO_CompareOp { + string summary = "Comparison operator"; + + string description = [{ + Compares `lhs` and `rhs` elementwise according to `comparison_direction`. + + See + https://www.tensorflow.org/xla/operation_semantics#element-wise_comparison_operations. + }]; +} + +//===----------------------------------------------------------------------===// +// Quantize op definitions. +//===----------------------------------------------------------------------===// + +// These mirror the XLA ComparisonDirection enum. +def HLO_MIN_COMBINED : StrEnumAttrCase<"MIN_COMBINED">; + +def HLO_DequantizeModeAttr : StrEnumAttr<"DequantizeMode", + "Dequantization mode. Only MIN_COMBINED is supported.", + [HLO_MIN_COMBINED]>; + +class BASE_HLO_DequantizeOp { + string summary = "Dequantize operator"; + + string description = [{ + Dequantize the quantized input of packed uint32 to bfloat16. Only uint8 or + uint16 is supported for the original unpacked input. + + Returns a tensor of shape [d0,..., dn * unpack_size] if unpacked input shape + is [d0, ..., dn], where unpack_size = sizeof(unit32) / sizeof(T), where T is + the unpacked input type. If transpose_output is true, will return a tensor + of shape [dn * unpack_size, dn-1, ..., d1, d0]. transpose_output is faster + when input's rank higher than 1. The input needs to be transposed to use + transpose_output feature. + }]; +} + +//===----------------------------------------------------------------------===// +// XLA Slice definitions. +//===----------------------------------------------------------------------===// + +class BASE_HLO_SliceOp { + string summary = "Slice operator"; + + string description = [{ + Slices a portion of the `operand` into a new configuration. + + See https://www.tensorflow.org/xla/operation_semantics#slice. + }]; +} + +class BASE_HLO_DynamicSliceOp { + string summary = "Dynamic Slice operator"; + + string description = [{ + Extracts a sub-array from the input array at dynamic start_indices. + + See https://www.tensorflow.org/xla/operation_semantics#dynamicslice. + }]; +} + +class BASE_HLO_DynamicUpdateSliceOp { + string summary = "Dynamic Update Slice operator"; + + string description = [{ + DynamicUpdateSlice generates a result which is the value of the input array + operand, with a slice update overwritten at start_indices. + + See https://www.tensorflow.org/xla/operation_semantics#dynamicupdateslice. + }]; +} + +//===----------------------------------------------------------------------===// +// XLA Other op definitions. +//===----------------------------------------------------------------------===// + +class BASE_HLO_AllToAllOp { + string summary = "AllToAll"; + + string description = [{ + AllToAll is a collective operation that sends data from all cores to all + cores. It has two phases: + - The scatter phase. On each core, the operand is split into `split_count` + number of blocks along the `split_dimension`, and the blocks are + scattered to all cores, e.g., the i-th block is sent to the i-th core. + - The gather phase. Each core concatenates the received blocks along the + `concat_dimension`. + + The participating cores can be configured by: + - replica_groups: each ReplicaGroup contains a list of replica id + participating in the computation (replica id for the current replica can + be retrieved using ReplicaId op). AllToAll will be applied within + subgroups in the specified order. For example, + `replica_groups` = {{1,2,3}, {4,5,0}} means that an AllToAll will be applied + within replicas {1, 2, 3}, and in the gather phase, the received blocks + will be concatenated in the same order of 1, 2, 3. Then, another AllToAll + will be applied within replicas 4, 5, 0, and the concatenation order is + also 4, 5, 0. If `replica_groups` is empty, all replicas belong to one + group, and the concatenation order is the numerical order (0, 1, 2, ...). + + Prerequisites: + - The dimension size of the operand on the split_dimension is divisible by + `split_count`. + - The operand's shape is not tuple. + + See https://www.tensorflow.org/xla/operation_semantics#alltoall + }]; +} + +class BASE_HLO_BatchNormGradOp { + string summary = "Batch Normalization Gradient"; + + string description = [{ + Calculates gradients of batch norm. + + See https://www.tensorflow.org/xla/operation_semantics#batchnormgrad + }]; +} + +class BASE_HLO_BatchNormInferenceOp { + string summary = "Batch Normalization for Inference"; + + string description = [{ + Normalizes an array across batch and spatial dimensions. + + See https://www.tensorflow.org/xla/operation_semantics#batchnorminference + }]; +} + +class BASE_HLO_BatchNormTrainingOp { + string summary = "Batch Normalization for Training"; + + string description = [{ + Normalizes an array across batch and spatial dimensions. + + See https://www.tensorflow.org/xla/operation_semantics#batchnormtraining + }]; +} + +class BASE_HLO_BitcastConvertOp { + string summary = "BitcastConvert operator"; + + string description = [{ + Similar to a 'tf.bitcast' in TensorFlow, performs an element-wise bitcast + operation from a data shape to a target shape. The dimensions must match, + and the conversion is an element-wise one. Bitcast is implemented as a + low-level cast, so machines with different floating-point representations + will give different results. + + See https://www.tensorflow.org/xla/operation_semantics#bitcastconverttype. + }]; +} + +class BASE_HLO_BroadcastOp { + string summary = "Broadcast a tensor to a higher rank by prepending dimensions"; + + string description = [{ + Broadcasts the operand tensor to a higher rank by prepending + `broadcast_sizes` to the dimensions. The current values of the operand are + copied into the other dimensions. + + This is a more limited form of broadcasting, that corresponds to the XLA + client Broadcast method. For a more general form of broadcasting, see the + BroadcastInDimOp. + + See https://www.tensorflow.org/xla/operation_semantics#broadcast. + }]; +} + +class BASE_HLO_BroadcastInDimOp { + string summary = "Broadcast a tensor into the given shape by adding dimensions."; + + string description = [{ + Broadcasts the `operand` tensor to a higher rank. This is not the limited + form of broadcasting exposed as the XLA client broadcast op, but rather the + more powerful "InDim" broadcasting, which is closer to the HLO broadcast op + and exposed in the XLA client BroadcastInDim method. + + `broadcast_dimensions` maps the operand dimension number to the target shape + dimension number. It must have the same size as the rank of the operand. The + mapped dimensions must either be the same size or the dimension being + broadcast from must be size 1 (degenerate broadcasting). + + For a scalar (0D tensor) operand, `broadcast_dimensions` must be empty. The + The scalar value will be broadcast to every element in the target shape. + + See https://www.tensorflow.org/xla/broadcasting. + }]; +} + +class BASE_HLO_CholeskyOp { + string summary = "Cholesky operator"; + + string description = [{ + Computes the Cholesky decomposition of a batch of symmetric (Hermitian) + positive definite matrices. + + If lower is true, computes lower-triangular matrices l such that + `a=l.Transpose(l)`. If lower is false, computes upper-triangular matrices u such + that `a=Transpose(u).u`. + + Input data is read only from the lower/upper triangle of a, depending on the + value of lower. Values from the other triangle are ignored. Output data is + returned in the same triangle; the values in the other triangle are + implementation-defined and may be anything. + + If the rank of a is greater than 2, a is treated as a batch of matrices, where + all except the minor 2 dimensions are batch dimensions. + + If a is not symmetric (Hermitian) positive definite, the result is + implementation-defined. + + See https://www.tensorflow.org/xla/operation_semantics#cholesky. + }]; +} + +class BASE_HLO_ClampOp { + string summary = "Clamp operator"; + + string description = [{ + Clamps an operand to within the range between a minimum and maximum value. + + Note: All three arrays must be the same shape. Alternatively, as a + restricted form of broadcasting, min and/or max can be a scalar (0D + tensor) of the element type of the tensor operand. + + See https://www.tensorflow.org/xla/operation_semantics#clamp. + }]; +} + +class BASE_HLO_CollectivePermuteOp { + string summary = "CollectivePermute operator"; + + string description = [{ + CollectivePermute is a collective operation that sends and receives data + cross replicas. + Note that there are the following restrictions on the source_target_pair: + - Any two pairs should not have the same target replica id, and they should + not have the same source replica id. + - If a replica id is not a target in any pair, then the output on that + replica is a tensor consists of 0(s) with the same shape as the input. + + See https://www.tensorflow.org/xla/operation_semantics#collectivepermute. + + }]; +} +class BASE_HLO_ConcatenateOp { + string summary = "XLA's concatenate op"; + + string description = [{ + Concatenates a set of tensors along the specified dimension. + + See https://www.tensorflow.org/xla/operation_semantics#concatenate. + }]; +} + +class BASE_HLO_ConvOp { + string summary = "Convolution operator"; + + string description = [{ + Computes a convolution of the kind used in neural networks. + + See https://www.tensorflow.org/xla/operation_semantics#conv_convolution. + }]; +} + +class BASE_HLO_CopyOp { + string summary = "Copy operator"; + + string description = [{ + Returns a copy of `operand`. + }]; +} + +class BASE_HLO_CrossReplicaSumOp { + string summary = "Sums input across replicated instances."; + + string description = [{ + For each of the replica groups, operands of the group devices are summed + so that each device has the sum. + + For example, suppose there are 8 TPU devices: `[A, B, C, D, E, F, G, H]`. + Passing group_assignment=`[[0,2,4,6],[1,3,5,7]]` sets `A, C, E, G` as group 0, + and `B, D, F, H` as group 1. Thus we get the outputs: + `[A+C+E+G, B+D+F+H, A+C+E+G, B+D+F+H, A+C+E+G, B+D+F+H, A+C+E+G, B+D+F+H]`. + + See https://www.tensorflow.org/xla/operation_semantics#crossreplicasum. + }]; +} + + +class BASE_HLO_CustomCallOp { + string summary = "CustomCall operator"; + + string description = [{ + A custom call invokes code external to XLA. The `args` are passed to the + external code, and the external code is expected to produce a result of the + given type. The exact mechanism is backend-specific. For example, in the CPU + backend, a call instruction is emitted which targets a symbol with the name + `call_target_name`. + + `call_target_name` and `backend_config` can be arbitrary strings, but + `call_target_name` should be short as it may be used in labels. + `backend_config` can encode arbitrarily large amounts of information. + + See https://www.tensorflow.org/xla/operation_semantics#customcall. + }]; +} + +class BASE_HLO_DotOp { + string summary = "Dot operator"; + string description = [{ + Performs dot products between vectors, vector/matrix and matrix/matrix + multiplication. + + See https://www.tensorflow.org/xla/operation_semantics#dot. + }]; +} + +class BASE_HLO_DotGeneralOp { + string summary = "General Dot operator"; + string description = [{ + Performs general dot products between vectors, vector/matrix and + matrix/matrix multiplication. + + See https://www.tensorflow.org/xla/operation_semantics#dotgeneral. + }]; +} + +class BASE_HLO_FftOp { + string summary = "Fast fourier transform operator"; + + string description = [{ + Returns the fast-fourier-transform of the input array. + + See + https://www.tensorflow.org/xla/operation_semantics#fft. + }]; +} + +class BASE_HLO_GatherOp{ + string summary = "Gather operator"; + + string description = [{ + Stitches together several slices of an input array. + + See https://www.tensorflow.org/xla/operation_semantics#gather. + }]; +} + +class BASE_HLO_MapOp { + string summary = "Map operator"; + + string description = [{ + Applies a scalar function over the given operands arrays, producing an array + of the same dimensions where each element is the result of the mapped function + applied to the corresponding elements in the input arrays. + + The mapped function is an arbitrary computation with the restriction that it + has N inputs of scalar type T and a single output with type S. The output has + the same dimensions as the operands except that the element type T is replaced + with S. + + See https://www.tensorflow.org/xla/operation_semantics#map. + }]; +} + +class BASE_HLO_ReshapeOp { + string summary = "Reshape operator"; + + string description = [{ + Reshapes the dimensions of `operand` into a new configuration. + + See https://www.tensorflow.org/xla/operation_semantics#reshape. + }]; +} + +class ScatterDimensionNumbers : StructAttr< + "ScatterDimensionNumbers", dialect, [ + StructFieldAttr<"update_window_dims", I64ElementsAttr>, + StructFieldAttr<"inserted_window_dims", I64ElementsAttr>, + StructFieldAttr<"scatter_dims_to_operand_dims", I64ElementsAttr>, + StructFieldAttr<"index_vector_dim", I64Attr>]> { + let description = "Structure of dimension information for scatter"; +} + +class BASE_HLO_ScatterOp { + string summary = "Scatter operator"; + + string description = [{ + Generates a result which is the value of the input array `operand`, + with several slices (at indices specified by `scatter_indices`) + updated with the values in `updates` using `update_computation`. + + See https://www.tensorflow.org/xla/operation_semantics#scatter. + }]; +} + +class BASE_HLO_SelectOp { + string summary = "Select operator"; + + string description = [{ + Constructs an output tensor from the elements of `on_true` and `on_false` + based on the values of `pred`. + + `pred`, `on_true` and `on_false` must be broadcast compatible. + }]; +} + +class BASE_HLO_SelectAndScatterOp { + string summary = "SelectAndScatter operator"; + + string description = [{ + Runs a windowed selection `select` function over `operand` with shape + `window_dimensions` and stride `window_strides`. This will produce an amount + of selected locations whose shape matches `source`. These are then scattered + to the output which is initialized with `init_value`. + Multiple scattered elements which land in the same output location are + combined using the `scatter` function. + + See https://www.tensorflow.org/xla/operation_semantics#selectandscatter. + }]; +} + +class BASE_HLO_SetDimensionSizeOp { + string summary = "SetDimensionSize operator"; + + string description = [{ + Sets the dynamic size of operand's given dimension. Pass through the operand + as result, with dynamic dimension tracked by the compiler. Padded values + will be ignored by downstream reduction ops. + + See https://www.tensorflow.org/xla/operation_semantics#setdimensionsize. + }]; +} + +class BASE_HLO_SortOp { + string summary = "Sort operator"; + + string description = [{ + Sorts the given `operands` at the given `dimension` with the given + `comparator`. + + See https://www.tensorflow.org/xla/operation_semantics#sort. + }]; +} + +class BASE_HLO_ReverseOp { + string summary = "Reverse operator"; + + string description = [{ + Reverses the specified dimensions of `operand` according to the given + `dimensions`. + + See https://www.tensorflow.org/xla/operation_semantics#rev_reverse. + }]; +} + +class BASE_HLO_PadOp { + string summary = "Pad operator"; + + string description = [{ + Pads the edges of `operand` with the `padding_value` and according to + the passed configuration. + + See https://www.tensorflow.org/xla/operation_semantics#pad. + }]; +} + +class BASE_HLO_TraceOp { + string summary = "Trace operator"; + + string description = [{ + Emits a logging message `tag` with the `operand`. + }]; +} + +class BASE_HLO_TransposeOp { + string summary = "Transpose operator"; + + string description = [{ + Permutes the dimensions of `operand` according to the given `permutation`. + + `res_dimensions[i] = operand_dimensions[permutation[i]]` + + See https://www.tensorflow.org/xla/operation_semantics#transpose. + }]; +} + +// These mirror the XLA Transpose enum in Triangular Solve options. +def HLO_TRANSPOSE_INVALID : StrEnumAttrCase<"TRANSPOSE_INVALID">; +def HLO_NO_TRANSPOSE : StrEnumAttrCase<"NO_TRANSPOSE">; +def HLO_TRANSPOSE : StrEnumAttrCase<"TRANSPOSE">; +def HLO_ADJOINT : StrEnumAttrCase<"ADJOINT">; + +def HLO_TransposeAttr : StrEnumAttr<"Transpose", + "Transpose options", + [ + HLO_TRANSPOSE_INVALID, + HLO_NO_TRANSPOSE, + HLO_TRANSPOSE, + HLO_ADJOINT + ]>; + +class BASE_HLO_TriangularSolveOp { + string summary = "TriangularSolve operator"; + + string description = [{ + Solves systems of linear equations with lower or upper triangular + coefficient matrices by forward- or back-substitution. Broadcasting along + leading dimensions, this routine solves one of the matrix systems + op(a) * x = b, or x * op(a) = b, for the variable x, given a and b, where + op(a) is either op(a) = a, or op(a) = Transpose(a), or + op(a) = Conj(Transpose(a)). + + Input data is read only from the lower/upper triangle of a, depending on the + value of lower. Values from the other triangle are ignored. Output data is + returned in the same triangle; the values in the other triangle are + implementation-defined and may be anything. + + If the rank of a and b are greater than 2, they are treated as batches of + matrices, where all except the minor 2 dimensions are batch dimensions. a + and b must have equal batch dimensions. + + See https://www.tensorflow.org/xla/operation_semantics#triangularsolve. + }]; + +} + +class BASE_HLO_RngUniformOp { + string summary = "RNG with uniform distribution."; + + string description = [{ + Constructs an output of a given shape with random numbers generated + following the uniform distribution over the interval `[a,b)`. The parameters + and output element type have to be a boolean type, an integral type or a + floating point types, and the types have to be consistent. + + See https://www.tensorflow.org/xla/operation_semantics#rnguniform. + }]; +} + +class BASE_HLO_RngNormalOp { + string summary = "RNG with normal distribution."; + + string description = [{ + Constructs an output of a given shape with random numbers generated + following the normal distribution with parameters `mu` and `sigma`. The + parameters and output shape have to have a floating point elemental type. + The parameters furthermore have to be scalar valued. + + See https://www.tensorflow.org/xla/operation_semantics#rngnormal. + }]; +} + +class BASE_HLO_ReducePrecisionOp { + string summary = "Reduce precision operator"; + + string description = [{ + Models the effect of converting floating - point values to a lower - + precision format(such as IEEE - FP16) and back to the original + format. The number of exponent and mantissa bits in the lower - + precision format can be specified arbitrarily, + although all bit sizes may not be supported on all hardware + implementations. + + See https://www.tensorflow.org/xla/operation_semantics#reduceprecision. + }]; +} + +class BASE_HLO_InfeedOp { + string summary = "Infeed operator"; + + string description = [{ + Reads a single data item from the implicit Infeed streaming interface of + the device, interpreting the data as the given shape and its layout, and + returns an LHLO op of the data. Multiple Infeed operations are allowed in a + computation, but there must be a total order among the Infeed operations. + For example, two Infeeds in the code below have a total order since there + is a dependency between the while loops. + + See https://www.tensorflow.org/xla/operation_semantics#infeed + }]; +} + +class BASE_HLO_WhileOp { + string summary = "While operator"; + + string description = [{ + Returns the result of executing a body function until the cond body returns + true. + + See https://www.tensorflow.org/xla/operation_semantics#while. + }]; +} + +#endif // HLO_OPS_BASE diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td b/include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td new file mode 100644 index 0000000..fdb301d --- /dev/null +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td @@ -0,0 +1,43 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This is the utils file for the HLO dialect. + +#ifndef HLO_UTILS +#define HLO_UTILS + +include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpBase.td" + +def NullArrayAttr : NativeCodeCall<"ArrayAttr()">; + +def CastIntElementsAttr : NativeCodeCall<"$0.cast()">; + +class ConstantSplat : NativeCodeCall< + "xla::getSplat(&$_builder, $0, " # value # ")">; + +def NullDenseIntElementsAttr : NativeCodeCall<"DenseIntElementsAttr()">; + +def BinBroadcastDimensions : NativeCodeCall< + "xla::getBroadcastDimensionsAttr(&$_builder, $0, $1)">; + +def BinBroadcastDimensionsNonEmpty : NativeCodeCall< + "xla::getBroadcastDimensionsAttr(&$_builder, $0, $1, /*allow_empty=*/false)">; + +// Here, the element type can be any integer or float type. But, note that only +// 32 bit integers are supported for the value. +class GetScalarOfType : NativeCodeCall< + "xla::GetScalarOfType(getElementTypeOrSelf($0)," # value # ")">; + +#endif // HLO_UTILS diff --git a/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h b/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h new file mode 100644 index 0000000..412711b --- /dev/null +++ b/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h @@ -0,0 +1,28 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_INFER_FUSIBILITY_OP_INTERFACE_H_ +#define TENSORFLOW_COMPILER_MLIR_XLA_IR_INFER_FUSIBILITY_OP_INTERFACE_H_ + +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/StandardTypes.h" + +namespace mlir { + +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h.inc" + +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_XLA_IR_INFER_FUSIBILITY_OP_INTERFACE_H_ diff --git a/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td b/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td new file mode 100644 index 0000000..017a185 --- /dev/null +++ b/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td @@ -0,0 +1,161 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file contains inferFusiblityOpInterface, which is used to guide +// fusion decision. + +#ifndef MLIR_INFER_FUSIBILITY_OP_INTERFACE +#define MLIR_INFER_FUSIBILITY_OP_INTERFACE + +include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpBase.td" + +// OpInterface to query if an op is fusible and to query the shape equality +// constraint among the inputs and outputs of an op. +def InferFusibilityOpInterface : OpInterface<"InferFusibilityOpInterface"> { + let description = [{ + Interface to query if an op is fusible and to query the shape equality + constraint among the inputs and outputs of an op. + }]; + + let methods = [ + InterfaceMethod< + /*desc=*/[{If true, this op can be fused with its operands + }], + /*retTy=*/"bool", + /*methodName=*/"isFusibleWithOperand", + /*args=*/(ins), + /*methodBody=*/[{}], + /*defaultImplementation=*/[{ + /// Returns whether this op can be fused with its operands + return true; + }] + >, + InterfaceMethod< + /*desc=*/[{If true, this op can be fused with its consumers + }], + /*retTy=*/"bool", + /*methodName=*/"isFusibleWithConsumer", + /*args=*/(ins), + /*methodBody=*/[{}], + /*defaultImplementation=*/[{ + /// Return whether this op can be fused withh its consumers + return true; + }] + >, + InterfaceMethod< + /*desc=*/"Return whether two inputs have the same shape (assuming no" + "implicit broadcasting).", + /*retTy=*/"bool", + /*methodName=*/"inferInputsShapeEquality", + /*args=*/(ins "int":$lhs, "int":$rhs), + /*methodBody=*/[{}], + /*defaultImplementation=*/[{ + /// Return whether two inputs have the same shape. + Operation *op = this->getOperation(); + assert(lhs < op->getNumOperands() && lhs >= 0 && + rhs < op->getNumOperands() && rhs >= 0); + if (lhs == rhs) return true; + + // if both lhs and rhs have static shapes, check them directly + Type lhs_ty = op->getOperand(lhs).getType(); + Type rhs_ty = op->getOperand(rhs).getType(); + auto lhs_shape_type = lhs_ty.dyn_cast_or_null(); + auto rhs_shape_type = rhs_ty.dyn_cast_or_null(); + if (!lhs_shape_type || !lhs_shape_type.hasStaticShape() || + !rhs_shape_type || !rhs_shape_type.hasStaticShape() || + lhs_shape_type.getRank() != rhs_shape_type.getRank()) { + return false; + } + return lhs_shape_type.getShape() == rhs_shape_type.getShape(); + }] + >, + InterfaceMethod< + /*desc=*/"Return whether two outputs have the same shape (assuming no" + " implicit broadcasting).", + /*retTy=*/"bool", + /*methodName=*/"inferOutputsShapeEquality", + /*args=*/(ins "int":$lhs, "int":$rhs), + /*methodBody=*/[{}], + /*defaultImplementation=*/[{ + /// Return whether two outputs have the same shape. + Operation *op = this->getOperation(); + assert(lhs < op->getNumResults() && lhs >= 0 && + rhs < op->getNumResults() && rhs >= 0); + if (lhs == rhs) return true; + + // if both lhs and rhs have static shapes, check them directly + Type lhs_ty = op->getResult(lhs).getType(); + Type rhs_ty = op->getResult(rhs).getType(); + auto lhs_shape_type = lhs_ty.dyn_cast_or_null(); + auto rhs_shape_type = rhs_ty.dyn_cast_or_null(); + if (!lhs_shape_type || !lhs_shape_type.hasStaticShape() || + !rhs_shape_type || !rhs_shape_type.hasStaticShape() || + lhs_shape_type.getRank() != rhs_shape_type.getRank()) { + return false; + } + return lhs_shape_type.getShape() == rhs_shape_type.getShape(); + }] + >, + InterfaceMethod< + /*desc=*/"Return whether the input and the output have the same" + " shape (assuming no implicit broadcasting).", + /*retTy=*/"bool", + /*methodName=*/"inferInputOutputShapeEquality", + /*args=*/(ins "int":$input, "int":$output), + /*methodBody=*/[{}], + /*defaultImplementation=*/[{ + /// Return whether the input and the output have the same shape. + Operation *op = this->getOperation(); + assert(input < op->getNumOperands() && input >= 0 && + output < op->getNumResults() && output >= 0); + + // if both input and output have static shapes, check them directly + Type input_ty = op->getOperand(input).getType(); + Type output_ty = op->getResult(output).getType(); + auto input_shape_type = input_ty.dyn_cast_or_null(); + auto output_shape_type = output_ty.dyn_cast_or_null(); + if (!input_shape_type || !input_shape_type.hasStaticShape() || + !output_shape_type || !output_shape_type.hasStaticShape() || + input_shape_type.getRank() != output_shape_type.getRank()) { + return false; + } + return input_shape_type.getShape() == output_shape_type.getShape(); + }] + >, + InterfaceMethod< + /*desc=*/[{Return the effective workload shape for the operation. + + Here the effective workload shape roughly represents the maximum + parallelism can be used during the codegen stage. It's used to check + the shape-compatibility of the operation. During fusion, we only + try to fuse shape-compatible ops for performace. + For example, the effective workload shape of an elementwise op is its + output shape, while the effective workload shape of a reduction op may + be its operand shape. + Return None if such an inference is not possible. + }], + /*retTy=*/"llvm::Optional", + /*methodName=*/"inferEffectiveWorkloadShape", + /*args=*/(ins), + /*methodBody=*/[{}], + /*defaultImplementation=*/[{ + /// Return effective workload size if possible, otherwise None. + return {}; + }] + >, + ]; +} + +#endif diff --git a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h new file mode 100644 index 0000000..554252a --- /dev/null +++ b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h @@ -0,0 +1,52 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file defines the operations used in the LXLA dialect. + +#ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_LHLO_OPS_H_ +#define TENSORFLOW_COMPILER_MLIR_XLA_IR_LHLO_OPS_H_ + +#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringRef.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Dialect.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Location.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpDefinition.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Types.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/SideEffectInterfaces.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/ViewLikeInterface.h" + +namespace mlir { +class OpBuilder; + +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_structs.h.inc" + +namespace xla_lhlo { + +class XlaLhloDialect : public Dialect { + public: + explicit XlaLhloDialect(MLIRContext *context); + static StringRef getDialectNamespace() { return "xla_lhlo"; } +}; + +#define GET_OP_CLASSES +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc" + +} // namespace xla_lhlo +} // end namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_XLA_IR_LHLO_OPS_H_ diff --git a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td new file mode 100644 index 0000000..4e4235d --- /dev/null +++ b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td @@ -0,0 +1,796 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This is the operation definition file for LXLA. +// +// This file largely overlaps with hlo_ops.td at a logic level. It's tempting to +// merge these two files together, but we need to consider the following +// obstacles: +// * We need to have a common representation for arguments. That is to say, +// HLO_Array translates to HLO_Tensor in HLO dialect, and +// Arg, "", [Mem(Read|Write)]> in LHLO. Array types within tuples +// also need to be transformed. +// * As of now, TableGen's dag functions are not sufficient to accomplish the +// one above. +// * Traits aren't identical, but need to be coped. For example, +// SameOperandAndResultType in HLO corresponds to SameTypeOperands in LHLO. +// * Also, currently HLO describes the API in XLA's client side, not service +// side. LHLO aims for the service side. + +#ifndef LHLO_OPS +#define LHLO_OPS + +include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpBase.td" +include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/SideEffectInterfaces.td" +include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/ViewLikeInterface.td" +include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td" + +def LHLO_Dialect : Dialect { + let name = "xla_lhlo"; + let cppNamespace = "xla_lhlo"; +} + +//===----------------------------------------------------------------------===// +// XLA type definitions. +//===----------------------------------------------------------------------===// + +// Any integer tensor types +def LHLO_IntBuffer : MemRefOf<[HLO_Int]>; + +// Any floating-point tensor types +def LHLO_FpBuffer : MemRefOf<[AnyFloat]>; + +def LHLO_ComplexBuffer : MemRefOf<[AnyComplex]>; + +def LHLO_FpOrComplexBuffer : MemRefOf<[AnyFloat, AnyComplex]>; + +def LHLO_PredBuffer : MemRefOf<[HLO_Pred]>; + +// Any integer or floating-point tensor types +def LHLO_IntOrFpBuffer : MemRefOf<[HLO_Int, AnyFloat]>; + +def LHLO_PredOrIntBuffer : MemRefOf<[HLO_Int, HLO_Pred]>; + +def LHLO_Buffer : MemRefOf<[AnyFloat, AnySignlessInteger, AnyComplex]>; + +//===----------------------------------------------------------------------===// +// XLA nullary op definitions. +//===----------------------------------------------------------------------===// + +class LHLO_Op traits> : + Op], traits)>; + +def LHLO_ConstOp : LHLO_Op<"constant", []>, BASE_HLO_ConstOp { + let arguments = (ins + ElementsAttr:$value, + Arg:$output + ); +} + +def LHLO_IotaOp : LHLO_Op<"iota", []>, BASE_HLO_IotaOp { + let arguments = (ins I64Attr:$iota_dimension, + Arg:$output); +} + +//===----------------------------------------------------------------------===// +// XLA unary elementwise op definitions. +//===----------------------------------------------------------------------===// +// See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions + +class LHLO_UnaryElementwiseOp traits = [SameTypeOperands]> + : LHLO_Op { + let arguments = (ins Arg:$input, + Arg:$output); +} + +def LHLO_AbsOp: LHLO_UnaryElementwiseOp<"abs">, BASE_HLO_AbsOp; + +// TODO(timshen): add a custom verifier. +def LHLO_BitcastConvertOp: + LHLO_UnaryElementwiseOp<"bitcast_convert", LHLO_Buffer, [SameOperandsShape]>, BASE_HLO_BitcastConvertOp; + +def LHLO_CeilOp: LHLO_UnaryElementwiseOp<"ceil", LHLO_FpBuffer>, BASE_HLO_CeilOp; + +def LHLO_ClzOp: LHLO_UnaryElementwiseOp<"count_leading_zeros", LHLO_IntBuffer>, BASE_HLO_ClzOp; + +// TODO(timshen): add a custom verifier. +def LHLO_ConvertOp : LHLO_UnaryElementwiseOp<"convert", LHLO_Buffer, [SameOperandsShape]>, BASE_HLO_ConvertOp; + +def LHLO_CosOp: LHLO_UnaryElementwiseOp<"cosine", LHLO_FpOrComplexBuffer>, BASE_HLO_CosOp; + +def LHLO_ExpOp: LHLO_UnaryElementwiseOp<"exponential", LHLO_FpOrComplexBuffer>, BASE_HLO_ExpOp; + +def LHLO_Expm1Op: LHLO_UnaryElementwiseOp<"exponential_minus_one", LHLO_FpOrComplexBuffer>, BASE_HLO_Expm1Op; + +def LHLO_FloorOp: LHLO_UnaryElementwiseOp<"floor", LHLO_FpBuffer>, BASE_HLO_FloorOp; + +def LHLO_ImagOp: LHLO_Op<"imag", [SameOperandsShape]>, BASE_HLO_ImagOp { + let arguments = (ins Arg:$input, + Arg:$output); +} + +def LHLO_IsFiniteOp: LHLO_Op<"is_finite", [SameOperandsShape]>, BASE_HLO_IsFiniteOp { + let arguments = (ins Arg:$input, + Arg:$output); +} + +def LHLO_LogOp: LHLO_UnaryElementwiseOp<"log", LHLO_FpOrComplexBuffer>, BASE_HLO_LogOp; + +def LHLO_Log1pOp: LHLO_UnaryElementwiseOp<"log_plus_one", LHLO_FpOrComplexBuffer>, BASE_HLO_Log1pOp; + +def LHLO_NegOp: LHLO_UnaryElementwiseOp<"negate">, BASE_HLO_NegOp; + +def LHLO_NotOp: LHLO_UnaryElementwiseOp<"not", LHLO_PredOrIntBuffer>, BASE_HLO_NotOp; + +def LHLO_PopulationCountOp: LHLO_UnaryElementwiseOp<"popcnt", LHLO_IntBuffer>, BASE_HLO_PopulationCountOp; + +def LHLO_RealOp: LHLO_Op<"real", [SameOperandsShape]>, BASE_HLO_RealOp { + let arguments = (ins Arg:$input, + Arg:$output); +} + +def LHLO_RoundOp: LHLO_UnaryElementwiseOp<"round_nearest_afz", LHLO_FpBuffer>, BASE_HLO_RoundOp; + +def LHLO_RsqrtOp: LHLO_UnaryElementwiseOp<"rsqrt", LHLO_FpOrComplexBuffer>, BASE_HLO_RsqrtOp; + +def LHLO_SqrtOp: LHLO_UnaryElementwiseOp<"sqrt", LHLO_FpOrComplexBuffer>, BASE_HLO_SqrtOp; + +def LHLO_SignOp: LHLO_UnaryElementwiseOp<"sign">, BASE_HLO_SignOp; + +def LHLO_SinOp: LHLO_UnaryElementwiseOp<"sine", LHLO_FpOrComplexBuffer>, BASE_HLO_SinOp; + +def LHLO_TanhOp: LHLO_UnaryElementwiseOp<"tanh", LHLO_FpOrComplexBuffer>, BASE_HLO_TanhOp; + +//===----------------------------------------------------------------------===// +// XLA binary elementwise op definitions. +//===----------------------------------------------------------------------===// +// See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations + +class LHLO_BinaryElementwiseOp traits = [SameTypeOperands]> : + LHLO_Op { + let arguments = (ins + Arg:$lhs, + Arg:$rhs, + Arg:$out, + OptionalAttr:$broadcast_dimensions + ); +} + +def LHLO_AddOp : LHLO_BinaryElementwiseOp<"add">, BASE_HLO_AddOp; + +def LHLO_AndOp: LHLO_BinaryElementwiseOp<"and", LHLO_PredOrIntBuffer>, BASE_HLO_AndOp; + +def LHLO_Atan2Op : LHLO_BinaryElementwiseOp<"atan2", LHLO_FpOrComplexBuffer>, BASE_HLO_Atan2Op; + +def LHLO_ComplexOp: LHLO_Op<"complex", [SameOperandsShape]>, BASE_HLO_ComplexOp { + let arguments = (ins + Arg:$lhs, + Arg:$rhs, + Arg:$output, + OptionalAttr:$broadcast_dimensions + ); +} + +def LHLO_DivOp : LHLO_BinaryElementwiseOp<"divide">, BASE_HLO_DivOp; + +def LHLO_MaxOp : LHLO_BinaryElementwiseOp<"maximum">, BASE_HLO_MaxOp; + +def LHLO_MinOp : LHLO_BinaryElementwiseOp<"minimum">, BASE_HLO_MinOp; + +def LHLO_MulOp : LHLO_BinaryElementwiseOp<"multiply">, BASE_HLO_MulOp; + +def LHLO_OrOp : LHLO_BinaryElementwiseOp<"or", LHLO_PredOrIntBuffer>, BASE_HLO_OrOp; + +def LHLO_PowOp : LHLO_BinaryElementwiseOp<"power">, BASE_HLO_PowOp; + +def LHLO_RemOp : LHLO_BinaryElementwiseOp<"remainder", LHLO_IntOrFpBuffer>, BASE_HLO_RemOp; + +def LHLO_ShiftLeftOp : LHLO_BinaryElementwiseOp<"shift_left", LHLO_IntBuffer>, BASE_HLO_ShiftLeftOp; + +def LHLO_ShiftRightArithmeticOp : LHLO_BinaryElementwiseOp<"shift_right_arithmetic", LHLO_IntBuffer>, BASE_HLO_ShiftRightArithmeticOp; + +def LHLO_ShiftRightLogicalOp : LHLO_BinaryElementwiseOp<"shift_right_logical", LHLO_IntBuffer>, BASE_HLO_ShiftRightLogicalOp; + +def LHLO_SubOp : LHLO_BinaryElementwiseOp<"subtract">, BASE_HLO_SubOp; + +def LHLO_XorOp : LHLO_BinaryElementwiseOp<"xor", LHLO_PredOrIntBuffer>, BASE_HLO_XorOp; + +//===----------------------------------------------------------------------===// +// XLA control flow op definitions. +//===----------------------------------------------------------------------===// + +// TODO(b/139813999): specify required function signature in a type-safe way. +def LHLO_ReduceOp: LHLO_Op<"reduce", [ + SameVariadicOperandSize, + SingleBlockImplicitTerminator<"TerminatorOp"> + ]>, BASE_HLO_ReduceOp { + let arguments = (ins + Arg, "", [MemRead]>:$operands, + Arg, "", [MemRead]>:$init_values, + Arg, "", [MemWrite]>:$out, + I64ElementsAttr:$dimensions + ); + + let regions = (region SizedRegion<1>:$body); +} + +def LHLO_ReduceWindowOp: LHLO_Op<"reduce_window", [ + SingleBlockImplicitTerminator<"TerminatorOp"> + ]>, BASE_HLO_ReduceWindowOp { + + let arguments = (ins + Arg:$operand, + Arg:$init_value, + Arg:$out, + I64ElementsAttr:$window_dimensions, + // If strides or dilations attributes are missing then the default value is + // one for each of the input dimensions. Similarly, padding values are zero + // for both low and high in each of the dimensions, if not specified. + OptionalAttr:$window_strides, + OptionalAttr:$base_dilations, + OptionalAttr:$window_dilations, + OptionalAttr:$padding + ); + + let regions = (region SizedRegion<1>:$body); +} + +// TODO(timshen): Add a custom parser to hide operand_segment_sizes. For example, +// A tuple-like pattern match syntax could work: +// xla_lhlo.case %index, (%input0, %input1, %input2), (%output0, %output1) { +// ... +// }, { +// ... +// } : (type_input0, type_input1, type_input2, type_output0, type_output1) -> () +def LHLO_CaseOp: LHLO_Op<"case", [ + AttrSizedOperandSegments, + SingleBlockImplicitTerminator<"TerminatorOp"> + ]>, BASE_HLO_CaseOp { + + let arguments = (ins + Arg:$index, + Arg, "", [MemRead]>:$branch_operands, + Arg, "", [MemWrite]>:$out + ); + + let regions = (region VariadicRegion>:$branches); +} + +// TODO(timshen): Add a custom syntax for this. +def LHLO_WhileOp: LHLO_Op<"while", [SameVariadicOperandSize]>, + BASE_HLO_WhileOp { + let arguments = (ins + Arg, "", [MemRead]>:$val, + Arg, "", [MemWrite]>:$output + ); + + let regions = (region SizedRegion<1>:$cond, SizedRegion<1>:$body); +} + +//===----------------------------------------------------------------------===// +// XLA tuple op definitions. +//===----------------------------------------------------------------------===// + +def LHLO_CompareOp: LHLO_Op<"compare", []>, BASE_HLO_CompareOp { + let arguments = (ins + Arg:$lhs, + Arg:$rhs, + Arg:$out, + OptionalAttr:$broadcast_dimensions, + HLO_ComparisonDirectionAttr:$comparison_direction + ); +} + +//===----------------------------------------------------------------------===// +// XLA Slice definitions. +//===----------------------------------------------------------------------===// + +def LHLO_SliceOp: LHLO_Op< + "slice", + [AllTypesMatch<["start_indices", "limit_indices", "strides"]>]> { + let arguments = (ins + Arg:$operand, + Arg:$output, + I64ElementsAttr:$start_indices, + I64ElementsAttr:$limit_indices, + I64ElementsAttr:$strides + ); +} + +def HLO_DynamicUpdateSliceOp: LHLO_Op<"dynamic-update-slice", []> { + let arguments = (ins + Arg:$operand, + Arg:$update, + Arg:$output, + Arg, "", [MemRead]>:$start_indices + ); +} + +//===----------------------------------------------------------------------===// +// StaticMemRefCastOp +//===----------------------------------------------------------------------===// + +def HLO_StaticMemRefCastOp: Op]> { + let summary = [{ + "modifies the offset, sizes and strides of a statically shaped memref. + }]; + let description = [{ + Allows to modify the offset, sizes and strides of a statically shaped memref. + + Example: + ```mlir + %buf_transformed = + xla_lhlo.static_memref_cast %buf + : memref<1x5xf32> -> memref<5xf32, offset: 2, strides: [1]> + + // The result of the op is a rank-1 memref with `[5]` shape, stride 1 and + // offset 2. + ``` + }]; + + let arguments = (ins Arg:$operand); + let results = (outs Res:$result); + + let builders = [OpBuilder< + "OpBuilder &builder, OperationState &result, MemRefType resultType, " # + "Value operand", [{ + result.addOperands(operand); + result.types.push_back(resultType); + }]>]; + + let extraClassDeclaration = [{ + MemRefType getType() { return getResult().getType().cast(); } + }]; + + let verifier = [{ return Verify(*this); }]; + let assemblyFormat = [{ + $operand attr-dict `:` type($operand) `->` type($result) + }]; +} + +//===----------------------------------------------------------------------===// +// DynamicMemRefCastOp +//===----------------------------------------------------------------------===// + +def HLO_DynamicMemRefCastOp: Op]> { + let summary = "dynamic memref cast operation"; + let description = [{ + Change sizes and strides of a memref using the values computed in runtime. + + Example: + ```mlir + %buf_transformed = + xla_lhlo.dynamic_memref_cast %buf(%size_X, %size_Y)[%step_X, %step_Y] + : memref -> memref + // The result of the op is a type-erased memref with `[%size_X, %size_Y]` + // shape and `[%step_X, %step_Y]` strides. The offset will be inherited + // from the input. + ``` + }]; + + let arguments = (ins + Arg:$operand, + Variadic:$sizes, + Variadic:$strides + ); + let results = (outs Res:$result); + + let builders = [OpBuilder< + "OpBuilder &builder, OperationState &result, MemRefType resultType, " # + "Value operand, ValueRange sizes, ValueRange strides", [{ + result.addOperands(operand); + result.addOperands(sizes); + result.addOperands(strides); + result.types.push_back(resultType); + }]>]; + + let extraClassDeclaration = [{ + MemRefType getType() { return getResult().getType().cast(); } + }]; + + let verifier = [{ return Verify(*this); }]; + let assemblyFormat = [{ + $operand `(` $sizes `)` `[` $strides `]` attr-dict `:` type($operand) `->` + type($result) + }]; +} + +//===----------------------------------------------------------------------===// +// XLA Other op definitions. +//===----------------------------------------------------------------------===// + +def LHLO_BatchNormGradOp : LHLO_Op<"batch_norm_grad", []>, + BASE_HLO_BatchNormGradOp { + + let arguments = (ins + Arg:$operand, + Arg:$scale, + Arg:$mean, + Arg:$variance, + Arg:$grad_output, + Arg:$grad_operand, // gradient of $operand. + Arg:$grad_scale, + Arg:$grad_offset, + F32Attr:$epsilon, + I64Attr:$feature_index + ); + +} + +def LHLO_BatchNormInferenceOp : LHLO_Op<"batch_norm_inference", []>, + BASE_HLO_BatchNormInferenceOp { + + let arguments = (ins + Arg:$operand, + Arg:$scale, + Arg:$offset, + Arg:$mean, + Arg:$variance, + Arg:$output, + F32Attr:$epsilon, + I64Attr:$feature_index + ); +} + +def LHLO_BatchNormTrainingOp : LHLO_Op<"batch_norm_training", []>, + BASE_HLO_BatchNormTrainingOp { + + let arguments = (ins + Arg:$operand, + Arg:$scale, + Arg:$offset, + Arg:$output, + Arg:$batch_mean, + Arg:$batch_var, + F32Attr:$epsilon, + I64Attr:$feature_index + ); +} + +// TODO(timshen): add a custom verifier. +def LHLO_BitcastOp: LHLO_Op<"bitcast", []> { + let arguments = (ins Arg:$input, + Arg:$output); +} + +def LHLO_BroadcastOp : LHLO_Op<"broadcast", + []>, BASE_HLO_BroadcastOp { + let arguments = (ins + Arg:$operand, + Arg:$output, + I64ElementsAttr:$broadcast_sizes + ); +} + +def LHLO_BroadcastInDimOp : LHLO_Op<"broadcast_in_dim", + []>, BASE_HLO_BroadcastInDimOp { + let arguments = (ins + Arg:$operand, + Arg:$output, + BroadcastDimAttr:$broadcast_dimensions + ); +} + +def LHLO_ClampOp : LHLO_Op<"clamp", []>, BASE_HLO_ClampOp { + let arguments = (ins + Arg:$min, + Arg:$operand, + Arg:$max, + Arg:$output + ); +} + +def LHLO_ConcatenateOp : LHLO_Op<"concatenate", []>, BASE_HLO_ConcatenateOp { + let arguments = (ins + Arg, "", [MemRead]>:$val, + Arg:$output, + I64Attr:$dimension + ); +} + +// TODO(bondhugula): Make this struct dialect independent so that it can be +// shared between the HLO and LHLO dialects. +def ConvDimensionNumbers : StructAttr<"ConvDimensionNumbers", LHLO_Dialect, [ + StructFieldAttr<"input_batch_dimension",I64Attr>, + StructFieldAttr<"input_feature_dimension", I64Attr>, + StructFieldAttr<"input_spatial_dimensions", I64ElementsAttr>, + StructFieldAttr<"kernel_input_feature_dimension", I64Attr>, + StructFieldAttr<"kernel_output_feature_dimension", I64Attr>, + StructFieldAttr<"kernel_spatial_dimensions", I64ElementsAttr>, + StructFieldAttr<"output_batch_dimension", I64Attr>, + StructFieldAttr<"output_feature_dimension", I64Attr>, + StructFieldAttr<"output_spatial_dimensions", I64ElementsAttr>] > { + + let description = "Structure of dimension information for conv op"; +} + +def LHLO_ConvOp : LHLO_Op<"convolution", []>, BASE_HLO_ConvOp { + let arguments = (ins + Arg:$lhs, + Arg:$rhs, + Arg:$output, + // Default value: one for each of the spatial dimension. + OptionalAttr:$window_strides, + // Default value: zero for each of the spatial dimension. + OptionalAttr:$padding, + // Default value: one for each of the spatial dimension. + OptionalAttr:$lhs_dilation, + // Default value: one for each of the spatial dimension. + OptionalAttr:$rhs_dilation, + ConvDimensionNumbers:$dimension_numbers, + I64Attr:$feature_group_count, + I64Attr:$batch_group_count, + HLO_PrecisionConfigAttr:$precision_config + ); +} + +def LHLO_CopyOp: LHLO_Op<"copy", []>, BASE_HLO_CopyOp { + let arguments = (ins + Arg:$operand, + Arg:$output + ); +} + +def LHLO_DotOp: LHLO_Op<"dot", []>, BASE_HLO_DotOp { + let arguments = (ins + Arg:$lhs, + Arg:$rhs, + HLO_PrecisionConfigAttr:$precision_config, + Arg:$output + ); +} + +def LHLO_GatherOp: LHLO_Op<"gather", []>, BASE_HLO_GatherOp { + let arguments = (ins + Arg:$operand, + Arg:$start_indices, + I64Attr:$index_vector_dim, + I64ElementsAttr:$offset_dims, + I64ElementsAttr:$slice_sizes, + I64ElementsAttr:$collapsed_slice_dims, + I64ElementsAttr:$start_index_map, + Arg:$output + ); +} + +def LHLO_ReshapeOp: LHLO_Op<"reshape", []>, BASE_HLO_ReshapeOp { + let arguments = (ins + Arg:$operand, + Arg:$output + ); +} + +def LHLO_ScatterOp: LHLO_Op<"scatter", []>, BASE_HLO_ScatterOp { + let arguments = (ins + Arg:$operand, + Arg:$scatter_indices, + Arg:$updates, + Arg:$output, + ScatterDimensionNumbers:$scatter_dimension_numbers, + DefaultValuedAttr:$indices_are_sorted, + DefaultValuedAttr:$unique_indices + ); + + let regions = (region SizedRegion<1>:$update_computation); +} + +def LHLO_SelectOp: LHLO_Op<"select", []>, BASE_HLO_SelectOp { + let arguments = (ins + Arg:$pred, + Arg:$on_true, + Arg:$on_false, + Arg:$output + ); +} + +def LHLO_SelectAndScatterOp: LHLO_Op<"select_and_scatter", []>, + BASE_HLO_SelectAndScatterOp { + let arguments = (ins + Arg:$operand, + Arg:$source, + Arg:$init_value, + Arg:$out, + OptionalAttr:$window_dimensions, + OptionalAttr:$window_strides, + OptionalAttr:$padding + ); + + let regions = (region SizedRegion<1>:$select, SizedRegion<1>:$scatter); +} + +def LHLO_ReverseOp: LHLO_Op<"reverse", []>, BASE_HLO_ReverseOp { + let arguments = (ins + Arg:$operand, + I64ElementsAttr:$dimensions, + Arg:$output + ); +} + +def LHLO_PadOp: LHLO_Op<"pad", []>, BASE_HLO_PadOp { + let arguments = (ins + Arg:$operand, + Arg:$padding_value, + I64ElementsAttr:$edge_padding_low, + I64ElementsAttr:$edge_padding_high, + I64ElementsAttr:$interior_padding, + Arg:$output + ); +} + +def LHLO_TransposeOp: LHLO_Op<"transpose", []>, BASE_HLO_TransposeOp { + let arguments = (ins + Arg:$operand, + I64ElementsAttr:$permutation, + Arg:$output + ); +} + +def LHLO_ReducePrecisionOp: LHLO_Op<"reduce_precision", [SameTypeOperands]>, + BASE_HLO_ReducePrecisionOp { + let arguments = (ins + Arg:$operand, + Arg:$output, + I32Attr:$exponent_bits, + I32Attr:$mantissa_bits + ); +} + +def LHLO_AllReduceOp : LHLO_Op<"all_reduce", [SameTypeOperands]>, + BASE_HLO_AllReduceOp { + let arguments = (ins + Arg:$operand, + Arg:$output, + I64ElementsAttr:$replica_groups, + DefaultValuedAttr:$constrain_layout, + OptionalAttr>:$channel_id, + DefaultValuedAttr:$use_global_device_ids + ); + let regions = (region SizedRegion<1>:$computation); +} + +def LHLO_CollectivePermuteOp: LHLO_Op<"collective_permute", [SameTypeOperands]>, + BASE_HLO_CollectivePermuteOp { + + let arguments = (ins + Arg:$operand, + Arg:$output, + I64ElementsAttr:$source_target_pairs, + OptionalAttr>:$channel_id + ); +} + +def LHLO_FftOp: LHLO_Op<"fft", []>, BASE_HLO_FftOp { + let arguments = (ins + Arg:$operand, + Arg:$output, + HLO_FftTypeAttr:$fft_type, + I64ElementsAttr:$fft_length + ); +} + +def LHLO_CholeskyOp: LHLO_Op<"cholesky", [SameOperandsElementType]>, BASE_HLO_CholeskyOp { + let arguments = (ins + Arg:$a, + Arg:$output, + DefaultValuedAttr:$lower + ); +} + +def LHLO_Infeed: LHLO_Op<"infeed", []>, BASE_HLO_InfeedOp { + let arguments = (ins + Arg:$output, + DefaultValuedAttr:$config + ); +} + +def LHLO_Outfeed: LHLO_Op<"outfeed", []> { + let arguments = (ins + Arg:$operand, + DefaultValuedAttr:$config + ); +} + +def LHLO_ReplicaIdOp : LHLO_Op<"replica_id", []>, BASE_HLO_ReplicaIdOp { + let arguments = (ins Arg, "", [MemWrite]>); +} + +def LHLO_TriangularSolveOp: LHLO_Op<"triangular_solve", [SameOperandsElementType]>, + BASE_HLO_TriangularSolveOp { + let arguments = (ins + Arg:$a, + Arg:$b, + Arg:$output, + BoolAttr:$left_side, + BoolAttr:$lower, + BoolAttr:$unit_diagonal, + HLO_TransposeAttr:$transpose_a + ); +} + +// TODO(timshen): add a custom verifier. +def LHLO_MapOp: LHLO_Op<"map", [SameOperandsShape]>, BASE_HLO_MapOp { + let arguments = (ins + Arg, "", [MemRead]>:$operands, + Arg:$output, + I64ElementsAttr:$dimensions + ); + let regions = (region SizedRegion<1>:$computation); +} + +def LHLO_RngGetAndUpdateStateOp: LHLO_Op<"rng_get_and_update_state", []> { + let arguments = (ins + Arg, "", [MemRead, MemWrite]>:$state, + I64Attr:$delta + ); +} + +// TODO(timshen): add a custom verifier. +def LHLO_SortOp: LHLO_Op<"sort", [SameVariadicOperandSize, SameOperandsShape]>, BASE_HLO_SortOp { + let arguments = (ins + Arg, "", [MemRead]>:$operands, + Arg, "", [MemWrite]>:$output, + DefaultValuedAttr:$dimension, + DefaultValuedAttr:$is_stable + ); + + let regions = (region SizedRegion<1>:$comparator); +} + +//===----------------------------------------------------------------------===// +// Late operations +//===----------------------------------------------------------------------===// + +def FusionOp : LHLO_Op<"fusion", [SingleBlockImplicitTerminator<"TerminatorOp">]> { + let summary = "Fusion operator"; + let description = [{ + Models the fusion instruction generated by the XLA compiler's fusion pass. + + Fusion instructions are generated by the fusion pass of the XLA compiler. + They serve as a hint to the backend that it is beneficial to emit the + contained instructions into a single loop nest or kernel. The XLA fusion + pass is designed such that it only generates fusion nodes that can be + handled by the XLA compilers backends. + The XLA runtime expects this hint to be followed, as it expects a single + kernel per HLO instruction. This restriction might be lifted in the future. + }]; + let regions = (region SizedRegion<1>:$region); + + let skipDefaultBuilders = 1; + let builders = [ + OpBuilder<"OpBuilder &builder, OperationState &result, " + "ArrayRef attributes"> + ]; +} + +def TerminatorOp : + LHLO_Op<"terminator", [Terminator]> { + let summary = "LHLO termination operation"; + let description = [{ + Terminator operation for the LHLO dialect. + }]; + let builders = [OpBuilder< + "OpBuilder &b, OperationState &result, ValueRange operands", + [{ build(b, result, llvm::None, operands, llvm::None); }] + >]; +} + +#endif // LHLO_OPS diff --git a/include/mlir-hlo/utils/broadcast_utils.h b/include/mlir-hlo/utils/broadcast_utils.h new file mode 100644 index 0000000..957d5f9 --- /dev/null +++ b/include/mlir-hlo/utils/broadcast_utils.h @@ -0,0 +1,49 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_BROADCAST_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_XLA_IR_BROADCAST_UTILS_H_ + +// Utilities relating to implementing HLO broadcasting. +// Note: This file should not depend on any non-MLIR TensorFlow libraries. + +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Location.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/InferTypeOpInterface.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Support/LLVM.h" + +namespace mlir { +namespace xla { + +// Checks whether the given operand types and broadcast_dims attr represent a +// legal combination for "numpy" style broadcasting (where 1-dims are prepended +// to the smaller ranked operand until it is of the same rank as the larger). +// See: https://docs.scipy.org/doc/numpy/reference/ufuncs.html +bool IsLegalNumpyRankedBroadcast(Value lhs, Value rhs, + DenseIntElementsAttr broadcast_dims); + +// Emits shape dialect ops to compute the result shape for a broadcasting +// binary elementwise op which broadcasts according to "numpy" semantics +// (see above), returning an extents tensor of the resulting shape. +Value ComputeBinaryElementwiseBroadcastingResultExtents(Location loc, Value lhs, + Value rhs, + OpBuilder& builder); + +} // namespace xla +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_XLA_IR_BROADCAST_UTILS_H_ diff --git a/include/mlir-hlo/utils/convert_op_folder.h b/include/mlir-hlo/utils/convert_op_folder.h new file mode 100644 index 0000000..dcda285 --- /dev/null +++ b/include/mlir-hlo/utils/convert_op_folder.h @@ -0,0 +1,33 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_XLA_CONVERT_OP_FOLDER_H_ +#define TENSORFLOW_COMPILER_MLIR_XLA_CONVERT_OP_FOLDER_H_ + +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" + +namespace mlir { +namespace xla { + +// Converts the given elements attr to the specified elements type. +// Requires type of the elements and new_type to be either integer or float +// type. +mlir::ElementsAttr ConvertElementsAttr(const mlir::ElementsAttr& elements, + mlir::Type new_type); +} // namespace xla +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_XLA_CONVERT_OP_FOLDER_H_ diff --git a/include/mlir-hlo/utils/hlo_utils.h b/include/mlir-hlo/utils/hlo_utils.h new file mode 100644 index 0000000..cfb7184 --- /dev/null +++ b/include/mlir-hlo/utils/hlo_utils.h @@ -0,0 +1,74 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_UTILS_H_ + +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/TypeUtilities.h" + +namespace mlir { +namespace xla { + +// Computes the broadcast dimensions attr for an elementwise binary operator +// between two ranked tensors. +// If `allow_empty` is true, then null can be returned to mean that the +// broadcast is an "identity". +mlir::DenseIntElementsAttr getBroadcastDimensionsAttr(mlir::Builder* b, + mlir::Value x, + mlir::Value y, + bool allow_empty = true); + +// Get a constant splat for the given value of type. Requires value to be of +// type static shaped RankedTensorType. +template +static ElementsAttr getSplat(Builder* b, RankedTensorType ty, T constant) { + Type element_ty = getElementTypeOrSelf(ty); + + if (element_ty.isSignlessInteger()) + return DenseElementsAttr::get(ty, b->getIntegerAttr(element_ty, constant)); + + if (element_ty.isa()) + return DenseElementsAttr::get(ty, b->getFloatAttr(element_ty, constant)); + + if (auto complex_ty = element_ty.dyn_cast()) { + auto complex_element_ty = complex_ty.getElementType(); + if (complex_element_ty.isF32()) + return DenseElementsAttr::get(ty, + static_cast>(constant)); + if (complex_element_ty.isF64()) + return DenseElementsAttr::get( + ty, static_cast>(constant)); + } + llvm_unreachable("unhandled element type"); +} + +template +static ElementsAttr getSplat(Builder* b, Value val, T constant) { + return getSplat(b, val.getType().cast(), constant); +} + +// Returns DenseElementsAttr of rank zero with the given element type and the +// value. +// Requires `ty` to be either FloatType of IntegerType. +DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value); + +} // namespace xla +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_UTILS_H_ diff --git a/lib/Dialect/mhlo/IR/chlo_ops.cc b/lib/Dialect/mhlo/IR/chlo_ops.cc new file mode 100644 index 0000000..fd2e441 --- /dev/null +++ b/lib/Dialect/mhlo/IR/chlo_ops.cc @@ -0,0 +1,278 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" + +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Diagnostics.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/TypeUtilities.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/broadcast_utils.h" + +namespace mlir { +namespace xla_chlo { + +template +static LogicalResult Verify(T op) { + return success(); +} + +//===----------------------------------------------------------------------===// +// BinaryOps +//===----------------------------------------------------------------------===// + +namespace { +// Gets the resulting type from a broadcast between two types. +static Type GetBroadcastType(Type x, Type y, Type element_type, + DenseIntElementsAttr broadcast_dimensions_attr) { + auto x_ranked = x.dyn_cast(); + auto y_ranked = y.dyn_cast(); + if (!x_ranked || !y_ranked) { + return UnrankedTensorType::get(element_type); + } + + auto shape_x = x_ranked.getShape(); + auto shape_y = y_ranked.getShape(); + + if (shape_x.size() == shape_y.size()) { + llvm::SmallVector out_shape(shape_x.size()); + for (int i = 0, e = shape_x.size(); i < e; i++) { + auto x_val = shape_x[i]; + auto y_val = shape_y[i]; + if (x_val == -1 || y_val == -1) { + out_shape[i] = -1; + } else { + out_shape[i] = std::max(x_val, y_val); + } + } + return RankedTensorType::get(out_shape, element_type); + } + + auto shape_large = shape_x.size() > shape_y.size() ? shape_x : shape_y; + auto shape_small = shape_x.size() <= shape_y.size() ? shape_x : shape_y; + + llvm::SmallVector broadcast_dimensions; + if (broadcast_dimensions_attr) { + // Explicit broadcast dimensions. + for (const APInt& int_value : broadcast_dimensions_attr.getIntValues()) { + broadcast_dimensions.push_back(int_value.getSExtValue()); + } + if (broadcast_dimensions.size() != shape_small.size()) { + // Signal illegal broadcast_dimensions as unranked. + return UnrankedTensorType::get(element_type); + } + } else { + // If no broadcast dimensions, assume "numpy" broadcasting. + broadcast_dimensions = llvm::to_vector<4>(llvm::seq( + shape_large.size() - shape_small.size(), shape_large.size())); + } + + llvm::SmallVector out_shape(shape_large.begin(), + shape_large.end()); + + // Update according to the broadcast dimensions. + for (auto index_pair : llvm::enumerate(broadcast_dimensions)) { + auto old_value = out_shape[index_pair.value()]; + auto new_value = shape_small[index_pair.index()]; + if (old_value != -1 && (new_value == -1 || new_value > old_value)) { + out_shape[index_pair.value()] = new_value; + } + } + + return RankedTensorType::get(out_shape, element_type); +} + +LogicalResult InferBroadcastBinaryOpReturnTypeComponents( + MLIRContext* context, Optional location, ValueRange operands, + DictionaryAttr attributes, Type element_type, + SmallVectorImpl& inferedReturnShapes) { + // Find broadcast_dimensions. + DenseIntElementsAttr broadcast_dimensions = + attributes.get("broadcast_dimensions") + .dyn_cast_or_null(); + + ShapedType lhs_type = operands[0].getType().dyn_cast(); + ShapedType rhs_type = operands[1].getType().dyn_cast(); + if (!lhs_type || !rhs_type || + lhs_type.getElementType() != rhs_type.getElementType()) { + return emitOptionalError(location, "mismatched operand types"); + } + if (!element_type) element_type = lhs_type.getElementType(); + Type result_type = + GetBroadcastType(lhs_type, rhs_type, element_type, broadcast_dimensions); + + if (auto ranked_result_type = result_type.dyn_cast()) { + inferedReturnShapes.emplace_back(ranked_result_type.getShape(), + element_type); + return success(); + } + + // TODO(laurenzo): This should be constructing with `element_type` but that + // constructor variant needs to be added upstream. + inferedReturnShapes.emplace_back(/* element_type */); + return success(); +} + +LogicalResult ReifyBroadcastBinaryOpReturnTypeShapes( + OpBuilder& builder, Operation* op, + SmallVectorImpl& reifiedReturnShapes) { + auto loc = op->getLoc(); + auto lhs = op->getOperand(0); + auto rhs = op->getOperand(1); + + // Check for "numpy"-style rank broadcast. + auto broadcast_dimensions = op->getAttr("broadcast_dimensions") + .dyn_cast_or_null(); + if (broadcast_dimensions && + !xla::IsLegalNumpyRankedBroadcast(lhs, rhs, broadcast_dimensions)) { + // Note: It is unclear whether the general specification of explicit + // broadcast_dimensions on binary ops is a feature we want to carry + // forward. While it can technically be implemented for ranked-dynamic, + // it is incompatible with unranked inputs. If this warning is emitted + // in real programs, it is an indication that the feature should be + // implemented versus just falling back on the more standard definition + // of numpy-like prefix-padding. + return op->emitWarning() + << "unsupported non prefix-padded dynamic rank " + << "broadcast_dimensions = " << broadcast_dimensions; + } + + Value computed_shape = xla::ComputeBinaryElementwiseBroadcastingResultExtents( + loc, lhs, rhs, builder); + if (!computed_shape) return failure(); + reifiedReturnShapes.push_back(computed_shape); + return success(); +} +} // namespace + +//===----------------------------------------------------------------------===// +// BroadcastComplexOp (has custom type inference due to different result type). +//===----------------------------------------------------------------------===// + +LogicalResult BroadcastComplexOp::inferReturnTypeComponents( + MLIRContext* context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl& inferedReturnShapes) { + ShapedType lhs_type = operands[0].getType().dyn_cast(); + if (!lhs_type) { + return emitOptionalError(location, "expected ShapedType"); + } + Type element_type = ComplexType::get(lhs_type.getElementType()); + return InferBroadcastBinaryOpReturnTypeComponents(context, location, operands, + attributes, element_type, + inferedReturnShapes); +} +LogicalResult BroadcastComplexOp::reifyReturnTypeShapes( + OpBuilder& builder, SmallVectorImpl& reifiedReturnShapes) { + return ReifyBroadcastBinaryOpReturnTypeShapes(builder, getOperation(), + reifiedReturnShapes); +} + +//===----------------------------------------------------------------------===// +// BroadcastCompareOp (has custom type inference due to different result type). +//===----------------------------------------------------------------------===// + +void BroadcastCompareOp::build(OpBuilder& builder, OperationState& result, + Value lhs, Value rhs, + DenseIntElementsAttr broadcast_dimensions, + StringAttr comparison_direction) { + auto new_type = GetBroadcastType(lhs.getType(), rhs.getType(), + builder.getI1Type(), broadcast_dimensions); + build(builder, result, new_type, lhs, rhs, broadcast_dimensions, + comparison_direction); +} + +LogicalResult BroadcastCompareOp::inferReturnTypeComponents( + MLIRContext* context, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl& inferedReturnShapes) { + Type element_type = IntegerType::get(1, context); + return InferBroadcastBinaryOpReturnTypeComponents(context, location, operands, + attributes, element_type, + inferedReturnShapes); +} +LogicalResult BroadcastCompareOp::reifyReturnTypeShapes( + OpBuilder& builder, SmallVectorImpl& reifiedReturnShapes) { + return ReifyBroadcastBinaryOpReturnTypeShapes(builder, getOperation(), + reifiedReturnShapes); +} + +//===----------------------------------------------------------------------===// +// Macros for method definitions that are common to most broadcasting ops. +//===----------------------------------------------------------------------===// + +#define BROADCAST_INFER_SHAPE_TYPE_OP_DEFS(Op) \ + LogicalResult Op::inferReturnTypeComponents( \ + MLIRContext* context, Optional location, ValueRange operands, \ + DictionaryAttr attributes, RegionRange regions, \ + SmallVectorImpl& inferedReturnShapes) { \ + return InferBroadcastBinaryOpReturnTypeComponents( \ + context, location, operands, attributes, /*element_type=*/nullptr, \ + inferedReturnShapes); \ + } \ + LogicalResult Op::reifyReturnTypeShapes( \ + OpBuilder& builder, SmallVectorImpl& reifiedReturnShapes) { \ + return ReifyBroadcastBinaryOpReturnTypeShapes(builder, getOperation(), \ + reifiedReturnShapes); \ + } + +#define BROADCAST_BINARY_OP_DEFS(Op) \ + void Op::build(OpBuilder& builder, OperationState& result, Value left, \ + Value right, DenseIntElementsAttr broadcast_dimensions) { \ + auto type = GetBroadcastType( \ + left.getType().cast(), right.getType().cast(), \ + getElementTypeOrSelf(right.getType()), broadcast_dimensions); \ + return Op::build(builder, result, type, left, right, \ + broadcast_dimensions); \ + } \ + BROADCAST_INFER_SHAPE_TYPE_OP_DEFS(Op) + +BROADCAST_BINARY_OP_DEFS(BroadcastAddOp); +BROADCAST_BINARY_OP_DEFS(BroadcastAndOp); +BROADCAST_BINARY_OP_DEFS(BroadcastAtan2Op); +BROADCAST_BINARY_OP_DEFS(BroadcastDivOp); +BROADCAST_BINARY_OP_DEFS(BroadcastMaxOp); +BROADCAST_BINARY_OP_DEFS(BroadcastMinOp); +BROADCAST_BINARY_OP_DEFS(BroadcastMulOp); +BROADCAST_BINARY_OP_DEFS(BroadcastOrOp); +BROADCAST_BINARY_OP_DEFS(BroadcastPowOp); +BROADCAST_BINARY_OP_DEFS(BroadcastRemOp); +BROADCAST_BINARY_OP_DEFS(BroadcastShiftLeftOp); +BROADCAST_BINARY_OP_DEFS(BroadcastShiftRightArithmeticOp); +BROADCAST_BINARY_OP_DEFS(BroadcastShiftRightLogicalOp); +BROADCAST_BINARY_OP_DEFS(BroadcastSubOp); +BROADCAST_BINARY_OP_DEFS(BroadcastXorOp); + +#undef BROADCAST_INFER_SHAPE_TYPE_OP_DEFS +#undef BROADCAST_BINARY_OP_DEFS + +#define GET_OP_CLASSES +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.cc.inc" + +//===----------------------------------------------------------------------===// +// xla_chlo Dialect Constructor +//===----------------------------------------------------------------------===// + +XlaHloClientDialect::XlaHloClientDialect(MLIRContext* context) + : Dialect(getDialectNamespace(), context) { + addOperations< +#define GET_OP_LIST +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.cc.inc" + >(); +} + +} // namespace xla_chlo +} // namespace mlir diff --git a/lib/Dialect/mhlo/IR/dialect_registration.cc b/lib/Dialect/mhlo/IR/dialect_registration.cc new file mode 100644 index 0000000..855c026 --- /dev/null +++ b/lib/Dialect/mhlo/IR/dialect_registration.cc @@ -0,0 +1,24 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" + +// Static initialization for XLA dialect registration. +static mlir::DialectRegistration xla_hlo_ops; +static mlir::DialectRegistration + xla_chlo_ops; +static mlir::DialectRegistration xla_lhlo_ops; diff --git a/lib/Dialect/mhlo/IR/hlo_ops.cc b/lib/Dialect/mhlo/IR/hlo_ops.cc new file mode 100644 index 0000000..0130f4b --- /dev/null +++ b/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -0,0 +1,2110 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file defines the operations used in the XLA dialect. + +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" + +#include +#include +#include + +#include +#include + +#include "third_party/absl/container/flat_hash_set.h" +#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/APFloat.h" +#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/APInt.h" +#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/ArrayRef.h" +#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLExtras.h" +#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/SmallVector.h" +#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringRef.h" +#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/iterator_range.h" +#include "third_party/llvm/llvm-project/llvm/include/llvm/Support/Casting.h" +#include "third_party/llvm/llvm-project/llvm/include/llvm/Support/FormatVariadic.h" +#include "third_party/llvm/llvm-project/llvm/include/llvm/Support/MathExtras.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Dialect.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Location.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Matchers.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpDefinition.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpImplementation.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OperationSupport.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/TypeUtilities.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Types.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Value.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Support/LLVM.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Support/LogicalResult.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/InliningUtils.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h.inc" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/convert_op_folder.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h" + +namespace mlir { +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_structs.cc.inc" +namespace xla_hlo { + +Operation* XlaHloDialect::materializeConstant(OpBuilder& builder, + Attribute value, Type type, + Location loc) { + // HLO dialect constants only support ElementsAttr unlike standard dialect + // constant which supports all attributes. + if (value.isa()) + return builder.create(loc, type, + value.cast()); + return nullptr; +} + +template +static LogicalResult Verify(T op) { + return success(); +} + +namespace { + +//===----------------------------------------------------------------------===// +// Utilities for the canonicalize patterns +//===----------------------------------------------------------------------===// + +// Returns 1D 64-bit dense elements attribute with the given values. +DenseIntElementsAttr GetI64ElementsAttr(ArrayRef values, + Builder* builder) { + RankedTensorType ty = RankedTensorType::get( + {static_cast(values.size())}, builder->getIntegerType(64)); + return DenseIntElementsAttr::get(ty, values); +} + +// Given the start indices and slice sizes for a dynamic-slice that can be +// converted to a static slice, returns the limits for the static slice. +DenseIntElementsAttr BuildSliceLimits(DenseIntElementsAttr start_indices, + DenseIntElementsAttr slice_sizes, + Builder* builder) { + SmallVector slice_limits; + for (int64_t i = 0; i < slice_sizes.getNumElements(); ++i) { + int64_t start_index = start_indices.getValue(i).getInt(); + int64_t slice_size = slice_sizes.getValue(i).getInt(); + slice_limits.push_back(start_index + slice_size); + } + return GetI64ElementsAttr(slice_limits, builder); +} + +#include "third_party/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/generated_canonicalize.inc" +} // namespace + +//===----------------------------------------------------------------------===// +// ConstOp +//===----------------------------------------------------------------------===// + +static void Print(ConstOp op, OpAsmPrinter* printer) { + // Print op name. + *printer << op.getOperationName(); + + // Elide attribute value while printing the attribute dictionary. + SmallVector elided_attrs; + elided_attrs.push_back("value"); + printer->printOptionalAttrDict(op.getAttrs(), elided_attrs); + + *printer << ' ' << op.value(); +} + +static ParseResult ParseConstOp(OpAsmParser* parser, OperationState* result) { + if (parser->parseOptionalAttrDict(result->attributes)) return failure(); + + // If colon is not present after attribute dictionary, it should be short form + // and attribute 'value' is outside the dictionary. + if (failed(parser->parseOptionalColon())) { + Attribute value; + if (parser->parseAttribute(value, "value", result->attributes)) + return failure(); + return parser->addTypeToList(value.getType(), result->types); + } + + // Long form should have type of the result after colon. + Type ty; + if (parser->parseType(ty)) return failure(); + result->types.push_back(ty); + return success(); +} + +OpFoldResult ConstOp::fold(ArrayRef operands) { + assert(operands.empty() && "constant has no operands"); + + // Return the held attribute value. + return value(); +} + +// Builds a constant op with the specified attribute `value`. +void ConstOp::build(OpBuilder& builder, OperationState& result, + Attribute value) { + Type type; + if (auto elemAttr = value.dyn_cast()) { + type = elemAttr.getType(); + } else if (value.isa() || value.isa() || + value.isa()) { + // All XLA types must be tensor types. In the build() method, we want to + // provide more flexibility by allowing attributes of scalar types. But we + // need to wrap it up with ElementsAttr to construct valid XLA constants. + type = RankedTensorType::get(/*shape=*/{}, value.getType()); + value = DenseElementsAttr::get(type.cast(), value); + } + + // TODO: support other XLA specific types. + assert(type && "unsupported attribute type for building xla_hlo.constant"); + result.types.push_back(type); + result.addAttribute("value", value); +} + +//===----------------------------------------------------------------------===// +// DotGeneralOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(DotGeneralOp op) { + auto dot_dimension_numbers = op.dot_dimension_numbers(); + int64_t lhs_batching_dimensions_size = llvm::size( + dot_dimension_numbers.lhs_batching_dimensions().getValues()); + int64_t rhs_batching_dimensions_size = llvm::size( + dot_dimension_numbers.rhs_batching_dimensions().getValues()); + if (lhs_batching_dimensions_size != rhs_batching_dimensions_size) { + return op.emitError() + << "lhs and rhs should have the same number of batching dimensions"; + } + int64_t lhs_contracting_dimensions_size = llvm::size( + dot_dimension_numbers.lhs_contracting_dimensions().getValues()); + int64_t rhs_contracting_dimensions_size = llvm::size( + dot_dimension_numbers.rhs_contracting_dimensions().getValues()); + if (lhs_contracting_dimensions_size != rhs_contracting_dimensions_size) { + return op.emitError() << "lhs and rhs should have the same number of " + "contracting dimensions"; + } + return success(); +} + +//===----------------------------------------------------------------------===// +// IotaOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(IotaOp op) { + auto shape = op.getType().cast(); + if (!shape.hasRank()) return success(); + + if (shape.getRank() == 0) + return op.emitOpError() << "does not support scalars."; + + auto iota_dimension = op.iota_dimension().getSExtValue(); + if (iota_dimension >= shape.getRank() || iota_dimension < 0) + return op.emitOpError() << "iota dimension cannot go beyond the output " + "rank or be negative."; + return success(); +} + +//===----------------------------------------------------------------------===// +// DynamicIotaOp +//===----------------------------------------------------------------------===// + +namespace { + +struct DynamicIotaIsStatic : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DynamicIotaOp iota, + PatternRewriter& rewriter) const override { + auto result_ty = iota.getType().cast(); + if (!result_ty.hasStaticShape()) { + return failure(); + } + + rewriter.replaceOpWithNewOp(iota, result_ty, iota.iota_dimension()); + return success(); + } +}; + +} // namespace + +void DynamicIotaOp::getCanonicalizationPatterns( + OwningRewritePatternList& results, MLIRContext* context) { + results.insert(context); +} + +//===----------------------------------------------------------------------===// +// AbsOp +//===----------------------------------------------------------------------===// + +void AbsOp::build(OpBuilder& builder, OperationState& result, Value operand) { + auto shaped_type = operand.getType().cast(); + Type new_type; + if (!shaped_type.getElementType().isa()) { + new_type = operand.getType(); + } else if (shaped_type.hasRank()) { + new_type = RankedTensorType::get(shaped_type.getShape(), operand.getType()); + } else { + new_type = UnrankedTensorType::get(operand.getType()); + } + + return AbsOp::build(builder, result, new_type, operand); +} + +//===----------------------------------------------------------------------===// +// CollectivePermuteOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(CollectivePermuteOp op) { + // Check that source target pair is Nx2 tensor. + auto type = op.source_target_pairs().getType().dyn_cast(); + if (type.getRank() != 2) + return op.emitError() << "expect source_target_pairs attribute to be of " + "rank 2, but got rank " + << type.getRank(); + if (type.getShape()[1] != 2) + return op.emitError() + << "expect source_target_pairs attribute of shape (N, 2), but got (" + << type.getShape() << ")"; + // Check source target pairs for duplicate sources or targets + absl::flat_hash_set sources; + absl::flat_hash_set targets; + for (auto i = op.source_target_pairs().begin(), + e = op.source_target_pairs().end(); + i != e; ++i) { + auto val = (*i).getSExtValue(); + if (i.getIndex() % 2 == 0) { + bool is_unique = sources.insert(val).second; + if (!is_unique) return op.emitError() << "duplicate sources not allowed."; + } else { + bool is_unique = targets.insert(val).second; + if (!is_unique) return op.emitError() << "duplicate targets not allowed."; + } + } + return success(); +} + +//===----------------------------------------------------------------------===// +// ConvertOp +//===----------------------------------------------------------------------===// + +void ConvertOp::build(OpBuilder& builder, OperationState& result, Value operand, + Type result_element_ty) { + Type result_ty; + Type operand_ty = operand.getType(); + if (auto ranked_ty = operand_ty.dyn_cast()) { + result_ty = RankedTensorType::get(ranked_ty.getShape(), result_element_ty); + } else { + result_ty = UnrankedTensorType::get(result_element_ty); + } + build(builder, result, result_ty, operand); +} + +OpFoldResult ConvertOp::fold(ArrayRef operands) { + if (getOperand().getType() == getResult().getType()) return getOperand(); + + // If the result has non-static shape, a convert op is necessary to go from + // static shape to non-static shape. + if (!getResult().getType().cast().hasStaticShape()) return {}; + + // If the operand is constant, we can do the conversion now. + if (auto elementsAttr = operands.front().dyn_cast_or_null()) { + return xla::ConvertElementsAttr(elementsAttr, + getElementTypeOrSelf(getResult())); + } + + return {}; +} + +//===----------------------------------------------------------------------===// +// DequantizeOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(DequantizeOp op) { + auto input_type = op.input().getType().dyn_cast(); + auto output_type = op.output().getType().dyn_cast(); + if (!input_type || !output_type) { + return op.emitError() << "ranked input and output."; + } + auto input_shape = input_type.getShape(); + auto output_shape = output_type.getShape().vec(); + if (op.transpose_output()) { + std::reverse(output_shape.begin(), output_shape.end()); + } + + // Check the input rank and output rank are same, and also the lower + // dimensions are same. + if (input_shape.size() != output_shape.size() || + !std::equal(input_shape.begin(), + std::next(input_shape.begin(), input_shape.size() - 1), + output_shape.begin())) { + return op.emitError() << "mismatched dimensions."; + } + + // Check that the last dimension of the output is 2x or 4x of that of the + // input depending on the unpacked input is 16 or 8 bits. + int input_last_dim = *input_shape.rbegin(); + int output_last_dim = *output_shape.rbegin(); + int scale_factor = op.is_16bits() ? 2 : 4; + if (output_last_dim != scale_factor * input_last_dim) { + return op.emitError() << "last dimension of output should be " + << scale_factor << "x of the input."; + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// GetTupleElementOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(GetTupleElementOp op) { + auto indexVal = op.index().getZExtValue(); + auto operandType = op.getOperand().getType().cast(); + if (indexVal >= operandType.size()) { + return op.emitOpError( + llvm::formatv("index {0} is out of bounds of operand with size {1}", + indexVal, operandType.size())); + } + + auto expectedType = operandType.getType(indexVal); + if (op.getType() != expectedType) { + return op.emitOpError(llvm::formatv("has return type {0}, but expected {1}", + op.getType(), expectedType)); + } + return success(); +} + +OpFoldResult GetTupleElementOp::fold(ArrayRef operands) { + if (auto tupleOp = + dyn_cast_or_null(getOperand().getDefiningOp())) { + return tupleOp.getOperand(index().getLimitedValue()); + } + + return {}; +} + +//===----------------------------------------------------------------------===// +// TupleOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(TupleOp op) { + SmallVector operandTypes = {op.operand_type_begin(), + op.operand_type_end()}; + auto expectedType = TupleType::get(operandTypes, op.getContext()); + if (op.getType() != expectedType) { + return op.emitOpError(llvm::formatv("has return type {0}, but expected {1}", + op.getType(), expectedType)); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// AllToAllOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(AllToAllOp op) { + // If operand is ranked, size of split dimension should be a multiple of split + // count. + auto type = op.getOperand().getType().dyn_cast(); + if (!type) return success(); + auto split_dim_size = type.getDimSize(op.split_dimension().getSExtValue()); + auto split_count = op.split_count().getSExtValue(); + if (split_dim_size % split_count != 0) { + return op.emitError() << "split dimension has size " << split_dim_size + << ", expected to be a multiple of split_count " + << split_count; + } + return success(); +} + +//===----------------------------------------------------------------------===// +// BroadcastOp +//===----------------------------------------------------------------------===// + +// TODO(b/129012527) These should be expressed as type constraints. +static LogicalResult Verify(BroadcastOp op) { + auto sizes = op.broadcast_sizes(); + auto sizesType = sizes.getType(); + auto sizesRank = sizesType.getRank(); + if (sizesRank != 1) { + return op.emitOpError(llvm::formatv( + "broadcast_sizes has rank {0} instead of rank 1", sizesRank)); + } + + auto resultType = op.getResult().getType().cast(); + auto resultRank = resultType.getRank(); + auto operandType = op.operand().getType().cast(); + auto operandRank = operandType.getRank(); + auto sizesSize = sizesType.getNumElements(); + auto expectedRank = operandRank + sizesSize; + + if (resultRank != expectedRank) { + return op.emitOpError( + llvm::formatv("result rank ({0}) does not match operand rank " + "({1}) plus size of broadcast_sizes ({2})", + resultRank, operandRank, sizesSize)); + } + + llvm::SmallVector expectedShape(sizes.getValues()); + + auto operandShape = operandType.getShape(); + expectedShape.insert(expectedShape.end(), operandShape.begin(), + operandShape.end()); + + auto resultShape = resultType.getShape(); + if (resultShape != llvm::makeArrayRef(expectedShape)) { + return op.emitOpError(llvm::formatv( + "result has shape [{0}] instead of [{1}]", + llvm::make_range(resultShape.begin(), resultShape.end()), + llvm::make_range(expectedShape.begin(), expectedShape.end()))); + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// BroadcastInDimOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(BroadcastInDimOp op) { + auto operandType = op.operand().getType().dyn_cast(); + auto operandRank = operandType.getRank(); + if (!op.broadcast_dimensions()) { + if (operandRank == 0) { + return success(); + } + return op.emitOpError( + llvm::formatv("broadcast_dimensions is absent, but required because " + "operand has non-zero rank ({0})", + operandRank)); + } + + auto dimensions = op.broadcast_dimensions(); + auto dimensionsType = op.broadcast_dimensions().getType(); + auto dimensionsRank = dimensionsType.getRank(); + if (dimensionsRank != 1) { + return op.emitOpError(llvm::formatv( + "broadcast_dimensions has rank {0} instead of rank 1", dimensionsRank)); + } + + auto dimensionsSize = dimensionsType.getNumElements(); + if (dimensionsSize != operandRank) { + return op.emitOpError(llvm::formatv( + "broadcast_dimensions size ({0}) does not match operand rank ({1})", + dimensionsSize, operandRank)); + } + + auto resultType = op.getResult().getType().cast(); + auto resultRank = resultType.getRank(); + if (resultRank < operandRank) { + return op.emitOpError( + llvm::formatv("result rank ({0}) is less than operand rank ({1})", + resultRank, operandRank)); + } + + for (int i = 0; i != dimensionsSize; ++i) { + auto dimIndex = dimensions.getValue(i); + if (dimIndex >= resultRank) { + return op.emitOpError( + llvm::formatv("broadcast_dimensions contains invalid value {0} for " + "result result with rank {1}", + dimIndex, resultRank)); + } + + auto dimSize = operandType.getDimSize(i); + auto resultDimSize = resultType.getDimSize(dimIndex); + if (dimSize != 1 && dimSize != resultDimSize) { + return op.emitOpError( + llvm::formatv("size of operand dimension {0} ({1}) is not equal to " + "1 or size of result dimension {2} ({3})", + i, dimSize, dimIndex, resultDimSize)); + } + } + + return success(); +} + +OpFoldResult BroadcastInDimOp::fold(ArrayRef) { + auto type = getType().cast(); + if (type != getOperand().getType()) { + return nullptr; + } + auto broadcast_values = broadcast_dimensions().getValues(); + if (!std::equal(broadcast_values.begin(), broadcast_values.end(), + llvm::seq(0, type.getRank()).begin())) { + return nullptr; + } + return getOperand(); +} + +//===----------------------------------------------------------------------===// +// DynamicBroadcastInDimOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(DynamicBroadcastInDimOp op) { + auto operandType = op.operand().getType().dyn_cast(); + auto resultType = op.getResult().getType().dyn_cast(); + + // If either the operand or result are unranked, there is very little + // to verify statically. + if (!operandType || !resultType) { + return success(); + } + + auto outputDimensionsType = + op.output_dimensions().getType().cast(); + auto outputDimensionsSize = outputDimensionsType.getDimSize(0); + auto operandRank = operandType.getRank(); + auto resultRank = resultType.getRank(); + + // Verify broadcast_dimensions. + auto bcastDimensions = op.broadcast_dimensions(); + auto bcastDimensionsType = op.broadcast_dimensions().getType(); + auto bcastDimensionsRank = bcastDimensionsType.getRank(); + // TODO(laurenzo): Update the BroadcastDimAttr to constrain its rank to 1. + if (bcastDimensionsRank != 1) { + return op.emitOpError( + llvm::formatv("broadcast_dimensions has rank {0} instead of rank 1", + bcastDimensionsRank)); + } + + auto bcastDimensionsSize = bcastDimensionsType.getNumElements(); + if (bcastDimensionsSize != operandRank) { + return op.emitOpError(llvm::formatv( + "broadcast_dimensions size ({0}) does not match operand rank ({1})", + bcastDimensionsSize, operandRank)); + } + + if (resultRank < operandRank) { + return op.emitOpError( + llvm::formatv("result rank ({0}) is less than operand rank ({1})", + resultRank, operandRank)); + } + + for (int i = 0; i != bcastDimensionsSize; ++i) { + auto dimIndex = bcastDimensions.getValue(i); + if (dimIndex >= resultRank) { + return op.emitOpError( + llvm::formatv("broadcast_dimensions contains invalid value {0} for " + "result result with rank {1}", + dimIndex, resultRank)); + } + + auto dimSize = operandType.getDimSize(i); + auto resultDimSize = resultType.getDimSize(dimIndex); + if (dimSize != 1 && dimSize != resultDimSize) { + return op.emitOpError( + llvm::formatv("size of operand dimension {0} ({1}) is not equal to " + "1 or size of result dimension {2} ({3})", + i, dimSize, dimIndex, resultDimSize)); + } + } + + if (outputDimensionsSize != resultRank) { + return op.emitOpError( + llvm::formatv("result rank ({0}) is not equal to number of output " + "dimensions ({1})", + resultRank, outputDimensionsSize)); + } + + return success(); +} + +// If a DynamicBroadCastInDimOp is not actually dynamic, use an ordinary +// BroadcastInDimOp. +class DynamicBroadcastInDimOpNotActuallyDynamic + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(DynamicBroadcastInDimOp op, + PatternRewriter& rewriter) const override { + auto type = op.getType().dyn_cast(); + if (!type || !type.hasStaticShape()) { + return rewriter.notifyMatchFailure(op, "requires static shape"); + } + rewriter.replaceOpWithNewOp( + op, op.getType(), op.operand(), op.broadcast_dimensions()); + return success(); + } +}; + +void DynamicBroadcastInDimOp::getCanonicalizationPatterns( + OwningRewritePatternList& results, MLIRContext* context) { + results.insert(context); +} + +//===----------------------------------------------------------------------===// +// ClampOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(ClampOp op) { + auto operandType = op.operand().getType().cast(); + auto operandShape = operandType.getShape(); + auto minType = op.min().getType().cast(); + + auto minShape = minType.getShape(); + if (minShape != operandShape && minType.getRank() != 0) { + return op.emitOpError(llvm::formatv( + "min shape [{0}] is not scalar and does not match operand shape [{1}]", + llvm::make_range(minShape.begin(), minShape.end()), + llvm::make_range(operandShape.begin(), operandShape.end()))); + } + + auto maxType = op.max().getType().cast(); + auto maxShape = maxType.getShape(); + if (maxShape != operandShape && maxType.getRank() != 0) { + return op.emitOpError(llvm::formatv( + "max shape [{0}] is not scalar and does not match operand shape [{1}]", + llvm::make_range(maxShape.begin(), maxShape.end()), + llvm::make_range(operandShape.begin(), operandShape.end()))); + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// ComplexOp +//===----------------------------------------------------------------------===// + +void ComplexOp::build(OpBuilder& builder, OperationState& state, Value lhs, + Value rhs) { + auto type = lhs.getType(); + auto element_ty = ComplexType::get(getElementTypeOrSelf(type)); + Type result_ty; + if (auto ranked_type = type.dyn_cast()) { + result_ty = RankedTensorType::get(ranked_type.getShape(), element_ty); + } else if (type.isa()) { + result_ty = UnrankedTensorType::get(element_ty); + } else { + result_ty = element_ty; + } + + build(builder, state, result_ty, lhs, rhs); +} + +OpFoldResult ComplexOp::fold(ArrayRef operands) { + auto real_op = + dyn_cast_or_null(getOperand(0).getDefiningOp()); + auto imag_op = + dyn_cast_or_null(getOperand(1).getDefiningOp()); + if (real_op && imag_op && real_op.getOperand() == imag_op.getOperand()) { + return real_op.getOperand(); + } + + return {}; +} + +namespace { +Type CreateRealType(Type type) { + auto element_ty = getElementTypeOrSelf(type); + if (auto complex_ty = element_ty.dyn_cast()) { + element_ty = complex_ty.getElementType(); + } + + if (auto ranked_type = type.dyn_cast()) { + return RankedTensorType::get(ranked_type.getShape(), element_ty); + } else if (type.dyn_cast()) { + return UnrankedTensorType::get(element_ty); + } + + return element_ty; +} +} // namespace + +void ImagOp::build(OpBuilder& builder, OperationState& state, Value val) { + build(builder, state, CreateRealType(val.getType()), val); +} + +OpFoldResult ImagOp::fold(ArrayRef operands) { + if (auto complex_op = + dyn_cast_or_null(getOperand().getDefiningOp())) { + return complex_op.getOperand(1); + } + + return {}; +} + +void RealOp::build(OpBuilder& builder, OperationState& state, Value val) { + build(builder, state, CreateRealType(val.getType()), val); +} + +OpFoldResult RealOp::fold(ArrayRef operands) { + if (auto complex_op = + dyn_cast_or_null(getOperand().getDefiningOp())) { + return complex_op.getOperand(0); + } + + return {}; +} + +//===----------------------------------------------------------------------===// +// ConcatenateOp +//===----------------------------------------------------------------------===// + +namespace { +class ConcatenateOperandRemoval : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(ConcatenateOp op, + PatternRewriter& rewriter) const override { + auto axis = op.dimension().getLimitedValue(); + llvm::SmallVector new_operands; + for (auto operand : op.getOperands()) { + auto ty = operand.getType().cast(); + if (ty.getDimSize(axis) != 0) { + new_operands.push_back(operand); + } + } + + if (!new_operands.empty() && new_operands.size() < op.getNumOperands()) { + rewriter.replaceOpWithNewOp(op, op.getResult().getType(), + new_operands, op.dimension()); + return success(); + } + + return failure(); + } +}; +} // namespace + +LogicalResult ConcatenateOp::inferReturnTypes( + MLIRContext*, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl& inferredReturnTypes) { + if (operands.empty()) { + return failure(); + } + + auto dimension_attr = attributes.get("dimension").cast(); + auto dimension = dimension_attr.getInt(); + + auto first_type = (*operands.begin()).getType().cast(); + auto out_element = first_type.getElementType(); + + for (auto operand : operands.getTypes()) { + auto element_type = getElementTypeOrSelf(operand); + if (element_type != out_element) { + return failure(); + } + } + + // If an input is unranked the output shape is unranked. + if (!first_type.hasRank()) { + inferredReturnTypes.push_back(UnrankedTensorType::get(out_element)); + return success(); + } + + auto out_shape = llvm::to_vector<6>(first_type.getShape()); + out_shape[dimension] = 0; + + for (auto operand : operands.getTypes()) { + auto type = operand.cast(); + if (!type.hasRank()) { + inferredReturnTypes.push_back(UnrankedTensorType::get(out_element)); + return success(); + } + + // If the dimension is dynamic we know the output dimension is dynamic. + auto dim = type.getShape()[dimension]; + if (dim == -1) { + out_shape[dimension] = -1; + break; + } + + out_shape[dimension] += dim; + } + + inferredReturnTypes.push_back(RankedTensorType::get(out_shape, out_element)); + + return success(); +} + +void ConcatenateOp::getCanonicalizationPatterns( + OwningRewritePatternList& results, MLIRContext* context) { + results.insert(context); +} + +template +static Attribute foldConcatenateHelper(ConcatenateOp* op, + ArrayRef operands) { + auto axis = op->dimension().getLimitedValue(); + auto type = op->getType().cast(); + + SmallVector values; + auto shape = type.getShape(); + + size_t top_size = 1; + for (int i = 0, e = axis; i < e; i++) { + top_size = top_size * shape[i]; + } + + for (size_t i = 0; i < top_size; i++) { + for (auto operand : operands) { + DenseElementsAttr attr = operand.cast(); + size_t bottom_size = attr.getNumElements() / top_size; + auto iter = attr.getValues().begin() + i * bottom_size; + values.append(iter, iter + bottom_size); + } + } + + return DenseElementsAttr::get(type, values); +} + +static Attribute foldConcatenate(ConcatenateOp* op, + ArrayRef operands) { + for (auto operand : operands) { + if (!operand) return {}; + } + + auto type = op->getResult().getType().cast(); + auto etype = type.getElementType(); + if (etype.isa()) { + return foldConcatenateHelper(op, operands); + } + + if (etype.isa()) { + return foldConcatenateHelper(op, operands); + } + + return {}; +} + +OpFoldResult ConcatenateOp::fold(ArrayRef operands) { + if (getNumOperands() == 1) return getOperand(0); + + ShapedType type = getResult().getType().cast(); + if (!type.hasStaticShape()) return {}; + + auto axis = dimension().getLimitedValue(); + if (auto attr = foldConcatenate(this, operands)) { + return attr; + } + + llvm::SmallVector new_operands; + for (auto operand : getOperands()) { + auto ty = operand.getType().cast(); + if (ty.getDimSize(axis) != 0) { + return {}; + } + } + + return DenseElementsAttr::get(type, ArrayRef()); +} + +static LogicalResult Verify(ConcatenateOp op) { + Type element_type = getElementTypeOrSelf(op.getOperand(0).getType()); + RankedTensorType first_ranked_type; + int num_operands = op.getNumOperands(); + for (int i = 0; i < num_operands; i++) { + auto second_type = op.getOperand(i).getType().dyn_cast(); + if (second_type.getElementType() != element_type) { + return op.emitOpError( + llvm::formatv("operands (0) and ({0}) do not match element type", i)); + } + + if (!second_type.hasRank()) { + continue; + } + + if (!first_ranked_type) { + first_ranked_type = second_type.cast(); + continue; + } + + if (first_ranked_type.getRank() != second_type.getRank()) { + return op.emitOpError( + llvm::formatv("operands (0) and ({0}) do not match rank", i)); + } + + auto first_shape = second_type.getShape(); + auto second_shape = second_type.getShape(); + for (int d = 0; d < first_ranked_type.getRank(); ++d) { + if (first_shape[d] != second_shape[d] && d != op.dimension()) { + return op.emitOpError(llvm::formatv( + "operands (0) and ({0}) non-concat dimensions do not match " + "({1}) != ({2})", + i, llvm::make_range(first_shape.begin(), first_shape.end()), + llvm::make_range(second_shape.begin(), second_shape.end()))); + } + } + } + return success(); +} + +//===----------------------------------------------------------------------===// +// DynamicReshapeOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(DynamicReshapeOp op) { + auto result_type = op.result().getType().dyn_cast(); + auto output_shape_type = + op.output_shape().getType().dyn_cast(); + if (result_type && output_shape_type && output_shape_type.hasStaticShape() && + output_shape_type.getDimSize(0) != result_type.getRank()) { + return op.emitError() << "output should have a rank equal to the number of " + "elements in output_shape"; + } + return success(); +} + +namespace { +class DynamicReshapeOpNotActuallyDynamic + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(DynamicReshapeOp op, + PatternRewriter& rewriter) const override { + auto type = op.result().getType().dyn_cast(); + if (!type || !type.hasStaticShape()) { + return rewriter.notifyMatchFailure(op, "requires static shape tensor"); + } + rewriter.replaceOpWithNewOp(op, op.getType(), op.operand()); + return success(); + } +}; +} // namespace + +void DynamicReshapeOp::getCanonicalizationPatterns( + OwningRewritePatternList& results, MLIRContext* context) { + results.insert(context); +} + +//===----------------------------------------------------------------------===// +// DynamicSliceOp +//===----------------------------------------------------------------------===// + +namespace { +// Canonicalizes DynamicSlice ops that can be replaced instead with Slice ops. +// This canonicalization is applied the case when the `begin` input values are +// compile time constants and thus can be made into a tensor. +struct DynamicSliceToSlice : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DynamicSliceOp dynamic_slice, + PatternRewriter& rewriter) const override { + Value input = dynamic_slice.operand(); + auto input_tensor = input.getType().dyn_cast(); + if (!input_tensor) return failure(); + + SmallVector temp_start_indices; + for (Value start : dynamic_slice.start_indices()) { + APInt val; + if (!matchPattern(start, m_ConstantInt(&val))) { + return failure(); + } + temp_start_indices.push_back(*(val.getRawData())); + } + + // At this point we've determined that the start indices are all constants; + // pack them into a single tensor. + auto loc = dynamic_slice.getLoc(); + int64_t input_rank = input_tensor.getRank(); + auto slice_start_indices = + GetI64ElementsAttr(temp_start_indices, &rewriter); + DenseIntElementsAttr slice_limits = BuildSliceLimits( + slice_start_indices, dynamic_slice.slice_sizes(), &rewriter); + DenseIntElementsAttr slice_strides = + GetI64ElementsAttr(SmallVector(input_rank, 1), &rewriter); + auto result = rewriter.create(loc, input, slice_start_indices, + slice_limits, slice_strides); + rewriter.replaceOp(dynamic_slice, {result}); + return success(); + } +}; + +} // namespace + +void DynamicSliceOp::getCanonicalizationPatterns( + OwningRewritePatternList& results, MLIRContext* context) { + results.insert(context); +} + +// Verifies that the number of slice sizes and the number of start indices match +static LogicalResult Verify(DynamicSliceOp op) { + int num_slice_sizes = op.slice_sizes().getNumElements(); + int num_start_indices = op.start_indices().size(); + if (num_start_indices != num_slice_sizes) { + return op.emitOpError() + << "has mismatched number of slice sizes (" << num_slice_sizes + << ") and number of start indices (" << num_start_indices << ")"; + } + return success(); +} + +//===----------------------------------------------------------------------===// +// InfeedOp +//===----------------------------------------------------------------------===// + +// Checks that the result type is of the form `tuple< any_type, token >`. +static LogicalResult Verify(InfeedOp op) { + auto result_ty = op.getResult().getType().cast(); + auto subtypes = result_ty.getTypes(); + if (subtypes.size() != 2) + return op.emitOpError() + << "result is expected to be a tuple of size 2, but got " + << subtypes.size(); + if (!subtypes[1].isa()) + return op.emitOpError() << "second element of result tuple is expected to " + "be of token type, but got " + << subtypes[1]; + return success(); +} + +//===----------------------------------------------------------------------===// +// MapOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(MapOp op) { + // Checks if the number of `operands` match the arity of the map `computation` + // region. + auto& computation_block = op.computation().front(); + auto computation_args = computation_block.getArguments(); + if (op.operands().size() != computation_args.size()) + return op.emitOpError() + << "expects number of operands to match the arity " + "of map computation, but got: " + << op.operands().size() << " and " << computation_args.size(); + + // The parameters of computation should all be scalars and match the element + // type of operands. + auto operand_type = op.operands()[0].getType().cast(); + auto operand_elem_ty = operand_type.getElementType(); + + for (auto indexed_arg : llvm::enumerate(computation_args)) { + auto arg_type = indexed_arg.value().getType().dyn_cast(); + if (!arg_type || arg_type.getRank() != 0) + return op.emitOpError() + << "computation arguments must be 0-rank tensor, but got: arg #" + << indexed_arg.index() << " of type " + << indexed_arg.value().getType(); + if (arg_type.getElementType() != operand_elem_ty) { + return op.emitOpError() + << "element type of operands and computation arguments must " + "match, but got: " + << operand_elem_ty << " and " << arg_type.getElementType(); + } + } + + // Mapped computation must return single output + auto computation_outputs = computation_block.getTerminator()->getOperands(); + if (computation_outputs.size() != 1) + return op.emitOpError() + << "computation must return single output, but got: " + << computation_outputs.size(); + + // The output of computation must be scalar and have the same element type + // as op result. + auto computation_output_type = + computation_outputs[0].getType().dyn_cast(); + if (!computation_output_type || computation_output_type.getRank() != 0) + return op.emitOpError() + << "computation must return 0-rank tensor, but got: " + << computation_outputs[0].getType(); + + auto result_type = op.getType().cast(); + if (computation_output_type.getElementType() != result_type.getElementType()) + return op.emitOpError() << "element type of result and computation output " + "must match, but got: " + << result_type.getElementType() << " and " + << computation_output_type.getElementType(); + + // Checks that the requested map dimension numbers are monotonically + // increasing. + auto values = op.dimensions().getValues(); + auto dimensions = std::vector{values.begin(), values.end()}; + for (int i = 0, e = dimensions.size(); i < e; ++i) { + if (dimensions[i] != i) + return op.emitOpError() << "requires monotonically increasing dimension " + "numbers, but got: " + << op.dimensions(); + } + + // Checks that number of dimensions of operands matches the size of + // `dimensions` since we currently only support mapping across all + // dimensions: i.e., scalar map functions. + if (operand_type.hasRank()) { + if (dimensions.size() != operand_type.getShape().size()) + return op.emitOpError() + << "applied to a subset of dimensions currently not supported: " + "operand dimensions = " + << operand_type.getShape().size() + << ", requested map dimensions size = " << dimensions.size(); + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// RecvOp +//===----------------------------------------------------------------------===// + +// Checks that the result type is of the form `tuple` +static LogicalResult Verify(RecvOp op) { + auto result_ty = op.getResult().getType().cast(); + auto subtypes = result_ty.getTypes(); + if (subtypes.size() != 2) + return op.emitOpError() + << "result is expected to be a tuple of size 2, but got " + << subtypes.size(); + if (!subtypes[1].isa()) + return op.emitOpError() << "second element of result tuple is expected to " + "be of token type, but got " + << subtypes[1]; + return success(); +} + +//===----------------------------------------------------------------------===// +// CopyOp +//===----------------------------------------------------------------------===// + +OpFoldResult CopyOp::fold(ArrayRef operands) { return getOperand(); } + +//===----------------------------------------------------------------------===// +// ReverseOp +//===----------------------------------------------------------------------===// + +OpFoldResult ReverseOp::fold(ArrayRef operands) { + auto input = operand(); + + // No dimensions to reverse. + if (dimensions().getNumElements() == 0) return input; + + llvm::SmallVector new_dims; + new_dims.reserve(dimensions().getNumElements()); + + auto shaped_type = input.getType().cast(); + for (auto dim : dimensions().getValues()) { + if (shaped_type.getDimSize(dim.getLimitedValue()) != 1) { + return nullptr; + } + } + + return input; +} + +//===----------------------------------------------------------------------===// +// ReduceOp +//===----------------------------------------------------------------------===// + +// Returns the result type after reducing operand of the given type across the +// specified dimensions. +static TensorType GetReduceResultType(Type operand_ty, + DenseIntElementsAttr dimensions, + Builder* builder) { + Type element_ty = getElementTypeOrSelf(operand_ty); + + auto ranked_ty = operand_ty.dyn_cast(); + if (!ranked_ty) return UnrankedTensorType::get(element_ty); + + int64_t rank = ranked_ty.getRank(); + llvm::SmallVector dims_mask(rank, false); + for (int64_t dim : dimensions.getValues()) dims_mask[dim] = true; + + SmallVector shape; + for (int64_t i = 0; i < rank; ++i) { + if (!dims_mask[i]) shape.push_back(ranked_ty.getDimSize(i)); + } + + return RankedTensorType::get(shape, element_ty); +} + +void ReduceOp::build(OpBuilder& builder, OperationState& state, + ValueRange operands, ValueRange init_values, + DenseIntElementsAttr dimensions) { + SmallVector result_ty; + result_ty.reserve(operands.size()); + + for (Value operand : operands) { + result_ty.push_back( + GetReduceResultType(operand.getType(), dimensions, &builder)); + } + build(builder, state, result_ty, operands, init_values, dimensions); +} + +LogicalResult ReduceOp::fold(ArrayRef operands, + SmallVectorImpl& results) { + // No dimensions to reduce. + if (dimensions().getNumElements() == 0) { + for (Value input : this->operands()) { + results.push_back(input); + } + return success(); + } + return failure(); +} + +//===----------------------------------------------------------------------===// +// SelectOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(SelectOp op) { + // TODO(jpienaar): Update to allow broadcastable and unranked inputs. This + // corresponds to the client side HLO. + return success(); +} + +// Makes it such that a SelectOp that is a non-root operation in a DRR infers +// the return type based on operand type. +LogicalResult SelectOp::inferReturnTypes( + MLIRContext*, Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl& inferredReturnTypes) { + auto x_type = operands[1].getType(); + auto y_type = operands[2].getType(); + auto x_tensor = x_type.cast(); + auto y_tensor = y_type.cast(); + + // Check for type compatibility in the select op. This requires that the two + // non-predicate operands: + // (a) have the same element type + // (b) have compatible shapes (i.e. the same shape and/or at least one + // dynamic shape) + if (x_tensor.getElementType() != y_tensor.getElementType() || + failed(mlir::verifyCompatibleShape(x_type, y_type))) { + return emitOptionalError(location, "incompatible operand types: ", x_type, + " and ", y_type); + } + + // TODO(lucyfox): Support output shape inference when operands have compatible + // shapes. (The output shape should be the most general of the operand shapes + // at each dimension.) For now, handle the straightforward cases and fail + // otherwise. When this is fully implemented, this logic should move into + // reusable functionality in MLIR Core. + Type output_type; + if (x_type == y_type || !x_tensor.hasRank()) { + output_type = x_type; + } else if (!y_tensor.hasRank()) { + output_type = y_type; + } else { + return emitOptionalError(location, + "currently unsupported operand types: ", x_type, + " and ", y_type); + } + inferredReturnTypes.assign({output_type}); + return success(); +} + +//===----------------------------------------------------------------------===// +// PadOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(PadOp op) { + auto input_type = op.operand().getType().cast(); + auto pad_type = op.padding_value().getType().cast(); + + if (pad_type.getRank() != 0) { + return op.emitOpError( + llvm::formatv("padding value type should be a rank-0 " + "tensor, is rank {0}", + pad_type.getRank())); + } + + const auto& padding_low = op.edge_padding_low(); + if (padding_low.getType().getNumElements() != input_type.getRank()) { + return op.emitOpError(llvm::formatv( + "edge_padding_low length ({0}) must match operand rank ({1})", + padding_low.getType().getNumElements(), input_type.getRank())); + } + + const auto& padding_high = op.edge_padding_high(); + if (padding_high.getType().getNumElements() != input_type.getRank()) { + return op.emitOpError(llvm::formatv( + "edge_padding_high length ({0}) must match operand rank ({1})", + padding_high.getType().getNumElements(), input_type.getRank())); + } + + const auto& padding_interior = op.interior_padding(); + if (padding_interior.getType().getNumElements() != input_type.getRank()) { + return op.emitOpError(llvm::formatv( + "interior_padding length ({0}) must match operand rank ({1})", + padding_interior.getType().getNumElements(), input_type.getRank())); + } + + auto input_shape = input_type.getShape(); + auto output_shape = + op.getResult().getType().cast().getShape(); + if (input_shape.size() != output_shape.size()) { + return op.emitOpError( + llvm::formatv("operand rank ({0}) and result rank({0}) should match", + input_shape.size(), output_shape.size())); + } + + for (int i = 0, e = input_shape.size(); i < e; i++) { + int padding_low_val = padding_low.getValue(i).getInt(); + int padding_high_val = padding_high.getValue(i).getInt(); + int padding_interior_val = + padding_interior.getValue(i).getInt(); + int expected_output = + input_shape[i] + padding_low_val + padding_high_val + + std::max(input_shape[i] - 1, 0LL) * padding_interior_val; + if (expected_output != output_shape[i]) { + return op.emitOpError(llvm::formatv( + "expected output shape's dimension #{0} to be {1} but found {2}", i, + expected_output, output_shape[i])); + } + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// ReshapeOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(ReshapeOp op) { + // If the operand type is dynamically shaped there is nothing to verify. + auto operand_ty = op.operand().getType().cast(); + if (!operand_ty || !operand_ty.hasStaticShape()) return success(); + + // If the operand type is statically shaped (not required) the number of + // elements must match that of the result type. + auto result_ty = op.getType().cast(); + assert(result_ty && result_ty.hasStaticShape() && + "result type must be statically shaped"); + int64_t num_result_elements = result_ty.getNumElements(); + int64_t num_operand_elements = operand_ty.getNumElements(); + if (num_result_elements != num_operand_elements) + return op.emitOpError() + << "number of output elements (" << num_result_elements + << ") doesn't match expected number of elements (" + << num_operand_elements << ")"; + + return success(); +} + +OpFoldResult ReshapeOp::fold(ArrayRef operands) { + if (getOperand().getType() == getType()) { + return getOperand(); + } + + if (auto prev_op = + dyn_cast_or_null(getOperand().getDefiningOp())) { + setOperand(prev_op.getOperand()); + return getResult(); + } + + if (auto elements = operands.front().dyn_cast_or_null()) { + return elements.reshape(getResult().getType().cast()); + } + + return {}; +} + +//===----------------------------------------------------------------------===// +// Case Op +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(CaseOp op) { + auto num_branches = op.branches().size(); + if (op.branch_operands().size() != num_branches) + return op.emitOpError() << "expects number of branches " << num_branches + << " to be same as number of branch operands " + << op.branch_operands().size(); + + MutableArrayRef branches = op.branches(); + OperandRange branch_operands = op.branch_operands(); + for (unsigned i = 0; i < num_branches; ++i) { + mlir::Region& branch_region = branches[i]; + if (branch_region.empty()) + return op.emitOpError() << "cannot have empty regions"; + mlir::Block& entry_block = branch_region.front(); + if (entry_block.getNumArguments() != 1) + return op.emitOpError() + << "expects branch regions to have single argument, but found " + << entry_block.getNumArguments() << " for branch " << i; + auto operand = branch_operands[i]; + if (entry_block.getArgument(0).getType() != operand.getType()) + return op.emitOpError() + << "expects operand " << i + 1 << " to be of type " + << entry_block.getArgument(0).getType() << ", but found " + << operand.getType(); + WalkResult walker = branch_region.walk([&](ReturnOp return_op) { + if (return_op.getOperands().getTypes() != op.getResultTypes()) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); + if (walker.wasInterrupted()) + return op.emitOpError() + << "branch " << i + << " returned values do not match op result types"; + } + return success(); +} + +//===----------------------------------------------------------------------===// +// BinaryOps +//===----------------------------------------------------------------------===// + +namespace { + +// Updates the element type of a (presumed) tensor type 'x', returning either +// a permuted UnrankedTensorType or RankedTensorType. +static Type UpdateResultElementType(Builder* builder, Type x, + Type element_type) { + auto x_ranked = x.dyn_cast(); + if (!x_ranked) { + return UnrankedTensorType::get(element_type); + } + + auto shape_x = x_ranked.getShape(); + return RankedTensorType::get(shape_x, element_type); +} +} // namespace + +template +static Attribute BinaryFolder(Op* op, ArrayRef attrs) { + if (!attrs[0] || !attrs[1]) return {}; + + DenseElementsAttr lhs = attrs[0].dyn_cast(); + DenseElementsAttr rhs = attrs[1].dyn_cast(); + if (!lhs || !rhs) return {}; + + ShapedType type = op->getType().template cast(); + if (!type.hasStaticShape()) { + return {}; + } + + Type etype = type.getElementType(); + + // Evaluate for integer values. + if (!etype.isa()) { + return {}; + } + + SmallVector values; + values.reserve(lhs.getNumElements()); + for (const auto zip : + llvm::zip(lhs.getValues(), rhs.getValues())) { + values.push_back(Convert()(std::get<0>(zip), std::get<1>(zip))); + } + + return DenseElementsAttr::get(type, values); +} + +template +struct divide : std::divides {}; + +template <> +struct divide { + APInt operator()(const APInt& a, const APInt& b) const { return a.sdiv(b); } +}; + +template +struct max { + T operator()(const T& a, const T& b) const { return std::max(a, b); } +}; + +template <> +struct max { + APInt operator()(const APInt& a, const APInt& b) const { + return llvm::APIntOps::smax(a, b); + } +}; + +template +struct min { + T operator()(const T& a, const T& b) const { return std::min(a, b); } +}; + +template <> +struct min { + APInt operator()(const APInt& a, const APInt& b) const { + return llvm::APIntOps::smin(a, b); + } +}; + +#define BINARY_FOLDER(Op, Func) \ + OpFoldResult Op::fold(ArrayRef attrs) { \ + if (getElementTypeOrSelf(getType()).isa()) \ + return BinaryFolder>(this, attrs); \ + if (getElementTypeOrSelf(getType()).isa()) \ + return BinaryFolder>(this, attrs); \ + return {}; \ + } + +// Addition, subtraction and multiplication use the std:: versions of the ops. +// Due to the other ops behaving differently in signed vs unsigned integers, +// APInts need a special implementation. Currently, it replicates signed int +// op behavior. +BINARY_FOLDER(AddOp, std::plus); +BINARY_FOLDER(SubOp, std::minus); +BINARY_FOLDER(MulOp, std::multiplies); +BINARY_FOLDER(DivOp, divide); +BINARY_FOLDER(MaxOp, max); +BINARY_FOLDER(MinOp, min); + +#undef BINARY_FOLDER + +//===----------------------------------------------------------------------===// +// SliceOp +//===----------------------------------------------------------------------===// + +void SliceOp::build(OpBuilder& builder, OperationState& result, Value operand, + DenseIntElementsAttr start_indices, + DenseIntElementsAttr limit_indices, + DenseIntElementsAttr strides) { + return build(builder, result, + InferOutputTypes(&builder, operand, start_indices, limit_indices, + strides), + operand, start_indices, limit_indices, strides); +} + +template +static void SliceElements(I values, ArrayRef sizes, + ArrayRef starts, ArrayRef limits, + ArrayRef strides, + llvm::SmallVectorImpl* out_values) { + assert(starts.size() == limits.size()); + assert(starts.size() == strides.size()); + if (starts.empty()) return; + + int64_t start = starts.front(); + int64_t limit = limits.front(); + int64_t stride = strides.front(); + if (starts.size() == 1) { + for (int i = start; i < limit; i += stride) { + out_values->push_back(*(values + i)); + } + return; + } + + for (; start < limit; start += stride) { + auto begin = values + start * sizes.front(); + SliceElements(begin, sizes.drop_front(), starts.drop_front(), + limits.drop_front(), strides.drop_front(), out_values); + } +} + +template +static Attribute FoldSlice(SliceOp* op, I values) { + auto start = llvm::to_vector<6>(op->start_indices().getValues()); + auto limit = llvm::to_vector<6>(op->limit_indices().getValues()); + auto stride = llvm::to_vector<6>(op->strides().getValues()); + + auto result_type = op->operand().getType().cast(); + if (!result_type.hasStaticShape()) return {}; + + auto shape = result_type.getShape(); + int64_t count = result_type.getNumElements(); + // Compute the striding for each dimension. + llvm::SmallVector sizes; + sizes.reserve(shape.size()); + for (auto v : shape) { + count = count / v; + sizes.push_back(count); + } + + llvm::SmallVector out_values; + out_values.reserve(result_type.getNumElements()); + SliceElements(values, sizes, start, limit, stride, &out_values); + + return DenseElementsAttr::get(op->getResult().getType().cast(), + out_values); +} + +OpFoldResult SliceOp::fold(ArrayRef operands) { + // Check if the SliceOp is a NoOp operation. + auto operand_shape = getOperand().getType().cast().getShape(); + auto result_type = getResult().getType().cast(); + auto result_shape = result_type.getShape(); + + if (result_type.hasStaticShape() && (operand_shape == result_shape)) { + return getOperand(); + } + + if (operands.empty() || !operands.front()) return {}; + + // Evaluate for statically valued inputs. + DenseElementsAttr elements = operands.front().dyn_cast(); + if (!elements) return {}; + + auto etype = elements.getType().getElementType(); + if (etype.isa()) { + return FoldSlice( + this, elements.getIntValues().begin()); + } else if (etype.isa()) { + return FoldSlice< + llvm::mapped_iterator>, + APFloat>(this, elements.getFloatValues().begin()); + } + + return {}; +} + +namespace { +// In cases where a concat is fed into a slice, it is possible the concat +// can be simplified or bypassed. This checks which inputs to the concat are +// used by the slice, either reducing the number of concatenated values or +// entirely removes the concat. +struct SimplifyConcatSlice : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(SliceOp slice, + PatternRewriter& rewriter) const override { + auto result_ty = slice.getType().cast(); + if (!result_ty.hasStaticShape()) { + return failure(); + } + + auto slice_input = slice.operand(); + auto slice_input_ty = slice_input.getType().cast(); + auto concat = dyn_cast_or_null(slice_input.getDefiningOp()); + if (!concat) { + return failure(); + } + + auto dimension = concat.dimension().getSExtValue(); + + auto start = slice.start_indices().getIntValues(); + auto limit = slice.limit_indices().getIntValues(); + + auto slice_start = (*(start.begin() + dimension)).getSExtValue(); + auto slice_limit = (*(limit.begin() + dimension)).getSExtValue(); + + // We need to determine what inputs from the concat affect the slice, and + // how the bounds of the slice need to be updated for the minimally required + // inputs. + int64_t running_size = 0; + int64_t front_offset = slice_input_ty.getShape()[dimension]; + + auto subset_start = concat.operand_end(); + auto subset_end = concat.operand_end(); + for (auto it = concat.operand_begin(); it < concat.operand_end(); ++it) { + auto input = *it; + ShapedType input_ty = input.getType().cast(); + if (input_ty.isDynamicDim(dimension)) { + return failure(); + } + auto dim_size = input_ty.getShape()[dimension]; + + // If this position is in the slice its the start of the subset and we + // need to update the start and limit values. + if (running_size + dim_size > slice_start && + subset_start == concat.operand_end()) { + subset_start = it; + front_offset = running_size; + } + + // Determine the last required offset. + if (running_size < slice_limit) { + subset_end = it + 1; + } + + running_size += dim_size; + } + + auto subset_size = subset_end - subset_start; + // We need all inputs so no optimization. + if (subset_size == concat.getNumOperands()) { + return failure(); + } + + if (subset_size > 1 && !concat.getResult().hasOneUse()) { + return failure(); + } + + auto concat_range = OperandRange(subset_start, subset_end); + auto new_concat = rewriter.create( + concat.getLoc(), concat_range, concat.dimension()); + + llvm::SmallVector new_start(start); + llvm::SmallVector new_limit(limit); + new_start[dimension] -= front_offset; + new_limit[dimension] -= front_offset; + + auto attr_type = slice.start_indices().getType().cast(); + auto create = rewriter.create( + slice.getLoc(), new_concat, + DenseIntElementsAttr::get(attr_type, new_start), + DenseIntElementsAttr::get(attr_type, new_limit), slice.strides()); + rewriter.replaceOp(slice, create.getResult()); + return success(); + } +}; +} // namespace + +void SliceOp::getCanonicalizationPatterns(OwningRewritePatternList& results, + MLIRContext* context) { + results.insert(context); +} + +// Returns output dimension size for slice result for the given arguments. +// Returns -1 if arguments are illegal. +static int64_t InferSliceDim(int64_t input_dim, int64_t start, int64_t end, + int64_t stride) { + if (input_dim == -1 || start < 0 || start > end || end > input_dim || + stride == 0) + return -1; + + return llvm::divideCeil(end - start, stride); +} + +Type SliceOp::InferOutputTypes(Builder* builder, Value operand, + DenseIntElementsAttr start_indices, + DenseIntElementsAttr limit_indices, + DenseIntElementsAttr strides) { + Type ty = operand.getType(); + RankedTensorType ranked_ty = ty.dyn_cast(); + if (!ranked_ty) return ty; + int64_t rank = ranked_ty.getRank(); + + // Illegal attributes. + ShapedType attr_ty = start_indices.getType(); + if (attr_ty.getRank() != 1 || attr_ty.getNumElements() != rank || + !attr_ty.getElementType().isSignlessInteger(64) || + limit_indices.getType() != attr_ty || strides.getType() != attr_ty) + return ty; + + SmallVector start(start_indices.getValues()); + SmallVector limit(limit_indices.getValues()); + SmallVector stride_vals(strides.getValues()); + + SmallVector shape; + shape.reserve(rank); + for (int64_t i = 0, e = rank; i != e; i++) { + shape.push_back(InferSliceDim(ranked_ty.getDimSize(i), start[i], limit[i], + stride_vals[i])); + } + return RankedTensorType::get(shape, ranked_ty.getElementType()); +} + +//===----------------------------------------------------------------------===// +// SortOp +//===----------------------------------------------------------------------===// + +void SortOp::build(OpBuilder& builder, OperationState& state, + ValueRange operands, int64_t dimension, bool is_stable) { + state.addOperands(operands); + state.addAttribute("dimension", builder.getI64IntegerAttr(dimension)); + state.addAttribute("is_stable", builder.getBoolAttr(dimension)); + + SmallVector element_types; + element_types.reserve(operands.size()); + for (Value operand : operands) element_types.push_back(operand.getType()); + state.addTypes(builder.getTupleType(element_types)); + + state.addRegion(); +} + +static LogicalResult Verify(SortOp op) { + Operation::operand_range operands = op.operands(); + if (operands.empty()) return op.emitOpError("requires at least one input"); + + // TODO(antiagainst): verify partionally dynamic shapes + if (llvm::all_of(operands, [](Value operand) { + return operand.getType().cast().hasRank(); + })) { + ArrayRef input_shape = + (*operands.begin()).getType().cast().getShape(); + + if (llvm::any_of(llvm::drop_begin(operands, 1), [&](Value operand) { + return operand.getType().cast().getShape() != input_shape; + })) + return op.emitOpError("requires all inputs to have the same dimensions"); + + int64_t rank = input_shape.size(); + int64_t cmp_dim = op.dimension().getSExtValue(); + if (cmp_dim < -rank || cmp_dim >= rank) + return op.emitOpError("dimension attribute value must be in range [-") + << rank << ", " << rank << "), but found " << cmp_dim; + } + + Block& block = op.comparator().front(); + size_t num_operands = op.getOperation()->getNumOperands(); + if (block.getNumArguments() != 2 * num_operands) + return op.emitOpError("comparator block should have ") + << 2 * num_operands << " arguments"; + + for (auto indexed_operand : llvm::enumerate(operands)) { + int index = indexed_operand.index(); + Type element_type = + indexed_operand.value().getType().cast().getElementType(); + Type tensor_type = RankedTensorType::get({}, element_type); + for (int i : {2 * index, 2 * index + 1}) { + Type arg_type = block.getArgument(i).getType(); + if (arg_type != tensor_type) + return op.emitOpError("comparator block argument #") + << i << " should be of type " << tensor_type << " but got " + << arg_type; + } + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// TransposeOp +//===----------------------------------------------------------------------===// + +OpFoldResult TransposeOp::fold(ArrayRef operands) { + for (auto it : llvm::enumerate(permutation().getValues())) { + if (it.index() != it.value()) { + return {}; + } + } + return getOperand(); +} + +static LogicalResult Verify(TransposeOp op) { + // permutation is an attribute of the op so it has static shape. + auto permutationType = op.permutation().getType(); + auto permutationRank = permutationType.getRank(); + if (permutationRank != 1) { + return op.emitOpError(llvm::formatv( + "permutation has rank {0} instead of rank 1", permutationRank)); + } + auto permutationSize = permutationType.getNumElements(); + + auto operandType = op.operand().getType().dyn_cast(); + if (operandType) { + auto operandRank = operandType.getRank(); + if (operandRank != permutationSize) { + return op.emitOpError(llvm::formatv( + "operand rank ({0}) does not match permutation size ({1})", + operandRank, permutationSize)); + } + } + + auto resultType = op.getResult().getType().dyn_cast(); + if (resultType) { + auto resultRank = resultType.getRank(); + if (resultRank != permutationSize) { + return op.emitOpError(llvm::formatv( + "result rank ({0}) does not match permutation size ({1})", resultRank, + permutationSize)); + } + } + + if (!resultType || !operandType) return success(); + + auto operandRank = operandType.getRank(); + SmallVector expectedShape(operandRank); + for (int i = 0; i != operandRank; ++i) { + auto permutedDim = op.permutation().getValue(i).getInt(); + expectedShape[i] = operandType.getDimSize(permutedDim); + } + + auto expectedType = + RankedTensorType::get(expectedShape, resultType.getElementType()); + if (failed(verifyCompatibleShape(resultType, expectedType))) { + return op.emitOpError(llvm::formatv( + "result type {0} is incompatible with the expected type {1}", + resultType, expectedType)); + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// TriangularSolveOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(TriangularSolveOp op) { + auto a_type = op.a().getType().dyn_cast(); + + // Skip verifier if a is unranked tensor. + if (!a_type) return success(); + + // Check that a should have rank >= 2 + auto a_rank = a_type.getRank(); + if (a_rank < 2) + return op.emitOpError() + << "operand 'a' must have rank >= 2, but got " << a_type; + + // The two minor dimensions of a must have same size. + if (a_type.getDimSize(a_rank - 2) != a_type.getDimSize(a_rank - 1)) + return op.emitOpError() << "two minor dimensions of operand 'a' must have " + "equal size, but got " + << a_type; + + auto b_type = op.b().getType().dyn_cast(); + // If b is unranked skip remaining checks. + if (!b_type) return success(); + + // Check that a and b have same rank. + auto b_rank = b_type.getRank(); + if (a_rank != b_rank) + return op.emitOpError() << "operands must have equal rank, but got " + << a_type << " and " << b_type; + + // The shared dimension of a and b should match. + if (a_type.getDimSize(a_rank - 1) != + b_type.getDimSize(b_rank - (op.left_side() ? 2 : 1))) + return op.emitOpError() << "shared dimension of operands 'a' and 'b' does " + "not match, but got " + << a_type << " and " << b_type; + + // The leading batch dimensions of a and b must be equal. + auto a_batch_dims = a_type.getShape().drop_back(2); + auto b_batch_dims = b_type.getShape().drop_back(2); + if (a_batch_dims != b_batch_dims) + return op.emitOpError() + << "leading batch dimensions of the operands must be same, but got " + << a_type << " and " << b_type; + + // Result and argument b must have same shape. + auto result_type = op.getType().dyn_cast(); + if (!result_type) return success(); + if (result_type != b_type) + return op.emitOpError() + << "result and operand 'b' must have same shape, but got " + << result_type << " and " << b_type; + return success(); +} + +//===----------------------------------------------------------------------===// +// GetTupleElementOp +//===----------------------------------------------------------------------===// + +void GetTupleElementOp::build(OpBuilder& builder, OperationState& result, + Value tuple, int32_t index) { + if (auto tuple_type = tuple.getType().dyn_cast()) { + auto element_type = tuple_type.getType(index); + build(builder, result, element_type, tuple, + builder.getI32IntegerAttr(index)); + return; + } + + build(builder, result, tuple.getType(), tuple, + builder.getI32IntegerAttr(index)); +} + +//===----------------------------------------------------------------------===// +// TupleOp +//===----------------------------------------------------------------------===// + +void TupleOp::build(OpBuilder& builder, OperationState& result, + ValueRange values) { + SmallVector types; + types.reserve(values.size()); + for (auto val : values) { + types.push_back(val.getType()); + } + + build(builder, result, builder.getTupleType(types), values); +} + +//===----------------------------------------------------------------------===// +// UnaryEinsumOp +//===----------------------------------------------------------------------===// + +void UnaryEinsumOp::getCanonicalizationPatterns( + OwningRewritePatternList& results, MLIRContext* context) { + results.insert(context); +} + +//===----------------------------------------------------------------------===// +// CompareOp +//===----------------------------------------------------------------------===// + +void CompareOp::build(OpBuilder& builder, OperationState& result, Value lhs, + Value rhs, StringAttr comparison_direction) { + auto new_type = + UpdateResultElementType(&builder, lhs.getType(), builder.getI1Type()); + build(builder, result, new_type, lhs, rhs, comparison_direction); +} + +#define GET_OP_CLASSES +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.cc.inc" + +//===----------------------------------------------------------------------===// +// xla_hlo Dialect Interfaces +//===----------------------------------------------------------------------===// + +namespace { +struct HLOInlinerInterface : public DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; + // We don't have any special restrictions on what can be inlined into + // destination regions (e.g. while/conditional bodies). Always allow it. + bool isLegalToInline(Region* dest, Region* src, + BlockAndValueMapping& valueMapping) const final { + return true; + } + // Operations in xla_hlo dialect are always legal to inline since they are + // pure. + bool isLegalToInline(Operation*, Region*, BlockAndValueMapping&) const final { + return true; + } +}; +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// xla_hlo Dialect Constructor +//===----------------------------------------------------------------------===// + +XlaHloDialect::XlaHloDialect(MLIRContext* context) + : Dialect(getDialectNamespace(), context) { + addOperations< +#define GET_OP_LIST +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.cc.inc" + >(); + addInterfaces(); + addTypes(); + // Support unknown operations because not all XLA operations are registered. + // allowUnknownOperations(); +} + +Type XlaHloDialect::parseType(DialectAsmParser& parser) const { + StringRef data_type; + if (parser.parseKeyword(&data_type)) return Type(); + + if (data_type == "token") return TokenType::get(getContext()); + parser.emitError(parser.getNameLoc()) + << "unknown xla_hlo type: " << data_type; + return nullptr; +} + +void XlaHloDialect::printType(Type type, DialectAsmPrinter& os) const { + if (type.isa()) { + os << "token"; + return; + } + os << ""; +} + +//===----------------------------------------------------------------------===// +// Shape inference +//===----------------------------------------------------------------------===// + +LogicalResult deriveShapeFromFirstOperand( + OpBuilder* builder, Operation* op, + SmallVectorImpl* reifiedReturnShapes) { + Value operand = op->getOperand(0); + ShapedType operand_type = operand.getType().dyn_cast(); + if (!operand_type) { + op->emitOpError() << "first operand is not a shaped type"; + return failure(); + } + auto loc = op->getLoc(); + SmallVector shape_values; + shape_values.reserve(operand_type.getRank()); + auto shape_scalar_type = builder->getIntegerType(64); + for (auto element : llvm::enumerate(operand_type.getShape())) { + if (element.value() == ShapedType::kDynamicSize) { + Value dim = builder->create(loc, operand, element.index()); + shape_values.push_back( + builder->create(loc, dim, shape_scalar_type)); + } else { + shape_values.push_back(builder->create( + loc, builder->getI64IntegerAttr(element.value()))); + } + } + *reifiedReturnShapes = SmallVector{ + builder->create(loc, shape_values)}; + return success(); +} + +} // namespace xla_hlo +} // namespace mlir diff --git a/lib/Dialect/mhlo/IR/infer_fusibility_op_interface.cc b/lib/Dialect/mhlo/IR/infer_fusibility_op_interface.cc new file mode 100644 index 0000000..23712d1 --- /dev/null +++ b/lib/Dialect/mhlo/IR/infer_fusibility_op_interface.cc @@ -0,0 +1,22 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h" + +namespace mlir { + +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.cc.inc" + +} // namespace mlir diff --git a/lib/Dialect/mhlo/IR/lhlo_ops.cc b/lib/Dialect/mhlo/IR/lhlo_ops.cc new file mode 100644 index 0000000..3e374a4 --- /dev/null +++ b/lib/Dialect/mhlo/IR/lhlo_ops.cc @@ -0,0 +1,102 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file defines the operations used in the XLA dialect. + +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" + +#include +#include +#include + +#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/APFloat.h" +#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/APInt.h" +#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/ArrayRef.h" +#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLExtras.h" +#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/SmallVector.h" +#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringRef.h" +#include "third_party/llvm/llvm-project/llvm/include/llvm/Support/FormatVariadic.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Dialect.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Location.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpDefinition.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpImplementation.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OperationSupport.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/TypeUtilities.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Types.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Value.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc" + +namespace mlir { +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_structs.cc.inc" +namespace xla_lhlo { + +XlaLhloDialect::XlaLhloDialect(MLIRContext *context) + : Dialect(getDialectNamespace(), context) { + addOperations< +#define GET_OP_LIST +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.cc.inc" + >(); +} + +//===----------------------------------------------------------------------===// +// StaticMemRefCastOp +//===----------------------------------------------------------------------===// + +Value StaticMemRefCastOp::getViewSource() { return *getODSOperands(0).begin(); } + +static LogicalResult Verify(StaticMemRefCastOp op) { + if (!op.operand().getType().cast().hasStaticShape()) + return op.emitOpError("operand must have static shape"); + if (!op.getType().hasStaticShape()) + return op.emitOpError("result must have static shape"); + return success(); +} + +//===----------------------------------------------------------------------===// +// DynamicMemRefCastOp +//===----------------------------------------------------------------------===// + +Value DynamicMemRefCastOp::getViewSource() { + return *getODSOperands(0).begin(); +} + +static LogicalResult Verify(DynamicMemRefCastOp op) { + // Check if `sizes` and `strides` args are compatible with the result type. + if (op.sizes().size() != op.getType().getRank()) + return op.emitOpError( + "`sizes` args count must be equal to the rank of the output memref"); + return success(); +} + +#define GET_OP_CLASSES +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.cc.inc" + +// TODO(cheshire): Support folding, reuse code from hlo_ops.cc. + +void FusionOp::build(OpBuilder &builder, OperationState &result, + ArrayRef attributes) { + result.addAttributes(attributes); + Region *bodyRegion = result.addRegion(); + FusionOp::ensureTerminator(*bodyRegion, builder, result.location); +} + +} // namespace xla_lhlo +} // namespace mlir diff --git a/lib/Dialect/mhlo/transforms/canonicalize.td b/lib/Dialect/mhlo/transforms/canonicalize.td new file mode 100644 index 0000000..a6435bc --- /dev/null +++ b/lib/Dialect/mhlo/transforms/canonicalize.td @@ -0,0 +1,30 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This is the canonicalize pattern definition file. + +include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpBase.td" +include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.td" +include "mlir-hlo/Dialect/mhlo/IR/hlo_utils.td" + +def UnaryToBinaryEinsumEq : NativeCodeCall< + "$_builder.getStringAttr(\",\" + $0.getValue().str())">; + +// Convert UnaryEinsumOp to EinsumOp with two operands with redundant first +// operand. +def UnaryEinsumToEinsum : Pat< + (HLO_UnaryEinsumOp $operand, $equation), + (HLO_EinsumOp (HLO_ConstOp (GetScalarOfType<1> $operand)), + $operand, (UnaryToBinaryEinsumEq $equation))>; diff --git a/lib/utils/broadcast_utils.cc b/lib/utils/broadcast_utils.cc new file mode 100644 index 0000000..1c1499a --- /dev/null +++ b/lib/utils/broadcast_utils.cc @@ -0,0 +1,74 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/broadcast_utils.h" + +#include + +#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/Sequence.h" +#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/SmallVector.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Shape/IR/Shape.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Diagnostics.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" + +namespace mlir { +namespace xla { + +bool IsLegalNumpyRankedBroadcast(Value lhs, Value rhs, + DenseIntElementsAttr broadcast_dims) { + RankedTensorType lhs_type = lhs.getType().dyn_cast(); + RankedTensorType rhs_type = rhs.getType().dyn_cast(); + if (!lhs_type || !rhs_type) return false; + if (lhs_type.getRank() == rhs_type.getRank()) return true; + + // Otherwise, verify that broadcast_dims strictly performs left-padding. + auto smaller_rank = std::min(lhs_type.getRank(), rhs_type.getRank()); + auto larger_rank = std::max(lhs_type.getRank(), rhs_type.getRank()); + + if (smaller_rank != broadcast_dims.getNumElements()) { + return false; + } + auto expected_extents = + llvm::seq(larger_rank - smaller_rank, larger_rank); + return std::equal(expected_extents.begin(), expected_extents.end(), + broadcast_dims.getIntValues().begin()); +} + +Value ComputeBinaryElementwiseBroadcastingResultExtents(Location loc, Value lhs, + Value rhs, + OpBuilder& builder) { + auto lhs_type = lhs.getType().dyn_cast(); + auto rhs_type = rhs.getType().dyn_cast(); + if (!lhs_type || !rhs_type) { + emitError(loc) << "shape computation for broadcasting elementwise ops " + << "is only implemented for ranked tensors"; + return nullptr; + } + + int64_t result_rank = std::max(lhs_type.getRank(), rhs_type.getRank()); + auto shape_type = shape::ShapeType::get(builder.getContext()); + Value lhs_shape_v = + builder.createOrFold(loc, shape_type, lhs); + Value rhs_shape_v = + builder.createOrFold(loc, shape_type, rhs); + Value result_shape_v = builder.createOrFold( + loc, shape_type, lhs_shape_v, rhs_shape_v, nullptr /* error */); + return builder.createOrFold( + loc, RankedTensorType::get({result_rank}, builder.getIndexType()), + result_shape_v); +} + +} // namespace xla +} // namespace mlir diff --git a/lib/utils/convert_op_folder.cc b/lib/utils/convert_op_folder.cc new file mode 100644 index 0000000..cf6b56f --- /dev/null +++ b/lib/utils/convert_op_folder.cc @@ -0,0 +1,86 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file defines helpers useful when creating or manipulating lhlo/hlo. + +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/convert_op_folder.h" + +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/TypeUtilities.h" + +namespace mlir { +namespace xla { + +mlir::ElementsAttr ConvertElementsAttr(const mlir::ElementsAttr& elements, + mlir::Type new_type) { + auto old_type = getElementTypeOrSelf(elements); + size_t bit_width = new_type.isBF16() ? 64 : new_type.getIntOrFloatBitWidth(); + + if (old_type.isa()) { + // mapValues always takes a function returning APInt, even when the output + // is actually float. + using func_type = mlir::APInt(const llvm::APFloat&); + if (auto newFloatType = new_type.dyn_cast()) { + // Float -> Float + return elements.mapValues( + new_type, llvm::function_ref( + [&newFloatType](const llvm::APFloat& floatVal) { + llvm::APFloat newDouble( + mlir::FloatAttr::getValueAsDouble(floatVal)); + bool loses_info = false; + newDouble.convert(newFloatType.getFloatSemantics(), + llvm::APFloat::rmNearestTiesToEven, + &loses_info); + return newDouble.bitcastToAPInt(); + })); + } + // Float -> Int + return elements.mapValues( + new_type, llvm::function_ref( + [&bit_width](const llvm::APFloat& floatVal) { + return llvm::APInt( + bit_width, + mlir::FloatAttr::getValueAsDouble(floatVal)); + })); + } + + // old_type is Integer + // mapValues always takes a function returning APInt, even when the output + // is actually float. + using func_type = llvm::APInt(const llvm::APInt&); + if (auto newFloatType = new_type.dyn_cast()) { + // Int -> Float + return elements.mapValues( + new_type, llvm::function_ref([&newFloatType]( + const llvm::APInt& intVal) { + llvm::APFloat newDouble(static_cast(intVal.getSExtValue())); + bool loses_info = false; + newDouble.convert(newFloatType.getFloatSemantics(), + llvm::APFloat::rmNearestTiesToEven, &loses_info); + return newDouble.bitcastToAPInt(); + })); + } + // new_type is Integer + // Int -> Int + return elements.mapValues( + new_type, + llvm::function_ref([&bit_width](const llvm::APInt& intVal) { + return llvm::APInt(bit_width, intVal.getSExtValue()); + })); +} + +} // namespace xla +} // namespace mlir diff --git a/lib/utils/hlo_utils.cc b/lib/utils/hlo_utils.cc new file mode 100644 index 0000000..75c01a9 --- /dev/null +++ b/lib/utils/hlo_utils.cc @@ -0,0 +1,70 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h" + +#include + +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h" + +namespace mlir { +namespace xla { + +DenseIntElementsAttr getBroadcastDimensionsAttr(Builder *b, Value x, Value y, + bool allow_empty) { + TensorType xType = x.getType().dyn_cast(); + TensorType yType = y.getType().dyn_cast(); + if (!xType || !yType) return {}; + if (allow_empty && xType == yType) return {}; + + // If the shapes have the same rank, then there is nothing to do. + auto xRank = xType.getRank(), yRank = yType.getRank(); + if (allow_empty && xRank == yRank) return {}; + + // Otherwise if the ranks of the inputs don't match, TensorFlow automatically + // reshapes the smaller by padding with dimensions of size 1 as a prefix. In + // other words to pad a 5-vector to a 3-dimensional tensor it is reshaped to + // have shape [1,1,5]. XLA's automatic broadcast code is able to broadcast + // from lower to higher rank, but doesn't assume you want to pad as a prefix + // of the dimensions, and instead needs to be told which dimensions of the + // higher rank tensor to match to the lower rank tensor. + auto maxRank = std::max(xRank, yRank); + auto minRank = std::min(xRank, yRank); + + // Match the lower rank tensor along the larger-numbered dimensions of the + // higher rank tensor. + SmallVector broadcastDimensions(minRank); + std::iota(broadcastDimensions.begin(), broadcastDimensions.end(), + maxRank - minRank); + + RankedTensorType type = + RankedTensorType::get({minRank}, b->getIntegerType(64)); + return DenseIntElementsAttr::get(type, broadcastDimensions); +} + +DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value) { + RankedTensorType scalar_ty = RankedTensorType::get({}, ty); + + if (auto float_ty = ty.dyn_cast()) { + APFloat value(float_ty.getFloatSemantics(), raw_value); + return DenseElementsAttr::get(scalar_ty, value); + } + auto int_ty = ty.cast(); + APInt value(int_ty.getWidth(), static_cast(raw_value), true); + return DenseElementsAttr::get(scalar_ty, value); +} + +} // namespace xla +} // namespace mlir