Move XLA-independent transforms to the new MLIR-HLO directory
This is as straighforward as possible, more cleanup/rewrite to come. PiperOrigin-RevId: 319849713
This commit is contained in:
parent
72010faaa7
commit
31dc1b21eb
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_CHLO_OPS_H_
|
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_CHLO_OPS_H_
|
||||||
#define TENSORFLOW_COMPILER_MLIR_XLA_IR_CHLO_OPS_H_
|
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_CHLO_OPS_H_
|
||||||
|
|
||||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringRef.h"
|
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringRef.h"
|
||||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Dialect.h"
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Dialect.h"
|
||||||
|
@ -42,4 +42,4 @@ class XlaHloClientDialect : public Dialect {
|
||||||
} // namespace xla_chlo
|
} // namespace xla_chlo
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
#endif // TENSORFLOW_COMPILER_MLIR_XLA_IR_CHLO_OPS_H_
|
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_CHLO_OPS_H_
|
||||||
|
|
|
@ -15,8 +15,8 @@ limitations under the License.
|
||||||
|
|
||||||
// This file defines the operations used in the XLA dialect.
|
// This file defines the operations used in the XLA dialect.
|
||||||
|
|
||||||
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_OPS_H_
|
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_H_
|
||||||
#define TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_OPS_H_
|
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_H_
|
||||||
|
|
||||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringRef.h"
|
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringRef.h"
|
||||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
|
||||||
|
@ -96,4 +96,4 @@ LogicalResult deriveShapeFromFirstOperand(
|
||||||
} // end namespace xla_hlo
|
} // end namespace xla_hlo
|
||||||
} // end namespace mlir
|
} // end namespace mlir
|
||||||
|
|
||||||
#endif // TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_OPS_H_
|
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_H_
|
||||||
|
|
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_INFER_FUSIBILITY_OP_INTERFACE_H_
|
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_INFER_FUSIBILITY_OP_INTERFACE_H_
|
||||||
#define TENSORFLOW_COMPILER_MLIR_XLA_IR_INFER_FUSIBILITY_OP_INTERFACE_H_
|
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_INFER_FUSIBILITY_OP_INTERFACE_H_
|
||||||
|
|
||||||
#include "mlir/IR/OpDefinition.h"
|
#include "mlir/IR/OpDefinition.h"
|
||||||
#include "mlir/IR/StandardTypes.h"
|
#include "mlir/IR/StandardTypes.h"
|
||||||
|
@ -25,4 +25,4 @@ namespace mlir {
|
||||||
|
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
#endif // TENSORFLOW_COMPILER_MLIR_XLA_IR_INFER_FUSIBILITY_OP_INTERFACE_H_
|
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_INFER_FUSIBILITY_OP_INTERFACE_H_
|
||||||
|
|
|
@ -15,8 +15,8 @@ limitations under the License.
|
||||||
|
|
||||||
// This file defines the operations used in the LXLA dialect.
|
// This file defines the operations used in the LXLA dialect.
|
||||||
|
|
||||||
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_LHLO_OPS_H_
|
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_OPS_H_
|
||||||
#define TENSORFLOW_COMPILER_MLIR_XLA_IR_LHLO_OPS_H_
|
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_OPS_H_
|
||||||
|
|
||||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringRef.h"
|
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringRef.h"
|
||||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
|
||||||
|
@ -49,4 +49,4 @@ class XlaLhloDialect : public Dialect {
|
||||||
} // namespace xla_lhlo
|
} // namespace xla_lhlo
|
||||||
} // end namespace mlir
|
} // end namespace mlir
|
||||||
|
|
||||||
#endif // TENSORFLOW_COMPILER_MLIR_XLA_IR_LHLO_OPS_H_
|
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_OPS_H_
|
||||||
|
|
|
@ -0,0 +1,80 @@
|
||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_HLO_TO_LHLO_OP_H_
|
||||||
|
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_HLO_TO_LHLO_OP_H_
|
||||||
|
|
||||||
|
#include <type_traits>
|
||||||
|
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace xla_hlo {
|
||||||
|
|
||||||
|
template <typename HloOpTy>
|
||||||
|
struct HloToLhloOpImpl {
|
||||||
|
using Type = std::false_type;
|
||||||
|
};
|
||||||
|
template <typename HloOpTy>
|
||||||
|
using HloToLhloOp = typename HloToLhloOpImpl<HloOpTy>::Type;
|
||||||
|
|
||||||
|
#define MAP_HLO_TO_LHLO(OpName) \
|
||||||
|
template <> \
|
||||||
|
struct HloToLhloOpImpl<xla_hlo::OpName> { \
|
||||||
|
using Type = xla_lhlo::OpName; \
|
||||||
|
}
|
||||||
|
|
||||||
|
MAP_HLO_TO_LHLO(AbsOp);
|
||||||
|
MAP_HLO_TO_LHLO(AddOp);
|
||||||
|
MAP_HLO_TO_LHLO(AndOp);
|
||||||
|
MAP_HLO_TO_LHLO(BroadcastInDimOp);
|
||||||
|
MAP_HLO_TO_LHLO(CeilOp);
|
||||||
|
MAP_HLO_TO_LHLO(ConstOp);
|
||||||
|
MAP_HLO_TO_LHLO(CompareOp);
|
||||||
|
MAP_HLO_TO_LHLO(ComplexOp);
|
||||||
|
MAP_HLO_TO_LHLO(ConvOp);
|
||||||
|
MAP_HLO_TO_LHLO(ConvertOp);
|
||||||
|
MAP_HLO_TO_LHLO(CopyOp);
|
||||||
|
MAP_HLO_TO_LHLO(CosOp);
|
||||||
|
MAP_HLO_TO_LHLO(DivOp);
|
||||||
|
MAP_HLO_TO_LHLO(DotOp);
|
||||||
|
MAP_HLO_TO_LHLO(ExpOp);
|
||||||
|
MAP_HLO_TO_LHLO(GatherOp);
|
||||||
|
MAP_HLO_TO_LHLO(ImagOp);
|
||||||
|
MAP_HLO_TO_LHLO(IotaOp);
|
||||||
|
MAP_HLO_TO_LHLO(LogOp);
|
||||||
|
MAP_HLO_TO_LHLO(MaxOp);
|
||||||
|
MAP_HLO_TO_LHLO(MinOp);
|
||||||
|
MAP_HLO_TO_LHLO(MulOp);
|
||||||
|
MAP_HLO_TO_LHLO(NegOp);
|
||||||
|
MAP_HLO_TO_LHLO(RealOp);
|
||||||
|
MAP_HLO_TO_LHLO(ReduceOp);
|
||||||
|
MAP_HLO_TO_LHLO(ReshapeOp);
|
||||||
|
MAP_HLO_TO_LHLO(RemOp);
|
||||||
|
MAP_HLO_TO_LHLO(RsqrtOp);
|
||||||
|
MAP_HLO_TO_LHLO(SelectOp);
|
||||||
|
MAP_HLO_TO_LHLO(SignOp);
|
||||||
|
MAP_HLO_TO_LHLO(SinOp);
|
||||||
|
MAP_HLO_TO_LHLO(SqrtOp);
|
||||||
|
MAP_HLO_TO_LHLO(SubOp);
|
||||||
|
MAP_HLO_TO_LHLO(TanhOp);
|
||||||
|
|
||||||
|
#undef MAP_HLO_TO_LHLO
|
||||||
|
|
||||||
|
} // namespace xla_hlo
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_HLO_TO_LHLO_OP_H_
|
|
@ -0,0 +1,510 @@
|
||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_XLA_TO_SCALAR_OP_H_
|
||||||
|
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_XLA_TO_SCALAR_OP_H_
|
||||||
|
|
||||||
|
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringRef.h"
|
||||||
|
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringSwitch.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace xla_lhlo {
|
||||||
|
namespace impl {
|
||||||
|
|
||||||
|
// A struct to map LhloBinaryOpTy type to the corresponding floating-point and
|
||||||
|
// integer scalar operation types.
|
||||||
|
template <typename LhloBinaryOpTy>
|
||||||
|
struct LhloToScalarOp;
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct LhloToScalarOp<xla_lhlo::AddOp> {
|
||||||
|
using FOp = ::mlir::AddFOp;
|
||||||
|
using IOp = ::mlir::AddIOp;
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
struct LhloToScalarOp<xla_lhlo::CompareOp> {
|
||||||
|
using FOp = ::mlir::CmpFOp;
|
||||||
|
using IOp = ::mlir::CmpIOp;
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
struct LhloToScalarOp<xla_lhlo::DivOp> {
|
||||||
|
using FOp = ::mlir::DivFOp;
|
||||||
|
using IOp = ::mlir::SignedDivIOp;
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
struct LhloToScalarOp<xla_lhlo::MulOp> {
|
||||||
|
using FOp = ::mlir::MulFOp;
|
||||||
|
using IOp = ::mlir::MulIOp;
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
struct LhloToScalarOp<xla_lhlo::RemOp> {
|
||||||
|
using FOp = ::mlir::RemFOp;
|
||||||
|
using IOp = ::mlir::SignedRemIOp;
|
||||||
|
};
|
||||||
|
template <>
|
||||||
|
struct LhloToScalarOp<xla_lhlo::SubOp> {
|
||||||
|
using FOp = ::mlir::SubFOp;
|
||||||
|
using IOp = ::mlir::SubIOp;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename LhloBinaryOpTy>
|
||||||
|
struct ScalarOp {
|
||||||
|
using FOp = typename LhloToScalarOp<LhloBinaryOpTy>::FOp;
|
||||||
|
using IOp = typename LhloToScalarOp<LhloBinaryOpTy>::IOp;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Alias for the map from LHLO binary op type to STD floating-point op type.
|
||||||
|
template <typename LhloOp>
|
||||||
|
using ScalarFOp = typename ScalarOp<LhloOp>::FOp;
|
||||||
|
// Alias for the map from LHLO binary op type to STD integer op type.
|
||||||
|
template <typename LhloOp>
|
||||||
|
using ScalarIOp = typename ScalarOp<LhloOp>::IOp;
|
||||||
|
|
||||||
|
template <typename... Args>
|
||||||
|
struct MapLhloOpToStdScalarOpImpl {
|
||||||
|
Value operator()(Location loc, ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Value> args, OpBuilder* b) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename StdScalarOp>
|
||||||
|
struct MapLhloOpToStdScalarOpImpl<StdScalarOp> {
|
||||||
|
Value operator()(Location loc, ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Value> args, OpBuilder* b) {
|
||||||
|
return b->template create<StdScalarOp>(loc, result_types, args, mlir::None);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename SupportedType, typename StdScalarOp, typename... Args>
|
||||||
|
struct MapLhloOpToStdScalarOpImpl<SupportedType, StdScalarOp, Args...> {
|
||||||
|
Value operator()(Location loc, ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Value> args, OpBuilder* b) {
|
||||||
|
Type element_type = args.front().getType();
|
||||||
|
if (element_type.isa<SupportedType>()) {
|
||||||
|
return b->template create<StdScalarOp>(loc, result_types, args,
|
||||||
|
mlir::None);
|
||||||
|
}
|
||||||
|
return MapLhloOpToStdScalarOpImpl<Args...>{}(loc, result_types, args, b);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Inserts the computation that corresponds to the body of the loop for lowered
|
||||||
|
// LHLO unary/binary op. Returns the value for the result.
|
||||||
|
template <typename LhloOpTy>
|
||||||
|
inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Value> args, OpBuilder* b) {
|
||||||
|
return MapLhloOpToStdScalarOpImpl<IntegerType, ScalarIOp<LhloOpTy>, FloatType,
|
||||||
|
ScalarFOp<LhloOpTy>>{}(loc, result_types,
|
||||||
|
args, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline Value MapLhloOpToStdScalarOp<xla_lhlo::AbsOp>(
|
||||||
|
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||||
|
OpBuilder* b) {
|
||||||
|
Type element_type = args.front().getType();
|
||||||
|
if (element_type.isa<FloatType>()) {
|
||||||
|
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::AbsFOp>{}(
|
||||||
|
loc, result_types, args, b);
|
||||||
|
}
|
||||||
|
if (element_type.isa<IntegerType>()) {
|
||||||
|
// xla_lhlo.abs(x, result) -> result = select((x > 0), x, sub(0, x))
|
||||||
|
Value lhs = args[0];
|
||||||
|
auto integer_type = element_type.dyn_cast<IntegerType>();
|
||||||
|
|
||||||
|
auto zero_intval =
|
||||||
|
b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth());
|
||||||
|
auto lhs_gt_zero = b->create<ScalarIOp<CompareOp>>(loc, CmpIPredicate::sge,
|
||||||
|
lhs, zero_intval);
|
||||||
|
auto neg_val = b->create<ScalarIOp<xla_lhlo::SubOp>>(loc, zero_intval, lhs);
|
||||||
|
return b->create<::mlir::SelectOp>(loc, lhs_gt_zero, lhs, neg_val);
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline Value MapLhloOpToStdScalarOp<xla_lhlo::AndOp>(
|
||||||
|
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||||
|
OpBuilder* b) {
|
||||||
|
return MapLhloOpToStdScalarOpImpl<IntegerType, ::mlir::AndOp>{}(
|
||||||
|
loc, result_types, args, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename PredicateType>
|
||||||
|
inline Optional<PredicateType> getCmpPredicate(
|
||||||
|
StringRef xla_comparison_direction) {
|
||||||
|
return llvm::None;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline Optional<CmpFPredicate> getCmpPredicate<CmpFPredicate>(
|
||||||
|
StringRef xla_comparison_direction) {
|
||||||
|
return llvm::StringSwitch<Optional<CmpFPredicate>>(xla_comparison_direction)
|
||||||
|
.Case("EQ", CmpFPredicate::OEQ)
|
||||||
|
.Case("NE", CmpFPredicate::ONE)
|
||||||
|
.Case("GE", CmpFPredicate::OGE)
|
||||||
|
.Case("GT", CmpFPredicate::OGT)
|
||||||
|
.Case("LE", CmpFPredicate::OLE)
|
||||||
|
.Case("LT", CmpFPredicate::OLT)
|
||||||
|
.Default(llvm::None);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline Optional<CmpIPredicate> getCmpPredicate<CmpIPredicate>(
|
||||||
|
StringRef xla_comparison_direction) {
|
||||||
|
return llvm::StringSwitch<Optional<CmpIPredicate>>(xla_comparison_direction)
|
||||||
|
.Case("EQ", CmpIPredicate::eq)
|
||||||
|
.Case("NE", CmpIPredicate::ne)
|
||||||
|
.Case("GE", CmpIPredicate::sge)
|
||||||
|
.Case("GT", CmpIPredicate::sgt)
|
||||||
|
.Case("LE", CmpIPredicate::sle)
|
||||||
|
.Case("LT", CmpIPredicate::slt)
|
||||||
|
.Default(llvm::None);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename XLACompareOpTy>
|
||||||
|
inline Value MapXlaCompareOpToStdScalarOp(Location loc,
|
||||||
|
StringRef comparison_direction,
|
||||||
|
ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Value> args, OpBuilder* b) {
|
||||||
|
const auto& lhs = args[0];
|
||||||
|
const auto& rhs = args[1];
|
||||||
|
Type element_type = lhs.getType();
|
||||||
|
if (element_type.isSignlessInteger()) {
|
||||||
|
Optional<CmpIPredicate> predicate =
|
||||||
|
getCmpPredicate<CmpIPredicate>(comparison_direction);
|
||||||
|
assert(predicate.hasValue() && "expected valid comparison direction");
|
||||||
|
return b->create<ScalarIOp<XLACompareOpTy>>(loc, predicate.getValue(), lhs,
|
||||||
|
rhs);
|
||||||
|
}
|
||||||
|
if (element_type.isa<FloatType>()) {
|
||||||
|
Optional<CmpFPredicate> predicate =
|
||||||
|
getCmpPredicate<CmpFPredicate>(comparison_direction);
|
||||||
|
assert(predicate.hasValue() && "expected valid comparison direction");
|
||||||
|
return b->create<ScalarFOp<XLACompareOpTy>>(loc, predicate.getValue(), lhs,
|
||||||
|
rhs);
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline Value MapLhloOpToStdScalarOp<xla_lhlo::CopyOp>(
|
||||||
|
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||||
|
OpBuilder* b) {
|
||||||
|
return args.front();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline Value MapLhloOpToStdScalarOp<xla_lhlo::ExpOp>(
|
||||||
|
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||||
|
OpBuilder* b) {
|
||||||
|
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::ExpOp>{}(
|
||||||
|
loc, result_types, args, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline Value MapLhloOpToStdScalarOp<xla_lhlo::CeilOp>(
|
||||||
|
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||||
|
OpBuilder* b) {
|
||||||
|
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::CeilFOp>{}(
|
||||||
|
loc, result_types, args, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline Value MapLhloOpToStdScalarOp<xla_lhlo::ComplexOp>(
|
||||||
|
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||||
|
OpBuilder* b) {
|
||||||
|
return MapLhloOpToStdScalarOpImpl<CreateComplexOp>{}(loc, result_types, args,
|
||||||
|
b);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline Value MapLhloOpToStdScalarOp<xla_lhlo::RealOp>(
|
||||||
|
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||||
|
OpBuilder* b) {
|
||||||
|
return MapLhloOpToStdScalarOpImpl<ReOp>{}(loc, result_types, args, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline Value MapLhloOpToStdScalarOp<xla_lhlo::ImagOp>(
|
||||||
|
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||||
|
OpBuilder* b) {
|
||||||
|
return MapLhloOpToStdScalarOpImpl<ImOp>{}(loc, result_types, args, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline Value MapLhloOpToStdScalarOp<xla_lhlo::ConvertOp>(
|
||||||
|
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||||
|
OpBuilder* b) {
|
||||||
|
Type sourceType = args.front().getType();
|
||||||
|
Type targetType = result_types.front();
|
||||||
|
|
||||||
|
if (mlir::SIToFPOp::areCastCompatible(sourceType, targetType)) {
|
||||||
|
return b->create<mlir::SIToFPOp>(loc, result_types, args, mlir::None);
|
||||||
|
} else if (sourceType.isa<FloatType>() && targetType.isa<FloatType>()) {
|
||||||
|
FloatType src = sourceType.cast<FloatType>();
|
||||||
|
FloatType res = targetType.cast<FloatType>();
|
||||||
|
if (src.getWidth() > res.getWidth()) {
|
||||||
|
return b->create<mlir::FPTruncOp>(loc, result_types, args, mlir::None);
|
||||||
|
} else if (src.getWidth() < res.getWidth()) {
|
||||||
|
return b->create<mlir::FPExtOp>(loc, result_types, args, mlir::None);
|
||||||
|
}
|
||||||
|
// No conversion is needed for the same width floats
|
||||||
|
return args.front();
|
||||||
|
}
|
||||||
|
if (sourceType.isSignlessInteger() && targetType.isSignlessInteger()) {
|
||||||
|
IntegerType src = sourceType.cast<IntegerType>();
|
||||||
|
IntegerType res = targetType.cast<IntegerType>();
|
||||||
|
if (src.getWidth() > res.getWidth()) {
|
||||||
|
return b->create<mlir::TruncateIOp>(loc, result_types, args, mlir::None);
|
||||||
|
} else if (src.getWidth() < res.getWidth()) {
|
||||||
|
return b->create<mlir::ZeroExtendIOp>(loc, result_types, args,
|
||||||
|
mlir::None);
|
||||||
|
}
|
||||||
|
// No conversion is needed for the same width integers
|
||||||
|
return args.front();
|
||||||
|
}
|
||||||
|
if (mlir::FPToSIOp::areCastCompatible(sourceType, targetType)) {
|
||||||
|
return b->create<mlir::FPToSIOp>(loc, result_types, args, mlir::None);
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline Value MapLhloOpToStdScalarOp<xla_lhlo::DotOp>(
|
||||||
|
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||||
|
OpBuilder* b) {
|
||||||
|
// Dot Op converter from lhlo to affine only accepts float and integer types.
|
||||||
|
const auto& lhs = args[0];
|
||||||
|
const auto& rhs = args[1];
|
||||||
|
const auto& result = args[2];
|
||||||
|
Type element_type = lhs.getType();
|
||||||
|
if (element_type.isa<FloatType>()) {
|
||||||
|
Value float_mul = MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::MulFOp>{}(
|
||||||
|
loc, result_types, {lhs, rhs}, b);
|
||||||
|
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::AddFOp>{}(
|
||||||
|
loc, result_types, {float_mul, result}, b);
|
||||||
|
}
|
||||||
|
if (element_type.isa<IntegerType>()) {
|
||||||
|
Value int_mul = MapLhloOpToStdScalarOpImpl<IntegerType, ::mlir::MulIOp>{}(
|
||||||
|
loc, result_types, {lhs, rhs}, b);
|
||||||
|
return MapLhloOpToStdScalarOpImpl<IntegerType, ::mlir::AddIOp>{}(
|
||||||
|
loc, result_types, {int_mul, result}, b);
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline Value MapLhloOpToStdScalarOp<xla_lhlo::CosOp>(
|
||||||
|
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||||
|
OpBuilder* b) {
|
||||||
|
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::CosOp>{}(
|
||||||
|
loc, result_types, args, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline Value MapLhloOpToStdScalarOp<xla_lhlo::SinOp>(
|
||||||
|
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||||
|
OpBuilder* b) {
|
||||||
|
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::SinOp>{}(
|
||||||
|
loc, result_types, args, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Implements the conversion of XLA op to scalar op (to use within region of a
|
||||||
|
/// linalg.generic op) for compare-select style operations like min/max.
|
||||||
|
template <typename... Args>
|
||||||
|
struct XlaCompareSelectOpToStdScalarOp {
|
||||||
|
static Value map(Location loc, StringRef comparison_direction,
|
||||||
|
ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||||
|
OpBuilder* b) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Specialization which allows converting to a comparison operation in standard
|
||||||
|
/// dialect with a given predicate based on the element type of the operand.
|
||||||
|
template <typename SupportedType, typename StdCompareOp, typename Predicate,
|
||||||
|
typename... Args>
|
||||||
|
struct XlaCompareSelectOpToStdScalarOp<SupportedType, StdCompareOp, Predicate,
|
||||||
|
Args...> {
|
||||||
|
static Value map(Location loc, StringRef comparison_direction,
|
||||||
|
ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||||
|
OpBuilder* b) {
|
||||||
|
Type element_type = args.front().getType();
|
||||||
|
if (element_type.isa<SupportedType>()) {
|
||||||
|
auto predicate = getCmpPredicate<Predicate>(comparison_direction);
|
||||||
|
assert(predicate.hasValue() && "expected valid comparison direction");
|
||||||
|
auto cmp = b->template create<StdCompareOp>(loc, predicate.getValue(),
|
||||||
|
args[0], args[1]);
|
||||||
|
return b->create<::mlir::SelectOp>(loc, cmp, args[0], args[1]);
|
||||||
|
}
|
||||||
|
return XlaCompareSelectOpToStdScalarOp<Args...>::map(
|
||||||
|
loc, comparison_direction, result_types, args, b);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline Value MapLhloOpToStdScalarOp<xla_lhlo::LogOp>(
|
||||||
|
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||||
|
OpBuilder* b) {
|
||||||
|
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::LogOp>{}(
|
||||||
|
loc, result_types, args, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline Value MapLhloOpToStdScalarOp<xla_lhlo::MaxOp>(
|
||||||
|
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||||
|
OpBuilder* b) {
|
||||||
|
return XlaCompareSelectOpToStdScalarOp<
|
||||||
|
IntegerType, ScalarIOp<xla_lhlo::CompareOp>, CmpIPredicate, FloatType,
|
||||||
|
ScalarFOp<xla_lhlo::CompareOp>, CmpFPredicate>::map(loc, "GT",
|
||||||
|
result_types, args,
|
||||||
|
b);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline Value MapLhloOpToStdScalarOp<xla_lhlo::MinOp>(
|
||||||
|
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||||
|
OpBuilder* b) {
|
||||||
|
return XlaCompareSelectOpToStdScalarOp<
|
||||||
|
IntegerType, ScalarIOp<xla_lhlo::CompareOp>, CmpIPredicate, FloatType,
|
||||||
|
ScalarFOp<xla_lhlo::CompareOp>, CmpFPredicate>::map(loc, "LT",
|
||||||
|
result_types, args,
|
||||||
|
b);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline Value MapLhloOpToStdScalarOp<xla_lhlo::NegOp>(
|
||||||
|
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||||
|
OpBuilder* b) {
|
||||||
|
Type element_type = args.front().getType();
|
||||||
|
if (element_type.isa<FloatType>()) {
|
||||||
|
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::NegFOp>{}(
|
||||||
|
loc, result_types, args, b);
|
||||||
|
}
|
||||||
|
if (element_type.isa<IntegerType>()) {
|
||||||
|
// xla_lhlo.neg(x, result) -> result = sub(0, x)
|
||||||
|
Value lhs = args[0];
|
||||||
|
auto integer_type = element_type.dyn_cast<IntegerType>();
|
||||||
|
|
||||||
|
auto zero_intval =
|
||||||
|
b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth());
|
||||||
|
return b->create<ScalarIOp<xla_lhlo::SubOp>>(loc, zero_intval, lhs);
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline Value MapLhloOpToStdScalarOp<xla_lhlo::RsqrtOp>(
|
||||||
|
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||||
|
OpBuilder* b) {
|
||||||
|
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::RsqrtOp>{}(
|
||||||
|
loc, result_types, args, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline Value MapLhloOpToStdScalarOp<xla_lhlo::SelectOp>(
|
||||||
|
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||||
|
OpBuilder* b) {
|
||||||
|
return MapLhloOpToStdScalarOpImpl<::mlir::SelectOp>{}(loc, result_types, args,
|
||||||
|
b);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline Value MapLhloOpToStdScalarOp<xla_lhlo::SignOp>(
|
||||||
|
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||||
|
OpBuilder* b) {
|
||||||
|
Type element_type = args.front().getType();
|
||||||
|
if (element_type.isa<FloatType>()) {
|
||||||
|
FloatType float_type = element_type.cast<FloatType>();
|
||||||
|
APFloat const_value = float_type.isF32() ? APFloat(1.0f) : APFloat(1.0);
|
||||||
|
Value one = b->create<mlir::ConstantFloatOp>(loc, const_value, float_type);
|
||||||
|
return b->create<::mlir::CopySignOp>(loc, result_types, one, args[0]);
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline Value MapLhloOpToStdScalarOp<xla_lhlo::SqrtOp>(
|
||||||
|
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||||
|
OpBuilder* b) {
|
||||||
|
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::SqrtOp>{}(
|
||||||
|
loc, result_types, args, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline Value MapLhloOpToStdScalarOp<xla_lhlo::TanhOp>(
|
||||||
|
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||||
|
OpBuilder* b) {
|
||||||
|
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::TanhOp>{}(
|
||||||
|
loc, result_types, args, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace impl
|
||||||
|
|
||||||
|
struct XlaOpToStdScalarOp {
|
||||||
|
// Implementation for LHLO ops except xla_lhlo::CompareOp.
|
||||||
|
template <typename XlaOpTy, typename LhloOpTy = XlaOpTy,
|
||||||
|
typename = std::enable_if_t<
|
||||||
|
!std::is_same<LhloOpTy, xla_lhlo::CompareOp>::value &&
|
||||||
|
std::is_same<typename xla_hlo::HloToLhloOp<LhloOpTy>,
|
||||||
|
std::false_type>::value>>
|
||||||
|
static Value map(XlaOpTy op, ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Value> args, OpBuilder* b, unsigned i = 0) {
|
||||||
|
return impl::MapLhloOpToStdScalarOp<LhloOpTy>(op.getLoc(), result_types,
|
||||||
|
args, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implementation for HLO ops except xla_hlo::CompareOp.
|
||||||
|
template <typename XlaOpTy, typename LhloOpTy = xla_hlo::HloToLhloOp<XlaOpTy>,
|
||||||
|
typename = std::enable_if_t<
|
||||||
|
!std::is_same<LhloOpTy, xla_lhlo::CompareOp>::value &&
|
||||||
|
!std::is_same<LhloOpTy, std::false_type>::value>>
|
||||||
|
static Value map(XlaOpTy op, ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Value> args, OpBuilder* b, int i = 0) {
|
||||||
|
return impl::MapLhloOpToStdScalarOp<LhloOpTy>(op.getLoc(), result_types,
|
||||||
|
args, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implementation for xla_lhlo::CompareOp.
|
||||||
|
template <typename LhloOpTy, typename = std::enable_if_t<std::is_same<
|
||||||
|
LhloOpTy, xla_lhlo::CompareOp>::value>>
|
||||||
|
static Value map(xla_lhlo::CompareOp op, ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Value> args, OpBuilder* b) {
|
||||||
|
auto comparison_direction = op.comparison_direction();
|
||||||
|
return impl::MapXlaCompareOpToStdScalarOp<xla_lhlo::CompareOp>(
|
||||||
|
op.getLoc(), comparison_direction, result_types, args, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implementation for xla_hlo::CompareOp.
|
||||||
|
template <typename HloOpTy, typename = std::enable_if_t<std::is_same<
|
||||||
|
HloOpTy, xla_hlo::CompareOp>::value>>
|
||||||
|
static Value map(xla_hlo::CompareOp op, ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Value> args, OpBuilder* b) {
|
||||||
|
auto comparison_direction = op.comparison_direction();
|
||||||
|
return impl::MapXlaCompareOpToStdScalarOp<xla_lhlo::CompareOp>(
|
||||||
|
op.getLoc(), comparison_direction, result_types, args, b);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace xla_lhlo
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_XLA_TO_SCALAR_OP_H_
|
|
@ -0,0 +1,105 @@
|
||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_PASSES_H_
|
||||||
|
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_PASSES_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/ArrayRef.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
|
||||||
|
class FuncOp;
|
||||||
|
class ModuleOp;
|
||||||
|
class Operation;
|
||||||
|
template <typename T>
|
||||||
|
class OperationPass;
|
||||||
|
class Pass;
|
||||||
|
|
||||||
|
namespace xla_hlo {
|
||||||
|
|
||||||
|
/// Lowers HLO control flow ops to the Standard dialect.
|
||||||
|
std::unique_ptr<OperationPass<FuncOp>> createLegalizeControlFlowPass();
|
||||||
|
|
||||||
|
/// Lowers from HLO dialect to Standard dialect.
|
||||||
|
std::unique_ptr<OperationPass<FuncOp>> createLegalizeToStdPass();
|
||||||
|
|
||||||
|
/// Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary
|
||||||
|
/// buffers if necessary. If `results_escape_functions` is set to true,
|
||||||
|
/// allocated buffers for function results will be returned and escape the
|
||||||
|
/// function. Otherwise, the signature is rewritten with extra arguments for the
|
||||||
|
/// buffers that are to be used for results.
|
||||||
|
std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass(
|
||||||
|
bool results_escape_functions = false);
|
||||||
|
|
||||||
|
// Lowers from HLO dialect to Linalg dialect.
|
||||||
|
std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass();
|
||||||
|
|
||||||
|
// Transforms unranked HLO operations to ranked ones where possible.
|
||||||
|
std::unique_ptr<OperationPass<FuncOp>> createTransformUnrankedHloPass();
|
||||||
|
|
||||||
|
// Sinks constants implicitly captured in control flow regions. This is
|
||||||
|
// necessary to export to XLA.
|
||||||
|
std::unique_ptr<OperationPass<FuncOp>> createSinkConstantsToControlFlowPass();
|
||||||
|
|
||||||
|
// fuse xla_hlo ops to kLoop/kInput fusion patterns
|
||||||
|
std::unique_ptr<OperationPass<FuncOp>> createXlaHloFusionPass();
|
||||||
|
|
||||||
|
} // namespace xla_hlo
|
||||||
|
|
||||||
|
namespace xla_lhlo {
|
||||||
|
|
||||||
|
// Lowers from LHLO dialect to Affine dialect.
|
||||||
|
std::unique_ptr<OperationPass<FuncOp>> createLegalizeToAffinePass();
|
||||||
|
|
||||||
|
// Lowers from LHLO dialect to Linalg dialect.
|
||||||
|
std::unique_ptr<OperationPass<FuncOp>> createLegalizeLhloToLinalgPass();
|
||||||
|
|
||||||
|
// Lowers from LHLO dialect to GPU dialect.
|
||||||
|
std::unique_ptr<OperationPass<FuncOp>> createLegalizeToGpuPass();
|
||||||
|
|
||||||
|
// Fuses linalg ops obtained after LHLO lowering. To enable fusion,
|
||||||
|
// operations are first tiled.
|
||||||
|
//
|
||||||
|
// When 'use_parallel_loops' is set, the tiling will use scf.parallel
|
||||||
|
// operations. Otherwise, scf.for operations are used.
|
||||||
|
//
|
||||||
|
// 'tile_sizes' provides the tile sizes to use for tiling. If the linalg
|
||||||
|
// operation has more dimensions than tile sizes provided, 1 is used as
|
||||||
|
// default.
|
||||||
|
std::unique_ptr<OperationPass<FuncOp>> createLhloFuseLinalg(
|
||||||
|
bool use_parallel_loops = false, llvm::ArrayRef<unsigned> tile_sizes = {});
|
||||||
|
|
||||||
|
// Removes unnecessary LHLO copies which copy from the allocated buffers to the
|
||||||
|
// block arguments. The block arguments are used instead of all uses of these
|
||||||
|
// buffers. The buffers are freed. This pass only works in regions that contain
|
||||||
|
// a single block.
|
||||||
|
std::unique_ptr<Pass> createLhloCopyRemovalPass();
|
||||||
|
|
||||||
|
// Lowers from LHLO dialect to parallel loops.
|
||||||
|
std::unique_ptr<OperationPass<FuncOp>> createLegalizeLhloToParallelLoopsPass();
|
||||||
|
|
||||||
|
} // namespace xla_lhlo
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
|
||||||
|
/// Lowers the standard TanhOp to an approximation that does not use intrinsics.
|
||||||
|
std::unique_ptr<OperationPass<FuncOp>> createLegalizeTanhToApproximationPass();
|
||||||
|
|
||||||
|
} // namespace xla
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_PASSES_H_
|
|
@ -0,0 +1,106 @@
|
||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_REWRITERS_H_
|
||||||
|
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_REWRITERS_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
class LLVMTypeConverter;
|
||||||
|
class LowerToLLVMOptions;
|
||||||
|
class OwningRewritePatternList;
|
||||||
|
class BufferAssignmentPlacer;
|
||||||
|
namespace xla_hlo {
|
||||||
|
|
||||||
|
// Collection of rewrite patterns for lowering a general dot product.
|
||||||
|
void PopulateGeneralDotOpLoweringPatterns(OwningRewritePatternList *patterns,
|
||||||
|
MLIRContext *ctx);
|
||||||
|
|
||||||
|
// Collection of rewrite patterns for lowering complex operations to equivalent
|
||||||
|
// float operations.
|
||||||
|
void PopulateComplexLoweringPatterns(MLIRContext *context,
|
||||||
|
OwningRewritePatternList *patterns);
|
||||||
|
|
||||||
|
void PopulateXlaToStdPatterns(OwningRewritePatternList *patterns,
|
||||||
|
MLIRContext *ctx);
|
||||||
|
|
||||||
|
// Collection of rewrite patterns for lowering of HLO to LHLO dialect.
|
||||||
|
void populateHLOToLHLOConversionPattern(
|
||||||
|
MLIRContext *context, BufferAssignmentPlacer *bufferAssignment,
|
||||||
|
TypeConverter *converter, OwningRewritePatternList *patterns);
|
||||||
|
// Collection of rewrite patterns for lowering of HLO to Linalg dialect.
|
||||||
|
void populateHLOToLinalgConversionPattern(MLIRContext *context,
|
||||||
|
OwningRewritePatternList *patterns);
|
||||||
|
|
||||||
|
// Sets up legality definitions for materializing broadcasts.
|
||||||
|
void SetupMaterializeBroadcastsLegality(MLIRContext *context,
|
||||||
|
ConversionTarget *conversionTarget);
|
||||||
|
|
||||||
|
// Populates a collection of rewrite patterns for materializing broadcast
|
||||||
|
// attributes to equivalent sequences of ops.
|
||||||
|
void PopulateMaterializeBroadcastsPatterns(MLIRContext *context,
|
||||||
|
OwningRewritePatternList *patterns);
|
||||||
|
|
||||||
|
// Sets up legality definitions for element-wise operations on ranked tensors.
|
||||||
|
void SetupTransformUnrankedHloLegality(MLIRContext *context,
|
||||||
|
ConversionTarget *conversionTarget);
|
||||||
|
|
||||||
|
// Populates a collection of rewrite patterns to realize element-wise operations
|
||||||
|
// on ranked tensors where possible.
|
||||||
|
void PopulateTransformUnrankedHloPatterns(MLIRContext *context,
|
||||||
|
OwningRewritePatternList *patterns);
|
||||||
|
|
||||||
|
// Populate a collection of conversion patterns for un-fusing
|
||||||
|
// batch_norm_inference and batch_norm_training into constituent HLO ops.
|
||||||
|
// TODO(laurenzo): Implement un-fusing of batch_norm_training.
|
||||||
|
void PopulateUnfuseBatchNormPatterns(MLIRContext *context,
|
||||||
|
OwningRewritePatternList *patterns);
|
||||||
|
|
||||||
|
} // namespace xla_hlo
|
||||||
|
|
||||||
|
namespace xla_lhlo {
|
||||||
|
|
||||||
|
/// Collect a set of patterns to convert from the LHLO dialect to LLVM.
|
||||||
|
void PopulateLhloToLLVMConversionPatterns(const LowerToLLVMOptions &options,
|
||||||
|
LLVMTypeConverter *converter,
|
||||||
|
OwningRewritePatternList *patterns);
|
||||||
|
|
||||||
|
} // namespace xla_lhlo
|
||||||
|
|
||||||
|
namespace xla_chlo {
|
||||||
|
|
||||||
|
// Populates a collection of conversion patterns for legalizing client-HLO to
|
||||||
|
// HLO.
|
||||||
|
void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
|
||||||
|
OwningRewritePatternList *patterns);
|
||||||
|
|
||||||
|
} // namespace xla_chlo
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
|
||||||
|
// Populates a pattern that translates the standard TanhOp to an approximation
|
||||||
|
// that does not use intrinsics.
|
||||||
|
void PopulateTanhToApproximationPatterns(MLIRContext *context,
|
||||||
|
OwningRewritePatternList *patterns);
|
||||||
|
|
||||||
|
} // namespace xla
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_REWRITERS_H_
|
|
@ -0,0 +1,165 @@
|
||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_CYCLE_DETECTOR_H_
|
||||||
|
#define TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_CYCLE_DETECTOR_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/DenseMap.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
|
||||||
|
// -------------------------------------------------------------------
|
||||||
|
|
||||||
|
// This file contains a light version of GraphCycles implemented in
|
||||||
|
// tensorflow/compiler/jit/graphcycles/graphcycles.h
|
||||||
|
//
|
||||||
|
// We re-implement it here because we do not want to rely
|
||||||
|
// on TensorFlow data structures, and hence we can move
|
||||||
|
// corresponding passes to llvm repo. easily in case necessnary.
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------
|
||||||
|
|
||||||
|
// This is a set data structure that provides a deterministic iteration order.
|
||||||
|
// The iteration order of elements only depends on the sequence of
|
||||||
|
// inserts/deletes, so as long as the inserts/deletes happen in the same
|
||||||
|
// sequence, the set will have the same iteration order.
|
||||||
|
//
|
||||||
|
// Assumes that T can be cheaply copied for simplicity.
|
||||||
|
template <typename T>
|
||||||
|
class OrderedSet {
|
||||||
|
public:
|
||||||
|
// Inserts `value` into the ordered set. Returns true if the value was not
|
||||||
|
// present in the set before the insertion.
|
||||||
|
bool Insert(T value) {
|
||||||
|
bool new_insertion =
|
||||||
|
value_to_index_.insert({value, value_sequence_.size()}).second;
|
||||||
|
if (new_insertion) {
|
||||||
|
value_sequence_.push_back(value);
|
||||||
|
}
|
||||||
|
return new_insertion;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Removes `value` from the set. Assumes `value` is already present in the
|
||||||
|
// set.
|
||||||
|
void Erase(T value) {
|
||||||
|
auto it = value_to_index_.find(value);
|
||||||
|
|
||||||
|
// Since we don't want to move values around in `value_sequence_` we swap
|
||||||
|
// the value in the last position and with value to be deleted and then
|
||||||
|
// pop_back.
|
||||||
|
value_to_index_[value_sequence_.back()] = it->second;
|
||||||
|
std::swap(value_sequence_[it->second], value_sequence_.back());
|
||||||
|
value_sequence_.pop_back();
|
||||||
|
value_to_index_.erase(it);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Reserve(size_t new_size) {
|
||||||
|
value_to_index_.reserve(new_size);
|
||||||
|
value_sequence_.reserve(new_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Clear() {
|
||||||
|
value_to_index_.clear();
|
||||||
|
value_sequence_.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Contains(T value) const { return value_to_index_.count(value); }
|
||||||
|
size_t Size() const { return value_sequence_.size(); }
|
||||||
|
|
||||||
|
const std::vector<T>& GetSequence() const { return value_sequence_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
// The stable order that we maintain through insertions and deletions.
|
||||||
|
std::vector<T> value_sequence_;
|
||||||
|
|
||||||
|
// Maps values to their indices in `value_sequence_`.
|
||||||
|
llvm::DenseMap<T, int> value_to_index_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------
|
||||||
|
|
||||||
|
// GraphCycles detects the introduction of a cycle into a directed
|
||||||
|
// graph that is being built up incrementally.
|
||||||
|
//
|
||||||
|
// Nodes are identified by small integers. It is not possible to
|
||||||
|
// record multiple edges with the same (source, destination) pair;
|
||||||
|
// requests to add an edge where one already exists are silently
|
||||||
|
// ignored.
|
||||||
|
//
|
||||||
|
// It is also not possible to introduce a cycle; an attempt to insert
|
||||||
|
// an edge that would introduce a cycle fails and returns false.
|
||||||
|
//
|
||||||
|
// GraphCycles uses no internal locking; calls into it should be
|
||||||
|
// serialized externally.
|
||||||
|
|
||||||
|
// Performance considerations:
|
||||||
|
// Works well on sparse graphs, poorly on dense graphs.
|
||||||
|
// Extra information is maintained incrementally to detect cycles quickly.
|
||||||
|
// InsertEdge() is very fast when the edge already exists, and reasonably fast
|
||||||
|
// otherwise.
|
||||||
|
// FindPath() is linear in the size of the graph.
|
||||||
|
// The current implementation uses O(|V|+|E|) space.
|
||||||
|
|
||||||
|
class GraphCycles {
|
||||||
|
public:
|
||||||
|
explicit GraphCycles(int32_t num_nodes);
|
||||||
|
~GraphCycles();
|
||||||
|
|
||||||
|
// Attempt to insert an edge from x to y. If the
|
||||||
|
// edge would introduce a cycle, return false without making any
|
||||||
|
// changes. Otherwise add the edge and return true.
|
||||||
|
bool InsertEdge(int32_t x, int32_t y);
|
||||||
|
|
||||||
|
// Remove any edge that exists from x to y.
|
||||||
|
void RemoveEdge(int32_t x, int32_t y);
|
||||||
|
|
||||||
|
// Return whether there is an edge directly from x to y.
|
||||||
|
bool HasEdge(int32_t x, int32_t y) const;
|
||||||
|
|
||||||
|
// Contracts the edge from 'a' to node 'b', merging nodes 'a' and 'b'. One of
|
||||||
|
// the nodes is removed from the graph, and edges to/from it are added to
|
||||||
|
// the remaining one, which is returned. If contracting the edge would create
|
||||||
|
// a cycle, does nothing and return no value.
|
||||||
|
llvm::Optional<int32_t> ContractEdge(int32_t a, int32_t b);
|
||||||
|
|
||||||
|
// Return whether dest_node `y` is reachable from source_node `x`
|
||||||
|
// by following edges. This is non-thread-safe version.
|
||||||
|
bool IsReachable(int32_t x, int32_t y);
|
||||||
|
|
||||||
|
// Return a copy of the successors set. This is needed for code using the
|
||||||
|
// collection while modifying the GraphCycles.
|
||||||
|
std::vector<int32_t> SuccessorsCopy(int32_t node) const;
|
||||||
|
|
||||||
|
// Returns all nodes in post order.
|
||||||
|
//
|
||||||
|
// If there is a path from X to Y then X appears after Y in the
|
||||||
|
// returned vector.
|
||||||
|
std::vector<int32_t> AllNodesInPostOrder() const;
|
||||||
|
|
||||||
|
// ----------------------------------------------------
|
||||||
|
struct Rep;
|
||||||
|
|
||||||
|
private:
|
||||||
|
GraphCycles(const GraphCycles&) = delete;
|
||||||
|
GraphCycles& operator=(const GraphCycles&) = delete;
|
||||||
|
|
||||||
|
Rep* rep_; // opaque representation
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_CYCLE_DETECTOR_H_
|
|
@ -0,0 +1,242 @@
|
||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Shape/IR/Shape.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OperationSupport.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h"
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/broadcast_utils.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace xla_chlo {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// Converts binary ops that statically are determined to not broadcast directly
|
||||||
|
// to the corresponding xla_hlo non-broadcasting op.
|
||||||
|
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
|
||||||
|
struct ConvertTrivialNonBroadcastBinaryOp : public OpRewritePattern<ChloOpTy> {
|
||||||
|
using OpRewritePattern<ChloOpTy>::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(ChloOpTy op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
// Only rewrite for statically determinable non-broadcasting cases.
|
||||||
|
auto lhs_type = op.lhs().getType().template dyn_cast<RankedTensorType>();
|
||||||
|
auto rhs_type = op.rhs().getType().template dyn_cast<RankedTensorType>();
|
||||||
|
if (!lhs_type || !rhs_type) return failure();
|
||||||
|
|
||||||
|
// Requires rank broadcast.
|
||||||
|
if (lhs_type.getRank() != rhs_type.getRank()) return failure();
|
||||||
|
// Any dynamic dimension may require broadcasting and requires more
|
||||||
|
// analysis.
|
||||||
|
if (!lhs_type.hasStaticShape() || !rhs_type.hasStaticShape())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
for (auto extents : llvm::zip(lhs_type.getShape(), rhs_type.getShape())) {
|
||||||
|
auto lhs_extent = std::get<0>(extents);
|
||||||
|
auto rhs_extent = std::get<1>(extents);
|
||||||
|
if (lhs_extent != rhs_extent) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, {Adaptor::CreateOp(op, op.getResult().getType(),
|
||||||
|
op.lhs(), op.rhs(), rewriter)});
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Converts a binary op with ranked broadcasting operands to explicitly
|
||||||
|
// broadcast and invoke the corresponding xla_hlo non-broadcasting op.
|
||||||
|
// Note that dynamic broadcasting supported by this pattern is only valid for
|
||||||
|
// "numpy" broadcasting semantics as defined here:
|
||||||
|
// https://docs.scipy.org/doc/numpy/reference/ufuncs.html
|
||||||
|
// Specifically, this includes the following cases:
|
||||||
|
// - Same rank broadcast (operands have the same static rank).
|
||||||
|
// - Different-rank broadcast, either without a broadcast_dims attribte or
|
||||||
|
// with the broadcast_dims attribute set to map to a prefix padding.
|
||||||
|
// - Legal combinations of degenerate (1-dim) implicit broadcasting.
|
||||||
|
// The restriction on broadcast_dims derives from the definition of the
|
||||||
|
// `shape.broadcast` op, which only supports prefix-padding.
|
||||||
|
//
|
||||||
|
// It may be possible to expand this pattern to operate on unranked tensors in
|
||||||
|
// the future by emitting more code to dynamically differentiate based on rank.
|
||||||
|
// Whether that is of any practical benefit remains to be seen.
|
||||||
|
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
|
||||||
|
struct ConvertRankedDynamicBroadcastBinaryOp
|
||||||
|
: public OpRewritePattern<ChloOpTy> {
|
||||||
|
using OpRewritePattern<ChloOpTy>::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(ChloOpTy op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
// Only support ranked operands.
|
||||||
|
Value lhs = op.lhs();
|
||||||
|
Value rhs = op.rhs();
|
||||||
|
auto lhs_type = lhs.getType().dyn_cast<RankedTensorType>();
|
||||||
|
auto rhs_type = rhs.getType().dyn_cast<RankedTensorType>();
|
||||||
|
auto result_type =
|
||||||
|
op.getResult().getType().template dyn_cast<RankedTensorType>();
|
||||||
|
if (!lhs_type || !rhs_type || !result_type) return failure();
|
||||||
|
|
||||||
|
// Check for "numpy"-style rank broadcast.
|
||||||
|
auto broadcast_dimensions = op.broadcast_dimensions();
|
||||||
|
if (broadcast_dimensions &&
|
||||||
|
!xla::IsLegalNumpyRankedBroadcast(lhs, rhs, *broadcast_dimensions)) {
|
||||||
|
// Note: It is unclear whether the general specification of explicit
|
||||||
|
// broadcast_dimensions on binary ops is a feature we want to carry
|
||||||
|
// forward. While it can technically be implemented for ranked-dynamic,
|
||||||
|
// it is incompatible with unranked inputs. If this warning is emitted
|
||||||
|
// in real programs, it is an indication that the feature should be
|
||||||
|
// implemented versus just falling back on the more standard definition
|
||||||
|
// of numpy-like prefix-padding.
|
||||||
|
op.emitWarning() << "unsupported non prefix-padded dynamic rank "
|
||||||
|
<< "broadcast_dimensions = " << *broadcast_dimensions;
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute result shape.
|
||||||
|
auto loc = op.getLoc();
|
||||||
|
|
||||||
|
// Insert a constraint on the shapes being broadcastable and insert all
|
||||||
|
// future code into an assuming block reliant on the constraint.
|
||||||
|
Value lhs_shape = rewriter.create<shape::ShapeOfOp>(loc, lhs);
|
||||||
|
Value rhs_shape = rewriter.create<shape::ShapeOfOp>(loc, rhs);
|
||||||
|
auto broadcastable_cstr =
|
||||||
|
rewriter.create<shape::CstrBroadcastableOp>(loc, lhs_shape, rhs_shape);
|
||||||
|
auto assuming_op = rewriter.create<shape::AssumingOp>(
|
||||||
|
loc, ArrayRef<Type>{result_type}, broadcastable_cstr.result());
|
||||||
|
|
||||||
|
OpBuilder::InsertionGuard guard(rewriter);
|
||||||
|
rewriter.createBlock(&assuming_op.doRegion());
|
||||||
|
|
||||||
|
int64_t result_rank = std::max(lhs_type.getRank(), rhs_type.getRank());
|
||||||
|
Value result_extents =
|
||||||
|
xla::ComputeBinaryElementwiseBroadcastingResultExtents(loc, lhs, rhs,
|
||||||
|
rewriter);
|
||||||
|
|
||||||
|
// Note that we unconditionally emit DynamicBroadcastInDim ops and let
|
||||||
|
// downstream canonicalizations fold them away if possible. This is
|
||||||
|
// because, in the dynamic case, there are many corner cases regarding
|
||||||
|
// when it is safe to omit, and some of them require analysis to prove
|
||||||
|
// properly.
|
||||||
|
auto lhs_broadcast_dimensions = llvm::to_vector<4>(
|
||||||
|
llvm::seq<int64_t>(result_rank - lhs_type.getRank(), result_rank));
|
||||||
|
Value broadcasted_lhs = rewriter.create<xla_hlo::DynamicBroadcastInDimOp>(
|
||||||
|
loc,
|
||||||
|
RankedTensorType::get(result_type.getShape(),
|
||||||
|
lhs_type.getElementType()),
|
||||||
|
lhs, result_extents,
|
||||||
|
rewriter.getI64TensorAttr(lhs_broadcast_dimensions));
|
||||||
|
auto rhs_broadcast_dimensions = llvm::to_vector<4>(
|
||||||
|
llvm::seq<int64_t>(result_rank - rhs_type.getRank(), result_rank));
|
||||||
|
Value broadcasted_rhs = rewriter.create<xla_hlo::DynamicBroadcastInDimOp>(
|
||||||
|
loc,
|
||||||
|
RankedTensorType::get(result_type.getShape(),
|
||||||
|
rhs_type.getElementType()),
|
||||||
|
rhs, result_extents,
|
||||||
|
rewriter.getI64TensorAttr(rhs_broadcast_dimensions));
|
||||||
|
|
||||||
|
// And generate the final non-broadcasted binary op.
|
||||||
|
Value final_result = Adaptor::CreateOp(op, result_type, broadcasted_lhs,
|
||||||
|
broadcasted_rhs, rewriter);
|
||||||
|
rewriter.create<shape::AssumingYieldOp>(loc, final_result);
|
||||||
|
rewriter.replaceOp(op, {assuming_op.getResult(0)});
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
|
||||||
|
void PopulateForBinaryOp(MLIRContext *context,
|
||||||
|
OwningRewritePatternList *patterns) {
|
||||||
|
patterns
|
||||||
|
->insert<ConvertTrivialNonBroadcastBinaryOp<ChloOpTy, HloOpTy, Adaptor>>(
|
||||||
|
context, 10);
|
||||||
|
patterns->insert<
|
||||||
|
ConvertRankedDynamicBroadcastBinaryOp<ChloOpTy, HloOpTy, Adaptor>>(
|
||||||
|
context, 5);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename FromOpTy, typename ToOpTy>
|
||||||
|
struct HloBinaryElementwiseAdaptor {
|
||||||
|
static ToOpTy CreateOp(FromOpTy from_op, Type result_type,
|
||||||
|
Value broadcasted_lhs, Value broadcasted_rhs,
|
||||||
|
OpBuilder &builder) {
|
||||||
|
return builder.create<ToOpTy>(from_op.getLoc(), result_type,
|
||||||
|
broadcasted_lhs, broadcasted_rhs);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct HloComplexAdaptor {
|
||||||
|
static xla_hlo::ComplexOp CreateOp(BroadcastComplexOp from_op,
|
||||||
|
Type result_type, Value broadcasted_lhs,
|
||||||
|
Value broadcasted_rhs,
|
||||||
|
OpBuilder &builder) {
|
||||||
|
return builder.create<xla_hlo::ComplexOp>(from_op.getLoc(), result_type,
|
||||||
|
broadcasted_lhs, broadcasted_rhs);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct HloCompareAdaptor {
|
||||||
|
static xla_hlo::CompareOp CreateOp(BroadcastCompareOp from_op,
|
||||||
|
Type result_type, Value broadcasted_lhs,
|
||||||
|
Value broadcasted_rhs,
|
||||||
|
OpBuilder &builder) {
|
||||||
|
return builder.create<xla_hlo::CompareOp>(from_op.getLoc(), result_type,
|
||||||
|
broadcasted_lhs, broadcasted_rhs,
|
||||||
|
from_op.comparison_direction());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
|
||||||
|
OwningRewritePatternList *patterns) {
|
||||||
|
// Instantiate conversion templates for conforming binary elementwise ops
|
||||||
|
// that do not have different dtypes between operands and results and do
|
||||||
|
// not have special attributes that need to be preserved.
|
||||||
|
#define POPULATE_BCAST(ChloOp, HloOp) \
|
||||||
|
PopulateForBinaryOp<ChloOp, HloOp, \
|
||||||
|
HloBinaryElementwiseAdaptor<ChloOp, HloOp>>(context, \
|
||||||
|
patterns);
|
||||||
|
|
||||||
|
POPULATE_BCAST(BroadcastAddOp, xla_hlo::AddOp);
|
||||||
|
POPULATE_BCAST(BroadcastAndOp, xla_hlo::AndOp);
|
||||||
|
POPULATE_BCAST(BroadcastAtan2Op, xla_hlo::Atan2Op);
|
||||||
|
POPULATE_BCAST(BroadcastDivOp, xla_hlo::DivOp);
|
||||||
|
POPULATE_BCAST(BroadcastMaxOp, xla_hlo::MaxOp);
|
||||||
|
POPULATE_BCAST(BroadcastMinOp, xla_hlo::MinOp);
|
||||||
|
POPULATE_BCAST(BroadcastMulOp, xla_hlo::MulOp);
|
||||||
|
POPULATE_BCAST(BroadcastOrOp, xla_hlo::OrOp);
|
||||||
|
POPULATE_BCAST(BroadcastPowOp, xla_hlo::PowOp);
|
||||||
|
POPULATE_BCAST(BroadcastRemOp, xla_hlo::RemOp);
|
||||||
|
POPULATE_BCAST(BroadcastShiftLeftOp, xla_hlo::ShiftLeftOp);
|
||||||
|
POPULATE_BCAST(BroadcastShiftRightArithmeticOp,
|
||||||
|
xla_hlo::ShiftRightArithmeticOp);
|
||||||
|
POPULATE_BCAST(BroadcastShiftRightLogicalOp, xla_hlo::ShiftRightLogicalOp);
|
||||||
|
POPULATE_BCAST(BroadcastSubOp, xla_hlo::SubOp);
|
||||||
|
POPULATE_BCAST(BroadcastXorOp, xla_hlo::XorOp);
|
||||||
|
|
||||||
|
// Broadcasting ops requiring special construction.
|
||||||
|
PopulateForBinaryOp<BroadcastComplexOp, xla_hlo::ComplexOp,
|
||||||
|
HloComplexAdaptor>(context, patterns);
|
||||||
|
PopulateForBinaryOp<BroadcastCompareOp, xla_hlo::CompareOp,
|
||||||
|
HloCompareAdaptor>(context, patterns);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace xla_chlo
|
||||||
|
} // namespace mlir
|
|
@ -0,0 +1,57 @@
|
||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Shape/IR/Shape.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace xla_chlo {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
struct TestChloLegalizeToHloPass
|
||||||
|
: public PassWrapper<TestChloLegalizeToHloPass, FunctionPass> {
|
||||||
|
void runOnFunction() override {
|
||||||
|
ConversionTarget conversionTarget(getContext());
|
||||||
|
OwningRewritePatternList conversionPatterns;
|
||||||
|
|
||||||
|
conversionTarget.addIllegalDialect<XlaHloClientDialect>();
|
||||||
|
// Consider the xla_hlo dialect legal for tests.
|
||||||
|
conversionTarget.addLegalDialect<xla_hlo::XlaHloDialect>();
|
||||||
|
// The conversion uses helpers from the Standard dialect.
|
||||||
|
conversionTarget.addLegalDialect<mlir::StandardOpsDialect>();
|
||||||
|
conversionTarget.addLegalDialect<mlir::shape::ShapeDialect>();
|
||||||
|
|
||||||
|
PopulateLegalizeChloToHloPatterns(&getContext(), &conversionPatterns);
|
||||||
|
|
||||||
|
if (failed(applyPartialConversion(getFunction(), conversionTarget,
|
||||||
|
conversionPatterns))) {
|
||||||
|
return signalPassFailure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
} // namespace xla_chlo
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
static mlir::PassRegistration<mlir::xla_chlo::TestChloLegalizeToHloPass> pass(
|
||||||
|
"test-xla-chlo-legalize-to-hlo",
|
||||||
|
"Test pass for applying chlo -> hlo legalization patterns");
|
|
@ -0,0 +1,493 @@
|
||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
// This file implements logic for lowering HLO dialect to LHLO dialect.
|
||||||
|
|
||||||
|
#include "third_party/absl/memory/memory.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/AffineMap.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/BlockAndValueMapping.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Function.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Location.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/BufferPlacement.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h"
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h"
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace xla_hlo {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
using BaseOpConversion = BufferAssignmentOpConversionPattern<T>;
|
||||||
|
using StdReturnOpConverter =
|
||||||
|
detail::BufferAssignmentReturnOpConverter<mlir::ReturnOp, mlir::ReturnOp,
|
||||||
|
xla_lhlo::CopyOp, true>;
|
||||||
|
|
||||||
|
Value InsertDynamicAllocAndDealloc(Location loc, Value result,
|
||||||
|
Value shape_operand,
|
||||||
|
ConversionPatternRewriter* rewriter) {
|
||||||
|
auto result_type = result.getType().dyn_cast<ShapedType>();
|
||||||
|
if (!result_type) {
|
||||||
|
result.getDefiningOp()->emitOpError()
|
||||||
|
<< "tensor to buffer conversion expects ranked results";
|
||||||
|
}
|
||||||
|
auto memref_type =
|
||||||
|
MemRefType::get(result_type.getShape(), result_type.getElementType());
|
||||||
|
|
||||||
|
Operation* op = result.getDefiningOp();
|
||||||
|
|
||||||
|
// Extract the required element out of the vector.
|
||||||
|
SmallVector<Value, 4> dynamic_operands;
|
||||||
|
for (auto shape_element : llvm::enumerate(result_type.getShape())) {
|
||||||
|
if (shape_element.value() != ShapedType::kDynamicSize) continue;
|
||||||
|
Value index = rewriter->create<ConstantOp>(
|
||||||
|
loc, rewriter->getIntegerAttr(rewriter->getIndexType(),
|
||||||
|
shape_element.index()));
|
||||||
|
Value alloc_operand = rewriter->create<ExtractElementOp>(loc, shape_operand,
|
||||||
|
ValueRange{index});
|
||||||
|
if (!alloc_operand.getType().isIndex()) {
|
||||||
|
alloc_operand = rewriter->create<IndexCastOp>(loc, alloc_operand,
|
||||||
|
rewriter->getIndexType());
|
||||||
|
}
|
||||||
|
dynamic_operands.push_back(alloc_operand);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insert in front of op to ensure sizes are available.
|
||||||
|
OpBuilder allocBuilder(op);
|
||||||
|
auto alloc = allocBuilder.create<AllocOp>(loc, memref_type, dynamic_operands);
|
||||||
|
return alloc;
|
||||||
|
}
|
||||||
|
|
||||||
|
Value InsertAlloc(Location loc, OpResult result,
|
||||||
|
BufferAssignmentPlacer* bufferAssignment,
|
||||||
|
ConversionPatternRewriter* rewriter) {
|
||||||
|
auto result_type = result.getType().dyn_cast<ShapedType>();
|
||||||
|
if (!result_type || !result_type.hasStaticShape()) {
|
||||||
|
result.getDefiningOp()->emitOpError()
|
||||||
|
<< "tensor to buffer conversion expects statically shaped results";
|
||||||
|
}
|
||||||
|
auto memref_type =
|
||||||
|
MemRefType::get(result_type.getShape(), result_type.getElementType());
|
||||||
|
OpBuilder::InsertionGuard guard(*rewriter);
|
||||||
|
rewriter->restoreInsertionPoint(
|
||||||
|
bufferAssignment->computeAllocPosition(result));
|
||||||
|
auto alloc = rewriter->create<AllocOp>(loc, memref_type);
|
||||||
|
return alloc;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename HloOpTy>
|
||||||
|
class HloToLhloOpConverter : public BaseOpConversion<HloOpTy> {
|
||||||
|
public:
|
||||||
|
using BaseOpConversion<HloOpTy>::BaseOpConversion;
|
||||||
|
LogicalResult matchAndRewrite(
|
||||||
|
HloOpTy hloOp, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter& rewriter) const final {
|
||||||
|
Operation* op = hloOp.getOperation();
|
||||||
|
const auto& original_results = op->getResults();
|
||||||
|
SmallVector<Value, 4> buffer_args(operands.begin(), operands.end());
|
||||||
|
for (auto result : llvm::enumerate(original_results)) {
|
||||||
|
RankedTensorType resultType =
|
||||||
|
result.value().getType().dyn_cast<RankedTensorType>();
|
||||||
|
if (!resultType) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
if (resultType.hasStaticShape()) {
|
||||||
|
buffer_args.push_back(InsertAlloc(op->getLoc(), result.value(),
|
||||||
|
this->bufferAssignment, &rewriter));
|
||||||
|
} else {
|
||||||
|
SmallVector<Value, 1> results_shape;
|
||||||
|
auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op);
|
||||||
|
if (!shape_type_op) return failure();
|
||||||
|
if (failed(
|
||||||
|
shape_type_op.reifyReturnTypeShapes(rewriter, results_shape)))
|
||||||
|
return failure();
|
||||||
|
buffer_args.push_back(InsertDynamicAllocAndDealloc(
|
||||||
|
op->getLoc(), result.value(), results_shape.front(), &rewriter));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
rewriter.create<xla_hlo::HloToLhloOp<HloOpTy>>(op->getLoc(), llvm::None,
|
||||||
|
buffer_args, op->getAttrs());
|
||||||
|
rewriter.replaceOp(op, ArrayRef<Value>(buffer_args).slice(operands.size()));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct HloToLhloDynamicBroadcastInDimOpConverter
|
||||||
|
: public BaseOpConversion<xla_hlo::DynamicBroadcastInDimOp> {
|
||||||
|
public:
|
||||||
|
using BaseOpConversion<xla_hlo::DynamicBroadcastInDimOp>::BaseOpConversion;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(
|
||||||
|
xla_hlo::DynamicBroadcastInDimOp op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter& rewriter) const final {
|
||||||
|
auto loc = op.getLoc();
|
||||||
|
Value resultBuffer = InsertDynamicAllocAndDealloc(
|
||||||
|
loc, op.getResult(), op.output_dimensions(), &rewriter);
|
||||||
|
|
||||||
|
Value transformed_operand =
|
||||||
|
InsertDynamicMemrefCastOp(op, operands.front(), &rewriter);
|
||||||
|
rewriter.create<xla_lhlo::BroadcastInDimOp>(
|
||||||
|
loc, transformed_operand, resultBuffer, op.broadcast_dimensions());
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, {resultBuffer});
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Inserts dynamic memref to change the layout of the memref to put 0-stride
|
||||||
|
// and size of the target dimension if size-1 dimension expansion is
|
||||||
|
// necessary.
|
||||||
|
xla_lhlo::DynamicMemRefCastOp InsertDynamicMemrefCastOp(
|
||||||
|
xla_hlo::DynamicBroadcastInDimOp op, Value operand, OpBuilder* b) const {
|
||||||
|
auto loc = op.getLoc();
|
||||||
|
auto operand_type = operand.getType().cast<MemRefType>();
|
||||||
|
auto operand_shape = operand_type.getShape();
|
||||||
|
|
||||||
|
SmallVector<Value, 2> sizes, strides;
|
||||||
|
sizes.reserve(operand_shape.size());
|
||||||
|
strides.reserve(operand_shape.size());
|
||||||
|
|
||||||
|
Value zero = b->create<ConstantIndexOp>(loc, 0);
|
||||||
|
Value one = b->create<ConstantIndexOp>(loc, 1);
|
||||||
|
for (auto dim : llvm::enumerate(op.broadcast_dimensions())) {
|
||||||
|
Value broadcast_dim_value =
|
||||||
|
b->create<ConstantIndexOp>(loc, dim.value().getSExtValue());
|
||||||
|
Value result_dim_size = b->create<ExtractElementOp>(
|
||||||
|
loc, op.output_dimensions(), broadcast_dim_value);
|
||||||
|
Value operand_dim_size =
|
||||||
|
ShapedType::isDynamic(operand_shape[dim.index()])
|
||||||
|
? b->create<DimOp>(loc, operand, dim.index()).getResult()
|
||||||
|
: b->create<ConstantIndexOp>(loc, operand_shape[dim.index()])
|
||||||
|
.getResult();
|
||||||
|
|
||||||
|
// TODO(pifon): Revisit if this cast is needed. Maybe we can use
|
||||||
|
// tensor<index> for `output_dimensions` as well.
|
||||||
|
if (!result_dim_size.getType().isIndex()) {
|
||||||
|
result_dim_size =
|
||||||
|
b->create<IndexCastOp>(loc, result_dim_size, b->getIndexType());
|
||||||
|
}
|
||||||
|
|
||||||
|
// There can be two cases:
|
||||||
|
// 1) Operand dim == result dim => expansion is not needed => stride := 1.
|
||||||
|
// 2) Operand dim < result dim => expansion is needed => stride := 0.
|
||||||
|
Value is_expansion = b->create<CmpIOp>(loc, CmpIPredicate::slt,
|
||||||
|
operand_dim_size, result_dim_size);
|
||||||
|
strides.push_back(
|
||||||
|
b->create<mlir::SelectOp>(loc, is_expansion, zero, one));
|
||||||
|
|
||||||
|
// Size of input dim can be set to the size of the corresponding output
|
||||||
|
// dimension for both cases.
|
||||||
|
sizes.push_back(result_dim_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Type-erased memref type with static rank, dynamic sizes and strides.
|
||||||
|
SmallVector<int64_t, 2> dynamic_layout(operand_shape.size(),
|
||||||
|
MemRefType::kDynamicStrideOrOffset);
|
||||||
|
SmallVector<int64_t, 2> dynamic_shape(operand_shape.size(),
|
||||||
|
MemRefType::kDynamicSize);
|
||||||
|
auto type_erased_memref_type = MemRefType::get(
|
||||||
|
dynamic_shape, operand_type.getElementType(),
|
||||||
|
makeStridedLinearLayoutMap(dynamic_layout,
|
||||||
|
/*offset=*/0, b->getContext()));
|
||||||
|
|
||||||
|
auto transformed_operand = b->create<xla_lhlo::DynamicMemRefCastOp>(
|
||||||
|
loc, type_erased_memref_type, operand, sizes, strides);
|
||||||
|
return transformed_operand;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct HloToLhloReduceOpConverter : public BaseOpConversion<xla_hlo::ReduceOp> {
|
||||||
|
public:
|
||||||
|
using BaseOpConversion<xla_hlo::ReduceOp>::BaseOpConversion;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(
|
||||||
|
xla_hlo::ReduceOp op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter& rewriter) const final {
|
||||||
|
auto loc = op.getLoc();
|
||||||
|
// TODO(b/137624192) Implement variadic reduce.
|
||||||
|
if (op.getNumResults() != 1) return failure();
|
||||||
|
if (!llvm::hasSingleElement(op.body())) {
|
||||||
|
return op.emitOpError()
|
||||||
|
<< "tensor to buffer conversion expects a single block "
|
||||||
|
"in the region containing the operation";
|
||||||
|
}
|
||||||
|
const auto& original_results = op.getResults();
|
||||||
|
SmallVector<Value, 4> buffer_args(operands.begin(), operands.end());
|
||||||
|
for (auto result : original_results) {
|
||||||
|
buffer_args.push_back(
|
||||||
|
InsertAlloc(loc, result, this->bufferAssignment, &rewriter));
|
||||||
|
}
|
||||||
|
auto new_op = rewriter.create<xla_lhlo::ReduceOp>(
|
||||||
|
loc, llvm::None, buffer_args, op.getAttrs());
|
||||||
|
|
||||||
|
// Copy over the operations inside the region.
|
||||||
|
rewriter.inlineRegionBefore(op.body(), new_op.body(), new_op.body().end());
|
||||||
|
|
||||||
|
// Create new block arguments with correct type.
|
||||||
|
auto& entry_block = new_op.body().front();
|
||||||
|
int original_arg_count = entry_block.getNumArguments();
|
||||||
|
for (int i = 0; i < original_arg_count; ++i) {
|
||||||
|
auto old_arg = entry_block.getArgument(i);
|
||||||
|
auto old_type = old_arg.getType().cast<TensorType>();
|
||||||
|
auto new_type =
|
||||||
|
MemRefType::get(old_type.getShape(), old_type.getElementType());
|
||||||
|
auto new_arg = entry_block.addArgument(new_type);
|
||||||
|
rewriter.replaceUsesOfBlockArgument(old_arg, new_arg);
|
||||||
|
}
|
||||||
|
// Add an argument for the result.
|
||||||
|
entry_block.addArgument(
|
||||||
|
entry_block.getArgument(original_arg_count).getType());
|
||||||
|
// Remove the old arguments.
|
||||||
|
for (int i = original_arg_count - 1; i >= 0; --i) {
|
||||||
|
entry_block.eraseArgument(i);
|
||||||
|
}
|
||||||
|
// Insert terminator at the end.
|
||||||
|
rewriter.setInsertionPointToEnd(&entry_block);
|
||||||
|
rewriter.create<xla_lhlo::TerminatorOp>(loc);
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, ArrayRef<Value>(buffer_args).slice(operands.size()));
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class HloToLhloTensorLoadOpConverter
|
||||||
|
: public BaseOpConversion<mlir::TensorLoadOp> {
|
||||||
|
public:
|
||||||
|
using BaseOpConversion<mlir::TensorLoadOp>::BaseOpConversion;
|
||||||
|
LogicalResult matchAndRewrite(
|
||||||
|
mlir::TensorLoadOp op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter& rewriter) const final {
|
||||||
|
rewriter.replaceOp(op, operands);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// TODO(b/137624192): Rewrite into a copy and elide copy if possible.
|
||||||
|
class HloToLhloTensorStoreOpConverter
|
||||||
|
: public BaseOpConversion<mlir::TensorStoreOp> {
|
||||||
|
public:
|
||||||
|
using BaseOpConversion<mlir::TensorStoreOp>::BaseOpConversion;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(
|
||||||
|
mlir::TensorStoreOp op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter& rewriter) const final {
|
||||||
|
rewriter.replaceOpWithNewOp<xla_lhlo::CopyOp>(
|
||||||
|
op, llvm::None, operands.front(), operands.back());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary
|
||||||
|
// buffers if necessary.
|
||||||
|
//
|
||||||
|
// Example fusion with HLO ops.
|
||||||
|
//
|
||||||
|
// func @fusion(%arg0: memref<2x2xf32>,
|
||||||
|
// %arg1: memref<2x2xf32>,
|
||||||
|
// %arg2: memref<2x2xf32>,
|
||||||
|
// %arg3: memref<2x2xf32>) {
|
||||||
|
// "xla_lhlo.fusion"() ({
|
||||||
|
// %0 = tensor_load %arg1 : memref<2x2xf32>
|
||||||
|
// %1 = tensor_load %arg2 : memref<2x2xf32>
|
||||||
|
// %2 = "xla_hlo.add"(%0, %1) :
|
||||||
|
// (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||||
|
// %3 = tensor_load %arg0 : memref<2x2xf32>
|
||||||
|
// %4 = "xla_hlo.multiply"(%2, %3) :
|
||||||
|
// (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||||
|
// tensor_store %4, %arg3 : memref<2x2xf32>
|
||||||
|
// "xla_lhlo.terminator"() : () -> ()
|
||||||
|
// }) : () -> ()
|
||||||
|
// return
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// Transformed fusion with LHLO ops.
|
||||||
|
// func @fusion(%arg0: memref<2x2xf32>,
|
||||||
|
// %arg1: memref<2x2xf32>,
|
||||||
|
// %arg2: memref<2x2xf32>,
|
||||||
|
// %arg3: memref<2x2xf32>) {
|
||||||
|
// "xla_lhlo.fusion"() ( {
|
||||||
|
// %0 = alloc() : memref<2x2xf32>
|
||||||
|
// "xla_lhlo.add"(%arg1, %arg2, %0) :
|
||||||
|
// (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||||
|
// "xla_lhlo.multiply"(%0, %arg0, %arg3) :
|
||||||
|
// (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||||
|
// "xla_lhlo.terminator"() : () -> ()
|
||||||
|
// }) : () -> ()
|
||||||
|
// return
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// FuncOp signature conversion example:
|
||||||
|
//
|
||||||
|
// func @func_op(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||||
|
// %0 = "xla_hlo.maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) ->
|
||||||
|
// tensor<4xf32> %1 = "xla_hlo.add"(%arg0, %0) : (tensor<4xf32>,
|
||||||
|
// tensor<4xf32>) -> tensor<4xf32> return %1 : tensor<4xf32>
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// Transformed function with an extra argument for the result. The types have
|
||||||
|
// been converted from tensor to memref.
|
||||||
|
//
|
||||||
|
// func @func_op(%arg0: memref<4xf32>,
|
||||||
|
// %arg1: memref<4xf32>,
|
||||||
|
// %arg2: memref<4xf32>) {
|
||||||
|
// %0 = alloc() : memref<4xf32>
|
||||||
|
|
||||||
|
// "xla_lhlo.maximum"(%arg0, %arg1, %0) :
|
||||||
|
// (memref<4xf32>, memref<4xf32>, memref<4xf32>) -> ()
|
||||||
|
// %1 = alloc() : memref<4xf32>
|
||||||
|
// "xla_lhlo.add"(%arg0, %0, %1) :
|
||||||
|
// (memref<4xf32>, memref<4xf32>, memref<4xf32>) -> ()
|
||||||
|
// "xla_lhlo.copy"(%1, %arg2) : (memref<4xf32>, memref<4xf32>) -> ()
|
||||||
|
// "xla_lhlo.terminator"() : () -> ()
|
||||||
|
// }
|
||||||
|
|
||||||
|
struct HloLegalizeToLhlo
|
||||||
|
: public PassWrapper<HloLegalizeToLhlo, OperationPass<ModuleOp>> {
|
||||||
|
public:
|
||||||
|
HloLegalizeToLhlo() = default;
|
||||||
|
HloLegalizeToLhlo(const HloLegalizeToLhlo& o) {
|
||||||
|
this->results_escape_function = o.results_escape_function.getValue();
|
||||||
|
}
|
||||||
|
explicit HloLegalizeToLhlo(bool results_escape_function) {
|
||||||
|
this->results_escape_function.setValue(results_escape_function);
|
||||||
|
}
|
||||||
|
|
||||||
|
void runOnOperation() override {
|
||||||
|
OwningRewritePatternList patterns;
|
||||||
|
auto& context = getContext();
|
||||||
|
ConversionTarget target(context);
|
||||||
|
target.addLegalDialect<xla_lhlo::XlaLhloDialect>();
|
||||||
|
target.addLegalDialect<StandardOpsDialect>();
|
||||||
|
target.addLegalOp<ModuleOp>();
|
||||||
|
target.addIllegalOp<mlir::TensorLoadOp>();
|
||||||
|
target.addIllegalOp<mlir::TensorStoreOp>();
|
||||||
|
target.addLegalOp<ModuleTerminatorOp>();
|
||||||
|
target.addLegalOp<TensorFromElementsOp>();
|
||||||
|
target.addIllegalDialect<xla_hlo::XlaHloDialect>();
|
||||||
|
|
||||||
|
BufferAssignmentTypeConverter converter;
|
||||||
|
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
|
||||||
|
auto inputs = op.getType().getInputs();
|
||||||
|
return llvm::all_of(inputs,
|
||||||
|
[](Type input) { return input.isa<MemRefType>(); }) &&
|
||||||
|
converter.isLegal(&op.getBody());
|
||||||
|
});
|
||||||
|
target.addDynamicallyLegalOp<mlir::ReturnOp>([&](mlir::ReturnOp returnOp) {
|
||||||
|
return std::all_of(returnOp.operand_type_begin(),
|
||||||
|
returnOp.operand_type_end(),
|
||||||
|
[](Type type) { return type.isa<MemRefType>(); });
|
||||||
|
});
|
||||||
|
|
||||||
|
auto module = getOperation();
|
||||||
|
WalkResult result = module.walk([&](FuncOp func) -> WalkResult {
|
||||||
|
BufferAssignmentPlacer bufferAssignment(func);
|
||||||
|
OwningRewritePatternList patterns;
|
||||||
|
populateHLOToLHLOConversionPattern(func.getContext(), &bufferAssignment,
|
||||||
|
&converter, &patterns);
|
||||||
|
if (results_escape_function) {
|
||||||
|
populateWithBufferAssignmentOpConversionPatterns<
|
||||||
|
mlir::ReturnOp, mlir::ReturnOp, xla_lhlo::CopyOp,
|
||||||
|
/*allowMemrefFunctionResults=*/true>(&context, &bufferAssignment,
|
||||||
|
&converter, &patterns);
|
||||||
|
} else {
|
||||||
|
populateWithBufferAssignmentOpConversionPatterns<
|
||||||
|
mlir::ReturnOp, mlir::ReturnOp, xla_lhlo::CopyOp,
|
||||||
|
/*allowMemrefFunctionResults=*/false>(&context, &bufferAssignment,
|
||||||
|
&converter, &patterns);
|
||||||
|
}
|
||||||
|
return applyPartialConversion(func, target, patterns);
|
||||||
|
});
|
||||||
|
if (result.wasInterrupted()) {
|
||||||
|
signalPassFailure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
Option<bool> results_escape_function{
|
||||||
|
*this, "results-escape-function",
|
||||||
|
llvm::cl::desc(
|
||||||
|
"Allocate the results of functions within the functions body"),
|
||||||
|
llvm::cl::init(false)};
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void populateHLOToLHLOConversionPattern(
|
||||||
|
MLIRContext* context, BufferAssignmentPlacer* bufferAssignment,
|
||||||
|
TypeConverter* converter, OwningRewritePatternList* patterns) {
|
||||||
|
// clang-format off
|
||||||
|
patterns->insert<
|
||||||
|
HloToLhloDynamicBroadcastInDimOpConverter,
|
||||||
|
HloToLhloOpConverter<xla_hlo::AbsOp>,
|
||||||
|
HloToLhloOpConverter<xla_hlo::AddOp>,
|
||||||
|
HloToLhloOpConverter<xla_hlo::AndOp>,
|
||||||
|
HloToLhloOpConverter<xla_hlo::BroadcastInDimOp>,
|
||||||
|
HloToLhloOpConverter<xla_hlo::CeilOp>,
|
||||||
|
HloToLhloOpConverter<xla_hlo::CompareOp>,
|
||||||
|
HloToLhloOpConverter<xla_hlo::ComplexOp>,
|
||||||
|
HloToLhloOpConverter<xla_hlo::ConstOp>,
|
||||||
|
HloToLhloOpConverter<xla_hlo::ConvOp>,
|
||||||
|
HloToLhloOpConverter<xla_hlo::ConvertOp>,
|
||||||
|
HloToLhloOpConverter<xla_hlo::CopyOp>,
|
||||||
|
HloToLhloOpConverter<xla_hlo::CosOp>,
|
||||||
|
HloToLhloOpConverter<xla_hlo::DivOp>,
|
||||||
|
HloToLhloOpConverter<xla_hlo::DotOp>,
|
||||||
|
HloToLhloOpConverter<xla_hlo::ExpOp>,
|
||||||
|
HloToLhloOpConverter<xla_hlo::GatherOp>,
|
||||||
|
HloToLhloOpConverter<xla_hlo::ImagOp>,
|
||||||
|
HloToLhloOpConverter<xla_hlo::IotaOp>,
|
||||||
|
HloToLhloOpConverter<xla_hlo::LogOp>,
|
||||||
|
HloToLhloOpConverter<xla_hlo::MaxOp>,
|
||||||
|
HloToLhloOpConverter<xla_hlo::MinOp>,
|
||||||
|
HloToLhloOpConverter<xla_hlo::MulOp>,
|
||||||
|
HloToLhloOpConverter<xla_hlo::NegOp>,
|
||||||
|
HloToLhloOpConverter<xla_hlo::RealOp>,
|
||||||
|
HloToLhloOpConverter<xla_hlo::RemOp>,
|
||||||
|
HloToLhloOpConverter<xla_hlo::RsqrtOp>,
|
||||||
|
HloToLhloOpConverter<xla_hlo::ReshapeOp>,
|
||||||
|
HloToLhloOpConverter<xla_hlo::SelectOp>,
|
||||||
|
HloToLhloOpConverter<xla_hlo::SignOp>,
|
||||||
|
HloToLhloOpConverter<xla_hlo::SqrtOp>,
|
||||||
|
HloToLhloOpConverter<xla_hlo::SubOp>,
|
||||||
|
HloToLhloOpConverter<xla_hlo::TanhOp>,
|
||||||
|
HloToLhloReduceOpConverter,
|
||||||
|
HloToLhloTensorLoadOpConverter,
|
||||||
|
HloToLhloTensorStoreOpConverter
|
||||||
|
>(context, bufferAssignment, converter);
|
||||||
|
// clang-format on
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass(
|
||||||
|
bool results_escape_function) {
|
||||||
|
return absl::make_unique<HloLegalizeToLhlo>(results_escape_function);
|
||||||
|
}
|
||||||
|
|
||||||
|
static PassRegistration<HloLegalizeToLhlo> legalize_pass(
|
||||||
|
"hlo-legalize-to-lhlo", "Legalize from HLO dialect to LHLO dialect");
|
||||||
|
|
||||||
|
} // namespace xla_hlo
|
||||||
|
} // namespace mlir
|
|
@ -0,0 +1,237 @@
|
||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
// This file implements logic for lowering XLA dialect to Standard dialect.
|
||||||
|
|
||||||
|
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLExtras.h"
|
||||||
|
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringSwitch.h"
|
||||||
|
#include "third_party/llvm/llvm-project/llvm/include/llvm/Support/Casting.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Block.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/BlockAndValueMapping.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Function.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/TypeUtilities.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/PassRegistry.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Support/LogicalResult.h"
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||||
|
|
||||||
|
using mlir::PassRegistration;
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace xla_hlo {
|
||||||
|
namespace {
|
||||||
|
struct LegalizeControlFlow
|
||||||
|
: public mlir::PassWrapper<LegalizeControlFlow, FunctionPass> {
|
||||||
|
// Perform the lowering to MLIR control flow.
|
||||||
|
void runOnFunction() override;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Replaces terminators for the newly created blocks from a targe region.
|
||||||
|
// These terminators are replaced with branch operations to a target block.
|
||||||
|
LogicalResult ReplaceTerminators(Region* region, Block* target_block,
|
||||||
|
Location loc,
|
||||||
|
const BlockAndValueMapping& mapper,
|
||||||
|
OpBuilder* builder) {
|
||||||
|
for (auto& old_block : region->getBlocks()) {
|
||||||
|
Block* block = mapper.lookup(&old_block);
|
||||||
|
auto return_op = dyn_cast<xla_hlo::ReturnOp>(block->getTerminator());
|
||||||
|
if (!return_op) continue;
|
||||||
|
builder->setInsertionPointToEnd(block);
|
||||||
|
builder->create<mlir::BranchOp>(loc, target_block, return_op.getOperands());
|
||||||
|
return_op.erase();
|
||||||
|
}
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult LowerIfOp(mlir::xla_hlo::IfOp if_op) {
|
||||||
|
Operation* op_inst = if_op.getOperation();
|
||||||
|
mlir::OpBuilder builder(if_op);
|
||||||
|
auto orig_block = op_inst->getBlock();
|
||||||
|
auto* tail_block = orig_block->splitBlock(op_inst);
|
||||||
|
auto loc = if_op.getLoc();
|
||||||
|
|
||||||
|
// Duplicate the true and false regions in the block between the sections
|
||||||
|
// before and after the conditional.
|
||||||
|
BlockAndValueMapping mapper;
|
||||||
|
if_op.true_branch().cloneInto(orig_block->getParent(),
|
||||||
|
Region::iterator(tail_block), mapper);
|
||||||
|
if_op.false_branch().cloneInto(orig_block->getParent(),
|
||||||
|
Region::iterator(tail_block), mapper);
|
||||||
|
|
||||||
|
// Determine the blocks for the start of the true and false regions.
|
||||||
|
Block* true_block = mapper.lookup(&if_op.true_branch().front());
|
||||||
|
Block* false_block = mapper.lookup(&if_op.false_branch().front());
|
||||||
|
|
||||||
|
// Perform the conditional branch into the true/false cases.
|
||||||
|
builder.setInsertionPointToEnd(orig_block);
|
||||||
|
|
||||||
|
// Extract the predicate for checking branching, then branch to the true and
|
||||||
|
// false regions appropriately.
|
||||||
|
auto cond_value = builder.create<mlir::ExtractElementOp>(loc, if_op.pred());
|
||||||
|
builder.create<mlir::CondBranchOp>(loc, cond_value, true_block,
|
||||||
|
if_op.true_arg(), false_block,
|
||||||
|
if_op.false_arg());
|
||||||
|
|
||||||
|
// Replace the true case's return operations with a branch to the tail of
|
||||||
|
// the condition.
|
||||||
|
if (failed(ReplaceTerminators(&if_op.true_branch(), tail_block, loc, mapper,
|
||||||
|
&builder)))
|
||||||
|
return failure();
|
||||||
|
if (failed(ReplaceTerminators(&if_op.false_branch(), tail_block, loc, mapper,
|
||||||
|
&builder)))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
tail_block->addArguments(if_op.getResult().getType());
|
||||||
|
if_op.getResult().replaceAllUsesWith(tail_block->getArgument(0));
|
||||||
|
|
||||||
|
op_inst->erase();
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult LowerWhileOp(mlir::xla_hlo::WhileOp while_op) {
|
||||||
|
// Converts an XLA while loop into control flow. This generates a set of MLIR
|
||||||
|
// blocks and branches, along with inlining the regions provided by the XLA
|
||||||
|
// while loop. The structure should be similar to below:
|
||||||
|
//
|
||||||
|
// <prior operations>
|
||||||
|
// %0 = "xla_hlo.while"(%arg0) {^cond(...){...}, ^body(...){...}}
|
||||||
|
// <post operations>
|
||||||
|
auto* op_inst = while_op.getOperation();
|
||||||
|
mlir::OpBuilder builder(while_op);
|
||||||
|
auto loc = while_op.getLoc();
|
||||||
|
|
||||||
|
// Break the block into four sections:
|
||||||
|
// orig_block - operations before the while and the branch into looping check.
|
||||||
|
// tail_block - operations after the while loop completes.
|
||||||
|
// cond_block - check the looping condition, then conditionally branch into
|
||||||
|
// the loop or, if condition is false, jump to the tail branch.
|
||||||
|
// body_block - inlined loop body, then jump back to the condition block.
|
||||||
|
auto* orig_block = op_inst->getBlock();
|
||||||
|
auto* tail_block = orig_block->splitBlock(op_inst);
|
||||||
|
|
||||||
|
BlockAndValueMapping mapper;
|
||||||
|
while_op.cond().cloneInto(orig_block->getParent(),
|
||||||
|
Region::iterator(tail_block), mapper);
|
||||||
|
while_op.body().cloneInto(orig_block->getParent(),
|
||||||
|
Region::iterator(tail_block), mapper);
|
||||||
|
|
||||||
|
// Lookup the entry blocks for both condition and body.
|
||||||
|
auto* cond_block = mapper.lookup(&while_op.cond().front());
|
||||||
|
auto* body_block = mapper.lookup(&while_op.body().front());
|
||||||
|
|
||||||
|
// Setup the end of the original block:
|
||||||
|
// <prior operations>
|
||||||
|
// br ^cond(%arg0) // Jumps to the condition statement.
|
||||||
|
builder.setInsertionPointToEnd(orig_block);
|
||||||
|
builder.create<mlir::BranchOp>(loc, cond_block, while_op.getOperand());
|
||||||
|
|
||||||
|
// Updates the inlined condition blocks by replacing the return op with an
|
||||||
|
// extract_element and conditional branch. This changes the block below:
|
||||||
|
// ^cond(%0):
|
||||||
|
// <inlined conditional region>
|
||||||
|
// "xla_hlo".return(%1)
|
||||||
|
//
|
||||||
|
// Into:
|
||||||
|
// ^cond(%0):
|
||||||
|
// <inlined conditional region>
|
||||||
|
// %2 = extract_element %1[] : tensor<i1> // Extract the condition value.
|
||||||
|
// cond_br %2, ^body(%0), ^tail(%0) // Branch.
|
||||||
|
builder.setInsertionPointToStart(cond_block);
|
||||||
|
|
||||||
|
// Replace the xla_hlo::ReturnOp with a branch back to the condition block.
|
||||||
|
// This is required as the xla_hlo::ReturnOp is used to mark the end of a
|
||||||
|
// block for regions nested inside of a operations (MLIR ReturnOp cannot be
|
||||||
|
// nested within an non-function region).
|
||||||
|
for (auto& block : while_op.cond()) {
|
||||||
|
auto new_block = mapper.lookup(&block);
|
||||||
|
|
||||||
|
auto return_op = dyn_cast<xla_hlo::ReturnOp>(new_block->getTerminator());
|
||||||
|
if (!return_op) continue;
|
||||||
|
builder.setInsertionPointToEnd(new_block);
|
||||||
|
|
||||||
|
auto return_value = return_op.getOperand(0);
|
||||||
|
auto cond_value = builder.create<mlir::ExtractElementOp>(loc, return_value);
|
||||||
|
|
||||||
|
// Get the body block arguments.
|
||||||
|
llvm::SmallVector<Value, 4> successor_args(cond_block->args_begin(),
|
||||||
|
cond_block->args_end());
|
||||||
|
builder.create<mlir::CondBranchOp>(loc, cond_value, body_block,
|
||||||
|
successor_args, tail_block,
|
||||||
|
successor_args);
|
||||||
|
return_op.erase();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Updates the body blocks by replace the return op with an branch to the
|
||||||
|
// conditional block. This changes the block below:
|
||||||
|
// ^body(%0):
|
||||||
|
// <inlined body block>
|
||||||
|
// "xla_hlo".return(%1)
|
||||||
|
//
|
||||||
|
// Into:
|
||||||
|
// ^body(%0):
|
||||||
|
// <inlined body block>
|
||||||
|
// br ^cond(%0) // Branch.
|
||||||
|
for (auto& block : while_op.body()) {
|
||||||
|
auto new_block = mapper.lookup(&block);
|
||||||
|
auto return_op =
|
||||||
|
dyn_cast<mlir::xla_hlo::ReturnOp>(new_block->getTerminator());
|
||||||
|
if (!return_op) continue;
|
||||||
|
builder.setInsertionPointToEnd(new_block);
|
||||||
|
builder.create<mlir::BranchOp>(loc, cond_block, return_op.getOperands());
|
||||||
|
return_op.erase();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Erase the original while loop.
|
||||||
|
tail_block->addArgument(while_op.getType());
|
||||||
|
while_op.getResult().replaceAllUsesWith(tail_block->getArgument(0));
|
||||||
|
op_inst->erase();
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
void LegalizeControlFlow::runOnFunction() {
|
||||||
|
auto func = getFunction();
|
||||||
|
llvm::SmallVector<IfOp, 4> if_ops;
|
||||||
|
func.walk([&](IfOp op) { if_ops.push_back(op); });
|
||||||
|
|
||||||
|
for (auto& op : if_ops) {
|
||||||
|
if (failed(LowerIfOp(op))) return signalPassFailure();
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::SmallVector<WhileOp, 4> while_ops;
|
||||||
|
func.walk([&](WhileOp op) { while_ops.push_back(op); });
|
||||||
|
|
||||||
|
for (auto& op : while_ops) {
|
||||||
|
if (failed(LowerWhileOp(op))) return signalPassFailure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
} // namespace xla_hlo
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
std::unique_ptr<mlir::OperationPass<mlir::FuncOp>>
|
||||||
|
mlir::xla_hlo::createLegalizeControlFlowPass() {
|
||||||
|
return std::make_unique<LegalizeControlFlow>();
|
||||||
|
}
|
||||||
|
|
||||||
|
static PassRegistration<mlir::xla_hlo::LegalizeControlFlow> legalize_cf_pass(
|
||||||
|
"xla-legalize-control-flow",
|
||||||
|
"Legalize from XLA control flow to MLIR control flow");
|
|
@ -0,0 +1,156 @@
|
||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
// This file implements logic for lowering the tanh standard ops to an
|
||||||
|
// approximation.
|
||||||
|
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Function.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace xla {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
/// Emits the fast tanh approximation that is also used by XLA.
|
||||||
|
Value EmitTanhApproximation(Value input, Location loc,
|
||||||
|
PatternRewriter &rewriter) {
|
||||||
|
// For small values of x, we can approximate tanh(x)=x. For extremely small
|
||||||
|
// values of x (|x| < 1e-37), the other approximation would evaluate
|
||||||
|
// tanh(x) = 0.
|
||||||
|
constexpr float kCanUseApprox = 0.0004;
|
||||||
|
Value abs_value = rewriter.create<AbsFOp>(loc, input);
|
||||||
|
Value can_use_approx =
|
||||||
|
rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(kCanUseApprox));
|
||||||
|
Value return_input = rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT,
|
||||||
|
abs_value, can_use_approx);
|
||||||
|
// Clamp the input to [-c, c].
|
||||||
|
Value max_clamp = rewriter.create<ConstantOp>(
|
||||||
|
loc, rewriter.getF32FloatAttr(7.90531110763549805f));
|
||||||
|
Value smaller_than_max =
|
||||||
|
rewriter.create<CmpFOp>(loc, CmpFPredicate::ULE, input, max_clamp);
|
||||||
|
Value clamped_half =
|
||||||
|
rewriter.create<SelectOp>(loc, smaller_than_max, input, max_clamp);
|
||||||
|
Value min_clamp = rewriter.create<ConstantOp>(
|
||||||
|
loc, rewriter.getF32FloatAttr(-7.90531110763549805f));
|
||||||
|
Value larger_than_min =
|
||||||
|
rewriter.create<CmpFOp>(loc, CmpFPredicate::UGE, clamped_half, min_clamp);
|
||||||
|
Value input_clamped =
|
||||||
|
rewriter.create<SelectOp>(loc, larger_than_min, clamped_half, min_clamp);
|
||||||
|
|
||||||
|
static constexpr std::array<float, 7> numerator_coeffs{
|
||||||
|
-2.76076847742355e-16f, 2.00018790482477e-13f, -8.60467152213735e-11f,
|
||||||
|
5.12229709037114e-08f, 1.48572235717979e-05f, 6.37261928875436e-04f,
|
||||||
|
4.89352455891786e-03f};
|
||||||
|
|
||||||
|
static constexpr std::array<float, 4> denominator_coeffs{
|
||||||
|
1.19825839466702e-06f, 1.18534705686654e-04f, 2.26843463243900e-03f,
|
||||||
|
4.89352518554385e-03f};
|
||||||
|
|
||||||
|
Value input_squared =
|
||||||
|
rewriter.create<MulFOp>(loc, input_clamped, input_clamped);
|
||||||
|
Value numerator = rewriter.create<ConstantOp>(
|
||||||
|
loc, rewriter.getF32FloatAttr(numerator_coeffs[0]));
|
||||||
|
for (int i = 1; i < numerator_coeffs.size(); i++) {
|
||||||
|
numerator = rewriter.create<AddFOp>(
|
||||||
|
loc, rewriter.create<MulFOp>(loc, input_squared, numerator),
|
||||||
|
rewriter.create<ConstantOp>(
|
||||||
|
loc, rewriter.getF32FloatAttr(numerator_coeffs[i])));
|
||||||
|
}
|
||||||
|
|
||||||
|
numerator = rewriter.create<MulFOp>(loc, input_clamped, numerator);
|
||||||
|
|
||||||
|
Value denominator = rewriter.create<ConstantOp>(
|
||||||
|
loc, rewriter.getF32FloatAttr(denominator_coeffs[0]));
|
||||||
|
for (int i = 1; i < denominator_coeffs.size(); i++) {
|
||||||
|
denominator = rewriter.create<AddFOp>(
|
||||||
|
loc, rewriter.create<MulFOp>(loc, input_squared, denominator),
|
||||||
|
rewriter.create<ConstantOp>(
|
||||||
|
loc, rewriter.getF32FloatAttr(denominator_coeffs[i])));
|
||||||
|
}
|
||||||
|
|
||||||
|
Value approx = rewriter.create<DivFOp>(loc, numerator, denominator);
|
||||||
|
|
||||||
|
return rewriter.create<SelectOp>(loc, return_input, input, approx);
|
||||||
|
}
|
||||||
|
|
||||||
|
class ApproximateTanhLowering : public OpRewritePattern<TanhOp> {
|
||||||
|
public:
|
||||||
|
explicit ApproximateTanhLowering(MLIRContext *ctx)
|
||||||
|
: OpRewritePattern<TanhOp>(ctx, 100) {}
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(TanhOp tanhOp,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
Type operand_type = tanhOp.getType();
|
||||||
|
|
||||||
|
if (operand_type.isF64()) {
|
||||||
|
// Similar to XLA, do not rewrite f64 as precision might matter.
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
Location loc = tanhOp.getLoc();
|
||||||
|
Value input = tanhOp.operand();
|
||||||
|
if (operand_type.isF16()) {
|
||||||
|
input = rewriter.create<FPExtOp>(loc, input, rewriter.getF32Type());
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we still do not have f32, fail.
|
||||||
|
if (!input.getType().isF32()) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
Value result = EmitTanhApproximation(input, loc, rewriter);
|
||||||
|
|
||||||
|
// Truncate back if needed.
|
||||||
|
if (operand_type.isF16()) {
|
||||||
|
result = rewriter.create<FPTruncOp>(loc, result, rewriter.getF16Type());
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.replaceOp(tanhOp, {result});
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct LegalizeTanhToApproximation
|
||||||
|
: public PassWrapper<LegalizeTanhToApproximation, FunctionPass> {
|
||||||
|
/// Perform the lowering of standard dialect operations to approximations.
|
||||||
|
void runOnFunction() override {
|
||||||
|
OwningRewritePatternList patterns;
|
||||||
|
PopulateTanhToApproximationPatterns(&getContext(), &patterns);
|
||||||
|
applyPatternsAndFoldGreedily(getFunction(), patterns);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // anonymous namespace
|
||||||
|
|
||||||
|
std::unique_ptr<mlir::OperationPass<mlir::FuncOp>>
|
||||||
|
createLegalizeTanhToApproximationPass() {
|
||||||
|
return std::make_unique<LegalizeTanhToApproximation>();
|
||||||
|
}
|
||||||
|
|
||||||
|
void PopulateTanhToApproximationPatterns(mlir::MLIRContext *context,
|
||||||
|
OwningRewritePatternList *patterns) {
|
||||||
|
patterns->insert<ApproximateTanhLowering>(context);
|
||||||
|
}
|
||||||
|
|
||||||
|
static PassRegistration<LegalizeTanhToApproximation> legalize_pass(
|
||||||
|
"xla-legalize-tanh-to-approximation",
|
||||||
|
"Legalize tanh from standard dialect to an approximation");
|
||||||
|
|
||||||
|
} // namespace xla
|
||||||
|
} // namespace mlir
|
|
@ -0,0 +1,208 @@
|
||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
// This file implements logic for lowering XLA dialect to Standard dialect.
|
||||||
|
|
||||||
|
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringSwitch.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Function.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace {
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/generated_legalize_to_standard.inc"
|
||||||
|
} // end anonymous namespace
|
||||||
|
namespace xla_hlo {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
class CompareIConvert : public OpRewritePattern<xla_hlo::CompareOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(xla_hlo::CompareOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
auto lhs = op.lhs();
|
||||||
|
auto rhs = op.rhs();
|
||||||
|
auto lhs_type = lhs.getType().cast<TensorType>();
|
||||||
|
auto rhs_type = rhs.getType().cast<TensorType>();
|
||||||
|
|
||||||
|
// Broadcasting not supported by this rewrite.
|
||||||
|
if (lhs_type.getShape() != rhs_type.getShape()) return failure();
|
||||||
|
|
||||||
|
if (!lhs_type.getElementType().isSignlessInteger() ||
|
||||||
|
!rhs_type.getElementType().isSignlessInteger())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto comparison_direction = op.comparison_direction();
|
||||||
|
auto compare_predicate =
|
||||||
|
llvm::StringSwitch<Optional<CmpIPredicate>>(comparison_direction)
|
||||||
|
.Case("EQ", CmpIPredicate::eq)
|
||||||
|
.Case("NE", CmpIPredicate::ne)
|
||||||
|
.Case("LT", CmpIPredicate::slt)
|
||||||
|
.Case("LE", CmpIPredicate::sle)
|
||||||
|
.Case("GT", CmpIPredicate::sgt)
|
||||||
|
.Case("GE", CmpIPredicate::sge)
|
||||||
|
.Default(llvm::None);
|
||||||
|
|
||||||
|
if (!compare_predicate.hasValue()) return failure();
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<CmpIOp>(op, compare_predicate.getValue(), lhs,
|
||||||
|
rhs);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class CompareFConvert : public OpRewritePattern<xla_hlo::CompareOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(xla_hlo::CompareOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
auto lhs = op.lhs();
|
||||||
|
auto rhs = op.rhs();
|
||||||
|
auto lhs_type = lhs.getType().cast<TensorType>();
|
||||||
|
auto rhs_type = rhs.getType().cast<TensorType>();
|
||||||
|
|
||||||
|
// Broadcasting not supported by this rewrite.
|
||||||
|
if (lhs_type.getShape() != rhs_type.getShape()) return failure();
|
||||||
|
|
||||||
|
if (!lhs_type.getElementType().isa<FloatType>() ||
|
||||||
|
!rhs_type.getElementType().isa<FloatType>())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto comparison_direction = op.comparison_direction();
|
||||||
|
auto compare_predicate =
|
||||||
|
llvm::StringSwitch<Optional<CmpFPredicate>>(comparison_direction)
|
||||||
|
.Case("EQ", CmpFPredicate::OEQ)
|
||||||
|
.Case("NE", CmpFPredicate::UNE)
|
||||||
|
.Case("LT", CmpFPredicate::OLT)
|
||||||
|
.Case("LE", CmpFPredicate::OLE)
|
||||||
|
.Case("GT", CmpFPredicate::OGT)
|
||||||
|
.Case("GE", CmpFPredicate::OGE)
|
||||||
|
.Default(llvm::None);
|
||||||
|
|
||||||
|
if (!compare_predicate.hasValue()) return failure();
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<CmpFOp>(op, compare_predicate.getValue(), lhs,
|
||||||
|
rhs);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Replace IotaOp with an integer constant. A ConvertOp is added to
|
||||||
|
// convert the integer constant to iota result type. For complex types, the real
|
||||||
|
// part is replaced with the generated constant and the imaginary part is
|
||||||
|
// replaced with zero tensor.
|
||||||
|
class ConvertIotaOp : public OpRewritePattern<xla_hlo::IotaOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(xla_hlo::IotaOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
auto output_type = op.getType().cast<ShapedType>();
|
||||||
|
auto output_size = output_type.getNumElements();
|
||||||
|
auto dimension = op.iota_dimension().getSExtValue();
|
||||||
|
auto max_dim_size = output_type.getDimSize(dimension);
|
||||||
|
|
||||||
|
auto element_type = output_type.getElementType();
|
||||||
|
int bitwidth;
|
||||||
|
|
||||||
|
auto complex_ty = element_type.dyn_cast<ComplexType>();
|
||||||
|
Type int_or_float_ty = element_type;
|
||||||
|
if (complex_ty) int_or_float_ty = complex_ty.getElementType();
|
||||||
|
|
||||||
|
bitwidth = int_or_float_ty.getIntOrFloatBitWidth();
|
||||||
|
llvm::SmallVector<APInt, 10> values;
|
||||||
|
values.reserve(output_size);
|
||||||
|
|
||||||
|
int64_t increase_stride = output_size;
|
||||||
|
for (int i = 0; i <= dimension; i++) {
|
||||||
|
increase_stride /= output_type.getDimSize(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t current_value = 0;
|
||||||
|
for (int i = 0; i < output_size; i++) {
|
||||||
|
int64_t value = (current_value / increase_stride) % max_dim_size;
|
||||||
|
values.push_back(APInt(bitwidth, value));
|
||||||
|
++current_value;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto int_shape_type = RankedTensorType::get(
|
||||||
|
output_type.getShape(),
|
||||||
|
IntegerType::get(bitwidth, rewriter.getContext()));
|
||||||
|
auto loc = op.getLoc();
|
||||||
|
auto integer_const = rewriter.create<mlir::ConstantOp>(
|
||||||
|
loc, DenseIntElementsAttr::get(int_shape_type, values));
|
||||||
|
|
||||||
|
auto int_or_float_shape_ty =
|
||||||
|
RankedTensorType::get(output_type.getShape(), int_or_float_ty);
|
||||||
|
|
||||||
|
auto iota_const =
|
||||||
|
rewriter.create<ConvertOp>(loc, int_or_float_shape_ty, integer_const);
|
||||||
|
|
||||||
|
// For int/float types we are done, replace op and return.
|
||||||
|
if (!complex_ty) {
|
||||||
|
rewriter.replaceOp(op, iota_const.getResult());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
// For complex types, generate a constant tensor of zeroes for the imaginary
|
||||||
|
// part and use iota_const for real part.
|
||||||
|
auto zeroes = rewriter.create<mlir::ConstantOp>(
|
||||||
|
loc, DenseIntElementsAttr::get(int_shape_type, APInt(bitwidth, 0)));
|
||||||
|
auto imag_zeroes =
|
||||||
|
rewriter.create<ConvertOp>(loc, int_or_float_shape_ty, zeroes);
|
||||||
|
rewriter.replaceOpWithNewOp<xla_hlo::ComplexOp>(op, iota_const,
|
||||||
|
imag_zeroes);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // end anonymous namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
struct LegalizeToStandard
|
||||||
|
: public PassWrapper<LegalizeToStandard, FunctionPass> {
|
||||||
|
/// Perform the lowering to Standard dialect.
|
||||||
|
void runOnFunction() override;
|
||||||
|
};
|
||||||
|
} // end anonymous namespace
|
||||||
|
|
||||||
|
std::unique_ptr<mlir::OperationPass<mlir::FuncOp>> createLegalizeToStdPass() {
|
||||||
|
return std::make_unique<LegalizeToStandard>();
|
||||||
|
}
|
||||||
|
|
||||||
|
void PopulateXlaToStdPatterns(OwningRewritePatternList *patterns,
|
||||||
|
mlir::MLIRContext *ctx) {
|
||||||
|
mlir::populateWithGenerated(ctx, patterns);
|
||||||
|
patterns->insert<CompareFConvert, CompareIConvert, ConvertIotaOp>(ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Perform the lowering to standard dialect.
|
||||||
|
void LegalizeToStandard::runOnFunction() {
|
||||||
|
OwningRewritePatternList patterns;
|
||||||
|
mlir::xla_hlo::PopulateXlaToStdPatterns(&patterns, &getContext());
|
||||||
|
applyPatternsAndFoldGreedily(getFunction(), patterns);
|
||||||
|
}
|
||||||
|
|
||||||
|
static PassRegistration<LegalizeToStandard> legalize_pass(
|
||||||
|
"xla-legalize-to-std", "Legalize from XLA dialect to standard dialect");
|
||||||
|
|
||||||
|
} // end namespace xla_hlo
|
||||||
|
} // end namespace mlir
|
|
@ -0,0 +1,71 @@
|
||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
// This is the legalization pattern definition file for XLA to StandardOps.
|
||||||
|
|
||||||
|
include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpBase.td"
|
||||||
|
include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td"
|
||||||
|
include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td"
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Nullary op patterns.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def : Pat<(HLO_ConstOp ElementsAttr:$value),
|
||||||
|
(ConstantOp $value)>;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Binary op patterns.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
def IsSameSizePred : CPred<
|
||||||
|
"$0.getType().cast<ShapedType>().getShape() "
|
||||||
|
"== $1.getType().cast<ShapedType>().getShape()">;
|
||||||
|
def IsSameSizeConstraint : Constraint<IsSameSizePred, "inputs are same size">;
|
||||||
|
|
||||||
|
|
||||||
|
def : Pat<(HLO_AndOp HLO_PredTensor:$l, HLO_PredTensor:$r),
|
||||||
|
(AndOp $l, $r),
|
||||||
|
[(IsSameSizeConstraint $l, $r)]>;
|
||||||
|
def : Pat<(HLO_AddOp HLO_FpTensor:$l, HLO_FpTensor:$r),
|
||||||
|
(AddFOp $l, $r),
|
||||||
|
[(IsSameSizeConstraint $l, $r)]>;
|
||||||
|
def : Pat<(HLO_SubOp HLO_FpTensor:$l, HLO_FpTensor:$r),
|
||||||
|
(SubFOp $l, $r),
|
||||||
|
[(IsSameSizeConstraint $l, $r)]>;
|
||||||
|
def : Pat<(HLO_MulOp HLO_FpTensor:$l, HLO_FpTensor:$r),
|
||||||
|
(MulFOp $l, $r),
|
||||||
|
[(IsSameSizeConstraint $l, $r)]>;
|
||||||
|
def : Pat<(HLO_DivOp HLO_FpTensor:$l, HLO_FpTensor:$r),
|
||||||
|
(DivFOp $l, $r),
|
||||||
|
[(IsSameSizeConstraint $l, $r)]>;
|
||||||
|
def : Pat<(HLO_RemOp HLO_FpTensor:$l, HLO_FpTensor:$r),
|
||||||
|
(RemFOp $l, $r),
|
||||||
|
[(IsSameSizeConstraint $l, $r)]>;
|
||||||
|
def : Pat<(HLO_AddOp HLO_IntTensor:$l, HLO_IntTensor:$r),
|
||||||
|
(AddIOp $l, $r),
|
||||||
|
[(IsSameSizeConstraint $l, $r)]>;
|
||||||
|
def : Pat<(HLO_SubOp HLO_IntTensor:$l, HLO_IntTensor:$r),
|
||||||
|
(SubIOp $l, $r),
|
||||||
|
[(IsSameSizeConstraint $l, $r)]>;
|
||||||
|
def : Pat<(HLO_MulOp HLO_IntTensor:$l, HLO_IntTensor:$r),
|
||||||
|
(MulIOp $l, $r),
|
||||||
|
[(IsSameSizeConstraint $l, $r)]>;
|
||||||
|
def : Pat<(HLO_DivOp HLO_IntTensor:$l, HLO_IntTensor:$r),
|
||||||
|
(SignedDivIOp $l, $r),
|
||||||
|
[(IsSameSizeConstraint $l, $r)]>;
|
||||||
|
def : Pat<(HLO_RemOp HLO_IntTensor:$l, HLO_IntTensor:$r),
|
||||||
|
(SignedRemIOp $l, $r),
|
||||||
|
[(IsSameSizeConstraint $l, $r)]>;
|
|
@ -0,0 +1,105 @@
|
||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
// This file implements a pass to remove redundant LHLO copy operations.
|
||||||
|
|
||||||
|
#include "third_party/absl/memory/memory.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace xla_lhlo {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// Removes LHLO copy operations that copy from allocated buffers to block
|
||||||
|
// arguments. All uses of each buffer are replaced with the corresponding block
|
||||||
|
// argument and the buffer is freed. Note that this pass only works in regions
|
||||||
|
// with a single block.
|
||||||
|
struct LhloCopyRemoval : mlir::PassWrapper<LhloCopyRemoval, OperationPass<>> {
|
||||||
|
void runOnOperation() override {
|
||||||
|
llvm::SmallVector<mlir::Operation*, 2> eraseList;
|
||||||
|
auto operation = getOperation();
|
||||||
|
operation->walk([&](mlir::xla_lhlo::CopyOp copyOp) {
|
||||||
|
// If this region contains more than one block, then ignore this copy
|
||||||
|
// operation.
|
||||||
|
if (copyOp.getParentRegion()->getBlocks().size() > 1) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
mlir::Value fromOperand = copyOp.operand();
|
||||||
|
mlir::Value toOperand = copyOp.output();
|
||||||
|
|
||||||
|
// If the fromOperand value is a block argument or the toOperand
|
||||||
|
// value is not a block argument, then ignore this copy operation.
|
||||||
|
if (!fromOperand.getDefiningOp() || toOperand.getDefiningOp()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// The copy operation removal is illegal if there is at least a single use
|
||||||
|
// of toOperand value that lies between the first use of fromOperand value
|
||||||
|
// and the copy operation.
|
||||||
|
auto fromOperandUsers = fromOperand.getUsers();
|
||||||
|
auto firstUser = *fromOperandUsers.begin();
|
||||||
|
for (auto op : fromOperandUsers) {
|
||||||
|
if (op->isBeforeInBlock(firstUser)) firstUser = op;
|
||||||
|
}
|
||||||
|
for (auto op : toOperand.getUsers()) {
|
||||||
|
if (op->isBeforeInBlock(copyOp) && firstUser->isBeforeInBlock(op)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(DFKI): Use live variable analysis to solve aliasing issues among
|
||||||
|
// block arguments.
|
||||||
|
|
||||||
|
// Remove the associated alloc operation.
|
||||||
|
auto allocOp = fromOperand.getDefiningOp();
|
||||||
|
eraseList.push_back(allocOp);
|
||||||
|
|
||||||
|
// Iterate over all uses of the fromOperand to find the associated
|
||||||
|
// deallocOp (if any).
|
||||||
|
for (auto op : fromOperandUsers) {
|
||||||
|
if (isa<mlir::DeallocOp>(op)) {
|
||||||
|
eraseList.push_back(op);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Replace all uses of the fromOperand with the toOperand. This rewires
|
||||||
|
// all references pointing to the original alloc operation to the new
|
||||||
|
// target operation in order to safely remove the copy op.
|
||||||
|
fromOperand.replaceAllUsesWith(toOperand);
|
||||||
|
copyOp.erase();
|
||||||
|
});
|
||||||
|
for (auto op : eraseList) {
|
||||||
|
op->erase();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<Pass> createLhloCopyRemovalPass() {
|
||||||
|
return absl::make_unique<LhloCopyRemoval>();
|
||||||
|
}
|
||||||
|
|
||||||
|
static PassRegistration<LhloCopyRemoval> copy_removal_pass(
|
||||||
|
"lhlo-copy-removal", "Removes redundant LHLO copy operations");
|
||||||
|
|
||||||
|
} // namespace xla_lhlo
|
||||||
|
} // namespace mlir
|
|
@ -0,0 +1,151 @@
|
||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
// This file implements logic for fusing linalg ops obtained after LHLO
|
||||||
|
// lowering.
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
|
||||||
|
#include "third_party/absl/memory/memory.h"
|
||||||
|
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/ArrayRef.h"
|
||||||
|
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLExtras.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/FoldUtils.h"
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace xla_lhlo {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using linalg::LinalgOp;
|
||||||
|
|
||||||
|
class LhloFuseLinalg : public PassWrapper<LhloFuseLinalg, FunctionPass> {
|
||||||
|
public:
|
||||||
|
LhloFuseLinalg() = default;
|
||||||
|
LhloFuseLinalg(const LhloFuseLinalg&) {}
|
||||||
|
LhloFuseLinalg(bool use_parallel_loops, llvm::ArrayRef<unsigned> tile_sizes) {
|
||||||
|
tile_sizes_ = tile_sizes;
|
||||||
|
use_parallel_loops_.setValue(use_parallel_loops);
|
||||||
|
}
|
||||||
|
|
||||||
|
void runOnFunction() override {
|
||||||
|
auto func = getFunction();
|
||||||
|
|
||||||
|
// TODO(pifon): Remove assumption that the function has a single block.
|
||||||
|
if (!llvm::hasSingleElement(func)) {
|
||||||
|
emitError(func.getLoc(), "The function needs to have a single block.");
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// The fusion in Linalg is currently possible only when the consumer op is
|
||||||
|
// tiled. In order to greedily fuse the ops, we have to start from the tiled
|
||||||
|
// root linalg ops, i.e. linalg ops that write to output buffers of the
|
||||||
|
// function or are returned in case of escaping allocations.
|
||||||
|
llvm::SmallDenseSet<Value> result_buffers;
|
||||||
|
for (auto func_arg : func.getArguments()) {
|
||||||
|
result_buffers.insert(func_arg);
|
||||||
|
}
|
||||||
|
for (auto& block : func) {
|
||||||
|
auto returnOp = mlir::dyn_cast<mlir::ReturnOp>(block.getTerminator());
|
||||||
|
if (!returnOp) continue;
|
||||||
|
for (auto operand : returnOp.getOperands()) {
|
||||||
|
result_buffers.insert(operand);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
MLIRContext* ctx = func.getContext();
|
||||||
|
OpBuilder b(func);
|
||||||
|
OperationFolder folder(ctx);
|
||||||
|
func.walk([&](linalg::GenericOp generic_op) {
|
||||||
|
SmallVector<int64_t, 2> tile_sizes(tile_sizes_.begin(),
|
||||||
|
tile_sizes_.end());
|
||||||
|
if (tile_sizes.empty()) {
|
||||||
|
tile_sizes = SmallVector<int64_t, 2>(generic_op.getNumLoops(), 1);
|
||||||
|
}
|
||||||
|
auto op = cast<LinalgOp>(generic_op.getOperation());
|
||||||
|
for (const Value result : op.getOutputBuffers()) {
|
||||||
|
if (!result_buffers.count(result)) continue;
|
||||||
|
if (tileGenericOp(op, tile_sizes, &b)) {
|
||||||
|
generic_op.erase();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
auto patterns = linalg::getLinalgTilingCanonicalizationPatterns(ctx);
|
||||||
|
applyPatternsAndFoldGreedily(func, patterns);
|
||||||
|
|
||||||
|
// Fuse producers of tiled linalg ops.
|
||||||
|
llvm::SmallDenseSet<Operation*> erase_set;
|
||||||
|
SmallVector<Operation*, 8> linalg_ops;
|
||||||
|
func.walk([&](LinalgOp op) { linalg_ops.push_back(op); });
|
||||||
|
for (auto* op : llvm::reverse(linalg_ops)) {
|
||||||
|
for (unsigned id = 0, e = LinalgOp(op).getNumInputs(); id < e; ++id) {
|
||||||
|
linalg::Aliases aliases;
|
||||||
|
linalg::LinalgDependenceGraph graph(aliases, linalg_ops);
|
||||||
|
if (auto info = fuseProducerOf(b, op, id, graph, &folder)) {
|
||||||
|
auto originalOp = info->originalProducer.getOperation();
|
||||||
|
erase_set.insert(originalOp);
|
||||||
|
auto originalOpInLinalgOpsVector = std::find_if(
|
||||||
|
linalg_ops.begin(), linalg_ops.end(),
|
||||||
|
[&](const Operation* op) { return op == originalOp; });
|
||||||
|
*originalOpInLinalgOpsVector = info->fusedProducer.getOperation();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto patterns = linalg::getLinalgTilingCanonicalizationPatterns(ctx);
|
||||||
|
applyPatternsAndFoldGreedily(func, patterns);
|
||||||
|
}
|
||||||
|
for (auto* e : erase_set) e->erase();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool tileGenericOp(LinalgOp op, ArrayRef<int64_t> tile_sizes, OpBuilder* b) {
|
||||||
|
auto loopType = use_parallel_loops_
|
||||||
|
? linalg::LinalgTilingLoopType::ParallelLoops
|
||||||
|
: linalg::LinalgTilingLoopType::Loops;
|
||||||
|
auto tiled_generic_op = linalg::tileLinalgOp(*b, op,
|
||||||
|
linalg::LinalgTilingOptions()
|
||||||
|
.setTileSizes(tile_sizes)
|
||||||
|
.setLoopType(loopType));
|
||||||
|
return tiled_generic_op.hasValue();
|
||||||
|
}
|
||||||
|
|
||||||
|
Option<bool> use_parallel_loops_{
|
||||||
|
*this, "use-parallel-loops",
|
||||||
|
llvm::cl::desc(
|
||||||
|
"Tiles GenericOp consumer to parallel loops before linalg fusion"),
|
||||||
|
llvm::cl::init(false)};
|
||||||
|
|
||||||
|
ListOption<unsigned> tile_sizes_{
|
||||||
|
*this, "tile-sizes",
|
||||||
|
llvm::cl::desc(
|
||||||
|
"Tile sizes by which to tile linalg generic before linalg fusion"),
|
||||||
|
llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<FuncOp>> createLhloFuseLinalg(
|
||||||
|
bool use_parallel_loops, ArrayRef<unsigned> tile_sizes) {
|
||||||
|
return absl::make_unique<LhloFuseLinalg>(use_parallel_loops, tile_sizes);
|
||||||
|
}
|
||||||
|
|
||||||
|
static PassRegistration<LhloFuseLinalg> legalize_pass(
|
||||||
|
"lhlo-fuse-linalg",
|
||||||
|
"Greedily fuse linalg ops obtained after LHLO lowering.");
|
||||||
|
|
||||||
|
} // namespace xla_lhlo
|
||||||
|
} // namespace mlir
|
|
@ -0,0 +1,161 @@
|
||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
// This file implements logic for lowering LHLO dialect to Affine dialect.
|
||||||
|
|
||||||
|
#include "third_party/absl/memory/memory.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Location.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_xla_to_scalar_op.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace xla_lhlo {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// Builds an affine loop nest iterating from zeros to "upper_bounds" with unit
|
||||||
|
// steps, and populates the body of the innermost loop using "body_builder".
|
||||||
|
static void BuildBoundedAffineLoopNest(
|
||||||
|
OpBuilder& builder, Location location, ArrayRef<int64_t> upper_bounds,
|
||||||
|
function_ref<void(OpBuilder&, Location, ValueRange)> body_builder) {
|
||||||
|
SmallVector<int64_t, 3> lower_bounds(upper_bounds.size(), /*Value=*/0);
|
||||||
|
SmallVector<int64_t, 3> steps(upper_bounds.size(), /*Value=*/1);
|
||||||
|
buildAffineLoopNest(builder, location, lower_bounds, upper_bounds, steps,
|
||||||
|
body_builder);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct DotOpConverter : public OpRewritePattern<DotOp> {
|
||||||
|
using OpRewritePattern<DotOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
// Supports only rank-2 tensors for LHS and RHS.
|
||||||
|
LogicalResult matchAndRewrite(DotOp op,
|
||||||
|
PatternRewriter& rewriter) const override {
|
||||||
|
Value lhs = op.lhs();
|
||||||
|
Value rhs = op.rhs();
|
||||||
|
MemRefType lhs_type = lhs.getType().cast<MemRefType>();
|
||||||
|
MemRefType rhs_type = rhs.getType().cast<MemRefType>();
|
||||||
|
Type element_type = lhs_type.getElementType();
|
||||||
|
ArrayRef<int64_t> shape_lhs = lhs_type.getShape();
|
||||||
|
ArrayRef<int64_t> shape_rhs = rhs_type.getShape();
|
||||||
|
|
||||||
|
if ((lhs_type.getRank() != 2) || (rhs_type.getRank() != 2)) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult map_status = success();
|
||||||
|
auto body_builder = [&](OpBuilder& builder, Location loc, ValueRange ivs) {
|
||||||
|
SmallVector<Value, 2> lhs_indices{ivs[0], ivs[2]},
|
||||||
|
rhs_indices{ivs[2], ivs[1]}, result_indices{ivs[0], ivs[1]};
|
||||||
|
|
||||||
|
auto l = builder.create<AffineLoadOp>(loc, lhs, lhs_indices);
|
||||||
|
auto r = builder.create<AffineLoadOp>(loc, rhs, rhs_indices);
|
||||||
|
auto result =
|
||||||
|
rewriter.create<AffineLoadOp>(loc, op.output(), result_indices);
|
||||||
|
Value op_result = xla_lhlo::XlaOpToStdScalarOp::map<DotOp>(
|
||||||
|
op, element_type, {l, r, result}, &builder);
|
||||||
|
map_status = success(op_result != nullptr);
|
||||||
|
if (failed(map_status)) return;
|
||||||
|
builder.create<AffineStoreOp>(loc, op_result, op.output(),
|
||||||
|
result_indices);
|
||||||
|
};
|
||||||
|
|
||||||
|
BuildBoundedAffineLoopNest(rewriter, op.getLoc(),
|
||||||
|
{shape_lhs[0], shape_rhs[1], shape_rhs[0]},
|
||||||
|
body_builder);
|
||||||
|
if (failed(map_status)) return failure();
|
||||||
|
|
||||||
|
rewriter.eraseOp(op);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename LhloOpTy>
|
||||||
|
struct BinaryOpConverter : public OpRewritePattern<LhloOpTy> {
|
||||||
|
using OpRewritePattern<LhloOpTy>::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(LhloOpTy op,
|
||||||
|
PatternRewriter& rewriter) const override {
|
||||||
|
const auto& lhs = op.lhs();
|
||||||
|
const auto& rhs = op.rhs();
|
||||||
|
const auto& lhs_type = lhs.getType().template cast<MemRefType>();
|
||||||
|
const auto& rhs_type = rhs.getType().template cast<MemRefType>();
|
||||||
|
const auto& element_type = lhs_type.getElementType();
|
||||||
|
|
||||||
|
if (lhs_type.getShape() != rhs_type.getShape()) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult map_status = success();
|
||||||
|
auto body_builder = [&](OpBuilder& builder, Location loc,
|
||||||
|
ValueRange induction_vars) {
|
||||||
|
auto l = builder.create<AffineLoadOp>(loc, lhs, induction_vars);
|
||||||
|
auto r = builder.create<AffineLoadOp>(loc, rhs, induction_vars);
|
||||||
|
Value op_result = xla_lhlo::XlaOpToStdScalarOp::map<LhloOpTy>(
|
||||||
|
op, element_type, {l, r}, &builder);
|
||||||
|
map_status = success(op_result != nullptr);
|
||||||
|
if (failed(map_status)) return;
|
||||||
|
rewriter.create<AffineStoreOp>(loc, op_result, op.out(), induction_vars);
|
||||||
|
};
|
||||||
|
|
||||||
|
BuildBoundedAffineLoopNest(rewriter, op.getLoc(), lhs_type.getShape(),
|
||||||
|
body_builder);
|
||||||
|
if (failed(map_status)) return failure();
|
||||||
|
rewriter.eraseOp(op);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void populateLHLOToAffineConversionPattern(MLIRContext* context,
|
||||||
|
OwningRewritePatternList* patterns) {
|
||||||
|
// clang-format off
|
||||||
|
patterns->insert<
|
||||||
|
BinaryOpConverter<xla_lhlo::AddOp>,
|
||||||
|
BinaryOpConverter<xla_lhlo::AndOp>,
|
||||||
|
BinaryOpConverter<xla_lhlo::DivOp>,
|
||||||
|
BinaryOpConverter<xla_lhlo::MaxOp>,
|
||||||
|
BinaryOpConverter<xla_lhlo::MinOp>,
|
||||||
|
BinaryOpConverter<xla_lhlo::MulOp>,
|
||||||
|
BinaryOpConverter<xla_lhlo::SubOp>,
|
||||||
|
DotOpConverter>(context);
|
||||||
|
// clang-format on
|
||||||
|
}
|
||||||
|
|
||||||
|
struct LhloLegalizeToAffine
|
||||||
|
: public PassWrapper<LhloLegalizeToAffine, FunctionPass> {
|
||||||
|
void runOnFunction() override {
|
||||||
|
OwningRewritePatternList patterns;
|
||||||
|
auto func = getFunction();
|
||||||
|
populateLHLOToAffineConversionPattern(func.getContext(), &patterns);
|
||||||
|
applyPatternsAndFoldGreedily(func, patterns);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<FuncOp>> createLegalizeToAffinePass() {
|
||||||
|
return absl::make_unique<LhloLegalizeToAffine>();
|
||||||
|
}
|
||||||
|
|
||||||
|
static PassRegistration<LhloLegalizeToAffine> legalize_pass(
|
||||||
|
"lhlo-legalize-to-affine", "Legalize from LHLO dialect to affine dialect");
|
||||||
|
|
||||||
|
} // namespace xla_lhlo
|
||||||
|
} // namespace mlir
|
|
@ -0,0 +1,196 @@
|
||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
// This file implements logic for lowering LHLO dialect to GPU dialect.
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
|
||||||
|
#include "third_party/absl/memory/memory.h"
|
||||||
|
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/ArrayRef.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/GPU/GPUDialect.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/SCF/SCF.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/BlockAndValueMapping.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Function.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Location.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h"
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_xla_to_scalar_op.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace xla_lhlo {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// A simple translation of LHLO reduce operations to a corresponding gpu
|
||||||
|
// launch operation. The transformation does no tiling and also only supports
|
||||||
|
// 1d results.
|
||||||
|
class LhloReduceToGPULaunchConverter : public OpConversionPattern<ReduceOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(
|
||||||
|
ReduceOp reduce_op, ArrayRef<Value> args,
|
||||||
|
ConversionPatternRewriter& rewriter) const final {
|
||||||
|
auto loc = reduce_op.getLoc();
|
||||||
|
// Only support 1d reductions for now.
|
||||||
|
int64_t size = 0;
|
||||||
|
for (auto result : reduce_op.out()) {
|
||||||
|
auto shaped_type = result.getType().dyn_cast<ShapedType>();
|
||||||
|
if (!shaped_type || shaped_type.getRank() != 1) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
auto dim_size = shaped_type.getDimSize(0);
|
||||||
|
if (size && size != dim_size) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
size = dim_size;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto reducing_dimension = *reduce_op.dimensions().int_value_begin();
|
||||||
|
|
||||||
|
// Require all inputs to have the same shape.
|
||||||
|
int64_t reduce_dim_size = 0;
|
||||||
|
for (auto input : reduce_op.operands()) {
|
||||||
|
auto shaped_type = input.getType().dyn_cast<ShapedType>();
|
||||||
|
if (!shaped_type || !shaped_type.hasStaticShape()) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
reduce_dim_size =
|
||||||
|
shaped_type.getDimSize(reducing_dimension.getSExtValue());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a launch that is parallel in the result dimension.
|
||||||
|
auto block_size_x = rewriter.create<mlir::ConstantOp>(
|
||||||
|
loc, rewriter.getIndexType(),
|
||||||
|
rewriter.getIntegerAttr(rewriter.getIndexType(), size));
|
||||||
|
auto one = rewriter.create<mlir::ConstantOp>(
|
||||||
|
loc, rewriter.getIndexType(),
|
||||||
|
rewriter.getIntegerAttr(rewriter.getIndexType(), 1));
|
||||||
|
auto launch_op = rewriter.create<mlir::gpu::LaunchOp>(
|
||||||
|
loc, one, one, one, block_size_x, one, one);
|
||||||
|
{
|
||||||
|
OpBuilder::InsertionGuard guard(rewriter);
|
||||||
|
rewriter.setInsertionPointToEnd(&launch_op.body().front());
|
||||||
|
auto index = launch_op.getThreadIds().x;
|
||||||
|
|
||||||
|
// Load the initial value and store it to the output.
|
||||||
|
for (auto pair : llvm::zip(reduce_op.init_values(), reduce_op.out())) {
|
||||||
|
auto init_value = rewriter.create<mlir::LoadOp>(loc, std::get<0>(pair));
|
||||||
|
rewriter.create<mlir::StoreOp>(loc, init_value, std::get<1>(pair),
|
||||||
|
ArrayRef<Value>{index});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insert a loop into the body to compute the reduction. The loop ranges
|
||||||
|
// from [0.dim).
|
||||||
|
auto zero = rewriter.create<mlir::ConstantOp>(
|
||||||
|
loc, rewriter.getIndexType(),
|
||||||
|
rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
|
||||||
|
// TODO(b/137624192) Use dimOp to make it shape independent.
|
||||||
|
auto upper = rewriter.create<mlir::ConstantOp>(
|
||||||
|
loc, rewriter.getIndexType(),
|
||||||
|
rewriter.getIntegerAttr(rewriter.getIndexType(), reduce_dim_size));
|
||||||
|
auto step = rewriter.create<mlir::ConstantOp>(
|
||||||
|
loc, rewriter.getIndexType(),
|
||||||
|
rewriter.getIntegerAttr(rewriter.getIndexType(), 1));
|
||||||
|
auto loop = rewriter.create<mlir::scf::ForOp>(loc, zero, upper, step);
|
||||||
|
|
||||||
|
rewriter.setInsertionPointToStart(loop.getBody());
|
||||||
|
// Compute memrefs for the value to reduce. This makes it easier to just
|
||||||
|
// inline the body.
|
||||||
|
auto output = *reduce_op.out().begin();
|
||||||
|
// TODO(herhut) Move this to the SliceOp builder.
|
||||||
|
auto resType = MemRefType::get(
|
||||||
|
llvm::None, output.getType().cast<MemRefType>().getElementType(),
|
||||||
|
makeStridedLinearLayoutMap(llvm::None,
|
||||||
|
MemRefType::getDynamicStrideOrOffset(),
|
||||||
|
rewriter.getContext()));
|
||||||
|
auto accumulator = rewriter.create<mlir::linalg::SliceOp>(
|
||||||
|
loc, resType, output, ArrayRef<Value>{launch_op.getThreadIds().x});
|
||||||
|
llvm::SmallVector<Value, 4> indexings;
|
||||||
|
auto input_buffer = *reduce_op.operands().begin();
|
||||||
|
auto input_type = input_buffer.getType().cast<MemRefType>();
|
||||||
|
for (int64_t dim = 0; dim < input_type.getRank(); ++dim) {
|
||||||
|
indexings.push_back(dim == reducing_dimension
|
||||||
|
? loop.getInductionVar()
|
||||||
|
: launch_op.getThreadIds().x);
|
||||||
|
}
|
||||||
|
// TODO(herhut) Move this to the SliceOp builder.
|
||||||
|
auto input = *reduce_op.operand_begin();
|
||||||
|
auto rhs = rewriter.create<mlir::linalg::SliceOp>(
|
||||||
|
loc,
|
||||||
|
MemRefType::get(
|
||||||
|
llvm::None, input_type.getElementType(),
|
||||||
|
makeStridedLinearLayoutMap(llvm::None,
|
||||||
|
MemRefType::getDynamicStrideOrOffset(),
|
||||||
|
rewriter.getContext())),
|
||||||
|
input, indexings);
|
||||||
|
|
||||||
|
// Now copy over the actual body of the reduction, leaving out the
|
||||||
|
// terminator.
|
||||||
|
BlockAndValueMapping mapping;
|
||||||
|
mapping.map(reduce_op.body().front().getArgument(0), accumulator);
|
||||||
|
mapping.map(reduce_op.body().front().getArgument(1), rhs);
|
||||||
|
mapping.map(reduce_op.body().front().getArgument(2), accumulator);
|
||||||
|
for (auto& nested : reduce_op.body().front().without_terminator()) {
|
||||||
|
auto clone = rewriter.clone(nested, mapping);
|
||||||
|
for (auto pair : llvm::zip(nested.getResults(), clone->getResults())) {
|
||||||
|
mapping.map(std::get<0>(pair), std::get<1>(pair));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Finally, insert the terminator for the launchOp.
|
||||||
|
rewriter.setInsertionPointToEnd(&launch_op.body().front());
|
||||||
|
rewriter.create<mlir::gpu::TerminatorOp>(loc);
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.eraseOp(reduce_op);
|
||||||
|
return success();
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
struct LhloLegalizeToGpu : public PassWrapper<LhloLegalizeToGpu, FunctionPass> {
|
||||||
|
void runOnFunction() override {
|
||||||
|
OwningRewritePatternList patterns;
|
||||||
|
ConversionTarget target(getContext());
|
||||||
|
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,
|
||||||
|
gpu::GPUDialect, scf::SCFDialect, XlaLhloDialect>();
|
||||||
|
target.addIllegalOp<ReduceOp>();
|
||||||
|
auto func = getFunction();
|
||||||
|
patterns.insert<LhloReduceToGPULaunchConverter>(func.getContext());
|
||||||
|
if (failed(applyPartialConversion(func, target, patterns))) {
|
||||||
|
signalPassFailure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<FuncOp>> createLegalizeToGpuPass() {
|
||||||
|
return absl::make_unique<LhloLegalizeToGpu>();
|
||||||
|
}
|
||||||
|
|
||||||
|
static PassRegistration<LhloLegalizeToGpu> legalize_pass(
|
||||||
|
"lhlo-legalize-to-gpu", "Legalize from LHLO dialect to GPU dialect");
|
||||||
|
|
||||||
|
} // namespace xla_lhlo
|
||||||
|
} // namespace mlir
|
|
@ -0,0 +1,136 @@
|
||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h"
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace xla_lhlo {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
struct StaticMemRefCastOpConverter
|
||||||
|
: public ConvertOpToLLVMPattern<StaticMemRefCastOp> {
|
||||||
|
using ConvertOpToLLVMPattern<StaticMemRefCastOp>::ConvertOpToLLVMPattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(
|
||||||
|
Operation *op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
auto cast_op = cast<StaticMemRefCastOp>(op);
|
||||||
|
|
||||||
|
StaticMemRefCastOp::Adaptor operands_adaptor(operands);
|
||||||
|
MemRefDescriptor sourceMemRef(operands_adaptor.operand());
|
||||||
|
|
||||||
|
MemRefType targetMemRefType =
|
||||||
|
cast_op.getResult().getType().cast<MemRefType>();
|
||||||
|
auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType)
|
||||||
|
.dyn_cast_or_null<LLVM::LLVMType>();
|
||||||
|
if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
|
||||||
|
return failure();
|
||||||
|
// Create descriptor.
|
||||||
|
auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
|
||||||
|
Type llvmTargetElementTy = desc.getElementType();
|
||||||
|
// Set allocated ptr.
|
||||||
|
Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
|
||||||
|
allocated =
|
||||||
|
rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
|
||||||
|
desc.setAllocatedPtr(rewriter, loc, allocated);
|
||||||
|
// Set aligned ptr.
|
||||||
|
Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
|
||||||
|
ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
|
||||||
|
desc.setAlignedPtr(rewriter, loc, ptr);
|
||||||
|
|
||||||
|
// Fill size and stride descriptors in memref.
|
||||||
|
auto target_sizes = targetMemRefType.getShape();
|
||||||
|
int64_t target_offset;
|
||||||
|
llvm::SmallVector<int64_t, 4> target_strides;
|
||||||
|
if (failed((getStridesAndOffset(targetMemRefType, target_strides,
|
||||||
|
target_offset))))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// Copy offset of `targetMemRef`.
|
||||||
|
desc.setConstantOffset(rewriter, loc, target_offset);
|
||||||
|
for (int i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
|
||||||
|
desc.setConstantSize(rewriter, loc, i, target_sizes[i]);
|
||||||
|
desc.setConstantStride(rewriter, loc, i, target_strides[i]);
|
||||||
|
}
|
||||||
|
rewriter.replaceOp(op, {desc});
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct DynamicMemRefCastOpConverter
|
||||||
|
: public ConvertOpToLLVMPattern<DynamicMemRefCastOp> {
|
||||||
|
using ConvertOpToLLVMPattern<DynamicMemRefCastOp>::ConvertOpToLLVMPattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(
|
||||||
|
Operation *op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
auto loc = op->getLoc();
|
||||||
|
auto cast_op = cast<DynamicMemRefCastOp>(op);
|
||||||
|
|
||||||
|
DynamicMemRefCastOp::Adaptor operands_adaptor(operands);
|
||||||
|
MemRefDescriptor sourceMemRef(operands_adaptor.operand());
|
||||||
|
|
||||||
|
MemRefType targetMemRefType =
|
||||||
|
cast_op.getResult().getType().cast<MemRefType>();
|
||||||
|
auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType)
|
||||||
|
.dyn_cast_or_null<LLVM::LLVMType>();
|
||||||
|
if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
|
||||||
|
return failure();
|
||||||
|
// Create descriptor.
|
||||||
|
auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
|
||||||
|
Type llvmTargetElementTy = desc.getElementType();
|
||||||
|
// Set allocated ptr.
|
||||||
|
Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
|
||||||
|
allocated =
|
||||||
|
rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
|
||||||
|
desc.setAllocatedPtr(rewriter, loc, allocated);
|
||||||
|
// Set aligned ptr.
|
||||||
|
Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
|
||||||
|
ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
|
||||||
|
desc.setAlignedPtr(rewriter, loc, ptr);
|
||||||
|
// Copy offset of `sourceMemRef`.
|
||||||
|
desc.setOffset(rewriter, loc, sourceMemRef.offset(rewriter, loc));
|
||||||
|
|
||||||
|
// Fill size and stride descriptors in memref.
|
||||||
|
if (!cast_op.sizes().empty()) {
|
||||||
|
auto sizes = operands_adaptor.sizes();
|
||||||
|
auto strides = operands_adaptor.strides();
|
||||||
|
for (int i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
|
||||||
|
desc.setSize(rewriter, loc, i, sizes[i]);
|
||||||
|
desc.setStride(rewriter, loc, i, strides[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
rewriter.replaceOp(op, {desc});
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void PopulateLhloToLLVMConversionPatterns(const LowerToLLVMOptions &options,
|
||||||
|
LLVMTypeConverter *converter,
|
||||||
|
OwningRewritePatternList *patterns) {
|
||||||
|
patterns->insert<DynamicMemRefCastOpConverter, StaticMemRefCastOpConverter>(
|
||||||
|
*converter, options);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace xla_lhlo
|
||||||
|
} // namespace mlir
|
|
@ -0,0 +1,59 @@
|
||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace xla_lhlo {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
class TestLhloToLLVMPass
|
||||||
|
: public ::mlir::PassWrapper<TestLhloToLLVMPass,
|
||||||
|
::mlir::OperationPass<::mlir::ModuleOp>> {
|
||||||
|
public:
|
||||||
|
void runOnOperation() override {
|
||||||
|
ModuleOp m = getOperation();
|
||||||
|
|
||||||
|
OwningRewritePatternList patterns;
|
||||||
|
LLVMTypeConverter converter(m.getContext());
|
||||||
|
populateStdToLLVMConversionPatterns(converter, patterns);
|
||||||
|
PopulateLhloToLLVMConversionPatterns(
|
||||||
|
LowerToLLVMOptions::getDefaultOptions(), &converter, &patterns);
|
||||||
|
|
||||||
|
ConversionTarget target(getContext());
|
||||||
|
target.addLegalDialect<LLVM::LLVMDialect>();
|
||||||
|
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
|
||||||
|
target.addIllegalDialect<XlaLhloDialect>();
|
||||||
|
|
||||||
|
if (failed(applyFullConversion(m, target, patterns))) {
|
||||||
|
signalPassFailure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
static PassRegistration<TestLhloToLLVMPass> legalize_lhlo_pass(
|
||||||
|
"test-lhlo-legalize-to-llvm", "Legalize from LHLO dialect to LLVM.");
|
||||||
|
|
||||||
|
} // namespace xla_lhlo
|
||||||
|
} // namespace mlir
|
|
@ -0,0 +1,731 @@
|
||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "third_party/absl/memory/memory.h"
|
||||||
|
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/ArrayRef.h"
|
||||||
|
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLExtras.h"
|
||||||
|
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/SmallVector.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/SCF/SCF.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h"
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace xla_lhlo {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// Clones and adapts the code in `lhlo_block` that works on buffers and has a
|
||||||
|
// single output buffer to make it compatible with `operands` that have element
|
||||||
|
// types of the respective buffers. Returns the computed value.
|
||||||
|
//
|
||||||
|
// Example. For `operands` with (f32, i32) types and a block with LHLO ops and
|
||||||
|
// with signature:
|
||||||
|
// ^bb(%lhs: memref<f32>, %rhs: memref<i32>, %res: memref<i1>):
|
||||||
|
// <LHLO_ops>
|
||||||
|
//
|
||||||
|
// inserts necessary alloc and store ops to compute and return result that has
|
||||||
|
// `i1` type.
|
||||||
|
Value ApplySingleResultLhloCode(Location loc, ValueRange operands,
|
||||||
|
Block* lhlo_block, OpBuilder* b) {
|
||||||
|
SmallVector<Value, 2> arg_bufs;
|
||||||
|
for (auto arg_type : lhlo_block->getArgumentTypes()) {
|
||||||
|
arg_bufs.push_back(b->create<AllocOp>(loc, arg_type.cast<MemRefType>()));
|
||||||
|
}
|
||||||
|
for (auto operand : llvm::enumerate(operands)) {
|
||||||
|
b->create<StoreOp>(loc, operand.value(), arg_bufs[operand.index()]);
|
||||||
|
}
|
||||||
|
// Clone the ops from `lhlo_block`.
|
||||||
|
BlockAndValueMapping mapping;
|
||||||
|
mapping.map(lhlo_block->getArguments(), arg_bufs);
|
||||||
|
for (auto& nested : lhlo_block->without_terminator()) {
|
||||||
|
auto clone = b->clone(nested, mapping);
|
||||||
|
mapping.map(nested.getResults(), clone->getResults());
|
||||||
|
}
|
||||||
|
return b->create<LoadOp>(loc, arg_bufs.back());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Converts a block with LHLO ops and with signature:
|
||||||
|
// ^bb(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
|
||||||
|
// into a reduction operator of scf.reduce by doing buffer allocation for
|
||||||
|
// scalar arguments and the result of `scf.reduce` to make it compatible with
|
||||||
|
// LHLO ops.
|
||||||
|
void ConvertToReductionOperator(Location loc, scf::ReduceOp reduce_op,
|
||||||
|
Block* lhlo_block, OpBuilder* b) {
|
||||||
|
Block& loop_reduce_op_body = reduce_op.reductionOperator().front();
|
||||||
|
OpBuilder::InsertionGuard guard(*b);
|
||||||
|
b->setInsertionPointToStart(&loop_reduce_op_body);
|
||||||
|
b->create<scf::ReduceReturnOp>(
|
||||||
|
loc, ApplySingleResultLhloCode(loc, loop_reduce_op_body.getArguments(),
|
||||||
|
lhlo_block, b));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns result of ConstantOp if `dim` is static, otherwise uses DimOp to
|
||||||
|
// extract dimension at runtime.
|
||||||
|
Value GetStaticOrDynamicDim(mlir::Location loc, Value shaped_value,
|
||||||
|
size_t dim_index, int64_t dim, OpBuilder* b) {
|
||||||
|
return dim == ShapedType::kDynamicSize
|
||||||
|
? b->create<DimOp>(loc, shaped_value, dim_index).getResult()
|
||||||
|
: b->create<ConstantIndexOp>(loc, dim);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct MappedIvs {
|
||||||
|
// False if the mapped indices are in the padding area, true otherwise.
|
||||||
|
Value in_bounds;
|
||||||
|
// Mapped indices.
|
||||||
|
SmallVector<Value, 2> ivs;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename OpTy>
|
||||||
|
MappedIvs MapWindowIvsToInput(OpTy op, ValueRange ivs, ValueRange window_ivs,
|
||||||
|
OpBuilder* b) {
|
||||||
|
MappedIvs mapped_ivs;
|
||||||
|
|
||||||
|
if (!op.window_strides().hasValue()) {
|
||||||
|
op.emitOpError("No window strides specified.");
|
||||||
|
}
|
||||||
|
auto window_strides = op.window_strides().getValue();
|
||||||
|
|
||||||
|
if (!op.padding().hasValue()) {
|
||||||
|
op.emitOpError("No padding specified.");
|
||||||
|
}
|
||||||
|
auto padding = op.padding().getValue();
|
||||||
|
|
||||||
|
auto loc = op.getLoc();
|
||||||
|
auto operand = op.operand();
|
||||||
|
auto operand_shape = operand.getType().template cast<MemRefType>().getShape();
|
||||||
|
|
||||||
|
// `in_bounds` is false when the mapped indices are in the padding area.
|
||||||
|
mapped_ivs.in_bounds = b->create<mlir::ConstantOp>(
|
||||||
|
loc, b->getI1Type(), b->getIntegerAttr(b->getI1Type(), 1));
|
||||||
|
for (unsigned i = 0, e = ivs.size(); i < e; ++i) {
|
||||||
|
auto stride = window_strides.template getValue<llvm::APInt>(i);
|
||||||
|
auto pad_low = padding.template getValue<llvm::APInt>({i, 0});
|
||||||
|
|
||||||
|
Value stride_val = b->create<ConstantIndexOp>(loc, stride.getSExtValue());
|
||||||
|
Value pad_low_val = b->create<ConstantIndexOp>(loc, pad_low.getSExtValue());
|
||||||
|
|
||||||
|
Value center = b->create<MulIOp>(loc, ivs[i], stride_val);
|
||||||
|
Value offset = b->create<SubIOp>(loc, window_ivs[i], pad_low_val);
|
||||||
|
Value index = b->create<AddIOp>(loc, center, offset);
|
||||||
|
Value upper_bound =
|
||||||
|
GetStaticOrDynamicDim(loc, operand, i, operand_shape[i], b);
|
||||||
|
// We must check whether 0 <= index_i < shape_i, as otherwise we are in
|
||||||
|
// the pad and then we have to use the neutral element for reduction.
|
||||||
|
// Equivalently, it can be computed as the unsigned comparison index_i <
|
||||||
|
// shape_i, since a negative value wraps to a large positive value.
|
||||||
|
mapped_ivs.in_bounds = b->create<mlir::AndOp>(
|
||||||
|
loc, mapped_ivs.in_bounds,
|
||||||
|
b->create<CmpIOp>(loc, CmpIPredicate::ult, index, upper_bound));
|
||||||
|
mapped_ivs.ivs.push_back(index);
|
||||||
|
}
|
||||||
|
return mapped_ivs;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns scf::Parallel over a shaped value with static or dynamic shape.
|
||||||
|
scf::ParallelOp MakeLoopOverShape(Location loc, Value shaped_value,
|
||||||
|
OpBuilder* b) {
|
||||||
|
Value zero = b->create<ConstantIndexOp>(loc, 0);
|
||||||
|
Value one = b->create<ConstantIndexOp>(loc, 1);
|
||||||
|
|
||||||
|
ArrayRef<int64_t> shape =
|
||||||
|
shaped_value.getType().cast<ShapedType>().getShape();
|
||||||
|
SmallVector<Value, 2> lower, upper, step;
|
||||||
|
for (auto dim : llvm::enumerate(shape)) {
|
||||||
|
upper.push_back(
|
||||||
|
GetStaticOrDynamicDim(loc, shaped_value, dim.index(), dim.value(), b));
|
||||||
|
lower.push_back(zero);
|
||||||
|
step.push_back(one);
|
||||||
|
}
|
||||||
|
return b->create<scf::ParallelOp>(loc, lower, upper, step);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Converts `xla_lhlo.ReduceOp` into two scf::ParallelOp and a scf::ReduceOp.
|
||||||
|
// The outper `ParallelOp` refers to the parallel loops if there are
|
||||||
|
// any. The inner `ParalleOp` refers to the reduction loops and `ReduceOp`
|
||||||
|
// contains the reduction operator.
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// "xla_lhlo.reduce"(%buffer, %init_buf, %result) ( {
|
||||||
|
// ^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
|
||||||
|
// <LHLO ops>
|
||||||
|
// } ) {dimensions = dense<[1]> : tensor<1xi64>}
|
||||||
|
// : (memref<100x10x5xf32>, memref<f32>, memref<100x5xf32>) -> ()
|
||||||
|
//
|
||||||
|
// is roughly converted into:
|
||||||
|
//
|
||||||
|
// %init = load %init_buf[] : memref<f32>
|
||||||
|
// scf.parallel (%i, %k) = (%c0, %c0) to (%c100, %c5) step (%c1, %c1) {
|
||||||
|
// %result = scf.parallel (%j) = (%c0) to (%c10) step (%c1) init (%init) {
|
||||||
|
// %elem_to_reduce = load %buffer[%i, %j, %k] : memref<100x10x5xf32>
|
||||||
|
// scf.reduce(%elem_to_reduce) {
|
||||||
|
// ^bb0(%elem: f32, %acc: f32): // no predecessors
|
||||||
|
// elem_buf = alloc() : memref<f32>
|
||||||
|
// store %elem, elem_buf[] : memref<f32>
|
||||||
|
// acc_buf = alloc() : memref<f32>
|
||||||
|
// store %acc, acc_buf[] : memref<f32>
|
||||||
|
// <LHLO_ops>
|
||||||
|
// %acc_result = load acc_buf[] : memref<f32>
|
||||||
|
// scf.reduce.return %acc_result : f32
|
||||||
|
// } : f32
|
||||||
|
// scf.yield
|
||||||
|
// } : f32
|
||||||
|
// scf.yield
|
||||||
|
// }
|
||||||
|
class ReduceOpConverter : public OpConversionPattern<xla_lhlo::ReduceOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern<xla_lhlo::ReduceOp>::OpConversionPattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(
|
||||||
|
xla_lhlo::ReduceOp xla_reduce_op, ArrayRef<Value> /*args*/,
|
||||||
|
ConversionPatternRewriter& rewriter) const final {
|
||||||
|
// TODO(b/137624192) Implement variadic reduce.
|
||||||
|
if (xla_reduce_op.out().size() != 1) return failure();
|
||||||
|
|
||||||
|
scf::ReduceOp reduce_op =
|
||||||
|
CreateReduceOpInNestedParallelLoops(xla_reduce_op, &rewriter);
|
||||||
|
ConvertToReductionOperator(xla_reduce_op.getLoc(), reduce_op,
|
||||||
|
&xla_reduce_op.body().front(), &rewriter);
|
||||||
|
rewriter.replaceOp(xla_reduce_op, llvm::None);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Creates nested `scf.parallel` ops with `scf.reduce`. The outer ParallelOp
|
||||||
|
// refers to the parallel dimensions of `xla_reduce_op` if any and the inner
|
||||||
|
// ParallelOp refers to the reduction dimensions. The scf.reduce op is
|
||||||
|
// returned.
|
||||||
|
//
|
||||||
|
// If the reduction argument is a memref<100x10x5xf32> and the
|
||||||
|
// reduction is performed along dimension 1 then this method will generate
|
||||||
|
//
|
||||||
|
// %init = load %init_buf[] : memref<f32>
|
||||||
|
// scf.parallel (%i, %k) = (%c0, %c0) to (%c100, %c5) step (%c1, %c1) {
|
||||||
|
// %result = scf.parallel (%j) = (%c0) to (%c10) step (%c1) init (%init) {
|
||||||
|
// %elem_to_reduce = load %buffer[%i, %j, %k] : memref<100x10x5xf32>
|
||||||
|
// scf.reduce(%elem_to_reduce) {
|
||||||
|
// <THE BLOCK PTR TO BE RETURNED>
|
||||||
|
// } : f32
|
||||||
|
// scf.yield
|
||||||
|
// } : f32
|
||||||
|
// scf.yield
|
||||||
|
// }
|
||||||
|
scf::ReduceOp CreateReduceOpInNestedParallelLoops(
|
||||||
|
xla_lhlo::ReduceOp xla_reduce_op,
|
||||||
|
ConversionPatternRewriter* rewriter) const {
|
||||||
|
auto loc = xla_reduce_op.getLoc();
|
||||||
|
DenseSet<int> reducing_dims;
|
||||||
|
for (const auto& rdim : xla_reduce_op.dimensions().getIntValues()) {
|
||||||
|
reducing_dims.insert(rdim.getSExtValue());
|
||||||
|
}
|
||||||
|
|
||||||
|
Value operand = *xla_reduce_op.operands().begin();
|
||||||
|
Value out = *xla_reduce_op.out().begin();
|
||||||
|
SmallVector<Value, 2> parallel_lower, parallel_upper, parallel_step;
|
||||||
|
SmallVector<Value, 2> reduce_lower, reduce_upper, reduce_step;
|
||||||
|
auto operand_shape = operand.getType().cast<MemRefType>().getShape();
|
||||||
|
for (auto dim : llvm::enumerate(operand_shape)) {
|
||||||
|
const bool is_reducing_dim = reducing_dims.count(dim.index());
|
||||||
|
|
||||||
|
Value ub = GetStaticOrDynamicDim(loc, operand, dim.index(), dim.value(),
|
||||||
|
rewriter);
|
||||||
|
Value lb = rewriter->create<ConstantIndexOp>(loc, 0);
|
||||||
|
Value step = rewriter->create<ConstantIndexOp>(loc, 1);
|
||||||
|
(is_reducing_dim ? reduce_lower : parallel_lower).push_back(lb);
|
||||||
|
(is_reducing_dim ? reduce_upper : parallel_upper).push_back(ub);
|
||||||
|
(is_reducing_dim ? reduce_step : parallel_step).push_back(step);
|
||||||
|
}
|
||||||
|
// Load initial value from memref<element_type>.
|
||||||
|
SmallVector<Value, 1> init_value = {
|
||||||
|
rewriter->create<LoadOp>(loc, *xla_reduce_op.init_values().begin())};
|
||||||
|
// Outer ParallelOp is not needed if it is a reduction across all dims.
|
||||||
|
scf::ParallelOp outer;
|
||||||
|
if (!parallel_lower.empty()) {
|
||||||
|
outer = rewriter->create<scf::ParallelOp>(loc, parallel_lower,
|
||||||
|
parallel_upper, parallel_step);
|
||||||
|
rewriter->setInsertionPointToStart(outer.getBody());
|
||||||
|
}
|
||||||
|
scf::ParallelOp inner = rewriter->create<scf::ParallelOp>(
|
||||||
|
loc, reduce_lower, reduce_upper, reduce_step, ValueRange(init_value));
|
||||||
|
Value reduction_result = *inner.getResults().begin();
|
||||||
|
|
||||||
|
SmallVector<Value, 1> out_indices;
|
||||||
|
if (outer != nullptr) {
|
||||||
|
out_indices.reserve(outer.getNumLoops());
|
||||||
|
for (Value iv : outer.getInductionVars()) {
|
||||||
|
out_indices.push_back(iv);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
out_indices.push_back(rewriter->create<ConstantIndexOp>(loc, 0));
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter->create<StoreOp>(loc, reduction_result, out, out_indices);
|
||||||
|
|
||||||
|
// Load the element to reduce.
|
||||||
|
SmallVector<Value, 2> indices;
|
||||||
|
indices.reserve(operand_shape.size());
|
||||||
|
|
||||||
|
if (outer) {
|
||||||
|
auto inner_ivs_it = inner.getInductionVars().begin();
|
||||||
|
auto outer_ivs_it = outer.getInductionVars().begin();
|
||||||
|
for (unsigned i = 0, e = operand_shape.size(); i < e; ++i) {
|
||||||
|
indices.push_back(reducing_dims.count(i) ? *inner_ivs_it++
|
||||||
|
: *outer_ivs_it++);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
indices = inner.getInductionVars();
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter->setInsertionPointToStart(inner.getBody());
|
||||||
|
Value elem = rewriter->create<mlir::LoadOp>(
|
||||||
|
loc, *xla_reduce_op.operands().begin(), indices);
|
||||||
|
return rewriter->create<scf::ReduceOp>(loc, elem);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Pseudocode:
|
||||||
|
// for each index O in output
|
||||||
|
// accumulator = neutral_value
|
||||||
|
// in_bounds = true
|
||||||
|
// for each index W in window
|
||||||
|
// for each dimension i from 0 to rank - 1
|
||||||
|
// index = O[i] * stride[i] + W[i] - pad_low[i]
|
||||||
|
// in_bounds = inbounds && (index `ult` shape[i])
|
||||||
|
// I[i] = index
|
||||||
|
// if (in_bounds)
|
||||||
|
// value = input[I]
|
||||||
|
// else
|
||||||
|
// value = neutral_value
|
||||||
|
// accumulator = reduction_operator(output[O], value)
|
||||||
|
// output[O] = accumulator
|
||||||
|
//
|
||||||
|
// Converts `xla_lhlo.ReduceWindowOp` into two scf::ParallelOp and a
|
||||||
|
// scf::ReduceOp.
|
||||||
|
// The outper `ParallelOp` refers to the parallel loops that traverese output
|
||||||
|
// buffer. The inner `ParalleOp` refers to the reduction loops that traverse
|
||||||
|
// reduction windows and `ReduceOp` contains the reduction operator.
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// func @reduce_window(%arg: memref<112x112xf32>,
|
||||||
|
// %init: memref<f32>,
|
||||||
|
// %result: memref<56x56xf32>) {
|
||||||
|
// "xla_lhlo.reduce_window"(%arg, %init, %result) ( {
|
||||||
|
// ^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
|
||||||
|
// "xla_lhlo.maximum"(%lhs, %rhs, %res)
|
||||||
|
// : (memref<f32>, memref<f32>, memref<f32>) -> ()
|
||||||
|
// "xla_lhlo.terminator"() : () -> ()
|
||||||
|
// }) {
|
||||||
|
// padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>,
|
||||||
|
// window_dimensions = dense<[3, 3]> : tensor<2xi64>,
|
||||||
|
// window_strides = dense<[2, 2]> : tensor<2xi64>
|
||||||
|
// } : (memref<112x112xf32>, memref<f32>, memref<56x56xf32>) -> ()
|
||||||
|
// return
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// is roughly converted into:
|
||||||
|
//
|
||||||
|
// %neutral_elem = load %init_buf[] : memref<f32>
|
||||||
|
// scf.parallel (%i, %j) = (%c0, %c0) to (%c56, %c56) step (%c1, %c1) {
|
||||||
|
// %result = scf.parallel (%iw, %jw) = (%c0, %c0)
|
||||||
|
// to (%c3, %c3) step (%c1, %c1) neutral_elem (%0) -> f32 {
|
||||||
|
// %in_bounds = <COMPUTE IF INDEX IS IN OPERAND'S pad>
|
||||||
|
// %elem = load %operand[%computed_i, %computed_j]
|
||||||
|
// %elem_or_neutral = select %in_bounds, %elem, %neutral_elem : f32
|
||||||
|
// scf.reduce(%elem_to_reduce) : f32 {
|
||||||
|
// ^bb0(%arg7: f32, %arg8: f32):
|
||||||
|
// <LHLO ops>
|
||||||
|
// }
|
||||||
|
// scf.yield
|
||||||
|
// }
|
||||||
|
// store %result, %output_buffer[%i, %j] : memref<56x56xf32>
|
||||||
|
// scf.yield
|
||||||
|
// }
|
||||||
|
// return
|
||||||
|
// }
|
||||||
|
class ReduceWindowOpConverter
|
||||||
|
: public OpConversionPattern<xla_lhlo::ReduceWindowOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern<xla_lhlo::ReduceWindowOp>::OpConversionPattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(
|
||||||
|
xla_lhlo::ReduceWindowOp xla_reduce_window_op, ArrayRef<Value> /*args*/,
|
||||||
|
ConversionPatternRewriter& rewriter) const final {
|
||||||
|
scf::ParallelOp output_loop, window_loop;
|
||||||
|
std::tie(output_loop, window_loop) =
|
||||||
|
CreateParallelLoopsToTraverseOutputAndWindow(xla_reduce_window_op,
|
||||||
|
&rewriter);
|
||||||
|
|
||||||
|
scf::ReduceOp reduce_op = CreateReduceOpInNestedParallelLoops(
|
||||||
|
xla_reduce_window_op, output_loop, window_loop, &rewriter);
|
||||||
|
|
||||||
|
ConvertToReductionOperator(xla_reduce_window_op.getLoc(), reduce_op,
|
||||||
|
&xla_reduce_window_op.body().front(), &rewriter);
|
||||||
|
rewriter.replaceOp(xla_reduce_window_op, llvm::None);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::pair<scf::ParallelOp, scf::ParallelOp>
|
||||||
|
CreateParallelLoopsToTraverseOutputAndWindow(
|
||||||
|
xla_lhlo::ReduceWindowOp xla_reduce_window_op,
|
||||||
|
ConversionPatternRewriter* rewriter) const {
|
||||||
|
auto loc = xla_reduce_window_op.getLoc();
|
||||||
|
Value init_value =
|
||||||
|
rewriter->create<LoadOp>(loc, xla_reduce_window_op.init_value());
|
||||||
|
|
||||||
|
Value zero = rewriter->create<ConstantIndexOp>(loc, 0);
|
||||||
|
Value one = rewriter->create<ConstantIndexOp>(loc, 1);
|
||||||
|
|
||||||
|
// Create an outer parallel loop that spans the output of ReduceWindowOp.
|
||||||
|
Value xla_output = xla_reduce_window_op.out();
|
||||||
|
auto output_loop = MakeLoopOverShape(loc, xla_output, rewriter);
|
||||||
|
|
||||||
|
// Create a nested loop that traverses the window.
|
||||||
|
SmallVector<Value, 2> window_lower, window_upper, window_step;
|
||||||
|
rewriter->setInsertionPointToStart(output_loop.getBody());
|
||||||
|
for (const auto& window_dim : xla_reduce_window_op.window_dimensions()) {
|
||||||
|
window_step.push_back(one);
|
||||||
|
window_lower.push_back(zero);
|
||||||
|
window_upper.push_back(
|
||||||
|
rewriter->create<ConstantIndexOp>(loc, window_dim.getSExtValue()));
|
||||||
|
}
|
||||||
|
auto window_loop = rewriter->create<scf::ParallelOp>(
|
||||||
|
loc, window_lower, window_upper, window_step, ValueRange(init_value));
|
||||||
|
|
||||||
|
Value reduction_result = *window_loop.getResults().begin();
|
||||||
|
auto output_ivs = output_loop.getInductionVars();
|
||||||
|
rewriter->create<StoreOp>(loc, reduction_result, xla_output, output_ivs);
|
||||||
|
return std::make_pair(output_loop, window_loop);
|
||||||
|
}
|
||||||
|
|
||||||
|
scf::ReduceOp CreateReduceOpInNestedParallelLoops(
|
||||||
|
xla_lhlo::ReduceWindowOp xla_reduce_window_op,
|
||||||
|
scf::ParallelOp output_loop, scf::ParallelOp window_loop,
|
||||||
|
ConversionPatternRewriter* rewriter) const {
|
||||||
|
rewriter->setInsertionPointToStart(window_loop.getBody());
|
||||||
|
auto loc = xla_reduce_window_op.getLoc();
|
||||||
|
|
||||||
|
if (xla_reduce_window_op.base_dilations().hasValue() ||
|
||||||
|
xla_reduce_window_op.window_dilations().hasValue()) {
|
||||||
|
xla_reduce_window_op.emitRemark(
|
||||||
|
"Lowering to parallel loops does not support `base_dilations` or "
|
||||||
|
"`window_dilations` attributes yet. The attributes will be ignored.");
|
||||||
|
}
|
||||||
|
|
||||||
|
Value xla_operand = xla_reduce_window_op.operand();
|
||||||
|
auto xla_operand_type = xla_operand.getType().cast<MemRefType>();
|
||||||
|
|
||||||
|
// Compute ivs in 'arg' buffer and whether these ivs are in pad area or not.
|
||||||
|
MappedIvs mapped_ivs = MapWindowIvsToInput(
|
||||||
|
xla_reduce_window_op, output_loop.getInductionVars(),
|
||||||
|
window_loop.getInductionVars(), rewriter);
|
||||||
|
|
||||||
|
auto elem_or_init = rewriter->create<scf::IfOp>(
|
||||||
|
loc, xla_operand_type.getElementType(), mapped_ivs.in_bounds,
|
||||||
|
/*withElseRegion=*/true);
|
||||||
|
|
||||||
|
OpBuilder then_builder = elem_or_init.getThenBodyBuilder();
|
||||||
|
Value elem = then_builder.create<mlir::LoadOp>(
|
||||||
|
loc, xla_reduce_window_op.operand(), mapped_ivs.ivs);
|
||||||
|
then_builder.create<scf::YieldOp>(loc, elem);
|
||||||
|
|
||||||
|
OpBuilder else_builder = elem_or_init.getElseBodyBuilder();
|
||||||
|
else_builder.create<scf::YieldOp>(loc, *window_loop.initVals().begin());
|
||||||
|
|
||||||
|
return rewriter->create<scf::ReduceOp>(loc,
|
||||||
|
*elem_or_init.results().begin());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// See the operation semantics in
|
||||||
|
// https://www.tensorflow.org/xla/operation_semantics#selectandscatter
|
||||||
|
//
|
||||||
|
// Pseudocode:
|
||||||
|
// scf.parallel(coordinates O in the output):
|
||||||
|
// output[O] = init
|
||||||
|
// scf.parallel(coordinates S in the source):
|
||||||
|
// selected_ivs = 0
|
||||||
|
// selected_val = 0
|
||||||
|
// initialized_flag = false
|
||||||
|
// scf.for (first dim W_1 in the window)
|
||||||
|
// iter_args (selected_ivs, selected_val, initialized_flag):
|
||||||
|
// ...
|
||||||
|
// scf.for (last dim W_N in the window):
|
||||||
|
// iter_args (selected_ivs, selected_val, initialized_flag):
|
||||||
|
// I = S * stride + W - pad_low
|
||||||
|
// if I within bounds of operand:
|
||||||
|
// if (initialized_flag):
|
||||||
|
// pred = select(selected_value, operand(I))):
|
||||||
|
// if (pred)
|
||||||
|
// selected_value = operand(I)
|
||||||
|
// selected_index = I
|
||||||
|
// else
|
||||||
|
// selected_value = operand(I)
|
||||||
|
// selected_index = I
|
||||||
|
// initialized_flag = true
|
||||||
|
// output(selected_index) = scatter(output(selected_index), source(S))
|
||||||
|
class SelectAndScatterOpConverter
|
||||||
|
: public OpConversionPattern<xla_lhlo::SelectAndScatterOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern<xla_lhlo::SelectAndScatterOp>::OpConversionPattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(
|
||||||
|
xla_lhlo::SelectAndScatterOp s_and_s_op, ArrayRef<Value> /*args*/,
|
||||||
|
ConversionPatternRewriter& rewriter) const final {
|
||||||
|
auto loc = s_and_s_op.getLoc();
|
||||||
|
InitializeOutput(s_and_s_op, &rewriter);
|
||||||
|
scf::ParallelOp loop_over_src =
|
||||||
|
MakeLoopOverShape(loc, s_and_s_op.source(), &rewriter);
|
||||||
|
rewriter.setInsertionPointToStart(loop_over_src.getBody());
|
||||||
|
|
||||||
|
// Compute indices of the selected element in the window.
|
||||||
|
auto selected_ivs = SelectIvs(s_and_s_op, loop_over_src, &rewriter);
|
||||||
|
|
||||||
|
// Load `source[selected_ivs]`.
|
||||||
|
auto src_elem = rewriter.create<LoadOp>(loc, s_and_s_op.source(),
|
||||||
|
loop_over_src.getInductionVars());
|
||||||
|
|
||||||
|
// Compute `out[selected_ivs]` = scatter(out[selected_ivs], src_element)`.
|
||||||
|
auto rmw = rewriter.create<GenericAtomicRMWOp>(loc, s_and_s_op.out(),
|
||||||
|
selected_ivs);
|
||||||
|
OpBuilder rmw_builder = OpBuilder::atBlockEnd(rmw.getBody());
|
||||||
|
auto acc_result =
|
||||||
|
ApplySingleResultLhloCode(loc, {src_elem, rmw.getCurrentValue()},
|
||||||
|
&s_and_s_op.scatter().front(), &rmw_builder);
|
||||||
|
rmw_builder.create<AtomicYieldOp>(loc, acc_result);
|
||||||
|
|
||||||
|
rewriter.replaceOp(s_and_s_op, llvm::None);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
void InitializeOutput(xla_lhlo::SelectAndScatterOp s_and_s_op,
|
||||||
|
OpBuilder* b) const {
|
||||||
|
auto loc = s_and_s_op.getLoc();
|
||||||
|
Value init_value = b->create<LoadOp>(loc, s_and_s_op.init_value());
|
||||||
|
|
||||||
|
scf::ParallelOp loop_over_output =
|
||||||
|
MakeLoopOverShape(loc, s_and_s_op.out(), b);
|
||||||
|
OpBuilder::InsertionGuard guard(*b);
|
||||||
|
b->setInsertionPointToStart(loop_over_output.getBody());
|
||||||
|
b->create<StoreOp>(loc, init_value, s_and_s_op.out(),
|
||||||
|
loop_over_output.getInductionVars());
|
||||||
|
}
|
||||||
|
|
||||||
|
struct WindowLoops {
|
||||||
|
SmallVector<Value, 2> selected_ivs;
|
||||||
|
SmallVector<Value, 2> window_ivs;
|
||||||
|
scf::ForOp inner_loop;
|
||||||
|
};
|
||||||
|
WindowLoops InsertWindowLoops(xla_lhlo::SelectAndScatterOp s_and_s_op,
|
||||||
|
scf::ParallelOp loop_over_src,
|
||||||
|
OpBuilder* b) const {
|
||||||
|
auto loc = s_and_s_op.getLoc();
|
||||||
|
Value zero = b->create<ConstantIndexOp>(loc, 0);
|
||||||
|
Value one = b->create<ConstantIndexOp>(loc, 1);
|
||||||
|
|
||||||
|
auto element_type =
|
||||||
|
s_and_s_op.out().getType().cast<MemRefType>().getElementType();
|
||||||
|
auto rank = loop_over_src.getNumLoops();
|
||||||
|
|
||||||
|
// `iter_args` = [iv_1, ..., iv_N, selected_value, is_initialized]
|
||||||
|
SmallVector<Value, 4> iter_args(rank, zero);
|
||||||
|
iter_args.push_back(b->create<mlir::ConstantOp>(
|
||||||
|
loc, element_type, b->getFloatAttr(element_type, 0)));
|
||||||
|
iter_args.push_back(b->create<mlir::ConstantOp>(
|
||||||
|
loc, b->getI1Type(), b->getIntegerAttr(b->getI1Type(), 0)));
|
||||||
|
|
||||||
|
// Create a nested loop that traverses the window.
|
||||||
|
OpBuilder::InsertPoint ip;
|
||||||
|
WindowLoops result;
|
||||||
|
for (const auto& window_dim :
|
||||||
|
s_and_s_op.window_dimensions()->getIntValues()) {
|
||||||
|
Value upper = b->create<ConstantIndexOp>(loc, window_dim.getSExtValue());
|
||||||
|
result.inner_loop =
|
||||||
|
b->create<scf::ForOp>(loc, zero, upper, one, iter_args);
|
||||||
|
if (b->getInsertionBlock() == loop_over_src.getBody()) {
|
||||||
|
ip = b->saveInsertionPoint();
|
||||||
|
result.selected_ivs = result.inner_loop.getResults().take_front(rank);
|
||||||
|
} else {
|
||||||
|
b->create<scf::YieldOp>(loc, result.inner_loop.getResults());
|
||||||
|
}
|
||||||
|
b->setInsertionPointToStart(result.inner_loop.getBody());
|
||||||
|
iter_args = ValueRange{result.inner_loop.getRegionIterArgs()};
|
||||||
|
result.window_ivs.push_back(result.inner_loop.getInductionVar());
|
||||||
|
}
|
||||||
|
b->restoreInsertionPoint(ip);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Adapter to store iteration arguments of sequential loops that perform
|
||||||
|
// select in a window.
|
||||||
|
class IterArgs {
|
||||||
|
public:
|
||||||
|
explicit IterArgs(ValueRange ivs_val_flag) : ivs_val_flag_(ivs_val_flag) {}
|
||||||
|
IterArgs(ValueRange ivs, Value value, Value flag) {
|
||||||
|
ivs_val_flag_ = ivs;
|
||||||
|
ivs_val_flag_.push_back(value);
|
||||||
|
ivs_val_flag_.push_back(flag);
|
||||||
|
}
|
||||||
|
|
||||||
|
ArrayRef<Value> to_vector() const { return ivs_val_flag_; }
|
||||||
|
|
||||||
|
// Indices of the currently selected value.
|
||||||
|
ArrayRef<Value> ivs() const { return to_vector().drop_back(2); }
|
||||||
|
// Currently selected value w.r.t. select() function.
|
||||||
|
Value value() const { return ivs_val_flag_.end()[-2]; }
|
||||||
|
// i1 flag if value() and ivs() were initialized.
|
||||||
|
Value is_init() const { return ivs_val_flag_.back(); }
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Vector that stores iv_1, ..., iv_N, value, init.
|
||||||
|
SmallVector<Value, 4> ivs_val_flag_;
|
||||||
|
};
|
||||||
|
|
||||||
|
SmallVector<Value, 2> SelectIvs(xla_lhlo::SelectAndScatterOp s_and_s_op,
|
||||||
|
scf::ParallelOp loop_over_src,
|
||||||
|
OpBuilder* b) const {
|
||||||
|
auto loc = s_and_s_op.getLoc();
|
||||||
|
|
||||||
|
WindowLoops window_loops = InsertWindowLoops(s_and_s_op, loop_over_src, b);
|
||||||
|
auto inner_loop_b =
|
||||||
|
OpBuilder::atBlockEnd(window_loops.inner_loop.getBody());
|
||||||
|
|
||||||
|
// Compute ivs in 'arg' buffer and whether these ivs are in the pad area.
|
||||||
|
MappedIvs mapped_ivs =
|
||||||
|
MapWindowIvsToInput(s_and_s_op, loop_over_src.getInductionVars(),
|
||||||
|
window_loops.window_ivs, &inner_loop_b);
|
||||||
|
|
||||||
|
IterArgs ivs_val_flag(window_loops.inner_loop.getRegionIterArgs());
|
||||||
|
|
||||||
|
auto if_in_bounds = inner_loop_b.create<scf::IfOp>(
|
||||||
|
loc, window_loops.inner_loop.getResultTypes(), mapped_ivs.in_bounds,
|
||||||
|
/*withElseRegion=*/true);
|
||||||
|
|
||||||
|
// Case when we are inside boundaries of 'arg' and not in the pad area.
|
||||||
|
{
|
||||||
|
OpBuilder in_bounds_then_b = if_in_bounds.getThenBodyBuilder();
|
||||||
|
auto select_or_init_results = SelectOrInitialize(
|
||||||
|
s_and_s_op, mapped_ivs.ivs, &ivs_val_flag, &in_bounds_then_b);
|
||||||
|
in_bounds_then_b.create<scf::YieldOp>(loc, select_or_init_results);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Case when we are in the pad.
|
||||||
|
{
|
||||||
|
OpBuilder in_bounds_else_b = if_in_bounds.getElseBodyBuilder();
|
||||||
|
in_bounds_else_b.create<scf::YieldOp>(loc, ivs_val_flag.to_vector());
|
||||||
|
}
|
||||||
|
|
||||||
|
inner_loop_b.create<scf::YieldOp>(loc, if_in_bounds.getResults());
|
||||||
|
return window_loops.selected_ivs;
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Value, 4> SelectOrInitialize(
|
||||||
|
xla_lhlo::SelectAndScatterOp s_and_s_op, ArrayRef<Value> operand_ivs,
|
||||||
|
IterArgs* ivs_val_flag, OpBuilder* b) const {
|
||||||
|
auto loc = s_and_s_op.getLoc();
|
||||||
|
Value true_i1 = b->create<mlir::ConstantOp>(
|
||||||
|
loc, b->getI1Type(), b->getIntegerAttr(b->getI1Type(), 1));
|
||||||
|
|
||||||
|
TypeRange iter_arg_types{ivs_val_flag->to_vector()};
|
||||||
|
Value operand_elem =
|
||||||
|
b->create<LoadOp>(loc, s_and_s_op.operand(), operand_ivs);
|
||||||
|
auto if_init =
|
||||||
|
b->create<scf::IfOp>(loc, iter_arg_types, ivs_val_flag->is_init(),
|
||||||
|
/*withElseRegion=*/true);
|
||||||
|
// Init == true, i.e. iter args are already initialized with a selected
|
||||||
|
// element in boundaries of the operand. Select function has to be computed
|
||||||
|
// here.
|
||||||
|
{
|
||||||
|
OpBuilder if_init_then_b = if_init.getThenBodyBuilder();
|
||||||
|
|
||||||
|
auto& lhlo_select = s_and_s_op.select().front();
|
||||||
|
Value pred =
|
||||||
|
ApplySingleResultLhloCode(loc, {operand_elem, ivs_val_flag->value()},
|
||||||
|
&lhlo_select, &if_init_then_b);
|
||||||
|
|
||||||
|
auto if_pred = if_init_then_b.create<scf::IfOp>(loc, iter_arg_types, pred,
|
||||||
|
/*withElseRegion=*/true);
|
||||||
|
|
||||||
|
// Pred == true, therefore pack newly selected ivs, val and init flag back
|
||||||
|
// to iter_args and return.
|
||||||
|
{
|
||||||
|
OpBuilder if_pred_then_b = if_pred.getThenBodyBuilder();
|
||||||
|
if_pred_then_b.create<scf::YieldOp>(
|
||||||
|
loc, IterArgs{operand_ivs, operand_elem, true_i1}.to_vector());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pred == false, therefore return old iter_args.
|
||||||
|
{
|
||||||
|
OpBuilder if_pred_else_b = if_pred.getElseBodyBuilder();
|
||||||
|
if_pred_else_b.create<scf::YieldOp>(loc, ivs_val_flag->to_vector());
|
||||||
|
}
|
||||||
|
|
||||||
|
if_init_then_b.create<scf::YieldOp>(loc, if_pred.getResults());
|
||||||
|
}
|
||||||
|
// Init == false, i.e. only pad was visited before and this is the first
|
||||||
|
// element in the boundaries of the operand.
|
||||||
|
{
|
||||||
|
OpBuilder if_init_else_b = if_init.getElseBodyBuilder();
|
||||||
|
|
||||||
|
if_init_else_b.create<scf::YieldOp>(
|
||||||
|
loc, IterArgs{operand_ivs, operand_elem, true_i1}.to_vector());
|
||||||
|
}
|
||||||
|
return if_init.getResults();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct LhloLegalizeToParallelLoops
|
||||||
|
: public PassWrapper<LhloLegalizeToParallelLoops, FunctionPass> {
|
||||||
|
void runOnFunction() override {
|
||||||
|
auto func = getFunction();
|
||||||
|
|
||||||
|
OwningRewritePatternList patterns;
|
||||||
|
// clang-format off
|
||||||
|
patterns.insert<
|
||||||
|
ReduceOpConverter,
|
||||||
|
ReduceWindowOpConverter,
|
||||||
|
SelectAndScatterOpConverter
|
||||||
|
>(func.getContext());
|
||||||
|
// clang-format on
|
||||||
|
|
||||||
|
ConversionTarget target(getContext());
|
||||||
|
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,
|
||||||
|
scf::SCFDialect, XlaLhloDialect>();
|
||||||
|
target.addIllegalOp<xla_lhlo::ReduceOp, xla_lhlo::ReduceWindowOp,
|
||||||
|
xla_lhlo::SelectAndScatterOp>();
|
||||||
|
|
||||||
|
if (failed(applyPartialConversion(func, target, patterns))) {
|
||||||
|
signalPassFailure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<FuncOp>> createLegalizeLhloToParallelLoopsPass() {
|
||||||
|
return absl::make_unique<LhloLegalizeToParallelLoops>();
|
||||||
|
}
|
||||||
|
|
||||||
|
static PassRegistration<LhloLegalizeToParallelLoops> legalize_lhlo_pass(
|
||||||
|
"lhlo-legalize-to-parallel-loops",
|
||||||
|
"Legalize from LHLO dialect to parallel loops.");
|
||||||
|
|
||||||
|
} // namespace xla_lhlo
|
||||||
|
} // namespace mlir
|
|
@ -0,0 +1,79 @@
|
||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
// Thsi file implements passes to convert complex operations to equivalent real
|
||||||
|
// value operations. This does not include removing complex values from function
|
||||||
|
// argument or return types.
|
||||||
|
|
||||||
|
#include <cstddef>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <iterator>
|
||||||
|
#include <numeric>
|
||||||
|
|
||||||
|
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLExtras.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/TypeUtilities.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Types.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/PassRegistry.h"
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h"
|
||||||
|
|
||||||
|
using mlir::FunctionPass;
|
||||||
|
using mlir::OwningRewritePatternList;
|
||||||
|
using mlir::PassRegistration;
|
||||||
|
using mlir::PassWrapper;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class LowerComplex : public PassWrapper<LowerComplex, FunctionPass> {
|
||||||
|
public:
|
||||||
|
explicit LowerComplex() : PassWrapper<LowerComplex, FunctionPass>() {}
|
||||||
|
|
||||||
|
/// Performs the lowering to XLA dialect.
|
||||||
|
void runOnFunction() override;
|
||||||
|
};
|
||||||
|
} // end anonymous namespace
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace xla {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/generated_lower_complex.inc"
|
||||||
|
|
||||||
|
} // end anonymous namespace
|
||||||
|
|
||||||
|
void PopulateComplexLoweringPatterns(MLIRContext* context,
|
||||||
|
OwningRewritePatternList* patterns) {
|
||||||
|
populateWithGenerated(context, patterns);
|
||||||
|
}
|
||||||
|
} // end namespace xla
|
||||||
|
} // end namespace mlir
|
||||||
|
|
||||||
|
// Lowers the complex operations that can be represented using other operations.
|
||||||
|
void LowerComplex::runOnFunction() {
|
||||||
|
// Add lowering patterns to the list.
|
||||||
|
OwningRewritePatternList patterns;
|
||||||
|
mlir::xla::PopulateComplexLoweringPatterns(&getContext(), &patterns);
|
||||||
|
|
||||||
|
applyPatternsAndFoldGreedily(getFunction(), patterns);
|
||||||
|
}
|
||||||
|
|
||||||
|
static PassRegistration<LowerComplex> pass(
|
||||||
|
"test-xla-lower-complex",
|
||||||
|
"Lower complex operations into non-complex operations");
|
|
@ -0,0 +1,109 @@
|
||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
// This is the legalization pattern that converts complex operations into
|
||||||
|
// equivalent real value operations.
|
||||||
|
|
||||||
|
include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpBase.td"
|
||||||
|
include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td"
|
||||||
|
include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td"
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Binary op patterns.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
// Add and subtraction are elementwise and can be distributed across the real
|
||||||
|
// and imaginary components.
|
||||||
|
foreach elementwiseOp = [HLO_AddOp, HLO_SubOp] in
|
||||||
|
def : Pat<(elementwiseOp HLO_ComplexTensor:$lhs,
|
||||||
|
HLO_ComplexTensor:$rhs),
|
||||||
|
(HLO_ComplexOp
|
||||||
|
(elementwiseOp (HLO_RealOp $lhs), (HLO_RealOp $rhs)),
|
||||||
|
(elementwiseOp (HLO_ImagOp $lhs), (HLO_ImagOp $rhs)))>;
|
||||||
|
|
||||||
|
// Complex multiplication results in a cross product multiplication between the
|
||||||
|
// real and imaginary components such that:
|
||||||
|
// result.real = lhs.real * rhs.real - lhs.imag * rhs.imag
|
||||||
|
// result.imag = lhs.imag * rhs.real + lhs.real * rhs.imag
|
||||||
|
def : Pat<(HLO_MulOp HLO_ComplexTensor:$lhs,
|
||||||
|
HLO_ComplexTensor:$rhs),
|
||||||
|
(HLO_ComplexOp
|
||||||
|
(HLO_SubOp
|
||||||
|
(HLO_MulOp
|
||||||
|
(HLO_RealOp:$lhs_real $lhs),
|
||||||
|
(HLO_RealOp:$rhs_real $rhs)),
|
||||||
|
(HLO_MulOp
|
||||||
|
(HLO_ImagOp:$lhs_imag $lhs),
|
||||||
|
(HLO_ImagOp:$rhs_imag $rhs))),
|
||||||
|
(HLO_AddOp
|
||||||
|
(HLO_MulOp $lhs_real, $rhs_imag),
|
||||||
|
(HLO_MulOp $lhs_imag, $rhs_real)))>;
|
||||||
|
|
||||||
|
// Multiplication between a complex and real tensor can be distributed by
|
||||||
|
// applying the real multiplicant to both the real and complex component.
|
||||||
|
//
|
||||||
|
// Note that the sourcep pattern is not legal according to the HLO dialect but
|
||||||
|
// instead handle intermediates generated by other patterns.
|
||||||
|
def : Pat<(HLO_MulOp HLO_ComplexTensor:$lhs, HLO_IntOrFpTensor:$rhs),
|
||||||
|
(HLO_ComplexOp
|
||||||
|
(HLO_MulOp (HLO_RealOp $lhs), $rhs),
|
||||||
|
(HLO_MulOp (HLO_ImagOp $lhs), $rhs))>;
|
||||||
|
|
||||||
|
def : Pat<(HLO_MulOp HLO_IntOrFpTensor:$lhs, HLO_ComplexTensor:$rhs),
|
||||||
|
(HLO_ComplexOp
|
||||||
|
(HLO_MulOp $lhs, (HLO_RealOp $rhs)),
|
||||||
|
(HLO_MulOp $lhs, (HLO_ImagOp $rhs)))>;
|
||||||
|
|
||||||
|
|
||||||
|
// Division is performed by normalizing the denominator by multiplying by the
|
||||||
|
// conjugate of the rhs.
|
||||||
|
// numerator = lhs * conj(rhs)
|
||||||
|
// denominator = rhs * conj(rhs)
|
||||||
|
def : Pat<(HLO_DivOp HLO_ComplexTensor:$lhs, HLO_ComplexTensor:$rhs),
|
||||||
|
(HLO_DivOp
|
||||||
|
(HLO_MulOp:$num $lhs,
|
||||||
|
(HLO_ComplexOp:$conj
|
||||||
|
(HLO_RealOp $rhs),
|
||||||
|
(HLO_NegOp (HLO_ImagOp $rhs)))),
|
||||||
|
(HLO_RealOp:$den (HLO_MulOp $rhs, $conj)))>;
|
||||||
|
|
||||||
|
|
||||||
|
def : Pat<(HLO_DivOp HLO_ComplexTensor:$lhs, HLO_IntOrFpTensor:$rhs),
|
||||||
|
(HLO_ComplexOp
|
||||||
|
(HLO_DivOp (HLO_RealOp $lhs), $rhs),
|
||||||
|
(HLO_DivOp (HLO_ImagOp $lhs), $rhs))>;
|
||||||
|
|
||||||
|
|
||||||
|
// Absolute value is evaluated as:
|
||||||
|
// result = sqrt(val.real * val.real + val.imag * val.imag)
|
||||||
|
def : Pat<(HLO_AbsOp HLO_ComplexTensor:$val),
|
||||||
|
(HLO_ComplexOp
|
||||||
|
(HLO_SqrtOp
|
||||||
|
(HLO_AddOp
|
||||||
|
(HLO_MulOp (HLO_RealOp:$real $val), $real),
|
||||||
|
(HLO_MulOp (HLO_ImagOp:$imag $val), $imag))),
|
||||||
|
(HLO_ConstOp (ConstantSplat<"0"> $real)))>;
|
||||||
|
|
||||||
|
// Exponential can be lowered to an exponential on the real component and a
|
||||||
|
// sum of sinusoids of the imaginary component, which equates to a normal
|
||||||
|
// exponential operator multiplied by Euler's formula.
|
||||||
|
//
|
||||||
|
// Exp(a + ib) = Exp(a) * Exp(ib) = Exp(a) * (Cos(b) + iSin(b))
|
||||||
|
def : Pat<(HLO_ExpOp HLO_ComplexTensor:$val),
|
||||||
|
(HLO_MulOp
|
||||||
|
(HLO_ExpOp (HLO_RealOp $val)),
|
||||||
|
(HLO_ComplexOp
|
||||||
|
(HLO_CosOp (HLO_ImagOp:$imag $val)),
|
||||||
|
(HLO_SinOp $imag)))>;
|
|
@ -0,0 +1,194 @@
|
||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
// This file implements logic for lowering XLA general dot to a regular dot.
|
||||||
|
|
||||||
|
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLExtras.h"
|
||||||
|
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringSwitch.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Function.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Location.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/TypeUtilities.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||||
|
|
||||||
|
using mlir::DenseIntElementsAttr;
|
||||||
|
using mlir::ElementsAttr;
|
||||||
|
using mlir::failure;
|
||||||
|
using mlir::FunctionPass;
|
||||||
|
using mlir::LogicalResult;
|
||||||
|
using mlir::MLIRContext;
|
||||||
|
using mlir::OpRewritePattern;
|
||||||
|
using mlir::OwningRewritePatternList;
|
||||||
|
using mlir::PassRegistration;
|
||||||
|
using mlir::PassWrapper;
|
||||||
|
using mlir::PatternRewriter;
|
||||||
|
using mlir::RankedTensorType;
|
||||||
|
using mlir::success;
|
||||||
|
using mlir::Value;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
Value TransposeReshape(Value arg, mlir::Location loc,
|
||||||
|
llvm::ArrayRef<int64_t> left_dims,
|
||||||
|
llvm::ArrayRef<int64_t> right_dims,
|
||||||
|
llvm::ArrayRef<int64_t> arg_shape,
|
||||||
|
PatternRewriter *rewriter) {
|
||||||
|
auto element_type = mlir::getElementTypeOrSelf(arg.getType());
|
||||||
|
|
||||||
|
int64_t left_size = 1;
|
||||||
|
for (auto dim : left_dims) {
|
||||||
|
left_size *= arg_shape[dim];
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t right_size = 1;
|
||||||
|
for (auto dim : right_dims) {
|
||||||
|
right_size *= arg_shape[dim];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate the transpose permutation attribute.
|
||||||
|
llvm::SmallVector<int64_t, 5> transpose_permutation(left_dims.begin(),
|
||||||
|
left_dims.end());
|
||||||
|
transpose_permutation.append(right_dims.begin(), right_dims.end());
|
||||||
|
|
||||||
|
mlir::TensorType transpose_permutation_type = RankedTensorType::get(
|
||||||
|
{static_cast<int64_t>(transpose_permutation.size())},
|
||||||
|
rewriter->getIntegerType(64));
|
||||||
|
|
||||||
|
auto transpose_permutation_attr =
|
||||||
|
DenseIntElementsAttr::get(transpose_permutation_type,
|
||||||
|
llvm::makeArrayRef(transpose_permutation))
|
||||||
|
.cast<DenseIntElementsAttr>();
|
||||||
|
|
||||||
|
// Compute the resulting shape.
|
||||||
|
llvm::SmallVector<int64_t, 5> transposed_shape;
|
||||||
|
for (auto val : transpose_permutation) {
|
||||||
|
transposed_shape.push_back(arg_shape[val]);
|
||||||
|
}
|
||||||
|
auto transpose_type = RankedTensorType::get(transposed_shape, element_type);
|
||||||
|
auto transpose_result = rewriter->create<mlir::xla_hlo::TransposeOp>(
|
||||||
|
loc, transpose_type, arg, transpose_permutation_attr);
|
||||||
|
|
||||||
|
// Return the final result.
|
||||||
|
auto reshaped_type =
|
||||||
|
RankedTensorType::get({left_size, right_size}, element_type);
|
||||||
|
return rewriter->create<mlir::xla_hlo::ReshapeOp>(loc, reshaped_type,
|
||||||
|
transpose_result);
|
||||||
|
}
|
||||||
|
|
||||||
|
Value ProcessDotArg(Value arg, mlir::Location loc,
|
||||||
|
ElementsAttr contract_dims_attr, bool outer_dims_first,
|
||||||
|
PatternRewriter *rewriter) {
|
||||||
|
auto shape = arg.getType().cast<mlir::ShapedType>().getShape();
|
||||||
|
|
||||||
|
llvm::SmallVector<bool, 5> is_outer_dim;
|
||||||
|
is_outer_dim.resize(shape.size(), true);
|
||||||
|
|
||||||
|
// Compute the contract dimension ordering.
|
||||||
|
llvm::SmallVector<int64_t, 5> contract_dims;
|
||||||
|
for (auto dim : contract_dims_attr.getValues<int64_t>()) {
|
||||||
|
contract_dims.push_back(dim);
|
||||||
|
is_outer_dim[dim] = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute the outer dimension orderings.
|
||||||
|
llvm::SmallVector<int64_t, 5> outer_dims;
|
||||||
|
for (auto it : llvm::enumerate(is_outer_dim)) {
|
||||||
|
if (it.value()) {
|
||||||
|
outer_dims.push_back(it.index());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (outer_dims_first) {
|
||||||
|
return TransposeReshape(arg, loc, outer_dims, contract_dims, shape,
|
||||||
|
rewriter);
|
||||||
|
}
|
||||||
|
|
||||||
|
return TransposeReshape(arg, loc, contract_dims, outer_dims, shape, rewriter);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct GeneralDotConvert
|
||||||
|
: public OpRewritePattern<mlir::xla_hlo::DotGeneralOp> {
|
||||||
|
// Attempts to lower a General Dot operator to a standard Dot operator.
|
||||||
|
// General dots include batching dimensions and can have collapsing
|
||||||
|
// dimensions along any axis. Inserting correctly arrange transpose and
|
||||||
|
// reshape operators organizes the tensors and allows the General Dot to be
|
||||||
|
// replaced with the standard Dot operator.
|
||||||
|
//
|
||||||
|
// Note: This requires an empty list of batch dimensions.
|
||||||
|
|
||||||
|
explicit GeneralDotConvert(MLIRContext *context)
|
||||||
|
: OpRewritePattern(context) {}
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(mlir::xla_hlo::DotGeneralOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
auto dot_element_type = mlir::getElementTypeOrSelf(op);
|
||||||
|
|
||||||
|
auto dot_numbers = op.dot_dimension_numbers();
|
||||||
|
if (dot_numbers.lhs_batching_dimensions().getNumElements() != 0 ||
|
||||||
|
dot_numbers.rhs_batching_dimensions().getNumElements() != 0) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto lhs = ProcessDotArg(op.lhs(), op.getLoc(),
|
||||||
|
dot_numbers.lhs_contracting_dimensions(),
|
||||||
|
/*outer_dims_first=*/true, &rewriter);
|
||||||
|
|
||||||
|
auto rhs = ProcessDotArg(op.rhs(), op.getLoc(),
|
||||||
|
dot_numbers.rhs_contracting_dimensions(),
|
||||||
|
/*outer_dims_first=*/false, &rewriter);
|
||||||
|
|
||||||
|
// Dot resulting shape.
|
||||||
|
auto lhs_shape = lhs.getType().cast<mlir::ShapedType>().getShape();
|
||||||
|
auto rhs_shape = rhs.getType().cast<mlir::ShapedType>().getShape();
|
||||||
|
auto new_dot_type =
|
||||||
|
RankedTensorType::get({lhs_shape[0], rhs_shape[1]}, dot_element_type);
|
||||||
|
|
||||||
|
auto new_dot_op = rewriter.create<mlir::xla_hlo::DotOp>(
|
||||||
|
op.getLoc(), new_dot_type, lhs, rhs, *(op.precision_config()));
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<mlir::xla_hlo::ReshapeOp>(op, op.getType(),
|
||||||
|
new_dot_op);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct LegalizeGeneralDot
|
||||||
|
: public PassWrapper<LegalizeGeneralDot, FunctionPass> {
|
||||||
|
/// Lower all general dots that can be represented as a non-batched matmul.
|
||||||
|
void runOnFunction() override {
|
||||||
|
OwningRewritePatternList patterns;
|
||||||
|
mlir::xla_hlo::PopulateGeneralDotOpLoweringPatterns(&patterns,
|
||||||
|
&getContext());
|
||||||
|
applyPatternsAndFoldGreedily(getFunction(), patterns);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void mlir::xla_hlo::PopulateGeneralDotOpLoweringPatterns(
|
||||||
|
OwningRewritePatternList *patterns, MLIRContext *ctx) {
|
||||||
|
patterns->insert<GeneralDotConvert>(ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
static PassRegistration<LegalizeGeneralDot> legalize_pass(
|
||||||
|
"test-xla-lower-general-dot",
|
||||||
|
"Tests lowering general dot to a non-batched dot when possible");
|
|
@ -0,0 +1,90 @@
|
||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include <numeric>
|
||||||
|
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h"
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace xla_hlo {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// Converts ClampOp with broadcast semantics. ClampOp requires "all three arrays
|
||||||
|
// must be the same shape. Alternatively, as a restricted form of broadcasting,
|
||||||
|
// min and/or max can be a scalar of type T."
|
||||||
|
struct ClampWithBroadcastConvert : public OpRewritePattern<ClampOp> {
|
||||||
|
explicit ClampWithBroadcastConvert(MLIRContext *context)
|
||||||
|
: OpRewritePattern<ClampOp>(context) {}
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(ClampOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
auto operand_type = op.operand().getType().dyn_cast<RankedTensorType>();
|
||||||
|
auto max_type = op.max().getType().dyn_cast<RankedTensorType>();
|
||||||
|
auto min_type = op.min().getType().dyn_cast<RankedTensorType>();
|
||||||
|
// Unrancked types are not supported.
|
||||||
|
if (!operand_type || !max_type || !min_type) return failure();
|
||||||
|
// Does not support operand with dynamic dimensions for now.
|
||||||
|
if (!operand_type.hasStaticShape()) return failure();
|
||||||
|
|
||||||
|
ArrayRef<int64_t> operand_shape = operand_type.getShape();
|
||||||
|
|
||||||
|
Value max_value = op.max();
|
||||||
|
if (max_type != operand_type) {
|
||||||
|
assert(max_type.getRank() == 0);
|
||||||
|
max_value = rewriter.createOrFold<BroadcastOp>(
|
||||||
|
op.getLoc(), operand_type, max_value,
|
||||||
|
rewriter.getI64TensorAttr(operand_shape));
|
||||||
|
}
|
||||||
|
|
||||||
|
Value min_value = op.min();
|
||||||
|
if (min_type != operand_type) {
|
||||||
|
assert(min_type.getRank() == 0);
|
||||||
|
min_value = rewriter.createOrFold<BroadcastOp>(
|
||||||
|
op.getLoc(), operand_type, min_value,
|
||||||
|
rewriter.getI64TensorAttr(operand_shape));
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<ClampOp>(op, op.getType(), min_value,
|
||||||
|
op.operand(), max_value);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void SetupMaterializeBroadcastsLegality(MLIRContext *context,
|
||||||
|
ConversionTarget *conversionTarget) {
|
||||||
|
conversionTarget->addDynamicallyLegalOp<ClampOp>([](ClampOp op) {
|
||||||
|
return op.max().getType() == op.operand().getType() &&
|
||||||
|
op.min().getType() == op.operand().getType();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void PopulateMaterializeBroadcastsPatterns(MLIRContext *context,
|
||||||
|
OwningRewritePatternList *patterns) {
|
||||||
|
// ClampOp. This op has a special case where it accepts either same-shaped
|
||||||
|
// inputs or scalars (a restricted form of broadcasting). This makes the
|
||||||
|
// broadcast explicit.
|
||||||
|
patterns->insert<ClampWithBroadcastConvert>(context);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace xla_hlo
|
||||||
|
} // namespace mlir
|
|
@ -0,0 +1,58 @@
|
||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h"
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace xla_hlo {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
struct TestMaterializeBroadcastsPass
|
||||||
|
: public PassWrapper<TestMaterializeBroadcastsPass, FunctionPass> {
|
||||||
|
void runOnFunction() override {
|
||||||
|
ConversionTarget conversionTarget(getContext());
|
||||||
|
OwningRewritePatternList conversionPatterns;
|
||||||
|
|
||||||
|
// Consider the xla_hlo dialect legal for tests.
|
||||||
|
conversionTarget.addLegalDialect<XlaHloDialect>();
|
||||||
|
// The conversion uses helpers from the Standard dialect.
|
||||||
|
conversionTarget.addLegalDialect<mlir::StandardOpsDialect>();
|
||||||
|
|
||||||
|
SetupMaterializeBroadcastsLegality(&getContext(), &conversionTarget);
|
||||||
|
PopulateMaterializeBroadcastsPatterns(&getContext(), &conversionPatterns);
|
||||||
|
|
||||||
|
if (failed(applyPartialConversion(getFunction(), conversionTarget,
|
||||||
|
conversionPatterns))) {
|
||||||
|
return signalPassFailure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
} // namespace xla_hlo
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
static mlir::PassRegistration<mlir::xla_hlo::TestMaterializeBroadcastsPass>
|
||||||
|
pass("test-xla-materialize-broadcasts",
|
||||||
|
"Test pass for materializing 'broadcast_dimensions' attributes");
|
|
@ -0,0 +1,85 @@
|
||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/DenseMap.h"
|
||||||
|
#include "third_party/llvm/llvm-project/llvm/include/llvm/Support/Casting.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/PassManager.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Support/LLVM.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/RegionUtils.h"
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace xla_hlo {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// A pass that sinks constants implicitly captured in control flow regions. This
|
||||||
|
// is necessary to export to XLA.
|
||||||
|
class SinkConstantsToControlFlow
|
||||||
|
: public mlir::PassWrapper<SinkConstantsToControlFlow, FunctionPass> {
|
||||||
|
void runOnFunction() override {
|
||||||
|
getFunction().walk([](Operation* op) {
|
||||||
|
if (auto while_op = llvm::dyn_cast<WhileOp>(op)) {
|
||||||
|
SinkToRegion(&while_op.body());
|
||||||
|
SinkToRegion(&while_op.cond());
|
||||||
|
} else if (auto if_op = llvm::dyn_cast<IfOp>(op)) {
|
||||||
|
SinkToRegion(&if_op.true_branch());
|
||||||
|
SinkToRegion(&if_op.false_branch());
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Performs constant sinking into a region.
|
||||||
|
static void SinkToRegion(Region* region) {
|
||||||
|
llvm::DenseMap<Value, ConstOp> sunk_constant;
|
||||||
|
visitUsedValuesDefinedAbove({*region}, [&](OpOperand* use) {
|
||||||
|
Value constant = use->get();
|
||||||
|
auto const_op = dyn_cast_or_null<ConstOp>(constant.getDefiningOp());
|
||||||
|
if (!const_op) return;
|
||||||
|
auto map_entry = sunk_constant.try_emplace(constant, nullptr);
|
||||||
|
if (!map_entry.second) {
|
||||||
|
// This constant has already been cloned into the region, reuse it.
|
||||||
|
use->set(map_entry.first->getSecond().getResult());
|
||||||
|
if (constant.use_empty()) const_op.erase();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (constant.hasOneUse()) {
|
||||||
|
const_op.getOperation()->moveBefore(®ion->front().front());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
map_entry.first->getSecond() = const_op.clone();
|
||||||
|
region->front().getOperations().insert(region->front().begin(),
|
||||||
|
map_entry.first->getSecond());
|
||||||
|
use->set(map_entry.first->getSecond().getResult());
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
static mlir::PassRegistration<SinkConstantsToControlFlow> pass(
|
||||||
|
"xla-hlo-sink-constants-to-control-flow",
|
||||||
|
"Sink constants implicitly captured in control flow regions. This is "
|
||||||
|
"necessary to export to XLA.");
|
||||||
|
|
||||||
|
} // anonymous namespace
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<FuncOp>> createSinkConstantsToControlFlowPass() {
|
||||||
|
return std::make_unique<SinkConstantsToControlFlow>();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace xla_hlo
|
||||||
|
} // namespace mlir
|
|
@ -0,0 +1,100 @@
|
||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Identifier.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OperationSupport.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/InferTypeOpInterface.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace xla {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
struct InferReturnTypeComponentsPattern : public RewritePattern {
|
||||||
|
InferReturnTypeComponentsPattern(MLIRContext *context)
|
||||||
|
: RewritePattern("xla_test.get_return_type_components", 1, context) {}
|
||||||
|
LogicalResult matchAndRewrite(Operation *op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
if (op->getNumOperands() != 1) return failure();
|
||||||
|
auto defining_op = op->getOperand(0).getDefiningOp();
|
||||||
|
auto defining_op_int =
|
||||||
|
llvm::dyn_cast_or_null<InferShapedTypeOpInterface>(defining_op);
|
||||||
|
if (!defining_op_int) return failure();
|
||||||
|
SmallVector<ShapedTypeComponents, 4> components;
|
||||||
|
if (failed(defining_op_int.inferReturnTypeComponents(
|
||||||
|
op->getContext(), op->getLoc(), defining_op->getOperands(),
|
||||||
|
defining_op->getAttrDictionary(), defining_op->getRegions(),
|
||||||
|
components))) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Replace the op with another pass-through op with attributes added.
|
||||||
|
OperationState state(op->getLoc(), "xla_test.return_type_components",
|
||||||
|
op->getOperands(), op->getResultTypes(),
|
||||||
|
op->getAttrs());
|
||||||
|
auto new_op = rewriter.createOperation(state);
|
||||||
|
for (auto it : llvm::enumerate(components)) {
|
||||||
|
if (it.value().hasRank()) {
|
||||||
|
new_op->setAttr((StringRef("dims") + Twine(it.index())).str(),
|
||||||
|
rewriter.getI64ArrayAttr(it.value().getDims()));
|
||||||
|
}
|
||||||
|
if (it.value().getElementType()) {
|
||||||
|
new_op->setAttr((Twine("element_type") + Twine(it.index())).str(),
|
||||||
|
TypeAttr::get(it.value().getElementType()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
rewriter.replaceOp(op, {new_op->getResults()});
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ReifyReturnTypeShapesPattern : public RewritePattern {
|
||||||
|
ReifyReturnTypeShapesPattern(MLIRContext *context)
|
||||||
|
: RewritePattern("xla_test.reify_return_type_shapes", 1, context) {}
|
||||||
|
LogicalResult matchAndRewrite(Operation *op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
if (op->getNumOperands() != 1) return failure();
|
||||||
|
auto defining_op = llvm::dyn_cast_or_null<InferShapedTypeOpInterface>(
|
||||||
|
op->getOperand(0).getDefiningOp());
|
||||||
|
if (!defining_op) return failure();
|
||||||
|
SmallVector<Value, 4> return_shapes;
|
||||||
|
if (failed(defining_op.reifyReturnTypeShapes(rewriter, return_shapes))) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
rewriter.replaceOp(op, return_shapes);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct TestInferShapedTypeMethodsPass
|
||||||
|
: public PassWrapper<TestInferShapedTypeMethodsPass, FunctionPass> {
|
||||||
|
void runOnFunction() override {
|
||||||
|
OwningRewritePatternList patterns;
|
||||||
|
patterns.insert<ReifyReturnTypeShapesPattern>(&getContext());
|
||||||
|
patterns.insert<InferReturnTypeComponentsPattern>(&getContext());
|
||||||
|
applyPatternsAndFoldGreedily(getFunction(), patterns);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace xla
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
static mlir::PassRegistration<mlir::xla::TestInferShapedTypeMethodsPass> pass(
|
||||||
|
"test-xla-infer-shaped-type-methods",
|
||||||
|
"Uses test ops to invoke InferShapedTypeOpInterface methods");
|
|
@ -0,0 +1,184 @@
|
||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/SmallVector.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Types.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h"
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace xla_hlo {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// Broadcasts the 1D value tensor 'value_1d' to the shape of 'result_type'. If
|
||||||
|
// 'shape_value' is initialized, creates a dynamic broadcast, otherwise creates
|
||||||
|
// a static broadcast.
|
||||||
|
Value BroadcastToFeatureDim(Location loc, RankedTensorType result_type,
|
||||||
|
Value value_1d, Value shape_value,
|
||||||
|
int64_t feature_dim,
|
||||||
|
PatternRewriter& rewriter) { // NOLINT
|
||||||
|
Builder b(rewriter.getContext());
|
||||||
|
auto dims_type = RankedTensorType::get({1}, b.getIntegerType(64));
|
||||||
|
auto dims = DenseIntElementsAttr::get(dims_type, {feature_dim});
|
||||||
|
if (shape_value) {
|
||||||
|
return rewriter.createOrFold<xla_hlo::DynamicBroadcastInDimOp>(
|
||||||
|
loc, result_type, value_1d, shape_value, dims);
|
||||||
|
}
|
||||||
|
assert(result_type.hasStaticShape());
|
||||||
|
return rewriter.create<xla_hlo::BroadcastInDimOp>(loc, result_type, value_1d,
|
||||||
|
dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate the shape value of operand, assuming it is a dynamic shape with
|
||||||
|
// static rank.
|
||||||
|
Value CalculateShapeValue(Location loc, Value operand,
|
||||||
|
PatternRewriter& rewriter) { // NOLINT
|
||||||
|
RankedTensorType result_type = operand.getType().dyn_cast<RankedTensorType>();
|
||||||
|
llvm::SmallVector<Value, 4> shape_values;
|
||||||
|
int64_t rank = result_type.getRank();
|
||||||
|
shape_values.reserve(rank);
|
||||||
|
for (int64_t i = 0; i < rank; ++i) {
|
||||||
|
shape_values.push_back(rewriter.create<mlir::DimOp>(loc, operand, i));
|
||||||
|
}
|
||||||
|
return rewriter.create<TensorFromElementsOp>(loc, shape_values);
|
||||||
|
}
|
||||||
|
|
||||||
|
Value MaterializeEpsilon(Operation* op, FloatAttr epsilon_attr,
|
||||||
|
FloatType fp_type, Value variance,
|
||||||
|
RankedTensorType broadcast_to_type,
|
||||||
|
PatternRewriter& rewriter) { // NOLINT
|
||||||
|
Builder b(rewriter.getContext());
|
||||||
|
if (epsilon_attr.getType() != fp_type) {
|
||||||
|
// Need to convert.
|
||||||
|
bool loses_info;
|
||||||
|
APFloat epsilon_float = epsilon_attr.getValue();
|
||||||
|
auto status = epsilon_float.convert(
|
||||||
|
fp_type.getFloatSemantics(), APFloat::rmNearestTiesToEven, &loses_info);
|
||||||
|
if ((status & (~APFloat::opInexact)) != APFloat::opOK) {
|
||||||
|
op->emitWarning() << "Could not convert batch_norm epsilon to target fp "
|
||||||
|
"type: opStatus = "
|
||||||
|
<< static_cast<int>(status);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
if (loses_info) {
|
||||||
|
op->emitWarning("Conversion of epsilon loses precision");
|
||||||
|
}
|
||||||
|
epsilon_attr = b.getFloatAttr(fp_type, epsilon_float);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto scalar_type = RankedTensorType::get({}, fp_type);
|
||||||
|
auto epsilon_tensor_attr =
|
||||||
|
DenseElementsAttr::get(scalar_type, {epsilon_attr.cast<Attribute>()});
|
||||||
|
Value epsilon =
|
||||||
|
rewriter.create<xla_hlo::ConstOp>(op->getLoc(), epsilon_tensor_attr);
|
||||||
|
auto dims_type = RankedTensorType::get({0}, b.getIntegerType(64));
|
||||||
|
auto dims = DenseIntElementsAttr::get(dims_type, SmallVector<int64_t, 1>{});
|
||||||
|
if (broadcast_to_type.hasStaticShape()) {
|
||||||
|
return rewriter.create<xla_hlo::BroadcastInDimOp>(
|
||||||
|
op->getLoc(), broadcast_to_type, epsilon, /*broadcast_dims=*/dims);
|
||||||
|
}
|
||||||
|
Value shape_value = CalculateShapeValue(op->getLoc(), variance, rewriter);
|
||||||
|
return rewriter.createOrFold<xla_hlo::DynamicBroadcastInDimOp>(
|
||||||
|
op->getLoc(), broadcast_to_type, epsilon, shape_value,
|
||||||
|
/*broadcast_dims=*/dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
class UnfuseBatchNormInferencePattern
|
||||||
|
: public OpRewritePattern<xla_hlo::BatchNormInferenceOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern<xla_hlo::BatchNormInferenceOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(xla_hlo::BatchNormInferenceOp bn_op,
|
||||||
|
PatternRewriter& rewriter) const override {
|
||||||
|
// Enforce type invariants.
|
||||||
|
// Note that we deduce the actual element type from the variance,
|
||||||
|
// which should not be subject to quantization at a higher level.
|
||||||
|
auto input_type = bn_op.operand().getType().dyn_cast<RankedTensorType>();
|
||||||
|
auto variance_type =
|
||||||
|
bn_op.variance().getType().dyn_cast<RankedTensorType>();
|
||||||
|
if (!input_type || !variance_type) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
auto fp_type = variance_type.getElementType().dyn_cast<FloatType>();
|
||||||
|
if (!fp_type) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
int64_t feature_dim = bn_op.feature_index().getSExtValue();
|
||||||
|
|
||||||
|
// Add epsilon to the variance and sqrt to get stddev:
|
||||||
|
// stddev = sqrt(variance + epsilon)
|
||||||
|
auto epsilon =
|
||||||
|
MaterializeEpsilon(bn_op.getOperation(), bn_op.epsilonAttr(), fp_type,
|
||||||
|
bn_op.variance(), variance_type, rewriter);
|
||||||
|
if (!epsilon) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
Value stddev = rewriter.create<xla_hlo::AddOp>(bn_op.getLoc(),
|
||||||
|
bn_op.variance(), epsilon);
|
||||||
|
stddev = rewriter.create<xla_hlo::SqrtOp>(bn_op.getLoc(), stddev);
|
||||||
|
|
||||||
|
// Broadcast all terms.
|
||||||
|
Value shape_value;
|
||||||
|
if (!input_type.hasStaticShape()) {
|
||||||
|
shape_value =
|
||||||
|
CalculateShapeValue(bn_op.getLoc(), bn_op.operand(), rewriter);
|
||||||
|
}
|
||||||
|
auto broadcast_scale =
|
||||||
|
BroadcastToFeatureDim(bn_op.getLoc(), input_type, bn_op.scale(),
|
||||||
|
shape_value, feature_dim, rewriter);
|
||||||
|
auto broadcast_offset =
|
||||||
|
BroadcastToFeatureDim(bn_op.getLoc(), input_type, bn_op.offset(),
|
||||||
|
shape_value, feature_dim, rewriter);
|
||||||
|
auto broadcast_mean =
|
||||||
|
BroadcastToFeatureDim(bn_op.getLoc(), input_type, bn_op.mean(),
|
||||||
|
shape_value, feature_dim, rewriter);
|
||||||
|
auto broadcast_stddev = BroadcastToFeatureDim(
|
||||||
|
bn_op.getLoc(), input_type, stddev, shape_value, feature_dim, rewriter);
|
||||||
|
|
||||||
|
// Compute:
|
||||||
|
// scale * (input - mean) / stddev + offset
|
||||||
|
Value result = rewriter.create<xla_hlo::SubOp>(
|
||||||
|
bn_op.getLoc(), bn_op.operand(), broadcast_mean);
|
||||||
|
result = rewriter.create<xla_hlo::MulOp>(bn_op.getLoc(), result,
|
||||||
|
broadcast_scale);
|
||||||
|
result = rewriter.create<xla_hlo::DivOp>(bn_op.getLoc(), result,
|
||||||
|
broadcast_stddev);
|
||||||
|
rewriter.replaceOpWithNewOp<xla_hlo::AddOp>(bn_op, result,
|
||||||
|
broadcast_offset);
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
// Populates conversion patterns to unfuse batch normalization operations.
|
||||||
|
// In combination with marking such ops as illegal, this allows backends that
|
||||||
|
// do not have special support for fused batchnorm to use simpler arithmetic
|
||||||
|
// primitives.
|
||||||
|
void PopulateUnfuseBatchNormPatterns(MLIRContext* context,
|
||||||
|
OwningRewritePatternList* patterns) {
|
||||||
|
patterns->insert<UnfuseBatchNormInferencePattern>(context);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace xla_hlo
|
||||||
|
} // namespace mlir
|
|
@ -0,0 +1,46 @@
|
||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h"
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace xla_hlo {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
struct TestUnfuseBatchNormPass
|
||||||
|
: public PassWrapper<TestUnfuseBatchNormPass, OperationPass<>> {
|
||||||
|
void runOnOperation() override {
|
||||||
|
OwningRewritePatternList patterns;
|
||||||
|
PopulateUnfuseBatchNormPatterns(&getContext(), &patterns);
|
||||||
|
applyPatternsAndFoldGreedily(getOperation(), patterns);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
} // namespace xla_hlo
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
static mlir::PassRegistration<mlir::xla_hlo::TestUnfuseBatchNormPass> pass(
|
||||||
|
"test-xla-unfuse-batch-norm",
|
||||||
|
"Test pass for materializing 'broadcast_dimensions' attributes");
|
|
@ -0,0 +1,579 @@
|
||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <unordered_set>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
|
||||||
|
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
|
||||||
|
#include "mlir/IR/Matchers.h"
|
||||||
|
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
|
||||||
|
#include "mlir/Transforms/RegionUtils.h" // TF:llvm-project
|
||||||
|
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/EquivalenceClasses.h"
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/cycle_detector.h"
|
||||||
|
|
||||||
|
// This pass has similar functionality of the fusion pass in XLA stack.
|
||||||
|
// However, unlike XLA, it targets the fully dynamic shape scenario.
|
||||||
|
// Currently, it implements the kLoop and kInput fusion templates.
|
||||||
|
// During conversion, it tries to greedily find kLoop/kInput fusion
|
||||||
|
// patterns.
|
||||||
|
//
|
||||||
|
// Similar to XLA, this pass supports fusion pattern having multiple outputs
|
||||||
|
// if all the shape of outputs are consistent. Following are some examples.
|
||||||
|
//
|
||||||
|
// kLoop kInput
|
||||||
|
// +----+ +----+ +----+ +----+ +----+ +----+
|
||||||
|
// |elem| |elem| |elem| |elem<----+elem+---->elem+----+
|
||||||
|
// +-+--+ +-+--+ +-+--+ +-+--+ +----+ +-+--+ |
|
||||||
|
// | | | | | |
|
||||||
|
// | | | | |
|
||||||
|
// +-v--+ | +-v--+ +--v---+ +--v---+ |
|
||||||
|
// |elem+<---+----<+elem| |reduce| |reduce| |
|
||||||
|
// +-+--+ +-+--+ +--+---+ +--+---+ |
|
||||||
|
// | | | | |
|
||||||
|
// | | | | |
|
||||||
|
// v v v v v
|
||||||
|
//
|
||||||
|
// To this end, we also add an simple shape constraint analysis phase.
|
||||||
|
// For kLoop fusion template, it requires all the outputs of the fused
|
||||||
|
// pattern have the same shape. However, we don't know the actual value
|
||||||
|
// of the shape at the compile time in the dynamic shape world.
|
||||||
|
// Fortunately, we could still infer the relationship among different ops
|
||||||
|
// according to their shape constrain traits. Currently, We only consider
|
||||||
|
// shape equality propagation for elementwise ops (assuming that implicit
|
||||||
|
// shape broadcast is forbidden). The above process could be built on the
|
||||||
|
// shape dialect once it is ready.
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace xla_hlo {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using llvm::EquivalenceClasses;
|
||||||
|
using FusionPattern = std::vector<Operation*>;
|
||||||
|
using FusionPlan = std::vector<FusionPattern>;
|
||||||
|
|
||||||
|
// To support using EquivalenceClasses for Value
|
||||||
|
class ValueWrapper {
|
||||||
|
public:
|
||||||
|
explicit ValueWrapper(Value value) : value_(std::move(value)) {}
|
||||||
|
|
||||||
|
Value getValue() const { return value_; }
|
||||||
|
|
||||||
|
bool operator==(const ValueWrapper& rhs) const {
|
||||||
|
return getValue() == rhs.getValue();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
Value value_;
|
||||||
|
};
|
||||||
|
|
||||||
|
bool operator<(const ValueWrapper& lhs, const ValueWrapper& rhs) {
|
||||||
|
auto lhs_value = lhs.getValue().getAsOpaquePointer();
|
||||||
|
auto rhs_value = rhs.getValue().getAsOpaquePointer();
|
||||||
|
return lhs_value < rhs_value;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IsFusible(Operation* op) {
|
||||||
|
if (matchPattern(op, m_Constant())) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
auto op_fusibility = dyn_cast<InferFusibilityOpInterface>(op);
|
||||||
|
return op_fusibility && (op_fusibility.isFusibleWithOperand() ||
|
||||||
|
op_fusibility.isFusibleWithConsumer());
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Value, 4> GetInputsOfFusionPattern(const FusionPattern& pattern) {
|
||||||
|
SmallVector<Value, 4> inputs;
|
||||||
|
DenseSet<Value> input_set;
|
||||||
|
DenseSet<Operation*> op_set;
|
||||||
|
for (Operation* op : pattern) {
|
||||||
|
bool inserted = op_set.insert(op).second;
|
||||||
|
(void)inserted;
|
||||||
|
assert(inserted && "FusionPattern contains duplicate operations");
|
||||||
|
}
|
||||||
|
|
||||||
|
for (Operation* op : pattern) {
|
||||||
|
for (Value operand : op->getOperands()) {
|
||||||
|
Operation* operand_op = operand.getDefiningOp();
|
||||||
|
if (op_set.find(operand_op) != op_set.end()) {
|
||||||
|
// skip if defining op is in the pattern
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (input_set.insert(operand).second) {
|
||||||
|
inputs.push_back(operand);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return inputs;
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Value, 4> GetOutputsOfFusionPattern(const FusionPattern& pattern) {
|
||||||
|
SmallVector<Value, 4> outputs;
|
||||||
|
DenseSet<Operation*> op_set;
|
||||||
|
for (Operation* op : pattern) {
|
||||||
|
bool inserted = op_set.insert(op).second;
|
||||||
|
(void)inserted;
|
||||||
|
assert(inserted && "FusionPattern contains duplicate operations");
|
||||||
|
}
|
||||||
|
|
||||||
|
for (Operation* op : pattern) {
|
||||||
|
for (Value result : op->getResults()) {
|
||||||
|
bool has_external_user = llvm::any_of(
|
||||||
|
result.getUses(),
|
||||||
|
[&](OpOperand& use) { return !op_set.count(use.getOwner()); });
|
||||||
|
if (has_external_user) {
|
||||||
|
outputs.push_back(result);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return outputs;
|
||||||
|
}
|
||||||
|
|
||||||
|
FusionPattern MergeFusionPattern(const FusionPattern& lhs,
|
||||||
|
const FusionPattern& rhs) {
|
||||||
|
FusionPattern pattern(lhs);
|
||||||
|
pattern.insert(pattern.end(), rhs.begin(), rhs.end());
|
||||||
|
return pattern;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline int EffectiveSize(const FusionPattern& pattern) {
|
||||||
|
return llvm::count_if(
|
||||||
|
pattern, [](Operation* op) { return !matchPattern(op, m_Constant()); });
|
||||||
|
}
|
||||||
|
|
||||||
|
// This is an simple shape constraint analysis, which is used to
|
||||||
|
// guide fusion decision (e.g. we only fuse shape-compatible ops).
|
||||||
|
//
|
||||||
|
// Currently, We only consider shape equality propagation based
|
||||||
|
// on the shape constrain traits of elementwise ops (assuming that
|
||||||
|
// implicit shape broadcast is forbidden).
|
||||||
|
class ShapeConstraintAnalysis {
|
||||||
|
public:
|
||||||
|
explicit ShapeConstraintAnalysis(const SmallVectorImpl<Operation*>& op_list) {
|
||||||
|
PropagateEquality(op_list);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns true is `lhs` and `rhs` are supposed to have same shape.
|
||||||
|
bool HasSameShape(Value lhs, Value rhs) {
|
||||||
|
return impl_.isEquivalent(ValueWrapper(lhs), ValueWrapper(rhs));
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
// shape equality propagation based on the shape constrains of
|
||||||
|
// elementwise ops.
|
||||||
|
void PropagateEquality(const SmallVectorImpl<Operation*>& op_list) {
|
||||||
|
bool converged = true;
|
||||||
|
do {
|
||||||
|
converged = true;
|
||||||
|
auto update = [&](Value lhs, Value rhs) {
|
||||||
|
if (!impl_.isEquivalent(ValueWrapper(lhs), ValueWrapper(rhs))) {
|
||||||
|
converged = false;
|
||||||
|
impl_.unionSets(ValueWrapper(lhs), ValueWrapper(rhs));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
for (Operation* op : op_list) {
|
||||||
|
auto op_fusibility = dyn_cast<InferFusibilityOpInterface>(op);
|
||||||
|
if (!op_fusibility) continue;
|
||||||
|
int numInput = op->getNumOperands();
|
||||||
|
int numOutput = op->getNumResults();
|
||||||
|
// shape equality propagation between inputs.
|
||||||
|
for (int input1 = 0; input1 < numInput; ++input1)
|
||||||
|
for (int input2 = input1 + 1; input2 < numInput; ++input2)
|
||||||
|
if (op_fusibility.inferInputsShapeEquality(input1, input2))
|
||||||
|
update(op->getOperand(input1), op->getOperand(input2));
|
||||||
|
|
||||||
|
// shape equality propagation between outputs.
|
||||||
|
for (int output1 = 0; output1 < numOutput; ++output1)
|
||||||
|
for (int output2 = output1 + 1; output2 < numOutput; ++output2)
|
||||||
|
if (op_fusibility.inferOutputsShapeEquality(output1, output2))
|
||||||
|
update(op->getResult(output1), op->getResult(output2));
|
||||||
|
|
||||||
|
// shape equality propagation between input and output.
|
||||||
|
for (int input = 0; input < numInput; ++input)
|
||||||
|
for (int output = 0; output < numOutput; ++output)
|
||||||
|
if (op_fusibility.inferInputOutputShapeEquality(input, output))
|
||||||
|
update(op->getOperand(input), op->getResult(output));
|
||||||
|
}
|
||||||
|
} while (!converged);
|
||||||
|
}
|
||||||
|
|
||||||
|
// a UnionFind set
|
||||||
|
EquivalenceClasses<ValueWrapper> impl_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// A fusion planner that can propose a fusion plan for a block of ops.
|
||||||
|
// The fusion plan is consisted of a group of fusion patterns.
|
||||||
|
//
|
||||||
|
// Currently all proposed patterns followed xla kLoop/kInput like fusion
|
||||||
|
// templates while are adapted to the fully dynamic shape world.
|
||||||
|
//
|
||||||
|
// kLoop fusion template satifies:
|
||||||
|
// - all ops in the fusion pattern are element-wise.
|
||||||
|
// - all the shapes of outputs of fusion pattern are same, and thus can
|
||||||
|
// fit into a same parallel loop.
|
||||||
|
//
|
||||||
|
// kInput fusion template satifies:
|
||||||
|
// - any op in the fusion pattern is either element-wise or a reduction.
|
||||||
|
// - if a op is a reduction, its output cannot be consumered by other
|
||||||
|
// ops in the same fusion pattern.
|
||||||
|
// - all the effective shapes of outputs of fusion pattern are same.
|
||||||
|
// - For element-wise op, its effective shape is its output shape.
|
||||||
|
// - For reduction op, its effective shape is its operand shape.
|
||||||
|
class FusionPlanner {
|
||||||
|
public:
|
||||||
|
explicit FusionPlanner(const SmallVectorImpl<Operation*>& op_list)
|
||||||
|
: op_list_(op_list),
|
||||||
|
shape_analysis_(op_list),
|
||||||
|
cycle_detector_(op_list.size()) {
|
||||||
|
BuildNodeMap();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns a fusion plan if success, otherwise none.
|
||||||
|
llvm::Optional<FusionPlan> Run() {
|
||||||
|
// Greedily search connected fusible pattern, and ops belonging to
|
||||||
|
// a same fusion pattern are grouped into a cluster.
|
||||||
|
RunEdgeContractionLoop();
|
||||||
|
|
||||||
|
// After doing edge contraction, each unique cluster having size
|
||||||
|
// more than one represents a potential fusion pattern.
|
||||||
|
// We collect all these clusters and construct a fusion plan.
|
||||||
|
//
|
||||||
|
// Note that the ops in a fusion pattern are in topological ordering.
|
||||||
|
FusionPlan plan;
|
||||||
|
DenseMap<int, int> pattern_ids;
|
||||||
|
for (Operation* op : op_list_) {
|
||||||
|
Cluster* cluster = GetClusterForNode(op);
|
||||||
|
int node_id = cluster->cycles_graph_node_id();
|
||||||
|
if (!IsFusible(op_list_[node_id]) ||
|
||||||
|
EffectiveSize(GetClusterForNode(op)->fused_pattern()) <= 1) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (!pattern_ids.count(node_id)) {
|
||||||
|
int pattern_id = pattern_ids.size();
|
||||||
|
pattern_ids[node_id] = pattern_id;
|
||||||
|
plan.emplace_back();
|
||||||
|
}
|
||||||
|
plan[pattern_ids[node_id]].push_back(op);
|
||||||
|
}
|
||||||
|
return plan;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns the op_list this planner operates on.
|
||||||
|
const SmallVectorImpl<Operation*>& op_list() const { return op_list_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Represent a (partial) fused pattern
|
||||||
|
class Cluster {
|
||||||
|
public:
|
||||||
|
Cluster(int node_id, FusionPlanner* planner) : node_id_(node_id) {
|
||||||
|
const SmallVectorImpl<Operation*>& op_list = planner->op_list();
|
||||||
|
pattern_.push_back(op_list[node_id]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Merges `other` into this cluster, and clears `other`.
|
||||||
|
void Merge(Cluster* other) {
|
||||||
|
pattern_.insert(pattern_.end(), other->pattern_.begin(),
|
||||||
|
other->pattern_.end());
|
||||||
|
other->pattern_.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
// The number of nodes in this cluster.
|
||||||
|
int cluster_size() const { return pattern_.size(); }
|
||||||
|
|
||||||
|
// The ID of the cluster as represented in `cycle_detector_`.
|
||||||
|
int cycles_graph_node_id() const { return node_id_; }
|
||||||
|
|
||||||
|
// Sets the ID of the cluster as represented in `cycle_detector_`.
|
||||||
|
void set_cycles_graph_node_id(int cycles_graph_node_id) {
|
||||||
|
node_id_ = cycles_graph_node_id;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Currently the fused pattern this cluster holds.
|
||||||
|
const FusionPattern& fused_pattern() { return pattern_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
// ID of the representative node of this cluster.
|
||||||
|
int node_id_;
|
||||||
|
|
||||||
|
// the fused pattern this cluster holds.
|
||||||
|
FusionPattern pattern_;
|
||||||
|
};
|
||||||
|
|
||||||
|
private:
|
||||||
|
Cluster* MakeCluster(int cycles_graph_node_id) {
|
||||||
|
cluster_storage_.emplace_back(new Cluster(cycles_graph_node_id, this));
|
||||||
|
return cluster_storage_.back().get();
|
||||||
|
}
|
||||||
|
|
||||||
|
void BuildNodeMap() {
|
||||||
|
int num_nodes = op_list_.size();
|
||||||
|
for (int node_id = 0; node_id < num_nodes; ++node_id) {
|
||||||
|
Operation* op = op_list_[node_id];
|
||||||
|
MakeCluster(node_id);
|
||||||
|
op_to_node_id_[op] = node_id;
|
||||||
|
leader_for_node_.insert(node_id);
|
||||||
|
for (Value operand : op->getOperands()) {
|
||||||
|
Operation* operand_op = operand.getDefiningOp();
|
||||||
|
if (operand_op == nullptr) {
|
||||||
|
// skip block argument
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto iter = op_to_node_id_.find(operand_op);
|
||||||
|
assert(iter != op_to_node_id_.end());
|
||||||
|
cycle_detector_.InsertEdge(iter->second, node_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns the cluster contains this op.
|
||||||
|
Cluster* GetClusterForNode(Operation* n) {
|
||||||
|
int id = op_to_node_id_.at(n);
|
||||||
|
id = leader_for_node_.getLeaderValue(id);
|
||||||
|
return cluster_storage_[id].get();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns the cluster contains the op having `node_id`.
|
||||||
|
Cluster* GetClusterForCyclesGraphNode(int node_id) {
|
||||||
|
return cluster_storage_[leader_for_node_.getLeaderValue(node_id)].get();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Merges the clusters `cluster_from` and `cluster_to`.
|
||||||
|
bool MergeClusters(Cluster* cluster_from, Cluster* cluster_to) {
|
||||||
|
int from = cluster_from->cycles_graph_node_id();
|
||||||
|
int to = cluster_to->cycles_graph_node_id();
|
||||||
|
|
||||||
|
auto optional_merged_node = cycle_detector_.ContractEdge(from, to);
|
||||||
|
if (!optional_merged_node.hasValue()) {
|
||||||
|
llvm::dbgs() << "Could not contract " << from << " -> " << to
|
||||||
|
<< " because contracting the edge would create a cycle.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Merge the clusters.
|
||||||
|
cluster_from->Merge(cluster_to);
|
||||||
|
cluster_from->set_cycles_graph_node_id(*optional_merged_node);
|
||||||
|
|
||||||
|
// Merge the UnionFind Set.
|
||||||
|
leader_for_node_.unionSets(from, to);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename FnTy>
|
||||||
|
bool ForEachEdgeInPostOrder(FnTy fn) {
|
||||||
|
bool changed = false;
|
||||||
|
for (int32_t node : cycle_detector_.AllNodesInPostOrder()) {
|
||||||
|
Cluster* cluster_from = GetClusterForCyclesGraphNode(node);
|
||||||
|
// Make a copy of the set of successors because we may modify the graph in
|
||||||
|
// TryToContractEdge.
|
||||||
|
std::vector<int32_t> successors_copy =
|
||||||
|
cycle_detector_.SuccessorsCopy(cluster_from->cycles_graph_node_id());
|
||||||
|
|
||||||
|
for (int to : successors_copy) {
|
||||||
|
Cluster* cluster_to = GetClusterForCyclesGraphNode(to);
|
||||||
|
bool contracted_edge = fn(cluster_from, cluster_to);
|
||||||
|
changed |= contracted_edge;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return changed;
|
||||||
|
}
|
||||||
|
|
||||||
|
// returns the outputs if two cluster were merged
|
||||||
|
SmallVector<Value, 4> GetResultsOfFusedPattern(Cluster* from, Cluster* to) {
|
||||||
|
FusionPattern fused_pattern =
|
||||||
|
MergeFusionPattern(from->fused_pattern(), to->fused_pattern());
|
||||||
|
return GetOutputsOfFusionPattern(fused_pattern);
|
||||||
|
}
|
||||||
|
|
||||||
|
// This function check if fusing `from` with `to` is valid and if so perform
|
||||||
|
// the merge. The validity is based on the operations in the clusters and
|
||||||
|
// the compatibility of the shapes of the outputs of the would-be fused
|
||||||
|
// clusters.
|
||||||
|
// Returns true is the merge was performed.
|
||||||
|
bool TryToContractEdge(Cluster* from, Cluster* to) {
|
||||||
|
int node_to = to->cycles_graph_node_id();
|
||||||
|
int node_from = from->cycles_graph_node_id();
|
||||||
|
|
||||||
|
// Both node_to and node_from should be fusible
|
||||||
|
if (!IsFusible(op_list_[node_to]) || !IsFusible(op_list_[node_from])) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto op_from_fusibility =
|
||||||
|
dyn_cast<InferFusibilityOpInterface>(op_list_[node_from]);
|
||||||
|
if (op_from_fusibility && !op_from_fusibility.isFusibleWithConsumer()) {
|
||||||
|
// This op cannot be fused with its consumers.
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto op_to_fusibility =
|
||||||
|
dyn_cast<InferFusibilityOpInterface>(op_list_[node_to]);
|
||||||
|
if (op_to_fusibility && !op_to_fusibility.isFusibleWithOperand()) {
|
||||||
|
// This op cannot be fused with its operands.
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Output shapes of a fusion pattern should be compatible as described in
|
||||||
|
// the document of this class.
|
||||||
|
SmallVector<Value, 4> results = GetResultsOfFusedPattern(from, to);
|
||||||
|
auto get_workload_shape = [](Value v) {
|
||||||
|
Operation* op = v.getDefiningOp();
|
||||||
|
// Block argument
|
||||||
|
if (!op) return v;
|
||||||
|
auto op_fusibility = dyn_cast<InferFusibilityOpInterface>(op);
|
||||||
|
// Const value
|
||||||
|
if (!op_fusibility) return v;
|
||||||
|
llvm::Optional<Value> workload =
|
||||||
|
op_fusibility.inferEffectiveWorkloadShape();
|
||||||
|
return workload.hasValue() ? *workload : v;
|
||||||
|
};
|
||||||
|
|
||||||
|
Value ref = get_workload_shape(results[0]);
|
||||||
|
if (!llvm::all_of(results, [&](Value result) {
|
||||||
|
Value val = get_workload_shape(result);
|
||||||
|
return shape_analysis_.HasSameShape(ref, val);
|
||||||
|
})) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return MergeClusters(from, to);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Greedily fuse connected node.
|
||||||
|
bool RunEdgeContractionLoop() {
|
||||||
|
using std::placeholders::_1;
|
||||||
|
using std::placeholders::_2;
|
||||||
|
return ForEachEdgeInPostOrder(
|
||||||
|
std::bind(&FusionPlanner::TryToContractEdge, this, _1, _2));
|
||||||
|
}
|
||||||
|
|
||||||
|
const SmallVectorImpl<Operation*>& op_list_;
|
||||||
|
|
||||||
|
// Shape equality checker
|
||||||
|
ShapeConstraintAnalysis shape_analysis_;
|
||||||
|
|
||||||
|
// op -> node_id
|
||||||
|
std::unordered_map<Operation*, int> op_to_node_id_;
|
||||||
|
|
||||||
|
// make sure not introduce cycle after fusion
|
||||||
|
GraphCycles cycle_detector_;
|
||||||
|
std::vector<std::unique_ptr<Cluster>> cluster_storage_;
|
||||||
|
|
||||||
|
// a UnionFind set. Each set represents a (partial) fused pattern
|
||||||
|
// and has a leader as representation.
|
||||||
|
EquivalenceClasses<int32_t> leader_for_node_;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct XlaHloFusion : public mlir::PassWrapper<XlaHloFusion, FunctionPass> {
|
||||||
|
void runOnFunction() override {
|
||||||
|
FuncOp func = getFunction();
|
||||||
|
if (!IsTargetFunc(func)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// process each block and do fusion within a block.
|
||||||
|
for (Block& block : func) {
|
||||||
|
SmallVector<Operation*, 4> op_list;
|
||||||
|
for (Operation& op : block) {
|
||||||
|
op_list.push_back(&op);
|
||||||
|
}
|
||||||
|
|
||||||
|
FusionPlanner planner(op_list);
|
||||||
|
llvm::Optional<FusionPlan> plan = planner.Run();
|
||||||
|
if (!plan) {
|
||||||
|
emitError(func.getLoc(), "can't find a fusion plan");
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (!ApplyFusionPlan(*plan)) {
|
||||||
|
emitError(func.getLoc(), "apply fusion plan failed");
|
||||||
|
signalPassFailure();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IsTargetFunc(FuncOp func) {
|
||||||
|
int num_fusible_ops = 0;
|
||||||
|
bool is_target_func = false;
|
||||||
|
// We only process the function having enough candidates
|
||||||
|
func.walk([&](Operation* op) {
|
||||||
|
num_fusible_ops +=
|
||||||
|
static_cast<int>(dyn_cast<InferFusibilityOpInterface>(op) != nullptr);
|
||||||
|
is_target_func = (num_fusible_ops > 1);
|
||||||
|
// early stop
|
||||||
|
if (is_target_func) return WalkResult::interrupt();
|
||||||
|
return WalkResult::advance();
|
||||||
|
});
|
||||||
|
return is_target_func;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ApplyFusionPlan(const FusionPlan& plan) {
|
||||||
|
for (const FusionPattern& pattern : plan) {
|
||||||
|
OpBuilder b(pattern.back());
|
||||||
|
|
||||||
|
SmallVector<Location, 4> locations;
|
||||||
|
locations.reserve(pattern.size());
|
||||||
|
for (Operation* op : pattern) {
|
||||||
|
locations.push_back(op->getLoc());
|
||||||
|
}
|
||||||
|
Location fused_loc =
|
||||||
|
FusedLoc::get(locations, pattern.back()->getContext());
|
||||||
|
|
||||||
|
SmallVector<Value, 4> inputs = GetInputsOfFusionPattern(pattern);
|
||||||
|
SmallVector<Value, 4> outputs = GetOutputsOfFusionPattern(pattern);
|
||||||
|
SmallVector<Type, 4> output_types;
|
||||||
|
output_types.reserve(outputs.size());
|
||||||
|
for (Value v : outputs) {
|
||||||
|
output_types.push_back(v.getType());
|
||||||
|
}
|
||||||
|
|
||||||
|
FusionOp fusion =
|
||||||
|
b.create<xla_hlo::FusionOp>(fused_loc, output_types, inputs);
|
||||||
|
Region& region = fusion.fused_computation();
|
||||||
|
region.push_back(new Block);
|
||||||
|
Block& block = region.front();
|
||||||
|
for (Operation* op : pattern) {
|
||||||
|
op->moveBefore(&block, block.end());
|
||||||
|
}
|
||||||
|
b.setInsertionPoint(&block, block.end());
|
||||||
|
b.create<xla_hlo::ReturnOp>(fused_loc, outputs);
|
||||||
|
|
||||||
|
for (auto output_and_result : llvm::zip(outputs, fusion.getResults())) {
|
||||||
|
Value output = std::get<0>(output_and_result);
|
||||||
|
Value fusion_result = std::get<1>(output_and_result);
|
||||||
|
for (OpOperand& use : llvm::make_early_inc_range(output.getUses())) {
|
||||||
|
if (use.getOwner()->getBlock() != &block) use.set(fusion_result);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<FuncOp>> createXlaHloFusion() {
|
||||||
|
return std::make_unique<XlaHloFusion>();
|
||||||
|
}
|
||||||
|
|
||||||
|
static PassRegistration<XlaHloFusion> xla_hlo_fusion_pass(
|
||||||
|
"xla-hlo-fusion", "fuse xla_hlo ops to kLoop/kInput fusion patterns.");
|
||||||
|
|
||||||
|
} // namespace xla_hlo
|
||||||
|
} // namespace mlir
|
|
@ -0,0 +1,909 @@
|
||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
// This file implements logic for lowering HLO/LHLO dialect to Linalg dialect.
|
||||||
|
|
||||||
|
#include "third_party/absl/memory/memory.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/AffineExpr.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Function.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Location.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h"
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_xla_to_scalar_op.h"
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
SmallVector<StringRef, 3> GetNParallelLoopsAttrs(unsigned nParallelLoops) {
|
||||||
|
static constexpr StringRef kParallelIterType = "parallel";
|
||||||
|
return SmallVector<StringRef, 3>(nParallelLoops, kParallelIterType);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <bool isLHLO = true>
|
||||||
|
Value getResultValue(Operation* op) {
|
||||||
|
return isLHLO ? op->getOperand(op->getNumOperands() - 1) : op->getResult(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <bool isLHLO = true>
|
||||||
|
ShapedType getXLAOpResultType(Operation* op) {
|
||||||
|
return getResultValue<isLHLO>(op).getType().template cast<ShapedType>();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <bool isLHLO = true>
|
||||||
|
bool verifyXLAOpBufferOrTensorSemantics(Operation* op) {
|
||||||
|
auto verifyType = [&](Value val) -> bool {
|
||||||
|
return (isLHLO && val.getType().isa<MemRefType>()) ||
|
||||||
|
(!isLHLO && val.getType().isa<RankedTensorType>());
|
||||||
|
};
|
||||||
|
if (!llvm::all_of(op->getOperands(), verifyType)) return false;
|
||||||
|
return isLHLO ? op->getResults().empty()
|
||||||
|
: llvm::all_of(op->getResults(), verifyType);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename OpTy, bool isLHLO = true>
|
||||||
|
class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern<OpTy>::OpConversionPattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(
|
||||||
|
OpTy op, ArrayRef<Value> args,
|
||||||
|
ConversionPatternRewriter& rewriter) const final {
|
||||||
|
auto loc = op.getLoc();
|
||||||
|
auto argType =
|
||||||
|
op.getOperation()->getOperand(0).getType().template cast<ShapedType>();
|
||||||
|
if (!argType.hasRank()) {
|
||||||
|
emitError(loc, "lhlo to linalg conversion expects ranked args");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
auto elemTy = argType.getElementType();
|
||||||
|
if (!elemTy.isSignlessIntOrFloat() && !elemTy.template isa<ComplexType>()) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Construct the indexing maps needed for linalg.generic ops.
|
||||||
|
SmallVector<AffineMap, 2> indexing_maps;
|
||||||
|
SmallVector<Type, 4> bodyArgTypes, bodyResultTypes, opResultTypes;
|
||||||
|
|
||||||
|
// This doesnt account for implicit broadcast, but the working assumption
|
||||||
|
// here is that are broadcasts have been made explicit.
|
||||||
|
unsigned nloops = argType.getRank();
|
||||||
|
|
||||||
|
if (isLHLO && !nloops) return failure();
|
||||||
|
|
||||||
|
int operandCount = (isLHLO ? args.size() - 1 : args.size());
|
||||||
|
auto verifyArgOrResultType = [&](Value val) -> ShapedType {
|
||||||
|
auto shapedType = val.getType().dyn_cast<ShapedType>();
|
||||||
|
if (!shapedType ||
|
||||||
|
(!shapedType.isa<MemRefType>() &&
|
||||||
|
!shapedType.isa<RankedTensorType>()) ||
|
||||||
|
shapedType.getRank() != nloops)
|
||||||
|
return nullptr;
|
||||||
|
indexing_maps.emplace_back(
|
||||||
|
nloops ? rewriter.getMultiDimIdentityMap(nloops)
|
||||||
|
: AffineMap::get(nloops, 0, rewriter.getContext()));
|
||||||
|
return shapedType;
|
||||||
|
};
|
||||||
|
for (const auto& arg : llvm::enumerate(args)) {
|
||||||
|
auto shapedType = verifyArgOrResultType(arg.value());
|
||||||
|
if (!shapedType) return failure();
|
||||||
|
auto& result_or_body_arg =
|
||||||
|
arg.index() < operandCount ? bodyArgTypes : bodyResultTypes;
|
||||||
|
result_or_body_arg.emplace_back(shapedType.getElementType());
|
||||||
|
}
|
||||||
|
if (!isLHLO) {
|
||||||
|
// HLO operations have return as tensor types.
|
||||||
|
assert(bodyResultTypes.empty() &&
|
||||||
|
"When lowering HLO ops result can't be part of arguments");
|
||||||
|
Value result = op.getOperation()->getResult(0);
|
||||||
|
auto shapedType = verifyArgOrResultType(result);
|
||||||
|
if (!shapedType) return failure();
|
||||||
|
bodyResultTypes.push_back(shapedType.getElementType());
|
||||||
|
opResultTypes.push_back(shapedType);
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t args_count = bodyArgTypes.size();
|
||||||
|
int64_t results_count = bodyResultTypes.size();
|
||||||
|
auto linalgOp = rewriter.create<linalg::GenericOp>(
|
||||||
|
loc, opResultTypes, args, args_count, results_count, indexing_maps,
|
||||||
|
GetNParallelLoopsAttrs(nloops),
|
||||||
|
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) {
|
||||||
|
// TODO(ravishankarm) : For now use the method in xla_lhlo namespace.
|
||||||
|
// That method needs to be moved out of there.
|
||||||
|
Value opResult = xla_lhlo::XlaOpToStdScalarOp::map<OpTy>(
|
||||||
|
op, bodyResultTypes,
|
||||||
|
llvm::to_vector<2>(args.take_front(args_count)), &rewriter);
|
||||||
|
nestedBuilder.create<linalg::YieldOp>(loc, opResult);
|
||||||
|
});
|
||||||
|
rewriter.replaceOp(op, linalgOp.getOperation()->getResults());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename LhloOp>
|
||||||
|
class ScalarPointwiseToStandardConverter : public OpConversionPattern<LhloOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern<LhloOp>::OpConversionPattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(
|
||||||
|
LhloOp lhlo_op, ArrayRef<Value> args,
|
||||||
|
ConversionPatternRewriter& rewriter) const final {
|
||||||
|
auto loc = lhlo_op.getLoc();
|
||||||
|
auto argType =
|
||||||
|
lhlo_op.getOperand(0).getType().template dyn_cast<ShapedType>();
|
||||||
|
if (!argType || !argType.getElementType().isSignlessIntOrFloat() ||
|
||||||
|
(argType.getRank() != 0)) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create two loads from the input.
|
||||||
|
auto lhs = rewriter.create<LoadOp>(loc, lhlo_op.lhs());
|
||||||
|
auto rhs = rewriter.create<LoadOp>(loc, lhlo_op.rhs());
|
||||||
|
// TODO(ravishankarm) : Move this method out of xla_lhlo namespace.
|
||||||
|
Value opResult = xla_lhlo::XlaOpToStdScalarOp::map<LhloOp>(
|
||||||
|
lhlo_op, argType.getElementType(), llvm::ArrayRef<Value>{lhs, rhs},
|
||||||
|
&rewriter);
|
||||||
|
rewriter.create<StoreOp>(loc, opResult, lhlo_op.out());
|
||||||
|
rewriter.eraseOp(lhlo_op);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// xla_lhlo.convolution conversion pattern.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
/// Converts xla_lhlo.convolution operation to a linalg.conv op.
|
||||||
|
struct ConvToLinalgConverter : public OpConversionPattern<xla_lhlo::ConvOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern<xla_lhlo::ConvOp>::OpConversionPattern;
|
||||||
|
|
||||||
|
// This code has been adapted from IREE's
|
||||||
|
// (https://github.com/google/iree/) xla_hlo -> linalg conversion.
|
||||||
|
LogicalResult matchAndRewrite(
|
||||||
|
xla_lhlo::ConvOp op, ArrayRef<Value> args,
|
||||||
|
ConversionPatternRewriter& rewriter) const final {
|
||||||
|
// Check validity of dimension information.
|
||||||
|
if (const xla_lhlo::ConvDimensionNumbers& dimensionNumbers =
|
||||||
|
op.dimension_numbers()) {
|
||||||
|
const int inputSpatialRank =
|
||||||
|
llvm::size(dimensionNumbers.input_spatial_dimensions());
|
||||||
|
// The dimensions for input should follow the order of
|
||||||
|
// batch_count, spatial_dims..., input_feature_count.
|
||||||
|
if (dimensionNumbers.input_batch_dimension().getInt() != 0 ||
|
||||||
|
dimensionNumbers.input_feature_dimension().getInt() !=
|
||||||
|
(inputSpatialRank + 1))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
const int kernelSpatialRank =
|
||||||
|
llvm::size(dimensionNumbers.kernel_spatial_dimensions());
|
||||||
|
// The dimensions for filter should follow the order of
|
||||||
|
// spatial_dims..., input_feature_count, num_output_feature_count.
|
||||||
|
if (dimensionNumbers.kernel_input_feature_dimension().getInt() !=
|
||||||
|
kernelSpatialRank ||
|
||||||
|
dimensionNumbers.kernel_output_feature_dimension().getInt() !=
|
||||||
|
(kernelSpatialRank + 1))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
const int outputSpatialRank =
|
||||||
|
llvm::size(dimensionNumbers.output_spatial_dimensions());
|
||||||
|
// The dimensions for output should follow the order of
|
||||||
|
// batch_count, spatial_dims.., output_feature_count.
|
||||||
|
if (dimensionNumbers.output_batch_dimension().getInt() != 0 ||
|
||||||
|
dimensionNumbers.output_feature_dimension().getInt() !=
|
||||||
|
(outputSpatialRank + 1))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (inputSpatialRank != outputSpatialRank ||
|
||||||
|
inputSpatialRank != kernelSpatialRank)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto inputSpatialDim =
|
||||||
|
dimensionNumbers.input_spatial_dimensions().begin();
|
||||||
|
auto kernelSpatialDim =
|
||||||
|
dimensionNumbers.kernel_spatial_dimensions().begin();
|
||||||
|
auto outputSpatialDim =
|
||||||
|
dimensionNumbers.output_spatial_dimensions().begin();
|
||||||
|
// Check if spatial dims are ordered correctly.
|
||||||
|
for (int i = 0; i < inputSpatialRank; ++i) {
|
||||||
|
const int dim = i + 1;
|
||||||
|
if ((*inputSpatialDim++).getZExtValue() != dim ||
|
||||||
|
(*outputSpatialDim++).getZExtValue() != dim ||
|
||||||
|
(*kernelSpatialDim++).getZExtValue() != i)
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: LHS dilation for deconvolution not supported yet.
|
||||||
|
if (op.lhs_dilation()) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::SmallVector<Attribute, 4> strides;
|
||||||
|
if (auto windowStrides = op.window_strides()) {
|
||||||
|
auto range = windowStrides->getAttributeValues();
|
||||||
|
strides.assign(range.begin(), range.end());
|
||||||
|
}
|
||||||
|
auto stridesArg = ArrayAttr::get(strides, op.getContext());
|
||||||
|
|
||||||
|
llvm::SmallVector<Attribute, 2> dilation;
|
||||||
|
if (auto rhsDilation = op.rhs_dilation()) {
|
||||||
|
auto range = rhsDilation->getAttributeValues();
|
||||||
|
dilation.assign(range.begin(), range.end());
|
||||||
|
} else {
|
||||||
|
// Default dilation of 1.
|
||||||
|
dilation.resize(2, IntegerAttr::get(rewriter.getIntegerType(64), 1));
|
||||||
|
}
|
||||||
|
auto dilationArg = ArrayAttr::get(dilation, op.getContext());
|
||||||
|
|
||||||
|
// Set padding only if it is non-zero.
|
||||||
|
DenseIntElementsAttr padding = op.paddingAttr();
|
||||||
|
if (!padding || !llvm::any_of(padding.getValues<APInt>(), [](APInt intVal) {
|
||||||
|
return !intVal.isNullValue();
|
||||||
|
})) {
|
||||||
|
padding = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// The order of input and filter are switched with linalg.conv.
|
||||||
|
rewriter.replaceOpWithNewOp<linalg::ConvOp>(
|
||||||
|
op, args[1], args[0], args[2], stridesArg, dilationArg, padding);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Base class for lowering xla operations that have one operand and one result,
|
||||||
|
/// and are semantically equivalent to a copy of the input to the output (like
|
||||||
|
/// transpose, some reshape, etc.). The derived classes need to provide a method
|
||||||
|
/// `getIndexingMaps` that returns AffineMaps for the index maps of the input
|
||||||
|
/// and the output.
|
||||||
|
template <typename Derived, typename OpTy, bool isLHLO = true>
|
||||||
|
class DataMovementOpConverter : public OpConversionPattern<OpTy> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern<OpTy>::OpConversionPattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(
|
||||||
|
OpTy op, ArrayRef<Value> args,
|
||||||
|
ConversionPatternRewriter& rewriter) const final {
|
||||||
|
if (!verifyXLAOpBufferOrTensorSemantics<isLHLO>(op)) return failure();
|
||||||
|
auto resultType = getXLAOpResultType<isLHLO>(op);
|
||||||
|
|
||||||
|
SmallVector<AffineMap, 2> indexing_maps =
|
||||||
|
Derived::getIndexingMaps(op, &rewriter);
|
||||||
|
if (indexing_maps.empty()) return failure();
|
||||||
|
|
||||||
|
auto nloops = resultType.getRank();
|
||||||
|
auto loc = op.getLoc();
|
||||||
|
auto linalgOp = rewriter.create<linalg::GenericOp>(
|
||||||
|
loc, isLHLO ? ArrayRef<Type>{} : resultType, args, /*inputCount=*/1,
|
||||||
|
/*outputCount=*/1, indexing_maps, GetNParallelLoopsAttrs(nloops),
|
||||||
|
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) {
|
||||||
|
nestedBuilder.create<linalg::YieldOp>(loc, *args.begin());
|
||||||
|
});
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, linalgOp.getOperation()->getResults());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Pattern to convert BroadcastOp to Linalg ops.
|
||||||
|
template <typename OpTy, bool isLHLO = true>
|
||||||
|
class BroadcastConverter
|
||||||
|
: public DataMovementOpConverter<BroadcastConverter<OpTy, isLHLO>, OpTy,
|
||||||
|
isLHLO> {
|
||||||
|
public:
|
||||||
|
using DataMovementOpConverter<BroadcastConverter, OpTy,
|
||||||
|
isLHLO>::DataMovementOpConverter;
|
||||||
|
|
||||||
|
static SmallVector<AffineMap, 2> getIndexingMaps(OpTy broadcastOp,
|
||||||
|
Builder* b) {
|
||||||
|
ShapedType inputType =
|
||||||
|
broadcastOp.operand().getType().template cast<ShapedType>();
|
||||||
|
unsigned inputRank = inputType.getRank();
|
||||||
|
unsigned nloops = getXLAOpResultType<isLHLO>(broadcastOp).getRank();
|
||||||
|
|
||||||
|
// BroadcastOp prepends the dimensions in the `broadcast_sizes` attribute to
|
||||||
|
// the input's dimensions.
|
||||||
|
unsigned numPrependedDims = llvm::size(broadcastOp.broadcast_sizes());
|
||||||
|
SmallVector<AffineExpr, 4> inputDimExprs;
|
||||||
|
inputDimExprs.reserve(inputRank);
|
||||||
|
for (int i = 0; i < inputRank; ++i) {
|
||||||
|
inputDimExprs.push_back(b->getAffineDimExpr(numPrependedDims + i));
|
||||||
|
}
|
||||||
|
|
||||||
|
AffineMap inputMap;
|
||||||
|
MLIRContext* context = b->getContext();
|
||||||
|
if (inputDimExprs.empty()) {
|
||||||
|
// The input is a scalar, i.e. this is a scalar broadcast op.
|
||||||
|
inputMap = AffineMap::get(nloops, /*symbolCount=*/0, context);
|
||||||
|
} else {
|
||||||
|
inputMap =
|
||||||
|
AffineMap::get(nloops, /*symbolCount=*/0, inputDimExprs, context);
|
||||||
|
}
|
||||||
|
return {inputMap, b->getMultiDimIdentityMap(nloops)};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class HloBroadcastInDimConverter
|
||||||
|
: public DataMovementOpConverter<HloBroadcastInDimConverter,
|
||||||
|
xla_hlo::BroadcastInDimOp, false> {
|
||||||
|
public:
|
||||||
|
using DataMovementOpConverter<HloBroadcastInDimConverter,
|
||||||
|
xla_hlo::BroadcastInDimOp,
|
||||||
|
false>::DataMovementOpConverter;
|
||||||
|
|
||||||
|
static SmallVector<AffineMap, 2> getIndexingMaps(
|
||||||
|
xla_hlo::BroadcastInDimOp broadcastOp, Builder* b) {
|
||||||
|
auto resultType = getXLAOpResultType<false>(broadcastOp);
|
||||||
|
auto operandType =
|
||||||
|
broadcastOp.operand().getType().template cast<ShapedType>();
|
||||||
|
unsigned nloops = resultType.getRank();
|
||||||
|
|
||||||
|
// The input is a scalar, i.e. this is a scalar broadcast op.
|
||||||
|
if (operandType.getRank() == 0) {
|
||||||
|
return {AffineMap::get(nloops, /*symbolCount=*/0, b->getContext()),
|
||||||
|
b->getMultiDimIdentityMap(nloops)};
|
||||||
|
}
|
||||||
|
|
||||||
|
auto operandShape = operandType.getShape();
|
||||||
|
SmallVector<AffineExpr, 4> dimExprs;
|
||||||
|
dimExprs.reserve(nloops);
|
||||||
|
|
||||||
|
if (broadcastOp.broadcast_dimensions()) {
|
||||||
|
for (const auto& broadcastDim :
|
||||||
|
enumerate(broadcastOp.broadcast_dimensions().getIntValues())) {
|
||||||
|
int size = broadcastDim.value().getSExtValue();
|
||||||
|
bool expansion_needed = operandShape[broadcastDim.index()] == 1 &&
|
||||||
|
resultType.getShape()[size] != 1;
|
||||||
|
dimExprs.push_back(expansion_needed ? b->getAffineConstantExpr(0)
|
||||||
|
: b->getAffineDimExpr(size));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
AffineMap::get(nloops, /*symbolCount=*/0, dimExprs, b->getContext()),
|
||||||
|
b->getMultiDimIdentityMap(nloops)};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class LhloBroadcastInDimConverter
|
||||||
|
: public OpConversionPattern<xla_lhlo::BroadcastInDimOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern<xla_lhlo::BroadcastInDimOp>::OpConversionPattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(
|
||||||
|
xla_lhlo::BroadcastInDimOp op, ArrayRef<Value> args,
|
||||||
|
ConversionPatternRewriter& rewriter) const final {
|
||||||
|
xla_lhlo::BroadcastInDimOp::Adaptor operand_adaptor(args);
|
||||||
|
auto result_type = operand_adaptor.output().getType().cast<MemRefType>();
|
||||||
|
auto result_shape = result_type.getShape();
|
||||||
|
|
||||||
|
auto operand_and_dims = InsertReshapeIfNecessary(op, args, rewriter);
|
||||||
|
|
||||||
|
Value operand = std::get<0>(operand_and_dims);
|
||||||
|
auto broadcast_dims = std::get<1>(operand_and_dims);
|
||||||
|
|
||||||
|
auto loc = op.getLoc();
|
||||||
|
auto nloops = result_type.getRank();
|
||||||
|
auto operand_type = operand.getType().cast<MemRefType>();
|
||||||
|
|
||||||
|
// For a degenerate case, i.e. broadcasting with expansion of
|
||||||
|
// memref<1xELEMENT_TYPE>, the operand is not passed to `linalg.generic`.
|
||||||
|
// Instead the value is loaded and used directly in `linalg.yield`.
|
||||||
|
if (operand_type.getRank() == 1 &&
|
||||||
|
operand_type.getDimSize(0) <
|
||||||
|
result_type.getDimSize(broadcast_dims.front())) {
|
||||||
|
Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
|
||||||
|
Value val =
|
||||||
|
rewriter.create<LoadOp>(loc, operand, llvm::makeArrayRef({zero}));
|
||||||
|
rewriter.create<linalg::GenericOp>(
|
||||||
|
loc, llvm::None, llvm::makeArrayRef(operand_adaptor.output()),
|
||||||
|
/*inputCount=*/0, /*outputCount=*/1,
|
||||||
|
llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
|
||||||
|
GetNParallelLoopsAttrs(nloops),
|
||||||
|
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) {
|
||||||
|
nestedBuilder.create<linalg::YieldOp>(loc, val);
|
||||||
|
});
|
||||||
|
|
||||||
|
} else {
|
||||||
|
auto indexing_maps = getIndexingMaps(op, broadcast_dims, result_shape,
|
||||||
|
operand_type, &rewriter);
|
||||||
|
rewriter.create<linalg::GenericOp>(
|
||||||
|
loc, llvm::None,
|
||||||
|
llvm::makeArrayRef({operand, operand_adaptor.output()}),
|
||||||
|
/*inputCount=*/1, /*outputCount=*/1, indexing_maps,
|
||||||
|
GetNParallelLoopsAttrs(nloops),
|
||||||
|
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) {
|
||||||
|
nestedBuilder.create<linalg::YieldOp>(loc, *args.begin());
|
||||||
|
});
|
||||||
|
}
|
||||||
|
rewriter.replaceOp(op, llvm::None);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Inserts 'linalg.reshape' if there is a size-1 dim expansion.
|
||||||
|
std::pair<Value, SmallVector<int64_t, 2>> InsertReshapeIfNecessary(
|
||||||
|
xla_lhlo::BroadcastInDimOp op, ArrayRef<Value> args,
|
||||||
|
ConversionPatternRewriter& rewriter) const {
|
||||||
|
xla_lhlo::BroadcastInDimOp::Adaptor operand_adaptor(args);
|
||||||
|
Value operand = operand_adaptor.operand();
|
||||||
|
auto operand_type = operand_adaptor.operand().getType().cast<MemRefType>();
|
||||||
|
auto operand_shape = operand_type.getShape();
|
||||||
|
|
||||||
|
Value result = operand_adaptor.output();
|
||||||
|
auto result_type = result.getType().cast<MemRefType>();
|
||||||
|
auto result_shape = result_type.getShape();
|
||||||
|
|
||||||
|
SmallVector<int64_t, 2> operand_strides;
|
||||||
|
int64_t operand_offset;
|
||||||
|
if (failed(getStridesAndOffset(operand_type, operand_strides,
|
||||||
|
operand_offset))) {
|
||||||
|
op.emitOpError() << "Failed to get offset and strides.";
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<int64_t, 2> new_shape, new_strides, broadcast_dims;
|
||||||
|
SmallVector<linalg::ReassociationIndices, 4> collapsed_dims_list;
|
||||||
|
linalg::ReassociationIndices collapsed_dims;
|
||||||
|
for (const auto& item :
|
||||||
|
enumerate(op.broadcast_dimensions().getIntValues())) {
|
||||||
|
size_t index = item.index();
|
||||||
|
int dim = item.value().getSExtValue();
|
||||||
|
|
||||||
|
collapsed_dims.push_back(index);
|
||||||
|
|
||||||
|
bool expansion_needed =
|
||||||
|
operand_shape[index] == 1 && result_shape[dim] != 1;
|
||||||
|
if (expansion_needed) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
new_shape.push_back(operand_shape[index]);
|
||||||
|
new_strides.push_back(operand_strides[index]);
|
||||||
|
broadcast_dims.push_back(dim);
|
||||||
|
|
||||||
|
collapsed_dims_list.push_back(collapsed_dims);
|
||||||
|
collapsed_dims.clear();
|
||||||
|
}
|
||||||
|
// If `collapsed_dims_list` is empty, then the memref has shape [1, ..., 1]
|
||||||
|
// and all dimensions need expansion. Such memref will be reshaped to a 1D
|
||||||
|
// memref with a single element. New shape and strides needs to be updated
|
||||||
|
// accordingly.
|
||||||
|
if (collapsed_dims_list.empty()) {
|
||||||
|
collapsed_dims_list.push_back({});
|
||||||
|
new_shape.push_back(1);
|
||||||
|
new_strides.push_back(1);
|
||||||
|
broadcast_dims.push_back(0);
|
||||||
|
}
|
||||||
|
for (const auto& dims : collapsed_dims) {
|
||||||
|
collapsed_dims_list.back().push_back(dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
// `linalg.reshape` is inserted only if necessary, i.e. when the rank can be
|
||||||
|
// reduced.
|
||||||
|
if (new_shape.size() < operand_shape.size()) {
|
||||||
|
auto new_memref_type = MemRefType::get(
|
||||||
|
new_shape, operand_type.getElementType(),
|
||||||
|
makeStridedLinearLayoutMap(new_strides, operand_offset,
|
||||||
|
rewriter.getContext()));
|
||||||
|
operand = rewriter.create<linalg::ReshapeOp>(op.getLoc(), new_memref_type,
|
||||||
|
operand_adaptor.operand(),
|
||||||
|
collapsed_dims_list);
|
||||||
|
}
|
||||||
|
return std::make_pair(operand, broadcast_dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<AffineMap, 2> getIndexingMaps(xla_lhlo::BroadcastInDimOp op,
|
||||||
|
ArrayRef<int64_t> broadcastDims,
|
||||||
|
ArrayRef<int64_t> resultShape,
|
||||||
|
MemRefType operandType,
|
||||||
|
Builder* b) const {
|
||||||
|
unsigned nloops = resultShape.size();
|
||||||
|
|
||||||
|
// The input is a scalar, i.e. this is a scalar broadcast op.
|
||||||
|
if (operandType.getRank() == 0) {
|
||||||
|
return {AffineMap::get(nloops, /*symbolCount=*/0, b->getContext()),
|
||||||
|
b->getMultiDimIdentityMap(nloops)};
|
||||||
|
}
|
||||||
|
|
||||||
|
auto operandShape = operandType.getShape();
|
||||||
|
SmallVector<AffineExpr, 4> dimExprs;
|
||||||
|
dimExprs.reserve(nloops);
|
||||||
|
|
||||||
|
for (const auto& broadcastDim : llvm::enumerate(broadcastDims)) {
|
||||||
|
int size = broadcastDim.value();
|
||||||
|
bool expansion_needed =
|
||||||
|
operandShape[broadcastDim.index()] == 1 && resultShape[size] != 1;
|
||||||
|
if (expansion_needed) {
|
||||||
|
op.emitOpError(
|
||||||
|
"BroadcastInDimOp lowering to Linalg does not support size-1 "
|
||||||
|
"dimensions expansion.");
|
||||||
|
}
|
||||||
|
dimExprs.push_back(b->getAffineDimExpr(size));
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
AffineMap::get(nloops, /*symbolCount=*/0, dimExprs, b->getContext()),
|
||||||
|
b->getMultiDimIdentityMap(nloops)};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename OpTy, bool isLHLO = true>
|
||||||
|
class TransposeConverter
|
||||||
|
: public DataMovementOpConverter<TransposeConverter<OpTy, isLHLO>, OpTy,
|
||||||
|
isLHLO> {
|
||||||
|
public:
|
||||||
|
using DataMovementOpConverter<TransposeConverter<OpTy, isLHLO>, OpTy,
|
||||||
|
isLHLO>::DataMovementOpConverter;
|
||||||
|
static SmallVector<AffineMap, 2> getIndexingMaps(OpTy op, Builder* b) {
|
||||||
|
auto resultType =
|
||||||
|
getXLAOpResultType<isLHLO>(op).template cast<ShapedType>();
|
||||||
|
auto nloops = resultType.getRank();
|
||||||
|
SmallVector<AffineExpr, 2> inputExprs;
|
||||||
|
inputExprs.resize(resultType.getRank());
|
||||||
|
for (auto permutation : llvm::enumerate(op.permutation())) {
|
||||||
|
inputExprs[permutation.value().getZExtValue()] =
|
||||||
|
b->getAffineDimExpr(permutation.index());
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
AffineMap::get(nloops, /*symbolCount=*/0, inputExprs, b->getContext()),
|
||||||
|
b->getMultiDimIdentityMap(nloops)};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Converts reshape ops that can be proven to be either a collapse of dimensions
|
||||||
|
// or expansion of dimensions of the operand.
|
||||||
|
template <typename OpTy, bool isLHLO = true>
|
||||||
|
class ReshapeOpConverter : public OpConversionPattern<OpTy> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern<OpTy>::OpConversionPattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(
|
||||||
|
OpTy reshapeOp, ArrayRef<Value> args,
|
||||||
|
ConversionPatternRewriter& rewriter) const final {
|
||||||
|
if (!verifyXLAOpBufferOrTensorSemantics<isLHLO>(reshapeOp))
|
||||||
|
return failure();
|
||||||
|
ShapedType operandType =
|
||||||
|
reshapeOp.operand().getType().template cast<ShapedType>();
|
||||||
|
ShapedType resultType = getXLAOpResultType<isLHLO>(reshapeOp);
|
||||||
|
|
||||||
|
if (!operandType.hasStaticShape() || !resultType.hasStaticShape())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// Compute the reassociation maps for the linalg operation.
|
||||||
|
ArrayRef<int64_t> srcShape =
|
||||||
|
(operandType.getRank() > resultType.getRank() ? operandType.getShape()
|
||||||
|
: resultType.getShape());
|
||||||
|
ArrayRef<int64_t> dstShape =
|
||||||
|
(operandType.getRank() > resultType.getRank() ? resultType.getShape()
|
||||||
|
: operandType.getShape());
|
||||||
|
unsigned currSrcDim = 0, currDstDim = 0;
|
||||||
|
SmallVector<linalg::ReassociationExprs, 4> reassociationMap(
|
||||||
|
dstShape.size());
|
||||||
|
while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) {
|
||||||
|
int64_t dstSize = dstShape[currDstDim];
|
||||||
|
int64_t srcSize = srcShape[currSrcDim];
|
||||||
|
while (srcSize < dstSize && currSrcDim < srcShape.size()) {
|
||||||
|
reassociationMap[currDstDim].push_back(
|
||||||
|
rewriter.getAffineDimExpr(currSrcDim++));
|
||||||
|
srcSize *= srcShape[currSrcDim];
|
||||||
|
}
|
||||||
|
if (srcSize == dstSize) {
|
||||||
|
reassociationMap[currDstDim].push_back(
|
||||||
|
rewriter.getAffineDimExpr(currSrcDim++));
|
||||||
|
// If the next dim in dstShape is not 1, treat subsequent dims in
|
||||||
|
// srcShape which are 1 to be collapsed.
|
||||||
|
if (currDstDim == dstShape.size() - 1 ||
|
||||||
|
dstShape[currDstDim + 1] != 1) {
|
||||||
|
while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) {
|
||||||
|
reassociationMap[currDstDim].push_back(
|
||||||
|
rewriter.getAffineDimExpr(currSrcDim++));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
currDstDim++;
|
||||||
|
}
|
||||||
|
if (currSrcDim != srcShape.size()) return failure();
|
||||||
|
|
||||||
|
if (isLHLO) {
|
||||||
|
Value reshapeBuffer = rewriter.create<linalg::ReshapeOp>(
|
||||||
|
reshapeOp.getLoc(), resultType, args[0], reassociationMap);
|
||||||
|
rewriter.replaceOpWithNewOp<linalg::CopyOp>(
|
||||||
|
reshapeOp, reshapeBuffer, args[1], /*inputPermutation =*/nullptr,
|
||||||
|
/*outputPermutation =*/nullptr);
|
||||||
|
} else {
|
||||||
|
rewriter.replaceOpWithNewOp<linalg::TensorReshapeOp>(
|
||||||
|
reshapeOp, resultType, args[0], reassociationMap);
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class IotaConverter : public OpConversionPattern<xla_lhlo::IotaOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern<xla_lhlo::IotaOp>::OpConversionPattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(
|
||||||
|
xla_lhlo::IotaOp iotaOp, ArrayRef<Value> args,
|
||||||
|
ConversionPatternRewriter& rewriter) const final {
|
||||||
|
auto resultMemrefType =
|
||||||
|
iotaOp.getOperand().getType().dyn_cast<MemRefType>();
|
||||||
|
if (!resultMemrefType) return failure();
|
||||||
|
|
||||||
|
auto resultElementType = resultMemrefType.getElementType();
|
||||||
|
if (!resultElementType.isSignlessIntOrFloat()) return failure();
|
||||||
|
|
||||||
|
// Construct the indexing maps needed for linalg.generic ops.
|
||||||
|
unsigned nloops = resultMemrefType.getRank();
|
||||||
|
|
||||||
|
rewriter.create<linalg::IndexedGenericOp>(
|
||||||
|
iotaOp.getLoc(), ArrayRef<Type>{}, args,
|
||||||
|
0, // args_in
|
||||||
|
1, // args_out
|
||||||
|
llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
|
||||||
|
GetNParallelLoopsAttrs(nloops),
|
||||||
|
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange ivs,
|
||||||
|
ValueRange args) {
|
||||||
|
Value castOp = nestedBuilder.create<IndexCastOp>(
|
||||||
|
nestedLoc, ivs[iotaOp.iota_dimension().getZExtValue()],
|
||||||
|
nestedBuilder.getIntegerType(
|
||||||
|
resultElementType.getIntOrFloatBitWidth()));
|
||||||
|
if (resultElementType.isa<FloatType>()) {
|
||||||
|
castOp = nestedBuilder.create<SIToFPOp>(nestedLoc, castOp,
|
||||||
|
resultElementType);
|
||||||
|
}
|
||||||
|
nestedBuilder.create<linalg::YieldOp>(nestedLoc, castOp);
|
||||||
|
});
|
||||||
|
|
||||||
|
rewriter.replaceOp(iotaOp, llvm::None);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class ConstConverter : public OpConversionPattern<xla_lhlo::ConstOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern<xla_lhlo::ConstOp>::OpConversionPattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(
|
||||||
|
xla_lhlo::ConstOp constOp, ArrayRef<Value> args,
|
||||||
|
ConversionPatternRewriter& rewriter) const final {
|
||||||
|
auto loc = constOp.getLoc();
|
||||||
|
auto valueAttr = constOp.value().cast<DenseElementsAttr>();
|
||||||
|
if (valueAttr.getType().getRank() != 0) return failure();
|
||||||
|
auto stdConstOp =
|
||||||
|
rewriter.create<mlir::ConstantOp>(loc, valueAttr.getValue({}));
|
||||||
|
rewriter.create<mlir::StoreOp>(loc, stdConstOp, constOp.getOperand());
|
||||||
|
rewriter.eraseOp(constOp);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// TODO(b/156787842): Support the lowering for dynamic shapes.
|
||||||
|
template <typename OpTy, bool isLHLO = true>
|
||||||
|
class ReverseConverter
|
||||||
|
: public DataMovementOpConverter<ReverseConverter<OpTy, isLHLO>, OpTy,
|
||||||
|
isLHLO> {
|
||||||
|
public:
|
||||||
|
using DataMovementOpConverter<ReverseConverter<OpTy, isLHLO>, OpTy,
|
||||||
|
isLHLO>::DataMovementOpConverter;
|
||||||
|
static SmallVector<AffineMap, 2> getIndexingMaps(OpTy op, Builder* b) {
|
||||||
|
auto resultType =
|
||||||
|
getXLAOpResultType<isLHLO>(op).template cast<ShapedType>();
|
||||||
|
auto nloops = resultType.getRank();
|
||||||
|
SmallVector<AffineExpr, 2> inputExprs;
|
||||||
|
inputExprs.reserve(nloops);
|
||||||
|
for (int i = 0; i < nloops; ++i)
|
||||||
|
inputExprs.push_back(b->getAffineDimExpr(i));
|
||||||
|
for (auto dim : op.dimensions()) {
|
||||||
|
int i = dim.getZExtValue();
|
||||||
|
if (resultType.isDynamicDim(i)) return {};
|
||||||
|
int n = resultType.getShape()[i];
|
||||||
|
inputExprs[i] = b->getAffineConstantExpr(n - 1) - inputExprs[i];
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
AffineMap::get(nloops, /*symbolCount=*/0, inputExprs, b->getContext()),
|
||||||
|
b->getMultiDimIdentityMap(nloops)};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class SliceConverter : public OpConversionPattern<xla_lhlo::SliceOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern<xla_lhlo::SliceOp>::OpConversionPattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(
|
||||||
|
xla_lhlo::SliceOp sliceOp, ArrayRef<Value> args,
|
||||||
|
ConversionPatternRewriter& rewriter) const final {
|
||||||
|
auto loc = sliceOp.getLoc();
|
||||||
|
auto argType =
|
||||||
|
sliceOp.getOperand(0).getType().template dyn_cast<ShapedType>();
|
||||||
|
if (!argType || !argType.hasRank()) {
|
||||||
|
emitError(loc, "lhlo to linalg conversion expects known-rank args");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Value, 3> ranges;
|
||||||
|
for (int i = 0, e = argType.getRank(); i < e; ++i) {
|
||||||
|
Value start_index = rewriter.create<ConstantIndexOp>(
|
||||||
|
loc, sliceOp.start_indices().getValue<int64_t>(i));
|
||||||
|
Value limit_index = rewriter.create<ConstantIndexOp>(
|
||||||
|
loc, sliceOp.limit_indices().getValue<int64_t>(i));
|
||||||
|
Value stride = rewriter.create<ConstantIndexOp>(
|
||||||
|
loc, sliceOp.strides().getValue<int64_t>(i));
|
||||||
|
ranges.push_back(rewriter.create<linalg::RangeOp>(loc, start_index,
|
||||||
|
limit_index, stride));
|
||||||
|
}
|
||||||
|
auto linalg_slice =
|
||||||
|
rewriter.create<linalg::SliceOp>(loc, sliceOp.getOperand(0), ranges);
|
||||||
|
rewriter.create<linalg::CopyOp>(loc, linalg_slice, sliceOp.getOperand(1));
|
||||||
|
rewriter.eraseOp(sliceOp);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void populateLHLOToLinalgConversionPattern(MLIRContext* context,
|
||||||
|
OwningRewritePatternList* patterns) {
|
||||||
|
// clang-format off
|
||||||
|
patterns->insert<BroadcastConverter<xla_lhlo::BroadcastOp>,
|
||||||
|
ConstConverter,
|
||||||
|
ConvToLinalgConverter,
|
||||||
|
IotaConverter,
|
||||||
|
LhloBroadcastInDimConverter,
|
||||||
|
PointwiseToLinalgConverter<xla_lhlo::AbsOp>,
|
||||||
|
PointwiseToLinalgConverter<xla_lhlo::AddOp>,
|
||||||
|
PointwiseToLinalgConverter<xla_lhlo::AndOp>,
|
||||||
|
PointwiseToLinalgConverter<xla_lhlo::CeilOp>,
|
||||||
|
PointwiseToLinalgConverter<xla_lhlo::CompareOp>,
|
||||||
|
PointwiseToLinalgConverter<xla_lhlo::ComplexOp>,
|
||||||
|
PointwiseToLinalgConverter<xla_lhlo::ConvertOp>,
|
||||||
|
// TODO(ataei): Remove this pattern, CopyOp is folded away.
|
||||||
|
PointwiseToLinalgConverter<xla_lhlo::CopyOp>,
|
||||||
|
PointwiseToLinalgConverter<xla_lhlo::CosOp>,
|
||||||
|
PointwiseToLinalgConverter<xla_lhlo::DivOp>,
|
||||||
|
PointwiseToLinalgConverter<xla_lhlo::ExpOp>,
|
||||||
|
PointwiseToLinalgConverter<xla_lhlo::ImagOp>,
|
||||||
|
PointwiseToLinalgConverter<xla_lhlo::LogOp>,
|
||||||
|
PointwiseToLinalgConverter<xla_lhlo::MaxOp>,
|
||||||
|
PointwiseToLinalgConverter<xla_lhlo::MinOp>,
|
||||||
|
PointwiseToLinalgConverter<xla_lhlo::MulOp>,
|
||||||
|
PointwiseToLinalgConverter<xla_lhlo::NegOp>,
|
||||||
|
PointwiseToLinalgConverter<xla_lhlo::RealOp>,
|
||||||
|
PointwiseToLinalgConverter<xla_lhlo::RemOp>,
|
||||||
|
PointwiseToLinalgConverter<xla_lhlo::RsqrtOp>,
|
||||||
|
PointwiseToLinalgConverter<xla_lhlo::SelectOp>,
|
||||||
|
PointwiseToLinalgConverter<xla_lhlo::SignOp>,
|
||||||
|
PointwiseToLinalgConverter<xla_lhlo::SinOp>,
|
||||||
|
PointwiseToLinalgConverter<xla_lhlo::SqrtOp>,
|
||||||
|
PointwiseToLinalgConverter<xla_lhlo::SubOp>,
|
||||||
|
PointwiseToLinalgConverter<xla_lhlo::TanhOp>,
|
||||||
|
ReshapeOpConverter<xla_lhlo::ReshapeOp>,
|
||||||
|
ReverseConverter<xla_lhlo::ReverseOp>,
|
||||||
|
ScalarPointwiseToStandardConverter<xla_lhlo::AddOp>,
|
||||||
|
SliceConverter
|
||||||
|
>(context);
|
||||||
|
// clang-format on
|
||||||
|
}
|
||||||
|
|
||||||
|
// Converts LHLO ops to Linalg generic.
|
||||||
|
// Sample result for xla_lhlo::AddOp.
|
||||||
|
//
|
||||||
|
// "xla_lhlo.add"(%arg1, %arg2, %out) :
|
||||||
|
// (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||||
|
//
|
||||||
|
// will be converted to
|
||||||
|
//
|
||||||
|
// #map0 = (d0, d1) -> (d0, d1)
|
||||||
|
// "linalg.generic"(%arg1, %arg2, %out) ( {
|
||||||
|
// ^bb0(%arg4: f32, %arg5: f32):
|
||||||
|
// %0 = addf %arg4, %arg5 : f32
|
||||||
|
// "linalg.yield"(%0) : (f32) -> ()
|
||||||
|
// }) {
|
||||||
|
// args_in = 2,
|
||||||
|
// args_out = 1,
|
||||||
|
// indexing_maps = [#map0, #map0, #map0],
|
||||||
|
// iterator_types = ["parallel", "parallel"],
|
||||||
|
// } : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||||
|
struct LhloLegalizeToLinalg
|
||||||
|
: public PassWrapper<LhloLegalizeToLinalg, FunctionPass> {
|
||||||
|
void runOnFunction() override {
|
||||||
|
OwningRewritePatternList patterns;
|
||||||
|
ConversionTarget target(getContext());
|
||||||
|
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect>();
|
||||||
|
|
||||||
|
auto func = getFunction();
|
||||||
|
populateLHLOToLinalgConversionPattern(func.getContext(), &patterns);
|
||||||
|
if (failed(applyPartialConversion(func, target, patterns, nullptr))) {
|
||||||
|
signalPassFailure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct HloLegalizeToLinalg
|
||||||
|
: public PassWrapper<HloLegalizeToLinalg, FunctionPass> {
|
||||||
|
void runOnFunction() override {
|
||||||
|
OwningRewritePatternList patterns;
|
||||||
|
ConversionTarget target(getContext());
|
||||||
|
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect>();
|
||||||
|
|
||||||
|
auto func = getFunction();
|
||||||
|
xla_hlo::populateHLOToLinalgConversionPattern(func.getContext(), &patterns);
|
||||||
|
if (failed(applyPartialConversion(func, target, patterns, nullptr))) {
|
||||||
|
signalPassFailure();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
namespace xla_lhlo {
|
||||||
|
std::unique_ptr<OperationPass<FuncOp>> createLegalizeLhloToLinalgPass() {
|
||||||
|
return absl::make_unique<LhloLegalizeToLinalg>();
|
||||||
|
}
|
||||||
|
|
||||||
|
static PassRegistration<LhloLegalizeToLinalg> legalize_lhlo_pass(
|
||||||
|
"lhlo-legalize-to-linalg", "Legalize from LHLO dialect to Linalg dialect");
|
||||||
|
} // namespace xla_lhlo
|
||||||
|
|
||||||
|
namespace xla_hlo {
|
||||||
|
|
||||||
|
void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
||||||
|
OwningRewritePatternList* patterns) {
|
||||||
|
patterns->insert<BroadcastConverter<xla_hlo::BroadcastOp, false>,
|
||||||
|
HloBroadcastInDimConverter,
|
||||||
|
PointwiseToLinalgConverter<xla_hlo::AbsOp, false>,
|
||||||
|
PointwiseToLinalgConverter<xla_hlo::AddOp, false>,
|
||||||
|
PointwiseToLinalgConverter<xla_hlo::AndOp, false>,
|
||||||
|
PointwiseToLinalgConverter<xla_hlo::CeilOp, false>,
|
||||||
|
PointwiseToLinalgConverter<xla_hlo::CompareOp, false>,
|
||||||
|
PointwiseToLinalgConverter<xla_hlo::ComplexOp, false>,
|
||||||
|
PointwiseToLinalgConverter<xla_hlo::ConvertOp, false>,
|
||||||
|
PointwiseToLinalgConverter<xla_hlo::CopyOp, false>,
|
||||||
|
PointwiseToLinalgConverter<xla_hlo::CosOp, false>,
|
||||||
|
PointwiseToLinalgConverter<xla_hlo::DivOp, false>,
|
||||||
|
PointwiseToLinalgConverter<xla_hlo::ExpOp, false>,
|
||||||
|
PointwiseToLinalgConverter<xla_hlo::ImagOp, false>,
|
||||||
|
PointwiseToLinalgConverter<xla_hlo::LogOp, false>,
|
||||||
|
PointwiseToLinalgConverter<xla_hlo::MaxOp, false>,
|
||||||
|
PointwiseToLinalgConverter<xla_hlo::MinOp, false>,
|
||||||
|
PointwiseToLinalgConverter<xla_hlo::MulOp, false>,
|
||||||
|
PointwiseToLinalgConverter<xla_hlo::NegOp, false>,
|
||||||
|
PointwiseToLinalgConverter<xla_hlo::RealOp, false>,
|
||||||
|
PointwiseToLinalgConverter<xla_hlo::RemOp, false>,
|
||||||
|
PointwiseToLinalgConverter<xla_hlo::RsqrtOp, false>,
|
||||||
|
PointwiseToLinalgConverter<xla_hlo::SelectOp, false>,
|
||||||
|
PointwiseToLinalgConverter<xla_hlo::SinOp, false>,
|
||||||
|
PointwiseToLinalgConverter<xla_hlo::SqrtOp, false>,
|
||||||
|
PointwiseToLinalgConverter<xla_hlo::SubOp, false>,
|
||||||
|
PointwiseToLinalgConverter<xla_hlo::TanhOp, false>,
|
||||||
|
ReshapeOpConverter<xla_hlo::ReshapeOp, false>,
|
||||||
|
ReverseConverter<xla_hlo::ReverseOp, false>,
|
||||||
|
TransposeConverter<xla_hlo::TransposeOp, false>>(context);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass() {
|
||||||
|
return absl::make_unique<HloLegalizeToLinalg>();
|
||||||
|
}
|
||||||
|
|
||||||
|
static PassRegistration<HloLegalizeToLinalg> legalize_hlo_pass(
|
||||||
|
"hlo-legalize-to-linalg", "Legalize from HLO dialect to Linalg dialect");
|
||||||
|
} // namespace xla_hlo
|
||||||
|
} // namespace mlir
|
|
@ -0,0 +1,188 @@
|
||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "third_party/absl/memory/memory.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Shape/IR/Shape.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Function.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
|
||||||
|
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h"
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace xla_hlo {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// TODO(frgossen): Make it variadic.
|
||||||
|
template <typename OpTy>
|
||||||
|
inline void AddLegalOpOnRankedTensor(ConversionTarget *target) {
|
||||||
|
target->addDynamicallyLegalOp<OpTy>([](OpTy op) {
|
||||||
|
return llvm::all_of((op.getOperation())->getOperandTypes(),
|
||||||
|
[&](Type t) { return t.isa<RankedTensorType>(); });
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Unary element-wise operations on unranked tensors can be applied to the
|
||||||
|
/// flattened tensor with the same effect.
|
||||||
|
/// This pattern rewrites every such operation to
|
||||||
|
/// (i) flatten the input tensor,
|
||||||
|
/// (ii) apply the unary operation, and
|
||||||
|
/// (iii) restore the original shape.
|
||||||
|
template <typename OpTy>
|
||||||
|
struct UnaryElementwiseOpConversion : public OpRewritePattern<OpTy> {
|
||||||
|
explicit UnaryElementwiseOpConversion(MLIRContext *context)
|
||||||
|
: OpRewritePattern<OpTy>(context) {}
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(OpTy op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
// Don't apply conversion to ops with statically shaped operands.
|
||||||
|
Value operand = op.getOperand();
|
||||||
|
auto operandTy = operand.getType().dyn_cast<TensorType>();
|
||||||
|
if (operandTy.hasRank()) return failure();
|
||||||
|
|
||||||
|
// Generate IR to flatten the operand.
|
||||||
|
auto loc = op.getLoc();
|
||||||
|
Value shape = rewriter.create<shape::ShapeOfOp>(loc, operand);
|
||||||
|
Value numElements = rewriter.create<shape::NumElementsOp>(
|
||||||
|
loc, rewriter.getType<shape::SizeType>(), shape);
|
||||||
|
Value numElementsAsIndex = rewriter.create<shape::SizeToIndexOp>(
|
||||||
|
loc, rewriter.getIndexType(), numElements);
|
||||||
|
Value flatShapeAsDimTensor =
|
||||||
|
rewriter.create<TensorFromElementsOp>(loc, numElementsAsIndex);
|
||||||
|
auto flatTensorTy = RankedTensorType::get({ShapedType::kDynamicSize},
|
||||||
|
operandTy.getElementType());
|
||||||
|
Value flatOperand = rewriter.create<xla_hlo::DynamicReshapeOp>(
|
||||||
|
loc, flatTensorTy, operand, flatShapeAsDimTensor);
|
||||||
|
|
||||||
|
// Generate IR for the actual operation.
|
||||||
|
Value flatResult = rewriter.create<OpTy>(loc, flatTensorTy, flatOperand);
|
||||||
|
|
||||||
|
// Generate IR to restore the original shape.
|
||||||
|
auto extentTensorTy = RankedTensorType::get({ShapedType::kDynamicSize},
|
||||||
|
rewriter.getIndexType());
|
||||||
|
Value shapeAsExtentTensor =
|
||||||
|
rewriter.create<shape::ToExtentTensorOp>(loc, extentTensorTy, shape);
|
||||||
|
Value result = rewriter.create<xla_hlo::DynamicReshapeOp>(
|
||||||
|
loc, operandTy, flatResult, shapeAsExtentTensor);
|
||||||
|
rewriter.replaceOp(op, result);
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Binary element-wise operation on unranked tensors can be applied to the
|
||||||
|
/// flattened operand tensors with the same effect.
|
||||||
|
/// This pattern rewrites every such operation to
|
||||||
|
/// (i) flatten the operand tensors,
|
||||||
|
/// (ii) apply the binary operation, and
|
||||||
|
// (iii) restore the original shape.
|
||||||
|
template <typename OpTy>
|
||||||
|
struct BinaryElementwiseOpConversion : public OpRewritePattern<OpTy> {
|
||||||
|
explicit BinaryElementwiseOpConversion(MLIRContext *context)
|
||||||
|
: OpRewritePattern<OpTy>(context) {}
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(OpTy op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
// Don't apply conversion unless both operands are unranked.
|
||||||
|
if (op.lhs().getType().template isa<RankedTensorType>() ||
|
||||||
|
op.rhs().getType().template isa<RankedTensorType>()) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flatten operands.
|
||||||
|
Type shapeTy = shape::ShapeType::get(rewriter.getContext());
|
||||||
|
auto loc = op.getLoc();
|
||||||
|
Value shapeLhs = rewriter.create<shape::ShapeOfOp>(loc, op.lhs());
|
||||||
|
Value shapeRhs = rewriter.create<shape::ShapeOfOp>(loc, op.rhs());
|
||||||
|
Value shape = rewriter.create<shape::AnyOp>(loc, shapeTy,
|
||||||
|
ValueRange{shapeLhs, shapeRhs});
|
||||||
|
Value numElements = rewriter.create<shape::NumElementsOp>(loc, shape);
|
||||||
|
Value numElementsAsIndex =
|
||||||
|
rewriter.create<shape::SizeToIndexOp>(loc, numElements);
|
||||||
|
Value flatShape =
|
||||||
|
rewriter.create<TensorFromElementsOp>(loc, numElementsAsIndex);
|
||||||
|
TensorType lhsTy = op.lhs().getType().template cast<TensorType>();
|
||||||
|
Type flatLhsTy = RankedTensorType::get({ShapedType::kDynamicSize},
|
||||||
|
lhsTy.getElementType());
|
||||||
|
Value flatLhs =
|
||||||
|
rewriter.create<DynamicReshapeOp>(loc, flatLhsTy, op.lhs(), flatShape);
|
||||||
|
TensorType rhsTy = op.rhs().getType().template cast<TensorType>();
|
||||||
|
Type flatRhsTy = RankedTensorType::get({ShapedType::kDynamicSize},
|
||||||
|
rhsTy.getElementType());
|
||||||
|
Value flatRhs =
|
||||||
|
rewriter.create<DynamicReshapeOp>(loc, flatRhsTy, op.rhs(), flatShape);
|
||||||
|
|
||||||
|
// Apply actual operation to flattened operands.
|
||||||
|
Value flatResult = rewriter.create<OpTy>(loc, flatLhs, flatRhs);
|
||||||
|
|
||||||
|
// Restore original shape.
|
||||||
|
auto extentTensorTy = RankedTensorType::get({ShapedType::kDynamicSize},
|
||||||
|
rewriter.getIndexType());
|
||||||
|
Value shapeAsExtentTensor =
|
||||||
|
rewriter.create<shape::ToExtentTensorOp>(loc, extentTensorTy, shape);
|
||||||
|
Value result = rewriter.create<DynamicReshapeOp>(
|
||||||
|
loc, op.getType(), flatResult, shapeAsExtentTensor);
|
||||||
|
rewriter.replaceOp(op, result);
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct TransformUnrankedHloPass
|
||||||
|
: public PassWrapper<TransformUnrankedHloPass, FunctionPass> {
|
||||||
|
void runOnFunction() override {
|
||||||
|
// Setup conversion target.
|
||||||
|
MLIRContext &ctx = getContext();
|
||||||
|
ConversionTarget target(ctx);
|
||||||
|
target.addLegalDialect<XlaHloDialect, StandardOpsDialect,
|
||||||
|
shape::ShapeDialect>();
|
||||||
|
target.addLegalOp<FuncOp>();
|
||||||
|
AddLegalOpOnRankedTensor<SqrtOp>(&target);
|
||||||
|
AddLegalOpOnRankedTensor<AddOp>(&target);
|
||||||
|
|
||||||
|
// Populate rewrite patterns.
|
||||||
|
OwningRewritePatternList patterns;
|
||||||
|
PopulateTransformUnrankedHloPatterns(&ctx, &patterns);
|
||||||
|
|
||||||
|
// Apply transformation.
|
||||||
|
if (failed(applyFullConversion(getFunction(), target, patterns)))
|
||||||
|
return signalPassFailure();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void PopulateTransformUnrankedHloPatterns(MLIRContext *context,
|
||||||
|
OwningRewritePatternList *patterns) {
|
||||||
|
// TODO(frgossen): Populate all unary and binary operations.
|
||||||
|
// clang-format off
|
||||||
|
patterns->insert<
|
||||||
|
BinaryElementwiseOpConversion<AddOp>,
|
||||||
|
UnaryElementwiseOpConversion<SqrtOp>>(context);
|
||||||
|
// clang-format on
|
||||||
|
}
|
||||||
|
|
||||||
|
static PassRegistration<TransformUnrankedHloPass> transform_unranked_hlo_pass(
|
||||||
|
"transform-unranked-hlo",
|
||||||
|
"Realize element-wise operations on ranked tensors where possible");
|
||||||
|
|
||||||
|
} // namespace xla_hlo
|
||||||
|
} // namespace mlir
|
|
@ -0,0 +1,340 @@
|
||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/cycle_detector.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
|
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/DenseSet.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using NodeSet = llvm::DenseSet<int32_t>;
|
||||||
|
using OrderedNodeSet = OrderedSet<int32_t>;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct VecStruct {
|
||||||
|
typedef llvm::SmallVector<T, 4> type;
|
||||||
|
};
|
||||||
|
template <typename T>
|
||||||
|
using Vec = typename VecStruct<T>::type;
|
||||||
|
|
||||||
|
struct Node {
|
||||||
|
// rank number assigned by Pearce-Kelly algorithm
|
||||||
|
int32_t rank;
|
||||||
|
// Temporary marker used by depth-first-search
|
||||||
|
bool visited;
|
||||||
|
// User-supplied data
|
||||||
|
void* data;
|
||||||
|
// List of immediate predecessor nodes in graph
|
||||||
|
OrderedNodeSet in;
|
||||||
|
// List of immediate successor nodes in graph
|
||||||
|
OrderedNodeSet out;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
struct GraphCycles::Rep {
|
||||||
|
Vec<Node*> nodes;
|
||||||
|
// Indices for unused entries in nodes
|
||||||
|
Vec<int32_t> free_nodes;
|
||||||
|
|
||||||
|
// Temporary state.
|
||||||
|
// Results of forward DFS
|
||||||
|
Vec<int32_t> deltaf;
|
||||||
|
// Results of backward DFS
|
||||||
|
Vec<int32_t> deltab;
|
||||||
|
// All nodes to reprocess
|
||||||
|
Vec<int32_t> list;
|
||||||
|
// Rank values to assign to list entries
|
||||||
|
Vec<int32_t> merged;
|
||||||
|
// Emulates recursion stack when doing depth first search
|
||||||
|
Vec<int32_t> stack;
|
||||||
|
};
|
||||||
|
|
||||||
|
GraphCycles::GraphCycles(int32_t num_nodes) : rep_(new Rep) {
|
||||||
|
rep_->nodes.reserve(num_nodes);
|
||||||
|
for (int32_t i = 0; i < num_nodes; ++i) {
|
||||||
|
Node* n = new Node;
|
||||||
|
n->visited = false;
|
||||||
|
n->data = nullptr;
|
||||||
|
n->rank = rep_->nodes.size();
|
||||||
|
rep_->nodes.push_back(n);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
GraphCycles::~GraphCycles() {
|
||||||
|
for (Vec<Node*>::size_type i = 0, e = rep_->nodes.size(); i < e; ++i) {
|
||||||
|
delete rep_->nodes[i];
|
||||||
|
}
|
||||||
|
delete rep_;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool GraphCycles::HasEdge(int32_t x, int32_t y) const {
|
||||||
|
return rep_->nodes[x]->out.Contains(y);
|
||||||
|
}
|
||||||
|
|
||||||
|
void GraphCycles::RemoveEdge(int32_t x, int32_t y) {
|
||||||
|
rep_->nodes[x]->out.Erase(y);
|
||||||
|
rep_->nodes[y]->in.Erase(x);
|
||||||
|
// No need to update the rank assignment since a previous valid
|
||||||
|
// rank assignment remains valid after an edge deletion.
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool ForwardDFS(GraphCycles::Rep* r, int32_t n, int32_t upper_bound);
|
||||||
|
static void BackwardDFS(GraphCycles::Rep* r, int32_t n, int32_t lower_bound);
|
||||||
|
static void Reorder(GraphCycles::Rep* r);
|
||||||
|
static void Sort(const Vec<Node*>&, Vec<int32_t>* delta);
|
||||||
|
static void MoveToList(GraphCycles::Rep* r, Vec<int32_t>* src,
|
||||||
|
Vec<int32_t>* dst);
|
||||||
|
static void ClearVisitedBits(GraphCycles::Rep* r, const Vec<int32_t>& nodes);
|
||||||
|
|
||||||
|
bool GraphCycles::InsertEdge(int32_t x, int32_t y) {
|
||||||
|
if (x == y) return false;
|
||||||
|
Rep* r = rep_;
|
||||||
|
Node* nx = r->nodes[x];
|
||||||
|
if (!nx->out.Insert(y)) {
|
||||||
|
// Edge already exists.
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
Node* ny = r->nodes[y];
|
||||||
|
ny->in.Insert(x);
|
||||||
|
|
||||||
|
if (nx->rank <= ny->rank) {
|
||||||
|
// New edge is consistent with existing rank assignment.
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Current rank assignments are incompatible with the new edge. Recompute.
|
||||||
|
// We only need to consider nodes that fall in the range [ny->rank,nx->rank].
|
||||||
|
if (ForwardDFS(r, y, nx->rank)) {
|
||||||
|
// Found a cycle. Undo the insertion and tell caller.
|
||||||
|
nx->out.Erase(y);
|
||||||
|
ny->in.Erase(x);
|
||||||
|
// Since we do not call Reorder() on this path, clear any visited
|
||||||
|
// markers left by ForwardDFS.
|
||||||
|
ClearVisitedBits(r, r->deltaf);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
BackwardDFS(r, x, ny->rank);
|
||||||
|
Reorder(r);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Follows the edges from producer to consumer and searchs if the node having
|
||||||
|
// rank `n` can reach the node having rank `upper_bound` using a DFS search.
|
||||||
|
// When doing DFS search, We only consider the pathes that satisfy the ranks
|
||||||
|
// of the nodes of the path are all smaller than `upper_bound`.
|
||||||
|
//
|
||||||
|
// Returns true if such path exists.
|
||||||
|
static bool ForwardDFS(GraphCycles::Rep* r, int32_t n, int32_t upper_bound) {
|
||||||
|
// Avoid recursion since stack space might be limited.
|
||||||
|
// We instead keep a stack of nodes to visit.
|
||||||
|
r->deltaf.clear();
|
||||||
|
r->stack.clear();
|
||||||
|
r->stack.push_back(n);
|
||||||
|
while (!r->stack.empty()) {
|
||||||
|
n = r->stack.back();
|
||||||
|
r->stack.pop_back();
|
||||||
|
Node* nn = r->nodes[n];
|
||||||
|
if (nn->visited) continue;
|
||||||
|
|
||||||
|
nn->visited = true;
|
||||||
|
r->deltaf.push_back(n);
|
||||||
|
|
||||||
|
for (auto w : nn->out.GetSequence()) {
|
||||||
|
Node* nw = r->nodes[w];
|
||||||
|
if (nw->rank == upper_bound) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (!nw->visited && nw->rank < upper_bound) {
|
||||||
|
r->stack.push_back(w);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Follows the edges from consumer to producer and visit all the nodes that
|
||||||
|
// is reachable from node `n` and have rank larger than `lower_bound`.
|
||||||
|
static void BackwardDFS(GraphCycles::Rep* r, int32_t n, int32_t lower_bound) {
|
||||||
|
r->deltab.clear();
|
||||||
|
r->stack.clear();
|
||||||
|
r->stack.push_back(n);
|
||||||
|
while (!r->stack.empty()) {
|
||||||
|
n = r->stack.back();
|
||||||
|
r->stack.pop_back();
|
||||||
|
Node* nn = r->nodes[n];
|
||||||
|
if (nn->visited) continue;
|
||||||
|
|
||||||
|
nn->visited = true;
|
||||||
|
r->deltab.push_back(n);
|
||||||
|
|
||||||
|
for (auto w : nn->in.GetSequence()) {
|
||||||
|
Node* nw = r->nodes[w];
|
||||||
|
if (!nw->visited && lower_bound < nw->rank) {
|
||||||
|
r->stack.push_back(w);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recomputes rank assignments to make them compatible with the edges (producer
|
||||||
|
// has smaller rank than its consumer)
|
||||||
|
static void Reorder(GraphCycles::Rep* r) {
|
||||||
|
Sort(r->nodes, &r->deltab);
|
||||||
|
Sort(r->nodes, &r->deltaf);
|
||||||
|
|
||||||
|
// Adds contents of delta lists to list (backwards deltas first).
|
||||||
|
r->list.clear();
|
||||||
|
MoveToList(r, &r->deltab, &r->list);
|
||||||
|
MoveToList(r, &r->deltaf, &r->list);
|
||||||
|
|
||||||
|
// Produce sorted list of all ranks that will be reassigned.
|
||||||
|
r->merged.resize(r->deltab.size() + r->deltaf.size());
|
||||||
|
std::merge(r->deltab.begin(), r->deltab.end(), r->deltaf.begin(),
|
||||||
|
r->deltaf.end(), r->merged.begin());
|
||||||
|
|
||||||
|
// Assign the ranks in order to the collected list.
|
||||||
|
for (Vec<int32_t>::size_type i = 0, e = r->list.size(); i < e; ++i) {
|
||||||
|
r->nodes[r->list[i]]->rank = r->merged[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sorts nodes in the vector according to their ranks. Small rank first.
|
||||||
|
static void Sort(const Vec<Node*>& nodes, Vec<int32_t>* delta) {
|
||||||
|
struct ByRank {
|
||||||
|
const Vec<Node*>* nodes;
|
||||||
|
bool operator()(int32_t a, int32_t b) const {
|
||||||
|
return (*nodes)[a]->rank < (*nodes)[b]->rank;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
ByRank cmp;
|
||||||
|
cmp.nodes = &nodes;
|
||||||
|
std::sort(delta->begin(), delta->end(), cmp);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Collects ranks of nodes in vector `src` to vector `dst`
|
||||||
|
static void MoveToList(GraphCycles::Rep* r, Vec<int32_t>* src,
|
||||||
|
Vec<int32_t>* dst) {
|
||||||
|
for (Vec<int32_t>::size_type i = 0, e = src->size(); i < e; i++) {
|
||||||
|
int32_t w = (*src)[i];
|
||||||
|
// Replace src entry with its rank
|
||||||
|
(*src)[i] = r->nodes[w]->rank;
|
||||||
|
// Prepare for future DFS calls
|
||||||
|
r->nodes[w]->visited = false;
|
||||||
|
dst->push_back(w);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clears bookkeeping fileds used during the last DFS process.
|
||||||
|
static void ClearVisitedBits(GraphCycles::Rep* r, const Vec<int32_t>& nodes) {
|
||||||
|
for (Vec<int32_t>::size_type i = 0, e = nodes.size(); i < e; i++) {
|
||||||
|
r->nodes[nodes[i]]->visited = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool GraphCycles::IsReachable(int32_t x, int32_t y) {
|
||||||
|
if (x == y) return true;
|
||||||
|
Rep* r = rep_;
|
||||||
|
Node* nx = r->nodes[x];
|
||||||
|
Node* ny = r->nodes[y];
|
||||||
|
|
||||||
|
if (nx->rank >= ny->rank) {
|
||||||
|
// x cannot reach y since it is after it in the topological ordering
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// See if x can reach y using a DFS search that is limited to y's rank
|
||||||
|
bool reachable = ForwardDFS(r, x, ny->rank);
|
||||||
|
|
||||||
|
// Clear any visited markers left by ForwardDFS.
|
||||||
|
ClearVisitedBits(r, r->deltaf);
|
||||||
|
return reachable;
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::Optional<int32_t> GraphCycles::ContractEdge(int32_t a, int32_t b) {
|
||||||
|
assert(HasEdge(a, b));
|
||||||
|
RemoveEdge(a, b);
|
||||||
|
|
||||||
|
if (IsReachable(a, b)) {
|
||||||
|
// Restore the graph to its original state.
|
||||||
|
InsertEdge(a, b);
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
if (rep_->nodes[b]->in.Size() + rep_->nodes[b]->out.Size() >
|
||||||
|
rep_->nodes[a]->in.Size() + rep_->nodes[a]->out.Size()) {
|
||||||
|
// Swap "a" and "b" to minimize copying.
|
||||||
|
std::swap(a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
Node* nb = rep_->nodes[b];
|
||||||
|
OrderedNodeSet out = std::move(nb->out);
|
||||||
|
OrderedNodeSet in = std::move(nb->in);
|
||||||
|
for (int32_t y : out.GetSequence()) {
|
||||||
|
rep_->nodes[y]->in.Erase(b);
|
||||||
|
}
|
||||||
|
for (int32_t y : in.GetSequence()) {
|
||||||
|
rep_->nodes[y]->out.Erase(b);
|
||||||
|
}
|
||||||
|
rep_->free_nodes.push_back(b);
|
||||||
|
|
||||||
|
rep_->nodes[a]->out.Reserve(rep_->nodes[a]->out.Size() + out.Size());
|
||||||
|
for (int32_t y : out.GetSequence()) {
|
||||||
|
InsertEdge(a, y);
|
||||||
|
}
|
||||||
|
|
||||||
|
rep_->nodes[a]->in.Reserve(rep_->nodes[a]->in.Size() + in.Size());
|
||||||
|
for (int32_t y : in.GetSequence()) {
|
||||||
|
InsertEdge(y, a);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Note, if the swap happened it might be what originally was called "b".
|
||||||
|
return a;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int32_t> GraphCycles::SuccessorsCopy(int32_t node) const {
|
||||||
|
return rep_->nodes[node]->out.GetSequence();
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
void SortInPostOrder(const Vec<Node*>& nodes, std::vector<int32_t>* to_sort) {
|
||||||
|
std::sort(to_sort->begin(), to_sort->end(), [&](int32_t a, int32_t b) {
|
||||||
|
return nodes[a]->rank > nodes[b]->rank;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::vector<int32_t> GraphCycles::AllNodesInPostOrder() const {
|
||||||
|
llvm::DenseSet<int32_t> free_nodes_set;
|
||||||
|
for (int32_t n : rep_->free_nodes) free_nodes_set.insert(n);
|
||||||
|
|
||||||
|
std::vector<int32_t> all_nodes;
|
||||||
|
all_nodes.reserve(rep_->nodes.size() - free_nodes_set.size());
|
||||||
|
for (size_t i = 0, e = rep_->nodes.size(); i < e; i++) {
|
||||||
|
if (!free_nodes_set.count(i)) {
|
||||||
|
all_nodes.push_back(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
SortInPostOrder(rep_->nodes, &all_nodes);
|
||||||
|
return all_nodes;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mlir
|
|
@ -0,0 +1,89 @@
|
||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/cycle_detector.h"
|
||||||
|
|
||||||
|
#include "third_party/tensorflow/compiler/xla/test.h"
|
||||||
|
|
||||||
|
class GraphCyclesTest : public ::testing::Test {
|
||||||
|
public:
|
||||||
|
GraphCyclesTest() : g_(100) {}
|
||||||
|
|
||||||
|
bool AddEdge(int x, int y) { return g_.InsertEdge(x, y); }
|
||||||
|
|
||||||
|
void AddMultiples() {
|
||||||
|
// For every node x > 0: add edge to 2*x, 3*x
|
||||||
|
for (int x = 1; x < 25; x++) {
|
||||||
|
EXPECT_TRUE(AddEdge(x, 2 * x)) << x;
|
||||||
|
EXPECT_TRUE(AddEdge(x, 3 * x)) << x;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mlir::GraphCycles g_;
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(GraphCyclesTest, NoCycle) { AddMultiples(); }
|
||||||
|
|
||||||
|
TEST_F(GraphCyclesTest, SimpleCycle) {
|
||||||
|
AddMultiples();
|
||||||
|
EXPECT_FALSE(AddEdge(8, 4));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(GraphCyclesTest, IndirectCycle) {
|
||||||
|
AddMultiples();
|
||||||
|
EXPECT_TRUE(AddEdge(16, 9));
|
||||||
|
EXPECT_FALSE(AddEdge(9, 2));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(GraphCyclesTest, RemoveEdge) {
|
||||||
|
EXPECT_TRUE(AddEdge(1, 2));
|
||||||
|
EXPECT_TRUE(AddEdge(2, 3));
|
||||||
|
EXPECT_TRUE(AddEdge(3, 4));
|
||||||
|
EXPECT_TRUE(AddEdge(4, 5));
|
||||||
|
g_.RemoveEdge(2, 3);
|
||||||
|
EXPECT_FALSE(g_.HasEdge(2, 3));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(GraphCyclesTest, IsReachable) {
|
||||||
|
EXPECT_TRUE(AddEdge(1, 2));
|
||||||
|
EXPECT_TRUE(AddEdge(2, 3));
|
||||||
|
EXPECT_TRUE(AddEdge(3, 4));
|
||||||
|
EXPECT_TRUE(AddEdge(4, 5));
|
||||||
|
|
||||||
|
EXPECT_TRUE(g_.IsReachable(1, 5));
|
||||||
|
EXPECT_FALSE(g_.IsReachable(5, 1));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(GraphCyclesTest, ContractEdge) {
|
||||||
|
ASSERT_TRUE(AddEdge(1, 2));
|
||||||
|
ASSERT_TRUE(AddEdge(1, 3));
|
||||||
|
ASSERT_TRUE(AddEdge(2, 3));
|
||||||
|
ASSERT_TRUE(AddEdge(2, 4));
|
||||||
|
ASSERT_TRUE(AddEdge(3, 4));
|
||||||
|
|
||||||
|
// It will introduce a cycle if the edge is contracted
|
||||||
|
EXPECT_FALSE(g_.ContractEdge(1, 3).hasValue());
|
||||||
|
EXPECT_TRUE(g_.HasEdge(1, 3));
|
||||||
|
|
||||||
|
// Node (2) has more edges.
|
||||||
|
EXPECT_EQ(*g_.ContractEdge(1, 2), 2);
|
||||||
|
EXPECT_TRUE(g_.HasEdge(2, 3));
|
||||||
|
EXPECT_TRUE(g_.HasEdge(2, 4));
|
||||||
|
EXPECT_TRUE(g_.HasEdge(3, 4));
|
||||||
|
|
||||||
|
// Node (2) has more edges.
|
||||||
|
EXPECT_EQ(*g_.ContractEdge(2, 3), 2);
|
||||||
|
EXPECT_TRUE(g_.HasEdge(2, 4));
|
||||||
|
}
|
Loading…
Reference in New Issue