Move the HLO/LHLO dialects to a new directory: tensorflow/compiler/mlir/hlo

We're preparing to restructure the MLIR HLO ecosystem with 5 dialects:

- chlo: client dialect with explicit broadcast and multiple composite operations
- mhlo: hlo with dynamic shape, decouple from XLA for evolution purpose
- lmhlo: same as above, but after buffer assignment.
- xla_hlo: mapping 1:1 to the XLA HloInstruction class.
- xla_lhlo: same as above, but after buffer assignment.

The first three dialects are intended to live in the new tensorflow/compiler/mlir/hlo
path, the latter two will be created in tensorflow/compiler/mlir/xla.

This patch only moves the directory, will followup with other transformations and tests.

The structure of the new directory follows: https://llvm.discourse.group/t/rfc-canonical-file-paths-to-dialects/621 as we intend to make it a standalone buildable component (see also https://github.com/google/mlir-npcomp as another example).

PiperOrigin-RevId: 319273229
This commit is contained in:
Mehdi Amini 2020-07-01 19:18:52 +00:00 committed by Mehdi Amini
parent 5fe5c39ccc
commit fcf3df1541
22 changed files with 7267 additions and 0 deletions

View File

@ -0,0 +1,45 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_CHLO_OPS_H_
#define TENSORFLOW_COMPILER_MLIR_XLA_IR_CHLO_OPS_H_
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringRef.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Dialect.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/DialectImplementation.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpDefinition.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Types.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/InferTypeOpInterface.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/SideEffectInterfaces.h"
namespace mlir {
namespace xla_chlo {
class XlaHloClientDialect : public Dialect {
public:
explicit XlaHloClientDialect(MLIRContext *context);
static StringRef getDialectNamespace() { return "xla_chlo"; }
};
#define GET_OP_CLASSES
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h.inc"
} // namespace xla_chlo
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_XLA_IR_CHLO_OPS_H_

View File

@ -0,0 +1,370 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Defines "client" aligned HLO ops.
// These ops are not necessarily orthogonal or optimized for transformation but
// for ease of expression in certain cases deemed important for client
// libraries (i.e. implicit broadcasting, helper ops, etc).
// This dialect is considered to exist in addition to augment the xla_hlo
// dialect for ergonomic needs, not duplicate/replace it.
//
// The typical use of this dialect is for client libraries to be able to emit
// less constrained ops and rely on the conversion framework to lower any
// xla_chlo ops to canonical xla_hlo ops.
//
// See: https://www.tensorflow.org/xla/operation_semantics
#ifndef CHLO_OPS
#define CHLO_OPS
include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpBase.td"
include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/InferTypeOpInterface.td"
include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/SideEffectInterfaces.td"
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td"
def HLOClient_Dialect : Dialect {
let name = "xla_chlo";
let cppNamespace = "xla_chlo";
let summary = [{
XLA Client HLO Ops
}];
let description = [{
This dialect contains ops that align closely with the API surface area
of the XlaBuilder C++ API, where such ops have semantics that go beyond
what exists in the lower level dialects (such as `xla_hlo`). Essentially,
whenever the client library uses syntactic sugar or composition
of multiple ops for an API call, this dialect tries to model the API call
and provide conversion patterns to fully materialize into lower level
dialects.
}];
}
class HLOClient_Op<string mnemonic, list<OpTrait> traits> :
Op<HLOClient_Dialect, mnemonic, traits> {
// TODO(b/129012527) Much of this custom verification should be expressed as
// type constraints.
let verifier = [{ return Verify(*this); }];
}
//===----------------------------------------------------------------------===//
// XLA binary elementwise op definitions.
// From the client perspective, each of these support both explicit rank
// broadcasting (via the broadcast_dimensions attribute) and implicit degenerate
// shape broadcasting.
//
// These correspond to operations in the xla_hlo dialect without the
// "broadcast_" prefix, except that those ops require same-shaped operands and
// results.
//
// See:
// https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations
// https://www.tensorflow.org/xla/broadcasting
//===----------------------------------------------------------------------===//
class HLOClient_BroadcastBinaryElementwiseOp<
string mnemonic, list<OpTrait> traits> :
HLOClient_Op<mnemonic,
!listconcat(traits, [
DeclareOpInterfaceMethods<InferShapedTypeOpInterface>])> {
let arguments = (ins
HLO_Tensor:$lhs,
HLO_Tensor:$rhs,
// Explicit rank-broadcast dimension mappings. Defaults to "numpy" prefix
// padded rank-broadcast semantics if omitted.
OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions
);
let builders = [OpBuilder<
"OpBuilder &builder, OperationState &result, Value left, Value right, "
"DenseIntElementsAttr broadcast_dimensions"
>];
let results = (outs HLO_Tensor);
let assemblyFormat = [{
$lhs `,` $rhs attr-dict `:`
`(` type($lhs) `,` type($rhs) `)` `->` type(results)
}];
let extraClassDeclaration = [{
// TODO(laurenzo): It isn't clear to me why reifyReturnShapes does not
// have its declaration generated by DeclareOpInterfaceMethods.
LogicalResult reifyReturnTypeShapes(
OpBuilder& builder, SmallVectorImpl<Value>& reifiedReturnShapes);
}];
}
def HLOClient_BroadcastAddOp : HLOClient_BroadcastBinaryElementwiseOp<"broadcast_add",
[Commutative, NoSideEffect, SameOperandsAndResultElementType]> {
string summary = "Addition operator (with optional broadcasting)";
string description = [{
Returns `lhs + rhs` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
}];
}
def HLOClient_BroadcastAtan2Op : HLOClient_BroadcastBinaryElementwiseOp<
"broadcast_atan2",
[NoSideEffect, SameOperandsAndResultElementType]> {
string summary = "Atan2 operator (with optional broadcasting)";
string description = [{
Returns `atan2(lhs/rhs)` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
}];
}
def HLOClient_BroadcastDivOp : HLOClient_BroadcastBinaryElementwiseOp<
"broadcast_divide",
[NoSideEffect, SameOperandsAndResultElementType]> {
string summary = "Division operator (with optional broadcasting)";
string description = [{
Returns `lhs / rhs` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
}];
}
def HLOClient_BroadcastMaxOp : HLOClient_BroadcastBinaryElementwiseOp<
"broadcast_maximum",
[Commutative, NoSideEffect, SameOperandsAndResultElementType]> {
string summary = "Maximum operator (with optional broadcasting)";
string description = [{
Returns `max(lhs, rhs)` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
}];
}
def HLOClient_BroadcastMinOp : HLOClient_BroadcastBinaryElementwiseOp<
"broadcast_minimum",
[Commutative, NoSideEffect, SameOperandsAndResultElementType]> {
string summary = "Minimum operator (with optional broadcasting)";
string description = [{
Returns `min(lhs, rhs)` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
}];
}
def HLOClient_BroadcastMulOp : HLOClient_BroadcastBinaryElementwiseOp<
"broadcast_multiply",
[Commutative, NoSideEffect, SameOperandsAndResultElementType]> {
string summary = "Multiplication operator (with optional broadcasting)";
string description = [{
Returns `lhs * rhs` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
}];
}
def HLOClient_BroadcastPowOp : HLOClient_BroadcastBinaryElementwiseOp<
"broadcast_power",
[NoSideEffect, SameOperandsAndResultElementType]> {
string summary = "Power operator (with optional broadcasting)";
string description = [{
Returns `lhs ^ rhs` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
}];
}
def HLOClient_BroadcastRemOp : HLOClient_BroadcastBinaryElementwiseOp<
"broadcast_remainder",
[NoSideEffect, SameOperandsAndResultElementType]> {
string summary = "Remainder operator (with optional broadcasting)";
string description = [{
Returns `lhs % rhs` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
}];
}
def HLOClient_BroadcastShiftLeftOp : HLOClient_BroadcastBinaryElementwiseOp<
"broadcast_shift_left",
[NoSideEffect, SameOperandsAndResultElementType]> {
string summary = "Shift left operator (with optional broadcasting)";
string description = [{
Returns `lhs << rhs` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
}];
}
def HLOClient_BroadcastShiftRightArithmeticOp : HLOClient_BroadcastBinaryElementwiseOp<
"broadcast_shift_right_arithmetic",
[NoSideEffect, SameOperandsAndResultElementType]> {
string summary = "Shift right arithmetic operator (with optional broadcasting)";
string description = [{
Returns `lhs >> rhs` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
}];
}
def HLOClient_BroadcastShiftRightLogicalOp : HLOClient_BroadcastBinaryElementwiseOp<
"broadcast_shift_right_logical",
[NoSideEffect, SameOperandsAndResultElementType]> {
string summary = "Shift right logical operator (with optional broadcasting)";
string description = [{
Returns `lhs >> rhs` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
}];
}
def HLOClient_BroadcastSubOp : HLOClient_BroadcastBinaryElementwiseOp<
"broadcast_subtract",
[NoSideEffect, SameOperandsAndResultElementType]> {
string summary = "Subtraction operator (with optional broadcasting)";
string description = [{
Returns `lhs - rhs` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
}];
}
//===----------------------------------------------------------------------===//
// XLA binary elementwise op definitions.
// The same description as the arithmetic binary elementwise ops applies.
//===----------------------------------------------------------------------===//
class HLOClient_BroadcastBinaryLogicalElementwiseOp<string mnemonic> :
HLOClient_BroadcastBinaryElementwiseOp<
mnemonic, [Commutative, NoSideEffect]> {
let arguments = (ins
HLO_PredOrIntTensor:$lhs,
HLO_PredOrIntTensor:$rhs,
// Explicit rank-broadcast dimension mappings. Defaults to "numpy" prefix
// padded rank-broadcast semantics if omitted.
OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions
);
}
def HLOClient_BroadcastAndOp: HLOClient_BroadcastBinaryLogicalElementwiseOp<
"broadcast_and"> {
string summary = "Logical and operator (with optional broadcasting)";
string description = [{
Returns `logical_and(lhs, rhs)` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
}];
}
def HLOClient_BroadcastOrOp: HLOClient_BroadcastBinaryLogicalElementwiseOp<
"broadcast_or"> {
string summary = "Logical or operator (with optional broadcasting)";
string description = [{
Returns `logical_or(lhs, rhs)` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
}];
}
def HLOClient_BroadcastXorOp : HLOClient_BroadcastBinaryLogicalElementwiseOp<
"broadcast_xor"> {
string summary = "Logical xor operator (with optional broadcasting)";
string description = [{
Returns `logical_xor(lhs, rhs)` element-wise.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations.
}];
}
//===----------------------------------------------------------------------===//
// Broadcasting complex op
//===----------------------------------------------------------------------===//
def HLOClient_BroadcastComplexOp : HLOClient_BroadcastBinaryElementwiseOp<
"broadcast_complex", [NoSideEffect]> {
string summary = "Complex operator (with optional broadcasting)";
string description = [{
Performs element-wise conversion of a pair of real and imaginary values to
a complex value.
}];
let arguments = (ins
HLO_FpTensor:$lhs,
HLO_FpTensor:$rhs,
// Explicit rank-broadcast dimension mappings. Defaults to "numpy" prefix
// padded rank-broadcast semantics if omitted.
OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions
);
let results = (outs HLO_ComplexTensor);
}
//===----------------------------------------------------------------------===//
// Broadcasting compare op
//===----------------------------------------------------------------------===//
def HLOClient_BroadcastCompareOp : HLOClient_BroadcastBinaryElementwiseOp<
"broadcast_compare", [NoSideEffect]> {
string summary = "Compare operator (with optional broadcasting)";
string description = [{
Compares `lhs` and `rhs` elementwise according to `comparison_direction`.
See
https://www.tensorflow.org/xla/operation_semantics#element-wise_comparison_operations.
}];
let arguments = (ins
HLO_Tensor:$lhs,
HLO_Tensor:$rhs,
OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions,
HLO_ComparisonDirectionAttr:$comparison_direction
);
let results = (outs HLO_PredTensor);
let builders = [OpBuilder<
"OpBuilder &builder, OperationState &result, Value lhs, Value rhs, "
"DenseIntElementsAttr broadcast_dimensions, StringAttr comparison_direction"
>];
}
#endif // CHLO_OPS

View File

@ -0,0 +1,99 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file defines the operations used in the XLA dialect.
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_OPS_H_
#define TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_OPS_H_
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringRef.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Dialect.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/DialectImplementation.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Location.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpDefinition.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Types.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/InferTypeOpInterface.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/SideEffectInterfaces.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h"
namespace mlir {
class OpBuilder;
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_structs.h.inc"
namespace xla_hlo {
class XlaHloDialect : public Dialect {
public:
explicit XlaHloDialect(MLIRContext *context);
static StringRef getDialectNamespace() { return "xla_hlo"; }
// Registered hook to materialize a constant operation from a given attribute
// value with the desired resultant type.
Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type,
Location loc) override;
// Parses a type registered to this dialect.
Type parseType(DialectAsmParser &parser) const override;
// Prints a type registered to this dialect.
void printType(Type type, DialectAsmPrinter &os) const override;
};
namespace HLOTypes {
enum Kind {
Token = Type::FIRST_XLA_HLO_TYPE,
};
} // namespace HLOTypes
class TokenType : public Type::TypeBase<TokenType, Type, TypeStorage> {
public:
using Base::Base;
static TokenType get(MLIRContext *context) {
return Base::get(context, HLOTypes::Token);
}
// Support method to enable LLVM-style type casting.
static bool kindof(unsigned kind) { return kind == HLOTypes::Token; }
};
// Shape derivation function that computes the shape of the result based on
// the first argument. For a 2-dimensional input tensor, this produces IR of
// the form
//
// %0 = dim %arg0, 0 : memref<?x?xf32>
// %1 = index_cast %0 : index to i64
// %2 = dim %arg0, 1 : memref<?x?xf32>
// %3 = index_cast %2 : index to i64
// %4 = "xla_hlo.scalars_to_dimension_tensor"(%1, %3)
// : (i64, i64) -> tensor<2xi64>
//
// and returns %4 as the shape value.
LogicalResult deriveShapeFromFirstOperand(
OpBuilder *builder, Operation *op,
SmallVectorImpl<Value> *reifiedReturnShapes);
#define GET_OP_CLASSES
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h.inc"
} // end namespace xla_hlo
} // end namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_OPS_H_

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,43 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This is the utils file for the HLO dialect.
#ifndef HLO_UTILS
#define HLO_UTILS
include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpBase.td"
def NullArrayAttr : NativeCodeCall<"ArrayAttr()">;
def CastIntElementsAttr : NativeCodeCall<"$0.cast<DenseIntElementsAttr>()">;
class ConstantSplat<string value> : NativeCodeCall<
"xla::getSplat(&$_builder, $0, " # value # ")">;
def NullDenseIntElementsAttr : NativeCodeCall<"DenseIntElementsAttr()">;
def BinBroadcastDimensions : NativeCodeCall<
"xla::getBroadcastDimensionsAttr(&$_builder, $0, $1)">;
def BinBroadcastDimensionsNonEmpty : NativeCodeCall<
"xla::getBroadcastDimensionsAttr(&$_builder, $0, $1, /*allow_empty=*/false)">;
// Here, the element type can be any integer or float type. But, note that only
// 32 bit integers are supported for the value.
class GetScalarOfType<int value> : NativeCodeCall<
"xla::GetScalarOfType(getElementTypeOrSelf($0)," # value # ")">;
#endif // HLO_UTILS

View File

@ -0,0 +1,28 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_INFER_FUSIBILITY_OP_INTERFACE_H_
#define TENSORFLOW_COMPILER_MLIR_XLA_IR_INFER_FUSIBILITY_OP_INTERFACE_H_
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/StandardTypes.h"
namespace mlir {
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h.inc"
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_XLA_IR_INFER_FUSIBILITY_OP_INTERFACE_H_

View File

@ -0,0 +1,161 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file contains inferFusiblityOpInterface, which is used to guide
// fusion decision.
#ifndef MLIR_INFER_FUSIBILITY_OP_INTERFACE
#define MLIR_INFER_FUSIBILITY_OP_INTERFACE
include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpBase.td"
// OpInterface to query if an op is fusible and to query the shape equality
// constraint among the inputs and outputs of an op.
def InferFusibilityOpInterface : OpInterface<"InferFusibilityOpInterface"> {
let description = [{
Interface to query if an op is fusible and to query the shape equality
constraint among the inputs and outputs of an op.
}];
let methods = [
InterfaceMethod<
/*desc=*/[{If true, this op can be fused with its operands
}],
/*retTy=*/"bool",
/*methodName=*/"isFusibleWithOperand",
/*args=*/(ins),
/*methodBody=*/[{}],
/*defaultImplementation=*/[{
/// Returns whether this op can be fused with its operands
return true;
}]
>,
InterfaceMethod<
/*desc=*/[{If true, this op can be fused with its consumers
}],
/*retTy=*/"bool",
/*methodName=*/"isFusibleWithConsumer",
/*args=*/(ins),
/*methodBody=*/[{}],
/*defaultImplementation=*/[{
/// Return whether this op can be fused withh its consumers
return true;
}]
>,
InterfaceMethod<
/*desc=*/"Return whether two inputs have the same shape (assuming no"
"implicit broadcasting).",
/*retTy=*/"bool",
/*methodName=*/"inferInputsShapeEquality",
/*args=*/(ins "int":$lhs, "int":$rhs),
/*methodBody=*/[{}],
/*defaultImplementation=*/[{
/// Return whether two inputs have the same shape.
Operation *op = this->getOperation();
assert(lhs < op->getNumOperands() && lhs >= 0 &&
rhs < op->getNumOperands() && rhs >= 0);
if (lhs == rhs) return true;
// if both lhs and rhs have static shapes, check them directly
Type lhs_ty = op->getOperand(lhs).getType();
Type rhs_ty = op->getOperand(rhs).getType();
auto lhs_shape_type = lhs_ty.dyn_cast_or_null<RankedTensorType>();
auto rhs_shape_type = rhs_ty.dyn_cast_or_null<RankedTensorType>();
if (!lhs_shape_type || !lhs_shape_type.hasStaticShape() ||
!rhs_shape_type || !rhs_shape_type.hasStaticShape() ||
lhs_shape_type.getRank() != rhs_shape_type.getRank()) {
return false;
}
return lhs_shape_type.getShape() == rhs_shape_type.getShape();
}]
>,
InterfaceMethod<
/*desc=*/"Return whether two outputs have the same shape (assuming no"
" implicit broadcasting).",
/*retTy=*/"bool",
/*methodName=*/"inferOutputsShapeEquality",
/*args=*/(ins "int":$lhs, "int":$rhs),
/*methodBody=*/[{}],
/*defaultImplementation=*/[{
/// Return whether two outputs have the same shape.
Operation *op = this->getOperation();
assert(lhs < op->getNumResults() && lhs >= 0 &&
rhs < op->getNumResults() && rhs >= 0);
if (lhs == rhs) return true;
// if both lhs and rhs have static shapes, check them directly
Type lhs_ty = op->getResult(lhs).getType();
Type rhs_ty = op->getResult(rhs).getType();
auto lhs_shape_type = lhs_ty.dyn_cast_or_null<RankedTensorType>();
auto rhs_shape_type = rhs_ty.dyn_cast_or_null<RankedTensorType>();
if (!lhs_shape_type || !lhs_shape_type.hasStaticShape() ||
!rhs_shape_type || !rhs_shape_type.hasStaticShape() ||
lhs_shape_type.getRank() != rhs_shape_type.getRank()) {
return false;
}
return lhs_shape_type.getShape() == rhs_shape_type.getShape();
}]
>,
InterfaceMethod<
/*desc=*/"Return whether the input and the output have the same"
" shape (assuming no implicit broadcasting).",
/*retTy=*/"bool",
/*methodName=*/"inferInputOutputShapeEquality",
/*args=*/(ins "int":$input, "int":$output),
/*methodBody=*/[{}],
/*defaultImplementation=*/[{
/// Return whether the input and the output have the same shape.
Operation *op = this->getOperation();
assert(input < op->getNumOperands() && input >= 0 &&
output < op->getNumResults() && output >= 0);
// if both input and output have static shapes, check them directly
Type input_ty = op->getOperand(input).getType();
Type output_ty = op->getResult(output).getType();
auto input_shape_type = input_ty.dyn_cast_or_null<RankedTensorType>();
auto output_shape_type = output_ty.dyn_cast_or_null<RankedTensorType>();
if (!input_shape_type || !input_shape_type.hasStaticShape() ||
!output_shape_type || !output_shape_type.hasStaticShape() ||
input_shape_type.getRank() != output_shape_type.getRank()) {
return false;
}
return input_shape_type.getShape() == output_shape_type.getShape();
}]
>,
InterfaceMethod<
/*desc=*/[{Return the effective workload shape for the operation.
Here the effective workload shape roughly represents the maximum
parallelism can be used during the codegen stage. It's used to check
the shape-compatibility of the operation. During fusion, we only
try to fuse shape-compatible ops for performace.
For example, the effective workload shape of an elementwise op is its
output shape, while the effective workload shape of a reduction op may
be its operand shape.
Return None if such an inference is not possible.
}],
/*retTy=*/"llvm::Optional<Value>",
/*methodName=*/"inferEffectiveWorkloadShape",
/*args=*/(ins),
/*methodBody=*/[{}],
/*defaultImplementation=*/[{
/// Return effective workload size if possible, otherwise None.
return {};
}]
>,
];
}
#endif

View File

@ -0,0 +1,52 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file defines the operations used in the LXLA dialect.
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_LHLO_OPS_H_
#define TENSORFLOW_COMPILER_MLIR_XLA_IR_LHLO_OPS_H_
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringRef.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Dialect.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Location.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpDefinition.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Types.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/SideEffectInterfaces.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/ViewLikeInterface.h"
namespace mlir {
class OpBuilder;
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_structs.h.inc"
namespace xla_lhlo {
class XlaLhloDialect : public Dialect {
public:
explicit XlaLhloDialect(MLIRContext *context);
static StringRef getDialectNamespace() { return "xla_lhlo"; }
};
#define GET_OP_CLASSES
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc"
} // namespace xla_lhlo
} // end namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_XLA_IR_LHLO_OPS_H_

View File

@ -0,0 +1,796 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This is the operation definition file for LXLA.
//
// This file largely overlaps with hlo_ops.td at a logic level. It's tempting to
// merge these two files together, but we need to consider the following
// obstacles:
// * We need to have a common representation for arguments. That is to say,
// HLO_Array<X> translates to HLO_Tensor<X> in HLO dialect, and
// Arg<LHLO_Buffer<X>, "", [Mem(Read|Write)]> in LHLO. Array types within tuples
// also need to be transformed.
// * As of now, TableGen's dag functions are not sufficient to accomplish the
// one above.
// * Traits aren't identical, but need to be coped. For example,
// SameOperandAndResultType in HLO corresponds to SameTypeOperands in LHLO.
// * Also, currently HLO describes the API in XLA's client side, not service
// side. LHLO aims for the service side.
#ifndef LHLO_OPS
#define LHLO_OPS
include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpBase.td"
include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/SideEffectInterfaces.td"
include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/ViewLikeInterface.td"
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td"
def LHLO_Dialect : Dialect {
let name = "xla_lhlo";
let cppNamespace = "xla_lhlo";
}
//===----------------------------------------------------------------------===//
// XLA type definitions.
//===----------------------------------------------------------------------===//
// Any integer tensor types
def LHLO_IntBuffer : MemRefOf<[HLO_Int]>;
// Any floating-point tensor types
def LHLO_FpBuffer : MemRefOf<[AnyFloat]>;
def LHLO_ComplexBuffer : MemRefOf<[AnyComplex]>;
def LHLO_FpOrComplexBuffer : MemRefOf<[AnyFloat, AnyComplex]>;
def LHLO_PredBuffer : MemRefOf<[HLO_Pred]>;
// Any integer or floating-point tensor types
def LHLO_IntOrFpBuffer : MemRefOf<[HLO_Int, AnyFloat]>;
def LHLO_PredOrIntBuffer : MemRefOf<[HLO_Int, HLO_Pred]>;
def LHLO_Buffer : MemRefOf<[AnyFloat, AnySignlessInteger, AnyComplex]>;
//===----------------------------------------------------------------------===//
// XLA nullary op definitions.
//===----------------------------------------------------------------------===//
class LHLO_Op<string mnemonic, list<OpTrait> traits> :
Op<LHLO_Dialect, mnemonic,
!listconcat([MemoryEffects<[MemRead, MemWrite]>], traits)>;
def LHLO_ConstOp : LHLO_Op<"constant", []>, BASE_HLO_ConstOp {
let arguments = (ins
ElementsAttr:$value,
Arg<LHLO_Buffer, "", [MemWrite]>:$output
);
}
def LHLO_IotaOp : LHLO_Op<"iota", []>, BASE_HLO_IotaOp {
let arguments = (ins I64Attr:$iota_dimension,
Arg<LHLO_Buffer, "", [MemWrite]>:$output);
}
//===----------------------------------------------------------------------===//
// XLA unary elementwise op definitions.
//===----------------------------------------------------------------------===//
// See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions
class LHLO_UnaryElementwiseOp<string mnemonic,
Type BufferType = LHLO_Buffer,
list<OpTrait> traits = [SameTypeOperands]>
: LHLO_Op<mnemonic, traits> {
let arguments = (ins Arg<BufferType, "", [MemRead]>:$input,
Arg<BufferType, "", [MemWrite]>:$output);
}
def LHLO_AbsOp: LHLO_UnaryElementwiseOp<"abs">, BASE_HLO_AbsOp;
// TODO(timshen): add a custom verifier.
def LHLO_BitcastConvertOp:
LHLO_UnaryElementwiseOp<"bitcast_convert", LHLO_Buffer, [SameOperandsShape]>, BASE_HLO_BitcastConvertOp;
def LHLO_CeilOp: LHLO_UnaryElementwiseOp<"ceil", LHLO_FpBuffer>, BASE_HLO_CeilOp;
def LHLO_ClzOp: LHLO_UnaryElementwiseOp<"count_leading_zeros", LHLO_IntBuffer>, BASE_HLO_ClzOp;
// TODO(timshen): add a custom verifier.
def LHLO_ConvertOp : LHLO_UnaryElementwiseOp<"convert", LHLO_Buffer, [SameOperandsShape]>, BASE_HLO_ConvertOp;
def LHLO_CosOp: LHLO_UnaryElementwiseOp<"cosine", LHLO_FpOrComplexBuffer>, BASE_HLO_CosOp;
def LHLO_ExpOp: LHLO_UnaryElementwiseOp<"exponential", LHLO_FpOrComplexBuffer>, BASE_HLO_ExpOp;
def LHLO_Expm1Op: LHLO_UnaryElementwiseOp<"exponential_minus_one", LHLO_FpOrComplexBuffer>, BASE_HLO_Expm1Op;
def LHLO_FloorOp: LHLO_UnaryElementwiseOp<"floor", LHLO_FpBuffer>, BASE_HLO_FloorOp;
def LHLO_ImagOp: LHLO_Op<"imag", [SameOperandsShape]>, BASE_HLO_ImagOp {
let arguments = (ins Arg<LHLO_ComplexBuffer, "", [MemRead]>:$input,
Arg<LHLO_FpBuffer, "", [MemWrite]>:$output);
}
def LHLO_IsFiniteOp: LHLO_Op<"is_finite", [SameOperandsShape]>, BASE_HLO_IsFiniteOp {
let arguments = (ins Arg<LHLO_FpBuffer, "", [MemRead]>:$input,
Arg<LHLO_PredBuffer, "", [MemWrite]>:$output);
}
def LHLO_LogOp: LHLO_UnaryElementwiseOp<"log", LHLO_FpOrComplexBuffer>, BASE_HLO_LogOp;
def LHLO_Log1pOp: LHLO_UnaryElementwiseOp<"log_plus_one", LHLO_FpOrComplexBuffer>, BASE_HLO_Log1pOp;
def LHLO_NegOp: LHLO_UnaryElementwiseOp<"negate">, BASE_HLO_NegOp;
def LHLO_NotOp: LHLO_UnaryElementwiseOp<"not", LHLO_PredOrIntBuffer>, BASE_HLO_NotOp;
def LHLO_PopulationCountOp: LHLO_UnaryElementwiseOp<"popcnt", LHLO_IntBuffer>, BASE_HLO_PopulationCountOp;
def LHLO_RealOp: LHLO_Op<"real", [SameOperandsShape]>, BASE_HLO_RealOp {
let arguments = (ins Arg<LHLO_ComplexBuffer, "", [MemRead]>:$input,
Arg<LHLO_FpBuffer, "", [MemWrite]>:$output);
}
def LHLO_RoundOp: LHLO_UnaryElementwiseOp<"round_nearest_afz", LHLO_FpBuffer>, BASE_HLO_RoundOp;
def LHLO_RsqrtOp: LHLO_UnaryElementwiseOp<"rsqrt", LHLO_FpOrComplexBuffer>, BASE_HLO_RsqrtOp;
def LHLO_SqrtOp: LHLO_UnaryElementwiseOp<"sqrt", LHLO_FpOrComplexBuffer>, BASE_HLO_SqrtOp;
def LHLO_SignOp: LHLO_UnaryElementwiseOp<"sign">, BASE_HLO_SignOp;
def LHLO_SinOp: LHLO_UnaryElementwiseOp<"sine", LHLO_FpOrComplexBuffer>, BASE_HLO_SinOp;
def LHLO_TanhOp: LHLO_UnaryElementwiseOp<"tanh", LHLO_FpOrComplexBuffer>, BASE_HLO_TanhOp;
//===----------------------------------------------------------------------===//
// XLA binary elementwise op definitions.
//===----------------------------------------------------------------------===//
// See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations
class LHLO_BinaryElementwiseOp<string mnemonic, Type BufferType = LHLO_Buffer,
list<OpTrait> traits = [SameTypeOperands]> :
LHLO_Op<mnemonic, traits> {
let arguments = (ins
Arg<BufferType, "", [MemRead]>:$lhs,
Arg<BufferType, "", [MemRead]>:$rhs,
Arg<BufferType, "", [MemWrite]>:$out,
OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions
);
}
def LHLO_AddOp : LHLO_BinaryElementwiseOp<"add">, BASE_HLO_AddOp;
def LHLO_AndOp: LHLO_BinaryElementwiseOp<"and", LHLO_PredOrIntBuffer>, BASE_HLO_AndOp;
def LHLO_Atan2Op : LHLO_BinaryElementwiseOp<"atan2", LHLO_FpOrComplexBuffer>, BASE_HLO_Atan2Op;
def LHLO_ComplexOp: LHLO_Op<"complex", [SameOperandsShape]>, BASE_HLO_ComplexOp {
let arguments = (ins
Arg<LHLO_FpBuffer, "", [MemRead]>:$lhs,
Arg<LHLO_FpBuffer, "", [MemRead]>:$rhs,
Arg<LHLO_ComplexBuffer, "", [MemWrite]>:$output,
OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions
);
}
def LHLO_DivOp : LHLO_BinaryElementwiseOp<"divide">, BASE_HLO_DivOp;
def LHLO_MaxOp : LHLO_BinaryElementwiseOp<"maximum">, BASE_HLO_MaxOp;
def LHLO_MinOp : LHLO_BinaryElementwiseOp<"minimum">, BASE_HLO_MinOp;
def LHLO_MulOp : LHLO_BinaryElementwiseOp<"multiply">, BASE_HLO_MulOp;
def LHLO_OrOp : LHLO_BinaryElementwiseOp<"or", LHLO_PredOrIntBuffer>, BASE_HLO_OrOp;
def LHLO_PowOp : LHLO_BinaryElementwiseOp<"power">, BASE_HLO_PowOp;
def LHLO_RemOp : LHLO_BinaryElementwiseOp<"remainder", LHLO_IntOrFpBuffer>, BASE_HLO_RemOp;
def LHLO_ShiftLeftOp : LHLO_BinaryElementwiseOp<"shift_left", LHLO_IntBuffer>, BASE_HLO_ShiftLeftOp;
def LHLO_ShiftRightArithmeticOp : LHLO_BinaryElementwiseOp<"shift_right_arithmetic", LHLO_IntBuffer>, BASE_HLO_ShiftRightArithmeticOp;
def LHLO_ShiftRightLogicalOp : LHLO_BinaryElementwiseOp<"shift_right_logical", LHLO_IntBuffer>, BASE_HLO_ShiftRightLogicalOp;
def LHLO_SubOp : LHLO_BinaryElementwiseOp<"subtract">, BASE_HLO_SubOp;
def LHLO_XorOp : LHLO_BinaryElementwiseOp<"xor", LHLO_PredOrIntBuffer>, BASE_HLO_XorOp;
//===----------------------------------------------------------------------===//
// XLA control flow op definitions.
//===----------------------------------------------------------------------===//
// TODO(b/139813999): specify required function signature in a type-safe way.
def LHLO_ReduceOp: LHLO_Op<"reduce", [
SameVariadicOperandSize,
SingleBlockImplicitTerminator<"TerminatorOp">
]>, BASE_HLO_ReduceOp {
let arguments = (ins
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$operands,
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$init_values,
Arg<Variadic<LHLO_Buffer>, "", [MemWrite]>:$out,
I64ElementsAttr:$dimensions
);
let regions = (region SizedRegion<1>:$body);
}
def LHLO_ReduceWindowOp: LHLO_Op<"reduce_window", [
SingleBlockImplicitTerminator<"TerminatorOp">
]>, BASE_HLO_ReduceWindowOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemRead]>:$init_value,
Arg<LHLO_Buffer, "", [MemWrite]>:$out,
I64ElementsAttr:$window_dimensions,
// If strides or dilations attributes are missing then the default value is
// one for each of the input dimensions. Similarly, padding values are zero
// for both low and high in each of the dimensions, if not specified.
OptionalAttr<I64ElementsAttr>:$window_strides,
OptionalAttr<I64ElementsAttr>:$base_dilations,
OptionalAttr<I64ElementsAttr>:$window_dilations,
OptionalAttr<I64ElementsAttr>:$padding
);
let regions = (region SizedRegion<1>:$body);
}
// TODO(timshen): Add a custom parser to hide operand_segment_sizes. For example,
// A tuple-like pattern match syntax could work:
// xla_lhlo.case %index, (%input0, %input1, %input2), (%output0, %output1) {
// ...
// }, {
// ...
// } : (type_input0, type_input1, type_input2, type_output0, type_output1) -> ()
def LHLO_CaseOp: LHLO_Op<"case", [
AttrSizedOperandSegments,
SingleBlockImplicitTerminator<"TerminatorOp">
]>, BASE_HLO_CaseOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$index,
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$branch_operands,
Arg<Variadic<LHLO_Buffer>, "", [MemWrite]>:$out
);
let regions = (region VariadicRegion<SizedRegion<1>>:$branches);
}
// TODO(timshen): Add a custom syntax for this.
def LHLO_WhileOp: LHLO_Op<"while", [SameVariadicOperandSize]>,
BASE_HLO_WhileOp {
let arguments = (ins
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$val,
Arg<Variadic<LHLO_Buffer>, "", [MemWrite]>:$output
);
let regions = (region SizedRegion<1>:$cond, SizedRegion<1>:$body);
}
//===----------------------------------------------------------------------===//
// XLA tuple op definitions.
//===----------------------------------------------------------------------===//
def LHLO_CompareOp: LHLO_Op<"compare", []>, BASE_HLO_CompareOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
Arg<LHLO_PredBuffer, "", [MemWrite]>:$out,
OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions,
HLO_ComparisonDirectionAttr:$comparison_direction
);
}
//===----------------------------------------------------------------------===//
// XLA Slice definitions.
//===----------------------------------------------------------------------===//
def LHLO_SliceOp: LHLO_Op<
"slice",
[AllTypesMatch<["start_indices", "limit_indices", "strides"]>]> {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
I64ElementsAttr:$start_indices,
I64ElementsAttr:$limit_indices,
I64ElementsAttr:$strides
);
}
def HLO_DynamicUpdateSliceOp: LHLO_Op<"dynamic-update-slice", []> {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemRead]>:$update,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$start_indices
);
}
//===----------------------------------------------------------------------===//
// StaticMemRefCastOp
//===----------------------------------------------------------------------===//
def HLO_StaticMemRefCastOp: Op<LHLO_Dialect, "static_memref_cast",
[NoSideEffect, DeclareOpInterfaceMethods<ViewLikeOpInterface>]> {
let summary = [{
"modifies the offset, sizes and strides of a statically shaped memref.
}];
let description = [{
Allows to modify the offset, sizes and strides of a statically shaped memref.
Example:
```mlir
%buf_transformed =
xla_lhlo.static_memref_cast %buf
: memref<1x5xf32> -> memref<5xf32, offset: 2, strides: [1]>
// The result of the op is a rank-1 memref with `[5]` shape, stride 1 and
// offset 2.
```
}];
let arguments = (ins Arg<LHLO_Buffer, "", []>:$operand);
let results = (outs Res<LHLO_Buffer, "", []>:$result);
let builders = [OpBuilder<
"OpBuilder &builder, OperationState &result, MemRefType resultType, " #
"Value operand", [{
result.addOperands(operand);
result.types.push_back(resultType);
}]>];
let extraClassDeclaration = [{
MemRefType getType() { return getResult().getType().cast<MemRefType>(); }
}];
let verifier = [{ return Verify(*this); }];
let assemblyFormat = [{
$operand attr-dict `:` type($operand) `->` type($result)
}];
}
//===----------------------------------------------------------------------===//
// DynamicMemRefCastOp
//===----------------------------------------------------------------------===//
def HLO_DynamicMemRefCastOp: Op<LHLO_Dialect, "dynamic_memref_cast",
[SameVariadicOperandSize, NoSideEffect,
DeclareOpInterfaceMethods<ViewLikeOpInterface>]> {
let summary = "dynamic memref cast operation";
let description = [{
Change sizes and strides of a memref using the values computed in runtime.
Example:
```mlir
%buf_transformed =
xla_lhlo.dynamic_memref_cast %buf(%size_X, %size_Y)[%step_X, %step_Y]
: memref<?x?xf32> -> memref<?x?xf32, offset: 0, strides: [?, ?]>
// The result of the op is a type-erased memref with `[%size_X, %size_Y]`
// shape and `[%step_X, %step_Y]` strides. The offset will be inherited
// from the input.
```
}];
let arguments = (ins
Arg<LHLO_Buffer, "", []>:$operand,
Variadic<Index>:$sizes,
Variadic<Index>:$strides
);
let results = (outs Res<LHLO_Buffer, "", []>:$result);
let builders = [OpBuilder<
"OpBuilder &builder, OperationState &result, MemRefType resultType, " #
"Value operand, ValueRange sizes, ValueRange strides", [{
result.addOperands(operand);
result.addOperands(sizes);
result.addOperands(strides);
result.types.push_back(resultType);
}]>];
let extraClassDeclaration = [{
MemRefType getType() { return getResult().getType().cast<MemRefType>(); }
}];
let verifier = [{ return Verify(*this); }];
let assemblyFormat = [{
$operand `(` $sizes `)` `[` $strides `]` attr-dict `:` type($operand) `->`
type($result)
}];
}
//===----------------------------------------------------------------------===//
// XLA Other op definitions.
//===----------------------------------------------------------------------===//
def LHLO_BatchNormGradOp : LHLO_Op<"batch_norm_grad", []>,
BASE_HLO_BatchNormGradOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemRead]>:$scale,
Arg<LHLO_Buffer, "", [MemRead]>:$mean,
Arg<LHLO_Buffer, "", [MemRead]>:$variance,
Arg<LHLO_Buffer, "", [MemRead]>:$grad_output,
Arg<LHLO_Buffer, "", [MemWrite]>:$grad_operand, // gradient of $operand.
Arg<LHLO_Buffer, "", [MemWrite]>:$grad_scale,
Arg<LHLO_Buffer, "", [MemWrite]>:$grad_offset,
F32Attr:$epsilon,
I64Attr:$feature_index
);
}
def LHLO_BatchNormInferenceOp : LHLO_Op<"batch_norm_inference", []>,
BASE_HLO_BatchNormInferenceOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemRead]>:$scale,
Arg<LHLO_Buffer, "", [MemRead]>:$offset,
Arg<LHLO_Buffer, "", [MemRead]>:$mean,
Arg<LHLO_Buffer, "", [MemRead]>:$variance,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
F32Attr:$epsilon,
I64Attr:$feature_index
);
}
def LHLO_BatchNormTrainingOp : LHLO_Op<"batch_norm_training", []>,
BASE_HLO_BatchNormTrainingOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemRead]>:$scale,
Arg<LHLO_Buffer, "", [MemRead]>:$offset,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
Arg<LHLO_Buffer, "", [MemWrite]>:$batch_mean,
Arg<LHLO_Buffer, "", [MemWrite]>:$batch_var,
F32Attr:$epsilon,
I64Attr:$feature_index
);
}
// TODO(timshen): add a custom verifier.
def LHLO_BitcastOp: LHLO_Op<"bitcast", []> {
let arguments = (ins Arg<LHLO_Buffer, "", [MemRead]>:$input,
Arg<LHLO_Buffer, "", [MemWrite]>:$output);
}
def LHLO_BroadcastOp : LHLO_Op<"broadcast",
[]>, BASE_HLO_BroadcastOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
I64ElementsAttr:$broadcast_sizes
);
}
def LHLO_BroadcastInDimOp : LHLO_Op<"broadcast_in_dim",
[]>, BASE_HLO_BroadcastInDimOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
BroadcastDimAttr:$broadcast_dimensions
);
}
def LHLO_ClampOp : LHLO_Op<"clamp", []>, BASE_HLO_ClampOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$min,
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemRead]>:$max,
Arg<LHLO_Buffer, "", [MemWrite]>:$output
);
}
def LHLO_ConcatenateOp : LHLO_Op<"concatenate", []>, BASE_HLO_ConcatenateOp {
let arguments = (ins
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$val,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
I64Attr:$dimension
);
}
// TODO(bondhugula): Make this struct dialect independent so that it can be
// shared between the HLO and LHLO dialects.
def ConvDimensionNumbers : StructAttr<"ConvDimensionNumbers", LHLO_Dialect, [
StructFieldAttr<"input_batch_dimension",I64Attr>,
StructFieldAttr<"input_feature_dimension", I64Attr>,
StructFieldAttr<"input_spatial_dimensions", I64ElementsAttr>,
StructFieldAttr<"kernel_input_feature_dimension", I64Attr>,
StructFieldAttr<"kernel_output_feature_dimension", I64Attr>,
StructFieldAttr<"kernel_spatial_dimensions", I64ElementsAttr>,
StructFieldAttr<"output_batch_dimension", I64Attr>,
StructFieldAttr<"output_feature_dimension", I64Attr>,
StructFieldAttr<"output_spatial_dimensions", I64ElementsAttr>] > {
let description = "Structure of dimension information for conv op";
}
def LHLO_ConvOp : LHLO_Op<"convolution", []>, BASE_HLO_ConvOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
// Default value: one for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$window_strides,
// Default value: zero for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$padding,
// Default value: one for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$lhs_dilation,
// Default value: one for each of the spatial dimension.
OptionalAttr<I64ElementsAttr>:$rhs_dilation,
ConvDimensionNumbers:$dimension_numbers,
I64Attr:$feature_group_count,
I64Attr:$batch_group_count,
HLO_PrecisionConfigAttr:$precision_config
);
}
def LHLO_CopyOp: LHLO_Op<"copy", []>, BASE_HLO_CopyOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemWrite]>:$output
);
}
def LHLO_DotOp: LHLO_Op<"dot", []>, BASE_HLO_DotOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
HLO_PrecisionConfigAttr:$precision_config,
Arg<LHLO_Buffer, "", [MemWrite]>:$output
);
}
def LHLO_GatherOp: LHLO_Op<"gather", []>, BASE_HLO_GatherOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_IntBuffer, "", [MemRead]>:$start_indices,
I64Attr:$index_vector_dim,
I64ElementsAttr:$offset_dims,
I64ElementsAttr:$slice_sizes,
I64ElementsAttr:$collapsed_slice_dims,
I64ElementsAttr:$start_index_map,
Arg<LHLO_Buffer, "", [MemWrite]>:$output
);
}
def LHLO_ReshapeOp: LHLO_Op<"reshape", []>, BASE_HLO_ReshapeOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemWrite]>:$output
);
}
def LHLO_ScatterOp: LHLO_Op<"scatter", []>, BASE_HLO_ScatterOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemRead]>:$scatter_indices,
Arg<LHLO_Buffer, "", [MemRead]>:$updates,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
ScatterDimensionNumbers<LHLO_Dialect>:$scatter_dimension_numbers,
DefaultValuedAttr<BoolAttr, "false">:$indices_are_sorted,
DefaultValuedAttr<BoolAttr, "false">:$unique_indices
);
let regions = (region SizedRegion<1>:$update_computation);
}
def LHLO_SelectOp: LHLO_Op<"select", []>, BASE_HLO_SelectOp {
let arguments = (ins
Arg<LHLO_PredBuffer, "", [MemRead]>:$pred,
Arg<LHLO_Buffer, "", [MemRead]>:$on_true,
Arg<LHLO_Buffer, "", [MemRead]>:$on_false,
Arg<LHLO_Buffer, "", [MemWrite]>:$output
);
}
def LHLO_SelectAndScatterOp: LHLO_Op<"select_and_scatter", []>,
BASE_HLO_SelectAndScatterOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemRead]>:$source,
Arg<LHLO_Buffer, "", [MemRead]>:$init_value,
Arg<LHLO_Buffer, "", [MemWrite]>:$out,
OptionalAttr<I64ElementsAttr>:$window_dimensions,
OptionalAttr<I64ElementsAttr>:$window_strides,
OptionalAttr<I64ElementsAttr>:$padding
);
let regions = (region SizedRegion<1>:$select, SizedRegion<1>:$scatter);
}
def LHLO_ReverseOp: LHLO_Op<"reverse", []>, BASE_HLO_ReverseOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
I64ElementsAttr:$dimensions,
Arg<LHLO_Buffer, "", [MemWrite]>:$output
);
}
def LHLO_PadOp: LHLO_Op<"pad", []>, BASE_HLO_PadOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemRead]>:$padding_value,
I64ElementsAttr:$edge_padding_low,
I64ElementsAttr:$edge_padding_high,
I64ElementsAttr:$interior_padding,
Arg<LHLO_Buffer, "", [MemWrite]>:$output
);
}
def LHLO_TransposeOp: LHLO_Op<"transpose", []>, BASE_HLO_TransposeOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
I64ElementsAttr:$permutation,
Arg<LHLO_Buffer, "", [MemWrite]>:$output
);
}
def LHLO_ReducePrecisionOp: LHLO_Op<"reduce_precision", [SameTypeOperands]>,
BASE_HLO_ReducePrecisionOp {
let arguments = (ins
Arg<LHLO_FpBuffer, "", [MemRead]>:$operand,
Arg<LHLO_FpBuffer, "", [MemWrite]>:$output,
I32Attr:$exponent_bits,
I32Attr:$mantissa_bits
);
}
def LHLO_AllReduceOp : LHLO_Op<"all_reduce", [SameTypeOperands]>,
BASE_HLO_AllReduceOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
I64ElementsAttr:$replica_groups,
DefaultValuedAttr<BoolAttr, "false">:$constrain_layout,
OptionalAttr<ChannelHandle<LHLO_Dialect>>:$channel_id,
DefaultValuedAttr<BoolAttr, "false">:$use_global_device_ids
);
let regions = (region SizedRegion<1>:$computation);
}
def LHLO_CollectivePermuteOp: LHLO_Op<"collective_permute", [SameTypeOperands]>,
BASE_HLO_CollectivePermuteOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
I64ElementsAttr:$source_target_pairs,
OptionalAttr<ChannelHandle<LHLO_Dialect>>:$channel_id
);
}
def LHLO_FftOp: LHLO_Op<"fft", []>, BASE_HLO_FftOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
HLO_FftTypeAttr:$fft_type,
I64ElementsAttr:$fft_length
);
}
def LHLO_CholeskyOp: LHLO_Op<"cholesky", [SameOperandsElementType]>, BASE_HLO_CholeskyOp {
let arguments = (ins
Arg<LHLO_FpOrComplexBuffer, "", [MemRead]>:$a,
Arg<LHLO_FpOrComplexBuffer, "", [MemWrite]>:$output,
DefaultValuedAttr<BoolAttr, "false">:$lower
);
}
def LHLO_Infeed: LHLO_Op<"infeed", []>, BASE_HLO_InfeedOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
DefaultValuedAttr<StrAttr, "">:$config
);
}
def LHLO_Outfeed: LHLO_Op<"outfeed", []> {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
DefaultValuedAttr<StrAttr, "">:$config
);
}
def LHLO_ReplicaIdOp : LHLO_Op<"replica_id", []>, BASE_HLO_ReplicaIdOp {
let arguments = (ins Arg<MemRefOf<[UI32]>, "", [MemWrite]>);
}
def LHLO_TriangularSolveOp: LHLO_Op<"triangular_solve", [SameOperandsElementType]>,
BASE_HLO_TriangularSolveOp {
let arguments = (ins
Arg<LHLO_FpOrComplexBuffer, "", [MemRead]>:$a,
Arg<LHLO_FpOrComplexBuffer, "", [MemRead]>:$b,
Arg<LHLO_FpOrComplexBuffer, "", [MemWrite]>:$output,
BoolAttr:$left_side,
BoolAttr:$lower,
BoolAttr:$unit_diagonal,
HLO_TransposeAttr:$transpose_a
);
}
// TODO(timshen): add a custom verifier.
def LHLO_MapOp: LHLO_Op<"map", [SameOperandsShape]>, BASE_HLO_MapOp {
let arguments = (ins
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$operands,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
I64ElementsAttr:$dimensions
);
let regions = (region SizedRegion<1>:$computation);
}
def LHLO_RngGetAndUpdateStateOp: LHLO_Op<"rng_get_and_update_state", []> {
let arguments = (ins
Arg<MemRefOf<[UI64]>, "", [MemRead, MemWrite]>:$state,
I64Attr:$delta
);
}
// TODO(timshen): add a custom verifier.
def LHLO_SortOp: LHLO_Op<"sort", [SameVariadicOperandSize, SameOperandsShape]>, BASE_HLO_SortOp {
let arguments = (ins
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$operands,
Arg<Variadic<LHLO_Buffer>, "", [MemWrite]>:$output,
DefaultValuedAttr<I64Attr, "-1">:$dimension,
DefaultValuedAttr<BoolAttr, "false">:$is_stable
);
let regions = (region SizedRegion<1>:$comparator);
}
//===----------------------------------------------------------------------===//
// Late operations
//===----------------------------------------------------------------------===//
def FusionOp : LHLO_Op<"fusion", [SingleBlockImplicitTerminator<"TerminatorOp">]> {
let summary = "Fusion operator";
let description = [{
Models the fusion instruction generated by the XLA compiler's fusion pass.
Fusion instructions are generated by the fusion pass of the XLA compiler.
They serve as a hint to the backend that it is beneficial to emit the
contained instructions into a single loop nest or kernel. The XLA fusion
pass is designed such that it only generates fusion nodes that can be
handled by the XLA compilers backends.
The XLA runtime expects this hint to be followed, as it expects a single
kernel per HLO instruction. This restriction might be lifted in the future.
}];
let regions = (region SizedRegion<1>:$region);
let skipDefaultBuilders = 1;
let builders = [
OpBuilder<"OpBuilder &builder, OperationState &result, "
"ArrayRef<NamedAttribute> attributes">
];
}
def TerminatorOp :
LHLO_Op<"terminator", [Terminator]> {
let summary = "LHLO termination operation";
let description = [{
Terminator operation for the LHLO dialect.
}];
let builders = [OpBuilder<
"OpBuilder &b, OperationState &result, ValueRange operands",
[{ build(b, result, llvm::None, operands, llvm::None); }]
>];
}
#endif // LHLO_OPS

View File

@ -0,0 +1,49 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_BROADCAST_UTILS_H_
#define TENSORFLOW_COMPILER_MLIR_XLA_IR_BROADCAST_UTILS_H_
// Utilities relating to implementing HLO broadcasting.
// Note: This file should not depend on any non-MLIR TensorFlow libraries.
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Location.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/InferTypeOpInterface.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Support/LLVM.h"
namespace mlir {
namespace xla {
// Checks whether the given operand types and broadcast_dims attr represent a
// legal combination for "numpy" style broadcasting (where 1-dims are prepended
// to the smaller ranked operand until it is of the same rank as the larger).
// See: https://docs.scipy.org/doc/numpy/reference/ufuncs.html
bool IsLegalNumpyRankedBroadcast(Value lhs, Value rhs,
DenseIntElementsAttr broadcast_dims);
// Emits shape dialect ops to compute the result shape for a broadcasting
// binary elementwise op which broadcasts according to "numpy" semantics
// (see above), returning an extents tensor of the resulting shape.
Value ComputeBinaryElementwiseBroadcastingResultExtents(Location loc, Value lhs,
Value rhs,
OpBuilder& builder);
} // namespace xla
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_XLA_IR_BROADCAST_UTILS_H_

View File

@ -0,0 +1,33 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_CONVERT_OP_FOLDER_H_
#define TENSORFLOW_COMPILER_MLIR_XLA_CONVERT_OP_FOLDER_H_
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
namespace mlir {
namespace xla {
// Converts the given elements attr to the specified elements type.
// Requires type of the elements and new_type to be either integer or float
// type.
mlir::ElementsAttr ConvertElementsAttr(const mlir::ElementsAttr& elements,
mlir::Type new_type);
} // namespace xla
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_XLA_CONVERT_OP_FOLDER_H_

View File

@ -0,0 +1,74 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_UTILS_H_
#define TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_UTILS_H_
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/TypeUtilities.h"
namespace mlir {
namespace xla {
// Computes the broadcast dimensions attr for an elementwise binary operator
// between two ranked tensors.
// If `allow_empty` is true, then null can be returned to mean that the
// broadcast is an "identity".
mlir::DenseIntElementsAttr getBroadcastDimensionsAttr(mlir::Builder* b,
mlir::Value x,
mlir::Value y,
bool allow_empty = true);
// Get a constant splat for the given value of type. Requires value to be of
// type static shaped RankedTensorType.
template <typename T>
static ElementsAttr getSplat(Builder* b, RankedTensorType ty, T constant) {
Type element_ty = getElementTypeOrSelf(ty);
if (element_ty.isSignlessInteger())
return DenseElementsAttr::get(ty, b->getIntegerAttr(element_ty, constant));
if (element_ty.isa<FloatType>())
return DenseElementsAttr::get(ty, b->getFloatAttr(element_ty, constant));
if (auto complex_ty = element_ty.dyn_cast<ComplexType>()) {
auto complex_element_ty = complex_ty.getElementType();
if (complex_element_ty.isF32())
return DenseElementsAttr::get(ty,
static_cast<std::complex<float>>(constant));
if (complex_element_ty.isF64())
return DenseElementsAttr::get(
ty, static_cast<std::complex<double>>(constant));
}
llvm_unreachable("unhandled element type");
}
template <typename T>
static ElementsAttr getSplat(Builder* b, Value val, T constant) {
return getSplat(b, val.getType().cast<RankedTensorType>(), constant);
}
// Returns DenseElementsAttr of rank zero with the given element type and the
// value.
// Requires `ty` to be either FloatType of IntegerType.
DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value);
} // namespace xla
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_UTILS_H_

View File

@ -0,0 +1,278 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Diagnostics.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/TypeUtilities.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/broadcast_utils.h"
namespace mlir {
namespace xla_chlo {
template <typename T>
static LogicalResult Verify(T op) {
return success();
}
//===----------------------------------------------------------------------===//
// BinaryOps
//===----------------------------------------------------------------------===//
namespace {
// Gets the resulting type from a broadcast between two types.
static Type GetBroadcastType(Type x, Type y, Type element_type,
DenseIntElementsAttr broadcast_dimensions_attr) {
auto x_ranked = x.dyn_cast<RankedTensorType>();
auto y_ranked = y.dyn_cast<RankedTensorType>();
if (!x_ranked || !y_ranked) {
return UnrankedTensorType::get(element_type);
}
auto shape_x = x_ranked.getShape();
auto shape_y = y_ranked.getShape();
if (shape_x.size() == shape_y.size()) {
llvm::SmallVector<int64_t, 4> out_shape(shape_x.size());
for (int i = 0, e = shape_x.size(); i < e; i++) {
auto x_val = shape_x[i];
auto y_val = shape_y[i];
if (x_val == -1 || y_val == -1) {
out_shape[i] = -1;
} else {
out_shape[i] = std::max(x_val, y_val);
}
}
return RankedTensorType::get(out_shape, element_type);
}
auto shape_large = shape_x.size() > shape_y.size() ? shape_x : shape_y;
auto shape_small = shape_x.size() <= shape_y.size() ? shape_x : shape_y;
llvm::SmallVector<int64_t, 4> broadcast_dimensions;
if (broadcast_dimensions_attr) {
// Explicit broadcast dimensions.
for (const APInt& int_value : broadcast_dimensions_attr.getIntValues()) {
broadcast_dimensions.push_back(int_value.getSExtValue());
}
if (broadcast_dimensions.size() != shape_small.size()) {
// Signal illegal broadcast_dimensions as unranked.
return UnrankedTensorType::get(element_type);
}
} else {
// If no broadcast dimensions, assume "numpy" broadcasting.
broadcast_dimensions = llvm::to_vector<4>(llvm::seq<int64_t>(
shape_large.size() - shape_small.size(), shape_large.size()));
}
llvm::SmallVector<int64_t, 4> out_shape(shape_large.begin(),
shape_large.end());
// Update according to the broadcast dimensions.
for (auto index_pair : llvm::enumerate(broadcast_dimensions)) {
auto old_value = out_shape[index_pair.value()];
auto new_value = shape_small[index_pair.index()];
if (old_value != -1 && (new_value == -1 || new_value > old_value)) {
out_shape[index_pair.value()] = new_value;
}
}
return RankedTensorType::get(out_shape, element_type);
}
LogicalResult InferBroadcastBinaryOpReturnTypeComponents(
MLIRContext* context, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, Type element_type,
SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
// Find broadcast_dimensions.
DenseIntElementsAttr broadcast_dimensions =
attributes.get("broadcast_dimensions")
.dyn_cast_or_null<DenseIntElementsAttr>();
ShapedType lhs_type = operands[0].getType().dyn_cast<ShapedType>();
ShapedType rhs_type = operands[1].getType().dyn_cast<ShapedType>();
if (!lhs_type || !rhs_type ||
lhs_type.getElementType() != rhs_type.getElementType()) {
return emitOptionalError(location, "mismatched operand types");
}
if (!element_type) element_type = lhs_type.getElementType();
Type result_type =
GetBroadcastType(lhs_type, rhs_type, element_type, broadcast_dimensions);
if (auto ranked_result_type = result_type.dyn_cast<RankedTensorType>()) {
inferedReturnShapes.emplace_back(ranked_result_type.getShape(),
element_type);
return success();
}
// TODO(laurenzo): This should be constructing with `element_type` but that
// constructor variant needs to be added upstream.
inferedReturnShapes.emplace_back(/* element_type */);
return success();
}
LogicalResult ReifyBroadcastBinaryOpReturnTypeShapes(
OpBuilder& builder, Operation* op,
SmallVectorImpl<Value>& reifiedReturnShapes) {
auto loc = op->getLoc();
auto lhs = op->getOperand(0);
auto rhs = op->getOperand(1);
// Check for "numpy"-style rank broadcast.
auto broadcast_dimensions = op->getAttr("broadcast_dimensions")
.dyn_cast_or_null<DenseIntElementsAttr>();
if (broadcast_dimensions &&
!xla::IsLegalNumpyRankedBroadcast(lhs, rhs, broadcast_dimensions)) {
// Note: It is unclear whether the general specification of explicit
// broadcast_dimensions on binary ops is a feature we want to carry
// forward. While it can technically be implemented for ranked-dynamic,
// it is incompatible with unranked inputs. If this warning is emitted
// in real programs, it is an indication that the feature should be
// implemented versus just falling back on the more standard definition
// of numpy-like prefix-padding.
return op->emitWarning()
<< "unsupported non prefix-padded dynamic rank "
<< "broadcast_dimensions = " << broadcast_dimensions;
}
Value computed_shape = xla::ComputeBinaryElementwiseBroadcastingResultExtents(
loc, lhs, rhs, builder);
if (!computed_shape) return failure();
reifiedReturnShapes.push_back(computed_shape);
return success();
}
} // namespace
//===----------------------------------------------------------------------===//
// BroadcastComplexOp (has custom type inference due to different result type).
//===----------------------------------------------------------------------===//
LogicalResult BroadcastComplexOp::inferReturnTypeComponents(
MLIRContext* context, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
ShapedType lhs_type = operands[0].getType().dyn_cast<ShapedType>();
if (!lhs_type) {
return emitOptionalError(location, "expected ShapedType");
}
Type element_type = ComplexType::get(lhs_type.getElementType());
return InferBroadcastBinaryOpReturnTypeComponents(context, location, operands,
attributes, element_type,
inferedReturnShapes);
}
LogicalResult BroadcastComplexOp::reifyReturnTypeShapes(
OpBuilder& builder, SmallVectorImpl<Value>& reifiedReturnShapes) {
return ReifyBroadcastBinaryOpReturnTypeShapes(builder, getOperation(),
reifiedReturnShapes);
}
//===----------------------------------------------------------------------===//
// BroadcastCompareOp (has custom type inference due to different result type).
//===----------------------------------------------------------------------===//
void BroadcastCompareOp::build(OpBuilder& builder, OperationState& result,
Value lhs, Value rhs,
DenseIntElementsAttr broadcast_dimensions,
StringAttr comparison_direction) {
auto new_type = GetBroadcastType(lhs.getType(), rhs.getType(),
builder.getI1Type(), broadcast_dimensions);
build(builder, result, new_type, lhs, rhs, broadcast_dimensions,
comparison_direction);
}
LogicalResult BroadcastCompareOp::inferReturnTypeComponents(
MLIRContext* context, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
Type element_type = IntegerType::get(1, context);
return InferBroadcastBinaryOpReturnTypeComponents(context, location, operands,
attributes, element_type,
inferedReturnShapes);
}
LogicalResult BroadcastCompareOp::reifyReturnTypeShapes(
OpBuilder& builder, SmallVectorImpl<Value>& reifiedReturnShapes) {
return ReifyBroadcastBinaryOpReturnTypeShapes(builder, getOperation(),
reifiedReturnShapes);
}
//===----------------------------------------------------------------------===//
// Macros for method definitions that are common to most broadcasting ops.
//===----------------------------------------------------------------------===//
#define BROADCAST_INFER_SHAPE_TYPE_OP_DEFS(Op) \
LogicalResult Op::inferReturnTypeComponents( \
MLIRContext* context, Optional<Location> location, ValueRange operands, \
DictionaryAttr attributes, RegionRange regions, \
SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) { \
return InferBroadcastBinaryOpReturnTypeComponents( \
context, location, operands, attributes, /*element_type=*/nullptr, \
inferedReturnShapes); \
} \
LogicalResult Op::reifyReturnTypeShapes( \
OpBuilder& builder, SmallVectorImpl<Value>& reifiedReturnShapes) { \
return ReifyBroadcastBinaryOpReturnTypeShapes(builder, getOperation(), \
reifiedReturnShapes); \
}
#define BROADCAST_BINARY_OP_DEFS(Op) \
void Op::build(OpBuilder& builder, OperationState& result, Value left, \
Value right, DenseIntElementsAttr broadcast_dimensions) { \
auto type = GetBroadcastType( \
left.getType().cast<ShapedType>(), right.getType().cast<ShapedType>(), \
getElementTypeOrSelf(right.getType()), broadcast_dimensions); \
return Op::build(builder, result, type, left, right, \
broadcast_dimensions); \
} \
BROADCAST_INFER_SHAPE_TYPE_OP_DEFS(Op)
BROADCAST_BINARY_OP_DEFS(BroadcastAddOp);
BROADCAST_BINARY_OP_DEFS(BroadcastAndOp);
BROADCAST_BINARY_OP_DEFS(BroadcastAtan2Op);
BROADCAST_BINARY_OP_DEFS(BroadcastDivOp);
BROADCAST_BINARY_OP_DEFS(BroadcastMaxOp);
BROADCAST_BINARY_OP_DEFS(BroadcastMinOp);
BROADCAST_BINARY_OP_DEFS(BroadcastMulOp);
BROADCAST_BINARY_OP_DEFS(BroadcastOrOp);
BROADCAST_BINARY_OP_DEFS(BroadcastPowOp);
BROADCAST_BINARY_OP_DEFS(BroadcastRemOp);
BROADCAST_BINARY_OP_DEFS(BroadcastShiftLeftOp);
BROADCAST_BINARY_OP_DEFS(BroadcastShiftRightArithmeticOp);
BROADCAST_BINARY_OP_DEFS(BroadcastShiftRightLogicalOp);
BROADCAST_BINARY_OP_DEFS(BroadcastSubOp);
BROADCAST_BINARY_OP_DEFS(BroadcastXorOp);
#undef BROADCAST_INFER_SHAPE_TYPE_OP_DEFS
#undef BROADCAST_BINARY_OP_DEFS
#define GET_OP_CLASSES
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.cc.inc"
//===----------------------------------------------------------------------===//
// xla_chlo Dialect Constructor
//===----------------------------------------------------------------------===//
XlaHloClientDialect::XlaHloClientDialect(MLIRContext* context)
: Dialect(getDialectNamespace(), context) {
addOperations<
#define GET_OP_LIST
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.cc.inc"
>();
}
} // namespace xla_chlo
} // namespace mlir

View File

@ -0,0 +1,24 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
// Static initialization for XLA dialect registration.
static mlir::DialectRegistration<mlir::xla_hlo::XlaHloDialect> xla_hlo_ops;
static mlir::DialectRegistration<mlir::xla_chlo::XlaHloClientDialect>
xla_chlo_ops;
static mlir::DialectRegistration<mlir::xla_lhlo::XlaLhloDialect> xla_lhlo_ops;

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,22 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h"
namespace mlir {
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.cc.inc"
} // namespace mlir

View File

@ -0,0 +1,102 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file defines the operations used in the XLA dialect.
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include <assert.h>
#include <stddef.h>
#include <stdint.h>
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/APFloat.h"
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/APInt.h"
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/ArrayRef.h"
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLExtras.h"
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/SmallVector.h"
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringRef.h"
#include "third_party/llvm/llvm-project/llvm/include/llvm/Support/FormatVariadic.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Dialect.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Location.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpDefinition.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpImplementation.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OperationSupport.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/TypeUtilities.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Types.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Value.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc"
namespace mlir {
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_structs.cc.inc"
namespace xla_lhlo {
XlaLhloDialect::XlaLhloDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context) {
addOperations<
#define GET_OP_LIST
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.cc.inc"
>();
}
//===----------------------------------------------------------------------===//
// StaticMemRefCastOp
//===----------------------------------------------------------------------===//
Value StaticMemRefCastOp::getViewSource() { return *getODSOperands(0).begin(); }
static LogicalResult Verify(StaticMemRefCastOp op) {
if (!op.operand().getType().cast<ShapedType>().hasStaticShape())
return op.emitOpError("operand must have static shape");
if (!op.getType().hasStaticShape())
return op.emitOpError("result must have static shape");
return success();
}
//===----------------------------------------------------------------------===//
// DynamicMemRefCastOp
//===----------------------------------------------------------------------===//
Value DynamicMemRefCastOp::getViewSource() {
return *getODSOperands(0).begin();
}
static LogicalResult Verify(DynamicMemRefCastOp op) {
// Check if `sizes` and `strides` args are compatible with the result type.
if (op.sizes().size() != op.getType().getRank())
return op.emitOpError(
"`sizes` args count must be equal to the rank of the output memref");
return success();
}
#define GET_OP_CLASSES
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.cc.inc"
// TODO(cheshire): Support folding, reuse code from hlo_ops.cc.
void FusionOp::build(OpBuilder &builder, OperationState &result,
ArrayRef<NamedAttribute> attributes) {
result.addAttributes(attributes);
Region *bodyRegion = result.addRegion();
FusionOp::ensureTerminator(*bodyRegion, builder, result.location);
}
} // namespace xla_lhlo
} // namespace mlir

View File

@ -0,0 +1,30 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This is the canonicalize pattern definition file.
include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpBase.td"
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.td"
include "mlir-hlo/Dialect/mhlo/IR/hlo_utils.td"
def UnaryToBinaryEinsumEq : NativeCodeCall<
"$_builder.getStringAttr(\",\" + $0.getValue().str())">;
// Convert UnaryEinsumOp to EinsumOp with two operands with redundant first
// operand.
def UnaryEinsumToEinsum : Pat<
(HLO_UnaryEinsumOp $operand, $equation),
(HLO_EinsumOp (HLO_ConstOp (GetScalarOfType<1> $operand)),
$operand, (UnaryToBinaryEinsumEq $equation))>;

View File

@ -0,0 +1,74 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/broadcast_utils.h"
#include <algorithm>
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/Sequence.h"
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/SmallVector.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Shape/IR/Shape.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Diagnostics.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
namespace mlir {
namespace xla {
bool IsLegalNumpyRankedBroadcast(Value lhs, Value rhs,
DenseIntElementsAttr broadcast_dims) {
RankedTensorType lhs_type = lhs.getType().dyn_cast<RankedTensorType>();
RankedTensorType rhs_type = rhs.getType().dyn_cast<RankedTensorType>();
if (!lhs_type || !rhs_type) return false;
if (lhs_type.getRank() == rhs_type.getRank()) return true;
// Otherwise, verify that broadcast_dims strictly performs left-padding.
auto smaller_rank = std::min(lhs_type.getRank(), rhs_type.getRank());
auto larger_rank = std::max(lhs_type.getRank(), rhs_type.getRank());
if (smaller_rank != broadcast_dims.getNumElements()) {
return false;
}
auto expected_extents =
llvm::seq<int64_t>(larger_rank - smaller_rank, larger_rank);
return std::equal(expected_extents.begin(), expected_extents.end(),
broadcast_dims.getIntValues().begin());
}
Value ComputeBinaryElementwiseBroadcastingResultExtents(Location loc, Value lhs,
Value rhs,
OpBuilder& builder) {
auto lhs_type = lhs.getType().dyn_cast<RankedTensorType>();
auto rhs_type = rhs.getType().dyn_cast<RankedTensorType>();
if (!lhs_type || !rhs_type) {
emitError(loc) << "shape computation for broadcasting elementwise ops "
<< "is only implemented for ranked tensors";
return nullptr;
}
int64_t result_rank = std::max(lhs_type.getRank(), rhs_type.getRank());
auto shape_type = shape::ShapeType::get(builder.getContext());
Value lhs_shape_v =
builder.createOrFold<shape::ShapeOfOp>(loc, shape_type, lhs);
Value rhs_shape_v =
builder.createOrFold<shape::ShapeOfOp>(loc, shape_type, rhs);
Value result_shape_v = builder.createOrFold<shape::BroadcastOp>(
loc, shape_type, lhs_shape_v, rhs_shape_v, nullptr /* error */);
return builder.createOrFold<shape::ToExtentTensorOp>(
loc, RankedTensorType::get({result_rank}, builder.getIndexType()),
result_shape_v);
}
} // namespace xla
} // namespace mlir

View File

@ -0,0 +1,86 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file defines helpers useful when creating or manipulating lhlo/hlo.
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/convert_op_folder.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/TypeUtilities.h"
namespace mlir {
namespace xla {
mlir::ElementsAttr ConvertElementsAttr(const mlir::ElementsAttr& elements,
mlir::Type new_type) {
auto old_type = getElementTypeOrSelf(elements);
size_t bit_width = new_type.isBF16() ? 64 : new_type.getIntOrFloatBitWidth();
if (old_type.isa<mlir::FloatType>()) {
// mapValues always takes a function returning APInt, even when the output
// is actually float.
using func_type = mlir::APInt(const llvm::APFloat&);
if (auto newFloatType = new_type.dyn_cast<mlir::FloatType>()) {
// Float -> Float
return elements.mapValues(
new_type, llvm::function_ref<func_type>(
[&newFloatType](const llvm::APFloat& floatVal) {
llvm::APFloat newDouble(
mlir::FloatAttr::getValueAsDouble(floatVal));
bool loses_info = false;
newDouble.convert(newFloatType.getFloatSemantics(),
llvm::APFloat::rmNearestTiesToEven,
&loses_info);
return newDouble.bitcastToAPInt();
}));
}
// Float -> Int
return elements.mapValues(
new_type, llvm::function_ref<func_type>(
[&bit_width](const llvm::APFloat& floatVal) {
return llvm::APInt(
bit_width,
mlir::FloatAttr::getValueAsDouble(floatVal));
}));
}
// old_type is Integer
// mapValues always takes a function returning APInt, even when the output
// is actually float.
using func_type = llvm::APInt(const llvm::APInt&);
if (auto newFloatType = new_type.dyn_cast<mlir::FloatType>()) {
// Int -> Float
return elements.mapValues(
new_type, llvm::function_ref<func_type>([&newFloatType](
const llvm::APInt& intVal) {
llvm::APFloat newDouble(static_cast<double>(intVal.getSExtValue()));
bool loses_info = false;
newDouble.convert(newFloatType.getFloatSemantics(),
llvm::APFloat::rmNearestTiesToEven, &loses_info);
return newDouble.bitcastToAPInt();
}));
}
// new_type is Integer
// Int -> Int
return elements.mapValues(
new_type,
llvm::function_ref<func_type>([&bit_width](const llvm::APInt& intVal) {
return llvm::APInt(bit_width, intVal.getSExtValue());
}));
}
} // namespace xla
} // namespace mlir

70
lib/utils/hlo_utils.cc Normal file
View File

@ -0,0 +1,70 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h"
#include <numeric>
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
namespace mlir {
namespace xla {
DenseIntElementsAttr getBroadcastDimensionsAttr(Builder *b, Value x, Value y,
bool allow_empty) {
TensorType xType = x.getType().dyn_cast<RankedTensorType>();
TensorType yType = y.getType().dyn_cast<RankedTensorType>();
if (!xType || !yType) return {};
if (allow_empty && xType == yType) return {};
// If the shapes have the same rank, then there is nothing to do.
auto xRank = xType.getRank(), yRank = yType.getRank();
if (allow_empty && xRank == yRank) return {};
// Otherwise if the ranks of the inputs don't match, TensorFlow automatically
// reshapes the smaller by padding with dimensions of size 1 as a prefix. In
// other words to pad a 5-vector to a 3-dimensional tensor it is reshaped to
// have shape [1,1,5]. XLA's automatic broadcast code is able to broadcast
// from lower to higher rank, but doesn't assume you want to pad as a prefix
// of the dimensions, and instead needs to be told which dimensions of the
// higher rank tensor to match to the lower rank tensor.
auto maxRank = std::max(xRank, yRank);
auto minRank = std::min(xRank, yRank);
// Match the lower rank tensor along the larger-numbered dimensions of the
// higher rank tensor.
SmallVector<int64_t, 4> broadcastDimensions(minRank);
std::iota(broadcastDimensions.begin(), broadcastDimensions.end(),
maxRank - minRank);
RankedTensorType type =
RankedTensorType::get({minRank}, b->getIntegerType(64));
return DenseIntElementsAttr::get(type, broadcastDimensions);
}
DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value) {
RankedTensorType scalar_ty = RankedTensorType::get({}, ty);
if (auto float_ty = ty.dyn_cast<FloatType>()) {
APFloat value(float_ty.getFloatSemantics(), raw_value);
return DenseElementsAttr::get(scalar_ty, value);
}
auto int_ty = ty.cast<IntegerType>();
APInt value(int_ty.getWidth(), static_cast<int64_t>(raw_value), true);
return DenseElementsAttr::get(scalar_ty, value);
}
} // namespace xla
} // namespace mlir