Move the HLO/LHLO dialects to a new directory: tensorflow/compiler/mlir/hlo
We're preparing to restructure the MLIR HLO ecosystem with 5 dialects: - chlo: client dialect with explicit broadcast and multiple composite operations - mhlo: hlo with dynamic shape, decouple from XLA for evolution purpose - lmhlo: same as above, but after buffer assignment. - xla_hlo: mapping 1:1 to the XLA HloInstruction class. - xla_lhlo: same as above, but after buffer assignment. The first three dialects are intended to live in the new tensorflow/compiler/mlir/hlo path, the latter two will be created in tensorflow/compiler/mlir/xla. This patch only moves the directory, will followup with other transformations and tests. The structure of the new directory follows: https://llvm.discourse.group/t/rfc-canonical-file-paths-to-dialects/621 as we intend to make it a standalone buildable component (see also https://github.com/google/mlir-npcomp as another example). PiperOrigin-RevId: 319273229
This commit is contained in:
		
							parent
							
								
									5fe5c39ccc
								
							
						
					
					
						commit
						fcf3df1541
					
				| 
						 | 
				
			
			@ -0,0 +1,45 @@
 | 
			
		|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_CHLO_OPS_H_
 | 
			
		||||
#define TENSORFLOW_COMPILER_MLIR_XLA_IR_CHLO_OPS_H_
 | 
			
		||||
 | 
			
		||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringRef.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Dialect.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/DialectImplementation.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpDefinition.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Types.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/InferTypeOpInterface.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/SideEffectInterfaces.h"
 | 
			
		||||
 | 
			
		||||
namespace mlir {
 | 
			
		||||
namespace xla_chlo {
 | 
			
		||||
 | 
			
		||||
class XlaHloClientDialect : public Dialect {
 | 
			
		||||
 public:
 | 
			
		||||
  explicit XlaHloClientDialect(MLIRContext *context);
 | 
			
		||||
  static StringRef getDialectNamespace() { return "xla_chlo"; }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
#define GET_OP_CLASSES
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h.inc"
 | 
			
		||||
 | 
			
		||||
}  // namespace xla_chlo
 | 
			
		||||
}  // namespace mlir
 | 
			
		||||
 | 
			
		||||
#endif  // TENSORFLOW_COMPILER_MLIR_XLA_IR_CHLO_OPS_H_
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,370 @@
 | 
			
		|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
// Defines "client" aligned HLO ops.
 | 
			
		||||
// These ops are not necessarily orthogonal or optimized for transformation but
 | 
			
		||||
// for ease of expression in certain cases deemed important for client
 | 
			
		||||
// libraries (i.e. implicit broadcasting, helper ops, etc).
 | 
			
		||||
// This dialect is considered to exist in addition to augment the xla_hlo
 | 
			
		||||
// dialect for ergonomic needs, not duplicate/replace it.
 | 
			
		||||
//
 | 
			
		||||
// The typical use of this dialect is for client libraries to be able to emit
 | 
			
		||||
// less constrained ops and rely on the conversion framework to lower any
 | 
			
		||||
// xla_chlo ops to canonical xla_hlo ops.
 | 
			
		||||
//
 | 
			
		||||
// See: https://www.tensorflow.org/xla/operation_semantics
 | 
			
		||||
 | 
			
		||||
#ifndef CHLO_OPS
 | 
			
		||||
#define CHLO_OPS
 | 
			
		||||
 | 
			
		||||
include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpBase.td"
 | 
			
		||||
include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/InferTypeOpInterface.td"
 | 
			
		||||
include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/SideEffectInterfaces.td"
 | 
			
		||||
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td"
 | 
			
		||||
 | 
			
		||||
def HLOClient_Dialect : Dialect {
 | 
			
		||||
  let name = "xla_chlo";
 | 
			
		||||
  let cppNamespace = "xla_chlo";
 | 
			
		||||
  let summary = [{
 | 
			
		||||
    XLA Client HLO Ops
 | 
			
		||||
  }];
 | 
			
		||||
 | 
			
		||||
  let description = [{
 | 
			
		||||
    This dialect contains ops that align closely with the API surface area
 | 
			
		||||
    of the XlaBuilder C++ API, where such ops have semantics that go beyond
 | 
			
		||||
    what exists in the lower level dialects (such as `xla_hlo`). Essentially,
 | 
			
		||||
    whenever the client library uses syntactic sugar or composition
 | 
			
		||||
    of multiple ops for an API call, this dialect tries to model the API call
 | 
			
		||||
    and provide conversion patterns to fully materialize into lower level
 | 
			
		||||
    dialects.
 | 
			
		||||
  }];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
class HLOClient_Op<string mnemonic, list<OpTrait> traits> :
 | 
			
		||||
    Op<HLOClient_Dialect, mnemonic, traits> {
 | 
			
		||||
  // TODO(b/129012527) Much of this custom verification should be expressed as
 | 
			
		||||
  // type constraints.
 | 
			
		||||
  let verifier = [{ return Verify(*this); }];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// XLA binary elementwise op definitions.
 | 
			
		||||
// From the client perspective, each of these support both explicit rank
 | 
			
		||||
// broadcasting (via the broadcast_dimensions attribute) and implicit degenerate
 | 
			
		||||
// shape broadcasting.
 | 
			
		||||
//
 | 
			
		||||
// These correspond to operations in the xla_hlo dialect without the
 | 
			
		||||
// "broadcast_" prefix, except that those ops require same-shaped operands and
 | 
			
		||||
// results.
 | 
			
		||||
//
 | 
			
		||||
// See:
 | 
			
		||||
//   https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations
 | 
			
		||||
//   https://www.tensorflow.org/xla/broadcasting
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
class HLOClient_BroadcastBinaryElementwiseOp<
 | 
			
		||||
  string mnemonic, list<OpTrait> traits> :
 | 
			
		||||
        HLOClient_Op<mnemonic,
 | 
			
		||||
            !listconcat(traits, [
 | 
			
		||||
              DeclareOpInterfaceMethods<InferShapedTypeOpInterface>])> {
 | 
			
		||||
  let arguments = (ins
 | 
			
		||||
    HLO_Tensor:$lhs,
 | 
			
		||||
    HLO_Tensor:$rhs,
 | 
			
		||||
    // Explicit rank-broadcast dimension mappings. Defaults to "numpy" prefix
 | 
			
		||||
    // padded rank-broadcast semantics if omitted.
 | 
			
		||||
    OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions
 | 
			
		||||
  );
 | 
			
		||||
 | 
			
		||||
  let builders = [OpBuilder<
 | 
			
		||||
    "OpBuilder &builder, OperationState &result, Value left, Value  right, "
 | 
			
		||||
    "DenseIntElementsAttr broadcast_dimensions"
 | 
			
		||||
  >];
 | 
			
		||||
 | 
			
		||||
  let results = (outs HLO_Tensor);
 | 
			
		||||
 | 
			
		||||
  let assemblyFormat = [{
 | 
			
		||||
    $lhs `,` $rhs attr-dict `:`
 | 
			
		||||
    `(` type($lhs) `,` type($rhs) `)` `->` type(results)
 | 
			
		||||
  }];
 | 
			
		||||
 | 
			
		||||
  let extraClassDeclaration = [{
 | 
			
		||||
    // TODO(laurenzo): It isn't clear to me why reifyReturnShapes does not
 | 
			
		||||
    // have its declaration generated by DeclareOpInterfaceMethods.
 | 
			
		||||
    LogicalResult reifyReturnTypeShapes(
 | 
			
		||||
         OpBuilder& builder, SmallVectorImpl<Value>& reifiedReturnShapes);
 | 
			
		||||
  }];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
def HLOClient_BroadcastAddOp : HLOClient_BroadcastBinaryElementwiseOp<"broadcast_add",
 | 
			
		||||
    [Commutative, NoSideEffect, SameOperandsAndResultElementType]> {
 | 
			
		||||
  string summary = "Addition operator (with optional broadcasting)";
 | 
			
		||||
 | 
			
		||||
  string description = [{
 | 
			
		||||
    Returns `lhs + rhs` element-wise.
 | 
			
		||||
 | 
			
		||||
    See
 | 
			
		||||
    https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
 | 
			
		||||
  }];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
def HLOClient_BroadcastAtan2Op : HLOClient_BroadcastBinaryElementwiseOp<
 | 
			
		||||
    "broadcast_atan2",
 | 
			
		||||
    [NoSideEffect, SameOperandsAndResultElementType]> {
 | 
			
		||||
  string summary = "Atan2 operator (with optional broadcasting)";
 | 
			
		||||
 | 
			
		||||
  string description = [{
 | 
			
		||||
    Returns `atan2(lhs/rhs)` element-wise.
 | 
			
		||||
 | 
			
		||||
    See
 | 
			
		||||
    https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
 | 
			
		||||
  }];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
def HLOClient_BroadcastDivOp : HLOClient_BroadcastBinaryElementwiseOp<
 | 
			
		||||
    "broadcast_divide",
 | 
			
		||||
    [NoSideEffect, SameOperandsAndResultElementType]> {
 | 
			
		||||
  string summary = "Division operator (with optional broadcasting)";
 | 
			
		||||
 | 
			
		||||
  string description = [{
 | 
			
		||||
    Returns `lhs / rhs` element-wise.
 | 
			
		||||
 | 
			
		||||
    See
 | 
			
		||||
    https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
 | 
			
		||||
  }];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
def HLOClient_BroadcastMaxOp : HLOClient_BroadcastBinaryElementwiseOp<
 | 
			
		||||
    "broadcast_maximum",
 | 
			
		||||
    [Commutative, NoSideEffect, SameOperandsAndResultElementType]> {
 | 
			
		||||
  string summary = "Maximum operator (with optional broadcasting)";
 | 
			
		||||
 | 
			
		||||
  string description = [{
 | 
			
		||||
    Returns `max(lhs, rhs)` element-wise.
 | 
			
		||||
 | 
			
		||||
    See
 | 
			
		||||
    https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
 | 
			
		||||
  }];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
def HLOClient_BroadcastMinOp : HLOClient_BroadcastBinaryElementwiseOp<
 | 
			
		||||
    "broadcast_minimum",
 | 
			
		||||
    [Commutative, NoSideEffect, SameOperandsAndResultElementType]> {
 | 
			
		||||
  string summary = "Minimum operator (with optional broadcasting)";
 | 
			
		||||
 | 
			
		||||
  string description = [{
 | 
			
		||||
    Returns `min(lhs, rhs)` element-wise.
 | 
			
		||||
 | 
			
		||||
    See
 | 
			
		||||
    https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
 | 
			
		||||
  }];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
def HLOClient_BroadcastMulOp : HLOClient_BroadcastBinaryElementwiseOp<
 | 
			
		||||
    "broadcast_multiply",
 | 
			
		||||
    [Commutative, NoSideEffect, SameOperandsAndResultElementType]> {
 | 
			
		||||
  string summary = "Multiplication operator (with optional broadcasting)";
 | 
			
		||||
 | 
			
		||||
  string description = [{
 | 
			
		||||
    Returns `lhs * rhs` element-wise.
 | 
			
		||||
 | 
			
		||||
    See
 | 
			
		||||
    https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
 | 
			
		||||
  }];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
def HLOClient_BroadcastPowOp : HLOClient_BroadcastBinaryElementwiseOp<
 | 
			
		||||
    "broadcast_power",
 | 
			
		||||
    [NoSideEffect, SameOperandsAndResultElementType]> {
 | 
			
		||||
  string summary = "Power operator (with optional broadcasting)";
 | 
			
		||||
 | 
			
		||||
  string description = [{
 | 
			
		||||
    Returns `lhs ^ rhs` element-wise.
 | 
			
		||||
 | 
			
		||||
    See
 | 
			
		||||
    https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
 | 
			
		||||
  }];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
def HLOClient_BroadcastRemOp : HLOClient_BroadcastBinaryElementwiseOp<
 | 
			
		||||
    "broadcast_remainder",
 | 
			
		||||
    [NoSideEffect, SameOperandsAndResultElementType]> {
 | 
			
		||||
  string summary = "Remainder operator (with optional broadcasting)";
 | 
			
		||||
 | 
			
		||||
  string description = [{
 | 
			
		||||
    Returns `lhs % rhs` element-wise.
 | 
			
		||||
 | 
			
		||||
    See
 | 
			
		||||
    https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
 | 
			
		||||
  }];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
def HLOClient_BroadcastShiftLeftOp : HLOClient_BroadcastBinaryElementwiseOp<
 | 
			
		||||
    "broadcast_shift_left",
 | 
			
		||||
    [NoSideEffect, SameOperandsAndResultElementType]> {
 | 
			
		||||
  string summary = "Shift left operator (with optional broadcasting)";
 | 
			
		||||
 | 
			
		||||
  string description = [{
 | 
			
		||||
    Returns `lhs << rhs` element-wise.
 | 
			
		||||
 | 
			
		||||
    See
 | 
			
		||||
    https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
 | 
			
		||||
  }];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
def HLOClient_BroadcastShiftRightArithmeticOp : HLOClient_BroadcastBinaryElementwiseOp<
 | 
			
		||||
    "broadcast_shift_right_arithmetic",
 | 
			
		||||
    [NoSideEffect, SameOperandsAndResultElementType]> {
 | 
			
		||||
  string summary = "Shift right arithmetic operator (with optional broadcasting)";
 | 
			
		||||
 | 
			
		||||
  string description = [{
 | 
			
		||||
    Returns `lhs >> rhs` element-wise.
 | 
			
		||||
 | 
			
		||||
    See
 | 
			
		||||
    https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
 | 
			
		||||
  }];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
def HLOClient_BroadcastShiftRightLogicalOp : HLOClient_BroadcastBinaryElementwiseOp<
 | 
			
		||||
    "broadcast_shift_right_logical",
 | 
			
		||||
    [NoSideEffect, SameOperandsAndResultElementType]> {
 | 
			
		||||
  string summary = "Shift right logical operator (with optional broadcasting)";
 | 
			
		||||
 | 
			
		||||
  string description = [{
 | 
			
		||||
    Returns `lhs >> rhs` element-wise.
 | 
			
		||||
 | 
			
		||||
    See
 | 
			
		||||
    https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
 | 
			
		||||
  }];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
def HLOClient_BroadcastSubOp : HLOClient_BroadcastBinaryElementwiseOp<
 | 
			
		||||
    "broadcast_subtract",
 | 
			
		||||
    [NoSideEffect, SameOperandsAndResultElementType]> {
 | 
			
		||||
  string summary = "Subtraction operator (with optional broadcasting)";
 | 
			
		||||
 | 
			
		||||
  string description = [{
 | 
			
		||||
    Returns `lhs - rhs` element-wise.
 | 
			
		||||
 | 
			
		||||
    See
 | 
			
		||||
    https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
 | 
			
		||||
  }];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// XLA binary elementwise op definitions.
 | 
			
		||||
// The same description as the arithmetic binary elementwise ops applies.
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
class HLOClient_BroadcastBinaryLogicalElementwiseOp<string mnemonic> :
 | 
			
		||||
    HLOClient_BroadcastBinaryElementwiseOp<
 | 
			
		||||
      mnemonic, [Commutative, NoSideEffect]> {
 | 
			
		||||
  let arguments = (ins
 | 
			
		||||
    HLO_PredOrIntTensor:$lhs,
 | 
			
		||||
    HLO_PredOrIntTensor:$rhs,
 | 
			
		||||
    // Explicit rank-broadcast dimension mappings. Defaults to "numpy" prefix
 | 
			
		||||
    // padded rank-broadcast semantics if omitted.
 | 
			
		||||
    OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions
 | 
			
		||||
  );
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
def HLOClient_BroadcastAndOp: HLOClient_BroadcastBinaryLogicalElementwiseOp<
 | 
			
		||||
    "broadcast_and"> {
 | 
			
		||||
  string summary = "Logical and operator (with optional broadcasting)";
 | 
			
		||||
 | 
			
		||||
  string description = [{
 | 
			
		||||
    Returns `logical_and(lhs, rhs)` element-wise.
 | 
			
		||||
 | 
			
		||||
    See
 | 
			
		||||
    https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
 | 
			
		||||
  }];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
def HLOClient_BroadcastOrOp: HLOClient_BroadcastBinaryLogicalElementwiseOp<
 | 
			
		||||
    "broadcast_or"> {
 | 
			
		||||
  string summary = "Logical or operator (with optional broadcasting)";
 | 
			
		||||
 | 
			
		||||
  string description = [{
 | 
			
		||||
    Returns `logical_or(lhs, rhs)` element-wise.
 | 
			
		||||
 | 
			
		||||
    See
 | 
			
		||||
    https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
 | 
			
		||||
  }];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
def HLOClient_BroadcastXorOp : HLOClient_BroadcastBinaryLogicalElementwiseOp<
 | 
			
		||||
    "broadcast_xor"> {
 | 
			
		||||
  string summary = "Logical xor operator (with optional broadcasting)";
 | 
			
		||||
 | 
			
		||||
  string description = [{
 | 
			
		||||
    Returns `logical_xor(lhs, rhs)` element-wise.
 | 
			
		||||
 | 
			
		||||
    See
 | 
			
		||||
    https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
 | 
			
		||||
  }];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// Broadcasting complex op
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
def HLOClient_BroadcastComplexOp : HLOClient_BroadcastBinaryElementwiseOp<
 | 
			
		||||
    "broadcast_complex", [NoSideEffect]> {
 | 
			
		||||
  string summary = "Complex operator (with optional broadcasting)";
 | 
			
		||||
 | 
			
		||||
  string description = [{
 | 
			
		||||
    Performs element-wise conversion of a pair of real and imaginary values to
 | 
			
		||||
    a complex value.
 | 
			
		||||
  }];
 | 
			
		||||
 | 
			
		||||
  let arguments = (ins
 | 
			
		||||
    HLO_FpTensor:$lhs,
 | 
			
		||||
    HLO_FpTensor:$rhs,
 | 
			
		||||
    // Explicit rank-broadcast dimension mappings. Defaults to "numpy" prefix
 | 
			
		||||
    // padded rank-broadcast semantics if omitted.
 | 
			
		||||
    OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions
 | 
			
		||||
  );
 | 
			
		||||
  let results = (outs HLO_ComplexTensor);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// Broadcasting compare op
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
def HLOClient_BroadcastCompareOp : HLOClient_BroadcastBinaryElementwiseOp<
 | 
			
		||||
    "broadcast_compare", [NoSideEffect]> {
 | 
			
		||||
  string summary = "Compare operator (with optional broadcasting)";
 | 
			
		||||
 | 
			
		||||
  string description = [{
 | 
			
		||||
    Compares `lhs` and `rhs` elementwise according to `comparison_direction`.
 | 
			
		||||
 | 
			
		||||
    See
 | 
			
		||||
    https://www.tensorflow.org/xla/operation_semantics#element-wise_comparison_operations.
 | 
			
		||||
  }];
 | 
			
		||||
 | 
			
		||||
  let arguments = (ins
 | 
			
		||||
    HLO_Tensor:$lhs,
 | 
			
		||||
    HLO_Tensor:$rhs,
 | 
			
		||||
    OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions,
 | 
			
		||||
    HLO_ComparisonDirectionAttr:$comparison_direction
 | 
			
		||||
  );
 | 
			
		||||
  let results = (outs HLO_PredTensor);
 | 
			
		||||
 | 
			
		||||
  let builders = [OpBuilder<
 | 
			
		||||
    "OpBuilder &builder, OperationState &result, Value lhs, Value rhs, "
 | 
			
		||||
    "DenseIntElementsAttr broadcast_dimensions, StringAttr comparison_direction"
 | 
			
		||||
  >];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#endif  // CHLO_OPS
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,99 @@
 | 
			
		|||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
// This file defines the operations used in the XLA dialect.
 | 
			
		||||
 | 
			
		||||
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_OPS_H_
 | 
			
		||||
#define TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_OPS_H_
 | 
			
		||||
 | 
			
		||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringRef.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Dialect.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/DialectImplementation.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Location.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpDefinition.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Types.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/InferTypeOpInterface.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/SideEffectInterfaces.h"
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h"
 | 
			
		||||
 | 
			
		||||
namespace mlir {
 | 
			
		||||
class OpBuilder;
 | 
			
		||||
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_structs.h.inc"
 | 
			
		||||
 | 
			
		||||
namespace xla_hlo {
 | 
			
		||||
 | 
			
		||||
class XlaHloDialect : public Dialect {
 | 
			
		||||
 public:
 | 
			
		||||
  explicit XlaHloDialect(MLIRContext *context);
 | 
			
		||||
  static StringRef getDialectNamespace() { return "xla_hlo"; }
 | 
			
		||||
 | 
			
		||||
  // Registered hook to materialize a constant operation from a given attribute
 | 
			
		||||
  // value with the desired resultant type.
 | 
			
		||||
  Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type,
 | 
			
		||||
                                 Location loc) override;
 | 
			
		||||
 | 
			
		||||
  // Parses a type registered to this dialect.
 | 
			
		||||
  Type parseType(DialectAsmParser &parser) const override;
 | 
			
		||||
 | 
			
		||||
  // Prints a type registered to this dialect.
 | 
			
		||||
  void printType(Type type, DialectAsmPrinter &os) const override;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
namespace HLOTypes {
 | 
			
		||||
enum Kind {
 | 
			
		||||
  Token = Type::FIRST_XLA_HLO_TYPE,
 | 
			
		||||
};
 | 
			
		||||
}  // namespace HLOTypes
 | 
			
		||||
 | 
			
		||||
class TokenType : public Type::TypeBase<TokenType, Type, TypeStorage> {
 | 
			
		||||
 public:
 | 
			
		||||
  using Base::Base;
 | 
			
		||||
 | 
			
		||||
  static TokenType get(MLIRContext *context) {
 | 
			
		||||
    return Base::get(context, HLOTypes::Token);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Support method to enable LLVM-style type casting.
 | 
			
		||||
  static bool kindof(unsigned kind) { return kind == HLOTypes::Token; }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// Shape derivation function that computes the shape of the result based on
 | 
			
		||||
// the first argument. For a 2-dimensional input tensor, this produces IR of
 | 
			
		||||
// the form
 | 
			
		||||
//
 | 
			
		||||
//  %0 = dim %arg0, 0 : memref<?x?xf32>
 | 
			
		||||
//  %1 = index_cast %0 : index to i64
 | 
			
		||||
//  %2 = dim %arg0, 1 : memref<?x?xf32>
 | 
			
		||||
//  %3 = index_cast %2 : index to i64
 | 
			
		||||
//  %4 = "xla_hlo.scalars_to_dimension_tensor"(%1, %3)
 | 
			
		||||
//    : (i64, i64) -> tensor<2xi64>
 | 
			
		||||
//
 | 
			
		||||
// and returns %4 as the shape value.
 | 
			
		||||
LogicalResult deriveShapeFromFirstOperand(
 | 
			
		||||
    OpBuilder *builder, Operation *op,
 | 
			
		||||
    SmallVectorImpl<Value> *reifiedReturnShapes);
 | 
			
		||||
 | 
			
		||||
#define GET_OP_CLASSES
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h.inc"
 | 
			
		||||
 | 
			
		||||
}  // end namespace xla_hlo
 | 
			
		||||
}  // end namespace mlir
 | 
			
		||||
 | 
			
		||||
#endif  //  TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_OPS_H_
 | 
			
		||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| 
						 | 
				
			
			@ -0,0 +1,43 @@
 | 
			
		|||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
// This is the utils file for the HLO dialect.
 | 
			
		||||
 | 
			
		||||
#ifndef HLO_UTILS
 | 
			
		||||
#define HLO_UTILS
 | 
			
		||||
 | 
			
		||||
include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpBase.td"
 | 
			
		||||
 | 
			
		||||
def NullArrayAttr : NativeCodeCall<"ArrayAttr()">;
 | 
			
		||||
 | 
			
		||||
def CastIntElementsAttr : NativeCodeCall<"$0.cast<DenseIntElementsAttr>()">;
 | 
			
		||||
 | 
			
		||||
class ConstantSplat<string value> : NativeCodeCall<
 | 
			
		||||
    "xla::getSplat(&$_builder, $0, " # value # ")">;
 | 
			
		||||
 | 
			
		||||
def NullDenseIntElementsAttr : NativeCodeCall<"DenseIntElementsAttr()">;
 | 
			
		||||
 | 
			
		||||
def BinBroadcastDimensions : NativeCodeCall<
 | 
			
		||||
    "xla::getBroadcastDimensionsAttr(&$_builder, $0, $1)">;
 | 
			
		||||
 | 
			
		||||
def BinBroadcastDimensionsNonEmpty : NativeCodeCall<
 | 
			
		||||
    "xla::getBroadcastDimensionsAttr(&$_builder, $0, $1, /*allow_empty=*/false)">;
 | 
			
		||||
 | 
			
		||||
// Here, the element type can be any integer or float type. But, note that only
 | 
			
		||||
// 32 bit integers are supported for the value.
 | 
			
		||||
class GetScalarOfType<int value> : NativeCodeCall<
 | 
			
		||||
  "xla::GetScalarOfType(getElementTypeOrSelf($0)," # value # ")">;
 | 
			
		||||
 | 
			
		||||
#endif // HLO_UTILS
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,28 @@
 | 
			
		|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_INFER_FUSIBILITY_OP_INTERFACE_H_
 | 
			
		||||
#define TENSORFLOW_COMPILER_MLIR_XLA_IR_INFER_FUSIBILITY_OP_INTERFACE_H_
 | 
			
		||||
 | 
			
		||||
#include "mlir/IR/OpDefinition.h"
 | 
			
		||||
#include "mlir/IR/StandardTypes.h"
 | 
			
		||||
 | 
			
		||||
namespace mlir {
 | 
			
		||||
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h.inc"
 | 
			
		||||
 | 
			
		||||
}  // namespace mlir
 | 
			
		||||
 | 
			
		||||
#endif  // TENSORFLOW_COMPILER_MLIR_XLA_IR_INFER_FUSIBILITY_OP_INTERFACE_H_
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,161 @@
 | 
			
		|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
// This file contains inferFusiblityOpInterface, which is used to guide
 | 
			
		||||
// fusion decision.
 | 
			
		||||
 | 
			
		||||
#ifndef MLIR_INFER_FUSIBILITY_OP_INTERFACE
 | 
			
		||||
#define MLIR_INFER_FUSIBILITY_OP_INTERFACE
 | 
			
		||||
 | 
			
		||||
include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpBase.td"
 | 
			
		||||
 | 
			
		||||
// OpInterface to query if an op is fusible and to query the shape equality
 | 
			
		||||
// constraint among the inputs and outputs of an op.
 | 
			
		||||
def InferFusibilityOpInterface : OpInterface<"InferFusibilityOpInterface"> {
 | 
			
		||||
  let description = [{
 | 
			
		||||
    Interface to query if an op is fusible and to query the shape equality
 | 
			
		||||
    constraint among the inputs and outputs of an op.
 | 
			
		||||
  }];
 | 
			
		||||
 | 
			
		||||
  let methods = [
 | 
			
		||||
    InterfaceMethod<
 | 
			
		||||
      /*desc=*/[{If true, this op can be fused with its operands
 | 
			
		||||
      }],
 | 
			
		||||
      /*retTy=*/"bool",
 | 
			
		||||
      /*methodName=*/"isFusibleWithOperand",
 | 
			
		||||
      /*args=*/(ins),
 | 
			
		||||
      /*methodBody=*/[{}],
 | 
			
		||||
      /*defaultImplementation=*/[{
 | 
			
		||||
        /// Returns whether this op can be fused with its operands
 | 
			
		||||
        return true;
 | 
			
		||||
      }]
 | 
			
		||||
    >,
 | 
			
		||||
    InterfaceMethod<
 | 
			
		||||
      /*desc=*/[{If true, this op can be fused with its consumers
 | 
			
		||||
      }],
 | 
			
		||||
      /*retTy=*/"bool",
 | 
			
		||||
      /*methodName=*/"isFusibleWithConsumer",
 | 
			
		||||
      /*args=*/(ins),
 | 
			
		||||
      /*methodBody=*/[{}],
 | 
			
		||||
      /*defaultImplementation=*/[{
 | 
			
		||||
        /// Return whether this op can be fused withh its consumers
 | 
			
		||||
        return true;
 | 
			
		||||
      }]
 | 
			
		||||
    >,
 | 
			
		||||
    InterfaceMethod<
 | 
			
		||||
      /*desc=*/"Return whether two inputs have the same shape (assuming no"
 | 
			
		||||
               "implicit broadcasting).",
 | 
			
		||||
      /*retTy=*/"bool",
 | 
			
		||||
      /*methodName=*/"inferInputsShapeEquality",
 | 
			
		||||
      /*args=*/(ins "int":$lhs, "int":$rhs),
 | 
			
		||||
      /*methodBody=*/[{}],
 | 
			
		||||
      /*defaultImplementation=*/[{
 | 
			
		||||
        /// Return whether two inputs have the same shape.
 | 
			
		||||
        Operation *op = this->getOperation();
 | 
			
		||||
        assert(lhs < op->getNumOperands() && lhs >= 0 &&
 | 
			
		||||
               rhs < op->getNumOperands() && rhs >= 0);
 | 
			
		||||
        if (lhs == rhs) return true;
 | 
			
		||||
 | 
			
		||||
        // if both lhs and rhs have static shapes, check them directly
 | 
			
		||||
        Type lhs_ty = op->getOperand(lhs).getType();
 | 
			
		||||
        Type rhs_ty = op->getOperand(rhs).getType();
 | 
			
		||||
        auto lhs_shape_type = lhs_ty.dyn_cast_or_null<RankedTensorType>();
 | 
			
		||||
        auto rhs_shape_type = rhs_ty.dyn_cast_or_null<RankedTensorType>();
 | 
			
		||||
        if (!lhs_shape_type || !lhs_shape_type.hasStaticShape() ||
 | 
			
		||||
            !rhs_shape_type || !rhs_shape_type.hasStaticShape() ||
 | 
			
		||||
            lhs_shape_type.getRank() != rhs_shape_type.getRank()) {
 | 
			
		||||
          return false;
 | 
			
		||||
        }
 | 
			
		||||
        return lhs_shape_type.getShape() == rhs_shape_type.getShape();
 | 
			
		||||
      }]
 | 
			
		||||
    >,
 | 
			
		||||
    InterfaceMethod<
 | 
			
		||||
      /*desc=*/"Return whether two outputs have the same shape (assuming no"
 | 
			
		||||
               " implicit broadcasting).",
 | 
			
		||||
      /*retTy=*/"bool",
 | 
			
		||||
      /*methodName=*/"inferOutputsShapeEquality",
 | 
			
		||||
      /*args=*/(ins "int":$lhs, "int":$rhs),
 | 
			
		||||
      /*methodBody=*/[{}],
 | 
			
		||||
      /*defaultImplementation=*/[{
 | 
			
		||||
        /// Return whether two outputs have the same shape.
 | 
			
		||||
        Operation *op = this->getOperation();
 | 
			
		||||
        assert(lhs < op->getNumResults() && lhs >= 0 &&
 | 
			
		||||
               rhs < op->getNumResults() && rhs >= 0);
 | 
			
		||||
        if (lhs == rhs) return true;
 | 
			
		||||
 | 
			
		||||
        // if both lhs and rhs have static shapes, check them directly
 | 
			
		||||
        Type lhs_ty = op->getResult(lhs).getType();
 | 
			
		||||
        Type rhs_ty = op->getResult(rhs).getType();
 | 
			
		||||
        auto lhs_shape_type = lhs_ty.dyn_cast_or_null<RankedTensorType>();
 | 
			
		||||
        auto rhs_shape_type = rhs_ty.dyn_cast_or_null<RankedTensorType>();
 | 
			
		||||
        if (!lhs_shape_type || !lhs_shape_type.hasStaticShape() ||
 | 
			
		||||
            !rhs_shape_type || !rhs_shape_type.hasStaticShape() ||
 | 
			
		||||
            lhs_shape_type.getRank() != rhs_shape_type.getRank()) {
 | 
			
		||||
          return false;
 | 
			
		||||
        }
 | 
			
		||||
        return lhs_shape_type.getShape() == rhs_shape_type.getShape();
 | 
			
		||||
      }]
 | 
			
		||||
    >,
 | 
			
		||||
    InterfaceMethod<
 | 
			
		||||
      /*desc=*/"Return whether the input and the output have the same"
 | 
			
		||||
               " shape (assuming no implicit broadcasting).",
 | 
			
		||||
      /*retTy=*/"bool",
 | 
			
		||||
      /*methodName=*/"inferInputOutputShapeEquality",
 | 
			
		||||
      /*args=*/(ins "int":$input, "int":$output),
 | 
			
		||||
      /*methodBody=*/[{}],
 | 
			
		||||
      /*defaultImplementation=*/[{
 | 
			
		||||
        /// Return whether the input and the output have the same shape.
 | 
			
		||||
        Operation *op = this->getOperation();
 | 
			
		||||
        assert(input < op->getNumOperands() && input >= 0 &&
 | 
			
		||||
               output < op->getNumResults() && output >= 0);
 | 
			
		||||
 | 
			
		||||
        // if both input and output have static shapes, check them directly
 | 
			
		||||
        Type input_ty = op->getOperand(input).getType();
 | 
			
		||||
        Type output_ty = op->getResult(output).getType();
 | 
			
		||||
        auto input_shape_type = input_ty.dyn_cast_or_null<RankedTensorType>();
 | 
			
		||||
        auto output_shape_type = output_ty.dyn_cast_or_null<RankedTensorType>();
 | 
			
		||||
        if (!input_shape_type || !input_shape_type.hasStaticShape() ||
 | 
			
		||||
            !output_shape_type || !output_shape_type.hasStaticShape() ||
 | 
			
		||||
            input_shape_type.getRank() != output_shape_type.getRank()) {
 | 
			
		||||
          return false;
 | 
			
		||||
        }
 | 
			
		||||
        return input_shape_type.getShape() == output_shape_type.getShape();
 | 
			
		||||
      }]
 | 
			
		||||
    >,
 | 
			
		||||
    InterfaceMethod<
 | 
			
		||||
      /*desc=*/[{Return the effective workload shape for the operation.
 | 
			
		||||
 | 
			
		||||
      Here the effective workload shape roughly represents the maximum
 | 
			
		||||
      parallelism can be used during the codegen stage. It's used to check
 | 
			
		||||
      the shape-compatibility of the operation. During fusion, we only
 | 
			
		||||
      try to fuse shape-compatible ops for performace.
 | 
			
		||||
      For example, the effective workload shape of an elementwise op is its
 | 
			
		||||
      output shape, while the effective workload shape of a reduction op may
 | 
			
		||||
      be its operand shape.
 | 
			
		||||
      Return None if such an inference is not possible.
 | 
			
		||||
      }],
 | 
			
		||||
      /*retTy=*/"llvm::Optional<Value>",
 | 
			
		||||
      /*methodName=*/"inferEffectiveWorkloadShape",
 | 
			
		||||
      /*args=*/(ins),
 | 
			
		||||
      /*methodBody=*/[{}],
 | 
			
		||||
      /*defaultImplementation=*/[{
 | 
			
		||||
        /// Return effective workload size if possible, otherwise None.
 | 
			
		||||
        return {};
 | 
			
		||||
      }]
 | 
			
		||||
    >,
 | 
			
		||||
  ];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#endif
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,52 @@
 | 
			
		|||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
// This file defines the operations used in the LXLA dialect.
 | 
			
		||||
 | 
			
		||||
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_LHLO_OPS_H_
 | 
			
		||||
#define TENSORFLOW_COMPILER_MLIR_XLA_IR_LHLO_OPS_H_
 | 
			
		||||
 | 
			
		||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringRef.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Dialect.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Location.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpDefinition.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Types.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/SideEffectInterfaces.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/ViewLikeInterface.h"
 | 
			
		||||
 | 
			
		||||
namespace mlir {
 | 
			
		||||
class OpBuilder;
 | 
			
		||||
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_structs.h.inc"
 | 
			
		||||
 | 
			
		||||
namespace xla_lhlo {
 | 
			
		||||
 | 
			
		||||
class XlaLhloDialect : public Dialect {
 | 
			
		||||
 public:
 | 
			
		||||
  explicit XlaLhloDialect(MLIRContext *context);
 | 
			
		||||
  static StringRef getDialectNamespace() { return "xla_lhlo"; }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
#define GET_OP_CLASSES
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc"
 | 
			
		||||
 | 
			
		||||
}  // namespace xla_lhlo
 | 
			
		||||
}  // end namespace mlir
 | 
			
		||||
 | 
			
		||||
#endif  // TENSORFLOW_COMPILER_MLIR_XLA_IR_LHLO_OPS_H_
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,796 @@
 | 
			
		|||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
// This is the operation definition file for LXLA.
 | 
			
		||||
//
 | 
			
		||||
// This file largely overlaps with hlo_ops.td at a logic level. It's tempting to
 | 
			
		||||
// merge these two files together, but we need to consider the following
 | 
			
		||||
// obstacles:
 | 
			
		||||
// * We need to have a common representation for arguments. That is to say,
 | 
			
		||||
// HLO_Array<X> translates to HLO_Tensor<X> in HLO dialect, and
 | 
			
		||||
// Arg<LHLO_Buffer<X>, "", [Mem(Read|Write)]> in LHLO. Array types within tuples
 | 
			
		||||
// also need to be transformed.
 | 
			
		||||
// * As of now, TableGen's dag functions are not sufficient to accomplish the
 | 
			
		||||
// one above.
 | 
			
		||||
// * Traits aren't identical, but need to be coped. For example,
 | 
			
		||||
// SameOperandAndResultType in HLO corresponds to SameTypeOperands in LHLO.
 | 
			
		||||
// * Also, currently HLO describes the API in XLA's client side, not service
 | 
			
		||||
// side. LHLO aims for the service side.
 | 
			
		||||
 | 
			
		||||
#ifndef LHLO_OPS
 | 
			
		||||
#define LHLO_OPS
 | 
			
		||||
 | 
			
		||||
include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpBase.td"
 | 
			
		||||
include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/SideEffectInterfaces.td"
 | 
			
		||||
include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/ViewLikeInterface.td"
 | 
			
		||||
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td"
 | 
			
		||||
 | 
			
		||||
def LHLO_Dialect : Dialect {
 | 
			
		||||
  let name = "xla_lhlo";
 | 
			
		||||
  let cppNamespace = "xla_lhlo";
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// XLA type definitions.
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
// Any integer tensor types
 | 
			
		||||
def LHLO_IntBuffer : MemRefOf<[HLO_Int]>;
 | 
			
		||||
 | 
			
		||||
// Any floating-point tensor types
 | 
			
		||||
def LHLO_FpBuffer : MemRefOf<[AnyFloat]>;
 | 
			
		||||
 | 
			
		||||
def LHLO_ComplexBuffer : MemRefOf<[AnyComplex]>;
 | 
			
		||||
 | 
			
		||||
def LHLO_FpOrComplexBuffer : MemRefOf<[AnyFloat, AnyComplex]>;
 | 
			
		||||
 | 
			
		||||
def LHLO_PredBuffer : MemRefOf<[HLO_Pred]>;
 | 
			
		||||
 | 
			
		||||
// Any integer or floating-point tensor types
 | 
			
		||||
def LHLO_IntOrFpBuffer : MemRefOf<[HLO_Int, AnyFloat]>;
 | 
			
		||||
 | 
			
		||||
def LHLO_PredOrIntBuffer : MemRefOf<[HLO_Int, HLO_Pred]>;
 | 
			
		||||
 | 
			
		||||
def LHLO_Buffer : MemRefOf<[AnyFloat, AnySignlessInteger, AnyComplex]>;
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// XLA nullary op definitions.
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
class LHLO_Op<string mnemonic, list<OpTrait> traits> :
 | 
			
		||||
  Op<LHLO_Dialect, mnemonic,
 | 
			
		||||
    !listconcat([MemoryEffects<[MemRead, MemWrite]>], traits)>;
 | 
			
		||||
 | 
			
		||||
def LHLO_ConstOp : LHLO_Op<"constant", []>, BASE_HLO_ConstOp {
 | 
			
		||||
  let arguments = (ins
 | 
			
		||||
    ElementsAttr:$value,
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemWrite]>:$output
 | 
			
		||||
  );
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
def LHLO_IotaOp : LHLO_Op<"iota", []>, BASE_HLO_IotaOp {
 | 
			
		||||
  let arguments = (ins I64Attr:$iota_dimension,
 | 
			
		||||
                   Arg<LHLO_Buffer, "", [MemWrite]>:$output);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// XLA unary elementwise op definitions.
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions
 | 
			
		||||
 | 
			
		||||
class LHLO_UnaryElementwiseOp<string mnemonic,
 | 
			
		||||
                              Type BufferType = LHLO_Buffer,
 | 
			
		||||
                              list<OpTrait> traits = [SameTypeOperands]>
 | 
			
		||||
    : LHLO_Op<mnemonic, traits> {
 | 
			
		||||
  let arguments = (ins Arg<BufferType, "", [MemRead]>:$input,
 | 
			
		||||
                       Arg<BufferType, "", [MemWrite]>:$output);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
def LHLO_AbsOp: LHLO_UnaryElementwiseOp<"abs">, BASE_HLO_AbsOp;
 | 
			
		||||
 | 
			
		||||
// TODO(timshen): add a custom verifier.
 | 
			
		||||
def LHLO_BitcastConvertOp:
 | 
			
		||||
    LHLO_UnaryElementwiseOp<"bitcast_convert", LHLO_Buffer, [SameOperandsShape]>, BASE_HLO_BitcastConvertOp;
 | 
			
		||||
 | 
			
		||||
def LHLO_CeilOp: LHLO_UnaryElementwiseOp<"ceil", LHLO_FpBuffer>, BASE_HLO_CeilOp;
 | 
			
		||||
 | 
			
		||||
def LHLO_ClzOp: LHLO_UnaryElementwiseOp<"count_leading_zeros", LHLO_IntBuffer>, BASE_HLO_ClzOp;
 | 
			
		||||
 | 
			
		||||
// TODO(timshen): add a custom verifier.
 | 
			
		||||
def LHLO_ConvertOp : LHLO_UnaryElementwiseOp<"convert", LHLO_Buffer, [SameOperandsShape]>, BASE_HLO_ConvertOp;
 | 
			
		||||
 | 
			
		||||
def LHLO_CosOp: LHLO_UnaryElementwiseOp<"cosine", LHLO_FpOrComplexBuffer>, BASE_HLO_CosOp;
 | 
			
		||||
 | 
			
		||||
def LHLO_ExpOp: LHLO_UnaryElementwiseOp<"exponential", LHLO_FpOrComplexBuffer>, BASE_HLO_ExpOp;
 | 
			
		||||
 | 
			
		||||
def LHLO_Expm1Op: LHLO_UnaryElementwiseOp<"exponential_minus_one", LHLO_FpOrComplexBuffer>, BASE_HLO_Expm1Op;
 | 
			
		||||
 | 
			
		||||
def LHLO_FloorOp: LHLO_UnaryElementwiseOp<"floor", LHLO_FpBuffer>, BASE_HLO_FloorOp;
 | 
			
		||||
 | 
			
		||||
def LHLO_ImagOp: LHLO_Op<"imag", [SameOperandsShape]>, BASE_HLO_ImagOp {
 | 
			
		||||
  let arguments = (ins Arg<LHLO_ComplexBuffer, "", [MemRead]>:$input,
 | 
			
		||||
                       Arg<LHLO_FpBuffer, "", [MemWrite]>:$output);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
def LHLO_IsFiniteOp: LHLO_Op<"is_finite", [SameOperandsShape]>, BASE_HLO_IsFiniteOp {
 | 
			
		||||
  let arguments = (ins Arg<LHLO_FpBuffer, "", [MemRead]>:$input,
 | 
			
		||||
                       Arg<LHLO_PredBuffer, "", [MemWrite]>:$output);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
def LHLO_LogOp: LHLO_UnaryElementwiseOp<"log", LHLO_FpOrComplexBuffer>, BASE_HLO_LogOp;
 | 
			
		||||
 | 
			
		||||
def LHLO_Log1pOp: LHLO_UnaryElementwiseOp<"log_plus_one", LHLO_FpOrComplexBuffer>, BASE_HLO_Log1pOp;
 | 
			
		||||
 | 
			
		||||
def LHLO_NegOp: LHLO_UnaryElementwiseOp<"negate">, BASE_HLO_NegOp;
 | 
			
		||||
 | 
			
		||||
def LHLO_NotOp: LHLO_UnaryElementwiseOp<"not", LHLO_PredOrIntBuffer>, BASE_HLO_NotOp;
 | 
			
		||||
 | 
			
		||||
def LHLO_PopulationCountOp: LHLO_UnaryElementwiseOp<"popcnt", LHLO_IntBuffer>, BASE_HLO_PopulationCountOp;
 | 
			
		||||
 | 
			
		||||
def LHLO_RealOp: LHLO_Op<"real", [SameOperandsShape]>, BASE_HLO_RealOp {
 | 
			
		||||
  let arguments = (ins Arg<LHLO_ComplexBuffer, "", [MemRead]>:$input,
 | 
			
		||||
                       Arg<LHLO_FpBuffer, "", [MemWrite]>:$output);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
def LHLO_RoundOp: LHLO_UnaryElementwiseOp<"round_nearest_afz", LHLO_FpBuffer>, BASE_HLO_RoundOp;
 | 
			
		||||
 | 
			
		||||
def LHLO_RsqrtOp: LHLO_UnaryElementwiseOp<"rsqrt", LHLO_FpOrComplexBuffer>, BASE_HLO_RsqrtOp;
 | 
			
		||||
 | 
			
		||||
def LHLO_SqrtOp: LHLO_UnaryElementwiseOp<"sqrt", LHLO_FpOrComplexBuffer>, BASE_HLO_SqrtOp;
 | 
			
		||||
 | 
			
		||||
def LHLO_SignOp: LHLO_UnaryElementwiseOp<"sign">, BASE_HLO_SignOp;
 | 
			
		||||
 | 
			
		||||
def LHLO_SinOp: LHLO_UnaryElementwiseOp<"sine", LHLO_FpOrComplexBuffer>, BASE_HLO_SinOp;
 | 
			
		||||
 | 
			
		||||
def LHLO_TanhOp: LHLO_UnaryElementwiseOp<"tanh", LHLO_FpOrComplexBuffer>, BASE_HLO_TanhOp;
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// XLA binary elementwise op definitions.
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations
 | 
			
		||||
 | 
			
		||||
class LHLO_BinaryElementwiseOp<string mnemonic, Type BufferType = LHLO_Buffer,
 | 
			
		||||
                               list<OpTrait> traits = [SameTypeOperands]> :
 | 
			
		||||
        LHLO_Op<mnemonic, traits> {
 | 
			
		||||
  let arguments = (ins
 | 
			
		||||
      Arg<BufferType, "", [MemRead]>:$lhs,
 | 
			
		||||
      Arg<BufferType, "", [MemRead]>:$rhs,
 | 
			
		||||
      Arg<BufferType, "", [MemWrite]>:$out,
 | 
			
		||||
      OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions
 | 
			
		||||
  );
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
def LHLO_AddOp : LHLO_BinaryElementwiseOp<"add">, BASE_HLO_AddOp;
 | 
			
		||||
 | 
			
		||||
def LHLO_AndOp: LHLO_BinaryElementwiseOp<"and", LHLO_PredOrIntBuffer>, BASE_HLO_AndOp;
 | 
			
		||||
 | 
			
		||||
def LHLO_Atan2Op : LHLO_BinaryElementwiseOp<"atan2", LHLO_FpOrComplexBuffer>, BASE_HLO_Atan2Op;
 | 
			
		||||
 | 
			
		||||
def LHLO_ComplexOp: LHLO_Op<"complex", [SameOperandsShape]>, BASE_HLO_ComplexOp {
 | 
			
		||||
  let arguments = (ins
 | 
			
		||||
      Arg<LHLO_FpBuffer, "", [MemRead]>:$lhs,
 | 
			
		||||
      Arg<LHLO_FpBuffer, "", [MemRead]>:$rhs,
 | 
			
		||||
      Arg<LHLO_ComplexBuffer, "", [MemWrite]>:$output,
 | 
			
		||||
      OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions
 | 
			
		||||
  );
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
def LHLO_DivOp : LHLO_BinaryElementwiseOp<"divide">, BASE_HLO_DivOp;
 | 
			
		||||
 | 
			
		||||
def LHLO_MaxOp : LHLO_BinaryElementwiseOp<"maximum">, BASE_HLO_MaxOp;
 | 
			
		||||
 | 
			
		||||
def LHLO_MinOp : LHLO_BinaryElementwiseOp<"minimum">, BASE_HLO_MinOp;
 | 
			
		||||
 | 
			
		||||
def LHLO_MulOp : LHLO_BinaryElementwiseOp<"multiply">, BASE_HLO_MulOp;
 | 
			
		||||
 | 
			
		||||
def LHLO_OrOp : LHLO_BinaryElementwiseOp<"or", LHLO_PredOrIntBuffer>, BASE_HLO_OrOp;
 | 
			
		||||
 | 
			
		||||
def LHLO_PowOp : LHLO_BinaryElementwiseOp<"power">, BASE_HLO_PowOp;
 | 
			
		||||
 | 
			
		||||
def LHLO_RemOp : LHLO_BinaryElementwiseOp<"remainder", LHLO_IntOrFpBuffer>, BASE_HLO_RemOp;
 | 
			
		||||
 | 
			
		||||
def LHLO_ShiftLeftOp : LHLO_BinaryElementwiseOp<"shift_left", LHLO_IntBuffer>, BASE_HLO_ShiftLeftOp;
 | 
			
		||||
 | 
			
		||||
def LHLO_ShiftRightArithmeticOp : LHLO_BinaryElementwiseOp<"shift_right_arithmetic", LHLO_IntBuffer>, BASE_HLO_ShiftRightArithmeticOp;
 | 
			
		||||
 | 
			
		||||
def LHLO_ShiftRightLogicalOp : LHLO_BinaryElementwiseOp<"shift_right_logical", LHLO_IntBuffer>, BASE_HLO_ShiftRightLogicalOp;
 | 
			
		||||
 | 
			
		||||
def LHLO_SubOp : LHLO_BinaryElementwiseOp<"subtract">, BASE_HLO_SubOp;
 | 
			
		||||
 | 
			
		||||
def LHLO_XorOp : LHLO_BinaryElementwiseOp<"xor", LHLO_PredOrIntBuffer>, BASE_HLO_XorOp;
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// XLA control flow op definitions.
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
// TODO(b/139813999): specify required function signature in a type-safe way.
 | 
			
		||||
def LHLO_ReduceOp: LHLO_Op<"reduce", [
 | 
			
		||||
      SameVariadicOperandSize,
 | 
			
		||||
      SingleBlockImplicitTerminator<"TerminatorOp">
 | 
			
		||||
    ]>, BASE_HLO_ReduceOp {
 | 
			
		||||
  let arguments = (ins
 | 
			
		||||
    Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$operands,
 | 
			
		||||
    Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$init_values,
 | 
			
		||||
    Arg<Variadic<LHLO_Buffer>, "", [MemWrite]>:$out,
 | 
			
		||||
    I64ElementsAttr:$dimensions
 | 
			
		||||
  );
 | 
			
		||||
 | 
			
		||||
  let regions = (region SizedRegion<1>:$body);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
def LHLO_ReduceWindowOp: LHLO_Op<"reduce_window", [
 | 
			
		||||
      SingleBlockImplicitTerminator<"TerminatorOp">
 | 
			
		||||
    ]>, BASE_HLO_ReduceWindowOp {
 | 
			
		||||
 | 
			
		||||
  let arguments = (ins
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemRead]>:$operand,
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemRead]>:$init_value,
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemWrite]>:$out,
 | 
			
		||||
    I64ElementsAttr:$window_dimensions,
 | 
			
		||||
    // If strides or dilations attributes are missing then the default value is
 | 
			
		||||
    // one for each of the input dimensions. Similarly, padding values are zero
 | 
			
		||||
    // for both low and high in each of the dimensions, if not specified.
 | 
			
		||||
    OptionalAttr<I64ElementsAttr>:$window_strides,
 | 
			
		||||
    OptionalAttr<I64ElementsAttr>:$base_dilations,
 | 
			
		||||
    OptionalAttr<I64ElementsAttr>:$window_dilations,
 | 
			
		||||
    OptionalAttr<I64ElementsAttr>:$padding
 | 
			
		||||
  );
 | 
			
		||||
 | 
			
		||||
  let regions = (region SizedRegion<1>:$body);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TODO(timshen): Add a custom parser to hide operand_segment_sizes. For example,
 | 
			
		||||
// A tuple-like pattern match syntax could work:
 | 
			
		||||
// xla_lhlo.case %index, (%input0, %input1, %input2), (%output0, %output1) {
 | 
			
		||||
//   ...
 | 
			
		||||
// }, {
 | 
			
		||||
//   ...
 | 
			
		||||
// } : (type_input0, type_input1, type_input2, type_output0, type_output1) -> ()
 | 
			
		||||
def LHLO_CaseOp: LHLO_Op<"case", [
 | 
			
		||||
      AttrSizedOperandSegments,
 | 
			
		||||
      SingleBlockImplicitTerminator<"TerminatorOp">
 | 
			
		||||
    ]>, BASE_HLO_CaseOp {
 | 
			
		||||
 | 
			
		||||
  let arguments = (ins
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemRead]>:$index,
 | 
			
		||||
    Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$branch_operands,
 | 
			
		||||
    Arg<Variadic<LHLO_Buffer>, "", [MemWrite]>:$out
 | 
			
		||||
  );
 | 
			
		||||
 | 
			
		||||
  let regions = (region VariadicRegion<SizedRegion<1>>:$branches);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TODO(timshen): Add a custom syntax for this.
 | 
			
		||||
def LHLO_WhileOp: LHLO_Op<"while", [SameVariadicOperandSize]>,
 | 
			
		||||
                  BASE_HLO_WhileOp {
 | 
			
		||||
  let arguments = (ins
 | 
			
		||||
    Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$val,
 | 
			
		||||
    Arg<Variadic<LHLO_Buffer>, "", [MemWrite]>:$output
 | 
			
		||||
  );
 | 
			
		||||
 | 
			
		||||
  let regions = (region SizedRegion<1>:$cond, SizedRegion<1>:$body);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// XLA tuple op definitions.
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
def LHLO_CompareOp: LHLO_Op<"compare", []>, BASE_HLO_CompareOp {
 | 
			
		||||
  let arguments = (ins
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
 | 
			
		||||
    Arg<LHLO_PredBuffer, "", [MemWrite]>:$out,
 | 
			
		||||
    OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions,
 | 
			
		||||
    HLO_ComparisonDirectionAttr:$comparison_direction
 | 
			
		||||
  );
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// XLA Slice definitions.
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
def LHLO_SliceOp: LHLO_Op<
 | 
			
		||||
      "slice",
 | 
			
		||||
      [AllTypesMatch<["start_indices", "limit_indices", "strides"]>]> {
 | 
			
		||||
  let arguments = (ins
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemRead]>:$operand,
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemWrite]>:$output,
 | 
			
		||||
    I64ElementsAttr:$start_indices,
 | 
			
		||||
    I64ElementsAttr:$limit_indices,
 | 
			
		||||
    I64ElementsAttr:$strides
 | 
			
		||||
  );
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
def HLO_DynamicUpdateSliceOp: LHLO_Op<"dynamic-update-slice", []> {
 | 
			
		||||
  let arguments = (ins
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemRead]>:$operand,
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemRead]>:$update,
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemWrite]>:$output,
 | 
			
		||||
    Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$start_indices
 | 
			
		||||
  );
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// StaticMemRefCastOp
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
def HLO_StaticMemRefCastOp: Op<LHLO_Dialect, "static_memref_cast",
 | 
			
		||||
    [NoSideEffect, DeclareOpInterfaceMethods<ViewLikeOpInterface>]> {
 | 
			
		||||
  let summary = [{
 | 
			
		||||
    "modifies the offset, sizes and strides of a statically shaped memref.
 | 
			
		||||
  }];
 | 
			
		||||
  let description = [{
 | 
			
		||||
    Allows to modify the offset, sizes and strides of a statically shaped memref.
 | 
			
		||||
 | 
			
		||||
    Example:
 | 
			
		||||
    ```mlir
 | 
			
		||||
    %buf_transformed =
 | 
			
		||||
        xla_lhlo.static_memref_cast %buf
 | 
			
		||||
        : memref<1x5xf32> -> memref<5xf32, offset: 2, strides: [1]>
 | 
			
		||||
 | 
			
		||||
    // The result of the op is a rank-1 memref with `[5]` shape, stride 1 and
 | 
			
		||||
    // offset 2.
 | 
			
		||||
    ```
 | 
			
		||||
  }];
 | 
			
		||||
 | 
			
		||||
  let arguments = (ins Arg<LHLO_Buffer, "", []>:$operand);
 | 
			
		||||
  let results = (outs Res<LHLO_Buffer, "", []>:$result);
 | 
			
		||||
 | 
			
		||||
  let builders = [OpBuilder<
 | 
			
		||||
    "OpBuilder &builder, OperationState &result, MemRefType resultType, " #
 | 
			
		||||
    "Value operand", [{
 | 
			
		||||
       result.addOperands(operand);
 | 
			
		||||
       result.types.push_back(resultType);
 | 
			
		||||
     }]>];
 | 
			
		||||
 | 
			
		||||
  let extraClassDeclaration = [{
 | 
			
		||||
    MemRefType getType() { return getResult().getType().cast<MemRefType>(); }
 | 
			
		||||
  }];
 | 
			
		||||
 | 
			
		||||
  let verifier = [{ return Verify(*this); }];
 | 
			
		||||
  let assemblyFormat = [{
 | 
			
		||||
    $operand attr-dict `:` type($operand) `->` type($result)
 | 
			
		||||
  }];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// DynamicMemRefCastOp
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
def HLO_DynamicMemRefCastOp: Op<LHLO_Dialect, "dynamic_memref_cast",
 | 
			
		||||
    [SameVariadicOperandSize, NoSideEffect,
 | 
			
		||||
     DeclareOpInterfaceMethods<ViewLikeOpInterface>]> {
 | 
			
		||||
  let summary = "dynamic memref cast operation";
 | 
			
		||||
  let description = [{
 | 
			
		||||
    Change sizes and strides of a memref using the values computed in runtime.
 | 
			
		||||
 | 
			
		||||
    Example:
 | 
			
		||||
    ```mlir
 | 
			
		||||
    %buf_transformed =
 | 
			
		||||
        xla_lhlo.dynamic_memref_cast %buf(%size_X, %size_Y)[%step_X, %step_Y]
 | 
			
		||||
        : memref<?x?xf32> -> memref<?x?xf32, offset: 0, strides: [?, ?]>
 | 
			
		||||
    // The result of the op is a type-erased memref with `[%size_X, %size_Y]`
 | 
			
		||||
    // shape and `[%step_X, %step_Y]` strides. The offset will be inherited
 | 
			
		||||
    // from the input.
 | 
			
		||||
    ```
 | 
			
		||||
  }];
 | 
			
		||||
 | 
			
		||||
  let arguments = (ins
 | 
			
		||||
    Arg<LHLO_Buffer, "", []>:$operand,
 | 
			
		||||
    Variadic<Index>:$sizes,
 | 
			
		||||
    Variadic<Index>:$strides
 | 
			
		||||
  );
 | 
			
		||||
  let results = (outs Res<LHLO_Buffer, "", []>:$result);
 | 
			
		||||
 | 
			
		||||
  let builders = [OpBuilder<
 | 
			
		||||
    "OpBuilder &builder, OperationState &result, MemRefType resultType, " #
 | 
			
		||||
    "Value operand, ValueRange sizes, ValueRange strides", [{
 | 
			
		||||
       result.addOperands(operand);
 | 
			
		||||
       result.addOperands(sizes);
 | 
			
		||||
       result.addOperands(strides);
 | 
			
		||||
       result.types.push_back(resultType);
 | 
			
		||||
     }]>];
 | 
			
		||||
 | 
			
		||||
  let extraClassDeclaration = [{
 | 
			
		||||
    MemRefType getType() { return getResult().getType().cast<MemRefType>(); }
 | 
			
		||||
  }];
 | 
			
		||||
 | 
			
		||||
  let verifier = [{ return Verify(*this); }];
 | 
			
		||||
  let assemblyFormat = [{
 | 
			
		||||
    $operand `(` $sizes `)` `[` $strides `]` attr-dict `:` type($operand) `->`
 | 
			
		||||
    type($result)
 | 
			
		||||
  }];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// XLA Other op definitions.
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
def LHLO_BatchNormGradOp : LHLO_Op<"batch_norm_grad", []>,
 | 
			
		||||
    BASE_HLO_BatchNormGradOp {
 | 
			
		||||
 | 
			
		||||
  let arguments = (ins
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemRead]>:$operand,
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemRead]>:$scale,
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemRead]>:$mean,
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemRead]>:$variance,
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemRead]>:$grad_output,
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemWrite]>:$grad_operand,  // gradient of $operand.
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemWrite]>:$grad_scale,
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemWrite]>:$grad_offset,
 | 
			
		||||
    F32Attr:$epsilon,
 | 
			
		||||
    I64Attr:$feature_index
 | 
			
		||||
  );
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
def LHLO_BatchNormInferenceOp : LHLO_Op<"batch_norm_inference", []>,
 | 
			
		||||
    BASE_HLO_BatchNormInferenceOp {
 | 
			
		||||
 | 
			
		||||
  let arguments = (ins
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemRead]>:$operand,
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemRead]>:$scale,
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemRead]>:$offset,
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemRead]>:$mean,
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemRead]>:$variance,
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemWrite]>:$output,
 | 
			
		||||
    F32Attr:$epsilon,
 | 
			
		||||
    I64Attr:$feature_index
 | 
			
		||||
  );
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
def LHLO_BatchNormTrainingOp : LHLO_Op<"batch_norm_training", []>,
 | 
			
		||||
    BASE_HLO_BatchNormTrainingOp {
 | 
			
		||||
 | 
			
		||||
  let arguments = (ins
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemRead]>:$operand,
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemRead]>:$scale,
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemRead]>:$offset,
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemWrite]>:$output,
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemWrite]>:$batch_mean,
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemWrite]>:$batch_var,
 | 
			
		||||
    F32Attr:$epsilon,
 | 
			
		||||
    I64Attr:$feature_index
 | 
			
		||||
  );
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TODO(timshen): add a custom verifier.
 | 
			
		||||
def LHLO_BitcastOp: LHLO_Op<"bitcast", []> {
 | 
			
		||||
  let arguments = (ins Arg<LHLO_Buffer, "", [MemRead]>:$input,
 | 
			
		||||
                       Arg<LHLO_Buffer, "", [MemWrite]>:$output);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
def LHLO_BroadcastOp : LHLO_Op<"broadcast",
 | 
			
		||||
      []>, BASE_HLO_BroadcastOp {
 | 
			
		||||
  let arguments = (ins
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemRead]>:$operand,
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemWrite]>:$output,
 | 
			
		||||
    I64ElementsAttr:$broadcast_sizes
 | 
			
		||||
  );
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
def LHLO_BroadcastInDimOp : LHLO_Op<"broadcast_in_dim",
 | 
			
		||||
      []>, BASE_HLO_BroadcastInDimOp {
 | 
			
		||||
  let arguments = (ins
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemRead]>:$operand,
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemWrite]>:$output,
 | 
			
		||||
    BroadcastDimAttr:$broadcast_dimensions
 | 
			
		||||
  );
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
def LHLO_ClampOp : LHLO_Op<"clamp", []>, BASE_HLO_ClampOp {
 | 
			
		||||
  let arguments = (ins
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemRead]>:$min,
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemRead]>:$operand,
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemRead]>:$max,
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemWrite]>:$output
 | 
			
		||||
  );
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
def LHLO_ConcatenateOp : LHLO_Op<"concatenate", []>, BASE_HLO_ConcatenateOp {
 | 
			
		||||
   let arguments = (ins
 | 
			
		||||
     Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$val,
 | 
			
		||||
     Arg<LHLO_Buffer, "", [MemWrite]>:$output,
 | 
			
		||||
     I64Attr:$dimension
 | 
			
		||||
   );
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TODO(bondhugula): Make this struct dialect independent so that it can be
 | 
			
		||||
// shared between the HLO and LHLO dialects.
 | 
			
		||||
def ConvDimensionNumbers : StructAttr<"ConvDimensionNumbers", LHLO_Dialect, [
 | 
			
		||||
  StructFieldAttr<"input_batch_dimension",I64Attr>,
 | 
			
		||||
  StructFieldAttr<"input_feature_dimension", I64Attr>,
 | 
			
		||||
  StructFieldAttr<"input_spatial_dimensions", I64ElementsAttr>,
 | 
			
		||||
  StructFieldAttr<"kernel_input_feature_dimension", I64Attr>,
 | 
			
		||||
  StructFieldAttr<"kernel_output_feature_dimension", I64Attr>,
 | 
			
		||||
  StructFieldAttr<"kernel_spatial_dimensions", I64ElementsAttr>,
 | 
			
		||||
  StructFieldAttr<"output_batch_dimension", I64Attr>,
 | 
			
		||||
  StructFieldAttr<"output_feature_dimension", I64Attr>,
 | 
			
		||||
  StructFieldAttr<"output_spatial_dimensions", I64ElementsAttr>] > {
 | 
			
		||||
 | 
			
		||||
  let description = "Structure of dimension information for conv op";
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
def LHLO_ConvOp : LHLO_Op<"convolution", []>, BASE_HLO_ConvOp {
 | 
			
		||||
  let arguments = (ins
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemWrite]>:$output,
 | 
			
		||||
    // Default value: one for each of the spatial dimension.
 | 
			
		||||
    OptionalAttr<I64ElementsAttr>:$window_strides,
 | 
			
		||||
    // Default value: zero for each of the spatial dimension.
 | 
			
		||||
    OptionalAttr<I64ElementsAttr>:$padding,
 | 
			
		||||
    // Default value: one for each of the spatial dimension.
 | 
			
		||||
    OptionalAttr<I64ElementsAttr>:$lhs_dilation,
 | 
			
		||||
    // Default value: one for each of the spatial dimension.
 | 
			
		||||
    OptionalAttr<I64ElementsAttr>:$rhs_dilation,
 | 
			
		||||
    ConvDimensionNumbers:$dimension_numbers,
 | 
			
		||||
    I64Attr:$feature_group_count,
 | 
			
		||||
    I64Attr:$batch_group_count,
 | 
			
		||||
    HLO_PrecisionConfigAttr:$precision_config
 | 
			
		||||
  );
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
def LHLO_CopyOp: LHLO_Op<"copy", []>, BASE_HLO_CopyOp {
 | 
			
		||||
  let arguments = (ins
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemRead]>:$operand,
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemWrite]>:$output
 | 
			
		||||
  );
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
def LHLO_DotOp: LHLO_Op<"dot", []>, BASE_HLO_DotOp {
 | 
			
		||||
  let arguments = (ins
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
 | 
			
		||||
    HLO_PrecisionConfigAttr:$precision_config,
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemWrite]>:$output
 | 
			
		||||
  );
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
def LHLO_GatherOp: LHLO_Op<"gather", []>, BASE_HLO_GatherOp {
 | 
			
		||||
  let arguments = (ins
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemRead]>:$operand,
 | 
			
		||||
    Arg<LHLO_IntBuffer, "", [MemRead]>:$start_indices,
 | 
			
		||||
    I64Attr:$index_vector_dim,
 | 
			
		||||
    I64ElementsAttr:$offset_dims,
 | 
			
		||||
    I64ElementsAttr:$slice_sizes,
 | 
			
		||||
    I64ElementsAttr:$collapsed_slice_dims,
 | 
			
		||||
    I64ElementsAttr:$start_index_map,
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemWrite]>:$output
 | 
			
		||||
  );
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
def LHLO_ReshapeOp: LHLO_Op<"reshape", []>, BASE_HLO_ReshapeOp {
 | 
			
		||||
  let arguments = (ins
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemRead]>:$operand,
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemWrite]>:$output
 | 
			
		||||
  );
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
def LHLO_ScatterOp: LHLO_Op<"scatter", []>, BASE_HLO_ScatterOp {
 | 
			
		||||
  let arguments = (ins
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemRead]>:$operand,
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemRead]>:$scatter_indices,
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemRead]>:$updates,
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemWrite]>:$output,
 | 
			
		||||
    ScatterDimensionNumbers<LHLO_Dialect>:$scatter_dimension_numbers,
 | 
			
		||||
    DefaultValuedAttr<BoolAttr, "false">:$indices_are_sorted,
 | 
			
		||||
    DefaultValuedAttr<BoolAttr, "false">:$unique_indices
 | 
			
		||||
  );
 | 
			
		||||
 | 
			
		||||
  let regions = (region SizedRegion<1>:$update_computation);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
def LHLO_SelectOp: LHLO_Op<"select", []>, BASE_HLO_SelectOp {
 | 
			
		||||
  let arguments = (ins
 | 
			
		||||
    Arg<LHLO_PredBuffer, "", [MemRead]>:$pred,
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemRead]>:$on_true,
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemRead]>:$on_false,
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemWrite]>:$output
 | 
			
		||||
  );
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
def LHLO_SelectAndScatterOp: LHLO_Op<"select_and_scatter", []>,
 | 
			
		||||
      BASE_HLO_SelectAndScatterOp {
 | 
			
		||||
  let arguments = (ins
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemRead]>:$operand,
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemRead]>:$source,
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemRead]>:$init_value,
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemWrite]>:$out,
 | 
			
		||||
    OptionalAttr<I64ElementsAttr>:$window_dimensions,
 | 
			
		||||
    OptionalAttr<I64ElementsAttr>:$window_strides,
 | 
			
		||||
    OptionalAttr<I64ElementsAttr>:$padding
 | 
			
		||||
  );
 | 
			
		||||
 | 
			
		||||
  let regions = (region SizedRegion<1>:$select, SizedRegion<1>:$scatter);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
def LHLO_ReverseOp: LHLO_Op<"reverse", []>, BASE_HLO_ReverseOp {
 | 
			
		||||
  let arguments = (ins
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemRead]>:$operand,
 | 
			
		||||
    I64ElementsAttr:$dimensions,
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemWrite]>:$output
 | 
			
		||||
  );
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
def LHLO_PadOp: LHLO_Op<"pad", []>, BASE_HLO_PadOp {
 | 
			
		||||
  let arguments = (ins
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemRead]>:$operand,
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemRead]>:$padding_value,
 | 
			
		||||
    I64ElementsAttr:$edge_padding_low,
 | 
			
		||||
    I64ElementsAttr:$edge_padding_high,
 | 
			
		||||
    I64ElementsAttr:$interior_padding,
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemWrite]>:$output
 | 
			
		||||
  );
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
def LHLO_TransposeOp: LHLO_Op<"transpose", []>, BASE_HLO_TransposeOp {
 | 
			
		||||
  let arguments = (ins
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemRead]>:$operand,
 | 
			
		||||
    I64ElementsAttr:$permutation,
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemWrite]>:$output
 | 
			
		||||
  );
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
def LHLO_ReducePrecisionOp: LHLO_Op<"reduce_precision", [SameTypeOperands]>,
 | 
			
		||||
                            BASE_HLO_ReducePrecisionOp {
 | 
			
		||||
  let arguments = (ins
 | 
			
		||||
    Arg<LHLO_FpBuffer, "", [MemRead]>:$operand,
 | 
			
		||||
    Arg<LHLO_FpBuffer, "", [MemWrite]>:$output,
 | 
			
		||||
    I32Attr:$exponent_bits,
 | 
			
		||||
    I32Attr:$mantissa_bits
 | 
			
		||||
  );
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
def LHLO_AllReduceOp : LHLO_Op<"all_reduce", [SameTypeOperands]>,
 | 
			
		||||
                       BASE_HLO_AllReduceOp {
 | 
			
		||||
  let arguments = (ins
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemRead]>:$operand,
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemWrite]>:$output,
 | 
			
		||||
    I64ElementsAttr:$replica_groups,
 | 
			
		||||
    DefaultValuedAttr<BoolAttr, "false">:$constrain_layout,
 | 
			
		||||
    OptionalAttr<ChannelHandle<LHLO_Dialect>>:$channel_id,
 | 
			
		||||
    DefaultValuedAttr<BoolAttr, "false">:$use_global_device_ids
 | 
			
		||||
  );
 | 
			
		||||
  let regions = (region SizedRegion<1>:$computation);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
def LHLO_CollectivePermuteOp: LHLO_Op<"collective_permute", [SameTypeOperands]>,
 | 
			
		||||
                              BASE_HLO_CollectivePermuteOp {
 | 
			
		||||
 | 
			
		||||
  let arguments = (ins
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemRead]>:$operand,
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemWrite]>:$output,
 | 
			
		||||
    I64ElementsAttr:$source_target_pairs,
 | 
			
		||||
    OptionalAttr<ChannelHandle<LHLO_Dialect>>:$channel_id
 | 
			
		||||
  );
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
def LHLO_FftOp: LHLO_Op<"fft", []>, BASE_HLO_FftOp {
 | 
			
		||||
  let arguments = (ins
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemRead]>:$operand,
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemWrite]>:$output,
 | 
			
		||||
    HLO_FftTypeAttr:$fft_type,
 | 
			
		||||
    I64ElementsAttr:$fft_length
 | 
			
		||||
  );
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
def LHLO_CholeskyOp: LHLO_Op<"cholesky", [SameOperandsElementType]>, BASE_HLO_CholeskyOp {
 | 
			
		||||
  let arguments = (ins
 | 
			
		||||
    Arg<LHLO_FpOrComplexBuffer, "", [MemRead]>:$a,
 | 
			
		||||
    Arg<LHLO_FpOrComplexBuffer, "", [MemWrite]>:$output,
 | 
			
		||||
    DefaultValuedAttr<BoolAttr, "false">:$lower
 | 
			
		||||
  );
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
def LHLO_Infeed: LHLO_Op<"infeed", []>, BASE_HLO_InfeedOp {
 | 
			
		||||
  let arguments = (ins
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemWrite]>:$output,
 | 
			
		||||
    DefaultValuedAttr<StrAttr, "">:$config
 | 
			
		||||
  );
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
def LHLO_Outfeed: LHLO_Op<"outfeed", []> {
 | 
			
		||||
  let arguments = (ins
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemRead]>:$operand,
 | 
			
		||||
    DefaultValuedAttr<StrAttr, "">:$config
 | 
			
		||||
  );
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
def LHLO_ReplicaIdOp : LHLO_Op<"replica_id", []>, BASE_HLO_ReplicaIdOp {
 | 
			
		||||
  let arguments = (ins Arg<MemRefOf<[UI32]>, "", [MemWrite]>);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
def LHLO_TriangularSolveOp: LHLO_Op<"triangular_solve", [SameOperandsElementType]>,
 | 
			
		||||
                            BASE_HLO_TriangularSolveOp {
 | 
			
		||||
  let arguments = (ins
 | 
			
		||||
    Arg<LHLO_FpOrComplexBuffer, "", [MemRead]>:$a,
 | 
			
		||||
    Arg<LHLO_FpOrComplexBuffer, "", [MemRead]>:$b,
 | 
			
		||||
    Arg<LHLO_FpOrComplexBuffer, "", [MemWrite]>:$output,
 | 
			
		||||
    BoolAttr:$left_side,
 | 
			
		||||
    BoolAttr:$lower,
 | 
			
		||||
    BoolAttr:$unit_diagonal,
 | 
			
		||||
    HLO_TransposeAttr:$transpose_a
 | 
			
		||||
  );
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TODO(timshen): add a custom verifier.
 | 
			
		||||
def LHLO_MapOp: LHLO_Op<"map", [SameOperandsShape]>, BASE_HLO_MapOp {
 | 
			
		||||
  let arguments = (ins
 | 
			
		||||
    Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$operands,
 | 
			
		||||
    Arg<LHLO_Buffer, "", [MemWrite]>:$output,
 | 
			
		||||
    I64ElementsAttr:$dimensions
 | 
			
		||||
  );
 | 
			
		||||
  let regions = (region SizedRegion<1>:$computation);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
def LHLO_RngGetAndUpdateStateOp: LHLO_Op<"rng_get_and_update_state", []> {
 | 
			
		||||
  let arguments = (ins
 | 
			
		||||
    Arg<MemRefOf<[UI64]>, "", [MemRead, MemWrite]>:$state,
 | 
			
		||||
    I64Attr:$delta
 | 
			
		||||
  );
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TODO(timshen): add a custom verifier.
 | 
			
		||||
def LHLO_SortOp: LHLO_Op<"sort", [SameVariadicOperandSize, SameOperandsShape]>, BASE_HLO_SortOp {
 | 
			
		||||
  let arguments = (ins
 | 
			
		||||
    Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$operands,
 | 
			
		||||
    Arg<Variadic<LHLO_Buffer>, "", [MemWrite]>:$output,
 | 
			
		||||
    DefaultValuedAttr<I64Attr, "-1">:$dimension,
 | 
			
		||||
    DefaultValuedAttr<BoolAttr, "false">:$is_stable
 | 
			
		||||
  );
 | 
			
		||||
 | 
			
		||||
  let regions = (region SizedRegion<1>:$comparator);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// Late operations
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
def FusionOp : LHLO_Op<"fusion", [SingleBlockImplicitTerminator<"TerminatorOp">]> {
 | 
			
		||||
  let summary = "Fusion operator";
 | 
			
		||||
  let description = [{
 | 
			
		||||
    Models the fusion instruction generated by the XLA compiler's fusion pass.
 | 
			
		||||
 | 
			
		||||
    Fusion instructions are generated by the fusion pass of the XLA compiler.
 | 
			
		||||
    They serve as a hint to the backend that it is beneficial to emit the
 | 
			
		||||
    contained instructions into a single loop nest or kernel. The XLA fusion
 | 
			
		||||
    pass is designed such that it only generates fusion nodes that can be
 | 
			
		||||
    handled by the XLA compilers backends.
 | 
			
		||||
    The XLA runtime expects this hint to be followed, as it expects a single
 | 
			
		||||
    kernel per HLO instruction. This restriction might be lifted in the future.
 | 
			
		||||
  }];
 | 
			
		||||
  let regions = (region SizedRegion<1>:$region);
 | 
			
		||||
 | 
			
		||||
  let skipDefaultBuilders = 1;
 | 
			
		||||
  let builders = [
 | 
			
		||||
     OpBuilder<"OpBuilder &builder, OperationState &result, "
 | 
			
		||||
               "ArrayRef<NamedAttribute> attributes">
 | 
			
		||||
   ];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
def TerminatorOp :
 | 
			
		||||
    LHLO_Op<"terminator", [Terminator]> {
 | 
			
		||||
  let summary = "LHLO termination operation";
 | 
			
		||||
  let description = [{
 | 
			
		||||
    Terminator operation for the LHLO dialect.
 | 
			
		||||
  }];
 | 
			
		||||
  let builders = [OpBuilder<
 | 
			
		||||
    "OpBuilder &b, OperationState &result, ValueRange operands",
 | 
			
		||||
    [{ build(b, result, llvm::None, operands, llvm::None); }]
 | 
			
		||||
  >];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#endif // LHLO_OPS
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,49 @@
 | 
			
		|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_BROADCAST_UTILS_H_
 | 
			
		||||
#define TENSORFLOW_COMPILER_MLIR_XLA_IR_BROADCAST_UTILS_H_
 | 
			
		||||
 | 
			
		||||
// Utilities relating to implementing HLO broadcasting.
 | 
			
		||||
// Note: This file should not depend on any non-MLIR TensorFlow libraries.
 | 
			
		||||
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Location.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/InferTypeOpInterface.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Support/LLVM.h"
 | 
			
		||||
 | 
			
		||||
namespace mlir {
 | 
			
		||||
namespace xla {
 | 
			
		||||
 | 
			
		||||
// Checks whether the given operand types and broadcast_dims attr represent a
 | 
			
		||||
// legal combination for "numpy" style broadcasting (where 1-dims are prepended
 | 
			
		||||
// to the smaller ranked operand until it is of the same rank as the larger).
 | 
			
		||||
// See: https://docs.scipy.org/doc/numpy/reference/ufuncs.html
 | 
			
		||||
bool IsLegalNumpyRankedBroadcast(Value lhs, Value rhs,
 | 
			
		||||
                                 DenseIntElementsAttr broadcast_dims);
 | 
			
		||||
 | 
			
		||||
// Emits shape dialect ops to compute the result shape for a broadcasting
 | 
			
		||||
// binary elementwise op which broadcasts according to "numpy" semantics
 | 
			
		||||
// (see above), returning an extents tensor of the resulting shape.
 | 
			
		||||
Value ComputeBinaryElementwiseBroadcastingResultExtents(Location loc, Value lhs,
 | 
			
		||||
                                                        Value rhs,
 | 
			
		||||
                                                        OpBuilder& builder);
 | 
			
		||||
 | 
			
		||||
}  // namespace xla
 | 
			
		||||
}  // namespace mlir
 | 
			
		||||
 | 
			
		||||
#endif  // TENSORFLOW_COMPILER_MLIR_XLA_IR_BROADCAST_UTILS_H_
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,33 @@
 | 
			
		|||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_CONVERT_OP_FOLDER_H_
 | 
			
		||||
#define TENSORFLOW_COMPILER_MLIR_XLA_CONVERT_OP_FOLDER_H_
 | 
			
		||||
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
 | 
			
		||||
 | 
			
		||||
namespace mlir {
 | 
			
		||||
namespace xla {
 | 
			
		||||
 | 
			
		||||
// Converts the given elements attr to the specified elements type.
 | 
			
		||||
// Requires type of the elements and new_type to be either integer or float
 | 
			
		||||
// type.
 | 
			
		||||
mlir::ElementsAttr ConvertElementsAttr(const mlir::ElementsAttr& elements,
 | 
			
		||||
                                       mlir::Type new_type);
 | 
			
		||||
}  // namespace xla
 | 
			
		||||
}  // namespace mlir
 | 
			
		||||
 | 
			
		||||
#endif  // TENSORFLOW_COMPILER_MLIR_XLA_CONVERT_OP_FOLDER_H_
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,74 @@
 | 
			
		|||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_UTILS_H_
 | 
			
		||||
#define TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_UTILS_H_
 | 
			
		||||
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/TypeUtilities.h"
 | 
			
		||||
 | 
			
		||||
namespace mlir {
 | 
			
		||||
namespace xla {
 | 
			
		||||
 | 
			
		||||
// Computes the broadcast dimensions attr for an elementwise binary operator
 | 
			
		||||
// between two ranked tensors.
 | 
			
		||||
// If `allow_empty` is true, then null can be returned to mean that the
 | 
			
		||||
// broadcast is an "identity".
 | 
			
		||||
mlir::DenseIntElementsAttr getBroadcastDimensionsAttr(mlir::Builder* b,
 | 
			
		||||
                                                      mlir::Value x,
 | 
			
		||||
                                                      mlir::Value y,
 | 
			
		||||
                                                      bool allow_empty = true);
 | 
			
		||||
 | 
			
		||||
// Get a constant splat for the given value of type. Requires value to be of
 | 
			
		||||
// type static shaped RankedTensorType.
 | 
			
		||||
template <typename T>
 | 
			
		||||
static ElementsAttr getSplat(Builder* b, RankedTensorType ty, T constant) {
 | 
			
		||||
  Type element_ty = getElementTypeOrSelf(ty);
 | 
			
		||||
 | 
			
		||||
  if (element_ty.isSignlessInteger())
 | 
			
		||||
    return DenseElementsAttr::get(ty, b->getIntegerAttr(element_ty, constant));
 | 
			
		||||
 | 
			
		||||
  if (element_ty.isa<FloatType>())
 | 
			
		||||
    return DenseElementsAttr::get(ty, b->getFloatAttr(element_ty, constant));
 | 
			
		||||
 | 
			
		||||
  if (auto complex_ty = element_ty.dyn_cast<ComplexType>()) {
 | 
			
		||||
    auto complex_element_ty = complex_ty.getElementType();
 | 
			
		||||
    if (complex_element_ty.isF32())
 | 
			
		||||
      return DenseElementsAttr::get(ty,
 | 
			
		||||
                                    static_cast<std::complex<float>>(constant));
 | 
			
		||||
    if (complex_element_ty.isF64())
 | 
			
		||||
      return DenseElementsAttr::get(
 | 
			
		||||
          ty, static_cast<std::complex<double>>(constant));
 | 
			
		||||
  }
 | 
			
		||||
  llvm_unreachable("unhandled element type");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename T>
 | 
			
		||||
static ElementsAttr getSplat(Builder* b, Value val, T constant) {
 | 
			
		||||
  return getSplat(b, val.getType().cast<RankedTensorType>(), constant);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Returns DenseElementsAttr of rank zero with the given element type and the
 | 
			
		||||
// value.
 | 
			
		||||
// Requires `ty` to be either FloatType of IntegerType.
 | 
			
		||||
DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value);
 | 
			
		||||
 | 
			
		||||
}  // namespace xla
 | 
			
		||||
}  // namespace mlir
 | 
			
		||||
 | 
			
		||||
#endif  // TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_UTILS_H_
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,278 @@
 | 
			
		|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
 | 
			
		||||
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Diagnostics.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/TypeUtilities.h"
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/broadcast_utils.h"
 | 
			
		||||
 | 
			
		||||
namespace mlir {
 | 
			
		||||
namespace xla_chlo {
 | 
			
		||||
 | 
			
		||||
template <typename T>
 | 
			
		||||
static LogicalResult Verify(T op) {
 | 
			
		||||
  return success();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// BinaryOps
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
// Gets the resulting type from a broadcast between two types.
 | 
			
		||||
static Type GetBroadcastType(Type x, Type y, Type element_type,
 | 
			
		||||
                             DenseIntElementsAttr broadcast_dimensions_attr) {
 | 
			
		||||
  auto x_ranked = x.dyn_cast<RankedTensorType>();
 | 
			
		||||
  auto y_ranked = y.dyn_cast<RankedTensorType>();
 | 
			
		||||
  if (!x_ranked || !y_ranked) {
 | 
			
		||||
    return UnrankedTensorType::get(element_type);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  auto shape_x = x_ranked.getShape();
 | 
			
		||||
  auto shape_y = y_ranked.getShape();
 | 
			
		||||
 | 
			
		||||
  if (shape_x.size() == shape_y.size()) {
 | 
			
		||||
    llvm::SmallVector<int64_t, 4> out_shape(shape_x.size());
 | 
			
		||||
    for (int i = 0, e = shape_x.size(); i < e; i++) {
 | 
			
		||||
      auto x_val = shape_x[i];
 | 
			
		||||
      auto y_val = shape_y[i];
 | 
			
		||||
      if (x_val == -1 || y_val == -1) {
 | 
			
		||||
        out_shape[i] = -1;
 | 
			
		||||
      } else {
 | 
			
		||||
        out_shape[i] = std::max(x_val, y_val);
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
    return RankedTensorType::get(out_shape, element_type);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  auto shape_large = shape_x.size() > shape_y.size() ? shape_x : shape_y;
 | 
			
		||||
  auto shape_small = shape_x.size() <= shape_y.size() ? shape_x : shape_y;
 | 
			
		||||
 | 
			
		||||
  llvm::SmallVector<int64_t, 4> broadcast_dimensions;
 | 
			
		||||
  if (broadcast_dimensions_attr) {
 | 
			
		||||
    // Explicit broadcast dimensions.
 | 
			
		||||
    for (const APInt& int_value : broadcast_dimensions_attr.getIntValues()) {
 | 
			
		||||
      broadcast_dimensions.push_back(int_value.getSExtValue());
 | 
			
		||||
    }
 | 
			
		||||
    if (broadcast_dimensions.size() != shape_small.size()) {
 | 
			
		||||
      // Signal illegal broadcast_dimensions as unranked.
 | 
			
		||||
      return UnrankedTensorType::get(element_type);
 | 
			
		||||
    }
 | 
			
		||||
  } else {
 | 
			
		||||
    // If no broadcast dimensions, assume "numpy" broadcasting.
 | 
			
		||||
    broadcast_dimensions = llvm::to_vector<4>(llvm::seq<int64_t>(
 | 
			
		||||
        shape_large.size() - shape_small.size(), shape_large.size()));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  llvm::SmallVector<int64_t, 4> out_shape(shape_large.begin(),
 | 
			
		||||
                                          shape_large.end());
 | 
			
		||||
 | 
			
		||||
  // Update according to the broadcast dimensions.
 | 
			
		||||
  for (auto index_pair : llvm::enumerate(broadcast_dimensions)) {
 | 
			
		||||
    auto old_value = out_shape[index_pair.value()];
 | 
			
		||||
    auto new_value = shape_small[index_pair.index()];
 | 
			
		||||
    if (old_value != -1 && (new_value == -1 || new_value > old_value)) {
 | 
			
		||||
      out_shape[index_pair.value()] = new_value;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return RankedTensorType::get(out_shape, element_type);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
LogicalResult InferBroadcastBinaryOpReturnTypeComponents(
 | 
			
		||||
    MLIRContext* context, Optional<Location> location, ValueRange operands,
 | 
			
		||||
    DictionaryAttr attributes, Type element_type,
 | 
			
		||||
    SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
 | 
			
		||||
  // Find broadcast_dimensions.
 | 
			
		||||
  DenseIntElementsAttr broadcast_dimensions =
 | 
			
		||||
      attributes.get("broadcast_dimensions")
 | 
			
		||||
          .dyn_cast_or_null<DenseIntElementsAttr>();
 | 
			
		||||
 | 
			
		||||
  ShapedType lhs_type = operands[0].getType().dyn_cast<ShapedType>();
 | 
			
		||||
  ShapedType rhs_type = operands[1].getType().dyn_cast<ShapedType>();
 | 
			
		||||
  if (!lhs_type || !rhs_type ||
 | 
			
		||||
      lhs_type.getElementType() != rhs_type.getElementType()) {
 | 
			
		||||
    return emitOptionalError(location, "mismatched operand types");
 | 
			
		||||
  }
 | 
			
		||||
  if (!element_type) element_type = lhs_type.getElementType();
 | 
			
		||||
  Type result_type =
 | 
			
		||||
      GetBroadcastType(lhs_type, rhs_type, element_type, broadcast_dimensions);
 | 
			
		||||
 | 
			
		||||
  if (auto ranked_result_type = result_type.dyn_cast<RankedTensorType>()) {
 | 
			
		||||
    inferedReturnShapes.emplace_back(ranked_result_type.getShape(),
 | 
			
		||||
                                     element_type);
 | 
			
		||||
    return success();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // TODO(laurenzo): This should be constructing with `element_type` but that
 | 
			
		||||
  // constructor variant needs to be added upstream.
 | 
			
		||||
  inferedReturnShapes.emplace_back(/* element_type */);
 | 
			
		||||
  return success();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
LogicalResult ReifyBroadcastBinaryOpReturnTypeShapes(
 | 
			
		||||
    OpBuilder& builder, Operation* op,
 | 
			
		||||
    SmallVectorImpl<Value>& reifiedReturnShapes) {
 | 
			
		||||
  auto loc = op->getLoc();
 | 
			
		||||
  auto lhs = op->getOperand(0);
 | 
			
		||||
  auto rhs = op->getOperand(1);
 | 
			
		||||
 | 
			
		||||
  // Check for "numpy"-style rank broadcast.
 | 
			
		||||
  auto broadcast_dimensions = op->getAttr("broadcast_dimensions")
 | 
			
		||||
                                  .dyn_cast_or_null<DenseIntElementsAttr>();
 | 
			
		||||
  if (broadcast_dimensions &&
 | 
			
		||||
      !xla::IsLegalNumpyRankedBroadcast(lhs, rhs, broadcast_dimensions)) {
 | 
			
		||||
    // Note: It is unclear whether the general specification of explicit
 | 
			
		||||
    // broadcast_dimensions on binary ops is a feature we want to carry
 | 
			
		||||
    // forward. While it can technically be implemented for ranked-dynamic,
 | 
			
		||||
    // it is incompatible with unranked inputs. If this warning is emitted
 | 
			
		||||
    // in real programs, it is an indication that the feature should be
 | 
			
		||||
    // implemented versus just falling back on the more standard definition
 | 
			
		||||
    // of numpy-like prefix-padding.
 | 
			
		||||
    return op->emitWarning()
 | 
			
		||||
           << "unsupported non prefix-padded dynamic rank "
 | 
			
		||||
           << "broadcast_dimensions = " << broadcast_dimensions;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Value computed_shape = xla::ComputeBinaryElementwiseBroadcastingResultExtents(
 | 
			
		||||
      loc, lhs, rhs, builder);
 | 
			
		||||
  if (!computed_shape) return failure();
 | 
			
		||||
  reifiedReturnShapes.push_back(computed_shape);
 | 
			
		||||
  return success();
 | 
			
		||||
}
 | 
			
		||||
}  // namespace
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// BroadcastComplexOp (has custom type inference due to different result type).
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
LogicalResult BroadcastComplexOp::inferReturnTypeComponents(
 | 
			
		||||
    MLIRContext* context, Optional<Location> location, ValueRange operands,
 | 
			
		||||
    DictionaryAttr attributes, RegionRange regions,
 | 
			
		||||
    SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
 | 
			
		||||
  ShapedType lhs_type = operands[0].getType().dyn_cast<ShapedType>();
 | 
			
		||||
  if (!lhs_type) {
 | 
			
		||||
    return emitOptionalError(location, "expected ShapedType");
 | 
			
		||||
  }
 | 
			
		||||
  Type element_type = ComplexType::get(lhs_type.getElementType());
 | 
			
		||||
  return InferBroadcastBinaryOpReturnTypeComponents(context, location, operands,
 | 
			
		||||
                                                    attributes, element_type,
 | 
			
		||||
                                                    inferedReturnShapes);
 | 
			
		||||
}
 | 
			
		||||
LogicalResult BroadcastComplexOp::reifyReturnTypeShapes(
 | 
			
		||||
    OpBuilder& builder, SmallVectorImpl<Value>& reifiedReturnShapes) {
 | 
			
		||||
  return ReifyBroadcastBinaryOpReturnTypeShapes(builder, getOperation(),
 | 
			
		||||
                                                reifiedReturnShapes);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// BroadcastCompareOp (has custom type inference due to different result type).
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
void BroadcastCompareOp::build(OpBuilder& builder, OperationState& result,
 | 
			
		||||
                               Value lhs, Value rhs,
 | 
			
		||||
                               DenseIntElementsAttr broadcast_dimensions,
 | 
			
		||||
                               StringAttr comparison_direction) {
 | 
			
		||||
  auto new_type = GetBroadcastType(lhs.getType(), rhs.getType(),
 | 
			
		||||
                                   builder.getI1Type(), broadcast_dimensions);
 | 
			
		||||
  build(builder, result, new_type, lhs, rhs, broadcast_dimensions,
 | 
			
		||||
        comparison_direction);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
LogicalResult BroadcastCompareOp::inferReturnTypeComponents(
 | 
			
		||||
    MLIRContext* context, Optional<Location> location, ValueRange operands,
 | 
			
		||||
    DictionaryAttr attributes, RegionRange regions,
 | 
			
		||||
    SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
 | 
			
		||||
  Type element_type = IntegerType::get(1, context);
 | 
			
		||||
  return InferBroadcastBinaryOpReturnTypeComponents(context, location, operands,
 | 
			
		||||
                                                    attributes, element_type,
 | 
			
		||||
                                                    inferedReturnShapes);
 | 
			
		||||
}
 | 
			
		||||
LogicalResult BroadcastCompareOp::reifyReturnTypeShapes(
 | 
			
		||||
    OpBuilder& builder, SmallVectorImpl<Value>& reifiedReturnShapes) {
 | 
			
		||||
  return ReifyBroadcastBinaryOpReturnTypeShapes(builder, getOperation(),
 | 
			
		||||
                                                reifiedReturnShapes);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// Macros for method definitions that are common to most broadcasting ops.
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
#define BROADCAST_INFER_SHAPE_TYPE_OP_DEFS(Op)                                \
 | 
			
		||||
  LogicalResult Op::inferReturnTypeComponents(                                \
 | 
			
		||||
      MLIRContext* context, Optional<Location> location, ValueRange operands, \
 | 
			
		||||
      DictionaryAttr attributes, RegionRange regions,                         \
 | 
			
		||||
      SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {           \
 | 
			
		||||
    return InferBroadcastBinaryOpReturnTypeComponents(                        \
 | 
			
		||||
        context, location, operands, attributes, /*element_type=*/nullptr,    \
 | 
			
		||||
        inferedReturnShapes);                                                 \
 | 
			
		||||
  }                                                                           \
 | 
			
		||||
  LogicalResult Op::reifyReturnTypeShapes(                                    \
 | 
			
		||||
      OpBuilder& builder, SmallVectorImpl<Value>& reifiedReturnShapes) {      \
 | 
			
		||||
    return ReifyBroadcastBinaryOpReturnTypeShapes(builder, getOperation(),    \
 | 
			
		||||
                                                  reifiedReturnShapes);       \
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
#define BROADCAST_BINARY_OP_DEFS(Op)                                           \
 | 
			
		||||
  void Op::build(OpBuilder& builder, OperationState& result, Value left,       \
 | 
			
		||||
                 Value right, DenseIntElementsAttr broadcast_dimensions) {     \
 | 
			
		||||
    auto type = GetBroadcastType(                                              \
 | 
			
		||||
        left.getType().cast<ShapedType>(), right.getType().cast<ShapedType>(), \
 | 
			
		||||
        getElementTypeOrSelf(right.getType()), broadcast_dimensions);          \
 | 
			
		||||
    return Op::build(builder, result, type, left, right,                       \
 | 
			
		||||
                     broadcast_dimensions);                                    \
 | 
			
		||||
  }                                                                            \
 | 
			
		||||
  BROADCAST_INFER_SHAPE_TYPE_OP_DEFS(Op)
 | 
			
		||||
 | 
			
		||||
BROADCAST_BINARY_OP_DEFS(BroadcastAddOp);
 | 
			
		||||
BROADCAST_BINARY_OP_DEFS(BroadcastAndOp);
 | 
			
		||||
BROADCAST_BINARY_OP_DEFS(BroadcastAtan2Op);
 | 
			
		||||
BROADCAST_BINARY_OP_DEFS(BroadcastDivOp);
 | 
			
		||||
BROADCAST_BINARY_OP_DEFS(BroadcastMaxOp);
 | 
			
		||||
BROADCAST_BINARY_OP_DEFS(BroadcastMinOp);
 | 
			
		||||
BROADCAST_BINARY_OP_DEFS(BroadcastMulOp);
 | 
			
		||||
BROADCAST_BINARY_OP_DEFS(BroadcastOrOp);
 | 
			
		||||
BROADCAST_BINARY_OP_DEFS(BroadcastPowOp);
 | 
			
		||||
BROADCAST_BINARY_OP_DEFS(BroadcastRemOp);
 | 
			
		||||
BROADCAST_BINARY_OP_DEFS(BroadcastShiftLeftOp);
 | 
			
		||||
BROADCAST_BINARY_OP_DEFS(BroadcastShiftRightArithmeticOp);
 | 
			
		||||
BROADCAST_BINARY_OP_DEFS(BroadcastShiftRightLogicalOp);
 | 
			
		||||
BROADCAST_BINARY_OP_DEFS(BroadcastSubOp);
 | 
			
		||||
BROADCAST_BINARY_OP_DEFS(BroadcastXorOp);
 | 
			
		||||
 | 
			
		||||
#undef BROADCAST_INFER_SHAPE_TYPE_OP_DEFS
 | 
			
		||||
#undef BROADCAST_BINARY_OP_DEFS
 | 
			
		||||
 | 
			
		||||
#define GET_OP_CLASSES
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.cc.inc"
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// xla_chlo Dialect Constructor
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
XlaHloClientDialect::XlaHloClientDialect(MLIRContext* context)
 | 
			
		||||
    : Dialect(getDialectNamespace(), context) {
 | 
			
		||||
  addOperations<
 | 
			
		||||
#define GET_OP_LIST
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.cc.inc"
 | 
			
		||||
      >();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace xla_chlo
 | 
			
		||||
}  // namespace mlir
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,24 @@
 | 
			
		|||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
 | 
			
		||||
 | 
			
		||||
// Static initialization for XLA dialect registration.
 | 
			
		||||
static mlir::DialectRegistration<mlir::xla_hlo::XlaHloDialect> xla_hlo_ops;
 | 
			
		||||
static mlir::DialectRegistration<mlir::xla_chlo::XlaHloClientDialect>
 | 
			
		||||
    xla_chlo_ops;
 | 
			
		||||
static mlir::DialectRegistration<mlir::xla_lhlo::XlaLhloDialect> xla_lhlo_ops;
 | 
			
		||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| 
						 | 
				
			
			@ -0,0 +1,22 @@
 | 
			
		|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h"
 | 
			
		||||
 | 
			
		||||
namespace mlir {
 | 
			
		||||
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.cc.inc"
 | 
			
		||||
 | 
			
		||||
}  // namespace mlir
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,102 @@
 | 
			
		|||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
// This file defines the operations used in the XLA dialect.
 | 
			
		||||
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
 | 
			
		||||
 | 
			
		||||
#include <assert.h>
 | 
			
		||||
#include <stddef.h>
 | 
			
		||||
#include <stdint.h>
 | 
			
		||||
 | 
			
		||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/APFloat.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/APInt.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/ArrayRef.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLExtras.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/SmallVector.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringRef.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/Support/FormatVariadic.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Dialect.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Location.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpDefinition.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpImplementation.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OperationSupport.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/TypeUtilities.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Types.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Value.h"
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc"
 | 
			
		||||
 | 
			
		||||
namespace mlir {
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_structs.cc.inc"
 | 
			
		||||
namespace xla_lhlo {
 | 
			
		||||
 | 
			
		||||
XlaLhloDialect::XlaLhloDialect(MLIRContext *context)
 | 
			
		||||
    : Dialect(getDialectNamespace(), context) {
 | 
			
		||||
  addOperations<
 | 
			
		||||
#define GET_OP_LIST
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.cc.inc"
 | 
			
		||||
      >();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// StaticMemRefCastOp
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
Value StaticMemRefCastOp::getViewSource() { return *getODSOperands(0).begin(); }
 | 
			
		||||
 | 
			
		||||
static LogicalResult Verify(StaticMemRefCastOp op) {
 | 
			
		||||
  if (!op.operand().getType().cast<ShapedType>().hasStaticShape())
 | 
			
		||||
    return op.emitOpError("operand must have static shape");
 | 
			
		||||
  if (!op.getType().hasStaticShape())
 | 
			
		||||
    return op.emitOpError("result must have static shape");
 | 
			
		||||
  return success();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// DynamicMemRefCastOp
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
Value DynamicMemRefCastOp::getViewSource() {
 | 
			
		||||
  return *getODSOperands(0).begin();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static LogicalResult Verify(DynamicMemRefCastOp op) {
 | 
			
		||||
  // Check if `sizes` and `strides` args are compatible with the result type.
 | 
			
		||||
  if (op.sizes().size() != op.getType().getRank())
 | 
			
		||||
    return op.emitOpError(
 | 
			
		||||
        "`sizes` args count must be equal to the rank of the output memref");
 | 
			
		||||
  return success();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#define GET_OP_CLASSES
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.cc.inc"
 | 
			
		||||
 | 
			
		||||
// TODO(cheshire): Support folding, reuse code from hlo_ops.cc.
 | 
			
		||||
 | 
			
		||||
void FusionOp::build(OpBuilder &builder, OperationState &result,
 | 
			
		||||
                     ArrayRef<NamedAttribute> attributes) {
 | 
			
		||||
  result.addAttributes(attributes);
 | 
			
		||||
  Region *bodyRegion = result.addRegion();
 | 
			
		||||
  FusionOp::ensureTerminator(*bodyRegion, builder, result.location);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace xla_lhlo
 | 
			
		||||
}  // namespace mlir
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,30 @@
 | 
			
		|||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
// This is the canonicalize pattern definition file.
 | 
			
		||||
 | 
			
		||||
include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpBase.td"
 | 
			
		||||
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.td"
 | 
			
		||||
include "mlir-hlo/Dialect/mhlo/IR/hlo_utils.td"
 | 
			
		||||
 | 
			
		||||
def UnaryToBinaryEinsumEq : NativeCodeCall<
 | 
			
		||||
  "$_builder.getStringAttr(\",\" + $0.getValue().str())">;
 | 
			
		||||
 | 
			
		||||
// Convert UnaryEinsumOp to EinsumOp with two operands with redundant first
 | 
			
		||||
// operand.
 | 
			
		||||
def UnaryEinsumToEinsum : Pat<
 | 
			
		||||
  (HLO_UnaryEinsumOp $operand, $equation),
 | 
			
		||||
  (HLO_EinsumOp (HLO_ConstOp (GetScalarOfType<1> $operand)),
 | 
			
		||||
                $operand, (UnaryToBinaryEinsumEq $equation))>;
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,74 @@
 | 
			
		|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/broadcast_utils.h"
 | 
			
		||||
 | 
			
		||||
#include <algorithm>
 | 
			
		||||
 | 
			
		||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/Sequence.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/SmallVector.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Shape/IR/Shape.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Diagnostics.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
 | 
			
		||||
 | 
			
		||||
namespace mlir {
 | 
			
		||||
namespace xla {
 | 
			
		||||
 | 
			
		||||
bool IsLegalNumpyRankedBroadcast(Value lhs, Value rhs,
 | 
			
		||||
                                 DenseIntElementsAttr broadcast_dims) {
 | 
			
		||||
  RankedTensorType lhs_type = lhs.getType().dyn_cast<RankedTensorType>();
 | 
			
		||||
  RankedTensorType rhs_type = rhs.getType().dyn_cast<RankedTensorType>();
 | 
			
		||||
  if (!lhs_type || !rhs_type) return false;
 | 
			
		||||
  if (lhs_type.getRank() == rhs_type.getRank()) return true;
 | 
			
		||||
 | 
			
		||||
  // Otherwise, verify that broadcast_dims strictly performs left-padding.
 | 
			
		||||
  auto smaller_rank = std::min(lhs_type.getRank(), rhs_type.getRank());
 | 
			
		||||
  auto larger_rank = std::max(lhs_type.getRank(), rhs_type.getRank());
 | 
			
		||||
 | 
			
		||||
  if (smaller_rank != broadcast_dims.getNumElements()) {
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
  auto expected_extents =
 | 
			
		||||
      llvm::seq<int64_t>(larger_rank - smaller_rank, larger_rank);
 | 
			
		||||
  return std::equal(expected_extents.begin(), expected_extents.end(),
 | 
			
		||||
                    broadcast_dims.getIntValues().begin());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Value ComputeBinaryElementwiseBroadcastingResultExtents(Location loc, Value lhs,
 | 
			
		||||
                                                        Value rhs,
 | 
			
		||||
                                                        OpBuilder& builder) {
 | 
			
		||||
  auto lhs_type = lhs.getType().dyn_cast<RankedTensorType>();
 | 
			
		||||
  auto rhs_type = rhs.getType().dyn_cast<RankedTensorType>();
 | 
			
		||||
  if (!lhs_type || !rhs_type) {
 | 
			
		||||
    emitError(loc) << "shape computation for broadcasting elementwise ops "
 | 
			
		||||
                   << "is only implemented for ranked tensors";
 | 
			
		||||
    return nullptr;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  int64_t result_rank = std::max(lhs_type.getRank(), rhs_type.getRank());
 | 
			
		||||
  auto shape_type = shape::ShapeType::get(builder.getContext());
 | 
			
		||||
  Value lhs_shape_v =
 | 
			
		||||
      builder.createOrFold<shape::ShapeOfOp>(loc, shape_type, lhs);
 | 
			
		||||
  Value rhs_shape_v =
 | 
			
		||||
      builder.createOrFold<shape::ShapeOfOp>(loc, shape_type, rhs);
 | 
			
		||||
  Value result_shape_v = builder.createOrFold<shape::BroadcastOp>(
 | 
			
		||||
      loc, shape_type, lhs_shape_v, rhs_shape_v, nullptr /* error */);
 | 
			
		||||
  return builder.createOrFold<shape::ToExtentTensorOp>(
 | 
			
		||||
      loc, RankedTensorType::get({result_rank}, builder.getIndexType()),
 | 
			
		||||
      result_shape_v);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace xla
 | 
			
		||||
}  // namespace mlir
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,86 @@
 | 
			
		|||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
// This file defines helpers useful when creating or manipulating lhlo/hlo.
 | 
			
		||||
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/convert_op_folder.h"
 | 
			
		||||
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/TypeUtilities.h"
 | 
			
		||||
 | 
			
		||||
namespace mlir {
 | 
			
		||||
namespace xla {
 | 
			
		||||
 | 
			
		||||
mlir::ElementsAttr ConvertElementsAttr(const mlir::ElementsAttr& elements,
 | 
			
		||||
                                       mlir::Type new_type) {
 | 
			
		||||
  auto old_type = getElementTypeOrSelf(elements);
 | 
			
		||||
  size_t bit_width = new_type.isBF16() ? 64 : new_type.getIntOrFloatBitWidth();
 | 
			
		||||
 | 
			
		||||
  if (old_type.isa<mlir::FloatType>()) {
 | 
			
		||||
    // mapValues always takes a function returning APInt, even when the output
 | 
			
		||||
    // is actually float.
 | 
			
		||||
    using func_type = mlir::APInt(const llvm::APFloat&);
 | 
			
		||||
    if (auto newFloatType = new_type.dyn_cast<mlir::FloatType>()) {
 | 
			
		||||
      // Float -> Float
 | 
			
		||||
      return elements.mapValues(
 | 
			
		||||
          new_type, llvm::function_ref<func_type>(
 | 
			
		||||
                        [&newFloatType](const llvm::APFloat& floatVal) {
 | 
			
		||||
                          llvm::APFloat newDouble(
 | 
			
		||||
                              mlir::FloatAttr::getValueAsDouble(floatVal));
 | 
			
		||||
                          bool loses_info = false;
 | 
			
		||||
                          newDouble.convert(newFloatType.getFloatSemantics(),
 | 
			
		||||
                                            llvm::APFloat::rmNearestTiesToEven,
 | 
			
		||||
                                            &loses_info);
 | 
			
		||||
                          return newDouble.bitcastToAPInt();
 | 
			
		||||
                        }));
 | 
			
		||||
    }
 | 
			
		||||
    // Float -> Int
 | 
			
		||||
    return elements.mapValues(
 | 
			
		||||
        new_type, llvm::function_ref<func_type>(
 | 
			
		||||
                      [&bit_width](const llvm::APFloat& floatVal) {
 | 
			
		||||
                        return llvm::APInt(
 | 
			
		||||
                            bit_width,
 | 
			
		||||
                            mlir::FloatAttr::getValueAsDouble(floatVal));
 | 
			
		||||
                      }));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // old_type is Integer
 | 
			
		||||
  // mapValues always takes a function returning APInt, even when the output
 | 
			
		||||
  // is actually float.
 | 
			
		||||
  using func_type = llvm::APInt(const llvm::APInt&);
 | 
			
		||||
  if (auto newFloatType = new_type.dyn_cast<mlir::FloatType>()) {
 | 
			
		||||
    // Int -> Float
 | 
			
		||||
    return elements.mapValues(
 | 
			
		||||
        new_type, llvm::function_ref<func_type>([&newFloatType](
 | 
			
		||||
                                                    const llvm::APInt& intVal) {
 | 
			
		||||
          llvm::APFloat newDouble(static_cast<double>(intVal.getSExtValue()));
 | 
			
		||||
          bool loses_info = false;
 | 
			
		||||
          newDouble.convert(newFloatType.getFloatSemantics(),
 | 
			
		||||
                            llvm::APFloat::rmNearestTiesToEven, &loses_info);
 | 
			
		||||
          return newDouble.bitcastToAPInt();
 | 
			
		||||
        }));
 | 
			
		||||
  }
 | 
			
		||||
  // new_type is Integer
 | 
			
		||||
  // Int -> Int
 | 
			
		||||
  return elements.mapValues(
 | 
			
		||||
      new_type,
 | 
			
		||||
      llvm::function_ref<func_type>([&bit_width](const llvm::APInt& intVal) {
 | 
			
		||||
        return llvm::APInt(bit_width, intVal.getSExtValue());
 | 
			
		||||
      }));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace xla
 | 
			
		||||
}  // namespace mlir
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,70 @@
 | 
			
		|||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h"
 | 
			
		||||
 | 
			
		||||
#include <numeric>
 | 
			
		||||
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
 | 
			
		||||
 | 
			
		||||
namespace mlir {
 | 
			
		||||
namespace xla {
 | 
			
		||||
 | 
			
		||||
DenseIntElementsAttr getBroadcastDimensionsAttr(Builder *b, Value x, Value y,
 | 
			
		||||
                                                bool allow_empty) {
 | 
			
		||||
  TensorType xType = x.getType().dyn_cast<RankedTensorType>();
 | 
			
		||||
  TensorType yType = y.getType().dyn_cast<RankedTensorType>();
 | 
			
		||||
  if (!xType || !yType) return {};
 | 
			
		||||
  if (allow_empty && xType == yType) return {};
 | 
			
		||||
 | 
			
		||||
  // If the shapes have the same rank, then there is nothing to do.
 | 
			
		||||
  auto xRank = xType.getRank(), yRank = yType.getRank();
 | 
			
		||||
  if (allow_empty && xRank == yRank) return {};
 | 
			
		||||
 | 
			
		||||
  // Otherwise if the ranks of the inputs don't match, TensorFlow automatically
 | 
			
		||||
  // reshapes the smaller by padding with dimensions of size 1 as a prefix. In
 | 
			
		||||
  // other words to pad a 5-vector to a 3-dimensional tensor it is reshaped to
 | 
			
		||||
  // have shape [1,1,5]. XLA's automatic broadcast code is able to broadcast
 | 
			
		||||
  // from lower to higher rank, but doesn't assume you want to pad as a prefix
 | 
			
		||||
  // of the dimensions, and instead needs to be told which dimensions of the
 | 
			
		||||
  // higher rank tensor to match to the lower rank tensor.
 | 
			
		||||
  auto maxRank = std::max(xRank, yRank);
 | 
			
		||||
  auto minRank = std::min(xRank, yRank);
 | 
			
		||||
 | 
			
		||||
  // Match the lower rank tensor along the larger-numbered dimensions of the
 | 
			
		||||
  // higher rank tensor.
 | 
			
		||||
  SmallVector<int64_t, 4> broadcastDimensions(minRank);
 | 
			
		||||
  std::iota(broadcastDimensions.begin(), broadcastDimensions.end(),
 | 
			
		||||
            maxRank - minRank);
 | 
			
		||||
 | 
			
		||||
  RankedTensorType type =
 | 
			
		||||
      RankedTensorType::get({minRank}, b->getIntegerType(64));
 | 
			
		||||
  return DenseIntElementsAttr::get(type, broadcastDimensions);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value) {
 | 
			
		||||
  RankedTensorType scalar_ty = RankedTensorType::get({}, ty);
 | 
			
		||||
 | 
			
		||||
  if (auto float_ty = ty.dyn_cast<FloatType>()) {
 | 
			
		||||
    APFloat value(float_ty.getFloatSemantics(), raw_value);
 | 
			
		||||
    return DenseElementsAttr::get(scalar_ty, value);
 | 
			
		||||
  }
 | 
			
		||||
  auto int_ty = ty.cast<IntegerType>();
 | 
			
		||||
  APInt value(int_ty.getWidth(), static_cast<int64_t>(raw_value), true);
 | 
			
		||||
  return DenseElementsAttr::get(scalar_ty, value);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace xla
 | 
			
		||||
}  // namespace mlir
 | 
			
		||||
		Loading…
	
		Reference in New Issue