Move the HLO/LHLO dialects to a new directory: tensorflow/compiler/mlir/hlo
We're preparing to restructure the MLIR HLO ecosystem with 5 dialects: - chlo: client dialect with explicit broadcast and multiple composite operations - mhlo: hlo with dynamic shape, decouple from XLA for evolution purpose - lmhlo: same as above, but after buffer assignment. - xla_hlo: mapping 1:1 to the XLA HloInstruction class. - xla_lhlo: same as above, but after buffer assignment. The first three dialects are intended to live in the new tensorflow/compiler/mlir/hlo path, the latter two will be created in tensorflow/compiler/mlir/xla. This patch only moves the directory, will followup with other transformations and tests. The structure of the new directory follows: https://llvm.discourse.group/t/rfc-canonical-file-paths-to-dialects/621 as we intend to make it a standalone buildable component (see also https://github.com/google/mlir-npcomp as another example). PiperOrigin-RevId: 319273229
This commit is contained in:
parent
5fe5c39ccc
commit
fcf3df1541
|
@ -0,0 +1,45 @@
|
||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_CHLO_OPS_H_
|
||||||
|
#define TENSORFLOW_COMPILER_MLIR_XLA_IR_CHLO_OPS_H_
|
||||||
|
|
||||||
|
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringRef.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Dialect.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/DialectImplementation.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpDefinition.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Types.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/InferTypeOpInterface.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/SideEffectInterfaces.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace xla_chlo {
|
||||||
|
|
||||||
|
class XlaHloClientDialect : public Dialect {
|
||||||
|
public:
|
||||||
|
explicit XlaHloClientDialect(MLIRContext *context);
|
||||||
|
static StringRef getDialectNamespace() { return "xla_chlo"; }
|
||||||
|
};
|
||||||
|
|
||||||
|
#define GET_OP_CLASSES
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h.inc"
|
||||||
|
|
||||||
|
} // namespace xla_chlo
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_COMPILER_MLIR_XLA_IR_CHLO_OPS_H_
|
|
@ -0,0 +1,370 @@
|
||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
// Defines "client" aligned HLO ops.
|
||||||
|
// These ops are not necessarily orthogonal or optimized for transformation but
|
||||||
|
// for ease of expression in certain cases deemed important for client
|
||||||
|
// libraries (i.e. implicit broadcasting, helper ops, etc).
|
||||||
|
// This dialect is considered to exist in addition to augment the xla_hlo
|
||||||
|
// dialect for ergonomic needs, not duplicate/replace it.
|
||||||
|
//
|
||||||
|
// The typical use of this dialect is for client libraries to be able to emit
|
||||||
|
// less constrained ops and rely on the conversion framework to lower any
|
||||||
|
// xla_chlo ops to canonical xla_hlo ops.
|
||||||
|
//
|
||||||
|
// See: https://www.tensorflow.org/xla/operation_semantics
|
||||||
|
|
||||||
|
#ifndef CHLO_OPS
|
||||||
|
#define CHLO_OPS
|
||||||
|
|
||||||
|
include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpBase.td"
|
||||||
|
include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/InferTypeOpInterface.td"
|
||||||
|
include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/SideEffectInterfaces.td"
|
||||||
|
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td"
|
||||||
|
|
||||||
|
def HLOClient_Dialect : Dialect {
|
||||||
|
let name = "xla_chlo";
|
||||||
|
let cppNamespace = "xla_chlo";
|
||||||
|
let summary = [{
|
||||||
|
XLA Client HLO Ops
|
||||||
|
}];
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
This dialect contains ops that align closely with the API surface area
|
||||||
|
of the XlaBuilder C++ API, where such ops have semantics that go beyond
|
||||||
|
what exists in the lower level dialects (such as `xla_hlo`). Essentially,
|
||||||
|
whenever the client library uses syntactic sugar or composition
|
||||||
|
of multiple ops for an API call, this dialect tries to model the API call
|
||||||
|
and provide conversion patterns to fully materialize into lower level
|
||||||
|
dialects.
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
class HLOClient_Op<string mnemonic, list<OpTrait> traits> :
|
||||||
|
Op<HLOClient_Dialect, mnemonic, traits> {
|
||||||
|
// TODO(b/129012527) Much of this custom verification should be expressed as
|
||||||
|
// type constraints.
|
||||||
|
let verifier = [{ return Verify(*this); }];
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// XLA binary elementwise op definitions.
|
||||||
|
// From the client perspective, each of these support both explicit rank
|
||||||
|
// broadcasting (via the broadcast_dimensions attribute) and implicit degenerate
|
||||||
|
// shape broadcasting.
|
||||||
|
//
|
||||||
|
// These correspond to operations in the xla_hlo dialect without the
|
||||||
|
// "broadcast_" prefix, except that those ops require same-shaped operands and
|
||||||
|
// results.
|
||||||
|
//
|
||||||
|
// See:
|
||||||
|
// https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations
|
||||||
|
// https://www.tensorflow.org/xla/broadcasting
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
class HLOClient_BroadcastBinaryElementwiseOp<
|
||||||
|
string mnemonic, list<OpTrait> traits> :
|
||||||
|
HLOClient_Op<mnemonic,
|
||||||
|
!listconcat(traits, [
|
||||||
|
DeclareOpInterfaceMethods<InferShapedTypeOpInterface>])> {
|
||||||
|
let arguments = (ins
|
||||||
|
HLO_Tensor:$lhs,
|
||||||
|
HLO_Tensor:$rhs,
|
||||||
|
// Explicit rank-broadcast dimension mappings. Defaults to "numpy" prefix
|
||||||
|
// padded rank-broadcast semantics if omitted.
|
||||||
|
OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions
|
||||||
|
);
|
||||||
|
|
||||||
|
let builders = [OpBuilder<
|
||||||
|
"OpBuilder &builder, OperationState &result, Value left, Value right, "
|
||||||
|
"DenseIntElementsAttr broadcast_dimensions"
|
||||||
|
>];
|
||||||
|
|
||||||
|
let results = (outs HLO_Tensor);
|
||||||
|
|
||||||
|
let assemblyFormat = [{
|
||||||
|
$lhs `,` $rhs attr-dict `:`
|
||||||
|
`(` type($lhs) `,` type($rhs) `)` `->` type(results)
|
||||||
|
}];
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
// TODO(laurenzo): It isn't clear to me why reifyReturnShapes does not
|
||||||
|
// have its declaration generated by DeclareOpInterfaceMethods.
|
||||||
|
LogicalResult reifyReturnTypeShapes(
|
||||||
|
OpBuilder& builder, SmallVectorImpl<Value>& reifiedReturnShapes);
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def HLOClient_BroadcastAddOp : HLOClient_BroadcastBinaryElementwiseOp<"broadcast_add",
|
||||||
|
[Commutative, NoSideEffect, SameOperandsAndResultElementType]> {
|
||||||
|
string summary = "Addition operator (with optional broadcasting)";
|
||||||
|
|
||||||
|
string description = [{
|
||||||
|
Returns `lhs + rhs` element-wise.
|
||||||
|
|
||||||
|
See
|
||||||
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def HLOClient_BroadcastAtan2Op : HLOClient_BroadcastBinaryElementwiseOp<
|
||||||
|
"broadcast_atan2",
|
||||||
|
[NoSideEffect, SameOperandsAndResultElementType]> {
|
||||||
|
string summary = "Atan2 operator (with optional broadcasting)";
|
||||||
|
|
||||||
|
string description = [{
|
||||||
|
Returns `atan2(lhs/rhs)` element-wise.
|
||||||
|
|
||||||
|
See
|
||||||
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def HLOClient_BroadcastDivOp : HLOClient_BroadcastBinaryElementwiseOp<
|
||||||
|
"broadcast_divide",
|
||||||
|
[NoSideEffect, SameOperandsAndResultElementType]> {
|
||||||
|
string summary = "Division operator (with optional broadcasting)";
|
||||||
|
|
||||||
|
string description = [{
|
||||||
|
Returns `lhs / rhs` element-wise.
|
||||||
|
|
||||||
|
See
|
||||||
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def HLOClient_BroadcastMaxOp : HLOClient_BroadcastBinaryElementwiseOp<
|
||||||
|
"broadcast_maximum",
|
||||||
|
[Commutative, NoSideEffect, SameOperandsAndResultElementType]> {
|
||||||
|
string summary = "Maximum operator (with optional broadcasting)";
|
||||||
|
|
||||||
|
string description = [{
|
||||||
|
Returns `max(lhs, rhs)` element-wise.
|
||||||
|
|
||||||
|
See
|
||||||
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def HLOClient_BroadcastMinOp : HLOClient_BroadcastBinaryElementwiseOp<
|
||||||
|
"broadcast_minimum",
|
||||||
|
[Commutative, NoSideEffect, SameOperandsAndResultElementType]> {
|
||||||
|
string summary = "Minimum operator (with optional broadcasting)";
|
||||||
|
|
||||||
|
string description = [{
|
||||||
|
Returns `min(lhs, rhs)` element-wise.
|
||||||
|
|
||||||
|
See
|
||||||
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def HLOClient_BroadcastMulOp : HLOClient_BroadcastBinaryElementwiseOp<
|
||||||
|
"broadcast_multiply",
|
||||||
|
[Commutative, NoSideEffect, SameOperandsAndResultElementType]> {
|
||||||
|
string summary = "Multiplication operator (with optional broadcasting)";
|
||||||
|
|
||||||
|
string description = [{
|
||||||
|
Returns `lhs * rhs` element-wise.
|
||||||
|
|
||||||
|
See
|
||||||
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def HLOClient_BroadcastPowOp : HLOClient_BroadcastBinaryElementwiseOp<
|
||||||
|
"broadcast_power",
|
||||||
|
[NoSideEffect, SameOperandsAndResultElementType]> {
|
||||||
|
string summary = "Power operator (with optional broadcasting)";
|
||||||
|
|
||||||
|
string description = [{
|
||||||
|
Returns `lhs ^ rhs` element-wise.
|
||||||
|
|
||||||
|
See
|
||||||
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def HLOClient_BroadcastRemOp : HLOClient_BroadcastBinaryElementwiseOp<
|
||||||
|
"broadcast_remainder",
|
||||||
|
[NoSideEffect, SameOperandsAndResultElementType]> {
|
||||||
|
string summary = "Remainder operator (with optional broadcasting)";
|
||||||
|
|
||||||
|
string description = [{
|
||||||
|
Returns `lhs % rhs` element-wise.
|
||||||
|
|
||||||
|
See
|
||||||
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def HLOClient_BroadcastShiftLeftOp : HLOClient_BroadcastBinaryElementwiseOp<
|
||||||
|
"broadcast_shift_left",
|
||||||
|
[NoSideEffect, SameOperandsAndResultElementType]> {
|
||||||
|
string summary = "Shift left operator (with optional broadcasting)";
|
||||||
|
|
||||||
|
string description = [{
|
||||||
|
Returns `lhs << rhs` element-wise.
|
||||||
|
|
||||||
|
See
|
||||||
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def HLOClient_BroadcastShiftRightArithmeticOp : HLOClient_BroadcastBinaryElementwiseOp<
|
||||||
|
"broadcast_shift_right_arithmetic",
|
||||||
|
[NoSideEffect, SameOperandsAndResultElementType]> {
|
||||||
|
string summary = "Shift right arithmetic operator (with optional broadcasting)";
|
||||||
|
|
||||||
|
string description = [{
|
||||||
|
Returns `lhs >> rhs` element-wise.
|
||||||
|
|
||||||
|
See
|
||||||
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def HLOClient_BroadcastShiftRightLogicalOp : HLOClient_BroadcastBinaryElementwiseOp<
|
||||||
|
"broadcast_shift_right_logical",
|
||||||
|
[NoSideEffect, SameOperandsAndResultElementType]> {
|
||||||
|
string summary = "Shift right logical operator (with optional broadcasting)";
|
||||||
|
|
||||||
|
string description = [{
|
||||||
|
Returns `lhs >> rhs` element-wise.
|
||||||
|
|
||||||
|
See
|
||||||
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def HLOClient_BroadcastSubOp : HLOClient_BroadcastBinaryElementwiseOp<
|
||||||
|
"broadcast_subtract",
|
||||||
|
[NoSideEffect, SameOperandsAndResultElementType]> {
|
||||||
|
string summary = "Subtraction operator (with optional broadcasting)";
|
||||||
|
|
||||||
|
string description = [{
|
||||||
|
Returns `lhs - rhs` element-wise.
|
||||||
|
|
||||||
|
See
|
||||||
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// XLA binary elementwise op definitions.
|
||||||
|
// The same description as the arithmetic binary elementwise ops applies.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
class HLOClient_BroadcastBinaryLogicalElementwiseOp<string mnemonic> :
|
||||||
|
HLOClient_BroadcastBinaryElementwiseOp<
|
||||||
|
mnemonic, [Commutative, NoSideEffect]> {
|
||||||
|
let arguments = (ins
|
||||||
|
HLO_PredOrIntTensor:$lhs,
|
||||||
|
HLO_PredOrIntTensor:$rhs,
|
||||||
|
// Explicit rank-broadcast dimension mappings. Defaults to "numpy" prefix
|
||||||
|
// padded rank-broadcast semantics if omitted.
|
||||||
|
OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
def HLOClient_BroadcastAndOp: HLOClient_BroadcastBinaryLogicalElementwiseOp<
|
||||||
|
"broadcast_and"> {
|
||||||
|
string summary = "Logical and operator (with optional broadcasting)";
|
||||||
|
|
||||||
|
string description = [{
|
||||||
|
Returns `logical_and(lhs, rhs)` element-wise.
|
||||||
|
|
||||||
|
See
|
||||||
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def HLOClient_BroadcastOrOp: HLOClient_BroadcastBinaryLogicalElementwiseOp<
|
||||||
|
"broadcast_or"> {
|
||||||
|
string summary = "Logical or operator (with optional broadcasting)";
|
||||||
|
|
||||||
|
string description = [{
|
||||||
|
Returns `logical_or(lhs, rhs)` element-wise.
|
||||||
|
|
||||||
|
See
|
||||||
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def HLOClient_BroadcastXorOp : HLOClient_BroadcastBinaryLogicalElementwiseOp<
|
||||||
|
"broadcast_xor"> {
|
||||||
|
string summary = "Logical xor operator (with optional broadcasting)";
|
||||||
|
|
||||||
|
string description = [{
|
||||||
|
Returns `logical_xor(lhs, rhs)` element-wise.
|
||||||
|
|
||||||
|
See
|
||||||
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Broadcasting complex op
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def HLOClient_BroadcastComplexOp : HLOClient_BroadcastBinaryElementwiseOp<
|
||||||
|
"broadcast_complex", [NoSideEffect]> {
|
||||||
|
string summary = "Complex operator (with optional broadcasting)";
|
||||||
|
|
||||||
|
string description = [{
|
||||||
|
Performs element-wise conversion of a pair of real and imaginary values to
|
||||||
|
a complex value.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
HLO_FpTensor:$lhs,
|
||||||
|
HLO_FpTensor:$rhs,
|
||||||
|
// Explicit rank-broadcast dimension mappings. Defaults to "numpy" prefix
|
||||||
|
// padded rank-broadcast semantics if omitted.
|
||||||
|
OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions
|
||||||
|
);
|
||||||
|
let results = (outs HLO_ComplexTensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Broadcasting compare op
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def HLOClient_BroadcastCompareOp : HLOClient_BroadcastBinaryElementwiseOp<
|
||||||
|
"broadcast_compare", [NoSideEffect]> {
|
||||||
|
string summary = "Compare operator (with optional broadcasting)";
|
||||||
|
|
||||||
|
string description = [{
|
||||||
|
Compares `lhs` and `rhs` elementwise according to `comparison_direction`.
|
||||||
|
|
||||||
|
See
|
||||||
|
https://www.tensorflow.org/xla/operation_semantics#element-wise_comparison_operations.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
HLO_Tensor:$lhs,
|
||||||
|
HLO_Tensor:$rhs,
|
||||||
|
OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions,
|
||||||
|
HLO_ComparisonDirectionAttr:$comparison_direction
|
||||||
|
);
|
||||||
|
let results = (outs HLO_PredTensor);
|
||||||
|
|
||||||
|
let builders = [OpBuilder<
|
||||||
|
"OpBuilder &builder, OperationState &result, Value lhs, Value rhs, "
|
||||||
|
"DenseIntElementsAttr broadcast_dimensions, StringAttr comparison_direction"
|
||||||
|
>];
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // CHLO_OPS
|
|
@ -0,0 +1,99 @@
|
||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
// This file defines the operations used in the XLA dialect.
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_OPS_H_
|
||||||
|
#define TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_OPS_H_
|
||||||
|
|
||||||
|
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringRef.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Dialect.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/DialectImplementation.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Location.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpDefinition.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Types.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/InferTypeOpInterface.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/SideEffectInterfaces.h"
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
class OpBuilder;
|
||||||
|
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_structs.h.inc"
|
||||||
|
|
||||||
|
namespace xla_hlo {
|
||||||
|
|
||||||
|
class XlaHloDialect : public Dialect {
|
||||||
|
public:
|
||||||
|
explicit XlaHloDialect(MLIRContext *context);
|
||||||
|
static StringRef getDialectNamespace() { return "xla_hlo"; }
|
||||||
|
|
||||||
|
// Registered hook to materialize a constant operation from a given attribute
|
||||||
|
// value with the desired resultant type.
|
||||||
|
Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type,
|
||||||
|
Location loc) override;
|
||||||
|
|
||||||
|
// Parses a type registered to this dialect.
|
||||||
|
Type parseType(DialectAsmParser &parser) const override;
|
||||||
|
|
||||||
|
// Prints a type registered to this dialect.
|
||||||
|
void printType(Type type, DialectAsmPrinter &os) const override;
|
||||||
|
};
|
||||||
|
|
||||||
|
namespace HLOTypes {
|
||||||
|
enum Kind {
|
||||||
|
Token = Type::FIRST_XLA_HLO_TYPE,
|
||||||
|
};
|
||||||
|
} // namespace HLOTypes
|
||||||
|
|
||||||
|
class TokenType : public Type::TypeBase<TokenType, Type, TypeStorage> {
|
||||||
|
public:
|
||||||
|
using Base::Base;
|
||||||
|
|
||||||
|
static TokenType get(MLIRContext *context) {
|
||||||
|
return Base::get(context, HLOTypes::Token);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Support method to enable LLVM-style type casting.
|
||||||
|
static bool kindof(unsigned kind) { return kind == HLOTypes::Token; }
|
||||||
|
};
|
||||||
|
|
||||||
|
// Shape derivation function that computes the shape of the result based on
|
||||||
|
// the first argument. For a 2-dimensional input tensor, this produces IR of
|
||||||
|
// the form
|
||||||
|
//
|
||||||
|
// %0 = dim %arg0, 0 : memref<?x?xf32>
|
||||||
|
// %1 = index_cast %0 : index to i64
|
||||||
|
// %2 = dim %arg0, 1 : memref<?x?xf32>
|
||||||
|
// %3 = index_cast %2 : index to i64
|
||||||
|
// %4 = "xla_hlo.scalars_to_dimension_tensor"(%1, %3)
|
||||||
|
// : (i64, i64) -> tensor<2xi64>
|
||||||
|
//
|
||||||
|
// and returns %4 as the shape value.
|
||||||
|
LogicalResult deriveShapeFromFirstOperand(
|
||||||
|
OpBuilder *builder, Operation *op,
|
||||||
|
SmallVectorImpl<Value> *reifiedReturnShapes);
|
||||||
|
|
||||||
|
#define GET_OP_CLASSES
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h.inc"
|
||||||
|
|
||||||
|
} // end namespace xla_hlo
|
||||||
|
} // end namespace mlir
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_OPS_H_
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,43 @@
|
||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
// This is the utils file for the HLO dialect.
|
||||||
|
|
||||||
|
#ifndef HLO_UTILS
|
||||||
|
#define HLO_UTILS
|
||||||
|
|
||||||
|
include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpBase.td"
|
||||||
|
|
||||||
|
def NullArrayAttr : NativeCodeCall<"ArrayAttr()">;
|
||||||
|
|
||||||
|
def CastIntElementsAttr : NativeCodeCall<"$0.cast<DenseIntElementsAttr>()">;
|
||||||
|
|
||||||
|
class ConstantSplat<string value> : NativeCodeCall<
|
||||||
|
"xla::getSplat(&$_builder, $0, " # value # ")">;
|
||||||
|
|
||||||
|
def NullDenseIntElementsAttr : NativeCodeCall<"DenseIntElementsAttr()">;
|
||||||
|
|
||||||
|
def BinBroadcastDimensions : NativeCodeCall<
|
||||||
|
"xla::getBroadcastDimensionsAttr(&$_builder, $0, $1)">;
|
||||||
|
|
||||||
|
def BinBroadcastDimensionsNonEmpty : NativeCodeCall<
|
||||||
|
"xla::getBroadcastDimensionsAttr(&$_builder, $0, $1, /*allow_empty=*/false)">;
|
||||||
|
|
||||||
|
// Here, the element type can be any integer or float type. But, note that only
|
||||||
|
// 32 bit integers are supported for the value.
|
||||||
|
class GetScalarOfType<int value> : NativeCodeCall<
|
||||||
|
"xla::GetScalarOfType(getElementTypeOrSelf($0)," # value # ")">;
|
||||||
|
|
||||||
|
#endif // HLO_UTILS
|
|
@ -0,0 +1,28 @@
|
||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_INFER_FUSIBILITY_OP_INTERFACE_H_
|
||||||
|
#define TENSORFLOW_COMPILER_MLIR_XLA_IR_INFER_FUSIBILITY_OP_INTERFACE_H_
|
||||||
|
|
||||||
|
#include "mlir/IR/OpDefinition.h"
|
||||||
|
#include "mlir/IR/StandardTypes.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h.inc"
|
||||||
|
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_COMPILER_MLIR_XLA_IR_INFER_FUSIBILITY_OP_INTERFACE_H_
|
|
@ -0,0 +1,161 @@
|
||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
// This file contains inferFusiblityOpInterface, which is used to guide
|
||||||
|
// fusion decision.
|
||||||
|
|
||||||
|
#ifndef MLIR_INFER_FUSIBILITY_OP_INTERFACE
|
||||||
|
#define MLIR_INFER_FUSIBILITY_OP_INTERFACE
|
||||||
|
|
||||||
|
include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpBase.td"
|
||||||
|
|
||||||
|
// OpInterface to query if an op is fusible and to query the shape equality
|
||||||
|
// constraint among the inputs and outputs of an op.
|
||||||
|
def InferFusibilityOpInterface : OpInterface<"InferFusibilityOpInterface"> {
|
||||||
|
let description = [{
|
||||||
|
Interface to query if an op is fusible and to query the shape equality
|
||||||
|
constraint among the inputs and outputs of an op.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let methods = [
|
||||||
|
InterfaceMethod<
|
||||||
|
/*desc=*/[{If true, this op can be fused with its operands
|
||||||
|
}],
|
||||||
|
/*retTy=*/"bool",
|
||||||
|
/*methodName=*/"isFusibleWithOperand",
|
||||||
|
/*args=*/(ins),
|
||||||
|
/*methodBody=*/[{}],
|
||||||
|
/*defaultImplementation=*/[{
|
||||||
|
/// Returns whether this op can be fused with its operands
|
||||||
|
return true;
|
||||||
|
}]
|
||||||
|
>,
|
||||||
|
InterfaceMethod<
|
||||||
|
/*desc=*/[{If true, this op can be fused with its consumers
|
||||||
|
}],
|
||||||
|
/*retTy=*/"bool",
|
||||||
|
/*methodName=*/"isFusibleWithConsumer",
|
||||||
|
/*args=*/(ins),
|
||||||
|
/*methodBody=*/[{}],
|
||||||
|
/*defaultImplementation=*/[{
|
||||||
|
/// Return whether this op can be fused withh its consumers
|
||||||
|
return true;
|
||||||
|
}]
|
||||||
|
>,
|
||||||
|
InterfaceMethod<
|
||||||
|
/*desc=*/"Return whether two inputs have the same shape (assuming no"
|
||||||
|
"implicit broadcasting).",
|
||||||
|
/*retTy=*/"bool",
|
||||||
|
/*methodName=*/"inferInputsShapeEquality",
|
||||||
|
/*args=*/(ins "int":$lhs, "int":$rhs),
|
||||||
|
/*methodBody=*/[{}],
|
||||||
|
/*defaultImplementation=*/[{
|
||||||
|
/// Return whether two inputs have the same shape.
|
||||||
|
Operation *op = this->getOperation();
|
||||||
|
assert(lhs < op->getNumOperands() && lhs >= 0 &&
|
||||||
|
rhs < op->getNumOperands() && rhs >= 0);
|
||||||
|
if (lhs == rhs) return true;
|
||||||
|
|
||||||
|
// if both lhs and rhs have static shapes, check them directly
|
||||||
|
Type lhs_ty = op->getOperand(lhs).getType();
|
||||||
|
Type rhs_ty = op->getOperand(rhs).getType();
|
||||||
|
auto lhs_shape_type = lhs_ty.dyn_cast_or_null<RankedTensorType>();
|
||||||
|
auto rhs_shape_type = rhs_ty.dyn_cast_or_null<RankedTensorType>();
|
||||||
|
if (!lhs_shape_type || !lhs_shape_type.hasStaticShape() ||
|
||||||
|
!rhs_shape_type || !rhs_shape_type.hasStaticShape() ||
|
||||||
|
lhs_shape_type.getRank() != rhs_shape_type.getRank()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return lhs_shape_type.getShape() == rhs_shape_type.getShape();
|
||||||
|
}]
|
||||||
|
>,
|
||||||
|
InterfaceMethod<
|
||||||
|
/*desc=*/"Return whether two outputs have the same shape (assuming no"
|
||||||
|
" implicit broadcasting).",
|
||||||
|
/*retTy=*/"bool",
|
||||||
|
/*methodName=*/"inferOutputsShapeEquality",
|
||||||
|
/*args=*/(ins "int":$lhs, "int":$rhs),
|
||||||
|
/*methodBody=*/[{}],
|
||||||
|
/*defaultImplementation=*/[{
|
||||||
|
/// Return whether two outputs have the same shape.
|
||||||
|
Operation *op = this->getOperation();
|
||||||
|
assert(lhs < op->getNumResults() && lhs >= 0 &&
|
||||||
|
rhs < op->getNumResults() && rhs >= 0);
|
||||||
|
if (lhs == rhs) return true;
|
||||||
|
|
||||||
|
// if both lhs and rhs have static shapes, check them directly
|
||||||
|
Type lhs_ty = op->getResult(lhs).getType();
|
||||||
|
Type rhs_ty = op->getResult(rhs).getType();
|
||||||
|
auto lhs_shape_type = lhs_ty.dyn_cast_or_null<RankedTensorType>();
|
||||||
|
auto rhs_shape_type = rhs_ty.dyn_cast_or_null<RankedTensorType>();
|
||||||
|
if (!lhs_shape_type || !lhs_shape_type.hasStaticShape() ||
|
||||||
|
!rhs_shape_type || !rhs_shape_type.hasStaticShape() ||
|
||||||
|
lhs_shape_type.getRank() != rhs_shape_type.getRank()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return lhs_shape_type.getShape() == rhs_shape_type.getShape();
|
||||||
|
}]
|
||||||
|
>,
|
||||||
|
InterfaceMethod<
|
||||||
|
/*desc=*/"Return whether the input and the output have the same"
|
||||||
|
" shape (assuming no implicit broadcasting).",
|
||||||
|
/*retTy=*/"bool",
|
||||||
|
/*methodName=*/"inferInputOutputShapeEquality",
|
||||||
|
/*args=*/(ins "int":$input, "int":$output),
|
||||||
|
/*methodBody=*/[{}],
|
||||||
|
/*defaultImplementation=*/[{
|
||||||
|
/// Return whether the input and the output have the same shape.
|
||||||
|
Operation *op = this->getOperation();
|
||||||
|
assert(input < op->getNumOperands() && input >= 0 &&
|
||||||
|
output < op->getNumResults() && output >= 0);
|
||||||
|
|
||||||
|
// if both input and output have static shapes, check them directly
|
||||||
|
Type input_ty = op->getOperand(input).getType();
|
||||||
|
Type output_ty = op->getResult(output).getType();
|
||||||
|
auto input_shape_type = input_ty.dyn_cast_or_null<RankedTensorType>();
|
||||||
|
auto output_shape_type = output_ty.dyn_cast_or_null<RankedTensorType>();
|
||||||
|
if (!input_shape_type || !input_shape_type.hasStaticShape() ||
|
||||||
|
!output_shape_type || !output_shape_type.hasStaticShape() ||
|
||||||
|
input_shape_type.getRank() != output_shape_type.getRank()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return input_shape_type.getShape() == output_shape_type.getShape();
|
||||||
|
}]
|
||||||
|
>,
|
||||||
|
InterfaceMethod<
|
||||||
|
/*desc=*/[{Return the effective workload shape for the operation.
|
||||||
|
|
||||||
|
Here the effective workload shape roughly represents the maximum
|
||||||
|
parallelism can be used during the codegen stage. It's used to check
|
||||||
|
the shape-compatibility of the operation. During fusion, we only
|
||||||
|
try to fuse shape-compatible ops for performace.
|
||||||
|
For example, the effective workload shape of an elementwise op is its
|
||||||
|
output shape, while the effective workload shape of a reduction op may
|
||||||
|
be its operand shape.
|
||||||
|
Return None if such an inference is not possible.
|
||||||
|
}],
|
||||||
|
/*retTy=*/"llvm::Optional<Value>",
|
||||||
|
/*methodName=*/"inferEffectiveWorkloadShape",
|
||||||
|
/*args=*/(ins),
|
||||||
|
/*methodBody=*/[{}],
|
||||||
|
/*defaultImplementation=*/[{
|
||||||
|
/// Return effective workload size if possible, otherwise None.
|
||||||
|
return {};
|
||||||
|
}]
|
||||||
|
>,
|
||||||
|
];
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
|
@ -0,0 +1,52 @@
|
||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
// This file defines the operations used in the LXLA dialect.
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_LHLO_OPS_H_
|
||||||
|
#define TENSORFLOW_COMPILER_MLIR_XLA_IR_LHLO_OPS_H_
|
||||||
|
|
||||||
|
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringRef.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Dialect.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Location.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpDefinition.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Types.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/SideEffectInterfaces.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/ViewLikeInterface.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
class OpBuilder;
|
||||||
|
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_structs.h.inc"
|
||||||
|
|
||||||
|
namespace xla_lhlo {
|
||||||
|
|
||||||
|
class XlaLhloDialect : public Dialect {
|
||||||
|
public:
|
||||||
|
explicit XlaLhloDialect(MLIRContext *context);
|
||||||
|
static StringRef getDialectNamespace() { return "xla_lhlo"; }
|
||||||
|
};
|
||||||
|
|
||||||
|
#define GET_OP_CLASSES
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc"
|
||||||
|
|
||||||
|
} // namespace xla_lhlo
|
||||||
|
} // end namespace mlir
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_COMPILER_MLIR_XLA_IR_LHLO_OPS_H_
|
|
@ -0,0 +1,796 @@
|
||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
// This is the operation definition file for LXLA.
|
||||||
|
//
|
||||||
|
// This file largely overlaps with hlo_ops.td at a logic level. It's tempting to
|
||||||
|
// merge these two files together, but we need to consider the following
|
||||||
|
// obstacles:
|
||||||
|
// * We need to have a common representation for arguments. That is to say,
|
||||||
|
// HLO_Array<X> translates to HLO_Tensor<X> in HLO dialect, and
|
||||||
|
// Arg<LHLO_Buffer<X>, "", [Mem(Read|Write)]> in LHLO. Array types within tuples
|
||||||
|
// also need to be transformed.
|
||||||
|
// * As of now, TableGen's dag functions are not sufficient to accomplish the
|
||||||
|
// one above.
|
||||||
|
// * Traits aren't identical, but need to be coped. For example,
|
||||||
|
// SameOperandAndResultType in HLO corresponds to SameTypeOperands in LHLO.
|
||||||
|
// * Also, currently HLO describes the API in XLA's client side, not service
|
||||||
|
// side. LHLO aims for the service side.
|
||||||
|
|
||||||
|
#ifndef LHLO_OPS
|
||||||
|
#define LHLO_OPS
|
||||||
|
|
||||||
|
include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpBase.td"
|
||||||
|
include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/SideEffectInterfaces.td"
|
||||||
|
include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/ViewLikeInterface.td"
|
||||||
|
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td"
|
||||||
|
|
||||||
|
def LHLO_Dialect : Dialect {
|
||||||
|
let name = "xla_lhlo";
|
||||||
|
let cppNamespace = "xla_lhlo";
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// XLA type definitions.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
// Any integer tensor types
|
||||||
|
def LHLO_IntBuffer : MemRefOf<[HLO_Int]>;
|
||||||
|
|
||||||
|
// Any floating-point tensor types
|
||||||
|
def LHLO_FpBuffer : MemRefOf<[AnyFloat]>;
|
||||||
|
|
||||||
|
def LHLO_ComplexBuffer : MemRefOf<[AnyComplex]>;
|
||||||
|
|
||||||
|
def LHLO_FpOrComplexBuffer : MemRefOf<[AnyFloat, AnyComplex]>;
|
||||||
|
|
||||||
|
def LHLO_PredBuffer : MemRefOf<[HLO_Pred]>;
|
||||||
|
|
||||||
|
// Any integer or floating-point tensor types
|
||||||
|
def LHLO_IntOrFpBuffer : MemRefOf<[HLO_Int, AnyFloat]>;
|
||||||
|
|
||||||
|
def LHLO_PredOrIntBuffer : MemRefOf<[HLO_Int, HLO_Pred]>;
|
||||||
|
|
||||||
|
def LHLO_Buffer : MemRefOf<[AnyFloat, AnySignlessInteger, AnyComplex]>;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// XLA nullary op definitions.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
class LHLO_Op<string mnemonic, list<OpTrait> traits> :
|
||||||
|
Op<LHLO_Dialect, mnemonic,
|
||||||
|
!listconcat([MemoryEffects<[MemRead, MemWrite]>], traits)>;
|
||||||
|
|
||||||
|
def LHLO_ConstOp : LHLO_Op<"constant", []>, BASE_HLO_ConstOp {
|
||||||
|
let arguments = (ins
|
||||||
|
ElementsAttr:$value,
|
||||||
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
def LHLO_IotaOp : LHLO_Op<"iota", []>, BASE_HLO_IotaOp {
|
||||||
|
let arguments = (ins I64Attr:$iota_dimension,
|
||||||
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output);
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// XLA unary elementwise op definitions.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions
|
||||||
|
|
||||||
|
class LHLO_UnaryElementwiseOp<string mnemonic,
|
||||||
|
Type BufferType = LHLO_Buffer,
|
||||||
|
list<OpTrait> traits = [SameTypeOperands]>
|
||||||
|
: LHLO_Op<mnemonic, traits> {
|
||||||
|
let arguments = (ins Arg<BufferType, "", [MemRead]>:$input,
|
||||||
|
Arg<BufferType, "", [MemWrite]>:$output);
|
||||||
|
}
|
||||||
|
|
||||||
|
def LHLO_AbsOp: LHLO_UnaryElementwiseOp<"abs">, BASE_HLO_AbsOp;
|
||||||
|
|
||||||
|
// TODO(timshen): add a custom verifier.
|
||||||
|
def LHLO_BitcastConvertOp:
|
||||||
|
LHLO_UnaryElementwiseOp<"bitcast_convert", LHLO_Buffer, [SameOperandsShape]>, BASE_HLO_BitcastConvertOp;
|
||||||
|
|
||||||
|
def LHLO_CeilOp: LHLO_UnaryElementwiseOp<"ceil", LHLO_FpBuffer>, BASE_HLO_CeilOp;
|
||||||
|
|
||||||
|
def LHLO_ClzOp: LHLO_UnaryElementwiseOp<"count_leading_zeros", LHLO_IntBuffer>, BASE_HLO_ClzOp;
|
||||||
|
|
||||||
|
// TODO(timshen): add a custom verifier.
|
||||||
|
def LHLO_ConvertOp : LHLO_UnaryElementwiseOp<"convert", LHLO_Buffer, [SameOperandsShape]>, BASE_HLO_ConvertOp;
|
||||||
|
|
||||||
|
def LHLO_CosOp: LHLO_UnaryElementwiseOp<"cosine", LHLO_FpOrComplexBuffer>, BASE_HLO_CosOp;
|
||||||
|
|
||||||
|
def LHLO_ExpOp: LHLO_UnaryElementwiseOp<"exponential", LHLO_FpOrComplexBuffer>, BASE_HLO_ExpOp;
|
||||||
|
|
||||||
|
def LHLO_Expm1Op: LHLO_UnaryElementwiseOp<"exponential_minus_one", LHLO_FpOrComplexBuffer>, BASE_HLO_Expm1Op;
|
||||||
|
|
||||||
|
def LHLO_FloorOp: LHLO_UnaryElementwiseOp<"floor", LHLO_FpBuffer>, BASE_HLO_FloorOp;
|
||||||
|
|
||||||
|
def LHLO_ImagOp: LHLO_Op<"imag", [SameOperandsShape]>, BASE_HLO_ImagOp {
|
||||||
|
let arguments = (ins Arg<LHLO_ComplexBuffer, "", [MemRead]>:$input,
|
||||||
|
Arg<LHLO_FpBuffer, "", [MemWrite]>:$output);
|
||||||
|
}
|
||||||
|
|
||||||
|
def LHLO_IsFiniteOp: LHLO_Op<"is_finite", [SameOperandsShape]>, BASE_HLO_IsFiniteOp {
|
||||||
|
let arguments = (ins Arg<LHLO_FpBuffer, "", [MemRead]>:$input,
|
||||||
|
Arg<LHLO_PredBuffer, "", [MemWrite]>:$output);
|
||||||
|
}
|
||||||
|
|
||||||
|
def LHLO_LogOp: LHLO_UnaryElementwiseOp<"log", LHLO_FpOrComplexBuffer>, BASE_HLO_LogOp;
|
||||||
|
|
||||||
|
def LHLO_Log1pOp: LHLO_UnaryElementwiseOp<"log_plus_one", LHLO_FpOrComplexBuffer>, BASE_HLO_Log1pOp;
|
||||||
|
|
||||||
|
def LHLO_NegOp: LHLO_UnaryElementwiseOp<"negate">, BASE_HLO_NegOp;
|
||||||
|
|
||||||
|
def LHLO_NotOp: LHLO_UnaryElementwiseOp<"not", LHLO_PredOrIntBuffer>, BASE_HLO_NotOp;
|
||||||
|
|
||||||
|
def LHLO_PopulationCountOp: LHLO_UnaryElementwiseOp<"popcnt", LHLO_IntBuffer>, BASE_HLO_PopulationCountOp;
|
||||||
|
|
||||||
|
def LHLO_RealOp: LHLO_Op<"real", [SameOperandsShape]>, BASE_HLO_RealOp {
|
||||||
|
let arguments = (ins Arg<LHLO_ComplexBuffer, "", [MemRead]>:$input,
|
||||||
|
Arg<LHLO_FpBuffer, "", [MemWrite]>:$output);
|
||||||
|
}
|
||||||
|
|
||||||
|
def LHLO_RoundOp: LHLO_UnaryElementwiseOp<"round_nearest_afz", LHLO_FpBuffer>, BASE_HLO_RoundOp;
|
||||||
|
|
||||||
|
def LHLO_RsqrtOp: LHLO_UnaryElementwiseOp<"rsqrt", LHLO_FpOrComplexBuffer>, BASE_HLO_RsqrtOp;
|
||||||
|
|
||||||
|
def LHLO_SqrtOp: LHLO_UnaryElementwiseOp<"sqrt", LHLO_FpOrComplexBuffer>, BASE_HLO_SqrtOp;
|
||||||
|
|
||||||
|
def LHLO_SignOp: LHLO_UnaryElementwiseOp<"sign">, BASE_HLO_SignOp;
|
||||||
|
|
||||||
|
def LHLO_SinOp: LHLO_UnaryElementwiseOp<"sine", LHLO_FpOrComplexBuffer>, BASE_HLO_SinOp;
|
||||||
|
|
||||||
|
def LHLO_TanhOp: LHLO_UnaryElementwiseOp<"tanh", LHLO_FpOrComplexBuffer>, BASE_HLO_TanhOp;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// XLA binary elementwise op definitions.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations
|
||||||
|
|
||||||
|
class LHLO_BinaryElementwiseOp<string mnemonic, Type BufferType = LHLO_Buffer,
|
||||||
|
list<OpTrait> traits = [SameTypeOperands]> :
|
||||||
|
LHLO_Op<mnemonic, traits> {
|
||||||
|
let arguments = (ins
|
||||||
|
Arg<BufferType, "", [MemRead]>:$lhs,
|
||||||
|
Arg<BufferType, "", [MemRead]>:$rhs,
|
||||||
|
Arg<BufferType, "", [MemWrite]>:$out,
|
||||||
|
OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
def LHLO_AddOp : LHLO_BinaryElementwiseOp<"add">, BASE_HLO_AddOp;
|
||||||
|
|
||||||
|
def LHLO_AndOp: LHLO_BinaryElementwiseOp<"and", LHLO_PredOrIntBuffer>, BASE_HLO_AndOp;
|
||||||
|
|
||||||
|
def LHLO_Atan2Op : LHLO_BinaryElementwiseOp<"atan2", LHLO_FpOrComplexBuffer>, BASE_HLO_Atan2Op;
|
||||||
|
|
||||||
|
def LHLO_ComplexOp: LHLO_Op<"complex", [SameOperandsShape]>, BASE_HLO_ComplexOp {
|
||||||
|
let arguments = (ins
|
||||||
|
Arg<LHLO_FpBuffer, "", [MemRead]>:$lhs,
|
||||||
|
Arg<LHLO_FpBuffer, "", [MemRead]>:$rhs,
|
||||||
|
Arg<LHLO_ComplexBuffer, "", [MemWrite]>:$output,
|
||||||
|
OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
def LHLO_DivOp : LHLO_BinaryElementwiseOp<"divide">, BASE_HLO_DivOp;
|
||||||
|
|
||||||
|
def LHLO_MaxOp : LHLO_BinaryElementwiseOp<"maximum">, BASE_HLO_MaxOp;
|
||||||
|
|
||||||
|
def LHLO_MinOp : LHLO_BinaryElementwiseOp<"minimum">, BASE_HLO_MinOp;
|
||||||
|
|
||||||
|
def LHLO_MulOp : LHLO_BinaryElementwiseOp<"multiply">, BASE_HLO_MulOp;
|
||||||
|
|
||||||
|
def LHLO_OrOp : LHLO_BinaryElementwiseOp<"or", LHLO_PredOrIntBuffer>, BASE_HLO_OrOp;
|
||||||
|
|
||||||
|
def LHLO_PowOp : LHLO_BinaryElementwiseOp<"power">, BASE_HLO_PowOp;
|
||||||
|
|
||||||
|
def LHLO_RemOp : LHLO_BinaryElementwiseOp<"remainder", LHLO_IntOrFpBuffer>, BASE_HLO_RemOp;
|
||||||
|
|
||||||
|
def LHLO_ShiftLeftOp : LHLO_BinaryElementwiseOp<"shift_left", LHLO_IntBuffer>, BASE_HLO_ShiftLeftOp;
|
||||||
|
|
||||||
|
def LHLO_ShiftRightArithmeticOp : LHLO_BinaryElementwiseOp<"shift_right_arithmetic", LHLO_IntBuffer>, BASE_HLO_ShiftRightArithmeticOp;
|
||||||
|
|
||||||
|
def LHLO_ShiftRightLogicalOp : LHLO_BinaryElementwiseOp<"shift_right_logical", LHLO_IntBuffer>, BASE_HLO_ShiftRightLogicalOp;
|
||||||
|
|
||||||
|
def LHLO_SubOp : LHLO_BinaryElementwiseOp<"subtract">, BASE_HLO_SubOp;
|
||||||
|
|
||||||
|
def LHLO_XorOp : LHLO_BinaryElementwiseOp<"xor", LHLO_PredOrIntBuffer>, BASE_HLO_XorOp;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// XLA control flow op definitions.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
// TODO(b/139813999): specify required function signature in a type-safe way.
|
||||||
|
def LHLO_ReduceOp: LHLO_Op<"reduce", [
|
||||||
|
SameVariadicOperandSize,
|
||||||
|
SingleBlockImplicitTerminator<"TerminatorOp">
|
||||||
|
]>, BASE_HLO_ReduceOp {
|
||||||
|
let arguments = (ins
|
||||||
|
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$operands,
|
||||||
|
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$init_values,
|
||||||
|
Arg<Variadic<LHLO_Buffer>, "", [MemWrite]>:$out,
|
||||||
|
I64ElementsAttr:$dimensions
|
||||||
|
);
|
||||||
|
|
||||||
|
let regions = (region SizedRegion<1>:$body);
|
||||||
|
}
|
||||||
|
|
||||||
|
def LHLO_ReduceWindowOp: LHLO_Op<"reduce_window", [
|
||||||
|
SingleBlockImplicitTerminator<"TerminatorOp">
|
||||||
|
]>, BASE_HLO_ReduceWindowOp {
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
||||||
|
Arg<LHLO_Buffer, "", [MemRead]>:$init_value,
|
||||||
|
Arg<LHLO_Buffer, "", [MemWrite]>:$out,
|
||||||
|
I64ElementsAttr:$window_dimensions,
|
||||||
|
// If strides or dilations attributes are missing then the default value is
|
||||||
|
// one for each of the input dimensions. Similarly, padding values are zero
|
||||||
|
// for both low and high in each of the dimensions, if not specified.
|
||||||
|
OptionalAttr<I64ElementsAttr>:$window_strides,
|
||||||
|
OptionalAttr<I64ElementsAttr>:$base_dilations,
|
||||||
|
OptionalAttr<I64ElementsAttr>:$window_dilations,
|
||||||
|
OptionalAttr<I64ElementsAttr>:$padding
|
||||||
|
);
|
||||||
|
|
||||||
|
let regions = (region SizedRegion<1>:$body);
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(timshen): Add a custom parser to hide operand_segment_sizes. For example,
|
||||||
|
// A tuple-like pattern match syntax could work:
|
||||||
|
// xla_lhlo.case %index, (%input0, %input1, %input2), (%output0, %output1) {
|
||||||
|
// ...
|
||||||
|
// }, {
|
||||||
|
// ...
|
||||||
|
// } : (type_input0, type_input1, type_input2, type_output0, type_output1) -> ()
|
||||||
|
def LHLO_CaseOp: LHLO_Op<"case", [
|
||||||
|
AttrSizedOperandSegments,
|
||||||
|
SingleBlockImplicitTerminator<"TerminatorOp">
|
||||||
|
]>, BASE_HLO_CaseOp {
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
Arg<LHLO_Buffer, "", [MemRead]>:$index,
|
||||||
|
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$branch_operands,
|
||||||
|
Arg<Variadic<LHLO_Buffer>, "", [MemWrite]>:$out
|
||||||
|
);
|
||||||
|
|
||||||
|
let regions = (region VariadicRegion<SizedRegion<1>>:$branches);
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(timshen): Add a custom syntax for this.
|
||||||
|
def LHLO_WhileOp: LHLO_Op<"while", [SameVariadicOperandSize]>,
|
||||||
|
BASE_HLO_WhileOp {
|
||||||
|
let arguments = (ins
|
||||||
|
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$val,
|
||||||
|
Arg<Variadic<LHLO_Buffer>, "", [MemWrite]>:$output
|
||||||
|
);
|
||||||
|
|
||||||
|
let regions = (region SizedRegion<1>:$cond, SizedRegion<1>:$body);
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// XLA tuple op definitions.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def LHLO_CompareOp: LHLO_Op<"compare", []>, BASE_HLO_CompareOp {
|
||||||
|
let arguments = (ins
|
||||||
|
Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
|
||||||
|
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
|
||||||
|
Arg<LHLO_PredBuffer, "", [MemWrite]>:$out,
|
||||||
|
OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions,
|
||||||
|
HLO_ComparisonDirectionAttr:$comparison_direction
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// XLA Slice definitions.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def LHLO_SliceOp: LHLO_Op<
|
||||||
|
"slice",
|
||||||
|
[AllTypesMatch<["start_indices", "limit_indices", "strides"]>]> {
|
||||||
|
let arguments = (ins
|
||||||
|
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
||||||
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
||||||
|
I64ElementsAttr:$start_indices,
|
||||||
|
I64ElementsAttr:$limit_indices,
|
||||||
|
I64ElementsAttr:$strides
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
def HLO_DynamicUpdateSliceOp: LHLO_Op<"dynamic-update-slice", []> {
|
||||||
|
let arguments = (ins
|
||||||
|
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
||||||
|
Arg<LHLO_Buffer, "", [MemRead]>:$update,
|
||||||
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
||||||
|
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$start_indices
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// StaticMemRefCastOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def HLO_StaticMemRefCastOp: Op<LHLO_Dialect, "static_memref_cast",
|
||||||
|
[NoSideEffect, DeclareOpInterfaceMethods<ViewLikeOpInterface>]> {
|
||||||
|
let summary = [{
|
||||||
|
"modifies the offset, sizes and strides of a statically shaped memref.
|
||||||
|
}];
|
||||||
|
let description = [{
|
||||||
|
Allows to modify the offset, sizes and strides of a statically shaped memref.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```mlir
|
||||||
|
%buf_transformed =
|
||||||
|
xla_lhlo.static_memref_cast %buf
|
||||||
|
: memref<1x5xf32> -> memref<5xf32, offset: 2, strides: [1]>
|
||||||
|
|
||||||
|
// The result of the op is a rank-1 memref with `[5]` shape, stride 1 and
|
||||||
|
// offset 2.
|
||||||
|
```
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins Arg<LHLO_Buffer, "", []>:$operand);
|
||||||
|
let results = (outs Res<LHLO_Buffer, "", []>:$result);
|
||||||
|
|
||||||
|
let builders = [OpBuilder<
|
||||||
|
"OpBuilder &builder, OperationState &result, MemRefType resultType, " #
|
||||||
|
"Value operand", [{
|
||||||
|
result.addOperands(operand);
|
||||||
|
result.types.push_back(resultType);
|
||||||
|
}]>];
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
MemRefType getType() { return getResult().getType().cast<MemRefType>(); }
|
||||||
|
}];
|
||||||
|
|
||||||
|
let verifier = [{ return Verify(*this); }];
|
||||||
|
let assemblyFormat = [{
|
||||||
|
$operand attr-dict `:` type($operand) `->` type($result)
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// DynamicMemRefCastOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def HLO_DynamicMemRefCastOp: Op<LHLO_Dialect, "dynamic_memref_cast",
|
||||||
|
[SameVariadicOperandSize, NoSideEffect,
|
||||||
|
DeclareOpInterfaceMethods<ViewLikeOpInterface>]> {
|
||||||
|
let summary = "dynamic memref cast operation";
|
||||||
|
let description = [{
|
||||||
|
Change sizes and strides of a memref using the values computed in runtime.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```mlir
|
||||||
|
%buf_transformed =
|
||||||
|
xla_lhlo.dynamic_memref_cast %buf(%size_X, %size_Y)[%step_X, %step_Y]
|
||||||
|
: memref<?x?xf32> -> memref<?x?xf32, offset: 0, strides: [?, ?]>
|
||||||
|
// The result of the op is a type-erased memref with `[%size_X, %size_Y]`
|
||||||
|
// shape and `[%step_X, %step_Y]` strides. The offset will be inherited
|
||||||
|
// from the input.
|
||||||
|
```
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
Arg<LHLO_Buffer, "", []>:$operand,
|
||||||
|
Variadic<Index>:$sizes,
|
||||||
|
Variadic<Index>:$strides
|
||||||
|
);
|
||||||
|
let results = (outs Res<LHLO_Buffer, "", []>:$result);
|
||||||
|
|
||||||
|
let builders = [OpBuilder<
|
||||||
|
"OpBuilder &builder, OperationState &result, MemRefType resultType, " #
|
||||||
|
"Value operand, ValueRange sizes, ValueRange strides", [{
|
||||||
|
result.addOperands(operand);
|
||||||
|
result.addOperands(sizes);
|
||||||
|
result.addOperands(strides);
|
||||||
|
result.types.push_back(resultType);
|
||||||
|
}]>];
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
MemRefType getType() { return getResult().getType().cast<MemRefType>(); }
|
||||||
|
}];
|
||||||
|
|
||||||
|
let verifier = [{ return Verify(*this); }];
|
||||||
|
let assemblyFormat = [{
|
||||||
|
$operand `(` $sizes `)` `[` $strides `]` attr-dict `:` type($operand) `->`
|
||||||
|
type($result)
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// XLA Other op definitions.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def LHLO_BatchNormGradOp : LHLO_Op<"batch_norm_grad", []>,
|
||||||
|
BASE_HLO_BatchNormGradOp {
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
||||||
|
Arg<LHLO_Buffer, "", [MemRead]>:$scale,
|
||||||
|
Arg<LHLO_Buffer, "", [MemRead]>:$mean,
|
||||||
|
Arg<LHLO_Buffer, "", [MemRead]>:$variance,
|
||||||
|
Arg<LHLO_Buffer, "", [MemRead]>:$grad_output,
|
||||||
|
Arg<LHLO_Buffer, "", [MemWrite]>:$grad_operand, // gradient of $operand.
|
||||||
|
Arg<LHLO_Buffer, "", [MemWrite]>:$grad_scale,
|
||||||
|
Arg<LHLO_Buffer, "", [MemWrite]>:$grad_offset,
|
||||||
|
F32Attr:$epsilon,
|
||||||
|
I64Attr:$feature_index
|
||||||
|
);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
def LHLO_BatchNormInferenceOp : LHLO_Op<"batch_norm_inference", []>,
|
||||||
|
BASE_HLO_BatchNormInferenceOp {
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
||||||
|
Arg<LHLO_Buffer, "", [MemRead]>:$scale,
|
||||||
|
Arg<LHLO_Buffer, "", [MemRead]>:$offset,
|
||||||
|
Arg<LHLO_Buffer, "", [MemRead]>:$mean,
|
||||||
|
Arg<LHLO_Buffer, "", [MemRead]>:$variance,
|
||||||
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
||||||
|
F32Attr:$epsilon,
|
||||||
|
I64Attr:$feature_index
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
def LHLO_BatchNormTrainingOp : LHLO_Op<"batch_norm_training", []>,
|
||||||
|
BASE_HLO_BatchNormTrainingOp {
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
||||||
|
Arg<LHLO_Buffer, "", [MemRead]>:$scale,
|
||||||
|
Arg<LHLO_Buffer, "", [MemRead]>:$offset,
|
||||||
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
||||||
|
Arg<LHLO_Buffer, "", [MemWrite]>:$batch_mean,
|
||||||
|
Arg<LHLO_Buffer, "", [MemWrite]>:$batch_var,
|
||||||
|
F32Attr:$epsilon,
|
||||||
|
I64Attr:$feature_index
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(timshen): add a custom verifier.
|
||||||
|
def LHLO_BitcastOp: LHLO_Op<"bitcast", []> {
|
||||||
|
let arguments = (ins Arg<LHLO_Buffer, "", [MemRead]>:$input,
|
||||||
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output);
|
||||||
|
}
|
||||||
|
|
||||||
|
def LHLO_BroadcastOp : LHLO_Op<"broadcast",
|
||||||
|
[]>, BASE_HLO_BroadcastOp {
|
||||||
|
let arguments = (ins
|
||||||
|
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
||||||
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
||||||
|
I64ElementsAttr:$broadcast_sizes
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
def LHLO_BroadcastInDimOp : LHLO_Op<"broadcast_in_dim",
|
||||||
|
[]>, BASE_HLO_BroadcastInDimOp {
|
||||||
|
let arguments = (ins
|
||||||
|
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
||||||
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
||||||
|
BroadcastDimAttr:$broadcast_dimensions
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
def LHLO_ClampOp : LHLO_Op<"clamp", []>, BASE_HLO_ClampOp {
|
||||||
|
let arguments = (ins
|
||||||
|
Arg<LHLO_Buffer, "", [MemRead]>:$min,
|
||||||
|
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
||||||
|
Arg<LHLO_Buffer, "", [MemRead]>:$max,
|
||||||
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
def LHLO_ConcatenateOp : LHLO_Op<"concatenate", []>, BASE_HLO_ConcatenateOp {
|
||||||
|
let arguments = (ins
|
||||||
|
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$val,
|
||||||
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
||||||
|
I64Attr:$dimension
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(bondhugula): Make this struct dialect independent so that it can be
|
||||||
|
// shared between the HLO and LHLO dialects.
|
||||||
|
def ConvDimensionNumbers : StructAttr<"ConvDimensionNumbers", LHLO_Dialect, [
|
||||||
|
StructFieldAttr<"input_batch_dimension",I64Attr>,
|
||||||
|
StructFieldAttr<"input_feature_dimension", I64Attr>,
|
||||||
|
StructFieldAttr<"input_spatial_dimensions", I64ElementsAttr>,
|
||||||
|
StructFieldAttr<"kernel_input_feature_dimension", I64Attr>,
|
||||||
|
StructFieldAttr<"kernel_output_feature_dimension", I64Attr>,
|
||||||
|
StructFieldAttr<"kernel_spatial_dimensions", I64ElementsAttr>,
|
||||||
|
StructFieldAttr<"output_batch_dimension", I64Attr>,
|
||||||
|
StructFieldAttr<"output_feature_dimension", I64Attr>,
|
||||||
|
StructFieldAttr<"output_spatial_dimensions", I64ElementsAttr>] > {
|
||||||
|
|
||||||
|
let description = "Structure of dimension information for conv op";
|
||||||
|
}
|
||||||
|
|
||||||
|
def LHLO_ConvOp : LHLO_Op<"convolution", []>, BASE_HLO_ConvOp {
|
||||||
|
let arguments = (ins
|
||||||
|
Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
|
||||||
|
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
|
||||||
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
||||||
|
// Default value: one for each of the spatial dimension.
|
||||||
|
OptionalAttr<I64ElementsAttr>:$window_strides,
|
||||||
|
// Default value: zero for each of the spatial dimension.
|
||||||
|
OptionalAttr<I64ElementsAttr>:$padding,
|
||||||
|
// Default value: one for each of the spatial dimension.
|
||||||
|
OptionalAttr<I64ElementsAttr>:$lhs_dilation,
|
||||||
|
// Default value: one for each of the spatial dimension.
|
||||||
|
OptionalAttr<I64ElementsAttr>:$rhs_dilation,
|
||||||
|
ConvDimensionNumbers:$dimension_numbers,
|
||||||
|
I64Attr:$feature_group_count,
|
||||||
|
I64Attr:$batch_group_count,
|
||||||
|
HLO_PrecisionConfigAttr:$precision_config
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
def LHLO_CopyOp: LHLO_Op<"copy", []>, BASE_HLO_CopyOp {
|
||||||
|
let arguments = (ins
|
||||||
|
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
||||||
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
def LHLO_DotOp: LHLO_Op<"dot", []>, BASE_HLO_DotOp {
|
||||||
|
let arguments = (ins
|
||||||
|
Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
|
||||||
|
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
|
||||||
|
HLO_PrecisionConfigAttr:$precision_config,
|
||||||
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
def LHLO_GatherOp: LHLO_Op<"gather", []>, BASE_HLO_GatherOp {
|
||||||
|
let arguments = (ins
|
||||||
|
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
||||||
|
Arg<LHLO_IntBuffer, "", [MemRead]>:$start_indices,
|
||||||
|
I64Attr:$index_vector_dim,
|
||||||
|
I64ElementsAttr:$offset_dims,
|
||||||
|
I64ElementsAttr:$slice_sizes,
|
||||||
|
I64ElementsAttr:$collapsed_slice_dims,
|
||||||
|
I64ElementsAttr:$start_index_map,
|
||||||
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
def LHLO_ReshapeOp: LHLO_Op<"reshape", []>, BASE_HLO_ReshapeOp {
|
||||||
|
let arguments = (ins
|
||||||
|
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
||||||
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
def LHLO_ScatterOp: LHLO_Op<"scatter", []>, BASE_HLO_ScatterOp {
|
||||||
|
let arguments = (ins
|
||||||
|
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
||||||
|
Arg<LHLO_Buffer, "", [MemRead]>:$scatter_indices,
|
||||||
|
Arg<LHLO_Buffer, "", [MemRead]>:$updates,
|
||||||
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
||||||
|
ScatterDimensionNumbers<LHLO_Dialect>:$scatter_dimension_numbers,
|
||||||
|
DefaultValuedAttr<BoolAttr, "false">:$indices_are_sorted,
|
||||||
|
DefaultValuedAttr<BoolAttr, "false">:$unique_indices
|
||||||
|
);
|
||||||
|
|
||||||
|
let regions = (region SizedRegion<1>:$update_computation);
|
||||||
|
}
|
||||||
|
|
||||||
|
def LHLO_SelectOp: LHLO_Op<"select", []>, BASE_HLO_SelectOp {
|
||||||
|
let arguments = (ins
|
||||||
|
Arg<LHLO_PredBuffer, "", [MemRead]>:$pred,
|
||||||
|
Arg<LHLO_Buffer, "", [MemRead]>:$on_true,
|
||||||
|
Arg<LHLO_Buffer, "", [MemRead]>:$on_false,
|
||||||
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
def LHLO_SelectAndScatterOp: LHLO_Op<"select_and_scatter", []>,
|
||||||
|
BASE_HLO_SelectAndScatterOp {
|
||||||
|
let arguments = (ins
|
||||||
|
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
||||||
|
Arg<LHLO_Buffer, "", [MemRead]>:$source,
|
||||||
|
Arg<LHLO_Buffer, "", [MemRead]>:$init_value,
|
||||||
|
Arg<LHLO_Buffer, "", [MemWrite]>:$out,
|
||||||
|
OptionalAttr<I64ElementsAttr>:$window_dimensions,
|
||||||
|
OptionalAttr<I64ElementsAttr>:$window_strides,
|
||||||
|
OptionalAttr<I64ElementsAttr>:$padding
|
||||||
|
);
|
||||||
|
|
||||||
|
let regions = (region SizedRegion<1>:$select, SizedRegion<1>:$scatter);
|
||||||
|
}
|
||||||
|
|
||||||
|
def LHLO_ReverseOp: LHLO_Op<"reverse", []>, BASE_HLO_ReverseOp {
|
||||||
|
let arguments = (ins
|
||||||
|
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
||||||
|
I64ElementsAttr:$dimensions,
|
||||||
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
def LHLO_PadOp: LHLO_Op<"pad", []>, BASE_HLO_PadOp {
|
||||||
|
let arguments = (ins
|
||||||
|
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
||||||
|
Arg<LHLO_Buffer, "", [MemRead]>:$padding_value,
|
||||||
|
I64ElementsAttr:$edge_padding_low,
|
||||||
|
I64ElementsAttr:$edge_padding_high,
|
||||||
|
I64ElementsAttr:$interior_padding,
|
||||||
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
def LHLO_TransposeOp: LHLO_Op<"transpose", []>, BASE_HLO_TransposeOp {
|
||||||
|
let arguments = (ins
|
||||||
|
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
||||||
|
I64ElementsAttr:$permutation,
|
||||||
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
def LHLO_ReducePrecisionOp: LHLO_Op<"reduce_precision", [SameTypeOperands]>,
|
||||||
|
BASE_HLO_ReducePrecisionOp {
|
||||||
|
let arguments = (ins
|
||||||
|
Arg<LHLO_FpBuffer, "", [MemRead]>:$operand,
|
||||||
|
Arg<LHLO_FpBuffer, "", [MemWrite]>:$output,
|
||||||
|
I32Attr:$exponent_bits,
|
||||||
|
I32Attr:$mantissa_bits
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
def LHLO_AllReduceOp : LHLO_Op<"all_reduce", [SameTypeOperands]>,
|
||||||
|
BASE_HLO_AllReduceOp {
|
||||||
|
let arguments = (ins
|
||||||
|
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
||||||
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
||||||
|
I64ElementsAttr:$replica_groups,
|
||||||
|
DefaultValuedAttr<BoolAttr, "false">:$constrain_layout,
|
||||||
|
OptionalAttr<ChannelHandle<LHLO_Dialect>>:$channel_id,
|
||||||
|
DefaultValuedAttr<BoolAttr, "false">:$use_global_device_ids
|
||||||
|
);
|
||||||
|
let regions = (region SizedRegion<1>:$computation);
|
||||||
|
}
|
||||||
|
|
||||||
|
def LHLO_CollectivePermuteOp: LHLO_Op<"collective_permute", [SameTypeOperands]>,
|
||||||
|
BASE_HLO_CollectivePermuteOp {
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
||||||
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
||||||
|
I64ElementsAttr:$source_target_pairs,
|
||||||
|
OptionalAttr<ChannelHandle<LHLO_Dialect>>:$channel_id
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
def LHLO_FftOp: LHLO_Op<"fft", []>, BASE_HLO_FftOp {
|
||||||
|
let arguments = (ins
|
||||||
|
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
||||||
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
||||||
|
HLO_FftTypeAttr:$fft_type,
|
||||||
|
I64ElementsAttr:$fft_length
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
def LHLO_CholeskyOp: LHLO_Op<"cholesky", [SameOperandsElementType]>, BASE_HLO_CholeskyOp {
|
||||||
|
let arguments = (ins
|
||||||
|
Arg<LHLO_FpOrComplexBuffer, "", [MemRead]>:$a,
|
||||||
|
Arg<LHLO_FpOrComplexBuffer, "", [MemWrite]>:$output,
|
||||||
|
DefaultValuedAttr<BoolAttr, "false">:$lower
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
def LHLO_Infeed: LHLO_Op<"infeed", []>, BASE_HLO_InfeedOp {
|
||||||
|
let arguments = (ins
|
||||||
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
||||||
|
DefaultValuedAttr<StrAttr, "">:$config
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
def LHLO_Outfeed: LHLO_Op<"outfeed", []> {
|
||||||
|
let arguments = (ins
|
||||||
|
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
||||||
|
DefaultValuedAttr<StrAttr, "">:$config
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
def LHLO_ReplicaIdOp : LHLO_Op<"replica_id", []>, BASE_HLO_ReplicaIdOp {
|
||||||
|
let arguments = (ins Arg<MemRefOf<[UI32]>, "", [MemWrite]>);
|
||||||
|
}
|
||||||
|
|
||||||
|
def LHLO_TriangularSolveOp: LHLO_Op<"triangular_solve", [SameOperandsElementType]>,
|
||||||
|
BASE_HLO_TriangularSolveOp {
|
||||||
|
let arguments = (ins
|
||||||
|
Arg<LHLO_FpOrComplexBuffer, "", [MemRead]>:$a,
|
||||||
|
Arg<LHLO_FpOrComplexBuffer, "", [MemRead]>:$b,
|
||||||
|
Arg<LHLO_FpOrComplexBuffer, "", [MemWrite]>:$output,
|
||||||
|
BoolAttr:$left_side,
|
||||||
|
BoolAttr:$lower,
|
||||||
|
BoolAttr:$unit_diagonal,
|
||||||
|
HLO_TransposeAttr:$transpose_a
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(timshen): add a custom verifier.
|
||||||
|
def LHLO_MapOp: LHLO_Op<"map", [SameOperandsShape]>, BASE_HLO_MapOp {
|
||||||
|
let arguments = (ins
|
||||||
|
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$operands,
|
||||||
|
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
||||||
|
I64ElementsAttr:$dimensions
|
||||||
|
);
|
||||||
|
let regions = (region SizedRegion<1>:$computation);
|
||||||
|
}
|
||||||
|
|
||||||
|
def LHLO_RngGetAndUpdateStateOp: LHLO_Op<"rng_get_and_update_state", []> {
|
||||||
|
let arguments = (ins
|
||||||
|
Arg<MemRefOf<[UI64]>, "", [MemRead, MemWrite]>:$state,
|
||||||
|
I64Attr:$delta
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(timshen): add a custom verifier.
|
||||||
|
def LHLO_SortOp: LHLO_Op<"sort", [SameVariadicOperandSize, SameOperandsShape]>, BASE_HLO_SortOp {
|
||||||
|
let arguments = (ins
|
||||||
|
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$operands,
|
||||||
|
Arg<Variadic<LHLO_Buffer>, "", [MemWrite]>:$output,
|
||||||
|
DefaultValuedAttr<I64Attr, "-1">:$dimension,
|
||||||
|
DefaultValuedAttr<BoolAttr, "false">:$is_stable
|
||||||
|
);
|
||||||
|
|
||||||
|
let regions = (region SizedRegion<1>:$comparator);
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Late operations
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def FusionOp : LHLO_Op<"fusion", [SingleBlockImplicitTerminator<"TerminatorOp">]> {
|
||||||
|
let summary = "Fusion operator";
|
||||||
|
let description = [{
|
||||||
|
Models the fusion instruction generated by the XLA compiler's fusion pass.
|
||||||
|
|
||||||
|
Fusion instructions are generated by the fusion pass of the XLA compiler.
|
||||||
|
They serve as a hint to the backend that it is beneficial to emit the
|
||||||
|
contained instructions into a single loop nest or kernel. The XLA fusion
|
||||||
|
pass is designed such that it only generates fusion nodes that can be
|
||||||
|
handled by the XLA compilers backends.
|
||||||
|
The XLA runtime expects this hint to be followed, as it expects a single
|
||||||
|
kernel per HLO instruction. This restriction might be lifted in the future.
|
||||||
|
}];
|
||||||
|
let regions = (region SizedRegion<1>:$region);
|
||||||
|
|
||||||
|
let skipDefaultBuilders = 1;
|
||||||
|
let builders = [
|
||||||
|
OpBuilder<"OpBuilder &builder, OperationState &result, "
|
||||||
|
"ArrayRef<NamedAttribute> attributes">
|
||||||
|
];
|
||||||
|
}
|
||||||
|
|
||||||
|
def TerminatorOp :
|
||||||
|
LHLO_Op<"terminator", [Terminator]> {
|
||||||
|
let summary = "LHLO termination operation";
|
||||||
|
let description = [{
|
||||||
|
Terminator operation for the LHLO dialect.
|
||||||
|
}];
|
||||||
|
let builders = [OpBuilder<
|
||||||
|
"OpBuilder &b, OperationState &result, ValueRange operands",
|
||||||
|
[{ build(b, result, llvm::None, operands, llvm::None); }]
|
||||||
|
>];
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // LHLO_OPS
|
|
@ -0,0 +1,49 @@
|
||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_BROADCAST_UTILS_H_
|
||||||
|
#define TENSORFLOW_COMPILER_MLIR_XLA_IR_BROADCAST_UTILS_H_
|
||||||
|
|
||||||
|
// Utilities relating to implementing HLO broadcasting.
|
||||||
|
// Note: This file should not depend on any non-MLIR TensorFlow libraries.
|
||||||
|
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Location.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/InferTypeOpInterface.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Support/LLVM.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace xla {
|
||||||
|
|
||||||
|
// Checks whether the given operand types and broadcast_dims attr represent a
|
||||||
|
// legal combination for "numpy" style broadcasting (where 1-dims are prepended
|
||||||
|
// to the smaller ranked operand until it is of the same rank as the larger).
|
||||||
|
// See: https://docs.scipy.org/doc/numpy/reference/ufuncs.html
|
||||||
|
bool IsLegalNumpyRankedBroadcast(Value lhs, Value rhs,
|
||||||
|
DenseIntElementsAttr broadcast_dims);
|
||||||
|
|
||||||
|
// Emits shape dialect ops to compute the result shape for a broadcasting
|
||||||
|
// binary elementwise op which broadcasts according to "numpy" semantics
|
||||||
|
// (see above), returning an extents tensor of the resulting shape.
|
||||||
|
Value ComputeBinaryElementwiseBroadcastingResultExtents(Location loc, Value lhs,
|
||||||
|
Value rhs,
|
||||||
|
OpBuilder& builder);
|
||||||
|
|
||||||
|
} // namespace xla
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_COMPILER_MLIR_XLA_IR_BROADCAST_UTILS_H_
|
|
@ -0,0 +1,33 @@
|
||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_CONVERT_OP_FOLDER_H_
|
||||||
|
#define TENSORFLOW_COMPILER_MLIR_XLA_CONVERT_OP_FOLDER_H_
|
||||||
|
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace xla {
|
||||||
|
|
||||||
|
// Converts the given elements attr to the specified elements type.
|
||||||
|
// Requires type of the elements and new_type to be either integer or float
|
||||||
|
// type.
|
||||||
|
mlir::ElementsAttr ConvertElementsAttr(const mlir::ElementsAttr& elements,
|
||||||
|
mlir::Type new_type);
|
||||||
|
} // namespace xla
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_COMPILER_MLIR_XLA_CONVERT_OP_FOLDER_H_
|
|
@ -0,0 +1,74 @@
|
||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_UTILS_H_
|
||||||
|
#define TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_UTILS_H_
|
||||||
|
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/TypeUtilities.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace xla {
|
||||||
|
|
||||||
|
// Computes the broadcast dimensions attr for an elementwise binary operator
|
||||||
|
// between two ranked tensors.
|
||||||
|
// If `allow_empty` is true, then null can be returned to mean that the
|
||||||
|
// broadcast is an "identity".
|
||||||
|
mlir::DenseIntElementsAttr getBroadcastDimensionsAttr(mlir::Builder* b,
|
||||||
|
mlir::Value x,
|
||||||
|
mlir::Value y,
|
||||||
|
bool allow_empty = true);
|
||||||
|
|
||||||
|
// Get a constant splat for the given value of type. Requires value to be of
|
||||||
|
// type static shaped RankedTensorType.
|
||||||
|
template <typename T>
|
||||||
|
static ElementsAttr getSplat(Builder* b, RankedTensorType ty, T constant) {
|
||||||
|
Type element_ty = getElementTypeOrSelf(ty);
|
||||||
|
|
||||||
|
if (element_ty.isSignlessInteger())
|
||||||
|
return DenseElementsAttr::get(ty, b->getIntegerAttr(element_ty, constant));
|
||||||
|
|
||||||
|
if (element_ty.isa<FloatType>())
|
||||||
|
return DenseElementsAttr::get(ty, b->getFloatAttr(element_ty, constant));
|
||||||
|
|
||||||
|
if (auto complex_ty = element_ty.dyn_cast<ComplexType>()) {
|
||||||
|
auto complex_element_ty = complex_ty.getElementType();
|
||||||
|
if (complex_element_ty.isF32())
|
||||||
|
return DenseElementsAttr::get(ty,
|
||||||
|
static_cast<std::complex<float>>(constant));
|
||||||
|
if (complex_element_ty.isF64())
|
||||||
|
return DenseElementsAttr::get(
|
||||||
|
ty, static_cast<std::complex<double>>(constant));
|
||||||
|
}
|
||||||
|
llvm_unreachable("unhandled element type");
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static ElementsAttr getSplat(Builder* b, Value val, T constant) {
|
||||||
|
return getSplat(b, val.getType().cast<RankedTensorType>(), constant);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns DenseElementsAttr of rank zero with the given element type and the
|
||||||
|
// value.
|
||||||
|
// Requires `ty` to be either FloatType of IntegerType.
|
||||||
|
DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value);
|
||||||
|
|
||||||
|
} // namespace xla
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_UTILS_H_
|
|
@ -0,0 +1,278 @@
|
||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
||||||
|
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Diagnostics.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/TypeUtilities.h"
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/broadcast_utils.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace xla_chlo {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static LogicalResult Verify(T op) {
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// BinaryOps
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// Gets the resulting type from a broadcast between two types.
|
||||||
|
static Type GetBroadcastType(Type x, Type y, Type element_type,
|
||||||
|
DenseIntElementsAttr broadcast_dimensions_attr) {
|
||||||
|
auto x_ranked = x.dyn_cast<RankedTensorType>();
|
||||||
|
auto y_ranked = y.dyn_cast<RankedTensorType>();
|
||||||
|
if (!x_ranked || !y_ranked) {
|
||||||
|
return UnrankedTensorType::get(element_type);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto shape_x = x_ranked.getShape();
|
||||||
|
auto shape_y = y_ranked.getShape();
|
||||||
|
|
||||||
|
if (shape_x.size() == shape_y.size()) {
|
||||||
|
llvm::SmallVector<int64_t, 4> out_shape(shape_x.size());
|
||||||
|
for (int i = 0, e = shape_x.size(); i < e; i++) {
|
||||||
|
auto x_val = shape_x[i];
|
||||||
|
auto y_val = shape_y[i];
|
||||||
|
if (x_val == -1 || y_val == -1) {
|
||||||
|
out_shape[i] = -1;
|
||||||
|
} else {
|
||||||
|
out_shape[i] = std::max(x_val, y_val);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return RankedTensorType::get(out_shape, element_type);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto shape_large = shape_x.size() > shape_y.size() ? shape_x : shape_y;
|
||||||
|
auto shape_small = shape_x.size() <= shape_y.size() ? shape_x : shape_y;
|
||||||
|
|
||||||
|
llvm::SmallVector<int64_t, 4> broadcast_dimensions;
|
||||||
|
if (broadcast_dimensions_attr) {
|
||||||
|
// Explicit broadcast dimensions.
|
||||||
|
for (const APInt& int_value : broadcast_dimensions_attr.getIntValues()) {
|
||||||
|
broadcast_dimensions.push_back(int_value.getSExtValue());
|
||||||
|
}
|
||||||
|
if (broadcast_dimensions.size() != shape_small.size()) {
|
||||||
|
// Signal illegal broadcast_dimensions as unranked.
|
||||||
|
return UnrankedTensorType::get(element_type);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// If no broadcast dimensions, assume "numpy" broadcasting.
|
||||||
|
broadcast_dimensions = llvm::to_vector<4>(llvm::seq<int64_t>(
|
||||||
|
shape_large.size() - shape_small.size(), shape_large.size()));
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::SmallVector<int64_t, 4> out_shape(shape_large.begin(),
|
||||||
|
shape_large.end());
|
||||||
|
|
||||||
|
// Update according to the broadcast dimensions.
|
||||||
|
for (auto index_pair : llvm::enumerate(broadcast_dimensions)) {
|
||||||
|
auto old_value = out_shape[index_pair.value()];
|
||||||
|
auto new_value = shape_small[index_pair.index()];
|
||||||
|
if (old_value != -1 && (new_value == -1 || new_value > old_value)) {
|
||||||
|
out_shape[index_pair.value()] = new_value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return RankedTensorType::get(out_shape, element_type);
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult InferBroadcastBinaryOpReturnTypeComponents(
|
||||||
|
MLIRContext* context, Optional<Location> location, ValueRange operands,
|
||||||
|
DictionaryAttr attributes, Type element_type,
|
||||||
|
SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
|
||||||
|
// Find broadcast_dimensions.
|
||||||
|
DenseIntElementsAttr broadcast_dimensions =
|
||||||
|
attributes.get("broadcast_dimensions")
|
||||||
|
.dyn_cast_or_null<DenseIntElementsAttr>();
|
||||||
|
|
||||||
|
ShapedType lhs_type = operands[0].getType().dyn_cast<ShapedType>();
|
||||||
|
ShapedType rhs_type = operands[1].getType().dyn_cast<ShapedType>();
|
||||||
|
if (!lhs_type || !rhs_type ||
|
||||||
|
lhs_type.getElementType() != rhs_type.getElementType()) {
|
||||||
|
return emitOptionalError(location, "mismatched operand types");
|
||||||
|
}
|
||||||
|
if (!element_type) element_type = lhs_type.getElementType();
|
||||||
|
Type result_type =
|
||||||
|
GetBroadcastType(lhs_type, rhs_type, element_type, broadcast_dimensions);
|
||||||
|
|
||||||
|
if (auto ranked_result_type = result_type.dyn_cast<RankedTensorType>()) {
|
||||||
|
inferedReturnShapes.emplace_back(ranked_result_type.getShape(),
|
||||||
|
element_type);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(laurenzo): This should be constructing with `element_type` but that
|
||||||
|
// constructor variant needs to be added upstream.
|
||||||
|
inferedReturnShapes.emplace_back(/* element_type */);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult ReifyBroadcastBinaryOpReturnTypeShapes(
|
||||||
|
OpBuilder& builder, Operation* op,
|
||||||
|
SmallVectorImpl<Value>& reifiedReturnShapes) {
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
auto lhs = op->getOperand(0);
|
||||||
|
auto rhs = op->getOperand(1);
|
||||||
|
|
||||||
|
// Check for "numpy"-style rank broadcast.
|
||||||
|
auto broadcast_dimensions = op->getAttr("broadcast_dimensions")
|
||||||
|
.dyn_cast_or_null<DenseIntElementsAttr>();
|
||||||
|
if (broadcast_dimensions &&
|
||||||
|
!xla::IsLegalNumpyRankedBroadcast(lhs, rhs, broadcast_dimensions)) {
|
||||||
|
// Note: It is unclear whether the general specification of explicit
|
||||||
|
// broadcast_dimensions on binary ops is a feature we want to carry
|
||||||
|
// forward. While it can technically be implemented for ranked-dynamic,
|
||||||
|
// it is incompatible with unranked inputs. If this warning is emitted
|
||||||
|
// in real programs, it is an indication that the feature should be
|
||||||
|
// implemented versus just falling back on the more standard definition
|
||||||
|
// of numpy-like prefix-padding.
|
||||||
|
return op->emitWarning()
|
||||||
|
<< "unsupported non prefix-padded dynamic rank "
|
||||||
|
<< "broadcast_dimensions = " << broadcast_dimensions;
|
||||||
|
}
|
||||||
|
|
||||||
|
Value computed_shape = xla::ComputeBinaryElementwiseBroadcastingResultExtents(
|
||||||
|
loc, lhs, rhs, builder);
|
||||||
|
if (!computed_shape) return failure();
|
||||||
|
reifiedReturnShapes.push_back(computed_shape);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// BroadcastComplexOp (has custom type inference due to different result type).
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
LogicalResult BroadcastComplexOp::inferReturnTypeComponents(
|
||||||
|
MLIRContext* context, Optional<Location> location, ValueRange operands,
|
||||||
|
DictionaryAttr attributes, RegionRange regions,
|
||||||
|
SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
|
||||||
|
ShapedType lhs_type = operands[0].getType().dyn_cast<ShapedType>();
|
||||||
|
if (!lhs_type) {
|
||||||
|
return emitOptionalError(location, "expected ShapedType");
|
||||||
|
}
|
||||||
|
Type element_type = ComplexType::get(lhs_type.getElementType());
|
||||||
|
return InferBroadcastBinaryOpReturnTypeComponents(context, location, operands,
|
||||||
|
attributes, element_type,
|
||||||
|
inferedReturnShapes);
|
||||||
|
}
|
||||||
|
LogicalResult BroadcastComplexOp::reifyReturnTypeShapes(
|
||||||
|
OpBuilder& builder, SmallVectorImpl<Value>& reifiedReturnShapes) {
|
||||||
|
return ReifyBroadcastBinaryOpReturnTypeShapes(builder, getOperation(),
|
||||||
|
reifiedReturnShapes);
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// BroadcastCompareOp (has custom type inference due to different result type).
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
void BroadcastCompareOp::build(OpBuilder& builder, OperationState& result,
|
||||||
|
Value lhs, Value rhs,
|
||||||
|
DenseIntElementsAttr broadcast_dimensions,
|
||||||
|
StringAttr comparison_direction) {
|
||||||
|
auto new_type = GetBroadcastType(lhs.getType(), rhs.getType(),
|
||||||
|
builder.getI1Type(), broadcast_dimensions);
|
||||||
|
build(builder, result, new_type, lhs, rhs, broadcast_dimensions,
|
||||||
|
comparison_direction);
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult BroadcastCompareOp::inferReturnTypeComponents(
|
||||||
|
MLIRContext* context, Optional<Location> location, ValueRange operands,
|
||||||
|
DictionaryAttr attributes, RegionRange regions,
|
||||||
|
SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
|
||||||
|
Type element_type = IntegerType::get(1, context);
|
||||||
|
return InferBroadcastBinaryOpReturnTypeComponents(context, location, operands,
|
||||||
|
attributes, element_type,
|
||||||
|
inferedReturnShapes);
|
||||||
|
}
|
||||||
|
LogicalResult BroadcastCompareOp::reifyReturnTypeShapes(
|
||||||
|
OpBuilder& builder, SmallVectorImpl<Value>& reifiedReturnShapes) {
|
||||||
|
return ReifyBroadcastBinaryOpReturnTypeShapes(builder, getOperation(),
|
||||||
|
reifiedReturnShapes);
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Macros for method definitions that are common to most broadcasting ops.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#define BROADCAST_INFER_SHAPE_TYPE_OP_DEFS(Op) \
|
||||||
|
LogicalResult Op::inferReturnTypeComponents( \
|
||||||
|
MLIRContext* context, Optional<Location> location, ValueRange operands, \
|
||||||
|
DictionaryAttr attributes, RegionRange regions, \
|
||||||
|
SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) { \
|
||||||
|
return InferBroadcastBinaryOpReturnTypeComponents( \
|
||||||
|
context, location, operands, attributes, /*element_type=*/nullptr, \
|
||||||
|
inferedReturnShapes); \
|
||||||
|
} \
|
||||||
|
LogicalResult Op::reifyReturnTypeShapes( \
|
||||||
|
OpBuilder& builder, SmallVectorImpl<Value>& reifiedReturnShapes) { \
|
||||||
|
return ReifyBroadcastBinaryOpReturnTypeShapes(builder, getOperation(), \
|
||||||
|
reifiedReturnShapes); \
|
||||||
|
}
|
||||||
|
|
||||||
|
#define BROADCAST_BINARY_OP_DEFS(Op) \
|
||||||
|
void Op::build(OpBuilder& builder, OperationState& result, Value left, \
|
||||||
|
Value right, DenseIntElementsAttr broadcast_dimensions) { \
|
||||||
|
auto type = GetBroadcastType( \
|
||||||
|
left.getType().cast<ShapedType>(), right.getType().cast<ShapedType>(), \
|
||||||
|
getElementTypeOrSelf(right.getType()), broadcast_dimensions); \
|
||||||
|
return Op::build(builder, result, type, left, right, \
|
||||||
|
broadcast_dimensions); \
|
||||||
|
} \
|
||||||
|
BROADCAST_INFER_SHAPE_TYPE_OP_DEFS(Op)
|
||||||
|
|
||||||
|
BROADCAST_BINARY_OP_DEFS(BroadcastAddOp);
|
||||||
|
BROADCAST_BINARY_OP_DEFS(BroadcastAndOp);
|
||||||
|
BROADCAST_BINARY_OP_DEFS(BroadcastAtan2Op);
|
||||||
|
BROADCAST_BINARY_OP_DEFS(BroadcastDivOp);
|
||||||
|
BROADCAST_BINARY_OP_DEFS(BroadcastMaxOp);
|
||||||
|
BROADCAST_BINARY_OP_DEFS(BroadcastMinOp);
|
||||||
|
BROADCAST_BINARY_OP_DEFS(BroadcastMulOp);
|
||||||
|
BROADCAST_BINARY_OP_DEFS(BroadcastOrOp);
|
||||||
|
BROADCAST_BINARY_OP_DEFS(BroadcastPowOp);
|
||||||
|
BROADCAST_BINARY_OP_DEFS(BroadcastRemOp);
|
||||||
|
BROADCAST_BINARY_OP_DEFS(BroadcastShiftLeftOp);
|
||||||
|
BROADCAST_BINARY_OP_DEFS(BroadcastShiftRightArithmeticOp);
|
||||||
|
BROADCAST_BINARY_OP_DEFS(BroadcastShiftRightLogicalOp);
|
||||||
|
BROADCAST_BINARY_OP_DEFS(BroadcastSubOp);
|
||||||
|
BROADCAST_BINARY_OP_DEFS(BroadcastXorOp);
|
||||||
|
|
||||||
|
#undef BROADCAST_INFER_SHAPE_TYPE_OP_DEFS
|
||||||
|
#undef BROADCAST_BINARY_OP_DEFS
|
||||||
|
|
||||||
|
#define GET_OP_CLASSES
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.cc.inc"
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// xla_chlo Dialect Constructor
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
XlaHloClientDialect::XlaHloClientDialect(MLIRContext* context)
|
||||||
|
: Dialect(getDialectNamespace(), context) {
|
||||||
|
addOperations<
|
||||||
|
#define GET_OP_LIST
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.cc.inc"
|
||||||
|
>();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace xla_chlo
|
||||||
|
} // namespace mlir
|
|
@ -0,0 +1,24 @@
|
||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||||
|
|
||||||
|
// Static initialization for XLA dialect registration.
|
||||||
|
static mlir::DialectRegistration<mlir::xla_hlo::XlaHloDialect> xla_hlo_ops;
|
||||||
|
static mlir::DialectRegistration<mlir::xla_chlo::XlaHloClientDialect>
|
||||||
|
xla_chlo_ops;
|
||||||
|
static mlir::DialectRegistration<mlir::xla_lhlo::XlaLhloDialect> xla_lhlo_ops;
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,22 @@
|
||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.cc.inc"
|
||||||
|
|
||||||
|
} // namespace mlir
|
|
@ -0,0 +1,102 @@
|
||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
// This file defines the operations used in the XLA dialect.
|
||||||
|
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||||
|
|
||||||
|
#include <assert.h>
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
|
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/APFloat.h"
|
||||||
|
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/APInt.h"
|
||||||
|
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/ArrayRef.h"
|
||||||
|
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLExtras.h"
|
||||||
|
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/SmallVector.h"
|
||||||
|
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringRef.h"
|
||||||
|
#include "third_party/llvm/llvm-project/llvm/include/llvm/Support/FormatVariadic.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Dialect.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Location.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpDefinition.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpImplementation.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OperationSupport.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/TypeUtilities.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Types.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Value.h"
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_structs.cc.inc"
|
||||||
|
namespace xla_lhlo {
|
||||||
|
|
||||||
|
XlaLhloDialect::XlaLhloDialect(MLIRContext *context)
|
||||||
|
: Dialect(getDialectNamespace(), context) {
|
||||||
|
addOperations<
|
||||||
|
#define GET_OP_LIST
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.cc.inc"
|
||||||
|
>();
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// StaticMemRefCastOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
Value StaticMemRefCastOp::getViewSource() { return *getODSOperands(0).begin(); }
|
||||||
|
|
||||||
|
static LogicalResult Verify(StaticMemRefCastOp op) {
|
||||||
|
if (!op.operand().getType().cast<ShapedType>().hasStaticShape())
|
||||||
|
return op.emitOpError("operand must have static shape");
|
||||||
|
if (!op.getType().hasStaticShape())
|
||||||
|
return op.emitOpError("result must have static shape");
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// DynamicMemRefCastOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
Value DynamicMemRefCastOp::getViewSource() {
|
||||||
|
return *getODSOperands(0).begin();
|
||||||
|
}
|
||||||
|
|
||||||
|
static LogicalResult Verify(DynamicMemRefCastOp op) {
|
||||||
|
// Check if `sizes` and `strides` args are compatible with the result type.
|
||||||
|
if (op.sizes().size() != op.getType().getRank())
|
||||||
|
return op.emitOpError(
|
||||||
|
"`sizes` args count must be equal to the rank of the output memref");
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
#define GET_OP_CLASSES
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.cc.inc"
|
||||||
|
|
||||||
|
// TODO(cheshire): Support folding, reuse code from hlo_ops.cc.
|
||||||
|
|
||||||
|
void FusionOp::build(OpBuilder &builder, OperationState &result,
|
||||||
|
ArrayRef<NamedAttribute> attributes) {
|
||||||
|
result.addAttributes(attributes);
|
||||||
|
Region *bodyRegion = result.addRegion();
|
||||||
|
FusionOp::ensureTerminator(*bodyRegion, builder, result.location);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace xla_lhlo
|
||||||
|
} // namespace mlir
|
|
@ -0,0 +1,30 @@
|
||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
// This is the canonicalize pattern definition file.
|
||||||
|
|
||||||
|
include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpBase.td"
|
||||||
|
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.td"
|
||||||
|
include "mlir-hlo/Dialect/mhlo/IR/hlo_utils.td"
|
||||||
|
|
||||||
|
def UnaryToBinaryEinsumEq : NativeCodeCall<
|
||||||
|
"$_builder.getStringAttr(\",\" + $0.getValue().str())">;
|
||||||
|
|
||||||
|
// Convert UnaryEinsumOp to EinsumOp with two operands with redundant first
|
||||||
|
// operand.
|
||||||
|
def UnaryEinsumToEinsum : Pat<
|
||||||
|
(HLO_UnaryEinsumOp $operand, $equation),
|
||||||
|
(HLO_EinsumOp (HLO_ConstOp (GetScalarOfType<1> $operand)),
|
||||||
|
$operand, (UnaryToBinaryEinsumEq $equation))>;
|
|
@ -0,0 +1,74 @@
|
||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/broadcast_utils.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
|
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/Sequence.h"
|
||||||
|
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/SmallVector.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Shape/IR/Shape.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Diagnostics.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace xla {
|
||||||
|
|
||||||
|
bool IsLegalNumpyRankedBroadcast(Value lhs, Value rhs,
|
||||||
|
DenseIntElementsAttr broadcast_dims) {
|
||||||
|
RankedTensorType lhs_type = lhs.getType().dyn_cast<RankedTensorType>();
|
||||||
|
RankedTensorType rhs_type = rhs.getType().dyn_cast<RankedTensorType>();
|
||||||
|
if (!lhs_type || !rhs_type) return false;
|
||||||
|
if (lhs_type.getRank() == rhs_type.getRank()) return true;
|
||||||
|
|
||||||
|
// Otherwise, verify that broadcast_dims strictly performs left-padding.
|
||||||
|
auto smaller_rank = std::min(lhs_type.getRank(), rhs_type.getRank());
|
||||||
|
auto larger_rank = std::max(lhs_type.getRank(), rhs_type.getRank());
|
||||||
|
|
||||||
|
if (smaller_rank != broadcast_dims.getNumElements()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto expected_extents =
|
||||||
|
llvm::seq<int64_t>(larger_rank - smaller_rank, larger_rank);
|
||||||
|
return std::equal(expected_extents.begin(), expected_extents.end(),
|
||||||
|
broadcast_dims.getIntValues().begin());
|
||||||
|
}
|
||||||
|
|
||||||
|
Value ComputeBinaryElementwiseBroadcastingResultExtents(Location loc, Value lhs,
|
||||||
|
Value rhs,
|
||||||
|
OpBuilder& builder) {
|
||||||
|
auto lhs_type = lhs.getType().dyn_cast<RankedTensorType>();
|
||||||
|
auto rhs_type = rhs.getType().dyn_cast<RankedTensorType>();
|
||||||
|
if (!lhs_type || !rhs_type) {
|
||||||
|
emitError(loc) << "shape computation for broadcasting elementwise ops "
|
||||||
|
<< "is only implemented for ranked tensors";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t result_rank = std::max(lhs_type.getRank(), rhs_type.getRank());
|
||||||
|
auto shape_type = shape::ShapeType::get(builder.getContext());
|
||||||
|
Value lhs_shape_v =
|
||||||
|
builder.createOrFold<shape::ShapeOfOp>(loc, shape_type, lhs);
|
||||||
|
Value rhs_shape_v =
|
||||||
|
builder.createOrFold<shape::ShapeOfOp>(loc, shape_type, rhs);
|
||||||
|
Value result_shape_v = builder.createOrFold<shape::BroadcastOp>(
|
||||||
|
loc, shape_type, lhs_shape_v, rhs_shape_v, nullptr /* error */);
|
||||||
|
return builder.createOrFold<shape::ToExtentTensorOp>(
|
||||||
|
loc, RankedTensorType::get({result_rank}, builder.getIndexType()),
|
||||||
|
result_shape_v);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace xla
|
||||||
|
} // namespace mlir
|
|
@ -0,0 +1,86 @@
|
||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
// This file defines helpers useful when creating or manipulating lhlo/hlo.
|
||||||
|
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/convert_op_folder.h"
|
||||||
|
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/TypeUtilities.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace xla {
|
||||||
|
|
||||||
|
mlir::ElementsAttr ConvertElementsAttr(const mlir::ElementsAttr& elements,
|
||||||
|
mlir::Type new_type) {
|
||||||
|
auto old_type = getElementTypeOrSelf(elements);
|
||||||
|
size_t bit_width = new_type.isBF16() ? 64 : new_type.getIntOrFloatBitWidth();
|
||||||
|
|
||||||
|
if (old_type.isa<mlir::FloatType>()) {
|
||||||
|
// mapValues always takes a function returning APInt, even when the output
|
||||||
|
// is actually float.
|
||||||
|
using func_type = mlir::APInt(const llvm::APFloat&);
|
||||||
|
if (auto newFloatType = new_type.dyn_cast<mlir::FloatType>()) {
|
||||||
|
// Float -> Float
|
||||||
|
return elements.mapValues(
|
||||||
|
new_type, llvm::function_ref<func_type>(
|
||||||
|
[&newFloatType](const llvm::APFloat& floatVal) {
|
||||||
|
llvm::APFloat newDouble(
|
||||||
|
mlir::FloatAttr::getValueAsDouble(floatVal));
|
||||||
|
bool loses_info = false;
|
||||||
|
newDouble.convert(newFloatType.getFloatSemantics(),
|
||||||
|
llvm::APFloat::rmNearestTiesToEven,
|
||||||
|
&loses_info);
|
||||||
|
return newDouble.bitcastToAPInt();
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
// Float -> Int
|
||||||
|
return elements.mapValues(
|
||||||
|
new_type, llvm::function_ref<func_type>(
|
||||||
|
[&bit_width](const llvm::APFloat& floatVal) {
|
||||||
|
return llvm::APInt(
|
||||||
|
bit_width,
|
||||||
|
mlir::FloatAttr::getValueAsDouble(floatVal));
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
// old_type is Integer
|
||||||
|
// mapValues always takes a function returning APInt, even when the output
|
||||||
|
// is actually float.
|
||||||
|
using func_type = llvm::APInt(const llvm::APInt&);
|
||||||
|
if (auto newFloatType = new_type.dyn_cast<mlir::FloatType>()) {
|
||||||
|
// Int -> Float
|
||||||
|
return elements.mapValues(
|
||||||
|
new_type, llvm::function_ref<func_type>([&newFloatType](
|
||||||
|
const llvm::APInt& intVal) {
|
||||||
|
llvm::APFloat newDouble(static_cast<double>(intVal.getSExtValue()));
|
||||||
|
bool loses_info = false;
|
||||||
|
newDouble.convert(newFloatType.getFloatSemantics(),
|
||||||
|
llvm::APFloat::rmNearestTiesToEven, &loses_info);
|
||||||
|
return newDouble.bitcastToAPInt();
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
// new_type is Integer
|
||||||
|
// Int -> Int
|
||||||
|
return elements.mapValues(
|
||||||
|
new_type,
|
||||||
|
llvm::function_ref<func_type>([&bit_width](const llvm::APInt& intVal) {
|
||||||
|
return llvm::APInt(bit_width, intVal.getSExtValue());
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace xla
|
||||||
|
} // namespace mlir
|
|
@ -0,0 +1,70 @@
|
||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h"
|
||||||
|
|
||||||
|
#include <numeric>
|
||||||
|
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace xla {
|
||||||
|
|
||||||
|
DenseIntElementsAttr getBroadcastDimensionsAttr(Builder *b, Value x, Value y,
|
||||||
|
bool allow_empty) {
|
||||||
|
TensorType xType = x.getType().dyn_cast<RankedTensorType>();
|
||||||
|
TensorType yType = y.getType().dyn_cast<RankedTensorType>();
|
||||||
|
if (!xType || !yType) return {};
|
||||||
|
if (allow_empty && xType == yType) return {};
|
||||||
|
|
||||||
|
// If the shapes have the same rank, then there is nothing to do.
|
||||||
|
auto xRank = xType.getRank(), yRank = yType.getRank();
|
||||||
|
if (allow_empty && xRank == yRank) return {};
|
||||||
|
|
||||||
|
// Otherwise if the ranks of the inputs don't match, TensorFlow automatically
|
||||||
|
// reshapes the smaller by padding with dimensions of size 1 as a prefix. In
|
||||||
|
// other words to pad a 5-vector to a 3-dimensional tensor it is reshaped to
|
||||||
|
// have shape [1,1,5]. XLA's automatic broadcast code is able to broadcast
|
||||||
|
// from lower to higher rank, but doesn't assume you want to pad as a prefix
|
||||||
|
// of the dimensions, and instead needs to be told which dimensions of the
|
||||||
|
// higher rank tensor to match to the lower rank tensor.
|
||||||
|
auto maxRank = std::max(xRank, yRank);
|
||||||
|
auto minRank = std::min(xRank, yRank);
|
||||||
|
|
||||||
|
// Match the lower rank tensor along the larger-numbered dimensions of the
|
||||||
|
// higher rank tensor.
|
||||||
|
SmallVector<int64_t, 4> broadcastDimensions(minRank);
|
||||||
|
std::iota(broadcastDimensions.begin(), broadcastDimensions.end(),
|
||||||
|
maxRank - minRank);
|
||||||
|
|
||||||
|
RankedTensorType type =
|
||||||
|
RankedTensorType::get({minRank}, b->getIntegerType(64));
|
||||||
|
return DenseIntElementsAttr::get(type, broadcastDimensions);
|
||||||
|
}
|
||||||
|
|
||||||
|
DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value) {
|
||||||
|
RankedTensorType scalar_ty = RankedTensorType::get({}, ty);
|
||||||
|
|
||||||
|
if (auto float_ty = ty.dyn_cast<FloatType>()) {
|
||||||
|
APFloat value(float_ty.getFloatSemantics(), raw_value);
|
||||||
|
return DenseElementsAttr::get(scalar_ty, value);
|
||||||
|
}
|
||||||
|
auto int_ty = ty.cast<IntegerType>();
|
||||||
|
APInt value(int_ty.getWidth(), static_cast<int64_t>(raw_value), true);
|
||||||
|
return DenseElementsAttr::get(scalar_ty, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace xla
|
||||||
|
} // namespace mlir
|
Loading…
Reference in New Issue