Move XLA-independent transforms to the new MLIR-HLO directory

This is as straighforward as possible, more cleanup/rewrite to come.

PiperOrigin-RevId: 319849713
This commit is contained in:
Mehdi Amini 2020-07-06 20:57:00 +00:00 committed by Mehdi Amini
parent 72010faaa7
commit 31dc1b21eb
37 changed files with 7031 additions and 12 deletions

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_CHLO_OPS_H_
#define TENSORFLOW_COMPILER_MLIR_XLA_IR_CHLO_OPS_H_
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_CHLO_OPS_H_
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_CHLO_OPS_H_
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringRef.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Dialect.h"
@ -42,4 +42,4 @@ class XlaHloClientDialect : public Dialect {
} // namespace xla_chlo
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_XLA_IR_CHLO_OPS_H_
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_CHLO_OPS_H_

View File

@ -15,8 +15,8 @@ limitations under the License.
// This file defines the operations used in the XLA dialect.
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_OPS_H_
#define TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_OPS_H_
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_H_
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_H_
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringRef.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
@ -96,4 +96,4 @@ LogicalResult deriveShapeFromFirstOperand(
} // end namespace xla_hlo
} // end namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_OPS_H_
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_H_

View File

@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_INFER_FUSIBILITY_OP_INTERFACE_H_
#define TENSORFLOW_COMPILER_MLIR_XLA_IR_INFER_FUSIBILITY_OP_INTERFACE_H_
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_INFER_FUSIBILITY_OP_INTERFACE_H_
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_INFER_FUSIBILITY_OP_INTERFACE_H_
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/StandardTypes.h"
@ -25,4 +25,4 @@ namespace mlir {
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_XLA_IR_INFER_FUSIBILITY_OP_INTERFACE_H_
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_INFER_FUSIBILITY_OP_INTERFACE_H_

View File

@ -15,8 +15,8 @@ limitations under the License.
// This file defines the operations used in the LXLA dialect.
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_LHLO_OPS_H_
#define TENSORFLOW_COMPILER_MLIR_XLA_IR_LHLO_OPS_H_
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_OPS_H_
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_OPS_H_
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringRef.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
@ -49,4 +49,4 @@ class XlaLhloDialect : public Dialect {
} // namespace xla_lhlo
} // end namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_XLA_IR_LHLO_OPS_H_
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_OPS_H_

View File

@ -0,0 +1,80 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_HLO_TO_LHLO_OP_H_
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_HLO_TO_LHLO_OP_H_
#include <type_traits>
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
namespace mlir {
namespace xla_hlo {
template <typename HloOpTy>
struct HloToLhloOpImpl {
using Type = std::false_type;
};
template <typename HloOpTy>
using HloToLhloOp = typename HloToLhloOpImpl<HloOpTy>::Type;
#define MAP_HLO_TO_LHLO(OpName) \
template <> \
struct HloToLhloOpImpl<xla_hlo::OpName> { \
using Type = xla_lhlo::OpName; \
}
MAP_HLO_TO_LHLO(AbsOp);
MAP_HLO_TO_LHLO(AddOp);
MAP_HLO_TO_LHLO(AndOp);
MAP_HLO_TO_LHLO(BroadcastInDimOp);
MAP_HLO_TO_LHLO(CeilOp);
MAP_HLO_TO_LHLO(ConstOp);
MAP_HLO_TO_LHLO(CompareOp);
MAP_HLO_TO_LHLO(ComplexOp);
MAP_HLO_TO_LHLO(ConvOp);
MAP_HLO_TO_LHLO(ConvertOp);
MAP_HLO_TO_LHLO(CopyOp);
MAP_HLO_TO_LHLO(CosOp);
MAP_HLO_TO_LHLO(DivOp);
MAP_HLO_TO_LHLO(DotOp);
MAP_HLO_TO_LHLO(ExpOp);
MAP_HLO_TO_LHLO(GatherOp);
MAP_HLO_TO_LHLO(ImagOp);
MAP_HLO_TO_LHLO(IotaOp);
MAP_HLO_TO_LHLO(LogOp);
MAP_HLO_TO_LHLO(MaxOp);
MAP_HLO_TO_LHLO(MinOp);
MAP_HLO_TO_LHLO(MulOp);
MAP_HLO_TO_LHLO(NegOp);
MAP_HLO_TO_LHLO(RealOp);
MAP_HLO_TO_LHLO(ReduceOp);
MAP_HLO_TO_LHLO(ReshapeOp);
MAP_HLO_TO_LHLO(RemOp);
MAP_HLO_TO_LHLO(RsqrtOp);
MAP_HLO_TO_LHLO(SelectOp);
MAP_HLO_TO_LHLO(SignOp);
MAP_HLO_TO_LHLO(SinOp);
MAP_HLO_TO_LHLO(SqrtOp);
MAP_HLO_TO_LHLO(SubOp);
MAP_HLO_TO_LHLO(TanhOp);
#undef MAP_HLO_TO_LHLO
} // namespace xla_hlo
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_HLO_TO_LHLO_OP_H_

View File

@ -0,0 +1,510 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_XLA_TO_SCALAR_OP_H_
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_XLA_TO_SCALAR_OP_H_
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringRef.h"
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringSwitch.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h"
namespace mlir {
namespace xla_lhlo {
namespace impl {
// A struct to map LhloBinaryOpTy type to the corresponding floating-point and
// integer scalar operation types.
template <typename LhloBinaryOpTy>
struct LhloToScalarOp;
template <>
struct LhloToScalarOp<xla_lhlo::AddOp> {
using FOp = ::mlir::AddFOp;
using IOp = ::mlir::AddIOp;
};
template <>
struct LhloToScalarOp<xla_lhlo::CompareOp> {
using FOp = ::mlir::CmpFOp;
using IOp = ::mlir::CmpIOp;
};
template <>
struct LhloToScalarOp<xla_lhlo::DivOp> {
using FOp = ::mlir::DivFOp;
using IOp = ::mlir::SignedDivIOp;
};
template <>
struct LhloToScalarOp<xla_lhlo::MulOp> {
using FOp = ::mlir::MulFOp;
using IOp = ::mlir::MulIOp;
};
template <>
struct LhloToScalarOp<xla_lhlo::RemOp> {
using FOp = ::mlir::RemFOp;
using IOp = ::mlir::SignedRemIOp;
};
template <>
struct LhloToScalarOp<xla_lhlo::SubOp> {
using FOp = ::mlir::SubFOp;
using IOp = ::mlir::SubIOp;
};
template <typename LhloBinaryOpTy>
struct ScalarOp {
using FOp = typename LhloToScalarOp<LhloBinaryOpTy>::FOp;
using IOp = typename LhloToScalarOp<LhloBinaryOpTy>::IOp;
};
// Alias for the map from LHLO binary op type to STD floating-point op type.
template <typename LhloOp>
using ScalarFOp = typename ScalarOp<LhloOp>::FOp;
// Alias for the map from LHLO binary op type to STD integer op type.
template <typename LhloOp>
using ScalarIOp = typename ScalarOp<LhloOp>::IOp;
template <typename... Args>
struct MapLhloOpToStdScalarOpImpl {
Value operator()(Location loc, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b) {
return nullptr;
}
};
template <typename StdScalarOp>
struct MapLhloOpToStdScalarOpImpl<StdScalarOp> {
Value operator()(Location loc, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b) {
return b->template create<StdScalarOp>(loc, result_types, args, mlir::None);
}
};
template <typename SupportedType, typename StdScalarOp, typename... Args>
struct MapLhloOpToStdScalarOpImpl<SupportedType, StdScalarOp, Args...> {
Value operator()(Location loc, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b) {
Type element_type = args.front().getType();
if (element_type.isa<SupportedType>()) {
return b->template create<StdScalarOp>(loc, result_types, args,
mlir::None);
}
return MapLhloOpToStdScalarOpImpl<Args...>{}(loc, result_types, args, b);
}
};
// Inserts the computation that corresponds to the body of the loop for lowered
// LHLO unary/binary op. Returns the value for the result.
template <typename LhloOpTy>
inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<IntegerType, ScalarIOp<LhloOpTy>, FloatType,
ScalarFOp<LhloOpTy>>{}(loc, result_types,
args, b);
}
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::AbsOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
Type element_type = args.front().getType();
if (element_type.isa<FloatType>()) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::AbsFOp>{}(
loc, result_types, args, b);
}
if (element_type.isa<IntegerType>()) {
// xla_lhlo.abs(x, result) -> result = select((x > 0), x, sub(0, x))
Value lhs = args[0];
auto integer_type = element_type.dyn_cast<IntegerType>();
auto zero_intval =
b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth());
auto lhs_gt_zero = b->create<ScalarIOp<CompareOp>>(loc, CmpIPredicate::sge,
lhs, zero_intval);
auto neg_val = b->create<ScalarIOp<xla_lhlo::SubOp>>(loc, zero_intval, lhs);
return b->create<::mlir::SelectOp>(loc, lhs_gt_zero, lhs, neg_val);
}
return nullptr;
}
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::AndOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<IntegerType, ::mlir::AndOp>{}(
loc, result_types, args, b);
}
template <typename PredicateType>
inline Optional<PredicateType> getCmpPredicate(
StringRef xla_comparison_direction) {
return llvm::None;
}
template <>
inline Optional<CmpFPredicate> getCmpPredicate<CmpFPredicate>(
StringRef xla_comparison_direction) {
return llvm::StringSwitch<Optional<CmpFPredicate>>(xla_comparison_direction)
.Case("EQ", CmpFPredicate::OEQ)
.Case("NE", CmpFPredicate::ONE)
.Case("GE", CmpFPredicate::OGE)
.Case("GT", CmpFPredicate::OGT)
.Case("LE", CmpFPredicate::OLE)
.Case("LT", CmpFPredicate::OLT)
.Default(llvm::None);
}
template <>
inline Optional<CmpIPredicate> getCmpPredicate<CmpIPredicate>(
StringRef xla_comparison_direction) {
return llvm::StringSwitch<Optional<CmpIPredicate>>(xla_comparison_direction)
.Case("EQ", CmpIPredicate::eq)
.Case("NE", CmpIPredicate::ne)
.Case("GE", CmpIPredicate::sge)
.Case("GT", CmpIPredicate::sgt)
.Case("LE", CmpIPredicate::sle)
.Case("LT", CmpIPredicate::slt)
.Default(llvm::None);
}
template <typename XLACompareOpTy>
inline Value MapXlaCompareOpToStdScalarOp(Location loc,
StringRef comparison_direction,
ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b) {
const auto& lhs = args[0];
const auto& rhs = args[1];
Type element_type = lhs.getType();
if (element_type.isSignlessInteger()) {
Optional<CmpIPredicate> predicate =
getCmpPredicate<CmpIPredicate>(comparison_direction);
assert(predicate.hasValue() && "expected valid comparison direction");
return b->create<ScalarIOp<XLACompareOpTy>>(loc, predicate.getValue(), lhs,
rhs);
}
if (element_type.isa<FloatType>()) {
Optional<CmpFPredicate> predicate =
getCmpPredicate<CmpFPredicate>(comparison_direction);
assert(predicate.hasValue() && "expected valid comparison direction");
return b->create<ScalarFOp<XLACompareOpTy>>(loc, predicate.getValue(), lhs,
rhs);
}
return nullptr;
}
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::CopyOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
return args.front();
}
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::ExpOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::ExpOp>{}(
loc, result_types, args, b);
}
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::CeilOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::CeilFOp>{}(
loc, result_types, args, b);
}
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::ComplexOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<CreateComplexOp>{}(loc, result_types, args,
b);
}
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::RealOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<ReOp>{}(loc, result_types, args, b);
}
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::ImagOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<ImOp>{}(loc, result_types, args, b);
}
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::ConvertOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
Type sourceType = args.front().getType();
Type targetType = result_types.front();
if (mlir::SIToFPOp::areCastCompatible(sourceType, targetType)) {
return b->create<mlir::SIToFPOp>(loc, result_types, args, mlir::None);
} else if (sourceType.isa<FloatType>() && targetType.isa<FloatType>()) {
FloatType src = sourceType.cast<FloatType>();
FloatType res = targetType.cast<FloatType>();
if (src.getWidth() > res.getWidth()) {
return b->create<mlir::FPTruncOp>(loc, result_types, args, mlir::None);
} else if (src.getWidth() < res.getWidth()) {
return b->create<mlir::FPExtOp>(loc, result_types, args, mlir::None);
}
// No conversion is needed for the same width floats
return args.front();
}
if (sourceType.isSignlessInteger() && targetType.isSignlessInteger()) {
IntegerType src = sourceType.cast<IntegerType>();
IntegerType res = targetType.cast<IntegerType>();
if (src.getWidth() > res.getWidth()) {
return b->create<mlir::TruncateIOp>(loc, result_types, args, mlir::None);
} else if (src.getWidth() < res.getWidth()) {
return b->create<mlir::ZeroExtendIOp>(loc, result_types, args,
mlir::None);
}
// No conversion is needed for the same width integers
return args.front();
}
if (mlir::FPToSIOp::areCastCompatible(sourceType, targetType)) {
return b->create<mlir::FPToSIOp>(loc, result_types, args, mlir::None);
}
return nullptr;
}
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::DotOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
// Dot Op converter from lhlo to affine only accepts float and integer types.
const auto& lhs = args[0];
const auto& rhs = args[1];
const auto& result = args[2];
Type element_type = lhs.getType();
if (element_type.isa<FloatType>()) {
Value float_mul = MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::MulFOp>{}(
loc, result_types, {lhs, rhs}, b);
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::AddFOp>{}(
loc, result_types, {float_mul, result}, b);
}
if (element_type.isa<IntegerType>()) {
Value int_mul = MapLhloOpToStdScalarOpImpl<IntegerType, ::mlir::MulIOp>{}(
loc, result_types, {lhs, rhs}, b);
return MapLhloOpToStdScalarOpImpl<IntegerType, ::mlir::AddIOp>{}(
loc, result_types, {int_mul, result}, b);
}
return nullptr;
}
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::CosOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::CosOp>{}(
loc, result_types, args, b);
}
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::SinOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::SinOp>{}(
loc, result_types, args, b);
}
/// Implements the conversion of XLA op to scalar op (to use within region of a
/// linalg.generic op) for compare-select style operations like min/max.
template <typename... Args>
struct XlaCompareSelectOpToStdScalarOp {
static Value map(Location loc, StringRef comparison_direction,
ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
return nullptr;
}
};
/// Specialization which allows converting to a comparison operation in standard
/// dialect with a given predicate based on the element type of the operand.
template <typename SupportedType, typename StdCompareOp, typename Predicate,
typename... Args>
struct XlaCompareSelectOpToStdScalarOp<SupportedType, StdCompareOp, Predicate,
Args...> {
static Value map(Location loc, StringRef comparison_direction,
ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
Type element_type = args.front().getType();
if (element_type.isa<SupportedType>()) {
auto predicate = getCmpPredicate<Predicate>(comparison_direction);
assert(predicate.hasValue() && "expected valid comparison direction");
auto cmp = b->template create<StdCompareOp>(loc, predicate.getValue(),
args[0], args[1]);
return b->create<::mlir::SelectOp>(loc, cmp, args[0], args[1]);
}
return XlaCompareSelectOpToStdScalarOp<Args...>::map(
loc, comparison_direction, result_types, args, b);
}
};
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::LogOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::LogOp>{}(
loc, result_types, args, b);
}
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::MaxOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
return XlaCompareSelectOpToStdScalarOp<
IntegerType, ScalarIOp<xla_lhlo::CompareOp>, CmpIPredicate, FloatType,
ScalarFOp<xla_lhlo::CompareOp>, CmpFPredicate>::map(loc, "GT",
result_types, args,
b);
}
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::MinOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
return XlaCompareSelectOpToStdScalarOp<
IntegerType, ScalarIOp<xla_lhlo::CompareOp>, CmpIPredicate, FloatType,
ScalarFOp<xla_lhlo::CompareOp>, CmpFPredicate>::map(loc, "LT",
result_types, args,
b);
}
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::NegOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
Type element_type = args.front().getType();
if (element_type.isa<FloatType>()) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::NegFOp>{}(
loc, result_types, args, b);
}
if (element_type.isa<IntegerType>()) {
// xla_lhlo.neg(x, result) -> result = sub(0, x)
Value lhs = args[0];
auto integer_type = element_type.dyn_cast<IntegerType>();
auto zero_intval =
b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth());
return b->create<ScalarIOp<xla_lhlo::SubOp>>(loc, zero_intval, lhs);
}
return nullptr;
}
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::RsqrtOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::RsqrtOp>{}(
loc, result_types, args, b);
}
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::SelectOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<::mlir::SelectOp>{}(loc, result_types, args,
b);
}
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::SignOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
Type element_type = args.front().getType();
if (element_type.isa<FloatType>()) {
FloatType float_type = element_type.cast<FloatType>();
APFloat const_value = float_type.isF32() ? APFloat(1.0f) : APFloat(1.0);
Value one = b->create<mlir::ConstantFloatOp>(loc, const_value, float_type);
return b->create<::mlir::CopySignOp>(loc, result_types, one, args[0]);
}
return nullptr;
}
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::SqrtOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::SqrtOp>{}(
loc, result_types, args, b);
}
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::TanhOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::TanhOp>{}(
loc, result_types, args, b);
}
} // namespace impl
struct XlaOpToStdScalarOp {
// Implementation for LHLO ops except xla_lhlo::CompareOp.
template <typename XlaOpTy, typename LhloOpTy = XlaOpTy,
typename = std::enable_if_t<
!std::is_same<LhloOpTy, xla_lhlo::CompareOp>::value &&
std::is_same<typename xla_hlo::HloToLhloOp<LhloOpTy>,
std::false_type>::value>>
static Value map(XlaOpTy op, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b, unsigned i = 0) {
return impl::MapLhloOpToStdScalarOp<LhloOpTy>(op.getLoc(), result_types,
args, b);
}
// Implementation for HLO ops except xla_hlo::CompareOp.
template <typename XlaOpTy, typename LhloOpTy = xla_hlo::HloToLhloOp<XlaOpTy>,
typename = std::enable_if_t<
!std::is_same<LhloOpTy, xla_lhlo::CompareOp>::value &&
!std::is_same<LhloOpTy, std::false_type>::value>>
static Value map(XlaOpTy op, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b, int i = 0) {
return impl::MapLhloOpToStdScalarOp<LhloOpTy>(op.getLoc(), result_types,
args, b);
}
// Implementation for xla_lhlo::CompareOp.
template <typename LhloOpTy, typename = std::enable_if_t<std::is_same<
LhloOpTy, xla_lhlo::CompareOp>::value>>
static Value map(xla_lhlo::CompareOp op, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b) {
auto comparison_direction = op.comparison_direction();
return impl::MapXlaCompareOpToStdScalarOp<xla_lhlo::CompareOp>(
op.getLoc(), comparison_direction, result_types, args, b);
}
// Implementation for xla_hlo::CompareOp.
template <typename HloOpTy, typename = std::enable_if_t<std::is_same<
HloOpTy, xla_hlo::CompareOp>::value>>
static Value map(xla_hlo::CompareOp op, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b) {
auto comparison_direction = op.comparison_direction();
return impl::MapXlaCompareOpToStdScalarOp<xla_lhlo::CompareOp>(
op.getLoc(), comparison_direction, result_types, args, b);
}
};
} // namespace xla_lhlo
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_XLA_TO_SCALAR_OP_H_

View File

@ -0,0 +1,105 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_PASSES_H_
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_PASSES_H_
#include <memory>
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/ArrayRef.h"
namespace mlir {
class FuncOp;
class ModuleOp;
class Operation;
template <typename T>
class OperationPass;
class Pass;
namespace xla_hlo {
/// Lowers HLO control flow ops to the Standard dialect.
std::unique_ptr<OperationPass<FuncOp>> createLegalizeControlFlowPass();
/// Lowers from HLO dialect to Standard dialect.
std::unique_ptr<OperationPass<FuncOp>> createLegalizeToStdPass();
/// Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary
/// buffers if necessary. If `results_escape_functions` is set to true,
/// allocated buffers for function results will be returned and escape the
/// function. Otherwise, the signature is rewritten with extra arguments for the
/// buffers that are to be used for results.
std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass(
bool results_escape_functions = false);
// Lowers from HLO dialect to Linalg dialect.
std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass();
// Transforms unranked HLO operations to ranked ones where possible.
std::unique_ptr<OperationPass<FuncOp>> createTransformUnrankedHloPass();
// Sinks constants implicitly captured in control flow regions. This is
// necessary to export to XLA.
std::unique_ptr<OperationPass<FuncOp>> createSinkConstantsToControlFlowPass();
// fuse xla_hlo ops to kLoop/kInput fusion patterns
std::unique_ptr<OperationPass<FuncOp>> createXlaHloFusionPass();
} // namespace xla_hlo
namespace xla_lhlo {
// Lowers from LHLO dialect to Affine dialect.
std::unique_ptr<OperationPass<FuncOp>> createLegalizeToAffinePass();
// Lowers from LHLO dialect to Linalg dialect.
std::unique_ptr<OperationPass<FuncOp>> createLegalizeLhloToLinalgPass();
// Lowers from LHLO dialect to GPU dialect.
std::unique_ptr<OperationPass<FuncOp>> createLegalizeToGpuPass();
// Fuses linalg ops obtained after LHLO lowering. To enable fusion,
// operations are first tiled.
//
// When 'use_parallel_loops' is set, the tiling will use scf.parallel
// operations. Otherwise, scf.for operations are used.
//
// 'tile_sizes' provides the tile sizes to use for tiling. If the linalg
// operation has more dimensions than tile sizes provided, 1 is used as
// default.
std::unique_ptr<OperationPass<FuncOp>> createLhloFuseLinalg(
bool use_parallel_loops = false, llvm::ArrayRef<unsigned> tile_sizes = {});
// Removes unnecessary LHLO copies which copy from the allocated buffers to the
// block arguments. The block arguments are used instead of all uses of these
// buffers. The buffers are freed. This pass only works in regions that contain
// a single block.
std::unique_ptr<Pass> createLhloCopyRemovalPass();
// Lowers from LHLO dialect to parallel loops.
std::unique_ptr<OperationPass<FuncOp>> createLegalizeLhloToParallelLoopsPass();
} // namespace xla_lhlo
namespace xla {
/// Lowers the standard TanhOp to an approximation that does not use intrinsics.
std::unique_ptr<OperationPass<FuncOp>> createLegalizeTanhToApproximationPass();
} // namespace xla
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_PASSES_H_

View File

@ -0,0 +1,106 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_REWRITERS_H_
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_REWRITERS_H_
#include <memory>
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h"
namespace mlir {
class LLVMTypeConverter;
class LowerToLLVMOptions;
class OwningRewritePatternList;
class BufferAssignmentPlacer;
namespace xla_hlo {
// Collection of rewrite patterns for lowering a general dot product.
void PopulateGeneralDotOpLoweringPatterns(OwningRewritePatternList *patterns,
MLIRContext *ctx);
// Collection of rewrite patterns for lowering complex operations to equivalent
// float operations.
void PopulateComplexLoweringPatterns(MLIRContext *context,
OwningRewritePatternList *patterns);
void PopulateXlaToStdPatterns(OwningRewritePatternList *patterns,
MLIRContext *ctx);
// Collection of rewrite patterns for lowering of HLO to LHLO dialect.
void populateHLOToLHLOConversionPattern(
MLIRContext *context, BufferAssignmentPlacer *bufferAssignment,
TypeConverter *converter, OwningRewritePatternList *patterns);
// Collection of rewrite patterns for lowering of HLO to Linalg dialect.
void populateHLOToLinalgConversionPattern(MLIRContext *context,
OwningRewritePatternList *patterns);
// Sets up legality definitions for materializing broadcasts.
void SetupMaterializeBroadcastsLegality(MLIRContext *context,
ConversionTarget *conversionTarget);
// Populates a collection of rewrite patterns for materializing broadcast
// attributes to equivalent sequences of ops.
void PopulateMaterializeBroadcastsPatterns(MLIRContext *context,
OwningRewritePatternList *patterns);
// Sets up legality definitions for element-wise operations on ranked tensors.
void SetupTransformUnrankedHloLegality(MLIRContext *context,
ConversionTarget *conversionTarget);
// Populates a collection of rewrite patterns to realize element-wise operations
// on ranked tensors where possible.
void PopulateTransformUnrankedHloPatterns(MLIRContext *context,
OwningRewritePatternList *patterns);
// Populate a collection of conversion patterns for un-fusing
// batch_norm_inference and batch_norm_training into constituent HLO ops.
// TODO(laurenzo): Implement un-fusing of batch_norm_training.
void PopulateUnfuseBatchNormPatterns(MLIRContext *context,
OwningRewritePatternList *patterns);
} // namespace xla_hlo
namespace xla_lhlo {
/// Collect a set of patterns to convert from the LHLO dialect to LLVM.
void PopulateLhloToLLVMConversionPatterns(const LowerToLLVMOptions &options,
LLVMTypeConverter *converter,
OwningRewritePatternList *patterns);
} // namespace xla_lhlo
namespace xla_chlo {
// Populates a collection of conversion patterns for legalizing client-HLO to
// HLO.
void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
OwningRewritePatternList *patterns);
} // namespace xla_chlo
namespace xla {
// Populates a pattern that translates the standard TanhOp to an approximation
// that does not use intrinsics.
void PopulateTanhToApproximationPatterns(MLIRContext *context,
OwningRewritePatternList *patterns);
} // namespace xla
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_REWRITERS_H_

View File

@ -0,0 +1,165 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_CYCLE_DETECTOR_H_
#define TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_CYCLE_DETECTOR_H_
#include <vector>
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/DenseMap.h"
namespace mlir {
// -------------------------------------------------------------------
// This file contains a light version of GraphCycles implemented in
// tensorflow/compiler/jit/graphcycles/graphcycles.h
//
// We re-implement it here because we do not want to rely
// on TensorFlow data structures, and hence we can move
// corresponding passes to llvm repo. easily in case necessnary.
// --------------------------------------------------------------------
// This is a set data structure that provides a deterministic iteration order.
// The iteration order of elements only depends on the sequence of
// inserts/deletes, so as long as the inserts/deletes happen in the same
// sequence, the set will have the same iteration order.
//
// Assumes that T can be cheaply copied for simplicity.
template <typename T>
class OrderedSet {
public:
// Inserts `value` into the ordered set. Returns true if the value was not
// present in the set before the insertion.
bool Insert(T value) {
bool new_insertion =
value_to_index_.insert({value, value_sequence_.size()}).second;
if (new_insertion) {
value_sequence_.push_back(value);
}
return new_insertion;
}
// Removes `value` from the set. Assumes `value` is already present in the
// set.
void Erase(T value) {
auto it = value_to_index_.find(value);
// Since we don't want to move values around in `value_sequence_` we swap
// the value in the last position and with value to be deleted and then
// pop_back.
value_to_index_[value_sequence_.back()] = it->second;
std::swap(value_sequence_[it->second], value_sequence_.back());
value_sequence_.pop_back();
value_to_index_.erase(it);
}
void Reserve(size_t new_size) {
value_to_index_.reserve(new_size);
value_sequence_.reserve(new_size);
}
void Clear() {
value_to_index_.clear();
value_sequence_.clear();
}
bool Contains(T value) const { return value_to_index_.count(value); }
size_t Size() const { return value_sequence_.size(); }
const std::vector<T>& GetSequence() const { return value_sequence_; }
private:
// The stable order that we maintain through insertions and deletions.
std::vector<T> value_sequence_;
// Maps values to their indices in `value_sequence_`.
llvm::DenseMap<T, int> value_to_index_;
};
// ---------------------------------------------------------------------
// GraphCycles detects the introduction of a cycle into a directed
// graph that is being built up incrementally.
//
// Nodes are identified by small integers. It is not possible to
// record multiple edges with the same (source, destination) pair;
// requests to add an edge where one already exists are silently
// ignored.
//
// It is also not possible to introduce a cycle; an attempt to insert
// an edge that would introduce a cycle fails and returns false.
//
// GraphCycles uses no internal locking; calls into it should be
// serialized externally.
// Performance considerations:
// Works well on sparse graphs, poorly on dense graphs.
// Extra information is maintained incrementally to detect cycles quickly.
// InsertEdge() is very fast when the edge already exists, and reasonably fast
// otherwise.
// FindPath() is linear in the size of the graph.
// The current implementation uses O(|V|+|E|) space.
class GraphCycles {
public:
explicit GraphCycles(int32_t num_nodes);
~GraphCycles();
// Attempt to insert an edge from x to y. If the
// edge would introduce a cycle, return false without making any
// changes. Otherwise add the edge and return true.
bool InsertEdge(int32_t x, int32_t y);
// Remove any edge that exists from x to y.
void RemoveEdge(int32_t x, int32_t y);
// Return whether there is an edge directly from x to y.
bool HasEdge(int32_t x, int32_t y) const;
// Contracts the edge from 'a' to node 'b', merging nodes 'a' and 'b'. One of
// the nodes is removed from the graph, and edges to/from it are added to
// the remaining one, which is returned. If contracting the edge would create
// a cycle, does nothing and return no value.
llvm::Optional<int32_t> ContractEdge(int32_t a, int32_t b);
// Return whether dest_node `y` is reachable from source_node `x`
// by following edges. This is non-thread-safe version.
bool IsReachable(int32_t x, int32_t y);
// Return a copy of the successors set. This is needed for code using the
// collection while modifying the GraphCycles.
std::vector<int32_t> SuccessorsCopy(int32_t node) const;
// Returns all nodes in post order.
//
// If there is a path from X to Y then X appears after Y in the
// returned vector.
std::vector<int32_t> AllNodesInPostOrder() const;
// ----------------------------------------------------
struct Rep;
private:
GraphCycles(const GraphCycles&) = delete;
GraphCycles& operator=(const GraphCycles&) = delete;
Rep* rep_; // opaque representation
};
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_CYCLE_DETECTOR_H_

View File

@ -0,0 +1,242 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Shape/IR/Shape.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OperationSupport.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/broadcast_utils.h"
namespace mlir {
namespace xla_chlo {
namespace {
// Converts binary ops that statically are determined to not broadcast directly
// to the corresponding xla_hlo non-broadcasting op.
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
struct ConvertTrivialNonBroadcastBinaryOp : public OpRewritePattern<ChloOpTy> {
using OpRewritePattern<ChloOpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(ChloOpTy op,
PatternRewriter &rewriter) const override {
// Only rewrite for statically determinable non-broadcasting cases.
auto lhs_type = op.lhs().getType().template dyn_cast<RankedTensorType>();
auto rhs_type = op.rhs().getType().template dyn_cast<RankedTensorType>();
if (!lhs_type || !rhs_type) return failure();
// Requires rank broadcast.
if (lhs_type.getRank() != rhs_type.getRank()) return failure();
// Any dynamic dimension may require broadcasting and requires more
// analysis.
if (!lhs_type.hasStaticShape() || !rhs_type.hasStaticShape())
return failure();
for (auto extents : llvm::zip(lhs_type.getShape(), rhs_type.getShape())) {
auto lhs_extent = std::get<0>(extents);
auto rhs_extent = std::get<1>(extents);
if (lhs_extent != rhs_extent) {
return failure();
}
}
rewriter.replaceOp(op, {Adaptor::CreateOp(op, op.getResult().getType(),
op.lhs(), op.rhs(), rewriter)});
return success();
}
};
// Converts a binary op with ranked broadcasting operands to explicitly
// broadcast and invoke the corresponding xla_hlo non-broadcasting op.
// Note that dynamic broadcasting supported by this pattern is only valid for
// "numpy" broadcasting semantics as defined here:
// https://docs.scipy.org/doc/numpy/reference/ufuncs.html
// Specifically, this includes the following cases:
// - Same rank broadcast (operands have the same static rank).
// - Different-rank broadcast, either without a broadcast_dims attribte or
// with the broadcast_dims attribute set to map to a prefix padding.
// - Legal combinations of degenerate (1-dim) implicit broadcasting.
// The restriction on broadcast_dims derives from the definition of the
// `shape.broadcast` op, which only supports prefix-padding.
//
// It may be possible to expand this pattern to operate on unranked tensors in
// the future by emitting more code to dynamically differentiate based on rank.
// Whether that is of any practical benefit remains to be seen.
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
struct ConvertRankedDynamicBroadcastBinaryOp
: public OpRewritePattern<ChloOpTy> {
using OpRewritePattern<ChloOpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(ChloOpTy op,
PatternRewriter &rewriter) const override {
// Only support ranked operands.
Value lhs = op.lhs();
Value rhs = op.rhs();
auto lhs_type = lhs.getType().dyn_cast<RankedTensorType>();
auto rhs_type = rhs.getType().dyn_cast<RankedTensorType>();
auto result_type =
op.getResult().getType().template dyn_cast<RankedTensorType>();
if (!lhs_type || !rhs_type || !result_type) return failure();
// Check for "numpy"-style rank broadcast.
auto broadcast_dimensions = op.broadcast_dimensions();
if (broadcast_dimensions &&
!xla::IsLegalNumpyRankedBroadcast(lhs, rhs, *broadcast_dimensions)) {
// Note: It is unclear whether the general specification of explicit
// broadcast_dimensions on binary ops is a feature we want to carry
// forward. While it can technically be implemented for ranked-dynamic,
// it is incompatible with unranked inputs. If this warning is emitted
// in real programs, it is an indication that the feature should be
// implemented versus just falling back on the more standard definition
// of numpy-like prefix-padding.
op.emitWarning() << "unsupported non prefix-padded dynamic rank "
<< "broadcast_dimensions = " << *broadcast_dimensions;
return failure();
}
// Compute result shape.
auto loc = op.getLoc();
// Insert a constraint on the shapes being broadcastable and insert all
// future code into an assuming block reliant on the constraint.
Value lhs_shape = rewriter.create<shape::ShapeOfOp>(loc, lhs);
Value rhs_shape = rewriter.create<shape::ShapeOfOp>(loc, rhs);
auto broadcastable_cstr =
rewriter.create<shape::CstrBroadcastableOp>(loc, lhs_shape, rhs_shape);
auto assuming_op = rewriter.create<shape::AssumingOp>(
loc, ArrayRef<Type>{result_type}, broadcastable_cstr.result());
OpBuilder::InsertionGuard guard(rewriter);
rewriter.createBlock(&assuming_op.doRegion());
int64_t result_rank = std::max(lhs_type.getRank(), rhs_type.getRank());
Value result_extents =
xla::ComputeBinaryElementwiseBroadcastingResultExtents(loc, lhs, rhs,
rewriter);
// Note that we unconditionally emit DynamicBroadcastInDim ops and let
// downstream canonicalizations fold them away if possible. This is
// because, in the dynamic case, there are many corner cases regarding
// when it is safe to omit, and some of them require analysis to prove
// properly.
auto lhs_broadcast_dimensions = llvm::to_vector<4>(
llvm::seq<int64_t>(result_rank - lhs_type.getRank(), result_rank));
Value broadcasted_lhs = rewriter.create<xla_hlo::DynamicBroadcastInDimOp>(
loc,
RankedTensorType::get(result_type.getShape(),
lhs_type.getElementType()),
lhs, result_extents,
rewriter.getI64TensorAttr(lhs_broadcast_dimensions));
auto rhs_broadcast_dimensions = llvm::to_vector<4>(
llvm::seq<int64_t>(result_rank - rhs_type.getRank(), result_rank));
Value broadcasted_rhs = rewriter.create<xla_hlo::DynamicBroadcastInDimOp>(
loc,
RankedTensorType::get(result_type.getShape(),
rhs_type.getElementType()),
rhs, result_extents,
rewriter.getI64TensorAttr(rhs_broadcast_dimensions));
// And generate the final non-broadcasted binary op.
Value final_result = Adaptor::CreateOp(op, result_type, broadcasted_lhs,
broadcasted_rhs, rewriter);
rewriter.create<shape::AssumingYieldOp>(loc, final_result);
rewriter.replaceOp(op, {assuming_op.getResult(0)});
return success();
}
};
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
void PopulateForBinaryOp(MLIRContext *context,
OwningRewritePatternList *patterns) {
patterns
->insert<ConvertTrivialNonBroadcastBinaryOp<ChloOpTy, HloOpTy, Adaptor>>(
context, 10);
patterns->insert<
ConvertRankedDynamicBroadcastBinaryOp<ChloOpTy, HloOpTy, Adaptor>>(
context, 5);
}
template <typename FromOpTy, typename ToOpTy>
struct HloBinaryElementwiseAdaptor {
static ToOpTy CreateOp(FromOpTy from_op, Type result_type,
Value broadcasted_lhs, Value broadcasted_rhs,
OpBuilder &builder) {
return builder.create<ToOpTy>(from_op.getLoc(), result_type,
broadcasted_lhs, broadcasted_rhs);
}
};
struct HloComplexAdaptor {
static xla_hlo::ComplexOp CreateOp(BroadcastComplexOp from_op,
Type result_type, Value broadcasted_lhs,
Value broadcasted_rhs,
OpBuilder &builder) {
return builder.create<xla_hlo::ComplexOp>(from_op.getLoc(), result_type,
broadcasted_lhs, broadcasted_rhs);
}
};
struct HloCompareAdaptor {
static xla_hlo::CompareOp CreateOp(BroadcastCompareOp from_op,
Type result_type, Value broadcasted_lhs,
Value broadcasted_rhs,
OpBuilder &builder) {
return builder.create<xla_hlo::CompareOp>(from_op.getLoc(), result_type,
broadcasted_lhs, broadcasted_rhs,
from_op.comparison_direction());
}
};
} // namespace
void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
OwningRewritePatternList *patterns) {
// Instantiate conversion templates for conforming binary elementwise ops
// that do not have different dtypes between operands and results and do
// not have special attributes that need to be preserved.
#define POPULATE_BCAST(ChloOp, HloOp) \
PopulateForBinaryOp<ChloOp, HloOp, \
HloBinaryElementwiseAdaptor<ChloOp, HloOp>>(context, \
patterns);
POPULATE_BCAST(BroadcastAddOp, xla_hlo::AddOp);
POPULATE_BCAST(BroadcastAndOp, xla_hlo::AndOp);
POPULATE_BCAST(BroadcastAtan2Op, xla_hlo::Atan2Op);
POPULATE_BCAST(BroadcastDivOp, xla_hlo::DivOp);
POPULATE_BCAST(BroadcastMaxOp, xla_hlo::MaxOp);
POPULATE_BCAST(BroadcastMinOp, xla_hlo::MinOp);
POPULATE_BCAST(BroadcastMulOp, xla_hlo::MulOp);
POPULATE_BCAST(BroadcastOrOp, xla_hlo::OrOp);
POPULATE_BCAST(BroadcastPowOp, xla_hlo::PowOp);
POPULATE_BCAST(BroadcastRemOp, xla_hlo::RemOp);
POPULATE_BCAST(BroadcastShiftLeftOp, xla_hlo::ShiftLeftOp);
POPULATE_BCAST(BroadcastShiftRightArithmeticOp,
xla_hlo::ShiftRightArithmeticOp);
POPULATE_BCAST(BroadcastShiftRightLogicalOp, xla_hlo::ShiftRightLogicalOp);
POPULATE_BCAST(BroadcastSubOp, xla_hlo::SubOp);
POPULATE_BCAST(BroadcastXorOp, xla_hlo::XorOp);
// Broadcasting ops requiring special construction.
PopulateForBinaryOp<BroadcastComplexOp, xla_hlo::ComplexOp,
HloComplexAdaptor>(context, patterns);
PopulateForBinaryOp<BroadcastCompareOp, xla_hlo::CompareOp,
HloCompareAdaptor>(context, patterns);
}
} // namespace xla_chlo
} // namespace mlir

View File

@ -0,0 +1,57 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Shape/IR/Shape.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
namespace mlir {
namespace xla_chlo {
namespace {
struct TestChloLegalizeToHloPass
: public PassWrapper<TestChloLegalizeToHloPass, FunctionPass> {
void runOnFunction() override {
ConversionTarget conversionTarget(getContext());
OwningRewritePatternList conversionPatterns;
conversionTarget.addIllegalDialect<XlaHloClientDialect>();
// Consider the xla_hlo dialect legal for tests.
conversionTarget.addLegalDialect<xla_hlo::XlaHloDialect>();
// The conversion uses helpers from the Standard dialect.
conversionTarget.addLegalDialect<mlir::StandardOpsDialect>();
conversionTarget.addLegalDialect<mlir::shape::ShapeDialect>();
PopulateLegalizeChloToHloPatterns(&getContext(), &conversionPatterns);
if (failed(applyPartialConversion(getFunction(), conversionTarget,
conversionPatterns))) {
return signalPassFailure();
}
}
};
} // namespace
} // namespace xla_chlo
} // namespace mlir
static mlir::PassRegistration<mlir::xla_chlo::TestChloLegalizeToHloPass> pass(
"test-xla-chlo-legalize-to-hlo",
"Test pass for applying chlo -> hlo legalization patterns");

View File

@ -0,0 +1,493 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file implements logic for lowering HLO dialect to LHLO dialect.
#include "third_party/absl/memory/memory.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/AffineMap.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/BlockAndValueMapping.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Function.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Location.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/BufferPlacement.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
namespace mlir {
namespace xla_hlo {
namespace {
template <typename T>
using BaseOpConversion = BufferAssignmentOpConversionPattern<T>;
using StdReturnOpConverter =
detail::BufferAssignmentReturnOpConverter<mlir::ReturnOp, mlir::ReturnOp,
xla_lhlo::CopyOp, true>;
Value InsertDynamicAllocAndDealloc(Location loc, Value result,
Value shape_operand,
ConversionPatternRewriter* rewriter) {
auto result_type = result.getType().dyn_cast<ShapedType>();
if (!result_type) {
result.getDefiningOp()->emitOpError()
<< "tensor to buffer conversion expects ranked results";
}
auto memref_type =
MemRefType::get(result_type.getShape(), result_type.getElementType());
Operation* op = result.getDefiningOp();
// Extract the required element out of the vector.
SmallVector<Value, 4> dynamic_operands;
for (auto shape_element : llvm::enumerate(result_type.getShape())) {
if (shape_element.value() != ShapedType::kDynamicSize) continue;
Value index = rewriter->create<ConstantOp>(
loc, rewriter->getIntegerAttr(rewriter->getIndexType(),
shape_element.index()));
Value alloc_operand = rewriter->create<ExtractElementOp>(loc, shape_operand,
ValueRange{index});
if (!alloc_operand.getType().isIndex()) {
alloc_operand = rewriter->create<IndexCastOp>(loc, alloc_operand,
rewriter->getIndexType());
}
dynamic_operands.push_back(alloc_operand);
}
// Insert in front of op to ensure sizes are available.
OpBuilder allocBuilder(op);
auto alloc = allocBuilder.create<AllocOp>(loc, memref_type, dynamic_operands);
return alloc;
}
Value InsertAlloc(Location loc, OpResult result,
BufferAssignmentPlacer* bufferAssignment,
ConversionPatternRewriter* rewriter) {
auto result_type = result.getType().dyn_cast<ShapedType>();
if (!result_type || !result_type.hasStaticShape()) {
result.getDefiningOp()->emitOpError()
<< "tensor to buffer conversion expects statically shaped results";
}
auto memref_type =
MemRefType::get(result_type.getShape(), result_type.getElementType());
OpBuilder::InsertionGuard guard(*rewriter);
rewriter->restoreInsertionPoint(
bufferAssignment->computeAllocPosition(result));
auto alloc = rewriter->create<AllocOp>(loc, memref_type);
return alloc;
}
template <typename HloOpTy>
class HloToLhloOpConverter : public BaseOpConversion<HloOpTy> {
public:
using BaseOpConversion<HloOpTy>::BaseOpConversion;
LogicalResult matchAndRewrite(
HloOpTy hloOp, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
Operation* op = hloOp.getOperation();
const auto& original_results = op->getResults();
SmallVector<Value, 4> buffer_args(operands.begin(), operands.end());
for (auto result : llvm::enumerate(original_results)) {
RankedTensorType resultType =
result.value().getType().dyn_cast<RankedTensorType>();
if (!resultType) {
return failure();
}
if (resultType.hasStaticShape()) {
buffer_args.push_back(InsertAlloc(op->getLoc(), result.value(),
this->bufferAssignment, &rewriter));
} else {
SmallVector<Value, 1> results_shape;
auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op);
if (!shape_type_op) return failure();
if (failed(
shape_type_op.reifyReturnTypeShapes(rewriter, results_shape)))
return failure();
buffer_args.push_back(InsertDynamicAllocAndDealloc(
op->getLoc(), result.value(), results_shape.front(), &rewriter));
}
}
rewriter.create<xla_hlo::HloToLhloOp<HloOpTy>>(op->getLoc(), llvm::None,
buffer_args, op->getAttrs());
rewriter.replaceOp(op, ArrayRef<Value>(buffer_args).slice(operands.size()));
return success();
}
};
struct HloToLhloDynamicBroadcastInDimOpConverter
: public BaseOpConversion<xla_hlo::DynamicBroadcastInDimOp> {
public:
using BaseOpConversion<xla_hlo::DynamicBroadcastInDimOp>::BaseOpConversion;
LogicalResult matchAndRewrite(
xla_hlo::DynamicBroadcastInDimOp op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
auto loc = op.getLoc();
Value resultBuffer = InsertDynamicAllocAndDealloc(
loc, op.getResult(), op.output_dimensions(), &rewriter);
Value transformed_operand =
InsertDynamicMemrefCastOp(op, operands.front(), &rewriter);
rewriter.create<xla_lhlo::BroadcastInDimOp>(
loc, transformed_operand, resultBuffer, op.broadcast_dimensions());
rewriter.replaceOp(op, {resultBuffer});
return success();
}
private:
// Inserts dynamic memref to change the layout of the memref to put 0-stride
// and size of the target dimension if size-1 dimension expansion is
// necessary.
xla_lhlo::DynamicMemRefCastOp InsertDynamicMemrefCastOp(
xla_hlo::DynamicBroadcastInDimOp op, Value operand, OpBuilder* b) const {
auto loc = op.getLoc();
auto operand_type = operand.getType().cast<MemRefType>();
auto operand_shape = operand_type.getShape();
SmallVector<Value, 2> sizes, strides;
sizes.reserve(operand_shape.size());
strides.reserve(operand_shape.size());
Value zero = b->create<ConstantIndexOp>(loc, 0);
Value one = b->create<ConstantIndexOp>(loc, 1);
for (auto dim : llvm::enumerate(op.broadcast_dimensions())) {
Value broadcast_dim_value =
b->create<ConstantIndexOp>(loc, dim.value().getSExtValue());
Value result_dim_size = b->create<ExtractElementOp>(
loc, op.output_dimensions(), broadcast_dim_value);
Value operand_dim_size =
ShapedType::isDynamic(operand_shape[dim.index()])
? b->create<DimOp>(loc, operand, dim.index()).getResult()
: b->create<ConstantIndexOp>(loc, operand_shape[dim.index()])
.getResult();
// TODO(pifon): Revisit if this cast is needed. Maybe we can use
// tensor<index> for `output_dimensions` as well.
if (!result_dim_size.getType().isIndex()) {
result_dim_size =
b->create<IndexCastOp>(loc, result_dim_size, b->getIndexType());
}
// There can be two cases:
// 1) Operand dim == result dim => expansion is not needed => stride := 1.
// 2) Operand dim < result dim => expansion is needed => stride := 0.
Value is_expansion = b->create<CmpIOp>(loc, CmpIPredicate::slt,
operand_dim_size, result_dim_size);
strides.push_back(
b->create<mlir::SelectOp>(loc, is_expansion, zero, one));
// Size of input dim can be set to the size of the corresponding output
// dimension for both cases.
sizes.push_back(result_dim_size);
}
// Type-erased memref type with static rank, dynamic sizes and strides.
SmallVector<int64_t, 2> dynamic_layout(operand_shape.size(),
MemRefType::kDynamicStrideOrOffset);
SmallVector<int64_t, 2> dynamic_shape(operand_shape.size(),
MemRefType::kDynamicSize);
auto type_erased_memref_type = MemRefType::get(
dynamic_shape, operand_type.getElementType(),
makeStridedLinearLayoutMap(dynamic_layout,
/*offset=*/0, b->getContext()));
auto transformed_operand = b->create<xla_lhlo::DynamicMemRefCastOp>(
loc, type_erased_memref_type, operand, sizes, strides);
return transformed_operand;
}
};
struct HloToLhloReduceOpConverter : public BaseOpConversion<xla_hlo::ReduceOp> {
public:
using BaseOpConversion<xla_hlo::ReduceOp>::BaseOpConversion;
LogicalResult matchAndRewrite(
xla_hlo::ReduceOp op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
auto loc = op.getLoc();
// TODO(b/137624192) Implement variadic reduce.
if (op.getNumResults() != 1) return failure();
if (!llvm::hasSingleElement(op.body())) {
return op.emitOpError()
<< "tensor to buffer conversion expects a single block "
"in the region containing the operation";
}
const auto& original_results = op.getResults();
SmallVector<Value, 4> buffer_args(operands.begin(), operands.end());
for (auto result : original_results) {
buffer_args.push_back(
InsertAlloc(loc, result, this->bufferAssignment, &rewriter));
}
auto new_op = rewriter.create<xla_lhlo::ReduceOp>(
loc, llvm::None, buffer_args, op.getAttrs());
// Copy over the operations inside the region.
rewriter.inlineRegionBefore(op.body(), new_op.body(), new_op.body().end());
// Create new block arguments with correct type.
auto& entry_block = new_op.body().front();
int original_arg_count = entry_block.getNumArguments();
for (int i = 0; i < original_arg_count; ++i) {
auto old_arg = entry_block.getArgument(i);
auto old_type = old_arg.getType().cast<TensorType>();
auto new_type =
MemRefType::get(old_type.getShape(), old_type.getElementType());
auto new_arg = entry_block.addArgument(new_type);
rewriter.replaceUsesOfBlockArgument(old_arg, new_arg);
}
// Add an argument for the result.
entry_block.addArgument(
entry_block.getArgument(original_arg_count).getType());
// Remove the old arguments.
for (int i = original_arg_count - 1; i >= 0; --i) {
entry_block.eraseArgument(i);
}
// Insert terminator at the end.
rewriter.setInsertionPointToEnd(&entry_block);
rewriter.create<xla_lhlo::TerminatorOp>(loc);
rewriter.replaceOp(op, ArrayRef<Value>(buffer_args).slice(operands.size()));
return success();
}
};
class HloToLhloTensorLoadOpConverter
: public BaseOpConversion<mlir::TensorLoadOp> {
public:
using BaseOpConversion<mlir::TensorLoadOp>::BaseOpConversion;
LogicalResult matchAndRewrite(
mlir::TensorLoadOp op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
rewriter.replaceOp(op, operands);
return success();
}
};
// TODO(b/137624192): Rewrite into a copy and elide copy if possible.
class HloToLhloTensorStoreOpConverter
: public BaseOpConversion<mlir::TensorStoreOp> {
public:
using BaseOpConversion<mlir::TensorStoreOp>::BaseOpConversion;
LogicalResult matchAndRewrite(
mlir::TensorStoreOp op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
rewriter.replaceOpWithNewOp<xla_lhlo::CopyOp>(
op, llvm::None, operands.front(), operands.back());
return success();
}
};
// Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary
// buffers if necessary.
//
// Example fusion with HLO ops.
//
// func @fusion(%arg0: memref<2x2xf32>,
// %arg1: memref<2x2xf32>,
// %arg2: memref<2x2xf32>,
// %arg3: memref<2x2xf32>) {
// "xla_lhlo.fusion"() ({
// %0 = tensor_load %arg1 : memref<2x2xf32>
// %1 = tensor_load %arg2 : memref<2x2xf32>
// %2 = "xla_hlo.add"(%0, %1) :
// (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// %3 = tensor_load %arg0 : memref<2x2xf32>
// %4 = "xla_hlo.multiply"(%2, %3) :
// (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// tensor_store %4, %arg3 : memref<2x2xf32>
// "xla_lhlo.terminator"() : () -> ()
// }) : () -> ()
// return
// }
//
// Transformed fusion with LHLO ops.
// func @fusion(%arg0: memref<2x2xf32>,
// %arg1: memref<2x2xf32>,
// %arg2: memref<2x2xf32>,
// %arg3: memref<2x2xf32>) {
// "xla_lhlo.fusion"() ( {
// %0 = alloc() : memref<2x2xf32>
// "xla_lhlo.add"(%arg1, %arg2, %0) :
// (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
// "xla_lhlo.multiply"(%0, %arg0, %arg3) :
// (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
// "xla_lhlo.terminator"() : () -> ()
// }) : () -> ()
// return
// }
//
// FuncOp signature conversion example:
//
// func @func_op(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// %0 = "xla_hlo.maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) ->
// tensor<4xf32> %1 = "xla_hlo.add"(%arg0, %0) : (tensor<4xf32>,
// tensor<4xf32>) -> tensor<4xf32> return %1 : tensor<4xf32>
// }
//
// Transformed function with an extra argument for the result. The types have
// been converted from tensor to memref.
//
// func @func_op(%arg0: memref<4xf32>,
// %arg1: memref<4xf32>,
// %arg2: memref<4xf32>) {
// %0 = alloc() : memref<4xf32>
// "xla_lhlo.maximum"(%arg0, %arg1, %0) :
// (memref<4xf32>, memref<4xf32>, memref<4xf32>) -> ()
// %1 = alloc() : memref<4xf32>
// "xla_lhlo.add"(%arg0, %0, %1) :
// (memref<4xf32>, memref<4xf32>, memref<4xf32>) -> ()
// "xla_lhlo.copy"(%1, %arg2) : (memref<4xf32>, memref<4xf32>) -> ()
// "xla_lhlo.terminator"() : () -> ()
// }
struct HloLegalizeToLhlo
: public PassWrapper<HloLegalizeToLhlo, OperationPass<ModuleOp>> {
public:
HloLegalizeToLhlo() = default;
HloLegalizeToLhlo(const HloLegalizeToLhlo& o) {
this->results_escape_function = o.results_escape_function.getValue();
}
explicit HloLegalizeToLhlo(bool results_escape_function) {
this->results_escape_function.setValue(results_escape_function);
}
void runOnOperation() override {
OwningRewritePatternList patterns;
auto& context = getContext();
ConversionTarget target(context);
target.addLegalDialect<xla_lhlo::XlaLhloDialect>();
target.addLegalDialect<StandardOpsDialect>();
target.addLegalOp<ModuleOp>();
target.addIllegalOp<mlir::TensorLoadOp>();
target.addIllegalOp<mlir::TensorStoreOp>();
target.addLegalOp<ModuleTerminatorOp>();
target.addLegalOp<TensorFromElementsOp>();
target.addIllegalDialect<xla_hlo::XlaHloDialect>();
BufferAssignmentTypeConverter converter;
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
auto inputs = op.getType().getInputs();
return llvm::all_of(inputs,
[](Type input) { return input.isa<MemRefType>(); }) &&
converter.isLegal(&op.getBody());
});
target.addDynamicallyLegalOp<mlir::ReturnOp>([&](mlir::ReturnOp returnOp) {
return std::all_of(returnOp.operand_type_begin(),
returnOp.operand_type_end(),
[](Type type) { return type.isa<MemRefType>(); });
});
auto module = getOperation();
WalkResult result = module.walk([&](FuncOp func) -> WalkResult {
BufferAssignmentPlacer bufferAssignment(func);
OwningRewritePatternList patterns;
populateHLOToLHLOConversionPattern(func.getContext(), &bufferAssignment,
&converter, &patterns);
if (results_escape_function) {
populateWithBufferAssignmentOpConversionPatterns<
mlir::ReturnOp, mlir::ReturnOp, xla_lhlo::CopyOp,
/*allowMemrefFunctionResults=*/true>(&context, &bufferAssignment,
&converter, &patterns);
} else {
populateWithBufferAssignmentOpConversionPatterns<
mlir::ReturnOp, mlir::ReturnOp, xla_lhlo::CopyOp,
/*allowMemrefFunctionResults=*/false>(&context, &bufferAssignment,
&converter, &patterns);
}
return applyPartialConversion(func, target, patterns);
});
if (result.wasInterrupted()) {
signalPassFailure();
}
}
private:
Option<bool> results_escape_function{
*this, "results-escape-function",
llvm::cl::desc(
"Allocate the results of functions within the functions body"),
llvm::cl::init(false)};
};
} // namespace
void populateHLOToLHLOConversionPattern(
MLIRContext* context, BufferAssignmentPlacer* bufferAssignment,
TypeConverter* converter, OwningRewritePatternList* patterns) {
// clang-format off
patterns->insert<
HloToLhloDynamicBroadcastInDimOpConverter,
HloToLhloOpConverter<xla_hlo::AbsOp>,
HloToLhloOpConverter<xla_hlo::AddOp>,
HloToLhloOpConverter<xla_hlo::AndOp>,
HloToLhloOpConverter<xla_hlo::BroadcastInDimOp>,
HloToLhloOpConverter<xla_hlo::CeilOp>,
HloToLhloOpConverter<xla_hlo::CompareOp>,
HloToLhloOpConverter<xla_hlo::ComplexOp>,
HloToLhloOpConverter<xla_hlo::ConstOp>,
HloToLhloOpConverter<xla_hlo::ConvOp>,
HloToLhloOpConverter<xla_hlo::ConvertOp>,
HloToLhloOpConverter<xla_hlo::CopyOp>,
HloToLhloOpConverter<xla_hlo::CosOp>,
HloToLhloOpConverter<xla_hlo::DivOp>,
HloToLhloOpConverter<xla_hlo::DotOp>,
HloToLhloOpConverter<xla_hlo::ExpOp>,
HloToLhloOpConverter<xla_hlo::GatherOp>,
HloToLhloOpConverter<xla_hlo::ImagOp>,
HloToLhloOpConverter<xla_hlo::IotaOp>,
HloToLhloOpConverter<xla_hlo::LogOp>,
HloToLhloOpConverter<xla_hlo::MaxOp>,
HloToLhloOpConverter<xla_hlo::MinOp>,
HloToLhloOpConverter<xla_hlo::MulOp>,
HloToLhloOpConverter<xla_hlo::NegOp>,
HloToLhloOpConverter<xla_hlo::RealOp>,
HloToLhloOpConverter<xla_hlo::RemOp>,
HloToLhloOpConverter<xla_hlo::RsqrtOp>,
HloToLhloOpConverter<xla_hlo::ReshapeOp>,
HloToLhloOpConverter<xla_hlo::SelectOp>,
HloToLhloOpConverter<xla_hlo::SignOp>,
HloToLhloOpConverter<xla_hlo::SqrtOp>,
HloToLhloOpConverter<xla_hlo::SubOp>,
HloToLhloOpConverter<xla_hlo::TanhOp>,
HloToLhloReduceOpConverter,
HloToLhloTensorLoadOpConverter,
HloToLhloTensorStoreOpConverter
>(context, bufferAssignment, converter);
// clang-format on
}
std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass(
bool results_escape_function) {
return absl::make_unique<HloLegalizeToLhlo>(results_escape_function);
}
static PassRegistration<HloLegalizeToLhlo> legalize_pass(
"hlo-legalize-to-lhlo", "Legalize from HLO dialect to LHLO dialect");
} // namespace xla_hlo
} // namespace mlir

View File

@ -0,0 +1,237 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file implements logic for lowering XLA dialect to Standard dialect.
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLExtras.h"
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringSwitch.h"
#include "third_party/llvm/llvm-project/llvm/include/llvm/Support/Casting.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Block.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/BlockAndValueMapping.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Function.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/TypeUtilities.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/PassRegistry.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Support/LogicalResult.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
using mlir::PassRegistration;
namespace mlir {
namespace xla_hlo {
namespace {
struct LegalizeControlFlow
: public mlir::PassWrapper<LegalizeControlFlow, FunctionPass> {
// Perform the lowering to MLIR control flow.
void runOnFunction() override;
};
// Replaces terminators for the newly created blocks from a targe region.
// These terminators are replaced with branch operations to a target block.
LogicalResult ReplaceTerminators(Region* region, Block* target_block,
Location loc,
const BlockAndValueMapping& mapper,
OpBuilder* builder) {
for (auto& old_block : region->getBlocks()) {
Block* block = mapper.lookup(&old_block);
auto return_op = dyn_cast<xla_hlo::ReturnOp>(block->getTerminator());
if (!return_op) continue;
builder->setInsertionPointToEnd(block);
builder->create<mlir::BranchOp>(loc, target_block, return_op.getOperands());
return_op.erase();
}
return success();
}
LogicalResult LowerIfOp(mlir::xla_hlo::IfOp if_op) {
Operation* op_inst = if_op.getOperation();
mlir::OpBuilder builder(if_op);
auto orig_block = op_inst->getBlock();
auto* tail_block = orig_block->splitBlock(op_inst);
auto loc = if_op.getLoc();
// Duplicate the true and false regions in the block between the sections
// before and after the conditional.
BlockAndValueMapping mapper;
if_op.true_branch().cloneInto(orig_block->getParent(),
Region::iterator(tail_block), mapper);
if_op.false_branch().cloneInto(orig_block->getParent(),
Region::iterator(tail_block), mapper);
// Determine the blocks for the start of the true and false regions.
Block* true_block = mapper.lookup(&if_op.true_branch().front());
Block* false_block = mapper.lookup(&if_op.false_branch().front());
// Perform the conditional branch into the true/false cases.
builder.setInsertionPointToEnd(orig_block);
// Extract the predicate for checking branching, then branch to the true and
// false regions appropriately.
auto cond_value = builder.create<mlir::ExtractElementOp>(loc, if_op.pred());
builder.create<mlir::CondBranchOp>(loc, cond_value, true_block,
if_op.true_arg(), false_block,
if_op.false_arg());
// Replace the true case's return operations with a branch to the tail of
// the condition.
if (failed(ReplaceTerminators(&if_op.true_branch(), tail_block, loc, mapper,
&builder)))
return failure();
if (failed(ReplaceTerminators(&if_op.false_branch(), tail_block, loc, mapper,
&builder)))
return failure();
tail_block->addArguments(if_op.getResult().getType());
if_op.getResult().replaceAllUsesWith(tail_block->getArgument(0));
op_inst->erase();
return success();
}
LogicalResult LowerWhileOp(mlir::xla_hlo::WhileOp while_op) {
// Converts an XLA while loop into control flow. This generates a set of MLIR
// blocks and branches, along with inlining the regions provided by the XLA
// while loop. The structure should be similar to below:
//
// <prior operations>
// %0 = "xla_hlo.while"(%arg0) {^cond(...){...}, ^body(...){...}}
// <post operations>
auto* op_inst = while_op.getOperation();
mlir::OpBuilder builder(while_op);
auto loc = while_op.getLoc();
// Break the block into four sections:
// orig_block - operations before the while and the branch into looping check.
// tail_block - operations after the while loop completes.
// cond_block - check the looping condition, then conditionally branch into
// the loop or, if condition is false, jump to the tail branch.
// body_block - inlined loop body, then jump back to the condition block.
auto* orig_block = op_inst->getBlock();
auto* tail_block = orig_block->splitBlock(op_inst);
BlockAndValueMapping mapper;
while_op.cond().cloneInto(orig_block->getParent(),
Region::iterator(tail_block), mapper);
while_op.body().cloneInto(orig_block->getParent(),
Region::iterator(tail_block), mapper);
// Lookup the entry blocks for both condition and body.
auto* cond_block = mapper.lookup(&while_op.cond().front());
auto* body_block = mapper.lookup(&while_op.body().front());
// Setup the end of the original block:
// <prior operations>
// br ^cond(%arg0) // Jumps to the condition statement.
builder.setInsertionPointToEnd(orig_block);
builder.create<mlir::BranchOp>(loc, cond_block, while_op.getOperand());
// Updates the inlined condition blocks by replacing the return op with an
// extract_element and conditional branch. This changes the block below:
// ^cond(%0):
// <inlined conditional region>
// "xla_hlo".return(%1)
//
// Into:
// ^cond(%0):
// <inlined conditional region>
// %2 = extract_element %1[] : tensor<i1> // Extract the condition value.
// cond_br %2, ^body(%0), ^tail(%0) // Branch.
builder.setInsertionPointToStart(cond_block);
// Replace the xla_hlo::ReturnOp with a branch back to the condition block.
// This is required as the xla_hlo::ReturnOp is used to mark the end of a
// block for regions nested inside of a operations (MLIR ReturnOp cannot be
// nested within an non-function region).
for (auto& block : while_op.cond()) {
auto new_block = mapper.lookup(&block);
auto return_op = dyn_cast<xla_hlo::ReturnOp>(new_block->getTerminator());
if (!return_op) continue;
builder.setInsertionPointToEnd(new_block);
auto return_value = return_op.getOperand(0);
auto cond_value = builder.create<mlir::ExtractElementOp>(loc, return_value);
// Get the body block arguments.
llvm::SmallVector<Value, 4> successor_args(cond_block->args_begin(),
cond_block->args_end());
builder.create<mlir::CondBranchOp>(loc, cond_value, body_block,
successor_args, tail_block,
successor_args);
return_op.erase();
}
// Updates the body blocks by replace the return op with an branch to the
// conditional block. This changes the block below:
// ^body(%0):
// <inlined body block>
// "xla_hlo".return(%1)
//
// Into:
// ^body(%0):
// <inlined body block>
// br ^cond(%0) // Branch.
for (auto& block : while_op.body()) {
auto new_block = mapper.lookup(&block);
auto return_op =
dyn_cast<mlir::xla_hlo::ReturnOp>(new_block->getTerminator());
if (!return_op) continue;
builder.setInsertionPointToEnd(new_block);
builder.create<mlir::BranchOp>(loc, cond_block, return_op.getOperands());
return_op.erase();
}
// Erase the original while loop.
tail_block->addArgument(while_op.getType());
while_op.getResult().replaceAllUsesWith(tail_block->getArgument(0));
op_inst->erase();
return success();
}
void LegalizeControlFlow::runOnFunction() {
auto func = getFunction();
llvm::SmallVector<IfOp, 4> if_ops;
func.walk([&](IfOp op) { if_ops.push_back(op); });
for (auto& op : if_ops) {
if (failed(LowerIfOp(op))) return signalPassFailure();
}
llvm::SmallVector<WhileOp, 4> while_ops;
func.walk([&](WhileOp op) { while_ops.push_back(op); });
for (auto& op : while_ops) {
if (failed(LowerWhileOp(op))) return signalPassFailure();
}
}
} // namespace
} // namespace xla_hlo
} // namespace mlir
std::unique_ptr<mlir::OperationPass<mlir::FuncOp>>
mlir::xla_hlo::createLegalizeControlFlowPass() {
return std::make_unique<LegalizeControlFlow>();
}
static PassRegistration<mlir::xla_hlo::LegalizeControlFlow> legalize_cf_pass(
"xla-legalize-control-flow",
"Legalize from XLA control flow to MLIR control flow");

View File

@ -0,0 +1,156 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file implements logic for lowering the tanh standard ops to an
// approximation.
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Function.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
namespace mlir {
namespace xla {
namespace {
/// Emits the fast tanh approximation that is also used by XLA.
Value EmitTanhApproximation(Value input, Location loc,
PatternRewriter &rewriter) {
// For small values of x, we can approximate tanh(x)=x. For extremely small
// values of x (|x| < 1e-37), the other approximation would evaluate
// tanh(x) = 0.
constexpr float kCanUseApprox = 0.0004;
Value abs_value = rewriter.create<AbsFOp>(loc, input);
Value can_use_approx =
rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(kCanUseApprox));
Value return_input = rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT,
abs_value, can_use_approx);
// Clamp the input to [-c, c].
Value max_clamp = rewriter.create<ConstantOp>(
loc, rewriter.getF32FloatAttr(7.90531110763549805f));
Value smaller_than_max =
rewriter.create<CmpFOp>(loc, CmpFPredicate::ULE, input, max_clamp);
Value clamped_half =
rewriter.create<SelectOp>(loc, smaller_than_max, input, max_clamp);
Value min_clamp = rewriter.create<ConstantOp>(
loc, rewriter.getF32FloatAttr(-7.90531110763549805f));
Value larger_than_min =
rewriter.create<CmpFOp>(loc, CmpFPredicate::UGE, clamped_half, min_clamp);
Value input_clamped =
rewriter.create<SelectOp>(loc, larger_than_min, clamped_half, min_clamp);
static constexpr std::array<float, 7> numerator_coeffs{
-2.76076847742355e-16f, 2.00018790482477e-13f, -8.60467152213735e-11f,
5.12229709037114e-08f, 1.48572235717979e-05f, 6.37261928875436e-04f,
4.89352455891786e-03f};
static constexpr std::array<float, 4> denominator_coeffs{
1.19825839466702e-06f, 1.18534705686654e-04f, 2.26843463243900e-03f,
4.89352518554385e-03f};
Value input_squared =
rewriter.create<MulFOp>(loc, input_clamped, input_clamped);
Value numerator = rewriter.create<ConstantOp>(
loc, rewriter.getF32FloatAttr(numerator_coeffs[0]));
for (int i = 1; i < numerator_coeffs.size(); i++) {
numerator = rewriter.create<AddFOp>(
loc, rewriter.create<MulFOp>(loc, input_squared, numerator),
rewriter.create<ConstantOp>(
loc, rewriter.getF32FloatAttr(numerator_coeffs[i])));
}
numerator = rewriter.create<MulFOp>(loc, input_clamped, numerator);
Value denominator = rewriter.create<ConstantOp>(
loc, rewriter.getF32FloatAttr(denominator_coeffs[0]));
for (int i = 1; i < denominator_coeffs.size(); i++) {
denominator = rewriter.create<AddFOp>(
loc, rewriter.create<MulFOp>(loc, input_squared, denominator),
rewriter.create<ConstantOp>(
loc, rewriter.getF32FloatAttr(denominator_coeffs[i])));
}
Value approx = rewriter.create<DivFOp>(loc, numerator, denominator);
return rewriter.create<SelectOp>(loc, return_input, input, approx);
}
class ApproximateTanhLowering : public OpRewritePattern<TanhOp> {
public:
explicit ApproximateTanhLowering(MLIRContext *ctx)
: OpRewritePattern<TanhOp>(ctx, 100) {}
LogicalResult matchAndRewrite(TanhOp tanhOp,
PatternRewriter &rewriter) const override {
Type operand_type = tanhOp.getType();
if (operand_type.isF64()) {
// Similar to XLA, do not rewrite f64 as precision might matter.
return failure();
}
Location loc = tanhOp.getLoc();
Value input = tanhOp.operand();
if (operand_type.isF16()) {
input = rewriter.create<FPExtOp>(loc, input, rewriter.getF32Type());
}
// If we still do not have f32, fail.
if (!input.getType().isF32()) {
return failure();
}
Value result = EmitTanhApproximation(input, loc, rewriter);
// Truncate back if needed.
if (operand_type.isF16()) {
result = rewriter.create<FPTruncOp>(loc, result, rewriter.getF16Type());
}
rewriter.replaceOp(tanhOp, {result});
return success();
}
};
struct LegalizeTanhToApproximation
: public PassWrapper<LegalizeTanhToApproximation, FunctionPass> {
/// Perform the lowering of standard dialect operations to approximations.
void runOnFunction() override {
OwningRewritePatternList patterns;
PopulateTanhToApproximationPatterns(&getContext(), &patterns);
applyPatternsAndFoldGreedily(getFunction(), patterns);
}
};
} // anonymous namespace
std::unique_ptr<mlir::OperationPass<mlir::FuncOp>>
createLegalizeTanhToApproximationPass() {
return std::make_unique<LegalizeTanhToApproximation>();
}
void PopulateTanhToApproximationPatterns(mlir::MLIRContext *context,
OwningRewritePatternList *patterns) {
patterns->insert<ApproximateTanhLowering>(context);
}
static PassRegistration<LegalizeTanhToApproximation> legalize_pass(
"xla-legalize-tanh-to-approximation",
"Legalize tanh from standard dialect to an approximation");
} // namespace xla
} // namespace mlir

View File

@ -0,0 +1,208 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file implements logic for lowering XLA dialect to Standard dialect.
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringSwitch.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Function.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
namespace mlir {
namespace {
#include "third_party/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/generated_legalize_to_standard.inc"
} // end anonymous namespace
namespace xla_hlo {
namespace {
class CompareIConvert : public OpRewritePattern<xla_hlo::CompareOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(xla_hlo::CompareOp op,
PatternRewriter &rewriter) const override {
auto lhs = op.lhs();
auto rhs = op.rhs();
auto lhs_type = lhs.getType().cast<TensorType>();
auto rhs_type = rhs.getType().cast<TensorType>();
// Broadcasting not supported by this rewrite.
if (lhs_type.getShape() != rhs_type.getShape()) return failure();
if (!lhs_type.getElementType().isSignlessInteger() ||
!rhs_type.getElementType().isSignlessInteger())
return failure();
auto comparison_direction = op.comparison_direction();
auto compare_predicate =
llvm::StringSwitch<Optional<CmpIPredicate>>(comparison_direction)
.Case("EQ", CmpIPredicate::eq)
.Case("NE", CmpIPredicate::ne)
.Case("LT", CmpIPredicate::slt)
.Case("LE", CmpIPredicate::sle)
.Case("GT", CmpIPredicate::sgt)
.Case("GE", CmpIPredicate::sge)
.Default(llvm::None);
if (!compare_predicate.hasValue()) return failure();
rewriter.replaceOpWithNewOp<CmpIOp>(op, compare_predicate.getValue(), lhs,
rhs);
return success();
}
};
class CompareFConvert : public OpRewritePattern<xla_hlo::CompareOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(xla_hlo::CompareOp op,
PatternRewriter &rewriter) const override {
auto lhs = op.lhs();
auto rhs = op.rhs();
auto lhs_type = lhs.getType().cast<TensorType>();
auto rhs_type = rhs.getType().cast<TensorType>();
// Broadcasting not supported by this rewrite.
if (lhs_type.getShape() != rhs_type.getShape()) return failure();
if (!lhs_type.getElementType().isa<FloatType>() ||
!rhs_type.getElementType().isa<FloatType>())
return failure();
auto comparison_direction = op.comparison_direction();
auto compare_predicate =
llvm::StringSwitch<Optional<CmpFPredicate>>(comparison_direction)
.Case("EQ", CmpFPredicate::OEQ)
.Case("NE", CmpFPredicate::UNE)
.Case("LT", CmpFPredicate::OLT)
.Case("LE", CmpFPredicate::OLE)
.Case("GT", CmpFPredicate::OGT)
.Case("GE", CmpFPredicate::OGE)
.Default(llvm::None);
if (!compare_predicate.hasValue()) return failure();
rewriter.replaceOpWithNewOp<CmpFOp>(op, compare_predicate.getValue(), lhs,
rhs);
return success();
}
};
// Replace IotaOp with an integer constant. A ConvertOp is added to
// convert the integer constant to iota result type. For complex types, the real
// part is replaced with the generated constant and the imaginary part is
// replaced with zero tensor.
class ConvertIotaOp : public OpRewritePattern<xla_hlo::IotaOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(xla_hlo::IotaOp op,
PatternRewriter &rewriter) const override {
auto output_type = op.getType().cast<ShapedType>();
auto output_size = output_type.getNumElements();
auto dimension = op.iota_dimension().getSExtValue();
auto max_dim_size = output_type.getDimSize(dimension);
auto element_type = output_type.getElementType();
int bitwidth;
auto complex_ty = element_type.dyn_cast<ComplexType>();
Type int_or_float_ty = element_type;
if (complex_ty) int_or_float_ty = complex_ty.getElementType();
bitwidth = int_or_float_ty.getIntOrFloatBitWidth();
llvm::SmallVector<APInt, 10> values;
values.reserve(output_size);
int64_t increase_stride = output_size;
for (int i = 0; i <= dimension; i++) {
increase_stride /= output_type.getDimSize(i);
}
int64_t current_value = 0;
for (int i = 0; i < output_size; i++) {
int64_t value = (current_value / increase_stride) % max_dim_size;
values.push_back(APInt(bitwidth, value));
++current_value;
}
auto int_shape_type = RankedTensorType::get(
output_type.getShape(),
IntegerType::get(bitwidth, rewriter.getContext()));
auto loc = op.getLoc();
auto integer_const = rewriter.create<mlir::ConstantOp>(
loc, DenseIntElementsAttr::get(int_shape_type, values));
auto int_or_float_shape_ty =
RankedTensorType::get(output_type.getShape(), int_or_float_ty);
auto iota_const =
rewriter.create<ConvertOp>(loc, int_or_float_shape_ty, integer_const);
// For int/float types we are done, replace op and return.
if (!complex_ty) {
rewriter.replaceOp(op, iota_const.getResult());
return success();
}
// For complex types, generate a constant tensor of zeroes for the imaginary
// part and use iota_const for real part.
auto zeroes = rewriter.create<mlir::ConstantOp>(
loc, DenseIntElementsAttr::get(int_shape_type, APInt(bitwidth, 0)));
auto imag_zeroes =
rewriter.create<ConvertOp>(loc, int_or_float_shape_ty, zeroes);
rewriter.replaceOpWithNewOp<xla_hlo::ComplexOp>(op, iota_const,
imag_zeroes);
return success();
}
};
} // end anonymous namespace
namespace {
struct LegalizeToStandard
: public PassWrapper<LegalizeToStandard, FunctionPass> {
/// Perform the lowering to Standard dialect.
void runOnFunction() override;
};
} // end anonymous namespace
std::unique_ptr<mlir::OperationPass<mlir::FuncOp>> createLegalizeToStdPass() {
return std::make_unique<LegalizeToStandard>();
}
void PopulateXlaToStdPatterns(OwningRewritePatternList *patterns,
mlir::MLIRContext *ctx) {
mlir::populateWithGenerated(ctx, patterns);
patterns->insert<CompareFConvert, CompareIConvert, ConvertIotaOp>(ctx);
}
/// Perform the lowering to standard dialect.
void LegalizeToStandard::runOnFunction() {
OwningRewritePatternList patterns;
mlir::xla_hlo::PopulateXlaToStdPatterns(&patterns, &getContext());
applyPatternsAndFoldGreedily(getFunction(), patterns);
}
static PassRegistration<LegalizeToStandard> legalize_pass(
"xla-legalize-to-std", "Legalize from XLA dialect to standard dialect");
} // end namespace xla_hlo
} // end namespace mlir

View File

@ -0,0 +1,71 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This is the legalization pattern definition file for XLA to StandardOps.
include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpBase.td"
include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td"
include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td"
//===----------------------------------------------------------------------===//
// Nullary op patterns.
//===----------------------------------------------------------------------===//
def : Pat<(HLO_ConstOp ElementsAttr:$value),
(ConstantOp $value)>;
//===----------------------------------------------------------------------===//
// Binary op patterns.
//===----------------------------------------------------------------------===//
def IsSameSizePred : CPred<
"$0.getType().cast<ShapedType>().getShape() "
"== $1.getType().cast<ShapedType>().getShape()">;
def IsSameSizeConstraint : Constraint<IsSameSizePred, "inputs are same size">;
def : Pat<(HLO_AndOp HLO_PredTensor:$l, HLO_PredTensor:$r),
(AndOp $l, $r),
[(IsSameSizeConstraint $l, $r)]>;
def : Pat<(HLO_AddOp HLO_FpTensor:$l, HLO_FpTensor:$r),
(AddFOp $l, $r),
[(IsSameSizeConstraint $l, $r)]>;
def : Pat<(HLO_SubOp HLO_FpTensor:$l, HLO_FpTensor:$r),
(SubFOp $l, $r),
[(IsSameSizeConstraint $l, $r)]>;
def : Pat<(HLO_MulOp HLO_FpTensor:$l, HLO_FpTensor:$r),
(MulFOp $l, $r),
[(IsSameSizeConstraint $l, $r)]>;
def : Pat<(HLO_DivOp HLO_FpTensor:$l, HLO_FpTensor:$r),
(DivFOp $l, $r),
[(IsSameSizeConstraint $l, $r)]>;
def : Pat<(HLO_RemOp HLO_FpTensor:$l, HLO_FpTensor:$r),
(RemFOp $l, $r),
[(IsSameSizeConstraint $l, $r)]>;
def : Pat<(HLO_AddOp HLO_IntTensor:$l, HLO_IntTensor:$r),
(AddIOp $l, $r),
[(IsSameSizeConstraint $l, $r)]>;
def : Pat<(HLO_SubOp HLO_IntTensor:$l, HLO_IntTensor:$r),
(SubIOp $l, $r),
[(IsSameSizeConstraint $l, $r)]>;
def : Pat<(HLO_MulOp HLO_IntTensor:$l, HLO_IntTensor:$r),
(MulIOp $l, $r),
[(IsSameSizeConstraint $l, $r)]>;
def : Pat<(HLO_DivOp HLO_IntTensor:$l, HLO_IntTensor:$r),
(SignedDivIOp $l, $r),
[(IsSameSizeConstraint $l, $r)]>;
def : Pat<(HLO_RemOp HLO_IntTensor:$l, HLO_IntTensor:$r),
(SignedRemIOp $l, $r),
[(IsSameSizeConstraint $l, $r)]>;

View File

@ -0,0 +1,105 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file implements a pass to remove redundant LHLO copy operations.
#include "third_party/absl/memory/memory.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
namespace mlir {
namespace xla_lhlo {
namespace {
// Removes LHLO copy operations that copy from allocated buffers to block
// arguments. All uses of each buffer are replaced with the corresponding block
// argument and the buffer is freed. Note that this pass only works in regions
// with a single block.
struct LhloCopyRemoval : mlir::PassWrapper<LhloCopyRemoval, OperationPass<>> {
void runOnOperation() override {
llvm::SmallVector<mlir::Operation*, 2> eraseList;
auto operation = getOperation();
operation->walk([&](mlir::xla_lhlo::CopyOp copyOp) {
// If this region contains more than one block, then ignore this copy
// operation.
if (copyOp.getParentRegion()->getBlocks().size() > 1) {
return;
}
mlir::Value fromOperand = copyOp.operand();
mlir::Value toOperand = copyOp.output();
// If the fromOperand value is a block argument or the toOperand
// value is not a block argument, then ignore this copy operation.
if (!fromOperand.getDefiningOp() || toOperand.getDefiningOp()) {
return;
}
// The copy operation removal is illegal if there is at least a single use
// of toOperand value that lies between the first use of fromOperand value
// and the copy operation.
auto fromOperandUsers = fromOperand.getUsers();
auto firstUser = *fromOperandUsers.begin();
for (auto op : fromOperandUsers) {
if (op->isBeforeInBlock(firstUser)) firstUser = op;
}
for (auto op : toOperand.getUsers()) {
if (op->isBeforeInBlock(copyOp) && firstUser->isBeforeInBlock(op)) {
return;
}
}
// TODO(DFKI): Use live variable analysis to solve aliasing issues among
// block arguments.
// Remove the associated alloc operation.
auto allocOp = fromOperand.getDefiningOp();
eraseList.push_back(allocOp);
// Iterate over all uses of the fromOperand to find the associated
// deallocOp (if any).
for (auto op : fromOperandUsers) {
if (isa<mlir::DeallocOp>(op)) {
eraseList.push_back(op);
break;
}
}
// Replace all uses of the fromOperand with the toOperand. This rewires
// all references pointing to the original alloc operation to the new
// target operation in order to safely remove the copy op.
fromOperand.replaceAllUsesWith(toOperand);
copyOp.erase();
});
for (auto op : eraseList) {
op->erase();
}
};
};
} // namespace
std::unique_ptr<Pass> createLhloCopyRemovalPass() {
return absl::make_unique<LhloCopyRemoval>();
}
static PassRegistration<LhloCopyRemoval> copy_removal_pass(
"lhlo-copy-removal", "Removes redundant LHLO copy operations");
} // namespace xla_lhlo
} // namespace mlir

View File

@ -0,0 +1,151 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file implements logic for fusing linalg ops obtained after LHLO
// lowering.
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
#include "third_party/absl/memory/memory.h"
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/ArrayRef.h"
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLExtras.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/FoldUtils.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
namespace mlir {
namespace xla_lhlo {
namespace {
using linalg::LinalgOp;
class LhloFuseLinalg : public PassWrapper<LhloFuseLinalg, FunctionPass> {
public:
LhloFuseLinalg() = default;
LhloFuseLinalg(const LhloFuseLinalg&) {}
LhloFuseLinalg(bool use_parallel_loops, llvm::ArrayRef<unsigned> tile_sizes) {
tile_sizes_ = tile_sizes;
use_parallel_loops_.setValue(use_parallel_loops);
}
void runOnFunction() override {
auto func = getFunction();
// TODO(pifon): Remove assumption that the function has a single block.
if (!llvm::hasSingleElement(func)) {
emitError(func.getLoc(), "The function needs to have a single block.");
signalPassFailure();
return;
}
// The fusion in Linalg is currently possible only when the consumer op is
// tiled. In order to greedily fuse the ops, we have to start from the tiled
// root linalg ops, i.e. linalg ops that write to output buffers of the
// function or are returned in case of escaping allocations.
llvm::SmallDenseSet<Value> result_buffers;
for (auto func_arg : func.getArguments()) {
result_buffers.insert(func_arg);
}
for (auto& block : func) {
auto returnOp = mlir::dyn_cast<mlir::ReturnOp>(block.getTerminator());
if (!returnOp) continue;
for (auto operand : returnOp.getOperands()) {
result_buffers.insert(operand);
}
}
MLIRContext* ctx = func.getContext();
OpBuilder b(func);
OperationFolder folder(ctx);
func.walk([&](linalg::GenericOp generic_op) {
SmallVector<int64_t, 2> tile_sizes(tile_sizes_.begin(),
tile_sizes_.end());
if (tile_sizes.empty()) {
tile_sizes = SmallVector<int64_t, 2>(generic_op.getNumLoops(), 1);
}
auto op = cast<LinalgOp>(generic_op.getOperation());
for (const Value result : op.getOutputBuffers()) {
if (!result_buffers.count(result)) continue;
if (tileGenericOp(op, tile_sizes, &b)) {
generic_op.erase();
return;
}
}
});
auto patterns = linalg::getLinalgTilingCanonicalizationPatterns(ctx);
applyPatternsAndFoldGreedily(func, patterns);
// Fuse producers of tiled linalg ops.
llvm::SmallDenseSet<Operation*> erase_set;
SmallVector<Operation*, 8> linalg_ops;
func.walk([&](LinalgOp op) { linalg_ops.push_back(op); });
for (auto* op : llvm::reverse(linalg_ops)) {
for (unsigned id = 0, e = LinalgOp(op).getNumInputs(); id < e; ++id) {
linalg::Aliases aliases;
linalg::LinalgDependenceGraph graph(aliases, linalg_ops);
if (auto info = fuseProducerOf(b, op, id, graph, &folder)) {
auto originalOp = info->originalProducer.getOperation();
erase_set.insert(originalOp);
auto originalOpInLinalgOpsVector = std::find_if(
linalg_ops.begin(), linalg_ops.end(),
[&](const Operation* op) { return op == originalOp; });
*originalOpInLinalgOpsVector = info->fusedProducer.getOperation();
}
}
auto patterns = linalg::getLinalgTilingCanonicalizationPatterns(ctx);
applyPatternsAndFoldGreedily(func, patterns);
}
for (auto* e : erase_set) e->erase();
}
private:
bool tileGenericOp(LinalgOp op, ArrayRef<int64_t> tile_sizes, OpBuilder* b) {
auto loopType = use_parallel_loops_
? linalg::LinalgTilingLoopType::ParallelLoops
: linalg::LinalgTilingLoopType::Loops;
auto tiled_generic_op = linalg::tileLinalgOp(*b, op,
linalg::LinalgTilingOptions()
.setTileSizes(tile_sizes)
.setLoopType(loopType));
return tiled_generic_op.hasValue();
}
Option<bool> use_parallel_loops_{
*this, "use-parallel-loops",
llvm::cl::desc(
"Tiles GenericOp consumer to parallel loops before linalg fusion"),
llvm::cl::init(false)};
ListOption<unsigned> tile_sizes_{
*this, "tile-sizes",
llvm::cl::desc(
"Tile sizes by which to tile linalg generic before linalg fusion"),
llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
};
} // namespace
std::unique_ptr<OperationPass<FuncOp>> createLhloFuseLinalg(
bool use_parallel_loops, ArrayRef<unsigned> tile_sizes) {
return absl::make_unique<LhloFuseLinalg>(use_parallel_loops, tile_sizes);
}
static PassRegistration<LhloFuseLinalg> legalize_pass(
"lhlo-fuse-linalg",
"Greedily fuse linalg ops obtained after LHLO lowering.");
} // namespace xla_lhlo
} // namespace mlir

View File

@ -0,0 +1,161 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file implements logic for lowering LHLO dialect to Affine dialect.
#include "third_party/absl/memory/memory.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Location.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_xla_to_scalar_op.h"
namespace mlir {
namespace xla_lhlo {
namespace {
// Builds an affine loop nest iterating from zeros to "upper_bounds" with unit
// steps, and populates the body of the innermost loop using "body_builder".
static void BuildBoundedAffineLoopNest(
OpBuilder& builder, Location location, ArrayRef<int64_t> upper_bounds,
function_ref<void(OpBuilder&, Location, ValueRange)> body_builder) {
SmallVector<int64_t, 3> lower_bounds(upper_bounds.size(), /*Value=*/0);
SmallVector<int64_t, 3> steps(upper_bounds.size(), /*Value=*/1);
buildAffineLoopNest(builder, location, lower_bounds, upper_bounds, steps,
body_builder);
}
struct DotOpConverter : public OpRewritePattern<DotOp> {
using OpRewritePattern<DotOp>::OpRewritePattern;
// Supports only rank-2 tensors for LHS and RHS.
LogicalResult matchAndRewrite(DotOp op,
PatternRewriter& rewriter) const override {
Value lhs = op.lhs();
Value rhs = op.rhs();
MemRefType lhs_type = lhs.getType().cast<MemRefType>();
MemRefType rhs_type = rhs.getType().cast<MemRefType>();
Type element_type = lhs_type.getElementType();
ArrayRef<int64_t> shape_lhs = lhs_type.getShape();
ArrayRef<int64_t> shape_rhs = rhs_type.getShape();
if ((lhs_type.getRank() != 2) || (rhs_type.getRank() != 2)) {
return failure();
}
LogicalResult map_status = success();
auto body_builder = [&](OpBuilder& builder, Location loc, ValueRange ivs) {
SmallVector<Value, 2> lhs_indices{ivs[0], ivs[2]},
rhs_indices{ivs[2], ivs[1]}, result_indices{ivs[0], ivs[1]};
auto l = builder.create<AffineLoadOp>(loc, lhs, lhs_indices);
auto r = builder.create<AffineLoadOp>(loc, rhs, rhs_indices);
auto result =
rewriter.create<AffineLoadOp>(loc, op.output(), result_indices);
Value op_result = xla_lhlo::XlaOpToStdScalarOp::map<DotOp>(
op, element_type, {l, r, result}, &builder);
map_status = success(op_result != nullptr);
if (failed(map_status)) return;
builder.create<AffineStoreOp>(loc, op_result, op.output(),
result_indices);
};
BuildBoundedAffineLoopNest(rewriter, op.getLoc(),
{shape_lhs[0], shape_rhs[1], shape_rhs[0]},
body_builder);
if (failed(map_status)) return failure();
rewriter.eraseOp(op);
return success();
}
};
template <typename LhloOpTy>
struct BinaryOpConverter : public OpRewritePattern<LhloOpTy> {
using OpRewritePattern<LhloOpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(LhloOpTy op,
PatternRewriter& rewriter) const override {
const auto& lhs = op.lhs();
const auto& rhs = op.rhs();
const auto& lhs_type = lhs.getType().template cast<MemRefType>();
const auto& rhs_type = rhs.getType().template cast<MemRefType>();
const auto& element_type = lhs_type.getElementType();
if (lhs_type.getShape() != rhs_type.getShape()) {
return failure();
}
LogicalResult map_status = success();
auto body_builder = [&](OpBuilder& builder, Location loc,
ValueRange induction_vars) {
auto l = builder.create<AffineLoadOp>(loc, lhs, induction_vars);
auto r = builder.create<AffineLoadOp>(loc, rhs, induction_vars);
Value op_result = xla_lhlo::XlaOpToStdScalarOp::map<LhloOpTy>(
op, element_type, {l, r}, &builder);
map_status = success(op_result != nullptr);
if (failed(map_status)) return;
rewriter.create<AffineStoreOp>(loc, op_result, op.out(), induction_vars);
};
BuildBoundedAffineLoopNest(rewriter, op.getLoc(), lhs_type.getShape(),
body_builder);
if (failed(map_status)) return failure();
rewriter.eraseOp(op);
return success();
}
};
void populateLHLOToAffineConversionPattern(MLIRContext* context,
OwningRewritePatternList* patterns) {
// clang-format off
patterns->insert<
BinaryOpConverter<xla_lhlo::AddOp>,
BinaryOpConverter<xla_lhlo::AndOp>,
BinaryOpConverter<xla_lhlo::DivOp>,
BinaryOpConverter<xla_lhlo::MaxOp>,
BinaryOpConverter<xla_lhlo::MinOp>,
BinaryOpConverter<xla_lhlo::MulOp>,
BinaryOpConverter<xla_lhlo::SubOp>,
DotOpConverter>(context);
// clang-format on
}
struct LhloLegalizeToAffine
: public PassWrapper<LhloLegalizeToAffine, FunctionPass> {
void runOnFunction() override {
OwningRewritePatternList patterns;
auto func = getFunction();
populateLHLOToAffineConversionPattern(func.getContext(), &patterns);
applyPatternsAndFoldGreedily(func, patterns);
}
};
} // namespace
std::unique_ptr<OperationPass<FuncOp>> createLegalizeToAffinePass() {
return absl::make_unique<LhloLegalizeToAffine>();
}
static PassRegistration<LhloLegalizeToAffine> legalize_pass(
"lhlo-legalize-to-affine", "Legalize from LHLO dialect to affine dialect");
} // namespace xla_lhlo
} // namespace mlir

View File

@ -0,0 +1,196 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file implements logic for lowering LHLO dialect to GPU dialect.
#include <cstdint>
#include "third_party/absl/memory/memory.h"
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/ArrayRef.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/GPU/GPUDialect.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/SCF/SCF.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/BlockAndValueMapping.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Function.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Location.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_xla_to_scalar_op.h"
namespace mlir {
namespace xla_lhlo {
namespace {
// A simple translation of LHLO reduce operations to a corresponding gpu
// launch operation. The transformation does no tiling and also only supports
// 1d results.
class LhloReduceToGPULaunchConverter : public OpConversionPattern<ReduceOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
ReduceOp reduce_op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final {
auto loc = reduce_op.getLoc();
// Only support 1d reductions for now.
int64_t size = 0;
for (auto result : reduce_op.out()) {
auto shaped_type = result.getType().dyn_cast<ShapedType>();
if (!shaped_type || shaped_type.getRank() != 1) {
return failure();
}
auto dim_size = shaped_type.getDimSize(0);
if (size && size != dim_size) {
return failure();
}
size = dim_size;
}
auto reducing_dimension = *reduce_op.dimensions().int_value_begin();
// Require all inputs to have the same shape.
int64_t reduce_dim_size = 0;
for (auto input : reduce_op.operands()) {
auto shaped_type = input.getType().dyn_cast<ShapedType>();
if (!shaped_type || !shaped_type.hasStaticShape()) {
return failure();
}
reduce_dim_size =
shaped_type.getDimSize(reducing_dimension.getSExtValue());
}
// Create a launch that is parallel in the result dimension.
auto block_size_x = rewriter.create<mlir::ConstantOp>(
loc, rewriter.getIndexType(),
rewriter.getIntegerAttr(rewriter.getIndexType(), size));
auto one = rewriter.create<mlir::ConstantOp>(
loc, rewriter.getIndexType(),
rewriter.getIntegerAttr(rewriter.getIndexType(), 1));
auto launch_op = rewriter.create<mlir::gpu::LaunchOp>(
loc, one, one, one, block_size_x, one, one);
{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToEnd(&launch_op.body().front());
auto index = launch_op.getThreadIds().x;
// Load the initial value and store it to the output.
for (auto pair : llvm::zip(reduce_op.init_values(), reduce_op.out())) {
auto init_value = rewriter.create<mlir::LoadOp>(loc, std::get<0>(pair));
rewriter.create<mlir::StoreOp>(loc, init_value, std::get<1>(pair),
ArrayRef<Value>{index});
}
// Insert a loop into the body to compute the reduction. The loop ranges
// from [0.dim).
auto zero = rewriter.create<mlir::ConstantOp>(
loc, rewriter.getIndexType(),
rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
// TODO(b/137624192) Use dimOp to make it shape independent.
auto upper = rewriter.create<mlir::ConstantOp>(
loc, rewriter.getIndexType(),
rewriter.getIntegerAttr(rewriter.getIndexType(), reduce_dim_size));
auto step = rewriter.create<mlir::ConstantOp>(
loc, rewriter.getIndexType(),
rewriter.getIntegerAttr(rewriter.getIndexType(), 1));
auto loop = rewriter.create<mlir::scf::ForOp>(loc, zero, upper, step);
rewriter.setInsertionPointToStart(loop.getBody());
// Compute memrefs for the value to reduce. This makes it easier to just
// inline the body.
auto output = *reduce_op.out().begin();
// TODO(herhut) Move this to the SliceOp builder.
auto resType = MemRefType::get(
llvm::None, output.getType().cast<MemRefType>().getElementType(),
makeStridedLinearLayoutMap(llvm::None,
MemRefType::getDynamicStrideOrOffset(),
rewriter.getContext()));
auto accumulator = rewriter.create<mlir::linalg::SliceOp>(
loc, resType, output, ArrayRef<Value>{launch_op.getThreadIds().x});
llvm::SmallVector<Value, 4> indexings;
auto input_buffer = *reduce_op.operands().begin();
auto input_type = input_buffer.getType().cast<MemRefType>();
for (int64_t dim = 0; dim < input_type.getRank(); ++dim) {
indexings.push_back(dim == reducing_dimension
? loop.getInductionVar()
: launch_op.getThreadIds().x);
}
// TODO(herhut) Move this to the SliceOp builder.
auto input = *reduce_op.operand_begin();
auto rhs = rewriter.create<mlir::linalg::SliceOp>(
loc,
MemRefType::get(
llvm::None, input_type.getElementType(),
makeStridedLinearLayoutMap(llvm::None,
MemRefType::getDynamicStrideOrOffset(),
rewriter.getContext())),
input, indexings);
// Now copy over the actual body of the reduction, leaving out the
// terminator.
BlockAndValueMapping mapping;
mapping.map(reduce_op.body().front().getArgument(0), accumulator);
mapping.map(reduce_op.body().front().getArgument(1), rhs);
mapping.map(reduce_op.body().front().getArgument(2), accumulator);
for (auto& nested : reduce_op.body().front().without_terminator()) {
auto clone = rewriter.clone(nested, mapping);
for (auto pair : llvm::zip(nested.getResults(), clone->getResults())) {
mapping.map(std::get<0>(pair), std::get<1>(pair));
}
}
// Finally, insert the terminator for the launchOp.
rewriter.setInsertionPointToEnd(&launch_op.body().front());
rewriter.create<mlir::gpu::TerminatorOp>(loc);
}
rewriter.eraseOp(reduce_op);
return success();
};
};
struct LhloLegalizeToGpu : public PassWrapper<LhloLegalizeToGpu, FunctionPass> {
void runOnFunction() override {
OwningRewritePatternList patterns;
ConversionTarget target(getContext());
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,
gpu::GPUDialect, scf::SCFDialect, XlaLhloDialect>();
target.addIllegalOp<ReduceOp>();
auto func = getFunction();
patterns.insert<LhloReduceToGPULaunchConverter>(func.getContext());
if (failed(applyPartialConversion(func, target, patterns))) {
signalPassFailure();
}
}
};
} // namespace
std::unique_ptr<OperationPass<FuncOp>> createLegalizeToGpuPass() {
return absl::make_unique<LhloLegalizeToGpu>();
}
static PassRegistration<LhloLegalizeToGpu> legalize_pass(
"lhlo-legalize-to-gpu", "Legalize from LHLO dialect to GPU dialect");
} // namespace xla_lhlo
} // namespace mlir

View File

@ -0,0 +1,136 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "third_party/llvm/llvm-project/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
namespace mlir {
namespace xla_lhlo {
namespace {
struct StaticMemRefCastOpConverter
: public ConvertOpToLLVMPattern<StaticMemRefCastOp> {
using ConvertOpToLLVMPattern<StaticMemRefCastOp>::ConvertOpToLLVMPattern;
LogicalResult matchAndRewrite(
Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto cast_op = cast<StaticMemRefCastOp>(op);
StaticMemRefCastOp::Adaptor operands_adaptor(operands);
MemRefDescriptor sourceMemRef(operands_adaptor.operand());
MemRefType targetMemRefType =
cast_op.getResult().getType().cast<MemRefType>();
auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType)
.dyn_cast_or_null<LLVM::LLVMType>();
if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
return failure();
// Create descriptor.
auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
Type llvmTargetElementTy = desc.getElementType();
// Set allocated ptr.
Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
allocated =
rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
desc.setAllocatedPtr(rewriter, loc, allocated);
// Set aligned ptr.
Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
desc.setAlignedPtr(rewriter, loc, ptr);
// Fill size and stride descriptors in memref.
auto target_sizes = targetMemRefType.getShape();
int64_t target_offset;
llvm::SmallVector<int64_t, 4> target_strides;
if (failed((getStridesAndOffset(targetMemRefType, target_strides,
target_offset))))
return failure();
// Copy offset of `targetMemRef`.
desc.setConstantOffset(rewriter, loc, target_offset);
for (int i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
desc.setConstantSize(rewriter, loc, i, target_sizes[i]);
desc.setConstantStride(rewriter, loc, i, target_strides[i]);
}
rewriter.replaceOp(op, {desc});
return success();
}
};
struct DynamicMemRefCastOpConverter
: public ConvertOpToLLVMPattern<DynamicMemRefCastOp> {
using ConvertOpToLLVMPattern<DynamicMemRefCastOp>::ConvertOpToLLVMPattern;
LogicalResult matchAndRewrite(
Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto cast_op = cast<DynamicMemRefCastOp>(op);
DynamicMemRefCastOp::Adaptor operands_adaptor(operands);
MemRefDescriptor sourceMemRef(operands_adaptor.operand());
MemRefType targetMemRefType =
cast_op.getResult().getType().cast<MemRefType>();
auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType)
.dyn_cast_or_null<LLVM::LLVMType>();
if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
return failure();
// Create descriptor.
auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
Type llvmTargetElementTy = desc.getElementType();
// Set allocated ptr.
Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
allocated =
rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
desc.setAllocatedPtr(rewriter, loc, allocated);
// Set aligned ptr.
Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
desc.setAlignedPtr(rewriter, loc, ptr);
// Copy offset of `sourceMemRef`.
desc.setOffset(rewriter, loc, sourceMemRef.offset(rewriter, loc));
// Fill size and stride descriptors in memref.
if (!cast_op.sizes().empty()) {
auto sizes = operands_adaptor.sizes();
auto strides = operands_adaptor.strides();
for (int i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
desc.setSize(rewriter, loc, i, sizes[i]);
desc.setStride(rewriter, loc, i, strides[i]);
}
}
rewriter.replaceOp(op, {desc});
return success();
}
};
} // namespace
void PopulateLhloToLLVMConversionPatterns(const LowerToLLVMOptions &options,
LLVMTypeConverter *converter,
OwningRewritePatternList *patterns) {
patterns->insert<DynamicMemRefCastOpConverter, StaticMemRefCastOpConverter>(
*converter, options);
}
} // namespace xla_lhlo
} // namespace mlir

View File

@ -0,0 +1,59 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "third_party/llvm/llvm-project/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
namespace mlir {
namespace xla_lhlo {
namespace {
class TestLhloToLLVMPass
: public ::mlir::PassWrapper<TestLhloToLLVMPass,
::mlir::OperationPass<::mlir::ModuleOp>> {
public:
void runOnOperation() override {
ModuleOp m = getOperation();
OwningRewritePatternList patterns;
LLVMTypeConverter converter(m.getContext());
populateStdToLLVMConversionPatterns(converter, patterns);
PopulateLhloToLLVMConversionPatterns(
LowerToLLVMOptions::getDefaultOptions(), &converter, &patterns);
ConversionTarget target(getContext());
target.addLegalDialect<LLVM::LLVMDialect>();
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
target.addIllegalDialect<XlaLhloDialect>();
if (failed(applyFullConversion(m, target, patterns))) {
signalPassFailure();
}
}
};
} // namespace
static PassRegistration<TestLhloToLLVMPass> legalize_lhlo_pass(
"test-lhlo-legalize-to-llvm", "Legalize from LHLO dialect to LLVM.");
} // namespace xla_lhlo
} // namespace mlir

View File

@ -0,0 +1,731 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "third_party/absl/memory/memory.h"
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/ArrayRef.h"
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLExtras.h"
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/SmallVector.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/SCF/SCF.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
namespace mlir {
namespace xla_lhlo {
namespace {
// Clones and adapts the code in `lhlo_block` that works on buffers and has a
// single output buffer to make it compatible with `operands` that have element
// types of the respective buffers. Returns the computed value.
//
// Example. For `operands` with (f32, i32) types and a block with LHLO ops and
// with signature:
// ^bb(%lhs: memref<f32>, %rhs: memref<i32>, %res: memref<i1>):
// <LHLO_ops>
//
// inserts necessary alloc and store ops to compute and return result that has
// `i1` type.
Value ApplySingleResultLhloCode(Location loc, ValueRange operands,
Block* lhlo_block, OpBuilder* b) {
SmallVector<Value, 2> arg_bufs;
for (auto arg_type : lhlo_block->getArgumentTypes()) {
arg_bufs.push_back(b->create<AllocOp>(loc, arg_type.cast<MemRefType>()));
}
for (auto operand : llvm::enumerate(operands)) {
b->create<StoreOp>(loc, operand.value(), arg_bufs[operand.index()]);
}
// Clone the ops from `lhlo_block`.
BlockAndValueMapping mapping;
mapping.map(lhlo_block->getArguments(), arg_bufs);
for (auto& nested : lhlo_block->without_terminator()) {
auto clone = b->clone(nested, mapping);
mapping.map(nested.getResults(), clone->getResults());
}
return b->create<LoadOp>(loc, arg_bufs.back());
}
// Converts a block with LHLO ops and with signature:
// ^bb(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
// into a reduction operator of scf.reduce by doing buffer allocation for
// scalar arguments and the result of `scf.reduce` to make it compatible with
// LHLO ops.
void ConvertToReductionOperator(Location loc, scf::ReduceOp reduce_op,
Block* lhlo_block, OpBuilder* b) {
Block& loop_reduce_op_body = reduce_op.reductionOperator().front();
OpBuilder::InsertionGuard guard(*b);
b->setInsertionPointToStart(&loop_reduce_op_body);
b->create<scf::ReduceReturnOp>(
loc, ApplySingleResultLhloCode(loc, loop_reduce_op_body.getArguments(),
lhlo_block, b));
}
// Returns result of ConstantOp if `dim` is static, otherwise uses DimOp to
// extract dimension at runtime.
Value GetStaticOrDynamicDim(mlir::Location loc, Value shaped_value,
size_t dim_index, int64_t dim, OpBuilder* b) {
return dim == ShapedType::kDynamicSize
? b->create<DimOp>(loc, shaped_value, dim_index).getResult()
: b->create<ConstantIndexOp>(loc, dim);
}
struct MappedIvs {
// False if the mapped indices are in the padding area, true otherwise.
Value in_bounds;
// Mapped indices.
SmallVector<Value, 2> ivs;
};
template <typename OpTy>
MappedIvs MapWindowIvsToInput(OpTy op, ValueRange ivs, ValueRange window_ivs,
OpBuilder* b) {
MappedIvs mapped_ivs;
if (!op.window_strides().hasValue()) {
op.emitOpError("No window strides specified.");
}
auto window_strides = op.window_strides().getValue();
if (!op.padding().hasValue()) {
op.emitOpError("No padding specified.");
}
auto padding = op.padding().getValue();
auto loc = op.getLoc();
auto operand = op.operand();
auto operand_shape = operand.getType().template cast<MemRefType>().getShape();
// `in_bounds` is false when the mapped indices are in the padding area.
mapped_ivs.in_bounds = b->create<mlir::ConstantOp>(
loc, b->getI1Type(), b->getIntegerAttr(b->getI1Type(), 1));
for (unsigned i = 0, e = ivs.size(); i < e; ++i) {
auto stride = window_strides.template getValue<llvm::APInt>(i);
auto pad_low = padding.template getValue<llvm::APInt>({i, 0});
Value stride_val = b->create<ConstantIndexOp>(loc, stride.getSExtValue());
Value pad_low_val = b->create<ConstantIndexOp>(loc, pad_low.getSExtValue());
Value center = b->create<MulIOp>(loc, ivs[i], stride_val);
Value offset = b->create<SubIOp>(loc, window_ivs[i], pad_low_val);
Value index = b->create<AddIOp>(loc, center, offset);
Value upper_bound =
GetStaticOrDynamicDim(loc, operand, i, operand_shape[i], b);
// We must check whether 0 <= index_i < shape_i, as otherwise we are in
// the pad and then we have to use the neutral element for reduction.
// Equivalently, it can be computed as the unsigned comparison index_i <
// shape_i, since a negative value wraps to a large positive value.
mapped_ivs.in_bounds = b->create<mlir::AndOp>(
loc, mapped_ivs.in_bounds,
b->create<CmpIOp>(loc, CmpIPredicate::ult, index, upper_bound));
mapped_ivs.ivs.push_back(index);
}
return mapped_ivs;
}
// Returns scf::Parallel over a shaped value with static or dynamic shape.
scf::ParallelOp MakeLoopOverShape(Location loc, Value shaped_value,
OpBuilder* b) {
Value zero = b->create<ConstantIndexOp>(loc, 0);
Value one = b->create<ConstantIndexOp>(loc, 1);
ArrayRef<int64_t> shape =
shaped_value.getType().cast<ShapedType>().getShape();
SmallVector<Value, 2> lower, upper, step;
for (auto dim : llvm::enumerate(shape)) {
upper.push_back(
GetStaticOrDynamicDim(loc, shaped_value, dim.index(), dim.value(), b));
lower.push_back(zero);
step.push_back(one);
}
return b->create<scf::ParallelOp>(loc, lower, upper, step);
}
// Converts `xla_lhlo.ReduceOp` into two scf::ParallelOp and a scf::ReduceOp.
// The outper `ParallelOp` refers to the parallel loops if there are
// any. The inner `ParalleOp` refers to the reduction loops and `ReduceOp`
// contains the reduction operator.
//
// Example:
//
// "xla_lhlo.reduce"(%buffer, %init_buf, %result) ( {
// ^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
// <LHLO ops>
// } ) {dimensions = dense<[1]> : tensor<1xi64>}
// : (memref<100x10x5xf32>, memref<f32>, memref<100x5xf32>) -> ()
//
// is roughly converted into:
//
// %init = load %init_buf[] : memref<f32>
// scf.parallel (%i, %k) = (%c0, %c0) to (%c100, %c5) step (%c1, %c1) {
// %result = scf.parallel (%j) = (%c0) to (%c10) step (%c1) init (%init) {
// %elem_to_reduce = load %buffer[%i, %j, %k] : memref<100x10x5xf32>
// scf.reduce(%elem_to_reduce) {
// ^bb0(%elem: f32, %acc: f32): // no predecessors
// elem_buf = alloc() : memref<f32>
// store %elem, elem_buf[] : memref<f32>
// acc_buf = alloc() : memref<f32>
// store %acc, acc_buf[] : memref<f32>
// <LHLO_ops>
// %acc_result = load acc_buf[] : memref<f32>
// scf.reduce.return %acc_result : f32
// } : f32
// scf.yield
// } : f32
// scf.yield
// }
class ReduceOpConverter : public OpConversionPattern<xla_lhlo::ReduceOp> {
public:
using OpConversionPattern<xla_lhlo::ReduceOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
xla_lhlo::ReduceOp xla_reduce_op, ArrayRef<Value> /*args*/,
ConversionPatternRewriter& rewriter) const final {
// TODO(b/137624192) Implement variadic reduce.
if (xla_reduce_op.out().size() != 1) return failure();
scf::ReduceOp reduce_op =
CreateReduceOpInNestedParallelLoops(xla_reduce_op, &rewriter);
ConvertToReductionOperator(xla_reduce_op.getLoc(), reduce_op,
&xla_reduce_op.body().front(), &rewriter);
rewriter.replaceOp(xla_reduce_op, llvm::None);
return success();
}
private:
// Creates nested `scf.parallel` ops with `scf.reduce`. The outer ParallelOp
// refers to the parallel dimensions of `xla_reduce_op` if any and the inner
// ParallelOp refers to the reduction dimensions. The scf.reduce op is
// returned.
//
// If the reduction argument is a memref<100x10x5xf32> and the
// reduction is performed along dimension 1 then this method will generate
//
// %init = load %init_buf[] : memref<f32>
// scf.parallel (%i, %k) = (%c0, %c0) to (%c100, %c5) step (%c1, %c1) {
// %result = scf.parallel (%j) = (%c0) to (%c10) step (%c1) init (%init) {
// %elem_to_reduce = load %buffer[%i, %j, %k] : memref<100x10x5xf32>
// scf.reduce(%elem_to_reduce) {
// <THE BLOCK PTR TO BE RETURNED>
// } : f32
// scf.yield
// } : f32
// scf.yield
// }
scf::ReduceOp CreateReduceOpInNestedParallelLoops(
xla_lhlo::ReduceOp xla_reduce_op,
ConversionPatternRewriter* rewriter) const {
auto loc = xla_reduce_op.getLoc();
DenseSet<int> reducing_dims;
for (const auto& rdim : xla_reduce_op.dimensions().getIntValues()) {
reducing_dims.insert(rdim.getSExtValue());
}
Value operand = *xla_reduce_op.operands().begin();
Value out = *xla_reduce_op.out().begin();
SmallVector<Value, 2> parallel_lower, parallel_upper, parallel_step;
SmallVector<Value, 2> reduce_lower, reduce_upper, reduce_step;
auto operand_shape = operand.getType().cast<MemRefType>().getShape();
for (auto dim : llvm::enumerate(operand_shape)) {
const bool is_reducing_dim = reducing_dims.count(dim.index());
Value ub = GetStaticOrDynamicDim(loc, operand, dim.index(), dim.value(),
rewriter);
Value lb = rewriter->create<ConstantIndexOp>(loc, 0);
Value step = rewriter->create<ConstantIndexOp>(loc, 1);
(is_reducing_dim ? reduce_lower : parallel_lower).push_back(lb);
(is_reducing_dim ? reduce_upper : parallel_upper).push_back(ub);
(is_reducing_dim ? reduce_step : parallel_step).push_back(step);
}
// Load initial value from memref<element_type>.
SmallVector<Value, 1> init_value = {
rewriter->create<LoadOp>(loc, *xla_reduce_op.init_values().begin())};
// Outer ParallelOp is not needed if it is a reduction across all dims.
scf::ParallelOp outer;
if (!parallel_lower.empty()) {
outer = rewriter->create<scf::ParallelOp>(loc, parallel_lower,
parallel_upper, parallel_step);
rewriter->setInsertionPointToStart(outer.getBody());
}
scf::ParallelOp inner = rewriter->create<scf::ParallelOp>(
loc, reduce_lower, reduce_upper, reduce_step, ValueRange(init_value));
Value reduction_result = *inner.getResults().begin();
SmallVector<Value, 1> out_indices;
if (outer != nullptr) {
out_indices.reserve(outer.getNumLoops());
for (Value iv : outer.getInductionVars()) {
out_indices.push_back(iv);
}
} else {
out_indices.push_back(rewriter->create<ConstantIndexOp>(loc, 0));
}
rewriter->create<StoreOp>(loc, reduction_result, out, out_indices);
// Load the element to reduce.
SmallVector<Value, 2> indices;
indices.reserve(operand_shape.size());
if (outer) {
auto inner_ivs_it = inner.getInductionVars().begin();
auto outer_ivs_it = outer.getInductionVars().begin();
for (unsigned i = 0, e = operand_shape.size(); i < e; ++i) {
indices.push_back(reducing_dims.count(i) ? *inner_ivs_it++
: *outer_ivs_it++);
}
} else {
indices = inner.getInductionVars();
}
rewriter->setInsertionPointToStart(inner.getBody());
Value elem = rewriter->create<mlir::LoadOp>(
loc, *xla_reduce_op.operands().begin(), indices);
return rewriter->create<scf::ReduceOp>(loc, elem);
}
};
// Pseudocode:
// for each index O in output
// accumulator = neutral_value
// in_bounds = true
// for each index W in window
// for each dimension i from 0 to rank - 1
// index = O[i] * stride[i] + W[i] - pad_low[i]
// in_bounds = inbounds && (index `ult` shape[i])
// I[i] = index
// if (in_bounds)
// value = input[I]
// else
// value = neutral_value
// accumulator = reduction_operator(output[O], value)
// output[O] = accumulator
//
// Converts `xla_lhlo.ReduceWindowOp` into two scf::ParallelOp and a
// scf::ReduceOp.
// The outper `ParallelOp` refers to the parallel loops that traverese output
// buffer. The inner `ParalleOp` refers to the reduction loops that traverse
// reduction windows and `ReduceOp` contains the reduction operator.
//
// Example:
//
// func @reduce_window(%arg: memref<112x112xf32>,
// %init: memref<f32>,
// %result: memref<56x56xf32>) {
// "xla_lhlo.reduce_window"(%arg, %init, %result) ( {
// ^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
// "xla_lhlo.maximum"(%lhs, %rhs, %res)
// : (memref<f32>, memref<f32>, memref<f32>) -> ()
// "xla_lhlo.terminator"() : () -> ()
// }) {
// padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>,
// window_dimensions = dense<[3, 3]> : tensor<2xi64>,
// window_strides = dense<[2, 2]> : tensor<2xi64>
// } : (memref<112x112xf32>, memref<f32>, memref<56x56xf32>) -> ()
// return
// }
//
// is roughly converted into:
//
// %neutral_elem = load %init_buf[] : memref<f32>
// scf.parallel (%i, %j) = (%c0, %c0) to (%c56, %c56) step (%c1, %c1) {
// %result = scf.parallel (%iw, %jw) = (%c0, %c0)
// to (%c3, %c3) step (%c1, %c1) neutral_elem (%0) -> f32 {
// %in_bounds = <COMPUTE IF INDEX IS IN OPERAND'S pad>
// %elem = load %operand[%computed_i, %computed_j]
// %elem_or_neutral = select %in_bounds, %elem, %neutral_elem : f32
// scf.reduce(%elem_to_reduce) : f32 {
// ^bb0(%arg7: f32, %arg8: f32):
// <LHLO ops>
// }
// scf.yield
// }
// store %result, %output_buffer[%i, %j] : memref<56x56xf32>
// scf.yield
// }
// return
// }
class ReduceWindowOpConverter
: public OpConversionPattern<xla_lhlo::ReduceWindowOp> {
public:
using OpConversionPattern<xla_lhlo::ReduceWindowOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
xla_lhlo::ReduceWindowOp xla_reduce_window_op, ArrayRef<Value> /*args*/,
ConversionPatternRewriter& rewriter) const final {
scf::ParallelOp output_loop, window_loop;
std::tie(output_loop, window_loop) =
CreateParallelLoopsToTraverseOutputAndWindow(xla_reduce_window_op,
&rewriter);
scf::ReduceOp reduce_op = CreateReduceOpInNestedParallelLoops(
xla_reduce_window_op, output_loop, window_loop, &rewriter);
ConvertToReductionOperator(xla_reduce_window_op.getLoc(), reduce_op,
&xla_reduce_window_op.body().front(), &rewriter);
rewriter.replaceOp(xla_reduce_window_op, llvm::None);
return success();
}
private:
std::pair<scf::ParallelOp, scf::ParallelOp>
CreateParallelLoopsToTraverseOutputAndWindow(
xla_lhlo::ReduceWindowOp xla_reduce_window_op,
ConversionPatternRewriter* rewriter) const {
auto loc = xla_reduce_window_op.getLoc();
Value init_value =
rewriter->create<LoadOp>(loc, xla_reduce_window_op.init_value());
Value zero = rewriter->create<ConstantIndexOp>(loc, 0);
Value one = rewriter->create<ConstantIndexOp>(loc, 1);
// Create an outer parallel loop that spans the output of ReduceWindowOp.
Value xla_output = xla_reduce_window_op.out();
auto output_loop = MakeLoopOverShape(loc, xla_output, rewriter);
// Create a nested loop that traverses the window.
SmallVector<Value, 2> window_lower, window_upper, window_step;
rewriter->setInsertionPointToStart(output_loop.getBody());
for (const auto& window_dim : xla_reduce_window_op.window_dimensions()) {
window_step.push_back(one);
window_lower.push_back(zero);
window_upper.push_back(
rewriter->create<ConstantIndexOp>(loc, window_dim.getSExtValue()));
}
auto window_loop = rewriter->create<scf::ParallelOp>(
loc, window_lower, window_upper, window_step, ValueRange(init_value));
Value reduction_result = *window_loop.getResults().begin();
auto output_ivs = output_loop.getInductionVars();
rewriter->create<StoreOp>(loc, reduction_result, xla_output, output_ivs);
return std::make_pair(output_loop, window_loop);
}
scf::ReduceOp CreateReduceOpInNestedParallelLoops(
xla_lhlo::ReduceWindowOp xla_reduce_window_op,
scf::ParallelOp output_loop, scf::ParallelOp window_loop,
ConversionPatternRewriter* rewriter) const {
rewriter->setInsertionPointToStart(window_loop.getBody());
auto loc = xla_reduce_window_op.getLoc();
if (xla_reduce_window_op.base_dilations().hasValue() ||
xla_reduce_window_op.window_dilations().hasValue()) {
xla_reduce_window_op.emitRemark(
"Lowering to parallel loops does not support `base_dilations` or "
"`window_dilations` attributes yet. The attributes will be ignored.");
}
Value xla_operand = xla_reduce_window_op.operand();
auto xla_operand_type = xla_operand.getType().cast<MemRefType>();
// Compute ivs in 'arg' buffer and whether these ivs are in pad area or not.
MappedIvs mapped_ivs = MapWindowIvsToInput(
xla_reduce_window_op, output_loop.getInductionVars(),
window_loop.getInductionVars(), rewriter);
auto elem_or_init = rewriter->create<scf::IfOp>(
loc, xla_operand_type.getElementType(), mapped_ivs.in_bounds,
/*withElseRegion=*/true);
OpBuilder then_builder = elem_or_init.getThenBodyBuilder();
Value elem = then_builder.create<mlir::LoadOp>(
loc, xla_reduce_window_op.operand(), mapped_ivs.ivs);
then_builder.create<scf::YieldOp>(loc, elem);
OpBuilder else_builder = elem_or_init.getElseBodyBuilder();
else_builder.create<scf::YieldOp>(loc, *window_loop.initVals().begin());
return rewriter->create<scf::ReduceOp>(loc,
*elem_or_init.results().begin());
}
};
// See the operation semantics in
// https://www.tensorflow.org/xla/operation_semantics#selectandscatter
//
// Pseudocode:
// scf.parallel(coordinates O in the output):
// output[O] = init
// scf.parallel(coordinates S in the source):
// selected_ivs = 0
// selected_val = 0
// initialized_flag = false
// scf.for (first dim W_1 in the window)
// iter_args (selected_ivs, selected_val, initialized_flag):
// ...
// scf.for (last dim W_N in the window):
// iter_args (selected_ivs, selected_val, initialized_flag):
// I = S * stride + W - pad_low
// if I within bounds of operand:
// if (initialized_flag):
// pred = select(selected_value, operand(I))):
// if (pred)
// selected_value = operand(I)
// selected_index = I
// else
// selected_value = operand(I)
// selected_index = I
// initialized_flag = true
// output(selected_index) = scatter(output(selected_index), source(S))
class SelectAndScatterOpConverter
: public OpConversionPattern<xla_lhlo::SelectAndScatterOp> {
public:
using OpConversionPattern<xla_lhlo::SelectAndScatterOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
xla_lhlo::SelectAndScatterOp s_and_s_op, ArrayRef<Value> /*args*/,
ConversionPatternRewriter& rewriter) const final {
auto loc = s_and_s_op.getLoc();
InitializeOutput(s_and_s_op, &rewriter);
scf::ParallelOp loop_over_src =
MakeLoopOverShape(loc, s_and_s_op.source(), &rewriter);
rewriter.setInsertionPointToStart(loop_over_src.getBody());
// Compute indices of the selected element in the window.
auto selected_ivs = SelectIvs(s_and_s_op, loop_over_src, &rewriter);
// Load `source[selected_ivs]`.
auto src_elem = rewriter.create<LoadOp>(loc, s_and_s_op.source(),
loop_over_src.getInductionVars());
// Compute `out[selected_ivs]` = scatter(out[selected_ivs], src_element)`.
auto rmw = rewriter.create<GenericAtomicRMWOp>(loc, s_and_s_op.out(),
selected_ivs);
OpBuilder rmw_builder = OpBuilder::atBlockEnd(rmw.getBody());
auto acc_result =
ApplySingleResultLhloCode(loc, {src_elem, rmw.getCurrentValue()},
&s_and_s_op.scatter().front(), &rmw_builder);
rmw_builder.create<AtomicYieldOp>(loc, acc_result);
rewriter.replaceOp(s_and_s_op, llvm::None);
return success();
}
private:
void InitializeOutput(xla_lhlo::SelectAndScatterOp s_and_s_op,
OpBuilder* b) const {
auto loc = s_and_s_op.getLoc();
Value init_value = b->create<LoadOp>(loc, s_and_s_op.init_value());
scf::ParallelOp loop_over_output =
MakeLoopOverShape(loc, s_and_s_op.out(), b);
OpBuilder::InsertionGuard guard(*b);
b->setInsertionPointToStart(loop_over_output.getBody());
b->create<StoreOp>(loc, init_value, s_and_s_op.out(),
loop_over_output.getInductionVars());
}
struct WindowLoops {
SmallVector<Value, 2> selected_ivs;
SmallVector<Value, 2> window_ivs;
scf::ForOp inner_loop;
};
WindowLoops InsertWindowLoops(xla_lhlo::SelectAndScatterOp s_and_s_op,
scf::ParallelOp loop_over_src,
OpBuilder* b) const {
auto loc = s_and_s_op.getLoc();
Value zero = b->create<ConstantIndexOp>(loc, 0);
Value one = b->create<ConstantIndexOp>(loc, 1);
auto element_type =
s_and_s_op.out().getType().cast<MemRefType>().getElementType();
auto rank = loop_over_src.getNumLoops();
// `iter_args` = [iv_1, ..., iv_N, selected_value, is_initialized]
SmallVector<Value, 4> iter_args(rank, zero);
iter_args.push_back(b->create<mlir::ConstantOp>(
loc, element_type, b->getFloatAttr(element_type, 0)));
iter_args.push_back(b->create<mlir::ConstantOp>(
loc, b->getI1Type(), b->getIntegerAttr(b->getI1Type(), 0)));
// Create a nested loop that traverses the window.
OpBuilder::InsertPoint ip;
WindowLoops result;
for (const auto& window_dim :
s_and_s_op.window_dimensions()->getIntValues()) {
Value upper = b->create<ConstantIndexOp>(loc, window_dim.getSExtValue());
result.inner_loop =
b->create<scf::ForOp>(loc, zero, upper, one, iter_args);
if (b->getInsertionBlock() == loop_over_src.getBody()) {
ip = b->saveInsertionPoint();
result.selected_ivs = result.inner_loop.getResults().take_front(rank);
} else {
b->create<scf::YieldOp>(loc, result.inner_loop.getResults());
}
b->setInsertionPointToStart(result.inner_loop.getBody());
iter_args = ValueRange{result.inner_loop.getRegionIterArgs()};
result.window_ivs.push_back(result.inner_loop.getInductionVar());
}
b->restoreInsertionPoint(ip);
return result;
}
// Adapter to store iteration arguments of sequential loops that perform
// select in a window.
class IterArgs {
public:
explicit IterArgs(ValueRange ivs_val_flag) : ivs_val_flag_(ivs_val_flag) {}
IterArgs(ValueRange ivs, Value value, Value flag) {
ivs_val_flag_ = ivs;
ivs_val_flag_.push_back(value);
ivs_val_flag_.push_back(flag);
}
ArrayRef<Value> to_vector() const { return ivs_val_flag_; }
// Indices of the currently selected value.
ArrayRef<Value> ivs() const { return to_vector().drop_back(2); }
// Currently selected value w.r.t. select() function.
Value value() const { return ivs_val_flag_.end()[-2]; }
// i1 flag if value() and ivs() were initialized.
Value is_init() const { return ivs_val_flag_.back(); }
private:
// Vector that stores iv_1, ..., iv_N, value, init.
SmallVector<Value, 4> ivs_val_flag_;
};
SmallVector<Value, 2> SelectIvs(xla_lhlo::SelectAndScatterOp s_and_s_op,
scf::ParallelOp loop_over_src,
OpBuilder* b) const {
auto loc = s_and_s_op.getLoc();
WindowLoops window_loops = InsertWindowLoops(s_and_s_op, loop_over_src, b);
auto inner_loop_b =
OpBuilder::atBlockEnd(window_loops.inner_loop.getBody());
// Compute ivs in 'arg' buffer and whether these ivs are in the pad area.
MappedIvs mapped_ivs =
MapWindowIvsToInput(s_and_s_op, loop_over_src.getInductionVars(),
window_loops.window_ivs, &inner_loop_b);
IterArgs ivs_val_flag(window_loops.inner_loop.getRegionIterArgs());
auto if_in_bounds = inner_loop_b.create<scf::IfOp>(
loc, window_loops.inner_loop.getResultTypes(), mapped_ivs.in_bounds,
/*withElseRegion=*/true);
// Case when we are inside boundaries of 'arg' and not in the pad area.
{
OpBuilder in_bounds_then_b = if_in_bounds.getThenBodyBuilder();
auto select_or_init_results = SelectOrInitialize(
s_and_s_op, mapped_ivs.ivs, &ivs_val_flag, &in_bounds_then_b);
in_bounds_then_b.create<scf::YieldOp>(loc, select_or_init_results);
}
// Case when we are in the pad.
{
OpBuilder in_bounds_else_b = if_in_bounds.getElseBodyBuilder();
in_bounds_else_b.create<scf::YieldOp>(loc, ivs_val_flag.to_vector());
}
inner_loop_b.create<scf::YieldOp>(loc, if_in_bounds.getResults());
return window_loops.selected_ivs;
}
SmallVector<Value, 4> SelectOrInitialize(
xla_lhlo::SelectAndScatterOp s_and_s_op, ArrayRef<Value> operand_ivs,
IterArgs* ivs_val_flag, OpBuilder* b) const {
auto loc = s_and_s_op.getLoc();
Value true_i1 = b->create<mlir::ConstantOp>(
loc, b->getI1Type(), b->getIntegerAttr(b->getI1Type(), 1));
TypeRange iter_arg_types{ivs_val_flag->to_vector()};
Value operand_elem =
b->create<LoadOp>(loc, s_and_s_op.operand(), operand_ivs);
auto if_init =
b->create<scf::IfOp>(loc, iter_arg_types, ivs_val_flag->is_init(),
/*withElseRegion=*/true);
// Init == true, i.e. iter args are already initialized with a selected
// element in boundaries of the operand. Select function has to be computed
// here.
{
OpBuilder if_init_then_b = if_init.getThenBodyBuilder();
auto& lhlo_select = s_and_s_op.select().front();
Value pred =
ApplySingleResultLhloCode(loc, {operand_elem, ivs_val_flag->value()},
&lhlo_select, &if_init_then_b);
auto if_pred = if_init_then_b.create<scf::IfOp>(loc, iter_arg_types, pred,
/*withElseRegion=*/true);
// Pred == true, therefore pack newly selected ivs, val and init flag back
// to iter_args and return.
{
OpBuilder if_pred_then_b = if_pred.getThenBodyBuilder();
if_pred_then_b.create<scf::YieldOp>(
loc, IterArgs{operand_ivs, operand_elem, true_i1}.to_vector());
}
// Pred == false, therefore return old iter_args.
{
OpBuilder if_pred_else_b = if_pred.getElseBodyBuilder();
if_pred_else_b.create<scf::YieldOp>(loc, ivs_val_flag->to_vector());
}
if_init_then_b.create<scf::YieldOp>(loc, if_pred.getResults());
}
// Init == false, i.e. only pad was visited before and this is the first
// element in the boundaries of the operand.
{
OpBuilder if_init_else_b = if_init.getElseBodyBuilder();
if_init_else_b.create<scf::YieldOp>(
loc, IterArgs{operand_ivs, operand_elem, true_i1}.to_vector());
}
return if_init.getResults();
}
};
struct LhloLegalizeToParallelLoops
: public PassWrapper<LhloLegalizeToParallelLoops, FunctionPass> {
void runOnFunction() override {
auto func = getFunction();
OwningRewritePatternList patterns;
// clang-format off
patterns.insert<
ReduceOpConverter,
ReduceWindowOpConverter,
SelectAndScatterOpConverter
>(func.getContext());
// clang-format on
ConversionTarget target(getContext());
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,
scf::SCFDialect, XlaLhloDialect>();
target.addIllegalOp<xla_lhlo::ReduceOp, xla_lhlo::ReduceWindowOp,
xla_lhlo::SelectAndScatterOp>();
if (failed(applyPartialConversion(func, target, patterns))) {
signalPassFailure();
}
}
};
} // namespace
std::unique_ptr<OperationPass<FuncOp>> createLegalizeLhloToParallelLoopsPass() {
return absl::make_unique<LhloLegalizeToParallelLoops>();
}
static PassRegistration<LhloLegalizeToParallelLoops> legalize_lhlo_pass(
"lhlo-legalize-to-parallel-loops",
"Legalize from LHLO dialect to parallel loops.");
} // namespace xla_lhlo
} // namespace mlir

View File

@ -0,0 +1,79 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Thsi file implements passes to convert complex operations to equivalent real
// value operations. This does not include removing complex values from function
// argument or return types.
#include <cstddef>
#include <cstdint>
#include <iterator>
#include <numeric>
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLExtras.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/TypeUtilities.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Types.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/PassRegistry.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h"
using mlir::FunctionPass;
using mlir::OwningRewritePatternList;
using mlir::PassRegistration;
using mlir::PassWrapper;
namespace {
class LowerComplex : public PassWrapper<LowerComplex, FunctionPass> {
public:
explicit LowerComplex() : PassWrapper<LowerComplex, FunctionPass>() {}
/// Performs the lowering to XLA dialect.
void runOnFunction() override;
};
} // end anonymous namespace
namespace mlir {
namespace xla {
namespace {
#include "third_party/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/generated_lower_complex.inc"
} // end anonymous namespace
void PopulateComplexLoweringPatterns(MLIRContext* context,
OwningRewritePatternList* patterns) {
populateWithGenerated(context, patterns);
}
} // end namespace xla
} // end namespace mlir
// Lowers the complex operations that can be represented using other operations.
void LowerComplex::runOnFunction() {
// Add lowering patterns to the list.
OwningRewritePatternList patterns;
mlir::xla::PopulateComplexLoweringPatterns(&getContext(), &patterns);
applyPatternsAndFoldGreedily(getFunction(), patterns);
}
static PassRegistration<LowerComplex> pass(
"test-xla-lower-complex",
"Lower complex operations into non-complex operations");

View File

@ -0,0 +1,109 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This is the legalization pattern that converts complex operations into
// equivalent real value operations.
include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpBase.td"
include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td"
include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td"
//===----------------------------------------------------------------------===//
// Binary op patterns.
//===----------------------------------------------------------------------===//
// Add and subtraction are elementwise and can be distributed across the real
// and imaginary components.
foreach elementwiseOp = [HLO_AddOp, HLO_SubOp] in
def : Pat<(elementwiseOp HLO_ComplexTensor:$lhs,
HLO_ComplexTensor:$rhs),
(HLO_ComplexOp
(elementwiseOp (HLO_RealOp $lhs), (HLO_RealOp $rhs)),
(elementwiseOp (HLO_ImagOp $lhs), (HLO_ImagOp $rhs)))>;
// Complex multiplication results in a cross product multiplication between the
// real and imaginary components such that:
// result.real = lhs.real * rhs.real - lhs.imag * rhs.imag
// result.imag = lhs.imag * rhs.real + lhs.real * rhs.imag
def : Pat<(HLO_MulOp HLO_ComplexTensor:$lhs,
HLO_ComplexTensor:$rhs),
(HLO_ComplexOp
(HLO_SubOp
(HLO_MulOp
(HLO_RealOp:$lhs_real $lhs),
(HLO_RealOp:$rhs_real $rhs)),
(HLO_MulOp
(HLO_ImagOp:$lhs_imag $lhs),
(HLO_ImagOp:$rhs_imag $rhs))),
(HLO_AddOp
(HLO_MulOp $lhs_real, $rhs_imag),
(HLO_MulOp $lhs_imag, $rhs_real)))>;
// Multiplication between a complex and real tensor can be distributed by
// applying the real multiplicant to both the real and complex component.
//
// Note that the sourcep pattern is not legal according to the HLO dialect but
// instead handle intermediates generated by other patterns.
def : Pat<(HLO_MulOp HLO_ComplexTensor:$lhs, HLO_IntOrFpTensor:$rhs),
(HLO_ComplexOp
(HLO_MulOp (HLO_RealOp $lhs), $rhs),
(HLO_MulOp (HLO_ImagOp $lhs), $rhs))>;
def : Pat<(HLO_MulOp HLO_IntOrFpTensor:$lhs, HLO_ComplexTensor:$rhs),
(HLO_ComplexOp
(HLO_MulOp $lhs, (HLO_RealOp $rhs)),
(HLO_MulOp $lhs, (HLO_ImagOp $rhs)))>;
// Division is performed by normalizing the denominator by multiplying by the
// conjugate of the rhs.
// numerator = lhs * conj(rhs)
// denominator = rhs * conj(rhs)
def : Pat<(HLO_DivOp HLO_ComplexTensor:$lhs, HLO_ComplexTensor:$rhs),
(HLO_DivOp
(HLO_MulOp:$num $lhs,
(HLO_ComplexOp:$conj
(HLO_RealOp $rhs),
(HLO_NegOp (HLO_ImagOp $rhs)))),
(HLO_RealOp:$den (HLO_MulOp $rhs, $conj)))>;
def : Pat<(HLO_DivOp HLO_ComplexTensor:$lhs, HLO_IntOrFpTensor:$rhs),
(HLO_ComplexOp
(HLO_DivOp (HLO_RealOp $lhs), $rhs),
(HLO_DivOp (HLO_ImagOp $lhs), $rhs))>;
// Absolute value is evaluated as:
// result = sqrt(val.real * val.real + val.imag * val.imag)
def : Pat<(HLO_AbsOp HLO_ComplexTensor:$val),
(HLO_ComplexOp
(HLO_SqrtOp
(HLO_AddOp
(HLO_MulOp (HLO_RealOp:$real $val), $real),
(HLO_MulOp (HLO_ImagOp:$imag $val), $imag))),
(HLO_ConstOp (ConstantSplat<"0"> $real)))>;
// Exponential can be lowered to an exponential on the real component and a
// sum of sinusoids of the imaginary component, which equates to a normal
// exponential operator multiplied by Euler's formula.
//
// Exp(a + ib) = Exp(a) * Exp(ib) = Exp(a) * (Cos(b) + iSin(b))
def : Pat<(HLO_ExpOp HLO_ComplexTensor:$val),
(HLO_MulOp
(HLO_ExpOp (HLO_RealOp $val)),
(HLO_ComplexOp
(HLO_CosOp (HLO_ImagOp:$imag $val)),
(HLO_SinOp $imag)))>;

View File

@ -0,0 +1,194 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file implements logic for lowering XLA general dot to a regular dot.
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLExtras.h"
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringSwitch.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Function.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Location.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/TypeUtilities.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
using mlir::DenseIntElementsAttr;
using mlir::ElementsAttr;
using mlir::failure;
using mlir::FunctionPass;
using mlir::LogicalResult;
using mlir::MLIRContext;
using mlir::OpRewritePattern;
using mlir::OwningRewritePatternList;
using mlir::PassRegistration;
using mlir::PassWrapper;
using mlir::PatternRewriter;
using mlir::RankedTensorType;
using mlir::success;
using mlir::Value;
namespace {
Value TransposeReshape(Value arg, mlir::Location loc,
llvm::ArrayRef<int64_t> left_dims,
llvm::ArrayRef<int64_t> right_dims,
llvm::ArrayRef<int64_t> arg_shape,
PatternRewriter *rewriter) {
auto element_type = mlir::getElementTypeOrSelf(arg.getType());
int64_t left_size = 1;
for (auto dim : left_dims) {
left_size *= arg_shape[dim];
}
int64_t right_size = 1;
for (auto dim : right_dims) {
right_size *= arg_shape[dim];
}
// Generate the transpose permutation attribute.
llvm::SmallVector<int64_t, 5> transpose_permutation(left_dims.begin(),
left_dims.end());
transpose_permutation.append(right_dims.begin(), right_dims.end());
mlir::TensorType transpose_permutation_type = RankedTensorType::get(
{static_cast<int64_t>(transpose_permutation.size())},
rewriter->getIntegerType(64));
auto transpose_permutation_attr =
DenseIntElementsAttr::get(transpose_permutation_type,
llvm::makeArrayRef(transpose_permutation))
.cast<DenseIntElementsAttr>();
// Compute the resulting shape.
llvm::SmallVector<int64_t, 5> transposed_shape;
for (auto val : transpose_permutation) {
transposed_shape.push_back(arg_shape[val]);
}
auto transpose_type = RankedTensorType::get(transposed_shape, element_type);
auto transpose_result = rewriter->create<mlir::xla_hlo::TransposeOp>(
loc, transpose_type, arg, transpose_permutation_attr);
// Return the final result.
auto reshaped_type =
RankedTensorType::get({left_size, right_size}, element_type);
return rewriter->create<mlir::xla_hlo::ReshapeOp>(loc, reshaped_type,
transpose_result);
}
Value ProcessDotArg(Value arg, mlir::Location loc,
ElementsAttr contract_dims_attr, bool outer_dims_first,
PatternRewriter *rewriter) {
auto shape = arg.getType().cast<mlir::ShapedType>().getShape();
llvm::SmallVector<bool, 5> is_outer_dim;
is_outer_dim.resize(shape.size(), true);
// Compute the contract dimension ordering.
llvm::SmallVector<int64_t, 5> contract_dims;
for (auto dim : contract_dims_attr.getValues<int64_t>()) {
contract_dims.push_back(dim);
is_outer_dim[dim] = false;
}
// Compute the outer dimension orderings.
llvm::SmallVector<int64_t, 5> outer_dims;
for (auto it : llvm::enumerate(is_outer_dim)) {
if (it.value()) {
outer_dims.push_back(it.index());
}
}
if (outer_dims_first) {
return TransposeReshape(arg, loc, outer_dims, contract_dims, shape,
rewriter);
}
return TransposeReshape(arg, loc, contract_dims, outer_dims, shape, rewriter);
}
struct GeneralDotConvert
: public OpRewritePattern<mlir::xla_hlo::DotGeneralOp> {
// Attempts to lower a General Dot operator to a standard Dot operator.
// General dots include batching dimensions and can have collapsing
// dimensions along any axis. Inserting correctly arrange transpose and
// reshape operators organizes the tensors and allows the General Dot to be
// replaced with the standard Dot operator.
//
// Note: This requires an empty list of batch dimensions.
explicit GeneralDotConvert(MLIRContext *context)
: OpRewritePattern(context) {}
LogicalResult matchAndRewrite(mlir::xla_hlo::DotGeneralOp op,
PatternRewriter &rewriter) const override {
auto dot_element_type = mlir::getElementTypeOrSelf(op);
auto dot_numbers = op.dot_dimension_numbers();
if (dot_numbers.lhs_batching_dimensions().getNumElements() != 0 ||
dot_numbers.rhs_batching_dimensions().getNumElements() != 0) {
return failure();
}
auto lhs = ProcessDotArg(op.lhs(), op.getLoc(),
dot_numbers.lhs_contracting_dimensions(),
/*outer_dims_first=*/true, &rewriter);
auto rhs = ProcessDotArg(op.rhs(), op.getLoc(),
dot_numbers.rhs_contracting_dimensions(),
/*outer_dims_first=*/false, &rewriter);
// Dot resulting shape.
auto lhs_shape = lhs.getType().cast<mlir::ShapedType>().getShape();
auto rhs_shape = rhs.getType().cast<mlir::ShapedType>().getShape();
auto new_dot_type =
RankedTensorType::get({lhs_shape[0], rhs_shape[1]}, dot_element_type);
auto new_dot_op = rewriter.create<mlir::xla_hlo::DotOp>(
op.getLoc(), new_dot_type, lhs, rhs, *(op.precision_config()));
rewriter.replaceOpWithNewOp<mlir::xla_hlo::ReshapeOp>(op, op.getType(),
new_dot_op);
return success();
}
};
struct LegalizeGeneralDot
: public PassWrapper<LegalizeGeneralDot, FunctionPass> {
/// Lower all general dots that can be represented as a non-batched matmul.
void runOnFunction() override {
OwningRewritePatternList patterns;
mlir::xla_hlo::PopulateGeneralDotOpLoweringPatterns(&patterns,
&getContext());
applyPatternsAndFoldGreedily(getFunction(), patterns);
}
};
} // namespace
void mlir::xla_hlo::PopulateGeneralDotOpLoweringPatterns(
OwningRewritePatternList *patterns, MLIRContext *ctx) {
patterns->insert<GeneralDotConvert>(ctx);
}
static PassRegistration<LegalizeGeneralDot> legalize_pass(
"test-xla-lower-general-dot",
"Tests lowering general dot to a non-batched dot when possible");

View File

@ -0,0 +1,90 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <numeric>
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
namespace mlir {
namespace xla_hlo {
namespace {
// Converts ClampOp with broadcast semantics. ClampOp requires "all three arrays
// must be the same shape. Alternatively, as a restricted form of broadcasting,
// min and/or max can be a scalar of type T."
struct ClampWithBroadcastConvert : public OpRewritePattern<ClampOp> {
explicit ClampWithBroadcastConvert(MLIRContext *context)
: OpRewritePattern<ClampOp>(context) {}
LogicalResult matchAndRewrite(ClampOp op,
PatternRewriter &rewriter) const override {
auto operand_type = op.operand().getType().dyn_cast<RankedTensorType>();
auto max_type = op.max().getType().dyn_cast<RankedTensorType>();
auto min_type = op.min().getType().dyn_cast<RankedTensorType>();
// Unrancked types are not supported.
if (!operand_type || !max_type || !min_type) return failure();
// Does not support operand with dynamic dimensions for now.
if (!operand_type.hasStaticShape()) return failure();
ArrayRef<int64_t> operand_shape = operand_type.getShape();
Value max_value = op.max();
if (max_type != operand_type) {
assert(max_type.getRank() == 0);
max_value = rewriter.createOrFold<BroadcastOp>(
op.getLoc(), operand_type, max_value,
rewriter.getI64TensorAttr(operand_shape));
}
Value min_value = op.min();
if (min_type != operand_type) {
assert(min_type.getRank() == 0);
min_value = rewriter.createOrFold<BroadcastOp>(
op.getLoc(), operand_type, min_value,
rewriter.getI64TensorAttr(operand_shape));
}
rewriter.replaceOpWithNewOp<ClampOp>(op, op.getType(), min_value,
op.operand(), max_value);
return success();
}
};
} // namespace
void SetupMaterializeBroadcastsLegality(MLIRContext *context,
ConversionTarget *conversionTarget) {
conversionTarget->addDynamicallyLegalOp<ClampOp>([](ClampOp op) {
return op.max().getType() == op.operand().getType() &&
op.min().getType() == op.operand().getType();
});
}
void PopulateMaterializeBroadcastsPatterns(MLIRContext *context,
OwningRewritePatternList *patterns) {
// ClampOp. This op has a special case where it accepts either same-shaped
// inputs or scalars (a restricted form of broadcasting). This makes the
// broadcast explicit.
patterns->insert<ClampWithBroadcastConvert>(context);
}
} // namespace xla_hlo
} // namespace mlir

View File

@ -0,0 +1,58 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
namespace mlir {
namespace xla_hlo {
namespace {
struct TestMaterializeBroadcastsPass
: public PassWrapper<TestMaterializeBroadcastsPass, FunctionPass> {
void runOnFunction() override {
ConversionTarget conversionTarget(getContext());
OwningRewritePatternList conversionPatterns;
// Consider the xla_hlo dialect legal for tests.
conversionTarget.addLegalDialect<XlaHloDialect>();
// The conversion uses helpers from the Standard dialect.
conversionTarget.addLegalDialect<mlir::StandardOpsDialect>();
SetupMaterializeBroadcastsLegality(&getContext(), &conversionTarget);
PopulateMaterializeBroadcastsPatterns(&getContext(), &conversionPatterns);
if (failed(applyPartialConversion(getFunction(), conversionTarget,
conversionPatterns))) {
return signalPassFailure();
}
}
};
} // namespace
} // namespace xla_hlo
} // namespace mlir
static mlir::PassRegistration<mlir::xla_hlo::TestMaterializeBroadcastsPass>
pass("test-xla-materialize-broadcasts",
"Test pass for materializing 'broadcast_dimensions' attributes");

View File

@ -0,0 +1,85 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/DenseMap.h"
#include "third_party/llvm/llvm-project/llvm/include/llvm/Support/Casting.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/PassManager.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Support/LLVM.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/RegionUtils.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
namespace mlir {
namespace xla_hlo {
namespace {
// A pass that sinks constants implicitly captured in control flow regions. This
// is necessary to export to XLA.
class SinkConstantsToControlFlow
: public mlir::PassWrapper<SinkConstantsToControlFlow, FunctionPass> {
void runOnFunction() override {
getFunction().walk([](Operation* op) {
if (auto while_op = llvm::dyn_cast<WhileOp>(op)) {
SinkToRegion(&while_op.body());
SinkToRegion(&while_op.cond());
} else if (auto if_op = llvm::dyn_cast<IfOp>(op)) {
SinkToRegion(&if_op.true_branch());
SinkToRegion(&if_op.false_branch());
}
});
}
private:
// Performs constant sinking into a region.
static void SinkToRegion(Region* region) {
llvm::DenseMap<Value, ConstOp> sunk_constant;
visitUsedValuesDefinedAbove({*region}, [&](OpOperand* use) {
Value constant = use->get();
auto const_op = dyn_cast_or_null<ConstOp>(constant.getDefiningOp());
if (!const_op) return;
auto map_entry = sunk_constant.try_emplace(constant, nullptr);
if (!map_entry.second) {
// This constant has already been cloned into the region, reuse it.
use->set(map_entry.first->getSecond().getResult());
if (constant.use_empty()) const_op.erase();
return;
}
if (constant.hasOneUse()) {
const_op.getOperation()->moveBefore(&region->front().front());
return;
}
map_entry.first->getSecond() = const_op.clone();
region->front().getOperations().insert(region->front().begin(),
map_entry.first->getSecond());
use->set(map_entry.first->getSecond().getResult());
});
}
};
static mlir::PassRegistration<SinkConstantsToControlFlow> pass(
"xla-hlo-sink-constants-to-control-flow",
"Sink constants implicitly captured in control flow regions. This is "
"necessary to export to XLA.");
} // anonymous namespace
std::unique_ptr<OperationPass<FuncOp>> createSinkConstantsToControlFlowPass() {
return std::make_unique<SinkConstantsToControlFlow>();
}
} // namespace xla_hlo
} // namespace mlir

View File

@ -0,0 +1,100 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Identifier.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OperationSupport.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/InferTypeOpInterface.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
namespace mlir {
namespace xla {
namespace {
struct InferReturnTypeComponentsPattern : public RewritePattern {
InferReturnTypeComponentsPattern(MLIRContext *context)
: RewritePattern("xla_test.get_return_type_components", 1, context) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
if (op->getNumOperands() != 1) return failure();
auto defining_op = op->getOperand(0).getDefiningOp();
auto defining_op_int =
llvm::dyn_cast_or_null<InferShapedTypeOpInterface>(defining_op);
if (!defining_op_int) return failure();
SmallVector<ShapedTypeComponents, 4> components;
if (failed(defining_op_int.inferReturnTypeComponents(
op->getContext(), op->getLoc(), defining_op->getOperands(),
defining_op->getAttrDictionary(), defining_op->getRegions(),
components))) {
return failure();
}
// Replace the op with another pass-through op with attributes added.
OperationState state(op->getLoc(), "xla_test.return_type_components",
op->getOperands(), op->getResultTypes(),
op->getAttrs());
auto new_op = rewriter.createOperation(state);
for (auto it : llvm::enumerate(components)) {
if (it.value().hasRank()) {
new_op->setAttr((StringRef("dims") + Twine(it.index())).str(),
rewriter.getI64ArrayAttr(it.value().getDims()));
}
if (it.value().getElementType()) {
new_op->setAttr((Twine("element_type") + Twine(it.index())).str(),
TypeAttr::get(it.value().getElementType()));
}
}
rewriter.replaceOp(op, {new_op->getResults()});
return success();
}
};
struct ReifyReturnTypeShapesPattern : public RewritePattern {
ReifyReturnTypeShapesPattern(MLIRContext *context)
: RewritePattern("xla_test.reify_return_type_shapes", 1, context) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
if (op->getNumOperands() != 1) return failure();
auto defining_op = llvm::dyn_cast_or_null<InferShapedTypeOpInterface>(
op->getOperand(0).getDefiningOp());
if (!defining_op) return failure();
SmallVector<Value, 4> return_shapes;
if (failed(defining_op.reifyReturnTypeShapes(rewriter, return_shapes))) {
return failure();
}
rewriter.replaceOp(op, return_shapes);
return success();
}
};
struct TestInferShapedTypeMethodsPass
: public PassWrapper<TestInferShapedTypeMethodsPass, FunctionPass> {
void runOnFunction() override {
OwningRewritePatternList patterns;
patterns.insert<ReifyReturnTypeShapesPattern>(&getContext());
patterns.insert<InferReturnTypeComponentsPattern>(&getContext());
applyPatternsAndFoldGreedily(getFunction(), patterns);
}
};
} // namespace
} // namespace xla
} // namespace mlir
static mlir::PassRegistration<mlir::xla::TestInferShapedTypeMethodsPass> pass(
"test-xla-infer-shaped-type-methods",
"Uses test ops to invoke InferShapedTypeOpInterface methods");

View File

@ -0,0 +1,184 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/SmallVector.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Types.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
namespace mlir {
namespace xla_hlo {
namespace {
// Broadcasts the 1D value tensor 'value_1d' to the shape of 'result_type'. If
// 'shape_value' is initialized, creates a dynamic broadcast, otherwise creates
// a static broadcast.
Value BroadcastToFeatureDim(Location loc, RankedTensorType result_type,
Value value_1d, Value shape_value,
int64_t feature_dim,
PatternRewriter& rewriter) { // NOLINT
Builder b(rewriter.getContext());
auto dims_type = RankedTensorType::get({1}, b.getIntegerType(64));
auto dims = DenseIntElementsAttr::get(dims_type, {feature_dim});
if (shape_value) {
return rewriter.createOrFold<xla_hlo::DynamicBroadcastInDimOp>(
loc, result_type, value_1d, shape_value, dims);
}
assert(result_type.hasStaticShape());
return rewriter.create<xla_hlo::BroadcastInDimOp>(loc, result_type, value_1d,
dims);
}
// Calculate the shape value of operand, assuming it is a dynamic shape with
// static rank.
Value CalculateShapeValue(Location loc, Value operand,
PatternRewriter& rewriter) { // NOLINT
RankedTensorType result_type = operand.getType().dyn_cast<RankedTensorType>();
llvm::SmallVector<Value, 4> shape_values;
int64_t rank = result_type.getRank();
shape_values.reserve(rank);
for (int64_t i = 0; i < rank; ++i) {
shape_values.push_back(rewriter.create<mlir::DimOp>(loc, operand, i));
}
return rewriter.create<TensorFromElementsOp>(loc, shape_values);
}
Value MaterializeEpsilon(Operation* op, FloatAttr epsilon_attr,
FloatType fp_type, Value variance,
RankedTensorType broadcast_to_type,
PatternRewriter& rewriter) { // NOLINT
Builder b(rewriter.getContext());
if (epsilon_attr.getType() != fp_type) {
// Need to convert.
bool loses_info;
APFloat epsilon_float = epsilon_attr.getValue();
auto status = epsilon_float.convert(
fp_type.getFloatSemantics(), APFloat::rmNearestTiesToEven, &loses_info);
if ((status & (~APFloat::opInexact)) != APFloat::opOK) {
op->emitWarning() << "Could not convert batch_norm epsilon to target fp "
"type: opStatus = "
<< static_cast<int>(status);
return nullptr;
}
if (loses_info) {
op->emitWarning("Conversion of epsilon loses precision");
}
epsilon_attr = b.getFloatAttr(fp_type, epsilon_float);
}
auto scalar_type = RankedTensorType::get({}, fp_type);
auto epsilon_tensor_attr =
DenseElementsAttr::get(scalar_type, {epsilon_attr.cast<Attribute>()});
Value epsilon =
rewriter.create<xla_hlo::ConstOp>(op->getLoc(), epsilon_tensor_attr);
auto dims_type = RankedTensorType::get({0}, b.getIntegerType(64));
auto dims = DenseIntElementsAttr::get(dims_type, SmallVector<int64_t, 1>{});
if (broadcast_to_type.hasStaticShape()) {
return rewriter.create<xla_hlo::BroadcastInDimOp>(
op->getLoc(), broadcast_to_type, epsilon, /*broadcast_dims=*/dims);
}
Value shape_value = CalculateShapeValue(op->getLoc(), variance, rewriter);
return rewriter.createOrFold<xla_hlo::DynamicBroadcastInDimOp>(
op->getLoc(), broadcast_to_type, epsilon, shape_value,
/*broadcast_dims=*/dims);
}
class UnfuseBatchNormInferencePattern
: public OpRewritePattern<xla_hlo::BatchNormInferenceOp> {
public:
using OpRewritePattern<xla_hlo::BatchNormInferenceOp>::OpRewritePattern;
LogicalResult matchAndRewrite(xla_hlo::BatchNormInferenceOp bn_op,
PatternRewriter& rewriter) const override {
// Enforce type invariants.
// Note that we deduce the actual element type from the variance,
// which should not be subject to quantization at a higher level.
auto input_type = bn_op.operand().getType().dyn_cast<RankedTensorType>();
auto variance_type =
bn_op.variance().getType().dyn_cast<RankedTensorType>();
if (!input_type || !variance_type) {
return failure();
}
auto fp_type = variance_type.getElementType().dyn_cast<FloatType>();
if (!fp_type) {
return failure();
}
int64_t feature_dim = bn_op.feature_index().getSExtValue();
// Add epsilon to the variance and sqrt to get stddev:
// stddev = sqrt(variance + epsilon)
auto epsilon =
MaterializeEpsilon(bn_op.getOperation(), bn_op.epsilonAttr(), fp_type,
bn_op.variance(), variance_type, rewriter);
if (!epsilon) {
return failure();
}
Value stddev = rewriter.create<xla_hlo::AddOp>(bn_op.getLoc(),
bn_op.variance(), epsilon);
stddev = rewriter.create<xla_hlo::SqrtOp>(bn_op.getLoc(), stddev);
// Broadcast all terms.
Value shape_value;
if (!input_type.hasStaticShape()) {
shape_value =
CalculateShapeValue(bn_op.getLoc(), bn_op.operand(), rewriter);
}
auto broadcast_scale =
BroadcastToFeatureDim(bn_op.getLoc(), input_type, bn_op.scale(),
shape_value, feature_dim, rewriter);
auto broadcast_offset =
BroadcastToFeatureDim(bn_op.getLoc(), input_type, bn_op.offset(),
shape_value, feature_dim, rewriter);
auto broadcast_mean =
BroadcastToFeatureDim(bn_op.getLoc(), input_type, bn_op.mean(),
shape_value, feature_dim, rewriter);
auto broadcast_stddev = BroadcastToFeatureDim(
bn_op.getLoc(), input_type, stddev, shape_value, feature_dim, rewriter);
// Compute:
// scale * (input - mean) / stddev + offset
Value result = rewriter.create<xla_hlo::SubOp>(
bn_op.getLoc(), bn_op.operand(), broadcast_mean);
result = rewriter.create<xla_hlo::MulOp>(bn_op.getLoc(), result,
broadcast_scale);
result = rewriter.create<xla_hlo::DivOp>(bn_op.getLoc(), result,
broadcast_stddev);
rewriter.replaceOpWithNewOp<xla_hlo::AddOp>(bn_op, result,
broadcast_offset);
return success();
}
};
} // namespace
// Populates conversion patterns to unfuse batch normalization operations.
// In combination with marking such ops as illegal, this allows backends that
// do not have special support for fused batchnorm to use simpler arithmetic
// primitives.
void PopulateUnfuseBatchNormPatterns(MLIRContext* context,
OwningRewritePatternList* patterns) {
patterns->insert<UnfuseBatchNormInferencePattern>(context);
}
} // namespace xla_hlo
} // namespace mlir

View File

@ -0,0 +1,46 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
namespace mlir {
namespace xla_hlo {
namespace {
struct TestUnfuseBatchNormPass
: public PassWrapper<TestUnfuseBatchNormPass, OperationPass<>> {
void runOnOperation() override {
OwningRewritePatternList patterns;
PopulateUnfuseBatchNormPatterns(&getContext(), &patterns);
applyPatternsAndFoldGreedily(getOperation(), patterns);
}
};
} // namespace
} // namespace xla_hlo
} // namespace mlir
static mlir::PassRegistration<mlir::xla_hlo::TestUnfuseBatchNormPass> pass(
"test-xla-unfuse-batch-norm",
"Test pass for materializing 'broadcast_dimensions' attributes");

View File

@ -0,0 +1,579 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Matchers.h"
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "mlir/Transforms/RegionUtils.h" // TF:llvm-project
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/EquivalenceClasses.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/cycle_detector.h"
// This pass has similar functionality of the fusion pass in XLA stack.
// However, unlike XLA, it targets the fully dynamic shape scenario.
// Currently, it implements the kLoop and kInput fusion templates.
// During conversion, it tries to greedily find kLoop/kInput fusion
// patterns.
//
// Similar to XLA, this pass supports fusion pattern having multiple outputs
// if all the shape of outputs are consistent. Following are some examples.
//
// kLoop kInput
// +----+ +----+ +----+ +----+ +----+ +----+
// |elem| |elem| |elem| |elem<----+elem+---->elem+----+
// +-+--+ +-+--+ +-+--+ +-+--+ +----+ +-+--+ |
// | | | | | |
// | | | | |
// +-v--+ | +-v--+ +--v---+ +--v---+ |
// |elem+<---+----<+elem| |reduce| |reduce| |
// +-+--+ +-+--+ +--+---+ +--+---+ |
// | | | | |
// | | | | |
// v v v v v
//
// To this end, we also add an simple shape constraint analysis phase.
// For kLoop fusion template, it requires all the outputs of the fused
// pattern have the same shape. However, we don't know the actual value
// of the shape at the compile time in the dynamic shape world.
// Fortunately, we could still infer the relationship among different ops
// according to their shape constrain traits. Currently, We only consider
// shape equality propagation for elementwise ops (assuming that implicit
// shape broadcast is forbidden). The above process could be built on the
// shape dialect once it is ready.
namespace mlir {
namespace xla_hlo {
namespace {
using llvm::EquivalenceClasses;
using FusionPattern = std::vector<Operation*>;
using FusionPlan = std::vector<FusionPattern>;
// To support using EquivalenceClasses for Value
class ValueWrapper {
public:
explicit ValueWrapper(Value value) : value_(std::move(value)) {}
Value getValue() const { return value_; }
bool operator==(const ValueWrapper& rhs) const {
return getValue() == rhs.getValue();
}
private:
Value value_;
};
bool operator<(const ValueWrapper& lhs, const ValueWrapper& rhs) {
auto lhs_value = lhs.getValue().getAsOpaquePointer();
auto rhs_value = rhs.getValue().getAsOpaquePointer();
return lhs_value < rhs_value;
}
bool IsFusible(Operation* op) {
if (matchPattern(op, m_Constant())) {
return true;
}
auto op_fusibility = dyn_cast<InferFusibilityOpInterface>(op);
return op_fusibility && (op_fusibility.isFusibleWithOperand() ||
op_fusibility.isFusibleWithConsumer());
}
SmallVector<Value, 4> GetInputsOfFusionPattern(const FusionPattern& pattern) {
SmallVector<Value, 4> inputs;
DenseSet<Value> input_set;
DenseSet<Operation*> op_set;
for (Operation* op : pattern) {
bool inserted = op_set.insert(op).second;
(void)inserted;
assert(inserted && "FusionPattern contains duplicate operations");
}
for (Operation* op : pattern) {
for (Value operand : op->getOperands()) {
Operation* operand_op = operand.getDefiningOp();
if (op_set.find(operand_op) != op_set.end()) {
// skip if defining op is in the pattern
continue;
}
if (input_set.insert(operand).second) {
inputs.push_back(operand);
}
}
}
return inputs;
}
SmallVector<Value, 4> GetOutputsOfFusionPattern(const FusionPattern& pattern) {
SmallVector<Value, 4> outputs;
DenseSet<Operation*> op_set;
for (Operation* op : pattern) {
bool inserted = op_set.insert(op).second;
(void)inserted;
assert(inserted && "FusionPattern contains duplicate operations");
}
for (Operation* op : pattern) {
for (Value result : op->getResults()) {
bool has_external_user = llvm::any_of(
result.getUses(),
[&](OpOperand& use) { return !op_set.count(use.getOwner()); });
if (has_external_user) {
outputs.push_back(result);
}
}
}
return outputs;
}
FusionPattern MergeFusionPattern(const FusionPattern& lhs,
const FusionPattern& rhs) {
FusionPattern pattern(lhs);
pattern.insert(pattern.end(), rhs.begin(), rhs.end());
return pattern;
}
inline int EffectiveSize(const FusionPattern& pattern) {
return llvm::count_if(
pattern, [](Operation* op) { return !matchPattern(op, m_Constant()); });
}
// This is an simple shape constraint analysis, which is used to
// guide fusion decision (e.g. we only fuse shape-compatible ops).
//
// Currently, We only consider shape equality propagation based
// on the shape constrain traits of elementwise ops (assuming that
// implicit shape broadcast is forbidden).
class ShapeConstraintAnalysis {
public:
explicit ShapeConstraintAnalysis(const SmallVectorImpl<Operation*>& op_list) {
PropagateEquality(op_list);
}
// Returns true is `lhs` and `rhs` are supposed to have same shape.
bool HasSameShape(Value lhs, Value rhs) {
return impl_.isEquivalent(ValueWrapper(lhs), ValueWrapper(rhs));
}
private:
// shape equality propagation based on the shape constrains of
// elementwise ops.
void PropagateEquality(const SmallVectorImpl<Operation*>& op_list) {
bool converged = true;
do {
converged = true;
auto update = [&](Value lhs, Value rhs) {
if (!impl_.isEquivalent(ValueWrapper(lhs), ValueWrapper(rhs))) {
converged = false;
impl_.unionSets(ValueWrapper(lhs), ValueWrapper(rhs));
}
};
for (Operation* op : op_list) {
auto op_fusibility = dyn_cast<InferFusibilityOpInterface>(op);
if (!op_fusibility) continue;
int numInput = op->getNumOperands();
int numOutput = op->getNumResults();
// shape equality propagation between inputs.
for (int input1 = 0; input1 < numInput; ++input1)
for (int input2 = input1 + 1; input2 < numInput; ++input2)
if (op_fusibility.inferInputsShapeEquality(input1, input2))
update(op->getOperand(input1), op->getOperand(input2));
// shape equality propagation between outputs.
for (int output1 = 0; output1 < numOutput; ++output1)
for (int output2 = output1 + 1; output2 < numOutput; ++output2)
if (op_fusibility.inferOutputsShapeEquality(output1, output2))
update(op->getResult(output1), op->getResult(output2));
// shape equality propagation between input and output.
for (int input = 0; input < numInput; ++input)
for (int output = 0; output < numOutput; ++output)
if (op_fusibility.inferInputOutputShapeEquality(input, output))
update(op->getOperand(input), op->getResult(output));
}
} while (!converged);
}
// a UnionFind set
EquivalenceClasses<ValueWrapper> impl_;
};
// A fusion planner that can propose a fusion plan for a block of ops.
// The fusion plan is consisted of a group of fusion patterns.
//
// Currently all proposed patterns followed xla kLoop/kInput like fusion
// templates while are adapted to the fully dynamic shape world.
//
// kLoop fusion template satifies:
// - all ops in the fusion pattern are element-wise.
// - all the shapes of outputs of fusion pattern are same, and thus can
// fit into a same parallel loop.
//
// kInput fusion template satifies:
// - any op in the fusion pattern is either element-wise or a reduction.
// - if a op is a reduction, its output cannot be consumered by other
// ops in the same fusion pattern.
// - all the effective shapes of outputs of fusion pattern are same.
// - For element-wise op, its effective shape is its output shape.
// - For reduction op, its effective shape is its operand shape.
class FusionPlanner {
public:
explicit FusionPlanner(const SmallVectorImpl<Operation*>& op_list)
: op_list_(op_list),
shape_analysis_(op_list),
cycle_detector_(op_list.size()) {
BuildNodeMap();
}
// Returns a fusion plan if success, otherwise none.
llvm::Optional<FusionPlan> Run() {
// Greedily search connected fusible pattern, and ops belonging to
// a same fusion pattern are grouped into a cluster.
RunEdgeContractionLoop();
// After doing edge contraction, each unique cluster having size
// more than one represents a potential fusion pattern.
// We collect all these clusters and construct a fusion plan.
//
// Note that the ops in a fusion pattern are in topological ordering.
FusionPlan plan;
DenseMap<int, int> pattern_ids;
for (Operation* op : op_list_) {
Cluster* cluster = GetClusterForNode(op);
int node_id = cluster->cycles_graph_node_id();
if (!IsFusible(op_list_[node_id]) ||
EffectiveSize(GetClusterForNode(op)->fused_pattern()) <= 1) {
continue;
}
if (!pattern_ids.count(node_id)) {
int pattern_id = pattern_ids.size();
pattern_ids[node_id] = pattern_id;
plan.emplace_back();
}
plan[pattern_ids[node_id]].push_back(op);
}
return plan;
}
// Returns the op_list this planner operates on.
const SmallVectorImpl<Operation*>& op_list() const { return op_list_; }
private:
// Represent a (partial) fused pattern
class Cluster {
public:
Cluster(int node_id, FusionPlanner* planner) : node_id_(node_id) {
const SmallVectorImpl<Operation*>& op_list = planner->op_list();
pattern_.push_back(op_list[node_id]);
}
// Merges `other` into this cluster, and clears `other`.
void Merge(Cluster* other) {
pattern_.insert(pattern_.end(), other->pattern_.begin(),
other->pattern_.end());
other->pattern_.clear();
}
// The number of nodes in this cluster.
int cluster_size() const { return pattern_.size(); }
// The ID of the cluster as represented in `cycle_detector_`.
int cycles_graph_node_id() const { return node_id_; }
// Sets the ID of the cluster as represented in `cycle_detector_`.
void set_cycles_graph_node_id(int cycles_graph_node_id) {
node_id_ = cycles_graph_node_id;
}
// Currently the fused pattern this cluster holds.
const FusionPattern& fused_pattern() { return pattern_; }
private:
// ID of the representative node of this cluster.
int node_id_;
// the fused pattern this cluster holds.
FusionPattern pattern_;
};
private:
Cluster* MakeCluster(int cycles_graph_node_id) {
cluster_storage_.emplace_back(new Cluster(cycles_graph_node_id, this));
return cluster_storage_.back().get();
}
void BuildNodeMap() {
int num_nodes = op_list_.size();
for (int node_id = 0; node_id < num_nodes; ++node_id) {
Operation* op = op_list_[node_id];
MakeCluster(node_id);
op_to_node_id_[op] = node_id;
leader_for_node_.insert(node_id);
for (Value operand : op->getOperands()) {
Operation* operand_op = operand.getDefiningOp();
if (operand_op == nullptr) {
// skip block argument
continue;
}
auto iter = op_to_node_id_.find(operand_op);
assert(iter != op_to_node_id_.end());
cycle_detector_.InsertEdge(iter->second, node_id);
}
}
}
// Returns the cluster contains this op.
Cluster* GetClusterForNode(Operation* n) {
int id = op_to_node_id_.at(n);
id = leader_for_node_.getLeaderValue(id);
return cluster_storage_[id].get();
}
// Returns the cluster contains the op having `node_id`.
Cluster* GetClusterForCyclesGraphNode(int node_id) {
return cluster_storage_[leader_for_node_.getLeaderValue(node_id)].get();
}
// Merges the clusters `cluster_from` and `cluster_to`.
bool MergeClusters(Cluster* cluster_from, Cluster* cluster_to) {
int from = cluster_from->cycles_graph_node_id();
int to = cluster_to->cycles_graph_node_id();
auto optional_merged_node = cycle_detector_.ContractEdge(from, to);
if (!optional_merged_node.hasValue()) {
llvm::dbgs() << "Could not contract " << from << " -> " << to
<< " because contracting the edge would create a cycle.";
return false;
}
// Merge the clusters.
cluster_from->Merge(cluster_to);
cluster_from->set_cycles_graph_node_id(*optional_merged_node);
// Merge the UnionFind Set.
leader_for_node_.unionSets(from, to);
return true;
}
template <typename FnTy>
bool ForEachEdgeInPostOrder(FnTy fn) {
bool changed = false;
for (int32_t node : cycle_detector_.AllNodesInPostOrder()) {
Cluster* cluster_from = GetClusterForCyclesGraphNode(node);
// Make a copy of the set of successors because we may modify the graph in
// TryToContractEdge.
std::vector<int32_t> successors_copy =
cycle_detector_.SuccessorsCopy(cluster_from->cycles_graph_node_id());
for (int to : successors_copy) {
Cluster* cluster_to = GetClusterForCyclesGraphNode(to);
bool contracted_edge = fn(cluster_from, cluster_to);
changed |= contracted_edge;
}
}
return changed;
}
// returns the outputs if two cluster were merged
SmallVector<Value, 4> GetResultsOfFusedPattern(Cluster* from, Cluster* to) {
FusionPattern fused_pattern =
MergeFusionPattern(from->fused_pattern(), to->fused_pattern());
return GetOutputsOfFusionPattern(fused_pattern);
}
// This function check if fusing `from` with `to` is valid and if so perform
// the merge. The validity is based on the operations in the clusters and
// the compatibility of the shapes of the outputs of the would-be fused
// clusters.
// Returns true is the merge was performed.
bool TryToContractEdge(Cluster* from, Cluster* to) {
int node_to = to->cycles_graph_node_id();
int node_from = from->cycles_graph_node_id();
// Both node_to and node_from should be fusible
if (!IsFusible(op_list_[node_to]) || !IsFusible(op_list_[node_from])) {
return false;
}
auto op_from_fusibility =
dyn_cast<InferFusibilityOpInterface>(op_list_[node_from]);
if (op_from_fusibility && !op_from_fusibility.isFusibleWithConsumer()) {
// This op cannot be fused with its consumers.
return false;
}
auto op_to_fusibility =
dyn_cast<InferFusibilityOpInterface>(op_list_[node_to]);
if (op_to_fusibility && !op_to_fusibility.isFusibleWithOperand()) {
// This op cannot be fused with its operands.
return false;
}
// Output shapes of a fusion pattern should be compatible as described in
// the document of this class.
SmallVector<Value, 4> results = GetResultsOfFusedPattern(from, to);
auto get_workload_shape = [](Value v) {
Operation* op = v.getDefiningOp();
// Block argument
if (!op) return v;
auto op_fusibility = dyn_cast<InferFusibilityOpInterface>(op);
// Const value
if (!op_fusibility) return v;
llvm::Optional<Value> workload =
op_fusibility.inferEffectiveWorkloadShape();
return workload.hasValue() ? *workload : v;
};
Value ref = get_workload_shape(results[0]);
if (!llvm::all_of(results, [&](Value result) {
Value val = get_workload_shape(result);
return shape_analysis_.HasSameShape(ref, val);
})) {
return false;
}
return MergeClusters(from, to);
}
// Greedily fuse connected node.
bool RunEdgeContractionLoop() {
using std::placeholders::_1;
using std::placeholders::_2;
return ForEachEdgeInPostOrder(
std::bind(&FusionPlanner::TryToContractEdge, this, _1, _2));
}
const SmallVectorImpl<Operation*>& op_list_;
// Shape equality checker
ShapeConstraintAnalysis shape_analysis_;
// op -> node_id
std::unordered_map<Operation*, int> op_to_node_id_;
// make sure not introduce cycle after fusion
GraphCycles cycle_detector_;
std::vector<std::unique_ptr<Cluster>> cluster_storage_;
// a UnionFind set. Each set represents a (partial) fused pattern
// and has a leader as representation.
EquivalenceClasses<int32_t> leader_for_node_;
};
struct XlaHloFusion : public mlir::PassWrapper<XlaHloFusion, FunctionPass> {
void runOnFunction() override {
FuncOp func = getFunction();
if (!IsTargetFunc(func)) {
return;
}
// process each block and do fusion within a block.
for (Block& block : func) {
SmallVector<Operation*, 4> op_list;
for (Operation& op : block) {
op_list.push_back(&op);
}
FusionPlanner planner(op_list);
llvm::Optional<FusionPlan> plan = planner.Run();
if (!plan) {
emitError(func.getLoc(), "can't find a fusion plan");
signalPassFailure();
return;
}
if (!ApplyFusionPlan(*plan)) {
emitError(func.getLoc(), "apply fusion plan failed");
signalPassFailure();
return;
}
}
}
bool IsTargetFunc(FuncOp func) {
int num_fusible_ops = 0;
bool is_target_func = false;
// We only process the function having enough candidates
func.walk([&](Operation* op) {
num_fusible_ops +=
static_cast<int>(dyn_cast<InferFusibilityOpInterface>(op) != nullptr);
is_target_func = (num_fusible_ops > 1);
// early stop
if (is_target_func) return WalkResult::interrupt();
return WalkResult::advance();
});
return is_target_func;
}
bool ApplyFusionPlan(const FusionPlan& plan) {
for (const FusionPattern& pattern : plan) {
OpBuilder b(pattern.back());
SmallVector<Location, 4> locations;
locations.reserve(pattern.size());
for (Operation* op : pattern) {
locations.push_back(op->getLoc());
}
Location fused_loc =
FusedLoc::get(locations, pattern.back()->getContext());
SmallVector<Value, 4> inputs = GetInputsOfFusionPattern(pattern);
SmallVector<Value, 4> outputs = GetOutputsOfFusionPattern(pattern);
SmallVector<Type, 4> output_types;
output_types.reserve(outputs.size());
for (Value v : outputs) {
output_types.push_back(v.getType());
}
FusionOp fusion =
b.create<xla_hlo::FusionOp>(fused_loc, output_types, inputs);
Region& region = fusion.fused_computation();
region.push_back(new Block);
Block& block = region.front();
for (Operation* op : pattern) {
op->moveBefore(&block, block.end());
}
b.setInsertionPoint(&block, block.end());
b.create<xla_hlo::ReturnOp>(fused_loc, outputs);
for (auto output_and_result : llvm::zip(outputs, fusion.getResults())) {
Value output = std::get<0>(output_and_result);
Value fusion_result = std::get<1>(output_and_result);
for (OpOperand& use : llvm::make_early_inc_range(output.getUses())) {
if (use.getOwner()->getBlock() != &block) use.set(fusion_result);
}
}
}
return true;
}
};
} // namespace
std::unique_ptr<OperationPass<FuncOp>> createXlaHloFusion() {
return std::make_unique<XlaHloFusion>();
}
static PassRegistration<XlaHloFusion> xla_hlo_fusion_pass(
"xla-hlo-fusion", "fuse xla_hlo ops to kLoop/kInput fusion patterns.");
} // namespace xla_hlo
} // namespace mlir

View File

@ -0,0 +1,909 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file implements logic for lowering HLO/LHLO dialect to Linalg dialect.
#include "third_party/absl/memory/memory.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/AffineExpr.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Function.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Location.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_xla_to_scalar_op.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
namespace mlir {
namespace {
SmallVector<StringRef, 3> GetNParallelLoopsAttrs(unsigned nParallelLoops) {
static constexpr StringRef kParallelIterType = "parallel";
return SmallVector<StringRef, 3>(nParallelLoops, kParallelIterType);
}
template <bool isLHLO = true>
Value getResultValue(Operation* op) {
return isLHLO ? op->getOperand(op->getNumOperands() - 1) : op->getResult(0);
}
template <bool isLHLO = true>
ShapedType getXLAOpResultType(Operation* op) {
return getResultValue<isLHLO>(op).getType().template cast<ShapedType>();
}
template <bool isLHLO = true>
bool verifyXLAOpBufferOrTensorSemantics(Operation* op) {
auto verifyType = [&](Value val) -> bool {
return (isLHLO && val.getType().isa<MemRefType>()) ||
(!isLHLO && val.getType().isa<RankedTensorType>());
};
if (!llvm::all_of(op->getOperands(), verifyType)) return false;
return isLHLO ? op->getResults().empty()
: llvm::all_of(op->getResults(), verifyType);
}
template <typename OpTy, bool isLHLO = true>
class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
public:
using OpConversionPattern<OpTy>::OpConversionPattern;
LogicalResult matchAndRewrite(
OpTy op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final {
auto loc = op.getLoc();
auto argType =
op.getOperation()->getOperand(0).getType().template cast<ShapedType>();
if (!argType.hasRank()) {
emitError(loc, "lhlo to linalg conversion expects ranked args");
return failure();
}
auto elemTy = argType.getElementType();
if (!elemTy.isSignlessIntOrFloat() && !elemTy.template isa<ComplexType>()) {
return failure();
}
// Construct the indexing maps needed for linalg.generic ops.
SmallVector<AffineMap, 2> indexing_maps;
SmallVector<Type, 4> bodyArgTypes, bodyResultTypes, opResultTypes;
// This doesnt account for implicit broadcast, but the working assumption
// here is that are broadcasts have been made explicit.
unsigned nloops = argType.getRank();
if (isLHLO && !nloops) return failure();
int operandCount = (isLHLO ? args.size() - 1 : args.size());
auto verifyArgOrResultType = [&](Value val) -> ShapedType {
auto shapedType = val.getType().dyn_cast<ShapedType>();
if (!shapedType ||
(!shapedType.isa<MemRefType>() &&
!shapedType.isa<RankedTensorType>()) ||
shapedType.getRank() != nloops)
return nullptr;
indexing_maps.emplace_back(
nloops ? rewriter.getMultiDimIdentityMap(nloops)
: AffineMap::get(nloops, 0, rewriter.getContext()));
return shapedType;
};
for (const auto& arg : llvm::enumerate(args)) {
auto shapedType = verifyArgOrResultType(arg.value());
if (!shapedType) return failure();
auto& result_or_body_arg =
arg.index() < operandCount ? bodyArgTypes : bodyResultTypes;
result_or_body_arg.emplace_back(shapedType.getElementType());
}
if (!isLHLO) {
// HLO operations have return as tensor types.
assert(bodyResultTypes.empty() &&
"When lowering HLO ops result can't be part of arguments");
Value result = op.getOperation()->getResult(0);
auto shapedType = verifyArgOrResultType(result);
if (!shapedType) return failure();
bodyResultTypes.push_back(shapedType.getElementType());
opResultTypes.push_back(shapedType);
}
int64_t args_count = bodyArgTypes.size();
int64_t results_count = bodyResultTypes.size();
auto linalgOp = rewriter.create<linalg::GenericOp>(
loc, opResultTypes, args, args_count, results_count, indexing_maps,
GetNParallelLoopsAttrs(nloops),
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) {
// TODO(ravishankarm) : For now use the method in xla_lhlo namespace.
// That method needs to be moved out of there.
Value opResult = xla_lhlo::XlaOpToStdScalarOp::map<OpTy>(
op, bodyResultTypes,
llvm::to_vector<2>(args.take_front(args_count)), &rewriter);
nestedBuilder.create<linalg::YieldOp>(loc, opResult);
});
rewriter.replaceOp(op, linalgOp.getOperation()->getResults());
return success();
}
};
template <typename LhloOp>
class ScalarPointwiseToStandardConverter : public OpConversionPattern<LhloOp> {
public:
using OpConversionPattern<LhloOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
LhloOp lhlo_op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final {
auto loc = lhlo_op.getLoc();
auto argType =
lhlo_op.getOperand(0).getType().template dyn_cast<ShapedType>();
if (!argType || !argType.getElementType().isSignlessIntOrFloat() ||
(argType.getRank() != 0)) {
return failure();
}
// Create two loads from the input.
auto lhs = rewriter.create<LoadOp>(loc, lhlo_op.lhs());
auto rhs = rewriter.create<LoadOp>(loc, lhlo_op.rhs());
// TODO(ravishankarm) : Move this method out of xla_lhlo namespace.
Value opResult = xla_lhlo::XlaOpToStdScalarOp::map<LhloOp>(
lhlo_op, argType.getElementType(), llvm::ArrayRef<Value>{lhs, rhs},
&rewriter);
rewriter.create<StoreOp>(loc, opResult, lhlo_op.out());
rewriter.eraseOp(lhlo_op);
return success();
}
};
//===----------------------------------------------------------------------===//
// xla_lhlo.convolution conversion pattern.
//===----------------------------------------------------------------------===//
/// Converts xla_lhlo.convolution operation to a linalg.conv op.
struct ConvToLinalgConverter : public OpConversionPattern<xla_lhlo::ConvOp> {
public:
using OpConversionPattern<xla_lhlo::ConvOp>::OpConversionPattern;
// This code has been adapted from IREE's
// (https://github.com/google/iree/) xla_hlo -> linalg conversion.
LogicalResult matchAndRewrite(
xla_lhlo::ConvOp op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final {
// Check validity of dimension information.
if (const xla_lhlo::ConvDimensionNumbers& dimensionNumbers =
op.dimension_numbers()) {
const int inputSpatialRank =
llvm::size(dimensionNumbers.input_spatial_dimensions());
// The dimensions for input should follow the order of
// batch_count, spatial_dims..., input_feature_count.
if (dimensionNumbers.input_batch_dimension().getInt() != 0 ||
dimensionNumbers.input_feature_dimension().getInt() !=
(inputSpatialRank + 1))
return failure();
const int kernelSpatialRank =
llvm::size(dimensionNumbers.kernel_spatial_dimensions());
// The dimensions for filter should follow the order of
// spatial_dims..., input_feature_count, num_output_feature_count.
if (dimensionNumbers.kernel_input_feature_dimension().getInt() !=
kernelSpatialRank ||
dimensionNumbers.kernel_output_feature_dimension().getInt() !=
(kernelSpatialRank + 1))
return failure();
const int outputSpatialRank =
llvm::size(dimensionNumbers.output_spatial_dimensions());
// The dimensions for output should follow the order of
// batch_count, spatial_dims.., output_feature_count.
if (dimensionNumbers.output_batch_dimension().getInt() != 0 ||
dimensionNumbers.output_feature_dimension().getInt() !=
(outputSpatialRank + 1))
return failure();
if (inputSpatialRank != outputSpatialRank ||
inputSpatialRank != kernelSpatialRank)
return failure();
auto inputSpatialDim =
dimensionNumbers.input_spatial_dimensions().begin();
auto kernelSpatialDim =
dimensionNumbers.kernel_spatial_dimensions().begin();
auto outputSpatialDim =
dimensionNumbers.output_spatial_dimensions().begin();
// Check if spatial dims are ordered correctly.
for (int i = 0; i < inputSpatialRank; ++i) {
const int dim = i + 1;
if ((*inputSpatialDim++).getZExtValue() != dim ||
(*outputSpatialDim++).getZExtValue() != dim ||
(*kernelSpatialDim++).getZExtValue() != i)
return failure();
}
}
// TODO: LHS dilation for deconvolution not supported yet.
if (op.lhs_dilation()) {
return failure();
}
llvm::SmallVector<Attribute, 4> strides;
if (auto windowStrides = op.window_strides()) {
auto range = windowStrides->getAttributeValues();
strides.assign(range.begin(), range.end());
}
auto stridesArg = ArrayAttr::get(strides, op.getContext());
llvm::SmallVector<Attribute, 2> dilation;
if (auto rhsDilation = op.rhs_dilation()) {
auto range = rhsDilation->getAttributeValues();
dilation.assign(range.begin(), range.end());
} else {
// Default dilation of 1.
dilation.resize(2, IntegerAttr::get(rewriter.getIntegerType(64), 1));
}
auto dilationArg = ArrayAttr::get(dilation, op.getContext());
// Set padding only if it is non-zero.
DenseIntElementsAttr padding = op.paddingAttr();
if (!padding || !llvm::any_of(padding.getValues<APInt>(), [](APInt intVal) {
return !intVal.isNullValue();
})) {
padding = nullptr;
}
// The order of input and filter are switched with linalg.conv.
rewriter.replaceOpWithNewOp<linalg::ConvOp>(
op, args[1], args[0], args[2], stridesArg, dilationArg, padding);
return success();
}
};
/// Base class for lowering xla operations that have one operand and one result,
/// and are semantically equivalent to a copy of the input to the output (like
/// transpose, some reshape, etc.). The derived classes need to provide a method
/// `getIndexingMaps` that returns AffineMaps for the index maps of the input
/// and the output.
template <typename Derived, typename OpTy, bool isLHLO = true>
class DataMovementOpConverter : public OpConversionPattern<OpTy> {
public:
using OpConversionPattern<OpTy>::OpConversionPattern;
LogicalResult matchAndRewrite(
OpTy op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final {
if (!verifyXLAOpBufferOrTensorSemantics<isLHLO>(op)) return failure();
auto resultType = getXLAOpResultType<isLHLO>(op);
SmallVector<AffineMap, 2> indexing_maps =
Derived::getIndexingMaps(op, &rewriter);
if (indexing_maps.empty()) return failure();
auto nloops = resultType.getRank();
auto loc = op.getLoc();
auto linalgOp = rewriter.create<linalg::GenericOp>(
loc, isLHLO ? ArrayRef<Type>{} : resultType, args, /*inputCount=*/1,
/*outputCount=*/1, indexing_maps, GetNParallelLoopsAttrs(nloops),
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) {
nestedBuilder.create<linalg::YieldOp>(loc, *args.begin());
});
rewriter.replaceOp(op, linalgOp.getOperation()->getResults());
return success();
}
};
/// Pattern to convert BroadcastOp to Linalg ops.
template <typename OpTy, bool isLHLO = true>
class BroadcastConverter
: public DataMovementOpConverter<BroadcastConverter<OpTy, isLHLO>, OpTy,
isLHLO> {
public:
using DataMovementOpConverter<BroadcastConverter, OpTy,
isLHLO>::DataMovementOpConverter;
static SmallVector<AffineMap, 2> getIndexingMaps(OpTy broadcastOp,
Builder* b) {
ShapedType inputType =
broadcastOp.operand().getType().template cast<ShapedType>();
unsigned inputRank = inputType.getRank();
unsigned nloops = getXLAOpResultType<isLHLO>(broadcastOp).getRank();
// BroadcastOp prepends the dimensions in the `broadcast_sizes` attribute to
// the input's dimensions.
unsigned numPrependedDims = llvm::size(broadcastOp.broadcast_sizes());
SmallVector<AffineExpr, 4> inputDimExprs;
inputDimExprs.reserve(inputRank);
for (int i = 0; i < inputRank; ++i) {
inputDimExprs.push_back(b->getAffineDimExpr(numPrependedDims + i));
}
AffineMap inputMap;
MLIRContext* context = b->getContext();
if (inputDimExprs.empty()) {
// The input is a scalar, i.e. this is a scalar broadcast op.
inputMap = AffineMap::get(nloops, /*symbolCount=*/0, context);
} else {
inputMap =
AffineMap::get(nloops, /*symbolCount=*/0, inputDimExprs, context);
}
return {inputMap, b->getMultiDimIdentityMap(nloops)};
}
};
class HloBroadcastInDimConverter
: public DataMovementOpConverter<HloBroadcastInDimConverter,
xla_hlo::BroadcastInDimOp, false> {
public:
using DataMovementOpConverter<HloBroadcastInDimConverter,
xla_hlo::BroadcastInDimOp,
false>::DataMovementOpConverter;
static SmallVector<AffineMap, 2> getIndexingMaps(
xla_hlo::BroadcastInDimOp broadcastOp, Builder* b) {
auto resultType = getXLAOpResultType<false>(broadcastOp);
auto operandType =
broadcastOp.operand().getType().template cast<ShapedType>();
unsigned nloops = resultType.getRank();
// The input is a scalar, i.e. this is a scalar broadcast op.
if (operandType.getRank() == 0) {
return {AffineMap::get(nloops, /*symbolCount=*/0, b->getContext()),
b->getMultiDimIdentityMap(nloops)};
}
auto operandShape = operandType.getShape();
SmallVector<AffineExpr, 4> dimExprs;
dimExprs.reserve(nloops);
if (broadcastOp.broadcast_dimensions()) {
for (const auto& broadcastDim :
enumerate(broadcastOp.broadcast_dimensions().getIntValues())) {
int size = broadcastDim.value().getSExtValue();
bool expansion_needed = operandShape[broadcastDim.index()] == 1 &&
resultType.getShape()[size] != 1;
dimExprs.push_back(expansion_needed ? b->getAffineConstantExpr(0)
: b->getAffineDimExpr(size));
}
}
return {
AffineMap::get(nloops, /*symbolCount=*/0, dimExprs, b->getContext()),
b->getMultiDimIdentityMap(nloops)};
}
};
class LhloBroadcastInDimConverter
: public OpConversionPattern<xla_lhlo::BroadcastInDimOp> {
public:
using OpConversionPattern<xla_lhlo::BroadcastInDimOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
xla_lhlo::BroadcastInDimOp op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final {
xla_lhlo::BroadcastInDimOp::Adaptor operand_adaptor(args);
auto result_type = operand_adaptor.output().getType().cast<MemRefType>();
auto result_shape = result_type.getShape();
auto operand_and_dims = InsertReshapeIfNecessary(op, args, rewriter);
Value operand = std::get<0>(operand_and_dims);
auto broadcast_dims = std::get<1>(operand_and_dims);
auto loc = op.getLoc();
auto nloops = result_type.getRank();
auto operand_type = operand.getType().cast<MemRefType>();
// For a degenerate case, i.e. broadcasting with expansion of
// memref<1xELEMENT_TYPE>, the operand is not passed to `linalg.generic`.
// Instead the value is loaded and used directly in `linalg.yield`.
if (operand_type.getRank() == 1 &&
operand_type.getDimSize(0) <
result_type.getDimSize(broadcast_dims.front())) {
Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
Value val =
rewriter.create<LoadOp>(loc, operand, llvm::makeArrayRef({zero}));
rewriter.create<linalg::GenericOp>(
loc, llvm::None, llvm::makeArrayRef(operand_adaptor.output()),
/*inputCount=*/0, /*outputCount=*/1,
llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
GetNParallelLoopsAttrs(nloops),
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) {
nestedBuilder.create<linalg::YieldOp>(loc, val);
});
} else {
auto indexing_maps = getIndexingMaps(op, broadcast_dims, result_shape,
operand_type, &rewriter);
rewriter.create<linalg::GenericOp>(
loc, llvm::None,
llvm::makeArrayRef({operand, operand_adaptor.output()}),
/*inputCount=*/1, /*outputCount=*/1, indexing_maps,
GetNParallelLoopsAttrs(nloops),
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) {
nestedBuilder.create<linalg::YieldOp>(loc, *args.begin());
});
}
rewriter.replaceOp(op, llvm::None);
return success();
}
// Inserts 'linalg.reshape' if there is a size-1 dim expansion.
std::pair<Value, SmallVector<int64_t, 2>> InsertReshapeIfNecessary(
xla_lhlo::BroadcastInDimOp op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const {
xla_lhlo::BroadcastInDimOp::Adaptor operand_adaptor(args);
Value operand = operand_adaptor.operand();
auto operand_type = operand_adaptor.operand().getType().cast<MemRefType>();
auto operand_shape = operand_type.getShape();
Value result = operand_adaptor.output();
auto result_type = result.getType().cast<MemRefType>();
auto result_shape = result_type.getShape();
SmallVector<int64_t, 2> operand_strides;
int64_t operand_offset;
if (failed(getStridesAndOffset(operand_type, operand_strides,
operand_offset))) {
op.emitOpError() << "Failed to get offset and strides.";
}
SmallVector<int64_t, 2> new_shape, new_strides, broadcast_dims;
SmallVector<linalg::ReassociationIndices, 4> collapsed_dims_list;
linalg::ReassociationIndices collapsed_dims;
for (const auto& item :
enumerate(op.broadcast_dimensions().getIntValues())) {
size_t index = item.index();
int dim = item.value().getSExtValue();
collapsed_dims.push_back(index);
bool expansion_needed =
operand_shape[index] == 1 && result_shape[dim] != 1;
if (expansion_needed) {
continue;
}
new_shape.push_back(operand_shape[index]);
new_strides.push_back(operand_strides[index]);
broadcast_dims.push_back(dim);
collapsed_dims_list.push_back(collapsed_dims);
collapsed_dims.clear();
}
// If `collapsed_dims_list` is empty, then the memref has shape [1, ..., 1]
// and all dimensions need expansion. Such memref will be reshaped to a 1D
// memref with a single element. New shape and strides needs to be updated
// accordingly.
if (collapsed_dims_list.empty()) {
collapsed_dims_list.push_back({});
new_shape.push_back(1);
new_strides.push_back(1);
broadcast_dims.push_back(0);
}
for (const auto& dims : collapsed_dims) {
collapsed_dims_list.back().push_back(dims);
}
// `linalg.reshape` is inserted only if necessary, i.e. when the rank can be
// reduced.
if (new_shape.size() < operand_shape.size()) {
auto new_memref_type = MemRefType::get(
new_shape, operand_type.getElementType(),
makeStridedLinearLayoutMap(new_strides, operand_offset,
rewriter.getContext()));
operand = rewriter.create<linalg::ReshapeOp>(op.getLoc(), new_memref_type,
operand_adaptor.operand(),
collapsed_dims_list);
}
return std::make_pair(operand, broadcast_dims);
}
SmallVector<AffineMap, 2> getIndexingMaps(xla_lhlo::BroadcastInDimOp op,
ArrayRef<int64_t> broadcastDims,
ArrayRef<int64_t> resultShape,
MemRefType operandType,
Builder* b) const {
unsigned nloops = resultShape.size();
// The input is a scalar, i.e. this is a scalar broadcast op.
if (operandType.getRank() == 0) {
return {AffineMap::get(nloops, /*symbolCount=*/0, b->getContext()),
b->getMultiDimIdentityMap(nloops)};
}
auto operandShape = operandType.getShape();
SmallVector<AffineExpr, 4> dimExprs;
dimExprs.reserve(nloops);
for (const auto& broadcastDim : llvm::enumerate(broadcastDims)) {
int size = broadcastDim.value();
bool expansion_needed =
operandShape[broadcastDim.index()] == 1 && resultShape[size] != 1;
if (expansion_needed) {
op.emitOpError(
"BroadcastInDimOp lowering to Linalg does not support size-1 "
"dimensions expansion.");
}
dimExprs.push_back(b->getAffineDimExpr(size));
}
return {
AffineMap::get(nloops, /*symbolCount=*/0, dimExprs, b->getContext()),
b->getMultiDimIdentityMap(nloops)};
}
};
template <typename OpTy, bool isLHLO = true>
class TransposeConverter
: public DataMovementOpConverter<TransposeConverter<OpTy, isLHLO>, OpTy,
isLHLO> {
public:
using DataMovementOpConverter<TransposeConverter<OpTy, isLHLO>, OpTy,
isLHLO>::DataMovementOpConverter;
static SmallVector<AffineMap, 2> getIndexingMaps(OpTy op, Builder* b) {
auto resultType =
getXLAOpResultType<isLHLO>(op).template cast<ShapedType>();
auto nloops = resultType.getRank();
SmallVector<AffineExpr, 2> inputExprs;
inputExprs.resize(resultType.getRank());
for (auto permutation : llvm::enumerate(op.permutation())) {
inputExprs[permutation.value().getZExtValue()] =
b->getAffineDimExpr(permutation.index());
}
return {
AffineMap::get(nloops, /*symbolCount=*/0, inputExprs, b->getContext()),
b->getMultiDimIdentityMap(nloops)};
}
};
// Converts reshape ops that can be proven to be either a collapse of dimensions
// or expansion of dimensions of the operand.
template <typename OpTy, bool isLHLO = true>
class ReshapeOpConverter : public OpConversionPattern<OpTy> {
public:
using OpConversionPattern<OpTy>::OpConversionPattern;
LogicalResult matchAndRewrite(
OpTy reshapeOp, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final {
if (!verifyXLAOpBufferOrTensorSemantics<isLHLO>(reshapeOp))
return failure();
ShapedType operandType =
reshapeOp.operand().getType().template cast<ShapedType>();
ShapedType resultType = getXLAOpResultType<isLHLO>(reshapeOp);
if (!operandType.hasStaticShape() || !resultType.hasStaticShape())
return failure();
// Compute the reassociation maps for the linalg operation.
ArrayRef<int64_t> srcShape =
(operandType.getRank() > resultType.getRank() ? operandType.getShape()
: resultType.getShape());
ArrayRef<int64_t> dstShape =
(operandType.getRank() > resultType.getRank() ? resultType.getShape()
: operandType.getShape());
unsigned currSrcDim = 0, currDstDim = 0;
SmallVector<linalg::ReassociationExprs, 4> reassociationMap(
dstShape.size());
while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) {
int64_t dstSize = dstShape[currDstDim];
int64_t srcSize = srcShape[currSrcDim];
while (srcSize < dstSize && currSrcDim < srcShape.size()) {
reassociationMap[currDstDim].push_back(
rewriter.getAffineDimExpr(currSrcDim++));
srcSize *= srcShape[currSrcDim];
}
if (srcSize == dstSize) {
reassociationMap[currDstDim].push_back(
rewriter.getAffineDimExpr(currSrcDim++));
// If the next dim in dstShape is not 1, treat subsequent dims in
// srcShape which are 1 to be collapsed.
if (currDstDim == dstShape.size() - 1 ||
dstShape[currDstDim + 1] != 1) {
while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) {
reassociationMap[currDstDim].push_back(
rewriter.getAffineDimExpr(currSrcDim++));
}
}
} else {
return failure();
}
currDstDim++;
}
if (currSrcDim != srcShape.size()) return failure();
if (isLHLO) {
Value reshapeBuffer = rewriter.create<linalg::ReshapeOp>(
reshapeOp.getLoc(), resultType, args[0], reassociationMap);
rewriter.replaceOpWithNewOp<linalg::CopyOp>(
reshapeOp, reshapeBuffer, args[1], /*inputPermutation =*/nullptr,
/*outputPermutation =*/nullptr);
} else {
rewriter.replaceOpWithNewOp<linalg::TensorReshapeOp>(
reshapeOp, resultType, args[0], reassociationMap);
}
return success();
}
};
class IotaConverter : public OpConversionPattern<xla_lhlo::IotaOp> {
public:
using OpConversionPattern<xla_lhlo::IotaOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
xla_lhlo::IotaOp iotaOp, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final {
auto resultMemrefType =
iotaOp.getOperand().getType().dyn_cast<MemRefType>();
if (!resultMemrefType) return failure();
auto resultElementType = resultMemrefType.getElementType();
if (!resultElementType.isSignlessIntOrFloat()) return failure();
// Construct the indexing maps needed for linalg.generic ops.
unsigned nloops = resultMemrefType.getRank();
rewriter.create<linalg::IndexedGenericOp>(
iotaOp.getLoc(), ArrayRef<Type>{}, args,
0, // args_in
1, // args_out
llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
GetNParallelLoopsAttrs(nloops),
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange ivs,
ValueRange args) {
Value castOp = nestedBuilder.create<IndexCastOp>(
nestedLoc, ivs[iotaOp.iota_dimension().getZExtValue()],
nestedBuilder.getIntegerType(
resultElementType.getIntOrFloatBitWidth()));
if (resultElementType.isa<FloatType>()) {
castOp = nestedBuilder.create<SIToFPOp>(nestedLoc, castOp,
resultElementType);
}
nestedBuilder.create<linalg::YieldOp>(nestedLoc, castOp);
});
rewriter.replaceOp(iotaOp, llvm::None);
return success();
}
};
class ConstConverter : public OpConversionPattern<xla_lhlo::ConstOp> {
public:
using OpConversionPattern<xla_lhlo::ConstOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
xla_lhlo::ConstOp constOp, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final {
auto loc = constOp.getLoc();
auto valueAttr = constOp.value().cast<DenseElementsAttr>();
if (valueAttr.getType().getRank() != 0) return failure();
auto stdConstOp =
rewriter.create<mlir::ConstantOp>(loc, valueAttr.getValue({}));
rewriter.create<mlir::StoreOp>(loc, stdConstOp, constOp.getOperand());
rewriter.eraseOp(constOp);
return success();
}
};
// TODO(b/156787842): Support the lowering for dynamic shapes.
template <typename OpTy, bool isLHLO = true>
class ReverseConverter
: public DataMovementOpConverter<ReverseConverter<OpTy, isLHLO>, OpTy,
isLHLO> {
public:
using DataMovementOpConverter<ReverseConverter<OpTy, isLHLO>, OpTy,
isLHLO>::DataMovementOpConverter;
static SmallVector<AffineMap, 2> getIndexingMaps(OpTy op, Builder* b) {
auto resultType =
getXLAOpResultType<isLHLO>(op).template cast<ShapedType>();
auto nloops = resultType.getRank();
SmallVector<AffineExpr, 2> inputExprs;
inputExprs.reserve(nloops);
for (int i = 0; i < nloops; ++i)
inputExprs.push_back(b->getAffineDimExpr(i));
for (auto dim : op.dimensions()) {
int i = dim.getZExtValue();
if (resultType.isDynamicDim(i)) return {};
int n = resultType.getShape()[i];
inputExprs[i] = b->getAffineConstantExpr(n - 1) - inputExprs[i];
}
return {
AffineMap::get(nloops, /*symbolCount=*/0, inputExprs, b->getContext()),
b->getMultiDimIdentityMap(nloops)};
}
};
class SliceConverter : public OpConversionPattern<xla_lhlo::SliceOp> {
public:
using OpConversionPattern<xla_lhlo::SliceOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
xla_lhlo::SliceOp sliceOp, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final {
auto loc = sliceOp.getLoc();
auto argType =
sliceOp.getOperand(0).getType().template dyn_cast<ShapedType>();
if (!argType || !argType.hasRank()) {
emitError(loc, "lhlo to linalg conversion expects known-rank args");
return failure();
}
SmallVector<Value, 3> ranges;
for (int i = 0, e = argType.getRank(); i < e; ++i) {
Value start_index = rewriter.create<ConstantIndexOp>(
loc, sliceOp.start_indices().getValue<int64_t>(i));
Value limit_index = rewriter.create<ConstantIndexOp>(
loc, sliceOp.limit_indices().getValue<int64_t>(i));
Value stride = rewriter.create<ConstantIndexOp>(
loc, sliceOp.strides().getValue<int64_t>(i));
ranges.push_back(rewriter.create<linalg::RangeOp>(loc, start_index,
limit_index, stride));
}
auto linalg_slice =
rewriter.create<linalg::SliceOp>(loc, sliceOp.getOperand(0), ranges);
rewriter.create<linalg::CopyOp>(loc, linalg_slice, sliceOp.getOperand(1));
rewriter.eraseOp(sliceOp);
return success();
}
};
void populateLHLOToLinalgConversionPattern(MLIRContext* context,
OwningRewritePatternList* patterns) {
// clang-format off
patterns->insert<BroadcastConverter<xla_lhlo::BroadcastOp>,
ConstConverter,
ConvToLinalgConverter,
IotaConverter,
LhloBroadcastInDimConverter,
PointwiseToLinalgConverter<xla_lhlo::AbsOp>,
PointwiseToLinalgConverter<xla_lhlo::AddOp>,
PointwiseToLinalgConverter<xla_lhlo::AndOp>,
PointwiseToLinalgConverter<xla_lhlo::CeilOp>,
PointwiseToLinalgConverter<xla_lhlo::CompareOp>,
PointwiseToLinalgConverter<xla_lhlo::ComplexOp>,
PointwiseToLinalgConverter<xla_lhlo::ConvertOp>,
// TODO(ataei): Remove this pattern, CopyOp is folded away.
PointwiseToLinalgConverter<xla_lhlo::CopyOp>,
PointwiseToLinalgConverter<xla_lhlo::CosOp>,
PointwiseToLinalgConverter<xla_lhlo::DivOp>,
PointwiseToLinalgConverter<xla_lhlo::ExpOp>,
PointwiseToLinalgConverter<xla_lhlo::ImagOp>,
PointwiseToLinalgConverter<xla_lhlo::LogOp>,
PointwiseToLinalgConverter<xla_lhlo::MaxOp>,
PointwiseToLinalgConverter<xla_lhlo::MinOp>,
PointwiseToLinalgConverter<xla_lhlo::MulOp>,
PointwiseToLinalgConverter<xla_lhlo::NegOp>,
PointwiseToLinalgConverter<xla_lhlo::RealOp>,
PointwiseToLinalgConverter<xla_lhlo::RemOp>,
PointwiseToLinalgConverter<xla_lhlo::RsqrtOp>,
PointwiseToLinalgConverter<xla_lhlo::SelectOp>,
PointwiseToLinalgConverter<xla_lhlo::SignOp>,
PointwiseToLinalgConverter<xla_lhlo::SinOp>,
PointwiseToLinalgConverter<xla_lhlo::SqrtOp>,
PointwiseToLinalgConverter<xla_lhlo::SubOp>,
PointwiseToLinalgConverter<xla_lhlo::TanhOp>,
ReshapeOpConverter<xla_lhlo::ReshapeOp>,
ReverseConverter<xla_lhlo::ReverseOp>,
ScalarPointwiseToStandardConverter<xla_lhlo::AddOp>,
SliceConverter
>(context);
// clang-format on
}
// Converts LHLO ops to Linalg generic.
// Sample result for xla_lhlo::AddOp.
//
// "xla_lhlo.add"(%arg1, %arg2, %out) :
// (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
//
// will be converted to
//
// #map0 = (d0, d1) -> (d0, d1)
// "linalg.generic"(%arg1, %arg2, %out) ( {
// ^bb0(%arg4: f32, %arg5: f32):
// %0 = addf %arg4, %arg5 : f32
// "linalg.yield"(%0) : (f32) -> ()
// }) {
// args_in = 2,
// args_out = 1,
// indexing_maps = [#map0, #map0, #map0],
// iterator_types = ["parallel", "parallel"],
// } : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
struct LhloLegalizeToLinalg
: public PassWrapper<LhloLegalizeToLinalg, FunctionPass> {
void runOnFunction() override {
OwningRewritePatternList patterns;
ConversionTarget target(getContext());
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect>();
auto func = getFunction();
populateLHLOToLinalgConversionPattern(func.getContext(), &patterns);
if (failed(applyPartialConversion(func, target, patterns, nullptr))) {
signalPassFailure();
}
}
};
struct HloLegalizeToLinalg
: public PassWrapper<HloLegalizeToLinalg, FunctionPass> {
void runOnFunction() override {
OwningRewritePatternList patterns;
ConversionTarget target(getContext());
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect>();
auto func = getFunction();
xla_hlo::populateHLOToLinalgConversionPattern(func.getContext(), &patterns);
if (failed(applyPartialConversion(func, target, patterns, nullptr))) {
signalPassFailure();
}
}
};
} // namespace
namespace xla_lhlo {
std::unique_ptr<OperationPass<FuncOp>> createLegalizeLhloToLinalgPass() {
return absl::make_unique<LhloLegalizeToLinalg>();
}
static PassRegistration<LhloLegalizeToLinalg> legalize_lhlo_pass(
"lhlo-legalize-to-linalg", "Legalize from LHLO dialect to Linalg dialect");
} // namespace xla_lhlo
namespace xla_hlo {
void populateHLOToLinalgConversionPattern(MLIRContext* context,
OwningRewritePatternList* patterns) {
patterns->insert<BroadcastConverter<xla_hlo::BroadcastOp, false>,
HloBroadcastInDimConverter,
PointwiseToLinalgConverter<xla_hlo::AbsOp, false>,
PointwiseToLinalgConverter<xla_hlo::AddOp, false>,
PointwiseToLinalgConverter<xla_hlo::AndOp, false>,
PointwiseToLinalgConverter<xla_hlo::CeilOp, false>,
PointwiseToLinalgConverter<xla_hlo::CompareOp, false>,
PointwiseToLinalgConverter<xla_hlo::ComplexOp, false>,
PointwiseToLinalgConverter<xla_hlo::ConvertOp, false>,
PointwiseToLinalgConverter<xla_hlo::CopyOp, false>,
PointwiseToLinalgConverter<xla_hlo::CosOp, false>,
PointwiseToLinalgConverter<xla_hlo::DivOp, false>,
PointwiseToLinalgConverter<xla_hlo::ExpOp, false>,
PointwiseToLinalgConverter<xla_hlo::ImagOp, false>,
PointwiseToLinalgConverter<xla_hlo::LogOp, false>,
PointwiseToLinalgConverter<xla_hlo::MaxOp, false>,
PointwiseToLinalgConverter<xla_hlo::MinOp, false>,
PointwiseToLinalgConverter<xla_hlo::MulOp, false>,
PointwiseToLinalgConverter<xla_hlo::NegOp, false>,
PointwiseToLinalgConverter<xla_hlo::RealOp, false>,
PointwiseToLinalgConverter<xla_hlo::RemOp, false>,
PointwiseToLinalgConverter<xla_hlo::RsqrtOp, false>,
PointwiseToLinalgConverter<xla_hlo::SelectOp, false>,
PointwiseToLinalgConverter<xla_hlo::SinOp, false>,
PointwiseToLinalgConverter<xla_hlo::SqrtOp, false>,
PointwiseToLinalgConverter<xla_hlo::SubOp, false>,
PointwiseToLinalgConverter<xla_hlo::TanhOp, false>,
ReshapeOpConverter<xla_hlo::ReshapeOp, false>,
ReverseConverter<xla_hlo::ReverseOp, false>,
TransposeConverter<xla_hlo::TransposeOp, false>>(context);
}
std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass() {
return absl::make_unique<HloLegalizeToLinalg>();
}
static PassRegistration<HloLegalizeToLinalg> legalize_hlo_pass(
"hlo-legalize-to-linalg", "Legalize from HLO dialect to Linalg dialect");
} // namespace xla_hlo
} // namespace mlir

View File

@ -0,0 +1,188 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "third_party/absl/memory/memory.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Shape/IR/Shape.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Function.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
namespace mlir {
namespace xla_hlo {
namespace {
// TODO(frgossen): Make it variadic.
template <typename OpTy>
inline void AddLegalOpOnRankedTensor(ConversionTarget *target) {
target->addDynamicallyLegalOp<OpTy>([](OpTy op) {
return llvm::all_of((op.getOperation())->getOperandTypes(),
[&](Type t) { return t.isa<RankedTensorType>(); });
});
}
/// Unary element-wise operations on unranked tensors can be applied to the
/// flattened tensor with the same effect.
/// This pattern rewrites every such operation to
/// (i) flatten the input tensor,
/// (ii) apply the unary operation, and
/// (iii) restore the original shape.
template <typename OpTy>
struct UnaryElementwiseOpConversion : public OpRewritePattern<OpTy> {
explicit UnaryElementwiseOpConversion(MLIRContext *context)
: OpRewritePattern<OpTy>(context) {}
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
// Don't apply conversion to ops with statically shaped operands.
Value operand = op.getOperand();
auto operandTy = operand.getType().dyn_cast<TensorType>();
if (operandTy.hasRank()) return failure();
// Generate IR to flatten the operand.
auto loc = op.getLoc();
Value shape = rewriter.create<shape::ShapeOfOp>(loc, operand);
Value numElements = rewriter.create<shape::NumElementsOp>(
loc, rewriter.getType<shape::SizeType>(), shape);
Value numElementsAsIndex = rewriter.create<shape::SizeToIndexOp>(
loc, rewriter.getIndexType(), numElements);
Value flatShapeAsDimTensor =
rewriter.create<TensorFromElementsOp>(loc, numElementsAsIndex);
auto flatTensorTy = RankedTensorType::get({ShapedType::kDynamicSize},
operandTy.getElementType());
Value flatOperand = rewriter.create<xla_hlo::DynamicReshapeOp>(
loc, flatTensorTy, operand, flatShapeAsDimTensor);
// Generate IR for the actual operation.
Value flatResult = rewriter.create<OpTy>(loc, flatTensorTy, flatOperand);
// Generate IR to restore the original shape.
auto extentTensorTy = RankedTensorType::get({ShapedType::kDynamicSize},
rewriter.getIndexType());
Value shapeAsExtentTensor =
rewriter.create<shape::ToExtentTensorOp>(loc, extentTensorTy, shape);
Value result = rewriter.create<xla_hlo::DynamicReshapeOp>(
loc, operandTy, flatResult, shapeAsExtentTensor);
rewriter.replaceOp(op, result);
return success();
}
};
/// Binary element-wise operation on unranked tensors can be applied to the
/// flattened operand tensors with the same effect.
/// This pattern rewrites every such operation to
/// (i) flatten the operand tensors,
/// (ii) apply the binary operation, and
// (iii) restore the original shape.
template <typename OpTy>
struct BinaryElementwiseOpConversion : public OpRewritePattern<OpTy> {
explicit BinaryElementwiseOpConversion(MLIRContext *context)
: OpRewritePattern<OpTy>(context) {}
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
// Don't apply conversion unless both operands are unranked.
if (op.lhs().getType().template isa<RankedTensorType>() ||
op.rhs().getType().template isa<RankedTensorType>()) {
return failure();
}
// Flatten operands.
Type shapeTy = shape::ShapeType::get(rewriter.getContext());
auto loc = op.getLoc();
Value shapeLhs = rewriter.create<shape::ShapeOfOp>(loc, op.lhs());
Value shapeRhs = rewriter.create<shape::ShapeOfOp>(loc, op.rhs());
Value shape = rewriter.create<shape::AnyOp>(loc, shapeTy,
ValueRange{shapeLhs, shapeRhs});
Value numElements = rewriter.create<shape::NumElementsOp>(loc, shape);
Value numElementsAsIndex =
rewriter.create<shape::SizeToIndexOp>(loc, numElements);
Value flatShape =
rewriter.create<TensorFromElementsOp>(loc, numElementsAsIndex);
TensorType lhsTy = op.lhs().getType().template cast<TensorType>();
Type flatLhsTy = RankedTensorType::get({ShapedType::kDynamicSize},
lhsTy.getElementType());
Value flatLhs =
rewriter.create<DynamicReshapeOp>(loc, flatLhsTy, op.lhs(), flatShape);
TensorType rhsTy = op.rhs().getType().template cast<TensorType>();
Type flatRhsTy = RankedTensorType::get({ShapedType::kDynamicSize},
rhsTy.getElementType());
Value flatRhs =
rewriter.create<DynamicReshapeOp>(loc, flatRhsTy, op.rhs(), flatShape);
// Apply actual operation to flattened operands.
Value flatResult = rewriter.create<OpTy>(loc, flatLhs, flatRhs);
// Restore original shape.
auto extentTensorTy = RankedTensorType::get({ShapedType::kDynamicSize},
rewriter.getIndexType());
Value shapeAsExtentTensor =
rewriter.create<shape::ToExtentTensorOp>(loc, extentTensorTy, shape);
Value result = rewriter.create<DynamicReshapeOp>(
loc, op.getType(), flatResult, shapeAsExtentTensor);
rewriter.replaceOp(op, result);
return success();
}
};
struct TransformUnrankedHloPass
: public PassWrapper<TransformUnrankedHloPass, FunctionPass> {
void runOnFunction() override {
// Setup conversion target.
MLIRContext &ctx = getContext();
ConversionTarget target(ctx);
target.addLegalDialect<XlaHloDialect, StandardOpsDialect,
shape::ShapeDialect>();
target.addLegalOp<FuncOp>();
AddLegalOpOnRankedTensor<SqrtOp>(&target);
AddLegalOpOnRankedTensor<AddOp>(&target);
// Populate rewrite patterns.
OwningRewritePatternList patterns;
PopulateTransformUnrankedHloPatterns(&ctx, &patterns);
// Apply transformation.
if (failed(applyFullConversion(getFunction(), target, patterns)))
return signalPassFailure();
}
};
} // namespace
void PopulateTransformUnrankedHloPatterns(MLIRContext *context,
OwningRewritePatternList *patterns) {
// TODO(frgossen): Populate all unary and binary operations.
// clang-format off
patterns->insert<
BinaryElementwiseOpConversion<AddOp>,
UnaryElementwiseOpConversion<SqrtOp>>(context);
// clang-format on
}
static PassRegistration<TransformUnrankedHloPass> transform_unranked_hlo_pass(
"transform-unranked-hlo",
"Realize element-wise operations on ranked tensors where possible");
} // namespace xla_hlo
} // namespace mlir

340
lib/utils/cycle_detector.cc Normal file
View File

@ -0,0 +1,340 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/cycle_detector.h"
#include <algorithm>
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/DenseSet.h"
namespace mlir {
namespace {
using NodeSet = llvm::DenseSet<int32_t>;
using OrderedNodeSet = OrderedSet<int32_t>;
template <typename T>
struct VecStruct {
typedef llvm::SmallVector<T, 4> type;
};
template <typename T>
using Vec = typename VecStruct<T>::type;
struct Node {
// rank number assigned by Pearce-Kelly algorithm
int32_t rank;
// Temporary marker used by depth-first-search
bool visited;
// User-supplied data
void* data;
// List of immediate predecessor nodes in graph
OrderedNodeSet in;
// List of immediate successor nodes in graph
OrderedNodeSet out;
};
} // namespace
struct GraphCycles::Rep {
Vec<Node*> nodes;
// Indices for unused entries in nodes
Vec<int32_t> free_nodes;
// Temporary state.
// Results of forward DFS
Vec<int32_t> deltaf;
// Results of backward DFS
Vec<int32_t> deltab;
// All nodes to reprocess
Vec<int32_t> list;
// Rank values to assign to list entries
Vec<int32_t> merged;
// Emulates recursion stack when doing depth first search
Vec<int32_t> stack;
};
GraphCycles::GraphCycles(int32_t num_nodes) : rep_(new Rep) {
rep_->nodes.reserve(num_nodes);
for (int32_t i = 0; i < num_nodes; ++i) {
Node* n = new Node;
n->visited = false;
n->data = nullptr;
n->rank = rep_->nodes.size();
rep_->nodes.push_back(n);
}
}
GraphCycles::~GraphCycles() {
for (Vec<Node*>::size_type i = 0, e = rep_->nodes.size(); i < e; ++i) {
delete rep_->nodes[i];
}
delete rep_;
}
bool GraphCycles::HasEdge(int32_t x, int32_t y) const {
return rep_->nodes[x]->out.Contains(y);
}
void GraphCycles::RemoveEdge(int32_t x, int32_t y) {
rep_->nodes[x]->out.Erase(y);
rep_->nodes[y]->in.Erase(x);
// No need to update the rank assignment since a previous valid
// rank assignment remains valid after an edge deletion.
}
static bool ForwardDFS(GraphCycles::Rep* r, int32_t n, int32_t upper_bound);
static void BackwardDFS(GraphCycles::Rep* r, int32_t n, int32_t lower_bound);
static void Reorder(GraphCycles::Rep* r);
static void Sort(const Vec<Node*>&, Vec<int32_t>* delta);
static void MoveToList(GraphCycles::Rep* r, Vec<int32_t>* src,
Vec<int32_t>* dst);
static void ClearVisitedBits(GraphCycles::Rep* r, const Vec<int32_t>& nodes);
bool GraphCycles::InsertEdge(int32_t x, int32_t y) {
if (x == y) return false;
Rep* r = rep_;
Node* nx = r->nodes[x];
if (!nx->out.Insert(y)) {
// Edge already exists.
return true;
}
Node* ny = r->nodes[y];
ny->in.Insert(x);
if (nx->rank <= ny->rank) {
// New edge is consistent with existing rank assignment.
return true;
}
// Current rank assignments are incompatible with the new edge. Recompute.
// We only need to consider nodes that fall in the range [ny->rank,nx->rank].
if (ForwardDFS(r, y, nx->rank)) {
// Found a cycle. Undo the insertion and tell caller.
nx->out.Erase(y);
ny->in.Erase(x);
// Since we do not call Reorder() on this path, clear any visited
// markers left by ForwardDFS.
ClearVisitedBits(r, r->deltaf);
return false;
}
BackwardDFS(r, x, ny->rank);
Reorder(r);
return true;
}
// Follows the edges from producer to consumer and searchs if the node having
// rank `n` can reach the node having rank `upper_bound` using a DFS search.
// When doing DFS search, We only consider the pathes that satisfy the ranks
// of the nodes of the path are all smaller than `upper_bound`.
//
// Returns true if such path exists.
static bool ForwardDFS(GraphCycles::Rep* r, int32_t n, int32_t upper_bound) {
// Avoid recursion since stack space might be limited.
// We instead keep a stack of nodes to visit.
r->deltaf.clear();
r->stack.clear();
r->stack.push_back(n);
while (!r->stack.empty()) {
n = r->stack.back();
r->stack.pop_back();
Node* nn = r->nodes[n];
if (nn->visited) continue;
nn->visited = true;
r->deltaf.push_back(n);
for (auto w : nn->out.GetSequence()) {
Node* nw = r->nodes[w];
if (nw->rank == upper_bound) {
return true;
}
if (!nw->visited && nw->rank < upper_bound) {
r->stack.push_back(w);
}
}
}
return false;
}
// Follows the edges from consumer to producer and visit all the nodes that
// is reachable from node `n` and have rank larger than `lower_bound`.
static void BackwardDFS(GraphCycles::Rep* r, int32_t n, int32_t lower_bound) {
r->deltab.clear();
r->stack.clear();
r->stack.push_back(n);
while (!r->stack.empty()) {
n = r->stack.back();
r->stack.pop_back();
Node* nn = r->nodes[n];
if (nn->visited) continue;
nn->visited = true;
r->deltab.push_back(n);
for (auto w : nn->in.GetSequence()) {
Node* nw = r->nodes[w];
if (!nw->visited && lower_bound < nw->rank) {
r->stack.push_back(w);
}
}
}
}
// Recomputes rank assignments to make them compatible with the edges (producer
// has smaller rank than its consumer)
static void Reorder(GraphCycles::Rep* r) {
Sort(r->nodes, &r->deltab);
Sort(r->nodes, &r->deltaf);
// Adds contents of delta lists to list (backwards deltas first).
r->list.clear();
MoveToList(r, &r->deltab, &r->list);
MoveToList(r, &r->deltaf, &r->list);
// Produce sorted list of all ranks that will be reassigned.
r->merged.resize(r->deltab.size() + r->deltaf.size());
std::merge(r->deltab.begin(), r->deltab.end(), r->deltaf.begin(),
r->deltaf.end(), r->merged.begin());
// Assign the ranks in order to the collected list.
for (Vec<int32_t>::size_type i = 0, e = r->list.size(); i < e; ++i) {
r->nodes[r->list[i]]->rank = r->merged[i];
}
}
// Sorts nodes in the vector according to their ranks. Small rank first.
static void Sort(const Vec<Node*>& nodes, Vec<int32_t>* delta) {
struct ByRank {
const Vec<Node*>* nodes;
bool operator()(int32_t a, int32_t b) const {
return (*nodes)[a]->rank < (*nodes)[b]->rank;
}
};
ByRank cmp;
cmp.nodes = &nodes;
std::sort(delta->begin(), delta->end(), cmp);
}
// Collects ranks of nodes in vector `src` to vector `dst`
static void MoveToList(GraphCycles::Rep* r, Vec<int32_t>* src,
Vec<int32_t>* dst) {
for (Vec<int32_t>::size_type i = 0, e = src->size(); i < e; i++) {
int32_t w = (*src)[i];
// Replace src entry with its rank
(*src)[i] = r->nodes[w]->rank;
// Prepare for future DFS calls
r->nodes[w]->visited = false;
dst->push_back(w);
}
}
// Clears bookkeeping fileds used during the last DFS process.
static void ClearVisitedBits(GraphCycles::Rep* r, const Vec<int32_t>& nodes) {
for (Vec<int32_t>::size_type i = 0, e = nodes.size(); i < e; i++) {
r->nodes[nodes[i]]->visited = false;
}
}
bool GraphCycles::IsReachable(int32_t x, int32_t y) {
if (x == y) return true;
Rep* r = rep_;
Node* nx = r->nodes[x];
Node* ny = r->nodes[y];
if (nx->rank >= ny->rank) {
// x cannot reach y since it is after it in the topological ordering
return false;
}
// See if x can reach y using a DFS search that is limited to y's rank
bool reachable = ForwardDFS(r, x, ny->rank);
// Clear any visited markers left by ForwardDFS.
ClearVisitedBits(r, r->deltaf);
return reachable;
}
llvm::Optional<int32_t> GraphCycles::ContractEdge(int32_t a, int32_t b) {
assert(HasEdge(a, b));
RemoveEdge(a, b);
if (IsReachable(a, b)) {
// Restore the graph to its original state.
InsertEdge(a, b);
return {};
}
if (rep_->nodes[b]->in.Size() + rep_->nodes[b]->out.Size() >
rep_->nodes[a]->in.Size() + rep_->nodes[a]->out.Size()) {
// Swap "a" and "b" to minimize copying.
std::swap(a, b);
}
Node* nb = rep_->nodes[b];
OrderedNodeSet out = std::move(nb->out);
OrderedNodeSet in = std::move(nb->in);
for (int32_t y : out.GetSequence()) {
rep_->nodes[y]->in.Erase(b);
}
for (int32_t y : in.GetSequence()) {
rep_->nodes[y]->out.Erase(b);
}
rep_->free_nodes.push_back(b);
rep_->nodes[a]->out.Reserve(rep_->nodes[a]->out.Size() + out.Size());
for (int32_t y : out.GetSequence()) {
InsertEdge(a, y);
}
rep_->nodes[a]->in.Reserve(rep_->nodes[a]->in.Size() + in.Size());
for (int32_t y : in.GetSequence()) {
InsertEdge(y, a);
}
// Note, if the swap happened it might be what originally was called "b".
return a;
}
std::vector<int32_t> GraphCycles::SuccessorsCopy(int32_t node) const {
return rep_->nodes[node]->out.GetSequence();
}
namespace {
void SortInPostOrder(const Vec<Node*>& nodes, std::vector<int32_t>* to_sort) {
std::sort(to_sort->begin(), to_sort->end(), [&](int32_t a, int32_t b) {
return nodes[a]->rank > nodes[b]->rank;
});
}
} // namespace
std::vector<int32_t> GraphCycles::AllNodesInPostOrder() const {
llvm::DenseSet<int32_t> free_nodes_set;
for (int32_t n : rep_->free_nodes) free_nodes_set.insert(n);
std::vector<int32_t> all_nodes;
all_nodes.reserve(rep_->nodes.size() - free_nodes_set.size());
for (size_t i = 0, e = rep_->nodes.size(); i < e; i++) {
if (!free_nodes_set.count(i)) {
all_nodes.push_back(i);
}
}
SortInPostOrder(rep_->nodes, &all_nodes);
return all_nodes;
}
} // namespace mlir

View File

@ -0,0 +1,89 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/cycle_detector.h"
#include "third_party/tensorflow/compiler/xla/test.h"
class GraphCyclesTest : public ::testing::Test {
public:
GraphCyclesTest() : g_(100) {}
bool AddEdge(int x, int y) { return g_.InsertEdge(x, y); }
void AddMultiples() {
// For every node x > 0: add edge to 2*x, 3*x
for (int x = 1; x < 25; x++) {
EXPECT_TRUE(AddEdge(x, 2 * x)) << x;
EXPECT_TRUE(AddEdge(x, 3 * x)) << x;
}
}
mlir::GraphCycles g_;
};
TEST_F(GraphCyclesTest, NoCycle) { AddMultiples(); }
TEST_F(GraphCyclesTest, SimpleCycle) {
AddMultiples();
EXPECT_FALSE(AddEdge(8, 4));
}
TEST_F(GraphCyclesTest, IndirectCycle) {
AddMultiples();
EXPECT_TRUE(AddEdge(16, 9));
EXPECT_FALSE(AddEdge(9, 2));
}
TEST_F(GraphCyclesTest, RemoveEdge) {
EXPECT_TRUE(AddEdge(1, 2));
EXPECT_TRUE(AddEdge(2, 3));
EXPECT_TRUE(AddEdge(3, 4));
EXPECT_TRUE(AddEdge(4, 5));
g_.RemoveEdge(2, 3);
EXPECT_FALSE(g_.HasEdge(2, 3));
}
TEST_F(GraphCyclesTest, IsReachable) {
EXPECT_TRUE(AddEdge(1, 2));
EXPECT_TRUE(AddEdge(2, 3));
EXPECT_TRUE(AddEdge(3, 4));
EXPECT_TRUE(AddEdge(4, 5));
EXPECT_TRUE(g_.IsReachable(1, 5));
EXPECT_FALSE(g_.IsReachable(5, 1));
}
TEST_F(GraphCyclesTest, ContractEdge) {
ASSERT_TRUE(AddEdge(1, 2));
ASSERT_TRUE(AddEdge(1, 3));
ASSERT_TRUE(AddEdge(2, 3));
ASSERT_TRUE(AddEdge(2, 4));
ASSERT_TRUE(AddEdge(3, 4));
// It will introduce a cycle if the edge is contracted
EXPECT_FALSE(g_.ContractEdge(1, 3).hasValue());
EXPECT_TRUE(g_.HasEdge(1, 3));
// Node (2) has more edges.
EXPECT_EQ(*g_.ContractEdge(1, 2), 2);
EXPECT_TRUE(g_.HasEdge(2, 3));
EXPECT_TRUE(g_.HasEdge(2, 4));
EXPECT_TRUE(g_.HasEdge(3, 4));
// Node (2) has more edges.
EXPECT_EQ(*g_.ContractEdge(2, 3), 2);
EXPECT_TRUE(g_.HasEdge(2, 4));
}