From 31dc1b21ebebcf0f6969d3d7b7fc71e3db834c33 Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Mon, 6 Jul 2020 20:57:00 +0000 Subject: [PATCH] Move XLA-independent transforms to the new MLIR-HLO directory This is as straighforward as possible, more cleanup/rewrite to come. PiperOrigin-RevId: 319849713 --- include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h | 6 +- include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h | 6 +- .../mhlo/IR/infer_fusibility_op_interface.h | 6 +- include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h | 6 +- .../mhlo/transforms/map_hlo_to_lhlo_op.h | 80 ++ .../mhlo/transforms/map_xla_to_scalar_op.h | 510 ++++++++++ .../mlir-hlo/Dialect/mhlo/transforms/passes.h | 105 ++ .../Dialect/mhlo/transforms/rewriters.h | 106 ++ include/mlir-hlo/utils/cycle_detector.h | 165 ++++ .../mhlo/transforms/chlo_legalize_to_hlo.cc | 242 +++++ .../transforms/chlo_legalize_to_hlo_pass.cc | 57 ++ .../mhlo/transforms/hlo_legalize_to_lhlo.cc | 493 ++++++++++ .../mhlo/transforms/legalize_control_flow.cc | 237 +++++ .../legalize_tanh_to_approximation.cc | 156 +++ .../mhlo/transforms/legalize_to_standard.cc | 208 ++++ .../legalize_to_standard_patterns.td | 71 ++ .../mhlo/transforms/lhlo_copy_removal.cc | 105 ++ .../mhlo/transforms/lhlo_fuse_linalg.cc | 151 +++ .../transforms/lhlo_legalize_to_affine.cc | 161 ++++ .../mhlo/transforms/lhlo_legalize_to_gpu.cc | 196 ++++ .../mhlo/transforms/lhlo_legalize_to_llvm.cc | 136 +++ .../transforms/lhlo_legalize_to_llvm_pass.cc | 59 ++ .../lhlo_legalize_to_parallel_loops.cc | 731 ++++++++++++++ lib/Dialect/mhlo/transforms/lower_complex.cc | 79 ++ .../mhlo/transforms/lower_complex_patterns.td | 109 +++ .../mhlo/transforms/lower_general_dot.cc | 194 ++++ .../mhlo/transforms/materialize_broadcasts.cc | 90 ++ .../transforms/materialize_broadcasts_pass.cc | 58 ++ .../sink_constants_to_control_flow.cc | 85 ++ .../transforms/test_infer_shaped_type_pass.cc | 100 ++ .../mhlo/transforms/unfuse_batch_norm.cc | 184 ++++ .../mhlo/transforms/unfuse_batch_norm_pass.cc | 46 + lib/Dialect/mhlo/transforms/xla_hlo_fusion.cc | 579 +++++++++++ .../mhlo/transforms/xla_legalize_to_linalg.cc | 909 ++++++++++++++++++ .../transforms/xla_transform_unranked_hlo.cc | 188 ++++ lib/utils/cycle_detector.cc | 340 +++++++ lib/utils/cycle_detector_test.cc | 89 ++ 37 files changed, 7031 insertions(+), 12 deletions(-) create mode 100644 include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h create mode 100644 include/mlir-hlo/Dialect/mhlo/transforms/map_xla_to_scalar_op.h create mode 100644 include/mlir-hlo/Dialect/mhlo/transforms/passes.h create mode 100644 include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h create mode 100644 include/mlir-hlo/utils/cycle_detector.h create mode 100644 lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc create mode 100644 lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc create mode 100644 lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc create mode 100644 lib/Dialect/mhlo/transforms/legalize_control_flow.cc create mode 100644 lib/Dialect/mhlo/transforms/legalize_tanh_to_approximation.cc create mode 100644 lib/Dialect/mhlo/transforms/legalize_to_standard.cc create mode 100644 lib/Dialect/mhlo/transforms/legalize_to_standard_patterns.td create mode 100644 lib/Dialect/mhlo/transforms/lhlo_copy_removal.cc create mode 100644 lib/Dialect/mhlo/transforms/lhlo_fuse_linalg.cc create mode 100644 lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc create mode 100644 lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc create mode 100644 lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm.cc create mode 100644 lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm_pass.cc create mode 100644 lib/Dialect/mhlo/transforms/lhlo_legalize_to_parallel_loops.cc create mode 100644 lib/Dialect/mhlo/transforms/lower_complex.cc create mode 100644 lib/Dialect/mhlo/transforms/lower_complex_patterns.td create mode 100644 lib/Dialect/mhlo/transforms/lower_general_dot.cc create mode 100644 lib/Dialect/mhlo/transforms/materialize_broadcasts.cc create mode 100644 lib/Dialect/mhlo/transforms/materialize_broadcasts_pass.cc create mode 100644 lib/Dialect/mhlo/transforms/sink_constants_to_control_flow.cc create mode 100644 lib/Dialect/mhlo/transforms/test_infer_shaped_type_pass.cc create mode 100644 lib/Dialect/mhlo/transforms/unfuse_batch_norm.cc create mode 100644 lib/Dialect/mhlo/transforms/unfuse_batch_norm_pass.cc create mode 100644 lib/Dialect/mhlo/transforms/xla_hlo_fusion.cc create mode 100644 lib/Dialect/mhlo/transforms/xla_legalize_to_linalg.cc create mode 100644 lib/Dialect/mhlo/transforms/xla_transform_unranked_hlo.cc create mode 100644 lib/utils/cycle_detector.cc create mode 100644 lib/utils/cycle_detector_test.cc diff --git a/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h index 222b808..d7fc318 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h +++ b/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_CHLO_OPS_H_ -#define TENSORFLOW_COMPILER_MLIR_XLA_IR_CHLO_OPS_H_ +#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_CHLO_OPS_H_ +#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_CHLO_OPS_H_ #include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringRef.h" #include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Dialect.h" @@ -42,4 +42,4 @@ class XlaHloClientDialect : public Dialect { } // namespace xla_chlo } // namespace mlir -#endif // TENSORFLOW_COMPILER_MLIR_XLA_IR_CHLO_OPS_H_ +#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_CHLO_OPS_H_ diff --git a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h index b6360b7..d945900 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h +++ b/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h @@ -15,8 +15,8 @@ limitations under the License. // This file defines the operations used in the XLA dialect. -#ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_OPS_H_ -#define TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_OPS_H_ +#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_H_ +#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_H_ #include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringRef.h" #include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h" @@ -96,4 +96,4 @@ LogicalResult deriveShapeFromFirstOperand( } // end namespace xla_hlo } // end namespace mlir -#endif // TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_OPS_H_ +#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_H_ diff --git a/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h b/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h index 412711b..76b4ad8 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h +++ b/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_INFER_FUSIBILITY_OP_INTERFACE_H_ -#define TENSORFLOW_COMPILER_MLIR_XLA_IR_INFER_FUSIBILITY_OP_INTERFACE_H_ +#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_INFER_FUSIBILITY_OP_INTERFACE_H_ +#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_INFER_FUSIBILITY_OP_INTERFACE_H_ #include "mlir/IR/OpDefinition.h" #include "mlir/IR/StandardTypes.h" @@ -25,4 +25,4 @@ namespace mlir { } // namespace mlir -#endif // TENSORFLOW_COMPILER_MLIR_XLA_IR_INFER_FUSIBILITY_OP_INTERFACE_H_ +#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_INFER_FUSIBILITY_OP_INTERFACE_H_ diff --git a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h index 554252a..0ea62ba 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h +++ b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h @@ -15,8 +15,8 @@ limitations under the License. // This file defines the operations used in the LXLA dialect. -#ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_LHLO_OPS_H_ -#define TENSORFLOW_COMPILER_MLIR_XLA_IR_LHLO_OPS_H_ +#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_OPS_H_ +#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_OPS_H_ #include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringRef.h" #include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h" @@ -49,4 +49,4 @@ class XlaLhloDialect : public Dialect { } // namespace xla_lhlo } // end namespace mlir -#endif // TENSORFLOW_COMPILER_MLIR_XLA_IR_LHLO_OPS_H_ +#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_OPS_H_ diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h b/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h new file mode 100644 index 0000000..5e826c2 --- /dev/null +++ b/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h @@ -0,0 +1,80 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_HLO_TO_LHLO_OP_H_ +#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_HLO_TO_LHLO_OP_H_ + +#include + +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" + +namespace mlir { +namespace xla_hlo { + +template +struct HloToLhloOpImpl { + using Type = std::false_type; +}; +template +using HloToLhloOp = typename HloToLhloOpImpl::Type; + +#define MAP_HLO_TO_LHLO(OpName) \ + template <> \ + struct HloToLhloOpImpl { \ + using Type = xla_lhlo::OpName; \ + } + +MAP_HLO_TO_LHLO(AbsOp); +MAP_HLO_TO_LHLO(AddOp); +MAP_HLO_TO_LHLO(AndOp); +MAP_HLO_TO_LHLO(BroadcastInDimOp); +MAP_HLO_TO_LHLO(CeilOp); +MAP_HLO_TO_LHLO(ConstOp); +MAP_HLO_TO_LHLO(CompareOp); +MAP_HLO_TO_LHLO(ComplexOp); +MAP_HLO_TO_LHLO(ConvOp); +MAP_HLO_TO_LHLO(ConvertOp); +MAP_HLO_TO_LHLO(CopyOp); +MAP_HLO_TO_LHLO(CosOp); +MAP_HLO_TO_LHLO(DivOp); +MAP_HLO_TO_LHLO(DotOp); +MAP_HLO_TO_LHLO(ExpOp); +MAP_HLO_TO_LHLO(GatherOp); +MAP_HLO_TO_LHLO(ImagOp); +MAP_HLO_TO_LHLO(IotaOp); +MAP_HLO_TO_LHLO(LogOp); +MAP_HLO_TO_LHLO(MaxOp); +MAP_HLO_TO_LHLO(MinOp); +MAP_HLO_TO_LHLO(MulOp); +MAP_HLO_TO_LHLO(NegOp); +MAP_HLO_TO_LHLO(RealOp); +MAP_HLO_TO_LHLO(ReduceOp); +MAP_HLO_TO_LHLO(ReshapeOp); +MAP_HLO_TO_LHLO(RemOp); +MAP_HLO_TO_LHLO(RsqrtOp); +MAP_HLO_TO_LHLO(SelectOp); +MAP_HLO_TO_LHLO(SignOp); +MAP_HLO_TO_LHLO(SinOp); +MAP_HLO_TO_LHLO(SqrtOp); +MAP_HLO_TO_LHLO(SubOp); +MAP_HLO_TO_LHLO(TanhOp); + +#undef MAP_HLO_TO_LHLO + +} // namespace xla_hlo +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_HLO_TO_LHLO_OP_H_ diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/map_xla_to_scalar_op.h b/include/mlir-hlo/Dialect/mhlo/transforms/map_xla_to_scalar_op.h new file mode 100644 index 0000000..bb710a8 --- /dev/null +++ b/include/mlir-hlo/Dialect/mhlo/transforms/map_xla_to_scalar_op.h @@ -0,0 +1,510 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_XLA_TO_SCALAR_OP_H_ +#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_XLA_TO_SCALAR_OP_H_ + +#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringRef.h" +#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringSwitch.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h" + +namespace mlir { +namespace xla_lhlo { +namespace impl { + +// A struct to map LhloBinaryOpTy type to the corresponding floating-point and +// integer scalar operation types. +template +struct LhloToScalarOp; + +template <> +struct LhloToScalarOp { + using FOp = ::mlir::AddFOp; + using IOp = ::mlir::AddIOp; +}; +template <> +struct LhloToScalarOp { + using FOp = ::mlir::CmpFOp; + using IOp = ::mlir::CmpIOp; +}; +template <> +struct LhloToScalarOp { + using FOp = ::mlir::DivFOp; + using IOp = ::mlir::SignedDivIOp; +}; +template <> +struct LhloToScalarOp { + using FOp = ::mlir::MulFOp; + using IOp = ::mlir::MulIOp; +}; +template <> +struct LhloToScalarOp { + using FOp = ::mlir::RemFOp; + using IOp = ::mlir::SignedRemIOp; +}; +template <> +struct LhloToScalarOp { + using FOp = ::mlir::SubFOp; + using IOp = ::mlir::SubIOp; +}; + +template +struct ScalarOp { + using FOp = typename LhloToScalarOp::FOp; + using IOp = typename LhloToScalarOp::IOp; +}; + +// Alias for the map from LHLO binary op type to STD floating-point op type. +template +using ScalarFOp = typename ScalarOp::FOp; +// Alias for the map from LHLO binary op type to STD integer op type. +template +using ScalarIOp = typename ScalarOp::IOp; + +template +struct MapLhloOpToStdScalarOpImpl { + Value operator()(Location loc, ArrayRef result_types, + ArrayRef args, OpBuilder* b) { + return nullptr; + } +}; + +template +struct MapLhloOpToStdScalarOpImpl { + Value operator()(Location loc, ArrayRef result_types, + ArrayRef args, OpBuilder* b) { + return b->template create(loc, result_types, args, mlir::None); + } +}; + +template +struct MapLhloOpToStdScalarOpImpl { + Value operator()(Location loc, ArrayRef result_types, + ArrayRef args, OpBuilder* b) { + Type element_type = args.front().getType(); + if (element_type.isa()) { + return b->template create(loc, result_types, args, + mlir::None); + } + return MapLhloOpToStdScalarOpImpl{}(loc, result_types, args, b); + } +}; + +// Inserts the computation that corresponds to the body of the loop for lowered +// LHLO unary/binary op. Returns the value for the result. +template +inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, + ArrayRef args, OpBuilder* b) { + return MapLhloOpToStdScalarOpImpl, FloatType, + ScalarFOp>{}(loc, result_types, + args, b); +} + +template <> +inline Value MapLhloOpToStdScalarOp( + Location loc, ArrayRef result_types, ArrayRef args, + OpBuilder* b) { + Type element_type = args.front().getType(); + if (element_type.isa()) { + return MapLhloOpToStdScalarOpImpl{}( + loc, result_types, args, b); + } + if (element_type.isa()) { + // xla_lhlo.abs(x, result) -> result = select((x > 0), x, sub(0, x)) + Value lhs = args[0]; + auto integer_type = element_type.dyn_cast(); + + auto zero_intval = + b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth()); + auto lhs_gt_zero = b->create>(loc, CmpIPredicate::sge, + lhs, zero_intval); + auto neg_val = b->create>(loc, zero_intval, lhs); + return b->create<::mlir::SelectOp>(loc, lhs_gt_zero, lhs, neg_val); + } + return nullptr; +} + +template <> +inline Value MapLhloOpToStdScalarOp( + Location loc, ArrayRef result_types, ArrayRef args, + OpBuilder* b) { + return MapLhloOpToStdScalarOpImpl{}( + loc, result_types, args, b); +} + +template +inline Optional getCmpPredicate( + StringRef xla_comparison_direction) { + return llvm::None; +} + +template <> +inline Optional getCmpPredicate( + StringRef xla_comparison_direction) { + return llvm::StringSwitch>(xla_comparison_direction) + .Case("EQ", CmpFPredicate::OEQ) + .Case("NE", CmpFPredicate::ONE) + .Case("GE", CmpFPredicate::OGE) + .Case("GT", CmpFPredicate::OGT) + .Case("LE", CmpFPredicate::OLE) + .Case("LT", CmpFPredicate::OLT) + .Default(llvm::None); +} + +template <> +inline Optional getCmpPredicate( + StringRef xla_comparison_direction) { + return llvm::StringSwitch>(xla_comparison_direction) + .Case("EQ", CmpIPredicate::eq) + .Case("NE", CmpIPredicate::ne) + .Case("GE", CmpIPredicate::sge) + .Case("GT", CmpIPredicate::sgt) + .Case("LE", CmpIPredicate::sle) + .Case("LT", CmpIPredicate::slt) + .Default(llvm::None); +} + +template +inline Value MapXlaCompareOpToStdScalarOp(Location loc, + StringRef comparison_direction, + ArrayRef result_types, + ArrayRef args, OpBuilder* b) { + const auto& lhs = args[0]; + const auto& rhs = args[1]; + Type element_type = lhs.getType(); + if (element_type.isSignlessInteger()) { + Optional predicate = + getCmpPredicate(comparison_direction); + assert(predicate.hasValue() && "expected valid comparison direction"); + return b->create>(loc, predicate.getValue(), lhs, + rhs); + } + if (element_type.isa()) { + Optional predicate = + getCmpPredicate(comparison_direction); + assert(predicate.hasValue() && "expected valid comparison direction"); + return b->create>(loc, predicate.getValue(), lhs, + rhs); + } + return nullptr; +} + +template <> +inline Value MapLhloOpToStdScalarOp( + Location loc, ArrayRef result_types, ArrayRef args, + OpBuilder* b) { + return args.front(); +} + +template <> +inline Value MapLhloOpToStdScalarOp( + Location loc, ArrayRef result_types, ArrayRef args, + OpBuilder* b) { + return MapLhloOpToStdScalarOpImpl{}( + loc, result_types, args, b); +} + +template <> +inline Value MapLhloOpToStdScalarOp( + Location loc, ArrayRef result_types, ArrayRef args, + OpBuilder* b) { + return MapLhloOpToStdScalarOpImpl{}( + loc, result_types, args, b); +} + +template <> +inline Value MapLhloOpToStdScalarOp( + Location loc, ArrayRef result_types, ArrayRef args, + OpBuilder* b) { + return MapLhloOpToStdScalarOpImpl{}(loc, result_types, args, + b); +} + +template <> +inline Value MapLhloOpToStdScalarOp( + Location loc, ArrayRef result_types, ArrayRef args, + OpBuilder* b) { + return MapLhloOpToStdScalarOpImpl{}(loc, result_types, args, b); +} + +template <> +inline Value MapLhloOpToStdScalarOp( + Location loc, ArrayRef result_types, ArrayRef args, + OpBuilder* b) { + return MapLhloOpToStdScalarOpImpl{}(loc, result_types, args, b); +} + +template <> +inline Value MapLhloOpToStdScalarOp( + Location loc, ArrayRef result_types, ArrayRef args, + OpBuilder* b) { + Type sourceType = args.front().getType(); + Type targetType = result_types.front(); + + if (mlir::SIToFPOp::areCastCompatible(sourceType, targetType)) { + return b->create(loc, result_types, args, mlir::None); + } else if (sourceType.isa() && targetType.isa()) { + FloatType src = sourceType.cast(); + FloatType res = targetType.cast(); + if (src.getWidth() > res.getWidth()) { + return b->create(loc, result_types, args, mlir::None); + } else if (src.getWidth() < res.getWidth()) { + return b->create(loc, result_types, args, mlir::None); + } + // No conversion is needed for the same width floats + return args.front(); + } + if (sourceType.isSignlessInteger() && targetType.isSignlessInteger()) { + IntegerType src = sourceType.cast(); + IntegerType res = targetType.cast(); + if (src.getWidth() > res.getWidth()) { + return b->create(loc, result_types, args, mlir::None); + } else if (src.getWidth() < res.getWidth()) { + return b->create(loc, result_types, args, + mlir::None); + } + // No conversion is needed for the same width integers + return args.front(); + } + if (mlir::FPToSIOp::areCastCompatible(sourceType, targetType)) { + return b->create(loc, result_types, args, mlir::None); + } + return nullptr; +} + +template <> +inline Value MapLhloOpToStdScalarOp( + Location loc, ArrayRef result_types, ArrayRef args, + OpBuilder* b) { + // Dot Op converter from lhlo to affine only accepts float and integer types. + const auto& lhs = args[0]; + const auto& rhs = args[1]; + const auto& result = args[2]; + Type element_type = lhs.getType(); + if (element_type.isa()) { + Value float_mul = MapLhloOpToStdScalarOpImpl{}( + loc, result_types, {lhs, rhs}, b); + return MapLhloOpToStdScalarOpImpl{}( + loc, result_types, {float_mul, result}, b); + } + if (element_type.isa()) { + Value int_mul = MapLhloOpToStdScalarOpImpl{}( + loc, result_types, {lhs, rhs}, b); + return MapLhloOpToStdScalarOpImpl{}( + loc, result_types, {int_mul, result}, b); + } + return nullptr; +} + +template <> +inline Value MapLhloOpToStdScalarOp( + Location loc, ArrayRef result_types, ArrayRef args, + OpBuilder* b) { + return MapLhloOpToStdScalarOpImpl{}( + loc, result_types, args, b); +} + +template <> +inline Value MapLhloOpToStdScalarOp( + Location loc, ArrayRef result_types, ArrayRef args, + OpBuilder* b) { + return MapLhloOpToStdScalarOpImpl{}( + loc, result_types, args, b); +} + +/// Implements the conversion of XLA op to scalar op (to use within region of a +/// linalg.generic op) for compare-select style operations like min/max. +template +struct XlaCompareSelectOpToStdScalarOp { + static Value map(Location loc, StringRef comparison_direction, + ArrayRef result_types, ArrayRef args, + OpBuilder* b) { + return nullptr; + } +}; + +/// Specialization which allows converting to a comparison operation in standard +/// dialect with a given predicate based on the element type of the operand. +template +struct XlaCompareSelectOpToStdScalarOp { + static Value map(Location loc, StringRef comparison_direction, + ArrayRef result_types, ArrayRef args, + OpBuilder* b) { + Type element_type = args.front().getType(); + if (element_type.isa()) { + auto predicate = getCmpPredicate(comparison_direction); + assert(predicate.hasValue() && "expected valid comparison direction"); + auto cmp = b->template create(loc, predicate.getValue(), + args[0], args[1]); + return b->create<::mlir::SelectOp>(loc, cmp, args[0], args[1]); + } + return XlaCompareSelectOpToStdScalarOp::map( + loc, comparison_direction, result_types, args, b); + } +}; + +template <> +inline Value MapLhloOpToStdScalarOp( + Location loc, ArrayRef result_types, ArrayRef args, + OpBuilder* b) { + return MapLhloOpToStdScalarOpImpl{}( + loc, result_types, args, b); +} + +template <> +inline Value MapLhloOpToStdScalarOp( + Location loc, ArrayRef result_types, ArrayRef args, + OpBuilder* b) { + return XlaCompareSelectOpToStdScalarOp< + IntegerType, ScalarIOp, CmpIPredicate, FloatType, + ScalarFOp, CmpFPredicate>::map(loc, "GT", + result_types, args, + b); +} + +template <> +inline Value MapLhloOpToStdScalarOp( + Location loc, ArrayRef result_types, ArrayRef args, + OpBuilder* b) { + return XlaCompareSelectOpToStdScalarOp< + IntegerType, ScalarIOp, CmpIPredicate, FloatType, + ScalarFOp, CmpFPredicate>::map(loc, "LT", + result_types, args, + b); +} + +template <> +inline Value MapLhloOpToStdScalarOp( + Location loc, ArrayRef result_types, ArrayRef args, + OpBuilder* b) { + Type element_type = args.front().getType(); + if (element_type.isa()) { + return MapLhloOpToStdScalarOpImpl{}( + loc, result_types, args, b); + } + if (element_type.isa()) { + // xla_lhlo.neg(x, result) -> result = sub(0, x) + Value lhs = args[0]; + auto integer_type = element_type.dyn_cast(); + + auto zero_intval = + b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth()); + return b->create>(loc, zero_intval, lhs); + } + return nullptr; +} + +template <> +inline Value MapLhloOpToStdScalarOp( + Location loc, ArrayRef result_types, ArrayRef args, + OpBuilder* b) { + return MapLhloOpToStdScalarOpImpl{}( + loc, result_types, args, b); +} + +template <> +inline Value MapLhloOpToStdScalarOp( + Location loc, ArrayRef result_types, ArrayRef args, + OpBuilder* b) { + return MapLhloOpToStdScalarOpImpl<::mlir::SelectOp>{}(loc, result_types, args, + b); +} + +template <> +inline Value MapLhloOpToStdScalarOp( + Location loc, ArrayRef result_types, ArrayRef args, + OpBuilder* b) { + Type element_type = args.front().getType(); + if (element_type.isa()) { + FloatType float_type = element_type.cast(); + APFloat const_value = float_type.isF32() ? APFloat(1.0f) : APFloat(1.0); + Value one = b->create(loc, const_value, float_type); + return b->create<::mlir::CopySignOp>(loc, result_types, one, args[0]); + } + return nullptr; +} + +template <> +inline Value MapLhloOpToStdScalarOp( + Location loc, ArrayRef result_types, ArrayRef args, + OpBuilder* b) { + return MapLhloOpToStdScalarOpImpl{}( + loc, result_types, args, b); +} + +template <> +inline Value MapLhloOpToStdScalarOp( + Location loc, ArrayRef result_types, ArrayRef args, + OpBuilder* b) { + return MapLhloOpToStdScalarOpImpl{}( + loc, result_types, args, b); +} + +} // namespace impl + +struct XlaOpToStdScalarOp { + // Implementation for LHLO ops except xla_lhlo::CompareOp. + template ::value && + std::is_same, + std::false_type>::value>> + static Value map(XlaOpTy op, ArrayRef result_types, + ArrayRef args, OpBuilder* b, unsigned i = 0) { + return impl::MapLhloOpToStdScalarOp(op.getLoc(), result_types, + args, b); + } + + // Implementation for HLO ops except xla_hlo::CompareOp. + template , + typename = std::enable_if_t< + !std::is_same::value && + !std::is_same::value>> + static Value map(XlaOpTy op, ArrayRef result_types, + ArrayRef args, OpBuilder* b, int i = 0) { + return impl::MapLhloOpToStdScalarOp(op.getLoc(), result_types, + args, b); + } + + // Implementation for xla_lhlo::CompareOp. + template ::value>> + static Value map(xla_lhlo::CompareOp op, ArrayRef result_types, + ArrayRef args, OpBuilder* b) { + auto comparison_direction = op.comparison_direction(); + return impl::MapXlaCompareOpToStdScalarOp( + op.getLoc(), comparison_direction, result_types, args, b); + } + + // Implementation for xla_hlo::CompareOp. + template ::value>> + static Value map(xla_hlo::CompareOp op, ArrayRef result_types, + ArrayRef args, OpBuilder* b) { + auto comparison_direction = op.comparison_direction(); + return impl::MapXlaCompareOpToStdScalarOp( + op.getLoc(), comparison_direction, result_types, args, b); + } +}; + +} // namespace xla_lhlo +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_XLA_TO_SCALAR_OP_H_ diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/passes.h b/include/mlir-hlo/Dialect/mhlo/transforms/passes.h new file mode 100644 index 0000000..3471587 --- /dev/null +++ b/include/mlir-hlo/Dialect/mhlo/transforms/passes.h @@ -0,0 +1,105 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_PASSES_H_ +#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_PASSES_H_ + +#include + +#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/ArrayRef.h" + +namespace mlir { + +class FuncOp; +class ModuleOp; +class Operation; +template +class OperationPass; +class Pass; + +namespace xla_hlo { + +/// Lowers HLO control flow ops to the Standard dialect. +std::unique_ptr> createLegalizeControlFlowPass(); + +/// Lowers from HLO dialect to Standard dialect. +std::unique_ptr> createLegalizeToStdPass(); + +/// Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary +/// buffers if necessary. If `results_escape_functions` is set to true, +/// allocated buffers for function results will be returned and escape the +/// function. Otherwise, the signature is rewritten with extra arguments for the +/// buffers that are to be used for results. +std::unique_ptr> createLegalizeToLhloPass( + bool results_escape_functions = false); + +// Lowers from HLO dialect to Linalg dialect. +std::unique_ptr> createLegalizeHloToLinalgPass(); + +// Transforms unranked HLO operations to ranked ones where possible. +std::unique_ptr> createTransformUnrankedHloPass(); + +// Sinks constants implicitly captured in control flow regions. This is +// necessary to export to XLA. +std::unique_ptr> createSinkConstantsToControlFlowPass(); + +// fuse xla_hlo ops to kLoop/kInput fusion patterns +std::unique_ptr> createXlaHloFusionPass(); + +} // namespace xla_hlo + +namespace xla_lhlo { + +// Lowers from LHLO dialect to Affine dialect. +std::unique_ptr> createLegalizeToAffinePass(); + +// Lowers from LHLO dialect to Linalg dialect. +std::unique_ptr> createLegalizeLhloToLinalgPass(); + +// Lowers from LHLO dialect to GPU dialect. +std::unique_ptr> createLegalizeToGpuPass(); + +// Fuses linalg ops obtained after LHLO lowering. To enable fusion, +// operations are first tiled. +// +// When 'use_parallel_loops' is set, the tiling will use scf.parallel +// operations. Otherwise, scf.for operations are used. +// +// 'tile_sizes' provides the tile sizes to use for tiling. If the linalg +// operation has more dimensions than tile sizes provided, 1 is used as +// default. +std::unique_ptr> createLhloFuseLinalg( + bool use_parallel_loops = false, llvm::ArrayRef tile_sizes = {}); + +// Removes unnecessary LHLO copies which copy from the allocated buffers to the +// block arguments. The block arguments are used instead of all uses of these +// buffers. The buffers are freed. This pass only works in regions that contain +// a single block. +std::unique_ptr createLhloCopyRemovalPass(); + +// Lowers from LHLO dialect to parallel loops. +std::unique_ptr> createLegalizeLhloToParallelLoopsPass(); + +} // namespace xla_lhlo + +namespace xla { + +/// Lowers the standard TanhOp to an approximation that does not use intrinsics. +std::unique_ptr> createLegalizeTanhToApproximationPass(); + +} // namespace xla +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_PASSES_H_ diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h b/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h new file mode 100644 index 0000000..606d510 --- /dev/null +++ b/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h @@ -0,0 +1,106 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_REWRITERS_H_ +#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_REWRITERS_H_ + +#include + +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h" + +namespace mlir { +class LLVMTypeConverter; +class LowerToLLVMOptions; +class OwningRewritePatternList; +class BufferAssignmentPlacer; +namespace xla_hlo { + +// Collection of rewrite patterns for lowering a general dot product. +void PopulateGeneralDotOpLoweringPatterns(OwningRewritePatternList *patterns, + MLIRContext *ctx); + +// Collection of rewrite patterns for lowering complex operations to equivalent +// float operations. +void PopulateComplexLoweringPatterns(MLIRContext *context, + OwningRewritePatternList *patterns); + +void PopulateXlaToStdPatterns(OwningRewritePatternList *patterns, + MLIRContext *ctx); + +// Collection of rewrite patterns for lowering of HLO to LHLO dialect. +void populateHLOToLHLOConversionPattern( + MLIRContext *context, BufferAssignmentPlacer *bufferAssignment, + TypeConverter *converter, OwningRewritePatternList *patterns); +// Collection of rewrite patterns for lowering of HLO to Linalg dialect. +void populateHLOToLinalgConversionPattern(MLIRContext *context, + OwningRewritePatternList *patterns); + +// Sets up legality definitions for materializing broadcasts. +void SetupMaterializeBroadcastsLegality(MLIRContext *context, + ConversionTarget *conversionTarget); + +// Populates a collection of rewrite patterns for materializing broadcast +// attributes to equivalent sequences of ops. +void PopulateMaterializeBroadcastsPatterns(MLIRContext *context, + OwningRewritePatternList *patterns); + +// Sets up legality definitions for element-wise operations on ranked tensors. +void SetupTransformUnrankedHloLegality(MLIRContext *context, + ConversionTarget *conversionTarget); + +// Populates a collection of rewrite patterns to realize element-wise operations +// on ranked tensors where possible. +void PopulateTransformUnrankedHloPatterns(MLIRContext *context, + OwningRewritePatternList *patterns); + +// Populate a collection of conversion patterns for un-fusing +// batch_norm_inference and batch_norm_training into constituent HLO ops. +// TODO(laurenzo): Implement un-fusing of batch_norm_training. +void PopulateUnfuseBatchNormPatterns(MLIRContext *context, + OwningRewritePatternList *patterns); + +} // namespace xla_hlo + +namespace xla_lhlo { + +/// Collect a set of patterns to convert from the LHLO dialect to LLVM. +void PopulateLhloToLLVMConversionPatterns(const LowerToLLVMOptions &options, + LLVMTypeConverter *converter, + OwningRewritePatternList *patterns); + +} // namespace xla_lhlo + +namespace xla_chlo { + +// Populates a collection of conversion patterns for legalizing client-HLO to +// HLO. +void PopulateLegalizeChloToHloPatterns(MLIRContext *context, + OwningRewritePatternList *patterns); + +} // namespace xla_chlo + +namespace xla { + +// Populates a pattern that translates the standard TanhOp to an approximation +// that does not use intrinsics. +void PopulateTanhToApproximationPatterns(MLIRContext *context, + OwningRewritePatternList *patterns); + +} // namespace xla +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_REWRITERS_H_ diff --git a/include/mlir-hlo/utils/cycle_detector.h b/include/mlir-hlo/utils/cycle_detector.h new file mode 100644 index 0000000..eea0f25 --- /dev/null +++ b/include/mlir-hlo/utils/cycle_detector.h @@ -0,0 +1,165 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_CYCLE_DETECTOR_H_ +#define TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_CYCLE_DETECTOR_H_ + +#include + +#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/DenseMap.h" + +namespace mlir { + +// ------------------------------------------------------------------- + +// This file contains a light version of GraphCycles implemented in +// tensorflow/compiler/jit/graphcycles/graphcycles.h +// +// We re-implement it here because we do not want to rely +// on TensorFlow data structures, and hence we can move +// corresponding passes to llvm repo. easily in case necessnary. + +// -------------------------------------------------------------------- + +// This is a set data structure that provides a deterministic iteration order. +// The iteration order of elements only depends on the sequence of +// inserts/deletes, so as long as the inserts/deletes happen in the same +// sequence, the set will have the same iteration order. +// +// Assumes that T can be cheaply copied for simplicity. +template +class OrderedSet { + public: + // Inserts `value` into the ordered set. Returns true if the value was not + // present in the set before the insertion. + bool Insert(T value) { + bool new_insertion = + value_to_index_.insert({value, value_sequence_.size()}).second; + if (new_insertion) { + value_sequence_.push_back(value); + } + return new_insertion; + } + + // Removes `value` from the set. Assumes `value` is already present in the + // set. + void Erase(T value) { + auto it = value_to_index_.find(value); + + // Since we don't want to move values around in `value_sequence_` we swap + // the value in the last position and with value to be deleted and then + // pop_back. + value_to_index_[value_sequence_.back()] = it->second; + std::swap(value_sequence_[it->second], value_sequence_.back()); + value_sequence_.pop_back(); + value_to_index_.erase(it); + } + + void Reserve(size_t new_size) { + value_to_index_.reserve(new_size); + value_sequence_.reserve(new_size); + } + + void Clear() { + value_to_index_.clear(); + value_sequence_.clear(); + } + + bool Contains(T value) const { return value_to_index_.count(value); } + size_t Size() const { return value_sequence_.size(); } + + const std::vector& GetSequence() const { return value_sequence_; } + + private: + // The stable order that we maintain through insertions and deletions. + std::vector value_sequence_; + + // Maps values to their indices in `value_sequence_`. + llvm::DenseMap value_to_index_; +}; + +// --------------------------------------------------------------------- + +// GraphCycles detects the introduction of a cycle into a directed +// graph that is being built up incrementally. +// +// Nodes are identified by small integers. It is not possible to +// record multiple edges with the same (source, destination) pair; +// requests to add an edge where one already exists are silently +// ignored. +// +// It is also not possible to introduce a cycle; an attempt to insert +// an edge that would introduce a cycle fails and returns false. +// +// GraphCycles uses no internal locking; calls into it should be +// serialized externally. + +// Performance considerations: +// Works well on sparse graphs, poorly on dense graphs. +// Extra information is maintained incrementally to detect cycles quickly. +// InsertEdge() is very fast when the edge already exists, and reasonably fast +// otherwise. +// FindPath() is linear in the size of the graph. +// The current implementation uses O(|V|+|E|) space. + +class GraphCycles { + public: + explicit GraphCycles(int32_t num_nodes); + ~GraphCycles(); + + // Attempt to insert an edge from x to y. If the + // edge would introduce a cycle, return false without making any + // changes. Otherwise add the edge and return true. + bool InsertEdge(int32_t x, int32_t y); + + // Remove any edge that exists from x to y. + void RemoveEdge(int32_t x, int32_t y); + + // Return whether there is an edge directly from x to y. + bool HasEdge(int32_t x, int32_t y) const; + + // Contracts the edge from 'a' to node 'b', merging nodes 'a' and 'b'. One of + // the nodes is removed from the graph, and edges to/from it are added to + // the remaining one, which is returned. If contracting the edge would create + // a cycle, does nothing and return no value. + llvm::Optional ContractEdge(int32_t a, int32_t b); + + // Return whether dest_node `y` is reachable from source_node `x` + // by following edges. This is non-thread-safe version. + bool IsReachable(int32_t x, int32_t y); + + // Return a copy of the successors set. This is needed for code using the + // collection while modifying the GraphCycles. + std::vector SuccessorsCopy(int32_t node) const; + + // Returns all nodes in post order. + // + // If there is a path from X to Y then X appears after Y in the + // returned vector. + std::vector AllNodesInPostOrder() const; + + // ---------------------------------------------------- + struct Rep; + + private: + GraphCycles(const GraphCycles&) = delete; + GraphCycles& operator=(const GraphCycles&) = delete; + + Rep* rep_; // opaque representation +}; + +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_CYCLE_DETECTOR_H_ diff --git a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc new file mode 100644 index 0000000..ed5282f --- /dev/null +++ b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc @@ -0,0 +1,242 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Shape/IR/Shape.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OperationSupport.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/broadcast_utils.h" + +namespace mlir { +namespace xla_chlo { + +namespace { + +// Converts binary ops that statically are determined to not broadcast directly +// to the corresponding xla_hlo non-broadcasting op. +template +struct ConvertTrivialNonBroadcastBinaryOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(ChloOpTy op, + PatternRewriter &rewriter) const override { + // Only rewrite for statically determinable non-broadcasting cases. + auto lhs_type = op.lhs().getType().template dyn_cast(); + auto rhs_type = op.rhs().getType().template dyn_cast(); + if (!lhs_type || !rhs_type) return failure(); + + // Requires rank broadcast. + if (lhs_type.getRank() != rhs_type.getRank()) return failure(); + // Any dynamic dimension may require broadcasting and requires more + // analysis. + if (!lhs_type.hasStaticShape() || !rhs_type.hasStaticShape()) + return failure(); + + for (auto extents : llvm::zip(lhs_type.getShape(), rhs_type.getShape())) { + auto lhs_extent = std::get<0>(extents); + auto rhs_extent = std::get<1>(extents); + if (lhs_extent != rhs_extent) { + return failure(); + } + } + + rewriter.replaceOp(op, {Adaptor::CreateOp(op, op.getResult().getType(), + op.lhs(), op.rhs(), rewriter)}); + return success(); + } +}; + +// Converts a binary op with ranked broadcasting operands to explicitly +// broadcast and invoke the corresponding xla_hlo non-broadcasting op. +// Note that dynamic broadcasting supported by this pattern is only valid for +// "numpy" broadcasting semantics as defined here: +// https://docs.scipy.org/doc/numpy/reference/ufuncs.html +// Specifically, this includes the following cases: +// - Same rank broadcast (operands have the same static rank). +// - Different-rank broadcast, either without a broadcast_dims attribte or +// with the broadcast_dims attribute set to map to a prefix padding. +// - Legal combinations of degenerate (1-dim) implicit broadcasting. +// The restriction on broadcast_dims derives from the definition of the +// `shape.broadcast` op, which only supports prefix-padding. +// +// It may be possible to expand this pattern to operate on unranked tensors in +// the future by emitting more code to dynamically differentiate based on rank. +// Whether that is of any practical benefit remains to be seen. +template +struct ConvertRankedDynamicBroadcastBinaryOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(ChloOpTy op, + PatternRewriter &rewriter) const override { + // Only support ranked operands. + Value lhs = op.lhs(); + Value rhs = op.rhs(); + auto lhs_type = lhs.getType().dyn_cast(); + auto rhs_type = rhs.getType().dyn_cast(); + auto result_type = + op.getResult().getType().template dyn_cast(); + if (!lhs_type || !rhs_type || !result_type) return failure(); + + // Check for "numpy"-style rank broadcast. + auto broadcast_dimensions = op.broadcast_dimensions(); + if (broadcast_dimensions && + !xla::IsLegalNumpyRankedBroadcast(lhs, rhs, *broadcast_dimensions)) { + // Note: It is unclear whether the general specification of explicit + // broadcast_dimensions on binary ops is a feature we want to carry + // forward. While it can technically be implemented for ranked-dynamic, + // it is incompatible with unranked inputs. If this warning is emitted + // in real programs, it is an indication that the feature should be + // implemented versus just falling back on the more standard definition + // of numpy-like prefix-padding. + op.emitWarning() << "unsupported non prefix-padded dynamic rank " + << "broadcast_dimensions = " << *broadcast_dimensions; + return failure(); + } + + // Compute result shape. + auto loc = op.getLoc(); + + // Insert a constraint on the shapes being broadcastable and insert all + // future code into an assuming block reliant on the constraint. + Value lhs_shape = rewriter.create(loc, lhs); + Value rhs_shape = rewriter.create(loc, rhs); + auto broadcastable_cstr = + rewriter.create(loc, lhs_shape, rhs_shape); + auto assuming_op = rewriter.create( + loc, ArrayRef{result_type}, broadcastable_cstr.result()); + + OpBuilder::InsertionGuard guard(rewriter); + rewriter.createBlock(&assuming_op.doRegion()); + + int64_t result_rank = std::max(lhs_type.getRank(), rhs_type.getRank()); + Value result_extents = + xla::ComputeBinaryElementwiseBroadcastingResultExtents(loc, lhs, rhs, + rewriter); + + // Note that we unconditionally emit DynamicBroadcastInDim ops and let + // downstream canonicalizations fold them away if possible. This is + // because, in the dynamic case, there are many corner cases regarding + // when it is safe to omit, and some of them require analysis to prove + // properly. + auto lhs_broadcast_dimensions = llvm::to_vector<4>( + llvm::seq(result_rank - lhs_type.getRank(), result_rank)); + Value broadcasted_lhs = rewriter.create( + loc, + RankedTensorType::get(result_type.getShape(), + lhs_type.getElementType()), + lhs, result_extents, + rewriter.getI64TensorAttr(lhs_broadcast_dimensions)); + auto rhs_broadcast_dimensions = llvm::to_vector<4>( + llvm::seq(result_rank - rhs_type.getRank(), result_rank)); + Value broadcasted_rhs = rewriter.create( + loc, + RankedTensorType::get(result_type.getShape(), + rhs_type.getElementType()), + rhs, result_extents, + rewriter.getI64TensorAttr(rhs_broadcast_dimensions)); + + // And generate the final non-broadcasted binary op. + Value final_result = Adaptor::CreateOp(op, result_type, broadcasted_lhs, + broadcasted_rhs, rewriter); + rewriter.create(loc, final_result); + rewriter.replaceOp(op, {assuming_op.getResult(0)}); + return success(); + } +}; + +template +void PopulateForBinaryOp(MLIRContext *context, + OwningRewritePatternList *patterns) { + patterns + ->insert>( + context, 10); + patterns->insert< + ConvertRankedDynamicBroadcastBinaryOp>( + context, 5); +} + +template +struct HloBinaryElementwiseAdaptor { + static ToOpTy CreateOp(FromOpTy from_op, Type result_type, + Value broadcasted_lhs, Value broadcasted_rhs, + OpBuilder &builder) { + return builder.create(from_op.getLoc(), result_type, + broadcasted_lhs, broadcasted_rhs); + } +}; + +struct HloComplexAdaptor { + static xla_hlo::ComplexOp CreateOp(BroadcastComplexOp from_op, + Type result_type, Value broadcasted_lhs, + Value broadcasted_rhs, + OpBuilder &builder) { + return builder.create(from_op.getLoc(), result_type, + broadcasted_lhs, broadcasted_rhs); + } +}; + +struct HloCompareAdaptor { + static xla_hlo::CompareOp CreateOp(BroadcastCompareOp from_op, + Type result_type, Value broadcasted_lhs, + Value broadcasted_rhs, + OpBuilder &builder) { + return builder.create(from_op.getLoc(), result_type, + broadcasted_lhs, broadcasted_rhs, + from_op.comparison_direction()); + } +}; + +} // namespace + +void PopulateLegalizeChloToHloPatterns(MLIRContext *context, + OwningRewritePatternList *patterns) { + // Instantiate conversion templates for conforming binary elementwise ops + // that do not have different dtypes between operands and results and do + // not have special attributes that need to be preserved. +#define POPULATE_BCAST(ChloOp, HloOp) \ + PopulateForBinaryOp>(context, \ + patterns); + + POPULATE_BCAST(BroadcastAddOp, xla_hlo::AddOp); + POPULATE_BCAST(BroadcastAndOp, xla_hlo::AndOp); + POPULATE_BCAST(BroadcastAtan2Op, xla_hlo::Atan2Op); + POPULATE_BCAST(BroadcastDivOp, xla_hlo::DivOp); + POPULATE_BCAST(BroadcastMaxOp, xla_hlo::MaxOp); + POPULATE_BCAST(BroadcastMinOp, xla_hlo::MinOp); + POPULATE_BCAST(BroadcastMulOp, xla_hlo::MulOp); + POPULATE_BCAST(BroadcastOrOp, xla_hlo::OrOp); + POPULATE_BCAST(BroadcastPowOp, xla_hlo::PowOp); + POPULATE_BCAST(BroadcastRemOp, xla_hlo::RemOp); + POPULATE_BCAST(BroadcastShiftLeftOp, xla_hlo::ShiftLeftOp); + POPULATE_BCAST(BroadcastShiftRightArithmeticOp, + xla_hlo::ShiftRightArithmeticOp); + POPULATE_BCAST(BroadcastShiftRightLogicalOp, xla_hlo::ShiftRightLogicalOp); + POPULATE_BCAST(BroadcastSubOp, xla_hlo::SubOp); + POPULATE_BCAST(BroadcastXorOp, xla_hlo::XorOp); + + // Broadcasting ops requiring special construction. + PopulateForBinaryOp(context, patterns); + PopulateForBinaryOp(context, patterns); +} + +} // namespace xla_chlo +} // namespace mlir diff --git a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc new file mode 100644 index 0000000..0e5d5b1 --- /dev/null +++ b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc @@ -0,0 +1,57 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Shape/IR/Shape.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" + +namespace mlir { +namespace xla_chlo { + +namespace { + +struct TestChloLegalizeToHloPass + : public PassWrapper { + void runOnFunction() override { + ConversionTarget conversionTarget(getContext()); + OwningRewritePatternList conversionPatterns; + + conversionTarget.addIllegalDialect(); + // Consider the xla_hlo dialect legal for tests. + conversionTarget.addLegalDialect(); + // The conversion uses helpers from the Standard dialect. + conversionTarget.addLegalDialect(); + conversionTarget.addLegalDialect(); + + PopulateLegalizeChloToHloPatterns(&getContext(), &conversionPatterns); + + if (failed(applyPartialConversion(getFunction(), conversionTarget, + conversionPatterns))) { + return signalPassFailure(); + } + } +}; + +} // namespace + +} // namespace xla_chlo +} // namespace mlir + +static mlir::PassRegistration pass( + "test-xla-chlo-legalize-to-hlo", + "Test pass for applying chlo -> hlo legalization patterns"); diff --git a/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc b/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc new file mode 100644 index 0000000..6fd5805 --- /dev/null +++ b/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc @@ -0,0 +1,493 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file implements logic for lowering HLO dialect to LHLO dialect. + +#include "third_party/absl/memory/memory.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/AffineMap.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/BlockAndValueMapping.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Function.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Location.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/BufferPlacement.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" + +namespace mlir { +namespace xla_hlo { +namespace { + +template +using BaseOpConversion = BufferAssignmentOpConversionPattern; +using StdReturnOpConverter = + detail::BufferAssignmentReturnOpConverter; + +Value InsertDynamicAllocAndDealloc(Location loc, Value result, + Value shape_operand, + ConversionPatternRewriter* rewriter) { + auto result_type = result.getType().dyn_cast(); + if (!result_type) { + result.getDefiningOp()->emitOpError() + << "tensor to buffer conversion expects ranked results"; + } + auto memref_type = + MemRefType::get(result_type.getShape(), result_type.getElementType()); + + Operation* op = result.getDefiningOp(); + + // Extract the required element out of the vector. + SmallVector dynamic_operands; + for (auto shape_element : llvm::enumerate(result_type.getShape())) { + if (shape_element.value() != ShapedType::kDynamicSize) continue; + Value index = rewriter->create( + loc, rewriter->getIntegerAttr(rewriter->getIndexType(), + shape_element.index())); + Value alloc_operand = rewriter->create(loc, shape_operand, + ValueRange{index}); + if (!alloc_operand.getType().isIndex()) { + alloc_operand = rewriter->create(loc, alloc_operand, + rewriter->getIndexType()); + } + dynamic_operands.push_back(alloc_operand); + } + + // Insert in front of op to ensure sizes are available. + OpBuilder allocBuilder(op); + auto alloc = allocBuilder.create(loc, memref_type, dynamic_operands); + return alloc; +} + +Value InsertAlloc(Location loc, OpResult result, + BufferAssignmentPlacer* bufferAssignment, + ConversionPatternRewriter* rewriter) { + auto result_type = result.getType().dyn_cast(); + if (!result_type || !result_type.hasStaticShape()) { + result.getDefiningOp()->emitOpError() + << "tensor to buffer conversion expects statically shaped results"; + } + auto memref_type = + MemRefType::get(result_type.getShape(), result_type.getElementType()); + OpBuilder::InsertionGuard guard(*rewriter); + rewriter->restoreInsertionPoint( + bufferAssignment->computeAllocPosition(result)); + auto alloc = rewriter->create(loc, memref_type); + return alloc; +} + +template +class HloToLhloOpConverter : public BaseOpConversion { + public: + using BaseOpConversion::BaseOpConversion; + LogicalResult matchAndRewrite( + HloOpTy hloOp, ArrayRef operands, + ConversionPatternRewriter& rewriter) const final { + Operation* op = hloOp.getOperation(); + const auto& original_results = op->getResults(); + SmallVector buffer_args(operands.begin(), operands.end()); + for (auto result : llvm::enumerate(original_results)) { + RankedTensorType resultType = + result.value().getType().dyn_cast(); + if (!resultType) { + return failure(); + } + if (resultType.hasStaticShape()) { + buffer_args.push_back(InsertAlloc(op->getLoc(), result.value(), + this->bufferAssignment, &rewriter)); + } else { + SmallVector results_shape; + auto shape_type_op = dyn_cast(op); + if (!shape_type_op) return failure(); + if (failed( + shape_type_op.reifyReturnTypeShapes(rewriter, results_shape))) + return failure(); + buffer_args.push_back(InsertDynamicAllocAndDealloc( + op->getLoc(), result.value(), results_shape.front(), &rewriter)); + } + } + rewriter.create>(op->getLoc(), llvm::None, + buffer_args, op->getAttrs()); + rewriter.replaceOp(op, ArrayRef(buffer_args).slice(operands.size())); + return success(); + } +}; + +struct HloToLhloDynamicBroadcastInDimOpConverter + : public BaseOpConversion { + public: + using BaseOpConversion::BaseOpConversion; + + LogicalResult matchAndRewrite( + xla_hlo::DynamicBroadcastInDimOp op, ArrayRef operands, + ConversionPatternRewriter& rewriter) const final { + auto loc = op.getLoc(); + Value resultBuffer = InsertDynamicAllocAndDealloc( + loc, op.getResult(), op.output_dimensions(), &rewriter); + + Value transformed_operand = + InsertDynamicMemrefCastOp(op, operands.front(), &rewriter); + rewriter.create( + loc, transformed_operand, resultBuffer, op.broadcast_dimensions()); + + rewriter.replaceOp(op, {resultBuffer}); + + return success(); + } + + private: + // Inserts dynamic memref to change the layout of the memref to put 0-stride + // and size of the target dimension if size-1 dimension expansion is + // necessary. + xla_lhlo::DynamicMemRefCastOp InsertDynamicMemrefCastOp( + xla_hlo::DynamicBroadcastInDimOp op, Value operand, OpBuilder* b) const { + auto loc = op.getLoc(); + auto operand_type = operand.getType().cast(); + auto operand_shape = operand_type.getShape(); + + SmallVector sizes, strides; + sizes.reserve(operand_shape.size()); + strides.reserve(operand_shape.size()); + + Value zero = b->create(loc, 0); + Value one = b->create(loc, 1); + for (auto dim : llvm::enumerate(op.broadcast_dimensions())) { + Value broadcast_dim_value = + b->create(loc, dim.value().getSExtValue()); + Value result_dim_size = b->create( + loc, op.output_dimensions(), broadcast_dim_value); + Value operand_dim_size = + ShapedType::isDynamic(operand_shape[dim.index()]) + ? b->create(loc, operand, dim.index()).getResult() + : b->create(loc, operand_shape[dim.index()]) + .getResult(); + + // TODO(pifon): Revisit if this cast is needed. Maybe we can use + // tensor for `output_dimensions` as well. + if (!result_dim_size.getType().isIndex()) { + result_dim_size = + b->create(loc, result_dim_size, b->getIndexType()); + } + + // There can be two cases: + // 1) Operand dim == result dim => expansion is not needed => stride := 1. + // 2) Operand dim < result dim => expansion is needed => stride := 0. + Value is_expansion = b->create(loc, CmpIPredicate::slt, + operand_dim_size, result_dim_size); + strides.push_back( + b->create(loc, is_expansion, zero, one)); + + // Size of input dim can be set to the size of the corresponding output + // dimension for both cases. + sizes.push_back(result_dim_size); + } + + // Type-erased memref type with static rank, dynamic sizes and strides. + SmallVector dynamic_layout(operand_shape.size(), + MemRefType::kDynamicStrideOrOffset); + SmallVector dynamic_shape(operand_shape.size(), + MemRefType::kDynamicSize); + auto type_erased_memref_type = MemRefType::get( + dynamic_shape, operand_type.getElementType(), + makeStridedLinearLayoutMap(dynamic_layout, + /*offset=*/0, b->getContext())); + + auto transformed_operand = b->create( + loc, type_erased_memref_type, operand, sizes, strides); + return transformed_operand; + } +}; + +struct HloToLhloReduceOpConverter : public BaseOpConversion { + public: + using BaseOpConversion::BaseOpConversion; + + LogicalResult matchAndRewrite( + xla_hlo::ReduceOp op, ArrayRef operands, + ConversionPatternRewriter& rewriter) const final { + auto loc = op.getLoc(); + // TODO(b/137624192) Implement variadic reduce. + if (op.getNumResults() != 1) return failure(); + if (!llvm::hasSingleElement(op.body())) { + return op.emitOpError() + << "tensor to buffer conversion expects a single block " + "in the region containing the operation"; + } + const auto& original_results = op.getResults(); + SmallVector buffer_args(operands.begin(), operands.end()); + for (auto result : original_results) { + buffer_args.push_back( + InsertAlloc(loc, result, this->bufferAssignment, &rewriter)); + } + auto new_op = rewriter.create( + loc, llvm::None, buffer_args, op.getAttrs()); + + // Copy over the operations inside the region. + rewriter.inlineRegionBefore(op.body(), new_op.body(), new_op.body().end()); + + // Create new block arguments with correct type. + auto& entry_block = new_op.body().front(); + int original_arg_count = entry_block.getNumArguments(); + for (int i = 0; i < original_arg_count; ++i) { + auto old_arg = entry_block.getArgument(i); + auto old_type = old_arg.getType().cast(); + auto new_type = + MemRefType::get(old_type.getShape(), old_type.getElementType()); + auto new_arg = entry_block.addArgument(new_type); + rewriter.replaceUsesOfBlockArgument(old_arg, new_arg); + } + // Add an argument for the result. + entry_block.addArgument( + entry_block.getArgument(original_arg_count).getType()); + // Remove the old arguments. + for (int i = original_arg_count - 1; i >= 0; --i) { + entry_block.eraseArgument(i); + } + // Insert terminator at the end. + rewriter.setInsertionPointToEnd(&entry_block); + rewriter.create(loc); + + rewriter.replaceOp(op, ArrayRef(buffer_args).slice(operands.size())); + + return success(); + } +}; + +class HloToLhloTensorLoadOpConverter + : public BaseOpConversion { + public: + using BaseOpConversion::BaseOpConversion; + LogicalResult matchAndRewrite( + mlir::TensorLoadOp op, ArrayRef operands, + ConversionPatternRewriter& rewriter) const final { + rewriter.replaceOp(op, operands); + return success(); + } +}; + +// TODO(b/137624192): Rewrite into a copy and elide copy if possible. +class HloToLhloTensorStoreOpConverter + : public BaseOpConversion { + public: + using BaseOpConversion::BaseOpConversion; + + LogicalResult matchAndRewrite( + mlir::TensorStoreOp op, ArrayRef operands, + ConversionPatternRewriter& rewriter) const final { + rewriter.replaceOpWithNewOp( + op, llvm::None, operands.front(), operands.back()); + return success(); + } +}; + +// Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary +// buffers if necessary. +// +// Example fusion with HLO ops. +// +// func @fusion(%arg0: memref<2x2xf32>, +// %arg1: memref<2x2xf32>, +// %arg2: memref<2x2xf32>, +// %arg3: memref<2x2xf32>) { +// "xla_lhlo.fusion"() ({ +// %0 = tensor_load %arg1 : memref<2x2xf32> +// %1 = tensor_load %arg2 : memref<2x2xf32> +// %2 = "xla_hlo.add"(%0, %1) : +// (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> +// %3 = tensor_load %arg0 : memref<2x2xf32> +// %4 = "xla_hlo.multiply"(%2, %3) : +// (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> +// tensor_store %4, %arg3 : memref<2x2xf32> +// "xla_lhlo.terminator"() : () -> () +// }) : () -> () +// return +// } +// +// Transformed fusion with LHLO ops. +// func @fusion(%arg0: memref<2x2xf32>, +// %arg1: memref<2x2xf32>, +// %arg2: memref<2x2xf32>, +// %arg3: memref<2x2xf32>) { +// "xla_lhlo.fusion"() ( { +// %0 = alloc() : memref<2x2xf32> +// "xla_lhlo.add"(%arg1, %arg2, %0) : +// (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () +// "xla_lhlo.multiply"(%0, %arg0, %arg3) : +// (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () +// "xla_lhlo.terminator"() : () -> () +// }) : () -> () +// return +// } +// +// FuncOp signature conversion example: +// +// func @func_op(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { +// %0 = "xla_hlo.maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> +// tensor<4xf32> %1 = "xla_hlo.add"(%arg0, %0) : (tensor<4xf32>, +// tensor<4xf32>) -> tensor<4xf32> return %1 : tensor<4xf32> +// } +// +// Transformed function with an extra argument for the result. The types have +// been converted from tensor to memref. +// +// func @func_op(%arg0: memref<4xf32>, +// %arg1: memref<4xf32>, +// %arg2: memref<4xf32>) { +// %0 = alloc() : memref<4xf32> + +// "xla_lhlo.maximum"(%arg0, %arg1, %0) : +// (memref<4xf32>, memref<4xf32>, memref<4xf32>) -> () +// %1 = alloc() : memref<4xf32> +// "xla_lhlo.add"(%arg0, %0, %1) : +// (memref<4xf32>, memref<4xf32>, memref<4xf32>) -> () +// "xla_lhlo.copy"(%1, %arg2) : (memref<4xf32>, memref<4xf32>) -> () +// "xla_lhlo.terminator"() : () -> () +// } + +struct HloLegalizeToLhlo + : public PassWrapper> { + public: + HloLegalizeToLhlo() = default; + HloLegalizeToLhlo(const HloLegalizeToLhlo& o) { + this->results_escape_function = o.results_escape_function.getValue(); + } + explicit HloLegalizeToLhlo(bool results_escape_function) { + this->results_escape_function.setValue(results_escape_function); + } + + void runOnOperation() override { + OwningRewritePatternList patterns; + auto& context = getContext(); + ConversionTarget target(context); + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addLegalOp(); + target.addLegalOp(); + target.addIllegalDialect(); + + BufferAssignmentTypeConverter converter; + target.addDynamicallyLegalOp([&](FuncOp op) { + auto inputs = op.getType().getInputs(); + return llvm::all_of(inputs, + [](Type input) { return input.isa(); }) && + converter.isLegal(&op.getBody()); + }); + target.addDynamicallyLegalOp([&](mlir::ReturnOp returnOp) { + return std::all_of(returnOp.operand_type_begin(), + returnOp.operand_type_end(), + [](Type type) { return type.isa(); }); + }); + + auto module = getOperation(); + WalkResult result = module.walk([&](FuncOp func) -> WalkResult { + BufferAssignmentPlacer bufferAssignment(func); + OwningRewritePatternList patterns; + populateHLOToLHLOConversionPattern(func.getContext(), &bufferAssignment, + &converter, &patterns); + if (results_escape_function) { + populateWithBufferAssignmentOpConversionPatterns< + mlir::ReturnOp, mlir::ReturnOp, xla_lhlo::CopyOp, + /*allowMemrefFunctionResults=*/true>(&context, &bufferAssignment, + &converter, &patterns); + } else { + populateWithBufferAssignmentOpConversionPatterns< + mlir::ReturnOp, mlir::ReturnOp, xla_lhlo::CopyOp, + /*allowMemrefFunctionResults=*/false>(&context, &bufferAssignment, + &converter, &patterns); + } + return applyPartialConversion(func, target, patterns); + }); + if (result.wasInterrupted()) { + signalPassFailure(); + } + } + + private: + Option results_escape_function{ + *this, "results-escape-function", + llvm::cl::desc( + "Allocate the results of functions within the functions body"), + llvm::cl::init(false)}; +}; +} // namespace + +void populateHLOToLHLOConversionPattern( + MLIRContext* context, BufferAssignmentPlacer* bufferAssignment, + TypeConverter* converter, OwningRewritePatternList* patterns) { + // clang-format off + patterns->insert< + HloToLhloDynamicBroadcastInDimOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloOpConverter, + HloToLhloReduceOpConverter, + HloToLhloTensorLoadOpConverter, + HloToLhloTensorStoreOpConverter + >(context, bufferAssignment, converter); + // clang-format on +} + +std::unique_ptr> createLegalizeToLhloPass( + bool results_escape_function) { + return absl::make_unique(results_escape_function); +} + +static PassRegistration legalize_pass( + "hlo-legalize-to-lhlo", "Legalize from HLO dialect to LHLO dialect"); + +} // namespace xla_hlo +} // namespace mlir diff --git a/lib/Dialect/mhlo/transforms/legalize_control_flow.cc b/lib/Dialect/mhlo/transforms/legalize_control_flow.cc new file mode 100644 index 0000000..87910af --- /dev/null +++ b/lib/Dialect/mhlo/transforms/legalize_control_flow.cc @@ -0,0 +1,237 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file implements logic for lowering XLA dialect to Standard dialect. + +#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLExtras.h" +#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringSwitch.h" +#include "third_party/llvm/llvm-project/llvm/include/llvm/Support/Casting.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Block.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/BlockAndValueMapping.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Function.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/TypeUtilities.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/PassRegistry.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Support/LogicalResult.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h" + +using mlir::PassRegistration; + +namespace mlir { +namespace xla_hlo { +namespace { +struct LegalizeControlFlow + : public mlir::PassWrapper { + // Perform the lowering to MLIR control flow. + void runOnFunction() override; +}; + +// Replaces terminators for the newly created blocks from a targe region. +// These terminators are replaced with branch operations to a target block. +LogicalResult ReplaceTerminators(Region* region, Block* target_block, + Location loc, + const BlockAndValueMapping& mapper, + OpBuilder* builder) { + for (auto& old_block : region->getBlocks()) { + Block* block = mapper.lookup(&old_block); + auto return_op = dyn_cast(block->getTerminator()); + if (!return_op) continue; + builder->setInsertionPointToEnd(block); + builder->create(loc, target_block, return_op.getOperands()); + return_op.erase(); + } + + return success(); +} + +LogicalResult LowerIfOp(mlir::xla_hlo::IfOp if_op) { + Operation* op_inst = if_op.getOperation(); + mlir::OpBuilder builder(if_op); + auto orig_block = op_inst->getBlock(); + auto* tail_block = orig_block->splitBlock(op_inst); + auto loc = if_op.getLoc(); + + // Duplicate the true and false regions in the block between the sections + // before and after the conditional. + BlockAndValueMapping mapper; + if_op.true_branch().cloneInto(orig_block->getParent(), + Region::iterator(tail_block), mapper); + if_op.false_branch().cloneInto(orig_block->getParent(), + Region::iterator(tail_block), mapper); + + // Determine the blocks for the start of the true and false regions. + Block* true_block = mapper.lookup(&if_op.true_branch().front()); + Block* false_block = mapper.lookup(&if_op.false_branch().front()); + + // Perform the conditional branch into the true/false cases. + builder.setInsertionPointToEnd(orig_block); + + // Extract the predicate for checking branching, then branch to the true and + // false regions appropriately. + auto cond_value = builder.create(loc, if_op.pred()); + builder.create(loc, cond_value, true_block, + if_op.true_arg(), false_block, + if_op.false_arg()); + + // Replace the true case's return operations with a branch to the tail of + // the condition. + if (failed(ReplaceTerminators(&if_op.true_branch(), tail_block, loc, mapper, + &builder))) + return failure(); + if (failed(ReplaceTerminators(&if_op.false_branch(), tail_block, loc, mapper, + &builder))) + return failure(); + + tail_block->addArguments(if_op.getResult().getType()); + if_op.getResult().replaceAllUsesWith(tail_block->getArgument(0)); + + op_inst->erase(); + return success(); +} + +LogicalResult LowerWhileOp(mlir::xla_hlo::WhileOp while_op) { + // Converts an XLA while loop into control flow. This generates a set of MLIR + // blocks and branches, along with inlining the regions provided by the XLA + // while loop. The structure should be similar to below: + // + // + // %0 = "xla_hlo.while"(%arg0) {^cond(...){...}, ^body(...){...}} + // + auto* op_inst = while_op.getOperation(); + mlir::OpBuilder builder(while_op); + auto loc = while_op.getLoc(); + + // Break the block into four sections: + // orig_block - operations before the while and the branch into looping check. + // tail_block - operations after the while loop completes. + // cond_block - check the looping condition, then conditionally branch into + // the loop or, if condition is false, jump to the tail branch. + // body_block - inlined loop body, then jump back to the condition block. + auto* orig_block = op_inst->getBlock(); + auto* tail_block = orig_block->splitBlock(op_inst); + + BlockAndValueMapping mapper; + while_op.cond().cloneInto(orig_block->getParent(), + Region::iterator(tail_block), mapper); + while_op.body().cloneInto(orig_block->getParent(), + Region::iterator(tail_block), mapper); + + // Lookup the entry blocks for both condition and body. + auto* cond_block = mapper.lookup(&while_op.cond().front()); + auto* body_block = mapper.lookup(&while_op.body().front()); + + // Setup the end of the original block: + // + // br ^cond(%arg0) // Jumps to the condition statement. + builder.setInsertionPointToEnd(orig_block); + builder.create(loc, cond_block, while_op.getOperand()); + + // Updates the inlined condition blocks by replacing the return op with an + // extract_element and conditional branch. This changes the block below: + // ^cond(%0): + // + // "xla_hlo".return(%1) + // + // Into: + // ^cond(%0): + // + // %2 = extract_element %1[] : tensor // Extract the condition value. + // cond_br %2, ^body(%0), ^tail(%0) // Branch. + builder.setInsertionPointToStart(cond_block); + + // Replace the xla_hlo::ReturnOp with a branch back to the condition block. + // This is required as the xla_hlo::ReturnOp is used to mark the end of a + // block for regions nested inside of a operations (MLIR ReturnOp cannot be + // nested within an non-function region). + for (auto& block : while_op.cond()) { + auto new_block = mapper.lookup(&block); + + auto return_op = dyn_cast(new_block->getTerminator()); + if (!return_op) continue; + builder.setInsertionPointToEnd(new_block); + + auto return_value = return_op.getOperand(0); + auto cond_value = builder.create(loc, return_value); + + // Get the body block arguments. + llvm::SmallVector successor_args(cond_block->args_begin(), + cond_block->args_end()); + builder.create(loc, cond_value, body_block, + successor_args, tail_block, + successor_args); + return_op.erase(); + } + + // Updates the body blocks by replace the return op with an branch to the + // conditional block. This changes the block below: + // ^body(%0): + // + // "xla_hlo".return(%1) + // + // Into: + // ^body(%0): + // + // br ^cond(%0) // Branch. + for (auto& block : while_op.body()) { + auto new_block = mapper.lookup(&block); + auto return_op = + dyn_cast(new_block->getTerminator()); + if (!return_op) continue; + builder.setInsertionPointToEnd(new_block); + builder.create(loc, cond_block, return_op.getOperands()); + return_op.erase(); + } + + // Erase the original while loop. + tail_block->addArgument(while_op.getType()); + while_op.getResult().replaceAllUsesWith(tail_block->getArgument(0)); + op_inst->erase(); + + return success(); +} + +void LegalizeControlFlow::runOnFunction() { + auto func = getFunction(); + llvm::SmallVector if_ops; + func.walk([&](IfOp op) { if_ops.push_back(op); }); + + for (auto& op : if_ops) { + if (failed(LowerIfOp(op))) return signalPassFailure(); + } + + llvm::SmallVector while_ops; + func.walk([&](WhileOp op) { while_ops.push_back(op); }); + + for (auto& op : while_ops) { + if (failed(LowerWhileOp(op))) return signalPassFailure(); + } +} +} // namespace +} // namespace xla_hlo +} // namespace mlir + +std::unique_ptr> +mlir::xla_hlo::createLegalizeControlFlowPass() { + return std::make_unique(); +} + +static PassRegistration legalize_cf_pass( + "xla-legalize-control-flow", + "Legalize from XLA control flow to MLIR control flow"); diff --git a/lib/Dialect/mhlo/transforms/legalize_tanh_to_approximation.cc b/lib/Dialect/mhlo/transforms/legalize_tanh_to_approximation.cc new file mode 100644 index 0000000..fbc5f50 --- /dev/null +++ b/lib/Dialect/mhlo/transforms/legalize_tanh_to_approximation.cc @@ -0,0 +1,156 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file implements logic for lowering the tanh standard ops to an +// approximation. + +#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Function.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" + +namespace mlir { +namespace xla { +namespace { + +/// Emits the fast tanh approximation that is also used by XLA. +Value EmitTanhApproximation(Value input, Location loc, + PatternRewriter &rewriter) { + // For small values of x, we can approximate tanh(x)=x. For extremely small + // values of x (|x| < 1e-37), the other approximation would evaluate + // tanh(x) = 0. + constexpr float kCanUseApprox = 0.0004; + Value abs_value = rewriter.create(loc, input); + Value can_use_approx = + rewriter.create(loc, rewriter.getF32FloatAttr(kCanUseApprox)); + Value return_input = rewriter.create(loc, CmpFPredicate::OLT, + abs_value, can_use_approx); + // Clamp the input to [-c, c]. + Value max_clamp = rewriter.create( + loc, rewriter.getF32FloatAttr(7.90531110763549805f)); + Value smaller_than_max = + rewriter.create(loc, CmpFPredicate::ULE, input, max_clamp); + Value clamped_half = + rewriter.create(loc, smaller_than_max, input, max_clamp); + Value min_clamp = rewriter.create( + loc, rewriter.getF32FloatAttr(-7.90531110763549805f)); + Value larger_than_min = + rewriter.create(loc, CmpFPredicate::UGE, clamped_half, min_clamp); + Value input_clamped = + rewriter.create(loc, larger_than_min, clamped_half, min_clamp); + + static constexpr std::array numerator_coeffs{ + -2.76076847742355e-16f, 2.00018790482477e-13f, -8.60467152213735e-11f, + 5.12229709037114e-08f, 1.48572235717979e-05f, 6.37261928875436e-04f, + 4.89352455891786e-03f}; + + static constexpr std::array denominator_coeffs{ + 1.19825839466702e-06f, 1.18534705686654e-04f, 2.26843463243900e-03f, + 4.89352518554385e-03f}; + + Value input_squared = + rewriter.create(loc, input_clamped, input_clamped); + Value numerator = rewriter.create( + loc, rewriter.getF32FloatAttr(numerator_coeffs[0])); + for (int i = 1; i < numerator_coeffs.size(); i++) { + numerator = rewriter.create( + loc, rewriter.create(loc, input_squared, numerator), + rewriter.create( + loc, rewriter.getF32FloatAttr(numerator_coeffs[i]))); + } + + numerator = rewriter.create(loc, input_clamped, numerator); + + Value denominator = rewriter.create( + loc, rewriter.getF32FloatAttr(denominator_coeffs[0])); + for (int i = 1; i < denominator_coeffs.size(); i++) { + denominator = rewriter.create( + loc, rewriter.create(loc, input_squared, denominator), + rewriter.create( + loc, rewriter.getF32FloatAttr(denominator_coeffs[i]))); + } + + Value approx = rewriter.create(loc, numerator, denominator); + + return rewriter.create(loc, return_input, input, approx); +} + +class ApproximateTanhLowering : public OpRewritePattern { + public: + explicit ApproximateTanhLowering(MLIRContext *ctx) + : OpRewritePattern(ctx, 100) {} + + LogicalResult matchAndRewrite(TanhOp tanhOp, + PatternRewriter &rewriter) const override { + Type operand_type = tanhOp.getType(); + + if (operand_type.isF64()) { + // Similar to XLA, do not rewrite f64 as precision might matter. + return failure(); + } + + Location loc = tanhOp.getLoc(); + Value input = tanhOp.operand(); + if (operand_type.isF16()) { + input = rewriter.create(loc, input, rewriter.getF32Type()); + } + + // If we still do not have f32, fail. + if (!input.getType().isF32()) { + return failure(); + } + + Value result = EmitTanhApproximation(input, loc, rewriter); + + // Truncate back if needed. + if (operand_type.isF16()) { + result = rewriter.create(loc, result, rewriter.getF16Type()); + } + + rewriter.replaceOp(tanhOp, {result}); + return success(); + } +}; + +struct LegalizeTanhToApproximation + : public PassWrapper { + /// Perform the lowering of standard dialect operations to approximations. + void runOnFunction() override { + OwningRewritePatternList patterns; + PopulateTanhToApproximationPatterns(&getContext(), &patterns); + applyPatternsAndFoldGreedily(getFunction(), patterns); + } +}; + +} // anonymous namespace + +std::unique_ptr> +createLegalizeTanhToApproximationPass() { + return std::make_unique(); +} + +void PopulateTanhToApproximationPatterns(mlir::MLIRContext *context, + OwningRewritePatternList *patterns) { + patterns->insert(context); +} + +static PassRegistration legalize_pass( + "xla-legalize-tanh-to-approximation", + "Legalize tanh from standard dialect to an approximation"); + +} // namespace xla +} // namespace mlir diff --git a/lib/Dialect/mhlo/transforms/legalize_to_standard.cc b/lib/Dialect/mhlo/transforms/legalize_to_standard.cc new file mode 100644 index 0000000..f4e7b49 --- /dev/null +++ b/lib/Dialect/mhlo/transforms/legalize_to_standard.cc @@ -0,0 +1,208 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file implements logic for lowering XLA dialect to Standard dialect. + +#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringSwitch.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Function.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" + +namespace mlir { +namespace { +#include "third_party/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/generated_legalize_to_standard.inc" +} // end anonymous namespace +namespace xla_hlo { +namespace { + +class CompareIConvert : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(xla_hlo::CompareOp op, + PatternRewriter &rewriter) const override { + auto lhs = op.lhs(); + auto rhs = op.rhs(); + auto lhs_type = lhs.getType().cast(); + auto rhs_type = rhs.getType().cast(); + + // Broadcasting not supported by this rewrite. + if (lhs_type.getShape() != rhs_type.getShape()) return failure(); + + if (!lhs_type.getElementType().isSignlessInteger() || + !rhs_type.getElementType().isSignlessInteger()) + return failure(); + + auto comparison_direction = op.comparison_direction(); + auto compare_predicate = + llvm::StringSwitch>(comparison_direction) + .Case("EQ", CmpIPredicate::eq) + .Case("NE", CmpIPredicate::ne) + .Case("LT", CmpIPredicate::slt) + .Case("LE", CmpIPredicate::sle) + .Case("GT", CmpIPredicate::sgt) + .Case("GE", CmpIPredicate::sge) + .Default(llvm::None); + + if (!compare_predicate.hasValue()) return failure(); + + rewriter.replaceOpWithNewOp(op, compare_predicate.getValue(), lhs, + rhs); + return success(); + } +}; + +class CompareFConvert : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(xla_hlo::CompareOp op, + PatternRewriter &rewriter) const override { + auto lhs = op.lhs(); + auto rhs = op.rhs(); + auto lhs_type = lhs.getType().cast(); + auto rhs_type = rhs.getType().cast(); + + // Broadcasting not supported by this rewrite. + if (lhs_type.getShape() != rhs_type.getShape()) return failure(); + + if (!lhs_type.getElementType().isa() || + !rhs_type.getElementType().isa()) + return failure(); + + auto comparison_direction = op.comparison_direction(); + auto compare_predicate = + llvm::StringSwitch>(comparison_direction) + .Case("EQ", CmpFPredicate::OEQ) + .Case("NE", CmpFPredicate::UNE) + .Case("LT", CmpFPredicate::OLT) + .Case("LE", CmpFPredicate::OLE) + .Case("GT", CmpFPredicate::OGT) + .Case("GE", CmpFPredicate::OGE) + .Default(llvm::None); + + if (!compare_predicate.hasValue()) return failure(); + + rewriter.replaceOpWithNewOp(op, compare_predicate.getValue(), lhs, + rhs); + return success(); + } +}; + +// Replace IotaOp with an integer constant. A ConvertOp is added to +// convert the integer constant to iota result type. For complex types, the real +// part is replaced with the generated constant and the imaginary part is +// replaced with zero tensor. +class ConvertIotaOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(xla_hlo::IotaOp op, + PatternRewriter &rewriter) const override { + auto output_type = op.getType().cast(); + auto output_size = output_type.getNumElements(); + auto dimension = op.iota_dimension().getSExtValue(); + auto max_dim_size = output_type.getDimSize(dimension); + + auto element_type = output_type.getElementType(); + int bitwidth; + + auto complex_ty = element_type.dyn_cast(); + Type int_or_float_ty = element_type; + if (complex_ty) int_or_float_ty = complex_ty.getElementType(); + + bitwidth = int_or_float_ty.getIntOrFloatBitWidth(); + llvm::SmallVector values; + values.reserve(output_size); + + int64_t increase_stride = output_size; + for (int i = 0; i <= dimension; i++) { + increase_stride /= output_type.getDimSize(i); + } + + int64_t current_value = 0; + for (int i = 0; i < output_size; i++) { + int64_t value = (current_value / increase_stride) % max_dim_size; + values.push_back(APInt(bitwidth, value)); + ++current_value; + } + + auto int_shape_type = RankedTensorType::get( + output_type.getShape(), + IntegerType::get(bitwidth, rewriter.getContext())); + auto loc = op.getLoc(); + auto integer_const = rewriter.create( + loc, DenseIntElementsAttr::get(int_shape_type, values)); + + auto int_or_float_shape_ty = + RankedTensorType::get(output_type.getShape(), int_or_float_ty); + + auto iota_const = + rewriter.create(loc, int_or_float_shape_ty, integer_const); + + // For int/float types we are done, replace op and return. + if (!complex_ty) { + rewriter.replaceOp(op, iota_const.getResult()); + return success(); + } + + // For complex types, generate a constant tensor of zeroes for the imaginary + // part and use iota_const for real part. + auto zeroes = rewriter.create( + loc, DenseIntElementsAttr::get(int_shape_type, APInt(bitwidth, 0))); + auto imag_zeroes = + rewriter.create(loc, int_or_float_shape_ty, zeroes); + rewriter.replaceOpWithNewOp(op, iota_const, + imag_zeroes); + return success(); + } +}; + +} // end anonymous namespace + +namespace { +struct LegalizeToStandard + : public PassWrapper { + /// Perform the lowering to Standard dialect. + void runOnFunction() override; +}; +} // end anonymous namespace + +std::unique_ptr> createLegalizeToStdPass() { + return std::make_unique(); +} + +void PopulateXlaToStdPatterns(OwningRewritePatternList *patterns, + mlir::MLIRContext *ctx) { + mlir::populateWithGenerated(ctx, patterns); + patterns->insert(ctx); +} + +/// Perform the lowering to standard dialect. +void LegalizeToStandard::runOnFunction() { + OwningRewritePatternList patterns; + mlir::xla_hlo::PopulateXlaToStdPatterns(&patterns, &getContext()); + applyPatternsAndFoldGreedily(getFunction(), patterns); +} + +static PassRegistration legalize_pass( + "xla-legalize-to-std", "Legalize from XLA dialect to standard dialect"); + +} // end namespace xla_hlo +} // end namespace mlir diff --git a/lib/Dialect/mhlo/transforms/legalize_to_standard_patterns.td b/lib/Dialect/mhlo/transforms/legalize_to_standard_patterns.td new file mode 100644 index 0000000..2d238eb --- /dev/null +++ b/lib/Dialect/mhlo/transforms/legalize_to_standard_patterns.td @@ -0,0 +1,71 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This is the legalization pattern definition file for XLA to StandardOps. + +include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpBase.td" +include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td" +include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td" + +//===----------------------------------------------------------------------===// +// Nullary op patterns. +//===----------------------------------------------------------------------===// + +def : Pat<(HLO_ConstOp ElementsAttr:$value), + (ConstantOp $value)>; + +//===----------------------------------------------------------------------===// +// Binary op patterns. +//===----------------------------------------------------------------------===// + +def IsSameSizePred : CPred< + "$0.getType().cast().getShape() " + "== $1.getType().cast().getShape()">; +def IsSameSizeConstraint : Constraint; + + +def : Pat<(HLO_AndOp HLO_PredTensor:$l, HLO_PredTensor:$r), + (AndOp $l, $r), + [(IsSameSizeConstraint $l, $r)]>; +def : Pat<(HLO_AddOp HLO_FpTensor:$l, HLO_FpTensor:$r), + (AddFOp $l, $r), + [(IsSameSizeConstraint $l, $r)]>; +def : Pat<(HLO_SubOp HLO_FpTensor:$l, HLO_FpTensor:$r), + (SubFOp $l, $r), + [(IsSameSizeConstraint $l, $r)]>; +def : Pat<(HLO_MulOp HLO_FpTensor:$l, HLO_FpTensor:$r), + (MulFOp $l, $r), + [(IsSameSizeConstraint $l, $r)]>; +def : Pat<(HLO_DivOp HLO_FpTensor:$l, HLO_FpTensor:$r), + (DivFOp $l, $r), + [(IsSameSizeConstraint $l, $r)]>; +def : Pat<(HLO_RemOp HLO_FpTensor:$l, HLO_FpTensor:$r), + (RemFOp $l, $r), + [(IsSameSizeConstraint $l, $r)]>; +def : Pat<(HLO_AddOp HLO_IntTensor:$l, HLO_IntTensor:$r), + (AddIOp $l, $r), + [(IsSameSizeConstraint $l, $r)]>; +def : Pat<(HLO_SubOp HLO_IntTensor:$l, HLO_IntTensor:$r), + (SubIOp $l, $r), + [(IsSameSizeConstraint $l, $r)]>; +def : Pat<(HLO_MulOp HLO_IntTensor:$l, HLO_IntTensor:$r), + (MulIOp $l, $r), + [(IsSameSizeConstraint $l, $r)]>; +def : Pat<(HLO_DivOp HLO_IntTensor:$l, HLO_IntTensor:$r), + (SignedDivIOp $l, $r), + [(IsSameSizeConstraint $l, $r)]>; +def : Pat<(HLO_RemOp HLO_IntTensor:$l, HLO_IntTensor:$r), + (SignedRemIOp $l, $r), + [(IsSameSizeConstraint $l, $r)]>; diff --git a/lib/Dialect/mhlo/transforms/lhlo_copy_removal.cc b/lib/Dialect/mhlo/transforms/lhlo_copy_removal.cc new file mode 100644 index 0000000..4fbd774 --- /dev/null +++ b/lib/Dialect/mhlo/transforms/lhlo_copy_removal.cc @@ -0,0 +1,105 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file implements a pass to remove redundant LHLO copy operations. + +#include "third_party/absl/memory/memory.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h" + +namespace mlir { +namespace xla_lhlo { +namespace { + +// Removes LHLO copy operations that copy from allocated buffers to block +// arguments. All uses of each buffer are replaced with the corresponding block +// argument and the buffer is freed. Note that this pass only works in regions +// with a single block. +struct LhloCopyRemoval : mlir::PassWrapper> { + void runOnOperation() override { + llvm::SmallVector eraseList; + auto operation = getOperation(); + operation->walk([&](mlir::xla_lhlo::CopyOp copyOp) { + // If this region contains more than one block, then ignore this copy + // operation. + if (copyOp.getParentRegion()->getBlocks().size() > 1) { + return; + } + + mlir::Value fromOperand = copyOp.operand(); + mlir::Value toOperand = copyOp.output(); + + // If the fromOperand value is a block argument or the toOperand + // value is not a block argument, then ignore this copy operation. + if (!fromOperand.getDefiningOp() || toOperand.getDefiningOp()) { + return; + } + + // The copy operation removal is illegal if there is at least a single use + // of toOperand value that lies between the first use of fromOperand value + // and the copy operation. + auto fromOperandUsers = fromOperand.getUsers(); + auto firstUser = *fromOperandUsers.begin(); + for (auto op : fromOperandUsers) { + if (op->isBeforeInBlock(firstUser)) firstUser = op; + } + for (auto op : toOperand.getUsers()) { + if (op->isBeforeInBlock(copyOp) && firstUser->isBeforeInBlock(op)) { + return; + } + } + + // TODO(DFKI): Use live variable analysis to solve aliasing issues among + // block arguments. + + // Remove the associated alloc operation. + auto allocOp = fromOperand.getDefiningOp(); + eraseList.push_back(allocOp); + + // Iterate over all uses of the fromOperand to find the associated + // deallocOp (if any). + for (auto op : fromOperandUsers) { + if (isa(op)) { + eraseList.push_back(op); + break; + } + } + + // Replace all uses of the fromOperand with the toOperand. This rewires + // all references pointing to the original alloc operation to the new + // target operation in order to safely remove the copy op. + fromOperand.replaceAllUsesWith(toOperand); + copyOp.erase(); + }); + for (auto op : eraseList) { + op->erase(); + } + }; +}; + +} // namespace + +std::unique_ptr createLhloCopyRemovalPass() { + return absl::make_unique(); +} + +static PassRegistration copy_removal_pass( + "lhlo-copy-removal", "Removes redundant LHLO copy operations"); + +} // namespace xla_lhlo +} // namespace mlir diff --git a/lib/Dialect/mhlo/transforms/lhlo_fuse_linalg.cc b/lib/Dialect/mhlo/transforms/lhlo_fuse_linalg.cc new file mode 100644 index 0000000..c5b81ec --- /dev/null +++ b/lib/Dialect/mhlo/transforms/lhlo_fuse_linalg.cc @@ -0,0 +1,151 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file implements logic for fusing linalg ops obtained after LHLO +// lowering. + +#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" +#include "third_party/absl/memory/memory.h" +#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/ArrayRef.h" +#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLExtras.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/FoldUtils.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h" + +namespace mlir { +namespace xla_lhlo { +namespace { + +using linalg::LinalgOp; + +class LhloFuseLinalg : public PassWrapper { + public: + LhloFuseLinalg() = default; + LhloFuseLinalg(const LhloFuseLinalg&) {} + LhloFuseLinalg(bool use_parallel_loops, llvm::ArrayRef tile_sizes) { + tile_sizes_ = tile_sizes; + use_parallel_loops_.setValue(use_parallel_loops); + } + + void runOnFunction() override { + auto func = getFunction(); + + // TODO(pifon): Remove assumption that the function has a single block. + if (!llvm::hasSingleElement(func)) { + emitError(func.getLoc(), "The function needs to have a single block."); + signalPassFailure(); + return; + } + + // The fusion in Linalg is currently possible only when the consumer op is + // tiled. In order to greedily fuse the ops, we have to start from the tiled + // root linalg ops, i.e. linalg ops that write to output buffers of the + // function or are returned in case of escaping allocations. + llvm::SmallDenseSet result_buffers; + for (auto func_arg : func.getArguments()) { + result_buffers.insert(func_arg); + } + for (auto& block : func) { + auto returnOp = mlir::dyn_cast(block.getTerminator()); + if (!returnOp) continue; + for (auto operand : returnOp.getOperands()) { + result_buffers.insert(operand); + } + } + MLIRContext* ctx = func.getContext(); + OpBuilder b(func); + OperationFolder folder(ctx); + func.walk([&](linalg::GenericOp generic_op) { + SmallVector tile_sizes(tile_sizes_.begin(), + tile_sizes_.end()); + if (tile_sizes.empty()) { + tile_sizes = SmallVector(generic_op.getNumLoops(), 1); + } + auto op = cast(generic_op.getOperation()); + for (const Value result : op.getOutputBuffers()) { + if (!result_buffers.count(result)) continue; + if (tileGenericOp(op, tile_sizes, &b)) { + generic_op.erase(); + return; + } + } + }); + auto patterns = linalg::getLinalgTilingCanonicalizationPatterns(ctx); + applyPatternsAndFoldGreedily(func, patterns); + + // Fuse producers of tiled linalg ops. + llvm::SmallDenseSet erase_set; + SmallVector linalg_ops; + func.walk([&](LinalgOp op) { linalg_ops.push_back(op); }); + for (auto* op : llvm::reverse(linalg_ops)) { + for (unsigned id = 0, e = LinalgOp(op).getNumInputs(); id < e; ++id) { + linalg::Aliases aliases; + linalg::LinalgDependenceGraph graph(aliases, linalg_ops); + if (auto info = fuseProducerOf(b, op, id, graph, &folder)) { + auto originalOp = info->originalProducer.getOperation(); + erase_set.insert(originalOp); + auto originalOpInLinalgOpsVector = std::find_if( + linalg_ops.begin(), linalg_ops.end(), + [&](const Operation* op) { return op == originalOp; }); + *originalOpInLinalgOpsVector = info->fusedProducer.getOperation(); + } + } + + auto patterns = linalg::getLinalgTilingCanonicalizationPatterns(ctx); + applyPatternsAndFoldGreedily(func, patterns); + } + for (auto* e : erase_set) e->erase(); + } + + private: + bool tileGenericOp(LinalgOp op, ArrayRef tile_sizes, OpBuilder* b) { + auto loopType = use_parallel_loops_ + ? linalg::LinalgTilingLoopType::ParallelLoops + : linalg::LinalgTilingLoopType::Loops; + auto tiled_generic_op = linalg::tileLinalgOp(*b, op, + linalg::LinalgTilingOptions() + .setTileSizes(tile_sizes) + .setLoopType(loopType)); + return tiled_generic_op.hasValue(); + } + + Option use_parallel_loops_{ + *this, "use-parallel-loops", + llvm::cl::desc( + "Tiles GenericOp consumer to parallel loops before linalg fusion"), + llvm::cl::init(false)}; + + ListOption tile_sizes_{ + *this, "tile-sizes", + llvm::cl::desc( + "Tile sizes by which to tile linalg generic before linalg fusion"), + llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; +}; + +} // namespace + +std::unique_ptr> createLhloFuseLinalg( + bool use_parallel_loops, ArrayRef tile_sizes) { + return absl::make_unique(use_parallel_loops, tile_sizes); +} + +static PassRegistration legalize_pass( + "lhlo-fuse-linalg", + "Greedily fuse linalg ops obtained after LHLO lowering."); + +} // namespace xla_lhlo +} // namespace mlir diff --git a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc new file mode 100644 index 0000000..f4354d1 --- /dev/null +++ b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc @@ -0,0 +1,161 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file implements logic for lowering LHLO dialect to Affine dialect. + +#include "third_party/absl/memory/memory.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Location.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_xla_to_scalar_op.h" + +namespace mlir { +namespace xla_lhlo { +namespace { + +// Builds an affine loop nest iterating from zeros to "upper_bounds" with unit +// steps, and populates the body of the innermost loop using "body_builder". +static void BuildBoundedAffineLoopNest( + OpBuilder& builder, Location location, ArrayRef upper_bounds, + function_ref body_builder) { + SmallVector lower_bounds(upper_bounds.size(), /*Value=*/0); + SmallVector steps(upper_bounds.size(), /*Value=*/1); + buildAffineLoopNest(builder, location, lower_bounds, upper_bounds, steps, + body_builder); +} + +struct DotOpConverter : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + // Supports only rank-2 tensors for LHS and RHS. + LogicalResult matchAndRewrite(DotOp op, + PatternRewriter& rewriter) const override { + Value lhs = op.lhs(); + Value rhs = op.rhs(); + MemRefType lhs_type = lhs.getType().cast(); + MemRefType rhs_type = rhs.getType().cast(); + Type element_type = lhs_type.getElementType(); + ArrayRef shape_lhs = lhs_type.getShape(); + ArrayRef shape_rhs = rhs_type.getShape(); + + if ((lhs_type.getRank() != 2) || (rhs_type.getRank() != 2)) { + return failure(); + } + + LogicalResult map_status = success(); + auto body_builder = [&](OpBuilder& builder, Location loc, ValueRange ivs) { + SmallVector lhs_indices{ivs[0], ivs[2]}, + rhs_indices{ivs[2], ivs[1]}, result_indices{ivs[0], ivs[1]}; + + auto l = builder.create(loc, lhs, lhs_indices); + auto r = builder.create(loc, rhs, rhs_indices); + auto result = + rewriter.create(loc, op.output(), result_indices); + Value op_result = xla_lhlo::XlaOpToStdScalarOp::map( + op, element_type, {l, r, result}, &builder); + map_status = success(op_result != nullptr); + if (failed(map_status)) return; + builder.create(loc, op_result, op.output(), + result_indices); + }; + + BuildBoundedAffineLoopNest(rewriter, op.getLoc(), + {shape_lhs[0], shape_rhs[1], shape_rhs[0]}, + body_builder); + if (failed(map_status)) return failure(); + + rewriter.eraseOp(op); + return success(); + } +}; + +template +struct BinaryOpConverter : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(LhloOpTy op, + PatternRewriter& rewriter) const override { + const auto& lhs = op.lhs(); + const auto& rhs = op.rhs(); + const auto& lhs_type = lhs.getType().template cast(); + const auto& rhs_type = rhs.getType().template cast(); + const auto& element_type = lhs_type.getElementType(); + + if (lhs_type.getShape() != rhs_type.getShape()) { + return failure(); + } + + LogicalResult map_status = success(); + auto body_builder = [&](OpBuilder& builder, Location loc, + ValueRange induction_vars) { + auto l = builder.create(loc, lhs, induction_vars); + auto r = builder.create(loc, rhs, induction_vars); + Value op_result = xla_lhlo::XlaOpToStdScalarOp::map( + op, element_type, {l, r}, &builder); + map_status = success(op_result != nullptr); + if (failed(map_status)) return; + rewriter.create(loc, op_result, op.out(), induction_vars); + }; + + BuildBoundedAffineLoopNest(rewriter, op.getLoc(), lhs_type.getShape(), + body_builder); + if (failed(map_status)) return failure(); + rewriter.eraseOp(op); + return success(); + } +}; + +void populateLHLOToAffineConversionPattern(MLIRContext* context, + OwningRewritePatternList* patterns) { + // clang-format off + patterns->insert< + BinaryOpConverter, + BinaryOpConverter, + BinaryOpConverter, + BinaryOpConverter, + BinaryOpConverter, + BinaryOpConverter, + BinaryOpConverter, + DotOpConverter>(context); + // clang-format on +} + +struct LhloLegalizeToAffine + : public PassWrapper { + void runOnFunction() override { + OwningRewritePatternList patterns; + auto func = getFunction(); + populateLHLOToAffineConversionPattern(func.getContext(), &patterns); + applyPatternsAndFoldGreedily(func, patterns); + } +}; + +} // namespace + +std::unique_ptr> createLegalizeToAffinePass() { + return absl::make_unique(); +} + +static PassRegistration legalize_pass( + "lhlo-legalize-to-affine", "Legalize from LHLO dialect to affine dialect"); + +} // namespace xla_lhlo +} // namespace mlir diff --git a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc new file mode 100644 index 0000000..bb502ad --- /dev/null +++ b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc @@ -0,0 +1,196 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file implements logic for lowering LHLO dialect to GPU dialect. + +#include + +#include "third_party/absl/memory/memory.h" +#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/ArrayRef.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/GPU/GPUDialect.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/SCF/SCF.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/BlockAndValueMapping.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Function.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Location.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_xla_to_scalar_op.h" + +namespace mlir { +namespace xla_lhlo { +namespace { + +// A simple translation of LHLO reduce operations to a corresponding gpu +// launch operation. The transformation does no tiling and also only supports +// 1d results. +class LhloReduceToGPULaunchConverter : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + ReduceOp reduce_op, ArrayRef args, + ConversionPatternRewriter& rewriter) const final { + auto loc = reduce_op.getLoc(); + // Only support 1d reductions for now. + int64_t size = 0; + for (auto result : reduce_op.out()) { + auto shaped_type = result.getType().dyn_cast(); + if (!shaped_type || shaped_type.getRank() != 1) { + return failure(); + } + auto dim_size = shaped_type.getDimSize(0); + if (size && size != dim_size) { + return failure(); + } + size = dim_size; + } + + auto reducing_dimension = *reduce_op.dimensions().int_value_begin(); + + // Require all inputs to have the same shape. + int64_t reduce_dim_size = 0; + for (auto input : reduce_op.operands()) { + auto shaped_type = input.getType().dyn_cast(); + if (!shaped_type || !shaped_type.hasStaticShape()) { + return failure(); + } + reduce_dim_size = + shaped_type.getDimSize(reducing_dimension.getSExtValue()); + } + + // Create a launch that is parallel in the result dimension. + auto block_size_x = rewriter.create( + loc, rewriter.getIndexType(), + rewriter.getIntegerAttr(rewriter.getIndexType(), size)); + auto one = rewriter.create( + loc, rewriter.getIndexType(), + rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); + auto launch_op = rewriter.create( + loc, one, one, one, block_size_x, one, one); + { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToEnd(&launch_op.body().front()); + auto index = launch_op.getThreadIds().x; + + // Load the initial value and store it to the output. + for (auto pair : llvm::zip(reduce_op.init_values(), reduce_op.out())) { + auto init_value = rewriter.create(loc, std::get<0>(pair)); + rewriter.create(loc, init_value, std::get<1>(pair), + ArrayRef{index}); + } + + // Insert a loop into the body to compute the reduction. The loop ranges + // from [0.dim). + auto zero = rewriter.create( + loc, rewriter.getIndexType(), + rewriter.getIntegerAttr(rewriter.getIndexType(), 0)); + // TODO(b/137624192) Use dimOp to make it shape independent. + auto upper = rewriter.create( + loc, rewriter.getIndexType(), + rewriter.getIntegerAttr(rewriter.getIndexType(), reduce_dim_size)); + auto step = rewriter.create( + loc, rewriter.getIndexType(), + rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); + auto loop = rewriter.create(loc, zero, upper, step); + + rewriter.setInsertionPointToStart(loop.getBody()); + // Compute memrefs for the value to reduce. This makes it easier to just + // inline the body. + auto output = *reduce_op.out().begin(); + // TODO(herhut) Move this to the SliceOp builder. + auto resType = MemRefType::get( + llvm::None, output.getType().cast().getElementType(), + makeStridedLinearLayoutMap(llvm::None, + MemRefType::getDynamicStrideOrOffset(), + rewriter.getContext())); + auto accumulator = rewriter.create( + loc, resType, output, ArrayRef{launch_op.getThreadIds().x}); + llvm::SmallVector indexings; + auto input_buffer = *reduce_op.operands().begin(); + auto input_type = input_buffer.getType().cast(); + for (int64_t dim = 0; dim < input_type.getRank(); ++dim) { + indexings.push_back(dim == reducing_dimension + ? loop.getInductionVar() + : launch_op.getThreadIds().x); + } + // TODO(herhut) Move this to the SliceOp builder. + auto input = *reduce_op.operand_begin(); + auto rhs = rewriter.create( + loc, + MemRefType::get( + llvm::None, input_type.getElementType(), + makeStridedLinearLayoutMap(llvm::None, + MemRefType::getDynamicStrideOrOffset(), + rewriter.getContext())), + input, indexings); + + // Now copy over the actual body of the reduction, leaving out the + // terminator. + BlockAndValueMapping mapping; + mapping.map(reduce_op.body().front().getArgument(0), accumulator); + mapping.map(reduce_op.body().front().getArgument(1), rhs); + mapping.map(reduce_op.body().front().getArgument(2), accumulator); + for (auto& nested : reduce_op.body().front().without_terminator()) { + auto clone = rewriter.clone(nested, mapping); + for (auto pair : llvm::zip(nested.getResults(), clone->getResults())) { + mapping.map(std::get<0>(pair), std::get<1>(pair)); + } + } + + // Finally, insert the terminator for the launchOp. + rewriter.setInsertionPointToEnd(&launch_op.body().front()); + rewriter.create(loc); + } + + rewriter.eraseOp(reduce_op); + return success(); + }; +}; + +struct LhloLegalizeToGpu : public PassWrapper { + void runOnFunction() override { + OwningRewritePatternList patterns; + ConversionTarget target(getContext()); + target.addLegalDialect(); + target.addIllegalOp(); + auto func = getFunction(); + patterns.insert(func.getContext()); + if (failed(applyPartialConversion(func, target, patterns))) { + signalPassFailure(); + } + } +}; + +} // namespace + +std::unique_ptr> createLegalizeToGpuPass() { + return absl::make_unique(); +} + +static PassRegistration legalize_pass( + "lhlo-legalize-to-gpu", "Legalize from LHLO dialect to GPU dialect"); + +} // namespace xla_lhlo +} // namespace mlir diff --git a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm.cc b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm.cc new file mode 100644 index 0000000..bfd0148 --- /dev/null +++ b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm.cc @@ -0,0 +1,136 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "third_party/llvm/llvm-project/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" + +namespace mlir { +namespace xla_lhlo { +namespace { + +struct StaticMemRefCastOpConverter + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult matchAndRewrite( + Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + auto cast_op = cast(op); + + StaticMemRefCastOp::Adaptor operands_adaptor(operands); + MemRefDescriptor sourceMemRef(operands_adaptor.operand()); + + MemRefType targetMemRefType = + cast_op.getResult().getType().cast(); + auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType) + .dyn_cast_or_null(); + if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy()) + return failure(); + // Create descriptor. + auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); + Type llvmTargetElementTy = desc.getElementType(); + // Set allocated ptr. + Value allocated = sourceMemRef.allocatedPtr(rewriter, loc); + allocated = + rewriter.create(loc, llvmTargetElementTy, allocated); + desc.setAllocatedPtr(rewriter, loc, allocated); + // Set aligned ptr. + Value ptr = sourceMemRef.alignedPtr(rewriter, loc); + ptr = rewriter.create(loc, llvmTargetElementTy, ptr); + desc.setAlignedPtr(rewriter, loc, ptr); + + // Fill size and stride descriptors in memref. + auto target_sizes = targetMemRefType.getShape(); + int64_t target_offset; + llvm::SmallVector target_strides; + if (failed((getStridesAndOffset(targetMemRefType, target_strides, + target_offset)))) + return failure(); + + // Copy offset of `targetMemRef`. + desc.setConstantOffset(rewriter, loc, target_offset); + for (int i = 0, e = targetMemRefType.getRank(); i < e; ++i) { + desc.setConstantSize(rewriter, loc, i, target_sizes[i]); + desc.setConstantStride(rewriter, loc, i, target_strides[i]); + } + rewriter.replaceOp(op, {desc}); + return success(); + } +}; + +struct DynamicMemRefCastOpConverter + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult matchAndRewrite( + Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + auto cast_op = cast(op); + + DynamicMemRefCastOp::Adaptor operands_adaptor(operands); + MemRefDescriptor sourceMemRef(operands_adaptor.operand()); + + MemRefType targetMemRefType = + cast_op.getResult().getType().cast(); + auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType) + .dyn_cast_or_null(); + if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy()) + return failure(); + // Create descriptor. + auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); + Type llvmTargetElementTy = desc.getElementType(); + // Set allocated ptr. + Value allocated = sourceMemRef.allocatedPtr(rewriter, loc); + allocated = + rewriter.create(loc, llvmTargetElementTy, allocated); + desc.setAllocatedPtr(rewriter, loc, allocated); + // Set aligned ptr. + Value ptr = sourceMemRef.alignedPtr(rewriter, loc); + ptr = rewriter.create(loc, llvmTargetElementTy, ptr); + desc.setAlignedPtr(rewriter, loc, ptr); + // Copy offset of `sourceMemRef`. + desc.setOffset(rewriter, loc, sourceMemRef.offset(rewriter, loc)); + + // Fill size and stride descriptors in memref. + if (!cast_op.sizes().empty()) { + auto sizes = operands_adaptor.sizes(); + auto strides = operands_adaptor.strides(); + for (int i = 0, e = targetMemRefType.getRank(); i < e; ++i) { + desc.setSize(rewriter, loc, i, sizes[i]); + desc.setStride(rewriter, loc, i, strides[i]); + } + } + rewriter.replaceOp(op, {desc}); + return success(); + } +}; + +} // namespace + +void PopulateLhloToLLVMConversionPatterns(const LowerToLLVMOptions &options, + LLVMTypeConverter *converter, + OwningRewritePatternList *patterns) { + patterns->insert( + *converter, options); +} + +} // namespace xla_lhlo +} // namespace mlir diff --git a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm_pass.cc b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm_pass.cc new file mode 100644 index 0000000..0fa52c0 --- /dev/null +++ b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm_pass.cc @@ -0,0 +1,59 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "third_party/llvm/llvm-project/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" + +namespace mlir { +namespace xla_lhlo { +namespace { + +class TestLhloToLLVMPass + : public ::mlir::PassWrapper> { + public: + void runOnOperation() override { + ModuleOp m = getOperation(); + + OwningRewritePatternList patterns; + LLVMTypeConverter converter(m.getContext()); + populateStdToLLVMConversionPatterns(converter, patterns); + PopulateLhloToLLVMConversionPatterns( + LowerToLLVMOptions::getDefaultOptions(), &converter, &patterns); + + ConversionTarget target(getContext()); + target.addLegalDialect(); + target.addLegalOp(); + target.addIllegalDialect(); + + if (failed(applyFullConversion(m, target, patterns))) { + signalPassFailure(); + } + } +}; + +} // namespace + +static PassRegistration legalize_lhlo_pass( + "test-lhlo-legalize-to-llvm", "Legalize from LHLO dialect to LLVM."); + +} // namespace xla_lhlo +} // namespace mlir diff --git a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_parallel_loops.cc b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_parallel_loops.cc new file mode 100644 index 0000000..cb2451f --- /dev/null +++ b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_parallel_loops.cc @@ -0,0 +1,731 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "third_party/absl/memory/memory.h" +#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/ArrayRef.h" +#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLExtras.h" +#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/SmallVector.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/SCF/SCF.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" + +namespace mlir { +namespace xla_lhlo { +namespace { + +// Clones and adapts the code in `lhlo_block` that works on buffers and has a +// single output buffer to make it compatible with `operands` that have element +// types of the respective buffers. Returns the computed value. +// +// Example. For `operands` with (f32, i32) types and a block with LHLO ops and +// with signature: +// ^bb(%lhs: memref, %rhs: memref, %res: memref): +// +// +// inserts necessary alloc and store ops to compute and return result that has +// `i1` type. +Value ApplySingleResultLhloCode(Location loc, ValueRange operands, + Block* lhlo_block, OpBuilder* b) { + SmallVector arg_bufs; + for (auto arg_type : lhlo_block->getArgumentTypes()) { + arg_bufs.push_back(b->create(loc, arg_type.cast())); + } + for (auto operand : llvm::enumerate(operands)) { + b->create(loc, operand.value(), arg_bufs[operand.index()]); + } + // Clone the ops from `lhlo_block`. + BlockAndValueMapping mapping; + mapping.map(lhlo_block->getArguments(), arg_bufs); + for (auto& nested : lhlo_block->without_terminator()) { + auto clone = b->clone(nested, mapping); + mapping.map(nested.getResults(), clone->getResults()); + } + return b->create(loc, arg_bufs.back()); +} + +// Converts a block with LHLO ops and with signature: +// ^bb(%lhs: memref, %rhs: memref, %res: memref): +// into a reduction operator of scf.reduce by doing buffer allocation for +// scalar arguments and the result of `scf.reduce` to make it compatible with +// LHLO ops. +void ConvertToReductionOperator(Location loc, scf::ReduceOp reduce_op, + Block* lhlo_block, OpBuilder* b) { + Block& loop_reduce_op_body = reduce_op.reductionOperator().front(); + OpBuilder::InsertionGuard guard(*b); + b->setInsertionPointToStart(&loop_reduce_op_body); + b->create( + loc, ApplySingleResultLhloCode(loc, loop_reduce_op_body.getArguments(), + lhlo_block, b)); +} + +// Returns result of ConstantOp if `dim` is static, otherwise uses DimOp to +// extract dimension at runtime. +Value GetStaticOrDynamicDim(mlir::Location loc, Value shaped_value, + size_t dim_index, int64_t dim, OpBuilder* b) { + return dim == ShapedType::kDynamicSize + ? b->create(loc, shaped_value, dim_index).getResult() + : b->create(loc, dim); +} + +struct MappedIvs { + // False if the mapped indices are in the padding area, true otherwise. + Value in_bounds; + // Mapped indices. + SmallVector ivs; +}; + +template +MappedIvs MapWindowIvsToInput(OpTy op, ValueRange ivs, ValueRange window_ivs, + OpBuilder* b) { + MappedIvs mapped_ivs; + + if (!op.window_strides().hasValue()) { + op.emitOpError("No window strides specified."); + } + auto window_strides = op.window_strides().getValue(); + + if (!op.padding().hasValue()) { + op.emitOpError("No padding specified."); + } + auto padding = op.padding().getValue(); + + auto loc = op.getLoc(); + auto operand = op.operand(); + auto operand_shape = operand.getType().template cast().getShape(); + + // `in_bounds` is false when the mapped indices are in the padding area. + mapped_ivs.in_bounds = b->create( + loc, b->getI1Type(), b->getIntegerAttr(b->getI1Type(), 1)); + for (unsigned i = 0, e = ivs.size(); i < e; ++i) { + auto stride = window_strides.template getValue(i); + auto pad_low = padding.template getValue({i, 0}); + + Value stride_val = b->create(loc, stride.getSExtValue()); + Value pad_low_val = b->create(loc, pad_low.getSExtValue()); + + Value center = b->create(loc, ivs[i], stride_val); + Value offset = b->create(loc, window_ivs[i], pad_low_val); + Value index = b->create(loc, center, offset); + Value upper_bound = + GetStaticOrDynamicDim(loc, operand, i, operand_shape[i], b); + // We must check whether 0 <= index_i < shape_i, as otherwise we are in + // the pad and then we have to use the neutral element for reduction. + // Equivalently, it can be computed as the unsigned comparison index_i < + // shape_i, since a negative value wraps to a large positive value. + mapped_ivs.in_bounds = b->create( + loc, mapped_ivs.in_bounds, + b->create(loc, CmpIPredicate::ult, index, upper_bound)); + mapped_ivs.ivs.push_back(index); + } + return mapped_ivs; +} + +// Returns scf::Parallel over a shaped value with static or dynamic shape. +scf::ParallelOp MakeLoopOverShape(Location loc, Value shaped_value, + OpBuilder* b) { + Value zero = b->create(loc, 0); + Value one = b->create(loc, 1); + + ArrayRef shape = + shaped_value.getType().cast().getShape(); + SmallVector lower, upper, step; + for (auto dim : llvm::enumerate(shape)) { + upper.push_back( + GetStaticOrDynamicDim(loc, shaped_value, dim.index(), dim.value(), b)); + lower.push_back(zero); + step.push_back(one); + } + return b->create(loc, lower, upper, step); +} + +// Converts `xla_lhlo.ReduceOp` into two scf::ParallelOp and a scf::ReduceOp. +// The outper `ParallelOp` refers to the parallel loops if there are +// any. The inner `ParalleOp` refers to the reduction loops and `ReduceOp` +// contains the reduction operator. +// +// Example: +// +// "xla_lhlo.reduce"(%buffer, %init_buf, %result) ( { +// ^bb0(%lhs: memref, %rhs: memref, %res: memref): +// +// } ) {dimensions = dense<[1]> : tensor<1xi64>} +// : (memref<100x10x5xf32>, memref, memref<100x5xf32>) -> () +// +// is roughly converted into: +// +// %init = load %init_buf[] : memref +// scf.parallel (%i, %k) = (%c0, %c0) to (%c100, %c5) step (%c1, %c1) { +// %result = scf.parallel (%j) = (%c0) to (%c10) step (%c1) init (%init) { +// %elem_to_reduce = load %buffer[%i, %j, %k] : memref<100x10x5xf32> +// scf.reduce(%elem_to_reduce) { +// ^bb0(%elem: f32, %acc: f32): // no predecessors +// elem_buf = alloc() : memref +// store %elem, elem_buf[] : memref +// acc_buf = alloc() : memref +// store %acc, acc_buf[] : memref +// +// %acc_result = load acc_buf[] : memref +// scf.reduce.return %acc_result : f32 +// } : f32 +// scf.yield +// } : f32 +// scf.yield +// } +class ReduceOpConverter : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + xla_lhlo::ReduceOp xla_reduce_op, ArrayRef /*args*/, + ConversionPatternRewriter& rewriter) const final { + // TODO(b/137624192) Implement variadic reduce. + if (xla_reduce_op.out().size() != 1) return failure(); + + scf::ReduceOp reduce_op = + CreateReduceOpInNestedParallelLoops(xla_reduce_op, &rewriter); + ConvertToReductionOperator(xla_reduce_op.getLoc(), reduce_op, + &xla_reduce_op.body().front(), &rewriter); + rewriter.replaceOp(xla_reduce_op, llvm::None); + return success(); + } + + private: + // Creates nested `scf.parallel` ops with `scf.reduce`. The outer ParallelOp + // refers to the parallel dimensions of `xla_reduce_op` if any and the inner + // ParallelOp refers to the reduction dimensions. The scf.reduce op is + // returned. + // + // If the reduction argument is a memref<100x10x5xf32> and the + // reduction is performed along dimension 1 then this method will generate + // + // %init = load %init_buf[] : memref + // scf.parallel (%i, %k) = (%c0, %c0) to (%c100, %c5) step (%c1, %c1) { + // %result = scf.parallel (%j) = (%c0) to (%c10) step (%c1) init (%init) { + // %elem_to_reduce = load %buffer[%i, %j, %k] : memref<100x10x5xf32> + // scf.reduce(%elem_to_reduce) { + // + // } : f32 + // scf.yield + // } : f32 + // scf.yield + // } + scf::ReduceOp CreateReduceOpInNestedParallelLoops( + xla_lhlo::ReduceOp xla_reduce_op, + ConversionPatternRewriter* rewriter) const { + auto loc = xla_reduce_op.getLoc(); + DenseSet reducing_dims; + for (const auto& rdim : xla_reduce_op.dimensions().getIntValues()) { + reducing_dims.insert(rdim.getSExtValue()); + } + + Value operand = *xla_reduce_op.operands().begin(); + Value out = *xla_reduce_op.out().begin(); + SmallVector parallel_lower, parallel_upper, parallel_step; + SmallVector reduce_lower, reduce_upper, reduce_step; + auto operand_shape = operand.getType().cast().getShape(); + for (auto dim : llvm::enumerate(operand_shape)) { + const bool is_reducing_dim = reducing_dims.count(dim.index()); + + Value ub = GetStaticOrDynamicDim(loc, operand, dim.index(), dim.value(), + rewriter); + Value lb = rewriter->create(loc, 0); + Value step = rewriter->create(loc, 1); + (is_reducing_dim ? reduce_lower : parallel_lower).push_back(lb); + (is_reducing_dim ? reduce_upper : parallel_upper).push_back(ub); + (is_reducing_dim ? reduce_step : parallel_step).push_back(step); + } + // Load initial value from memref. + SmallVector init_value = { + rewriter->create(loc, *xla_reduce_op.init_values().begin())}; + // Outer ParallelOp is not needed if it is a reduction across all dims. + scf::ParallelOp outer; + if (!parallel_lower.empty()) { + outer = rewriter->create(loc, parallel_lower, + parallel_upper, parallel_step); + rewriter->setInsertionPointToStart(outer.getBody()); + } + scf::ParallelOp inner = rewriter->create( + loc, reduce_lower, reduce_upper, reduce_step, ValueRange(init_value)); + Value reduction_result = *inner.getResults().begin(); + + SmallVector out_indices; + if (outer != nullptr) { + out_indices.reserve(outer.getNumLoops()); + for (Value iv : outer.getInductionVars()) { + out_indices.push_back(iv); + } + } else { + out_indices.push_back(rewriter->create(loc, 0)); + } + + rewriter->create(loc, reduction_result, out, out_indices); + + // Load the element to reduce. + SmallVector indices; + indices.reserve(operand_shape.size()); + + if (outer) { + auto inner_ivs_it = inner.getInductionVars().begin(); + auto outer_ivs_it = outer.getInductionVars().begin(); + for (unsigned i = 0, e = operand_shape.size(); i < e; ++i) { + indices.push_back(reducing_dims.count(i) ? *inner_ivs_it++ + : *outer_ivs_it++); + } + } else { + indices = inner.getInductionVars(); + } + + rewriter->setInsertionPointToStart(inner.getBody()); + Value elem = rewriter->create( + loc, *xla_reduce_op.operands().begin(), indices); + return rewriter->create(loc, elem); + } +}; + +// Pseudocode: +// for each index O in output +// accumulator = neutral_value +// in_bounds = true +// for each index W in window +// for each dimension i from 0 to rank - 1 +// index = O[i] * stride[i] + W[i] - pad_low[i] +// in_bounds = inbounds && (index `ult` shape[i]) +// I[i] = index +// if (in_bounds) +// value = input[I] +// else +// value = neutral_value +// accumulator = reduction_operator(output[O], value) +// output[O] = accumulator +// +// Converts `xla_lhlo.ReduceWindowOp` into two scf::ParallelOp and a +// scf::ReduceOp. +// The outper `ParallelOp` refers to the parallel loops that traverese output +// buffer. The inner `ParalleOp` refers to the reduction loops that traverse +// reduction windows and `ReduceOp` contains the reduction operator. +// +// Example: +// +// func @reduce_window(%arg: memref<112x112xf32>, +// %init: memref, +// %result: memref<56x56xf32>) { +// "xla_lhlo.reduce_window"(%arg, %init, %result) ( { +// ^bb0(%lhs: memref, %rhs: memref, %res: memref): +// "xla_lhlo.maximum"(%lhs, %rhs, %res) +// : (memref, memref, memref) -> () +// "xla_lhlo.terminator"() : () -> () +// }) { +// padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, +// window_dimensions = dense<[3, 3]> : tensor<2xi64>, +// window_strides = dense<[2, 2]> : tensor<2xi64> +// } : (memref<112x112xf32>, memref, memref<56x56xf32>) -> () +// return +// } +// +// is roughly converted into: +// +// %neutral_elem = load %init_buf[] : memref +// scf.parallel (%i, %j) = (%c0, %c0) to (%c56, %c56) step (%c1, %c1) { +// %result = scf.parallel (%iw, %jw) = (%c0, %c0) +// to (%c3, %c3) step (%c1, %c1) neutral_elem (%0) -> f32 { +// %in_bounds = +// %elem = load %operand[%computed_i, %computed_j] +// %elem_or_neutral = select %in_bounds, %elem, %neutral_elem : f32 +// scf.reduce(%elem_to_reduce) : f32 { +// ^bb0(%arg7: f32, %arg8: f32): +// +// } +// scf.yield +// } +// store %result, %output_buffer[%i, %j] : memref<56x56xf32> +// scf.yield +// } +// return +// } +class ReduceWindowOpConverter + : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + xla_lhlo::ReduceWindowOp xla_reduce_window_op, ArrayRef /*args*/, + ConversionPatternRewriter& rewriter) const final { + scf::ParallelOp output_loop, window_loop; + std::tie(output_loop, window_loop) = + CreateParallelLoopsToTraverseOutputAndWindow(xla_reduce_window_op, + &rewriter); + + scf::ReduceOp reduce_op = CreateReduceOpInNestedParallelLoops( + xla_reduce_window_op, output_loop, window_loop, &rewriter); + + ConvertToReductionOperator(xla_reduce_window_op.getLoc(), reduce_op, + &xla_reduce_window_op.body().front(), &rewriter); + rewriter.replaceOp(xla_reduce_window_op, llvm::None); + return success(); + } + + private: + std::pair + CreateParallelLoopsToTraverseOutputAndWindow( + xla_lhlo::ReduceWindowOp xla_reduce_window_op, + ConversionPatternRewriter* rewriter) const { + auto loc = xla_reduce_window_op.getLoc(); + Value init_value = + rewriter->create(loc, xla_reduce_window_op.init_value()); + + Value zero = rewriter->create(loc, 0); + Value one = rewriter->create(loc, 1); + + // Create an outer parallel loop that spans the output of ReduceWindowOp. + Value xla_output = xla_reduce_window_op.out(); + auto output_loop = MakeLoopOverShape(loc, xla_output, rewriter); + + // Create a nested loop that traverses the window. + SmallVector window_lower, window_upper, window_step; + rewriter->setInsertionPointToStart(output_loop.getBody()); + for (const auto& window_dim : xla_reduce_window_op.window_dimensions()) { + window_step.push_back(one); + window_lower.push_back(zero); + window_upper.push_back( + rewriter->create(loc, window_dim.getSExtValue())); + } + auto window_loop = rewriter->create( + loc, window_lower, window_upper, window_step, ValueRange(init_value)); + + Value reduction_result = *window_loop.getResults().begin(); + auto output_ivs = output_loop.getInductionVars(); + rewriter->create(loc, reduction_result, xla_output, output_ivs); + return std::make_pair(output_loop, window_loop); + } + + scf::ReduceOp CreateReduceOpInNestedParallelLoops( + xla_lhlo::ReduceWindowOp xla_reduce_window_op, + scf::ParallelOp output_loop, scf::ParallelOp window_loop, + ConversionPatternRewriter* rewriter) const { + rewriter->setInsertionPointToStart(window_loop.getBody()); + auto loc = xla_reduce_window_op.getLoc(); + + if (xla_reduce_window_op.base_dilations().hasValue() || + xla_reduce_window_op.window_dilations().hasValue()) { + xla_reduce_window_op.emitRemark( + "Lowering to parallel loops does not support `base_dilations` or " + "`window_dilations` attributes yet. The attributes will be ignored."); + } + + Value xla_operand = xla_reduce_window_op.operand(); + auto xla_operand_type = xla_operand.getType().cast(); + + // Compute ivs in 'arg' buffer and whether these ivs are in pad area or not. + MappedIvs mapped_ivs = MapWindowIvsToInput( + xla_reduce_window_op, output_loop.getInductionVars(), + window_loop.getInductionVars(), rewriter); + + auto elem_or_init = rewriter->create( + loc, xla_operand_type.getElementType(), mapped_ivs.in_bounds, + /*withElseRegion=*/true); + + OpBuilder then_builder = elem_or_init.getThenBodyBuilder(); + Value elem = then_builder.create( + loc, xla_reduce_window_op.operand(), mapped_ivs.ivs); + then_builder.create(loc, elem); + + OpBuilder else_builder = elem_or_init.getElseBodyBuilder(); + else_builder.create(loc, *window_loop.initVals().begin()); + + return rewriter->create(loc, + *elem_or_init.results().begin()); + } +}; + +// See the operation semantics in +// https://www.tensorflow.org/xla/operation_semantics#selectandscatter +// +// Pseudocode: +// scf.parallel(coordinates O in the output): +// output[O] = init +// scf.parallel(coordinates S in the source): +// selected_ivs = 0 +// selected_val = 0 +// initialized_flag = false +// scf.for (first dim W_1 in the window) +// iter_args (selected_ivs, selected_val, initialized_flag): +// ... +// scf.for (last dim W_N in the window): +// iter_args (selected_ivs, selected_val, initialized_flag): +// I = S * stride + W - pad_low +// if I within bounds of operand: +// if (initialized_flag): +// pred = select(selected_value, operand(I))): +// if (pred) +// selected_value = operand(I) +// selected_index = I +// else +// selected_value = operand(I) +// selected_index = I +// initialized_flag = true +// output(selected_index) = scatter(output(selected_index), source(S)) +class SelectAndScatterOpConverter + : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + xla_lhlo::SelectAndScatterOp s_and_s_op, ArrayRef /*args*/, + ConversionPatternRewriter& rewriter) const final { + auto loc = s_and_s_op.getLoc(); + InitializeOutput(s_and_s_op, &rewriter); + scf::ParallelOp loop_over_src = + MakeLoopOverShape(loc, s_and_s_op.source(), &rewriter); + rewriter.setInsertionPointToStart(loop_over_src.getBody()); + + // Compute indices of the selected element in the window. + auto selected_ivs = SelectIvs(s_and_s_op, loop_over_src, &rewriter); + + // Load `source[selected_ivs]`. + auto src_elem = rewriter.create(loc, s_and_s_op.source(), + loop_over_src.getInductionVars()); + + // Compute `out[selected_ivs]` = scatter(out[selected_ivs], src_element)`. + auto rmw = rewriter.create(loc, s_and_s_op.out(), + selected_ivs); + OpBuilder rmw_builder = OpBuilder::atBlockEnd(rmw.getBody()); + auto acc_result = + ApplySingleResultLhloCode(loc, {src_elem, rmw.getCurrentValue()}, + &s_and_s_op.scatter().front(), &rmw_builder); + rmw_builder.create(loc, acc_result); + + rewriter.replaceOp(s_and_s_op, llvm::None); + return success(); + } + + private: + void InitializeOutput(xla_lhlo::SelectAndScatterOp s_and_s_op, + OpBuilder* b) const { + auto loc = s_and_s_op.getLoc(); + Value init_value = b->create(loc, s_and_s_op.init_value()); + + scf::ParallelOp loop_over_output = + MakeLoopOverShape(loc, s_and_s_op.out(), b); + OpBuilder::InsertionGuard guard(*b); + b->setInsertionPointToStart(loop_over_output.getBody()); + b->create(loc, init_value, s_and_s_op.out(), + loop_over_output.getInductionVars()); + } + + struct WindowLoops { + SmallVector selected_ivs; + SmallVector window_ivs; + scf::ForOp inner_loop; + }; + WindowLoops InsertWindowLoops(xla_lhlo::SelectAndScatterOp s_and_s_op, + scf::ParallelOp loop_over_src, + OpBuilder* b) const { + auto loc = s_and_s_op.getLoc(); + Value zero = b->create(loc, 0); + Value one = b->create(loc, 1); + + auto element_type = + s_and_s_op.out().getType().cast().getElementType(); + auto rank = loop_over_src.getNumLoops(); + + // `iter_args` = [iv_1, ..., iv_N, selected_value, is_initialized] + SmallVector iter_args(rank, zero); + iter_args.push_back(b->create( + loc, element_type, b->getFloatAttr(element_type, 0))); + iter_args.push_back(b->create( + loc, b->getI1Type(), b->getIntegerAttr(b->getI1Type(), 0))); + + // Create a nested loop that traverses the window. + OpBuilder::InsertPoint ip; + WindowLoops result; + for (const auto& window_dim : + s_and_s_op.window_dimensions()->getIntValues()) { + Value upper = b->create(loc, window_dim.getSExtValue()); + result.inner_loop = + b->create(loc, zero, upper, one, iter_args); + if (b->getInsertionBlock() == loop_over_src.getBody()) { + ip = b->saveInsertionPoint(); + result.selected_ivs = result.inner_loop.getResults().take_front(rank); + } else { + b->create(loc, result.inner_loop.getResults()); + } + b->setInsertionPointToStart(result.inner_loop.getBody()); + iter_args = ValueRange{result.inner_loop.getRegionIterArgs()}; + result.window_ivs.push_back(result.inner_loop.getInductionVar()); + } + b->restoreInsertionPoint(ip); + return result; + } + + // Adapter to store iteration arguments of sequential loops that perform + // select in a window. + class IterArgs { + public: + explicit IterArgs(ValueRange ivs_val_flag) : ivs_val_flag_(ivs_val_flag) {} + IterArgs(ValueRange ivs, Value value, Value flag) { + ivs_val_flag_ = ivs; + ivs_val_flag_.push_back(value); + ivs_val_flag_.push_back(flag); + } + + ArrayRef to_vector() const { return ivs_val_flag_; } + + // Indices of the currently selected value. + ArrayRef ivs() const { return to_vector().drop_back(2); } + // Currently selected value w.r.t. select() function. + Value value() const { return ivs_val_flag_.end()[-2]; } + // i1 flag if value() and ivs() were initialized. + Value is_init() const { return ivs_val_flag_.back(); } + + private: + // Vector that stores iv_1, ..., iv_N, value, init. + SmallVector ivs_val_flag_; + }; + + SmallVector SelectIvs(xla_lhlo::SelectAndScatterOp s_and_s_op, + scf::ParallelOp loop_over_src, + OpBuilder* b) const { + auto loc = s_and_s_op.getLoc(); + + WindowLoops window_loops = InsertWindowLoops(s_and_s_op, loop_over_src, b); + auto inner_loop_b = + OpBuilder::atBlockEnd(window_loops.inner_loop.getBody()); + + // Compute ivs in 'arg' buffer and whether these ivs are in the pad area. + MappedIvs mapped_ivs = + MapWindowIvsToInput(s_and_s_op, loop_over_src.getInductionVars(), + window_loops.window_ivs, &inner_loop_b); + + IterArgs ivs_val_flag(window_loops.inner_loop.getRegionIterArgs()); + + auto if_in_bounds = inner_loop_b.create( + loc, window_loops.inner_loop.getResultTypes(), mapped_ivs.in_bounds, + /*withElseRegion=*/true); + + // Case when we are inside boundaries of 'arg' and not in the pad area. + { + OpBuilder in_bounds_then_b = if_in_bounds.getThenBodyBuilder(); + auto select_or_init_results = SelectOrInitialize( + s_and_s_op, mapped_ivs.ivs, &ivs_val_flag, &in_bounds_then_b); + in_bounds_then_b.create(loc, select_or_init_results); + } + + // Case when we are in the pad. + { + OpBuilder in_bounds_else_b = if_in_bounds.getElseBodyBuilder(); + in_bounds_else_b.create(loc, ivs_val_flag.to_vector()); + } + + inner_loop_b.create(loc, if_in_bounds.getResults()); + return window_loops.selected_ivs; + } + + SmallVector SelectOrInitialize( + xla_lhlo::SelectAndScatterOp s_and_s_op, ArrayRef operand_ivs, + IterArgs* ivs_val_flag, OpBuilder* b) const { + auto loc = s_and_s_op.getLoc(); + Value true_i1 = b->create( + loc, b->getI1Type(), b->getIntegerAttr(b->getI1Type(), 1)); + + TypeRange iter_arg_types{ivs_val_flag->to_vector()}; + Value operand_elem = + b->create(loc, s_and_s_op.operand(), operand_ivs); + auto if_init = + b->create(loc, iter_arg_types, ivs_val_flag->is_init(), + /*withElseRegion=*/true); + // Init == true, i.e. iter args are already initialized with a selected + // element in boundaries of the operand. Select function has to be computed + // here. + { + OpBuilder if_init_then_b = if_init.getThenBodyBuilder(); + + auto& lhlo_select = s_and_s_op.select().front(); + Value pred = + ApplySingleResultLhloCode(loc, {operand_elem, ivs_val_flag->value()}, + &lhlo_select, &if_init_then_b); + + auto if_pred = if_init_then_b.create(loc, iter_arg_types, pred, + /*withElseRegion=*/true); + + // Pred == true, therefore pack newly selected ivs, val and init flag back + // to iter_args and return. + { + OpBuilder if_pred_then_b = if_pred.getThenBodyBuilder(); + if_pred_then_b.create( + loc, IterArgs{operand_ivs, operand_elem, true_i1}.to_vector()); + } + + // Pred == false, therefore return old iter_args. + { + OpBuilder if_pred_else_b = if_pred.getElseBodyBuilder(); + if_pred_else_b.create(loc, ivs_val_flag->to_vector()); + } + + if_init_then_b.create(loc, if_pred.getResults()); + } + // Init == false, i.e. only pad was visited before and this is the first + // element in the boundaries of the operand. + { + OpBuilder if_init_else_b = if_init.getElseBodyBuilder(); + + if_init_else_b.create( + loc, IterArgs{operand_ivs, operand_elem, true_i1}.to_vector()); + } + return if_init.getResults(); + } +}; + +struct LhloLegalizeToParallelLoops + : public PassWrapper { + void runOnFunction() override { + auto func = getFunction(); + + OwningRewritePatternList patterns; + // clang-format off + patterns.insert< + ReduceOpConverter, + ReduceWindowOpConverter, + SelectAndScatterOpConverter + >(func.getContext()); + // clang-format on + + ConversionTarget target(getContext()); + target.addLegalDialect(); + target.addIllegalOp(); + + if (failed(applyPartialConversion(func, target, patterns))) { + signalPassFailure(); + } + } +}; + +} // namespace + +std::unique_ptr> createLegalizeLhloToParallelLoopsPass() { + return absl::make_unique(); +} + +static PassRegistration legalize_lhlo_pass( + "lhlo-legalize-to-parallel-loops", + "Legalize from LHLO dialect to parallel loops."); + +} // namespace xla_lhlo +} // namespace mlir diff --git a/lib/Dialect/mhlo/transforms/lower_complex.cc b/lib/Dialect/mhlo/transforms/lower_complex.cc new file mode 100644 index 0000000..2b85ef4 --- /dev/null +++ b/lib/Dialect/mhlo/transforms/lower_complex.cc @@ -0,0 +1,79 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Thsi file implements passes to convert complex operations to equivalent real +// value operations. This does not include removing complex values from function +// argument or return types. + +#include +#include +#include +#include + +#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLExtras.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/TypeUtilities.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Types.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/PassRegistry.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h" + +using mlir::FunctionPass; +using mlir::OwningRewritePatternList; +using mlir::PassRegistration; +using mlir::PassWrapper; + +namespace { +class LowerComplex : public PassWrapper { + public: + explicit LowerComplex() : PassWrapper() {} + + /// Performs the lowering to XLA dialect. + void runOnFunction() override; +}; +} // end anonymous namespace + +namespace mlir { +namespace xla { +namespace { + +#include "third_party/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/generated_lower_complex.inc" + +} // end anonymous namespace + +void PopulateComplexLoweringPatterns(MLIRContext* context, + OwningRewritePatternList* patterns) { + populateWithGenerated(context, patterns); +} +} // end namespace xla +} // end namespace mlir + +// Lowers the complex operations that can be represented using other operations. +void LowerComplex::runOnFunction() { + // Add lowering patterns to the list. + OwningRewritePatternList patterns; + mlir::xla::PopulateComplexLoweringPatterns(&getContext(), &patterns); + + applyPatternsAndFoldGreedily(getFunction(), patterns); +} + +static PassRegistration pass( + "test-xla-lower-complex", + "Lower complex operations into non-complex operations"); diff --git a/lib/Dialect/mhlo/transforms/lower_complex_patterns.td b/lib/Dialect/mhlo/transforms/lower_complex_patterns.td new file mode 100644 index 0000000..93b5065 --- /dev/null +++ b/lib/Dialect/mhlo/transforms/lower_complex_patterns.td @@ -0,0 +1,109 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This is the legalization pattern that converts complex operations into +// equivalent real value operations. + +include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpBase.td" +include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td" +include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td" + +//===----------------------------------------------------------------------===// +// Binary op patterns. +//===----------------------------------------------------------------------===// + +// Add and subtraction are elementwise and can be distributed across the real +// and imaginary components. +foreach elementwiseOp = [HLO_AddOp, HLO_SubOp] in + def : Pat<(elementwiseOp HLO_ComplexTensor:$lhs, + HLO_ComplexTensor:$rhs), + (HLO_ComplexOp + (elementwiseOp (HLO_RealOp $lhs), (HLO_RealOp $rhs)), + (elementwiseOp (HLO_ImagOp $lhs), (HLO_ImagOp $rhs)))>; + +// Complex multiplication results in a cross product multiplication between the +// real and imaginary components such that: +// result.real = lhs.real * rhs.real - lhs.imag * rhs.imag +// result.imag = lhs.imag * rhs.real + lhs.real * rhs.imag +def : Pat<(HLO_MulOp HLO_ComplexTensor:$lhs, + HLO_ComplexTensor:$rhs), + (HLO_ComplexOp + (HLO_SubOp + (HLO_MulOp + (HLO_RealOp:$lhs_real $lhs), + (HLO_RealOp:$rhs_real $rhs)), + (HLO_MulOp + (HLO_ImagOp:$lhs_imag $lhs), + (HLO_ImagOp:$rhs_imag $rhs))), + (HLO_AddOp + (HLO_MulOp $lhs_real, $rhs_imag), + (HLO_MulOp $lhs_imag, $rhs_real)))>; + +// Multiplication between a complex and real tensor can be distributed by +// applying the real multiplicant to both the real and complex component. +// +// Note that the sourcep pattern is not legal according to the HLO dialect but +// instead handle intermediates generated by other patterns. +def : Pat<(HLO_MulOp HLO_ComplexTensor:$lhs, HLO_IntOrFpTensor:$rhs), + (HLO_ComplexOp + (HLO_MulOp (HLO_RealOp $lhs), $rhs), + (HLO_MulOp (HLO_ImagOp $lhs), $rhs))>; + +def : Pat<(HLO_MulOp HLO_IntOrFpTensor:$lhs, HLO_ComplexTensor:$rhs), + (HLO_ComplexOp + (HLO_MulOp $lhs, (HLO_RealOp $rhs)), + (HLO_MulOp $lhs, (HLO_ImagOp $rhs)))>; + + +// Division is performed by normalizing the denominator by multiplying by the +// conjugate of the rhs. +// numerator = lhs * conj(rhs) +// denominator = rhs * conj(rhs) +def : Pat<(HLO_DivOp HLO_ComplexTensor:$lhs, HLO_ComplexTensor:$rhs), + (HLO_DivOp + (HLO_MulOp:$num $lhs, + (HLO_ComplexOp:$conj + (HLO_RealOp $rhs), + (HLO_NegOp (HLO_ImagOp $rhs)))), + (HLO_RealOp:$den (HLO_MulOp $rhs, $conj)))>; + + +def : Pat<(HLO_DivOp HLO_ComplexTensor:$lhs, HLO_IntOrFpTensor:$rhs), + (HLO_ComplexOp + (HLO_DivOp (HLO_RealOp $lhs), $rhs), + (HLO_DivOp (HLO_ImagOp $lhs), $rhs))>; + + +// Absolute value is evaluated as: +// result = sqrt(val.real * val.real + val.imag * val.imag) +def : Pat<(HLO_AbsOp HLO_ComplexTensor:$val), + (HLO_ComplexOp + (HLO_SqrtOp + (HLO_AddOp + (HLO_MulOp (HLO_RealOp:$real $val), $real), + (HLO_MulOp (HLO_ImagOp:$imag $val), $imag))), + (HLO_ConstOp (ConstantSplat<"0"> $real)))>; + +// Exponential can be lowered to an exponential on the real component and a +// sum of sinusoids of the imaginary component, which equates to a normal +// exponential operator multiplied by Euler's formula. +// +// Exp(a + ib) = Exp(a) * Exp(ib) = Exp(a) * (Cos(b) + iSin(b)) +def : Pat<(HLO_ExpOp HLO_ComplexTensor:$val), + (HLO_MulOp + (HLO_ExpOp (HLO_RealOp $val)), + (HLO_ComplexOp + (HLO_CosOp (HLO_ImagOp:$imag $val)), + (HLO_SinOp $imag)))>; diff --git a/lib/Dialect/mhlo/transforms/lower_general_dot.cc b/lib/Dialect/mhlo/transforms/lower_general_dot.cc new file mode 100644 index 0000000..4b38c34 --- /dev/null +++ b/lib/Dialect/mhlo/transforms/lower_general_dot.cc @@ -0,0 +1,194 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file implements logic for lowering XLA general dot to a regular dot. + +#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLExtras.h" +#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringSwitch.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Function.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Location.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/TypeUtilities.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" + +using mlir::DenseIntElementsAttr; +using mlir::ElementsAttr; +using mlir::failure; +using mlir::FunctionPass; +using mlir::LogicalResult; +using mlir::MLIRContext; +using mlir::OpRewritePattern; +using mlir::OwningRewritePatternList; +using mlir::PassRegistration; +using mlir::PassWrapper; +using mlir::PatternRewriter; +using mlir::RankedTensorType; +using mlir::success; +using mlir::Value; + +namespace { + +Value TransposeReshape(Value arg, mlir::Location loc, + llvm::ArrayRef left_dims, + llvm::ArrayRef right_dims, + llvm::ArrayRef arg_shape, + PatternRewriter *rewriter) { + auto element_type = mlir::getElementTypeOrSelf(arg.getType()); + + int64_t left_size = 1; + for (auto dim : left_dims) { + left_size *= arg_shape[dim]; + } + + int64_t right_size = 1; + for (auto dim : right_dims) { + right_size *= arg_shape[dim]; + } + + // Generate the transpose permutation attribute. + llvm::SmallVector transpose_permutation(left_dims.begin(), + left_dims.end()); + transpose_permutation.append(right_dims.begin(), right_dims.end()); + + mlir::TensorType transpose_permutation_type = RankedTensorType::get( + {static_cast(transpose_permutation.size())}, + rewriter->getIntegerType(64)); + + auto transpose_permutation_attr = + DenseIntElementsAttr::get(transpose_permutation_type, + llvm::makeArrayRef(transpose_permutation)) + .cast(); + + // Compute the resulting shape. + llvm::SmallVector transposed_shape; + for (auto val : transpose_permutation) { + transposed_shape.push_back(arg_shape[val]); + } + auto transpose_type = RankedTensorType::get(transposed_shape, element_type); + auto transpose_result = rewriter->create( + loc, transpose_type, arg, transpose_permutation_attr); + + // Return the final result. + auto reshaped_type = + RankedTensorType::get({left_size, right_size}, element_type); + return rewriter->create(loc, reshaped_type, + transpose_result); +} + +Value ProcessDotArg(Value arg, mlir::Location loc, + ElementsAttr contract_dims_attr, bool outer_dims_first, + PatternRewriter *rewriter) { + auto shape = arg.getType().cast().getShape(); + + llvm::SmallVector is_outer_dim; + is_outer_dim.resize(shape.size(), true); + + // Compute the contract dimension ordering. + llvm::SmallVector contract_dims; + for (auto dim : contract_dims_attr.getValues()) { + contract_dims.push_back(dim); + is_outer_dim[dim] = false; + } + + // Compute the outer dimension orderings. + llvm::SmallVector outer_dims; + for (auto it : llvm::enumerate(is_outer_dim)) { + if (it.value()) { + outer_dims.push_back(it.index()); + } + } + + if (outer_dims_first) { + return TransposeReshape(arg, loc, outer_dims, contract_dims, shape, + rewriter); + } + + return TransposeReshape(arg, loc, contract_dims, outer_dims, shape, rewriter); +} + +struct GeneralDotConvert + : public OpRewritePattern { + // Attempts to lower a General Dot operator to a standard Dot operator. + // General dots include batching dimensions and can have collapsing + // dimensions along any axis. Inserting correctly arrange transpose and + // reshape operators organizes the tensors and allows the General Dot to be + // replaced with the standard Dot operator. + // + // Note: This requires an empty list of batch dimensions. + + explicit GeneralDotConvert(MLIRContext *context) + : OpRewritePattern(context) {} + + LogicalResult matchAndRewrite(mlir::xla_hlo::DotGeneralOp op, + PatternRewriter &rewriter) const override { + auto dot_element_type = mlir::getElementTypeOrSelf(op); + + auto dot_numbers = op.dot_dimension_numbers(); + if (dot_numbers.lhs_batching_dimensions().getNumElements() != 0 || + dot_numbers.rhs_batching_dimensions().getNumElements() != 0) { + return failure(); + } + + auto lhs = ProcessDotArg(op.lhs(), op.getLoc(), + dot_numbers.lhs_contracting_dimensions(), + /*outer_dims_first=*/true, &rewriter); + + auto rhs = ProcessDotArg(op.rhs(), op.getLoc(), + dot_numbers.rhs_contracting_dimensions(), + /*outer_dims_first=*/false, &rewriter); + + // Dot resulting shape. + auto lhs_shape = lhs.getType().cast().getShape(); + auto rhs_shape = rhs.getType().cast().getShape(); + auto new_dot_type = + RankedTensorType::get({lhs_shape[0], rhs_shape[1]}, dot_element_type); + + auto new_dot_op = rewriter.create( + op.getLoc(), new_dot_type, lhs, rhs, *(op.precision_config())); + + rewriter.replaceOpWithNewOp(op, op.getType(), + new_dot_op); + return success(); + } +}; + +struct LegalizeGeneralDot + : public PassWrapper { + /// Lower all general dots that can be represented as a non-batched matmul. + void runOnFunction() override { + OwningRewritePatternList patterns; + mlir::xla_hlo::PopulateGeneralDotOpLoweringPatterns(&patterns, + &getContext()); + applyPatternsAndFoldGreedily(getFunction(), patterns); + } +}; + +} // namespace + +void mlir::xla_hlo::PopulateGeneralDotOpLoweringPatterns( + OwningRewritePatternList *patterns, MLIRContext *ctx) { + patterns->insert(ctx); +} + +static PassRegistration legalize_pass( + "test-xla-lower-general-dot", + "Tests lowering general dot to a non-batched dot when possible"); diff --git a/lib/Dialect/mhlo/transforms/materialize_broadcasts.cc b/lib/Dialect/mhlo/transforms/materialize_broadcasts.cc new file mode 100644 index 0000000..074f97c --- /dev/null +++ b/lib/Dialect/mhlo/transforms/materialize_broadcasts.cc @@ -0,0 +1,90 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" + +namespace mlir { +namespace xla_hlo { + +namespace { + +// Converts ClampOp with broadcast semantics. ClampOp requires "all three arrays +// must be the same shape. Alternatively, as a restricted form of broadcasting, +// min and/or max can be a scalar of type T." +struct ClampWithBroadcastConvert : public OpRewritePattern { + explicit ClampWithBroadcastConvert(MLIRContext *context) + : OpRewritePattern(context) {} + + LogicalResult matchAndRewrite(ClampOp op, + PatternRewriter &rewriter) const override { + auto operand_type = op.operand().getType().dyn_cast(); + auto max_type = op.max().getType().dyn_cast(); + auto min_type = op.min().getType().dyn_cast(); + // Unrancked types are not supported. + if (!operand_type || !max_type || !min_type) return failure(); + // Does not support operand with dynamic dimensions for now. + if (!operand_type.hasStaticShape()) return failure(); + + ArrayRef operand_shape = operand_type.getShape(); + + Value max_value = op.max(); + if (max_type != operand_type) { + assert(max_type.getRank() == 0); + max_value = rewriter.createOrFold( + op.getLoc(), operand_type, max_value, + rewriter.getI64TensorAttr(operand_shape)); + } + + Value min_value = op.min(); + if (min_type != operand_type) { + assert(min_type.getRank() == 0); + min_value = rewriter.createOrFold( + op.getLoc(), operand_type, min_value, + rewriter.getI64TensorAttr(operand_shape)); + } + + rewriter.replaceOpWithNewOp(op, op.getType(), min_value, + op.operand(), max_value); + return success(); + } +}; + +} // namespace + +void SetupMaterializeBroadcastsLegality(MLIRContext *context, + ConversionTarget *conversionTarget) { + conversionTarget->addDynamicallyLegalOp([](ClampOp op) { + return op.max().getType() == op.operand().getType() && + op.min().getType() == op.operand().getType(); + }); +} + +void PopulateMaterializeBroadcastsPatterns(MLIRContext *context, + OwningRewritePatternList *patterns) { + // ClampOp. This op has a special case where it accepts either same-shaped + // inputs or scalars (a restricted form of broadcasting). This makes the + // broadcast explicit. + patterns->insert(context); +} + +} // namespace xla_hlo +} // namespace mlir diff --git a/lib/Dialect/mhlo/transforms/materialize_broadcasts_pass.cc b/lib/Dialect/mhlo/transforms/materialize_broadcasts_pass.cc new file mode 100644 index 0000000..2106ec3 --- /dev/null +++ b/lib/Dialect/mhlo/transforms/materialize_broadcasts_pass.cc @@ -0,0 +1,58 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" + +namespace mlir { +namespace xla_hlo { + +namespace { + +struct TestMaterializeBroadcastsPass + : public PassWrapper { + void runOnFunction() override { + ConversionTarget conversionTarget(getContext()); + OwningRewritePatternList conversionPatterns; + + // Consider the xla_hlo dialect legal for tests. + conversionTarget.addLegalDialect(); + // The conversion uses helpers from the Standard dialect. + conversionTarget.addLegalDialect(); + + SetupMaterializeBroadcastsLegality(&getContext(), &conversionTarget); + PopulateMaterializeBroadcastsPatterns(&getContext(), &conversionPatterns); + + if (failed(applyPartialConversion(getFunction(), conversionTarget, + conversionPatterns))) { + return signalPassFailure(); + } + } +}; + +} // namespace + +} // namespace xla_hlo +} // namespace mlir + +static mlir::PassRegistration + pass("test-xla-materialize-broadcasts", + "Test pass for materializing 'broadcast_dimensions' attributes"); diff --git a/lib/Dialect/mhlo/transforms/sink_constants_to_control_flow.cc b/lib/Dialect/mhlo/transforms/sink_constants_to_control_flow.cc new file mode 100644 index 0000000..666ca53 --- /dev/null +++ b/lib/Dialect/mhlo/transforms/sink_constants_to_control_flow.cc @@ -0,0 +1,85 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/DenseMap.h" +#include "third_party/llvm/llvm-project/llvm/include/llvm/Support/Casting.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/PassManager.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Support/LLVM.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/RegionUtils.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" + +namespace mlir { +namespace xla_hlo { + +namespace { + +// A pass that sinks constants implicitly captured in control flow regions. This +// is necessary to export to XLA. +class SinkConstantsToControlFlow + : public mlir::PassWrapper { + void runOnFunction() override { + getFunction().walk([](Operation* op) { + if (auto while_op = llvm::dyn_cast(op)) { + SinkToRegion(&while_op.body()); + SinkToRegion(&while_op.cond()); + } else if (auto if_op = llvm::dyn_cast(op)) { + SinkToRegion(&if_op.true_branch()); + SinkToRegion(&if_op.false_branch()); + } + }); + } + + private: + // Performs constant sinking into a region. + static void SinkToRegion(Region* region) { + llvm::DenseMap sunk_constant; + visitUsedValuesDefinedAbove({*region}, [&](OpOperand* use) { + Value constant = use->get(); + auto const_op = dyn_cast_or_null(constant.getDefiningOp()); + if (!const_op) return; + auto map_entry = sunk_constant.try_emplace(constant, nullptr); + if (!map_entry.second) { + // This constant has already been cloned into the region, reuse it. + use->set(map_entry.first->getSecond().getResult()); + if (constant.use_empty()) const_op.erase(); + return; + } + if (constant.hasOneUse()) { + const_op.getOperation()->moveBefore(®ion->front().front()); + return; + } + map_entry.first->getSecond() = const_op.clone(); + region->front().getOperations().insert(region->front().begin(), + map_entry.first->getSecond()); + use->set(map_entry.first->getSecond().getResult()); + }); + } +}; + +static mlir::PassRegistration pass( + "xla-hlo-sink-constants-to-control-flow", + "Sink constants implicitly captured in control flow regions. This is " + "necessary to export to XLA."); + +} // anonymous namespace + +std::unique_ptr> createSinkConstantsToControlFlowPass() { + return std::make_unique(); +} + +} // namespace xla_hlo +} // namespace mlir diff --git a/lib/Dialect/mhlo/transforms/test_infer_shaped_type_pass.cc b/lib/Dialect/mhlo/transforms/test_infer_shaped_type_pass.cc new file mode 100644 index 0000000..a7362f7 --- /dev/null +++ b/lib/Dialect/mhlo/transforms/test_infer_shaped_type_pass.cc @@ -0,0 +1,100 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Identifier.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OperationSupport.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/InferTypeOpInterface.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" + +namespace mlir { +namespace xla { +namespace { + +struct InferReturnTypeComponentsPattern : public RewritePattern { + InferReturnTypeComponentsPattern(MLIRContext *context) + : RewritePattern("xla_test.get_return_type_components", 1, context) {} + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + if (op->getNumOperands() != 1) return failure(); + auto defining_op = op->getOperand(0).getDefiningOp(); + auto defining_op_int = + llvm::dyn_cast_or_null(defining_op); + if (!defining_op_int) return failure(); + SmallVector components; + if (failed(defining_op_int.inferReturnTypeComponents( + op->getContext(), op->getLoc(), defining_op->getOperands(), + defining_op->getAttrDictionary(), defining_op->getRegions(), + components))) { + return failure(); + } + + // Replace the op with another pass-through op with attributes added. + OperationState state(op->getLoc(), "xla_test.return_type_components", + op->getOperands(), op->getResultTypes(), + op->getAttrs()); + auto new_op = rewriter.createOperation(state); + for (auto it : llvm::enumerate(components)) { + if (it.value().hasRank()) { + new_op->setAttr((StringRef("dims") + Twine(it.index())).str(), + rewriter.getI64ArrayAttr(it.value().getDims())); + } + if (it.value().getElementType()) { + new_op->setAttr((Twine("element_type") + Twine(it.index())).str(), + TypeAttr::get(it.value().getElementType())); + } + } + rewriter.replaceOp(op, {new_op->getResults()}); + return success(); + } +}; + +struct ReifyReturnTypeShapesPattern : public RewritePattern { + ReifyReturnTypeShapesPattern(MLIRContext *context) + : RewritePattern("xla_test.reify_return_type_shapes", 1, context) {} + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + if (op->getNumOperands() != 1) return failure(); + auto defining_op = llvm::dyn_cast_or_null( + op->getOperand(0).getDefiningOp()); + if (!defining_op) return failure(); + SmallVector return_shapes; + if (failed(defining_op.reifyReturnTypeShapes(rewriter, return_shapes))) { + return failure(); + } + rewriter.replaceOp(op, return_shapes); + return success(); + } +}; + +struct TestInferShapedTypeMethodsPass + : public PassWrapper { + void runOnFunction() override { + OwningRewritePatternList patterns; + patterns.insert(&getContext()); + patterns.insert(&getContext()); + applyPatternsAndFoldGreedily(getFunction(), patterns); + } +}; + +} // namespace +} // namespace xla +} // namespace mlir + +static mlir::PassRegistration pass( + "test-xla-infer-shaped-type-methods", + "Uses test ops to invoke InferShapedTypeOpInterface methods"); diff --git a/lib/Dialect/mhlo/transforms/unfuse_batch_norm.cc b/lib/Dialect/mhlo/transforms/unfuse_batch_norm.cc new file mode 100644 index 0000000..b0fc6a1 --- /dev/null +++ b/lib/Dialect/mhlo/transforms/unfuse_batch_norm.cc @@ -0,0 +1,184 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/SmallVector.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Types.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" + +namespace mlir { +namespace xla_hlo { + +namespace { + +// Broadcasts the 1D value tensor 'value_1d' to the shape of 'result_type'. If +// 'shape_value' is initialized, creates a dynamic broadcast, otherwise creates +// a static broadcast. +Value BroadcastToFeatureDim(Location loc, RankedTensorType result_type, + Value value_1d, Value shape_value, + int64_t feature_dim, + PatternRewriter& rewriter) { // NOLINT + Builder b(rewriter.getContext()); + auto dims_type = RankedTensorType::get({1}, b.getIntegerType(64)); + auto dims = DenseIntElementsAttr::get(dims_type, {feature_dim}); + if (shape_value) { + return rewriter.createOrFold( + loc, result_type, value_1d, shape_value, dims); + } + assert(result_type.hasStaticShape()); + return rewriter.create(loc, result_type, value_1d, + dims); +} + +// Calculate the shape value of operand, assuming it is a dynamic shape with +// static rank. +Value CalculateShapeValue(Location loc, Value operand, + PatternRewriter& rewriter) { // NOLINT + RankedTensorType result_type = operand.getType().dyn_cast(); + llvm::SmallVector shape_values; + int64_t rank = result_type.getRank(); + shape_values.reserve(rank); + for (int64_t i = 0; i < rank; ++i) { + shape_values.push_back(rewriter.create(loc, operand, i)); + } + return rewriter.create(loc, shape_values); +} + +Value MaterializeEpsilon(Operation* op, FloatAttr epsilon_attr, + FloatType fp_type, Value variance, + RankedTensorType broadcast_to_type, + PatternRewriter& rewriter) { // NOLINT + Builder b(rewriter.getContext()); + if (epsilon_attr.getType() != fp_type) { + // Need to convert. + bool loses_info; + APFloat epsilon_float = epsilon_attr.getValue(); + auto status = epsilon_float.convert( + fp_type.getFloatSemantics(), APFloat::rmNearestTiesToEven, &loses_info); + if ((status & (~APFloat::opInexact)) != APFloat::opOK) { + op->emitWarning() << "Could not convert batch_norm epsilon to target fp " + "type: opStatus = " + << static_cast(status); + return nullptr; + } + if (loses_info) { + op->emitWarning("Conversion of epsilon loses precision"); + } + epsilon_attr = b.getFloatAttr(fp_type, epsilon_float); + } + + auto scalar_type = RankedTensorType::get({}, fp_type); + auto epsilon_tensor_attr = + DenseElementsAttr::get(scalar_type, {epsilon_attr.cast()}); + Value epsilon = + rewriter.create(op->getLoc(), epsilon_tensor_attr); + auto dims_type = RankedTensorType::get({0}, b.getIntegerType(64)); + auto dims = DenseIntElementsAttr::get(dims_type, SmallVector{}); + if (broadcast_to_type.hasStaticShape()) { + return rewriter.create( + op->getLoc(), broadcast_to_type, epsilon, /*broadcast_dims=*/dims); + } + Value shape_value = CalculateShapeValue(op->getLoc(), variance, rewriter); + return rewriter.createOrFold( + op->getLoc(), broadcast_to_type, epsilon, shape_value, + /*broadcast_dims=*/dims); +} + +class UnfuseBatchNormInferencePattern + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(xla_hlo::BatchNormInferenceOp bn_op, + PatternRewriter& rewriter) const override { + // Enforce type invariants. + // Note that we deduce the actual element type from the variance, + // which should not be subject to quantization at a higher level. + auto input_type = bn_op.operand().getType().dyn_cast(); + auto variance_type = + bn_op.variance().getType().dyn_cast(); + if (!input_type || !variance_type) { + return failure(); + } + auto fp_type = variance_type.getElementType().dyn_cast(); + if (!fp_type) { + return failure(); + } + int64_t feature_dim = bn_op.feature_index().getSExtValue(); + + // Add epsilon to the variance and sqrt to get stddev: + // stddev = sqrt(variance + epsilon) + auto epsilon = + MaterializeEpsilon(bn_op.getOperation(), bn_op.epsilonAttr(), fp_type, + bn_op.variance(), variance_type, rewriter); + if (!epsilon) { + return failure(); + } + Value stddev = rewriter.create(bn_op.getLoc(), + bn_op.variance(), epsilon); + stddev = rewriter.create(bn_op.getLoc(), stddev); + + // Broadcast all terms. + Value shape_value; + if (!input_type.hasStaticShape()) { + shape_value = + CalculateShapeValue(bn_op.getLoc(), bn_op.operand(), rewriter); + } + auto broadcast_scale = + BroadcastToFeatureDim(bn_op.getLoc(), input_type, bn_op.scale(), + shape_value, feature_dim, rewriter); + auto broadcast_offset = + BroadcastToFeatureDim(bn_op.getLoc(), input_type, bn_op.offset(), + shape_value, feature_dim, rewriter); + auto broadcast_mean = + BroadcastToFeatureDim(bn_op.getLoc(), input_type, bn_op.mean(), + shape_value, feature_dim, rewriter); + auto broadcast_stddev = BroadcastToFeatureDim( + bn_op.getLoc(), input_type, stddev, shape_value, feature_dim, rewriter); + + // Compute: + // scale * (input - mean) / stddev + offset + Value result = rewriter.create( + bn_op.getLoc(), bn_op.operand(), broadcast_mean); + result = rewriter.create(bn_op.getLoc(), result, + broadcast_scale); + result = rewriter.create(bn_op.getLoc(), result, + broadcast_stddev); + rewriter.replaceOpWithNewOp(bn_op, result, + broadcast_offset); + + return success(); + } +}; + +} // namespace + +// Populates conversion patterns to unfuse batch normalization operations. +// In combination with marking such ops as illegal, this allows backends that +// do not have special support for fused batchnorm to use simpler arithmetic +// primitives. +void PopulateUnfuseBatchNormPatterns(MLIRContext* context, + OwningRewritePatternList* patterns) { + patterns->insert(context); +} + +} // namespace xla_hlo +} // namespace mlir diff --git a/lib/Dialect/mhlo/transforms/unfuse_batch_norm_pass.cc b/lib/Dialect/mhlo/transforms/unfuse_batch_norm_pass.cc new file mode 100644 index 0000000..179b63c --- /dev/null +++ b/lib/Dialect/mhlo/transforms/unfuse_batch_norm_pass.cc @@ -0,0 +1,46 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" + +namespace mlir { +namespace xla_hlo { + +namespace { + +struct TestUnfuseBatchNormPass + : public PassWrapper> { + void runOnOperation() override { + OwningRewritePatternList patterns; + PopulateUnfuseBatchNormPatterns(&getContext(), &patterns); + applyPatternsAndFoldGreedily(getOperation(), patterns); + } +}; + +} // namespace + +} // namespace xla_hlo +} // namespace mlir + +static mlir::PassRegistration pass( + "test-xla-unfuse-batch-norm", + "Test pass for materializing 'broadcast_dimensions' attributes"); diff --git a/lib/Dialect/mhlo/transforms/xla_hlo_fusion.cc b/lib/Dialect/mhlo/transforms/xla_hlo_fusion.cc new file mode 100644 index 0000000..2cde14a --- /dev/null +++ b/lib/Dialect/mhlo/transforms/xla_hlo_fusion.cc @@ -0,0 +1,579 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project +#include "mlir/IR/MLIRContext.h" // TF:llvm-project +#include "mlir/IR/Matchers.h" +#include "mlir/Pass/Pass.h" // TF:local_config_mlir +#include "mlir/Transforms/RegionUtils.h" // TF:llvm-project +#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/EquivalenceClasses.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/cycle_detector.h" + +// This pass has similar functionality of the fusion pass in XLA stack. +// However, unlike XLA, it targets the fully dynamic shape scenario. +// Currently, it implements the kLoop and kInput fusion templates. +// During conversion, it tries to greedily find kLoop/kInput fusion +// patterns. +// +// Similar to XLA, this pass supports fusion pattern having multiple outputs +// if all the shape of outputs are consistent. Following are some examples. +// +// kLoop kInput +// +----+ +----+ +----+ +----+ +----+ +----+ +// |elem| |elem| |elem| |elem<----+elem+---->elem+----+ +// +-+--+ +-+--+ +-+--+ +-+--+ +----+ +-+--+ | +// | | | | | | +// | | | | | +// +-v--+ | +-v--+ +--v---+ +--v---+ | +// |elem+<---+----<+elem| |reduce| |reduce| | +// +-+--+ +-+--+ +--+---+ +--+---+ | +// | | | | | +// | | | | | +// v v v v v +// +// To this end, we also add an simple shape constraint analysis phase. +// For kLoop fusion template, it requires all the outputs of the fused +// pattern have the same shape. However, we don't know the actual value +// of the shape at the compile time in the dynamic shape world. +// Fortunately, we could still infer the relationship among different ops +// according to their shape constrain traits. Currently, We only consider +// shape equality propagation for elementwise ops (assuming that implicit +// shape broadcast is forbidden). The above process could be built on the +// shape dialect once it is ready. + +namespace mlir { +namespace xla_hlo { +namespace { + +using llvm::EquivalenceClasses; +using FusionPattern = std::vector; +using FusionPlan = std::vector; + +// To support using EquivalenceClasses for Value +class ValueWrapper { + public: + explicit ValueWrapper(Value value) : value_(std::move(value)) {} + + Value getValue() const { return value_; } + + bool operator==(const ValueWrapper& rhs) const { + return getValue() == rhs.getValue(); + } + + private: + Value value_; +}; + +bool operator<(const ValueWrapper& lhs, const ValueWrapper& rhs) { + auto lhs_value = lhs.getValue().getAsOpaquePointer(); + auto rhs_value = rhs.getValue().getAsOpaquePointer(); + return lhs_value < rhs_value; +} + +bool IsFusible(Operation* op) { + if (matchPattern(op, m_Constant())) { + return true; + } + auto op_fusibility = dyn_cast(op); + return op_fusibility && (op_fusibility.isFusibleWithOperand() || + op_fusibility.isFusibleWithConsumer()); +} + +SmallVector GetInputsOfFusionPattern(const FusionPattern& pattern) { + SmallVector inputs; + DenseSet input_set; + DenseSet op_set; + for (Operation* op : pattern) { + bool inserted = op_set.insert(op).second; + (void)inserted; + assert(inserted && "FusionPattern contains duplicate operations"); + } + + for (Operation* op : pattern) { + for (Value operand : op->getOperands()) { + Operation* operand_op = operand.getDefiningOp(); + if (op_set.find(operand_op) != op_set.end()) { + // skip if defining op is in the pattern + continue; + } + if (input_set.insert(operand).second) { + inputs.push_back(operand); + } + } + } + return inputs; +} + +SmallVector GetOutputsOfFusionPattern(const FusionPattern& pattern) { + SmallVector outputs; + DenseSet op_set; + for (Operation* op : pattern) { + bool inserted = op_set.insert(op).second; + (void)inserted; + assert(inserted && "FusionPattern contains duplicate operations"); + } + + for (Operation* op : pattern) { + for (Value result : op->getResults()) { + bool has_external_user = llvm::any_of( + result.getUses(), + [&](OpOperand& use) { return !op_set.count(use.getOwner()); }); + if (has_external_user) { + outputs.push_back(result); + } + } + } + return outputs; +} + +FusionPattern MergeFusionPattern(const FusionPattern& lhs, + const FusionPattern& rhs) { + FusionPattern pattern(lhs); + pattern.insert(pattern.end(), rhs.begin(), rhs.end()); + return pattern; +} + +inline int EffectiveSize(const FusionPattern& pattern) { + return llvm::count_if( + pattern, [](Operation* op) { return !matchPattern(op, m_Constant()); }); +} + +// This is an simple shape constraint analysis, which is used to +// guide fusion decision (e.g. we only fuse shape-compatible ops). +// +// Currently, We only consider shape equality propagation based +// on the shape constrain traits of elementwise ops (assuming that +// implicit shape broadcast is forbidden). +class ShapeConstraintAnalysis { + public: + explicit ShapeConstraintAnalysis(const SmallVectorImpl& op_list) { + PropagateEquality(op_list); + } + + // Returns true is `lhs` and `rhs` are supposed to have same shape. + bool HasSameShape(Value lhs, Value rhs) { + return impl_.isEquivalent(ValueWrapper(lhs), ValueWrapper(rhs)); + } + + private: + // shape equality propagation based on the shape constrains of + // elementwise ops. + void PropagateEquality(const SmallVectorImpl& op_list) { + bool converged = true; + do { + converged = true; + auto update = [&](Value lhs, Value rhs) { + if (!impl_.isEquivalent(ValueWrapper(lhs), ValueWrapper(rhs))) { + converged = false; + impl_.unionSets(ValueWrapper(lhs), ValueWrapper(rhs)); + } + }; + for (Operation* op : op_list) { + auto op_fusibility = dyn_cast(op); + if (!op_fusibility) continue; + int numInput = op->getNumOperands(); + int numOutput = op->getNumResults(); + // shape equality propagation between inputs. + for (int input1 = 0; input1 < numInput; ++input1) + for (int input2 = input1 + 1; input2 < numInput; ++input2) + if (op_fusibility.inferInputsShapeEquality(input1, input2)) + update(op->getOperand(input1), op->getOperand(input2)); + + // shape equality propagation between outputs. + for (int output1 = 0; output1 < numOutput; ++output1) + for (int output2 = output1 + 1; output2 < numOutput; ++output2) + if (op_fusibility.inferOutputsShapeEquality(output1, output2)) + update(op->getResult(output1), op->getResult(output2)); + + // shape equality propagation between input and output. + for (int input = 0; input < numInput; ++input) + for (int output = 0; output < numOutput; ++output) + if (op_fusibility.inferInputOutputShapeEquality(input, output)) + update(op->getOperand(input), op->getResult(output)); + } + } while (!converged); + } + + // a UnionFind set + EquivalenceClasses impl_; +}; + +// A fusion planner that can propose a fusion plan for a block of ops. +// The fusion plan is consisted of a group of fusion patterns. +// +// Currently all proposed patterns followed xla kLoop/kInput like fusion +// templates while are adapted to the fully dynamic shape world. +// +// kLoop fusion template satifies: +// - all ops in the fusion pattern are element-wise. +// - all the shapes of outputs of fusion pattern are same, and thus can +// fit into a same parallel loop. +// +// kInput fusion template satifies: +// - any op in the fusion pattern is either element-wise or a reduction. +// - if a op is a reduction, its output cannot be consumered by other +// ops in the same fusion pattern. +// - all the effective shapes of outputs of fusion pattern are same. +// - For element-wise op, its effective shape is its output shape. +// - For reduction op, its effective shape is its operand shape. +class FusionPlanner { + public: + explicit FusionPlanner(const SmallVectorImpl& op_list) + : op_list_(op_list), + shape_analysis_(op_list), + cycle_detector_(op_list.size()) { + BuildNodeMap(); + } + + // Returns a fusion plan if success, otherwise none. + llvm::Optional Run() { + // Greedily search connected fusible pattern, and ops belonging to + // a same fusion pattern are grouped into a cluster. + RunEdgeContractionLoop(); + + // After doing edge contraction, each unique cluster having size + // more than one represents a potential fusion pattern. + // We collect all these clusters and construct a fusion plan. + // + // Note that the ops in a fusion pattern are in topological ordering. + FusionPlan plan; + DenseMap pattern_ids; + for (Operation* op : op_list_) { + Cluster* cluster = GetClusterForNode(op); + int node_id = cluster->cycles_graph_node_id(); + if (!IsFusible(op_list_[node_id]) || + EffectiveSize(GetClusterForNode(op)->fused_pattern()) <= 1) { + continue; + } + if (!pattern_ids.count(node_id)) { + int pattern_id = pattern_ids.size(); + pattern_ids[node_id] = pattern_id; + plan.emplace_back(); + } + plan[pattern_ids[node_id]].push_back(op); + } + return plan; + } + + // Returns the op_list this planner operates on. + const SmallVectorImpl& op_list() const { return op_list_; } + + private: + // Represent a (partial) fused pattern + class Cluster { + public: + Cluster(int node_id, FusionPlanner* planner) : node_id_(node_id) { + const SmallVectorImpl& op_list = planner->op_list(); + pattern_.push_back(op_list[node_id]); + } + + // Merges `other` into this cluster, and clears `other`. + void Merge(Cluster* other) { + pattern_.insert(pattern_.end(), other->pattern_.begin(), + other->pattern_.end()); + other->pattern_.clear(); + } + + // The number of nodes in this cluster. + int cluster_size() const { return pattern_.size(); } + + // The ID of the cluster as represented in `cycle_detector_`. + int cycles_graph_node_id() const { return node_id_; } + + // Sets the ID of the cluster as represented in `cycle_detector_`. + void set_cycles_graph_node_id(int cycles_graph_node_id) { + node_id_ = cycles_graph_node_id; + } + + // Currently the fused pattern this cluster holds. + const FusionPattern& fused_pattern() { return pattern_; } + + private: + // ID of the representative node of this cluster. + int node_id_; + + // the fused pattern this cluster holds. + FusionPattern pattern_; + }; + + private: + Cluster* MakeCluster(int cycles_graph_node_id) { + cluster_storage_.emplace_back(new Cluster(cycles_graph_node_id, this)); + return cluster_storage_.back().get(); + } + + void BuildNodeMap() { + int num_nodes = op_list_.size(); + for (int node_id = 0; node_id < num_nodes; ++node_id) { + Operation* op = op_list_[node_id]; + MakeCluster(node_id); + op_to_node_id_[op] = node_id; + leader_for_node_.insert(node_id); + for (Value operand : op->getOperands()) { + Operation* operand_op = operand.getDefiningOp(); + if (operand_op == nullptr) { + // skip block argument + continue; + } + auto iter = op_to_node_id_.find(operand_op); + assert(iter != op_to_node_id_.end()); + cycle_detector_.InsertEdge(iter->second, node_id); + } + } + } + + // Returns the cluster contains this op. + Cluster* GetClusterForNode(Operation* n) { + int id = op_to_node_id_.at(n); + id = leader_for_node_.getLeaderValue(id); + return cluster_storage_[id].get(); + } + + // Returns the cluster contains the op having `node_id`. + Cluster* GetClusterForCyclesGraphNode(int node_id) { + return cluster_storage_[leader_for_node_.getLeaderValue(node_id)].get(); + } + + // Merges the clusters `cluster_from` and `cluster_to`. + bool MergeClusters(Cluster* cluster_from, Cluster* cluster_to) { + int from = cluster_from->cycles_graph_node_id(); + int to = cluster_to->cycles_graph_node_id(); + + auto optional_merged_node = cycle_detector_.ContractEdge(from, to); + if (!optional_merged_node.hasValue()) { + llvm::dbgs() << "Could not contract " << from << " -> " << to + << " because contracting the edge would create a cycle."; + return false; + } + + // Merge the clusters. + cluster_from->Merge(cluster_to); + cluster_from->set_cycles_graph_node_id(*optional_merged_node); + + // Merge the UnionFind Set. + leader_for_node_.unionSets(from, to); + return true; + } + + template + bool ForEachEdgeInPostOrder(FnTy fn) { + bool changed = false; + for (int32_t node : cycle_detector_.AllNodesInPostOrder()) { + Cluster* cluster_from = GetClusterForCyclesGraphNode(node); + // Make a copy of the set of successors because we may modify the graph in + // TryToContractEdge. + std::vector successors_copy = + cycle_detector_.SuccessorsCopy(cluster_from->cycles_graph_node_id()); + + for (int to : successors_copy) { + Cluster* cluster_to = GetClusterForCyclesGraphNode(to); + bool contracted_edge = fn(cluster_from, cluster_to); + changed |= contracted_edge; + } + } + + return changed; + } + + // returns the outputs if two cluster were merged + SmallVector GetResultsOfFusedPattern(Cluster* from, Cluster* to) { + FusionPattern fused_pattern = + MergeFusionPattern(from->fused_pattern(), to->fused_pattern()); + return GetOutputsOfFusionPattern(fused_pattern); + } + + // This function check if fusing `from` with `to` is valid and if so perform + // the merge. The validity is based on the operations in the clusters and + // the compatibility of the shapes of the outputs of the would-be fused + // clusters. + // Returns true is the merge was performed. + bool TryToContractEdge(Cluster* from, Cluster* to) { + int node_to = to->cycles_graph_node_id(); + int node_from = from->cycles_graph_node_id(); + + // Both node_to and node_from should be fusible + if (!IsFusible(op_list_[node_to]) || !IsFusible(op_list_[node_from])) { + return false; + } + + auto op_from_fusibility = + dyn_cast(op_list_[node_from]); + if (op_from_fusibility && !op_from_fusibility.isFusibleWithConsumer()) { + // This op cannot be fused with its consumers. + return false; + } + + auto op_to_fusibility = + dyn_cast(op_list_[node_to]); + if (op_to_fusibility && !op_to_fusibility.isFusibleWithOperand()) { + // This op cannot be fused with its operands. + return false; + } + + // Output shapes of a fusion pattern should be compatible as described in + // the document of this class. + SmallVector results = GetResultsOfFusedPattern(from, to); + auto get_workload_shape = [](Value v) { + Operation* op = v.getDefiningOp(); + // Block argument + if (!op) return v; + auto op_fusibility = dyn_cast(op); + // Const value + if (!op_fusibility) return v; + llvm::Optional workload = + op_fusibility.inferEffectiveWorkloadShape(); + return workload.hasValue() ? *workload : v; + }; + + Value ref = get_workload_shape(results[0]); + if (!llvm::all_of(results, [&](Value result) { + Value val = get_workload_shape(result); + return shape_analysis_.HasSameShape(ref, val); + })) { + return false; + } + + return MergeClusters(from, to); + } + + // Greedily fuse connected node. + bool RunEdgeContractionLoop() { + using std::placeholders::_1; + using std::placeholders::_2; + return ForEachEdgeInPostOrder( + std::bind(&FusionPlanner::TryToContractEdge, this, _1, _2)); + } + + const SmallVectorImpl& op_list_; + + // Shape equality checker + ShapeConstraintAnalysis shape_analysis_; + + // op -> node_id + std::unordered_map op_to_node_id_; + + // make sure not introduce cycle after fusion + GraphCycles cycle_detector_; + std::vector> cluster_storage_; + + // a UnionFind set. Each set represents a (partial) fused pattern + // and has a leader as representation. + EquivalenceClasses leader_for_node_; +}; + +struct XlaHloFusion : public mlir::PassWrapper { + void runOnFunction() override { + FuncOp func = getFunction(); + if (!IsTargetFunc(func)) { + return; + } + + // process each block and do fusion within a block. + for (Block& block : func) { + SmallVector op_list; + for (Operation& op : block) { + op_list.push_back(&op); + } + + FusionPlanner planner(op_list); + llvm::Optional plan = planner.Run(); + if (!plan) { + emitError(func.getLoc(), "can't find a fusion plan"); + signalPassFailure(); + return; + } + if (!ApplyFusionPlan(*plan)) { + emitError(func.getLoc(), "apply fusion plan failed"); + signalPassFailure(); + return; + } + } + } + + bool IsTargetFunc(FuncOp func) { + int num_fusible_ops = 0; + bool is_target_func = false; + // We only process the function having enough candidates + func.walk([&](Operation* op) { + num_fusible_ops += + static_cast(dyn_cast(op) != nullptr); + is_target_func = (num_fusible_ops > 1); + // early stop + if (is_target_func) return WalkResult::interrupt(); + return WalkResult::advance(); + }); + return is_target_func; + } + + bool ApplyFusionPlan(const FusionPlan& plan) { + for (const FusionPattern& pattern : plan) { + OpBuilder b(pattern.back()); + + SmallVector locations; + locations.reserve(pattern.size()); + for (Operation* op : pattern) { + locations.push_back(op->getLoc()); + } + Location fused_loc = + FusedLoc::get(locations, pattern.back()->getContext()); + + SmallVector inputs = GetInputsOfFusionPattern(pattern); + SmallVector outputs = GetOutputsOfFusionPattern(pattern); + SmallVector output_types; + output_types.reserve(outputs.size()); + for (Value v : outputs) { + output_types.push_back(v.getType()); + } + + FusionOp fusion = + b.create(fused_loc, output_types, inputs); + Region& region = fusion.fused_computation(); + region.push_back(new Block); + Block& block = region.front(); + for (Operation* op : pattern) { + op->moveBefore(&block, block.end()); + } + b.setInsertionPoint(&block, block.end()); + b.create(fused_loc, outputs); + + for (auto output_and_result : llvm::zip(outputs, fusion.getResults())) { + Value output = std::get<0>(output_and_result); + Value fusion_result = std::get<1>(output_and_result); + for (OpOperand& use : llvm::make_early_inc_range(output.getUses())) { + if (use.getOwner()->getBlock() != &block) use.set(fusion_result); + } + } + } + return true; + } +}; + +} // namespace + +std::unique_ptr> createXlaHloFusion() { + return std::make_unique(); +} + +static PassRegistration xla_hlo_fusion_pass( + "xla-hlo-fusion", "fuse xla_hlo ops to kLoop/kInput fusion patterns."); + +} // namespace xla_hlo +} // namespace mlir diff --git a/lib/Dialect/mhlo/transforms/xla_legalize_to_linalg.cc b/lib/Dialect/mhlo/transforms/xla_legalize_to_linalg.cc new file mode 100644 index 0000000..66a9aaa --- /dev/null +++ b/lib/Dialect/mhlo/transforms/xla_legalize_to_linalg.cc @@ -0,0 +1,909 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file implements logic for lowering HLO/LHLO dialect to Linalg dialect. + +#include "third_party/absl/memory/memory.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/AffineExpr.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Function.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Location.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_xla_to_scalar_op.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" + +namespace mlir { +namespace { + +SmallVector GetNParallelLoopsAttrs(unsigned nParallelLoops) { + static constexpr StringRef kParallelIterType = "parallel"; + return SmallVector(nParallelLoops, kParallelIterType); +} + +template +Value getResultValue(Operation* op) { + return isLHLO ? op->getOperand(op->getNumOperands() - 1) : op->getResult(0); +} + +template +ShapedType getXLAOpResultType(Operation* op) { + return getResultValue(op).getType().template cast(); +} + +template +bool verifyXLAOpBufferOrTensorSemantics(Operation* op) { + auto verifyType = [&](Value val) -> bool { + return (isLHLO && val.getType().isa()) || + (!isLHLO && val.getType().isa()); + }; + if (!llvm::all_of(op->getOperands(), verifyType)) return false; + return isLHLO ? op->getResults().empty() + : llvm::all_of(op->getResults(), verifyType); +} + +template +class PointwiseToLinalgConverter : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + OpTy op, ArrayRef args, + ConversionPatternRewriter& rewriter) const final { + auto loc = op.getLoc(); + auto argType = + op.getOperation()->getOperand(0).getType().template cast(); + if (!argType.hasRank()) { + emitError(loc, "lhlo to linalg conversion expects ranked args"); + return failure(); + } + auto elemTy = argType.getElementType(); + if (!elemTy.isSignlessIntOrFloat() && !elemTy.template isa()) { + return failure(); + } + + // Construct the indexing maps needed for linalg.generic ops. + SmallVector indexing_maps; + SmallVector bodyArgTypes, bodyResultTypes, opResultTypes; + + // This doesnt account for implicit broadcast, but the working assumption + // here is that are broadcasts have been made explicit. + unsigned nloops = argType.getRank(); + + if (isLHLO && !nloops) return failure(); + + int operandCount = (isLHLO ? args.size() - 1 : args.size()); + auto verifyArgOrResultType = [&](Value val) -> ShapedType { + auto shapedType = val.getType().dyn_cast(); + if (!shapedType || + (!shapedType.isa() && + !shapedType.isa()) || + shapedType.getRank() != nloops) + return nullptr; + indexing_maps.emplace_back( + nloops ? rewriter.getMultiDimIdentityMap(nloops) + : AffineMap::get(nloops, 0, rewriter.getContext())); + return shapedType; + }; + for (const auto& arg : llvm::enumerate(args)) { + auto shapedType = verifyArgOrResultType(arg.value()); + if (!shapedType) return failure(); + auto& result_or_body_arg = + arg.index() < operandCount ? bodyArgTypes : bodyResultTypes; + result_or_body_arg.emplace_back(shapedType.getElementType()); + } + if (!isLHLO) { + // HLO operations have return as tensor types. + assert(bodyResultTypes.empty() && + "When lowering HLO ops result can't be part of arguments"); + Value result = op.getOperation()->getResult(0); + auto shapedType = verifyArgOrResultType(result); + if (!shapedType) return failure(); + bodyResultTypes.push_back(shapedType.getElementType()); + opResultTypes.push_back(shapedType); + } + + int64_t args_count = bodyArgTypes.size(); + int64_t results_count = bodyResultTypes.size(); + auto linalgOp = rewriter.create( + loc, opResultTypes, args, args_count, results_count, indexing_maps, + GetNParallelLoopsAttrs(nloops), + [&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) { + // TODO(ravishankarm) : For now use the method in xla_lhlo namespace. + // That method needs to be moved out of there. + Value opResult = xla_lhlo::XlaOpToStdScalarOp::map( + op, bodyResultTypes, + llvm::to_vector<2>(args.take_front(args_count)), &rewriter); + nestedBuilder.create(loc, opResult); + }); + rewriter.replaceOp(op, linalgOp.getOperation()->getResults()); + return success(); + } +}; + +template +class ScalarPointwiseToStandardConverter : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + LhloOp lhlo_op, ArrayRef args, + ConversionPatternRewriter& rewriter) const final { + auto loc = lhlo_op.getLoc(); + auto argType = + lhlo_op.getOperand(0).getType().template dyn_cast(); + if (!argType || !argType.getElementType().isSignlessIntOrFloat() || + (argType.getRank() != 0)) { + return failure(); + } + + // Create two loads from the input. + auto lhs = rewriter.create(loc, lhlo_op.lhs()); + auto rhs = rewriter.create(loc, lhlo_op.rhs()); + // TODO(ravishankarm) : Move this method out of xla_lhlo namespace. + Value opResult = xla_lhlo::XlaOpToStdScalarOp::map( + lhlo_op, argType.getElementType(), llvm::ArrayRef{lhs, rhs}, + &rewriter); + rewriter.create(loc, opResult, lhlo_op.out()); + rewriter.eraseOp(lhlo_op); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// xla_lhlo.convolution conversion pattern. +//===----------------------------------------------------------------------===// + +/// Converts xla_lhlo.convolution operation to a linalg.conv op. +struct ConvToLinalgConverter : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + // This code has been adapted from IREE's + // (https://github.com/google/iree/) xla_hlo -> linalg conversion. + LogicalResult matchAndRewrite( + xla_lhlo::ConvOp op, ArrayRef args, + ConversionPatternRewriter& rewriter) const final { + // Check validity of dimension information. + if (const xla_lhlo::ConvDimensionNumbers& dimensionNumbers = + op.dimension_numbers()) { + const int inputSpatialRank = + llvm::size(dimensionNumbers.input_spatial_dimensions()); + // The dimensions for input should follow the order of + // batch_count, spatial_dims..., input_feature_count. + if (dimensionNumbers.input_batch_dimension().getInt() != 0 || + dimensionNumbers.input_feature_dimension().getInt() != + (inputSpatialRank + 1)) + return failure(); + + const int kernelSpatialRank = + llvm::size(dimensionNumbers.kernel_spatial_dimensions()); + // The dimensions for filter should follow the order of + // spatial_dims..., input_feature_count, num_output_feature_count. + if (dimensionNumbers.kernel_input_feature_dimension().getInt() != + kernelSpatialRank || + dimensionNumbers.kernel_output_feature_dimension().getInt() != + (kernelSpatialRank + 1)) + return failure(); + + const int outputSpatialRank = + llvm::size(dimensionNumbers.output_spatial_dimensions()); + // The dimensions for output should follow the order of + // batch_count, spatial_dims.., output_feature_count. + if (dimensionNumbers.output_batch_dimension().getInt() != 0 || + dimensionNumbers.output_feature_dimension().getInt() != + (outputSpatialRank + 1)) + return failure(); + + if (inputSpatialRank != outputSpatialRank || + inputSpatialRank != kernelSpatialRank) + return failure(); + + auto inputSpatialDim = + dimensionNumbers.input_spatial_dimensions().begin(); + auto kernelSpatialDim = + dimensionNumbers.kernel_spatial_dimensions().begin(); + auto outputSpatialDim = + dimensionNumbers.output_spatial_dimensions().begin(); + // Check if spatial dims are ordered correctly. + for (int i = 0; i < inputSpatialRank; ++i) { + const int dim = i + 1; + if ((*inputSpatialDim++).getZExtValue() != dim || + (*outputSpatialDim++).getZExtValue() != dim || + (*kernelSpatialDim++).getZExtValue() != i) + return failure(); + } + } + + // TODO: LHS dilation for deconvolution not supported yet. + if (op.lhs_dilation()) { + return failure(); + } + + llvm::SmallVector strides; + if (auto windowStrides = op.window_strides()) { + auto range = windowStrides->getAttributeValues(); + strides.assign(range.begin(), range.end()); + } + auto stridesArg = ArrayAttr::get(strides, op.getContext()); + + llvm::SmallVector dilation; + if (auto rhsDilation = op.rhs_dilation()) { + auto range = rhsDilation->getAttributeValues(); + dilation.assign(range.begin(), range.end()); + } else { + // Default dilation of 1. + dilation.resize(2, IntegerAttr::get(rewriter.getIntegerType(64), 1)); + } + auto dilationArg = ArrayAttr::get(dilation, op.getContext()); + + // Set padding only if it is non-zero. + DenseIntElementsAttr padding = op.paddingAttr(); + if (!padding || !llvm::any_of(padding.getValues(), [](APInt intVal) { + return !intVal.isNullValue(); + })) { + padding = nullptr; + } + + // The order of input and filter are switched with linalg.conv. + rewriter.replaceOpWithNewOp( + op, args[1], args[0], args[2], stridesArg, dilationArg, padding); + return success(); + } +}; + +/// Base class for lowering xla operations that have one operand and one result, +/// and are semantically equivalent to a copy of the input to the output (like +/// transpose, some reshape, etc.). The derived classes need to provide a method +/// `getIndexingMaps` that returns AffineMaps for the index maps of the input +/// and the output. +template +class DataMovementOpConverter : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + OpTy op, ArrayRef args, + ConversionPatternRewriter& rewriter) const final { + if (!verifyXLAOpBufferOrTensorSemantics(op)) return failure(); + auto resultType = getXLAOpResultType(op); + + SmallVector indexing_maps = + Derived::getIndexingMaps(op, &rewriter); + if (indexing_maps.empty()) return failure(); + + auto nloops = resultType.getRank(); + auto loc = op.getLoc(); + auto linalgOp = rewriter.create( + loc, isLHLO ? ArrayRef{} : resultType, args, /*inputCount=*/1, + /*outputCount=*/1, indexing_maps, GetNParallelLoopsAttrs(nloops), + [&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) { + nestedBuilder.create(loc, *args.begin()); + }); + + rewriter.replaceOp(op, linalgOp.getOperation()->getResults()); + return success(); + } +}; + +/// Pattern to convert BroadcastOp to Linalg ops. +template +class BroadcastConverter + : public DataMovementOpConverter, OpTy, + isLHLO> { + public: + using DataMovementOpConverter::DataMovementOpConverter; + + static SmallVector getIndexingMaps(OpTy broadcastOp, + Builder* b) { + ShapedType inputType = + broadcastOp.operand().getType().template cast(); + unsigned inputRank = inputType.getRank(); + unsigned nloops = getXLAOpResultType(broadcastOp).getRank(); + + // BroadcastOp prepends the dimensions in the `broadcast_sizes` attribute to + // the input's dimensions. + unsigned numPrependedDims = llvm::size(broadcastOp.broadcast_sizes()); + SmallVector inputDimExprs; + inputDimExprs.reserve(inputRank); + for (int i = 0; i < inputRank; ++i) { + inputDimExprs.push_back(b->getAffineDimExpr(numPrependedDims + i)); + } + + AffineMap inputMap; + MLIRContext* context = b->getContext(); + if (inputDimExprs.empty()) { + // The input is a scalar, i.e. this is a scalar broadcast op. + inputMap = AffineMap::get(nloops, /*symbolCount=*/0, context); + } else { + inputMap = + AffineMap::get(nloops, /*symbolCount=*/0, inputDimExprs, context); + } + return {inputMap, b->getMultiDimIdentityMap(nloops)}; + } +}; + +class HloBroadcastInDimConverter + : public DataMovementOpConverter { + public: + using DataMovementOpConverter::DataMovementOpConverter; + + static SmallVector getIndexingMaps( + xla_hlo::BroadcastInDimOp broadcastOp, Builder* b) { + auto resultType = getXLAOpResultType(broadcastOp); + auto operandType = + broadcastOp.operand().getType().template cast(); + unsigned nloops = resultType.getRank(); + + // The input is a scalar, i.e. this is a scalar broadcast op. + if (operandType.getRank() == 0) { + return {AffineMap::get(nloops, /*symbolCount=*/0, b->getContext()), + b->getMultiDimIdentityMap(nloops)}; + } + + auto operandShape = operandType.getShape(); + SmallVector dimExprs; + dimExprs.reserve(nloops); + + if (broadcastOp.broadcast_dimensions()) { + for (const auto& broadcastDim : + enumerate(broadcastOp.broadcast_dimensions().getIntValues())) { + int size = broadcastDim.value().getSExtValue(); + bool expansion_needed = operandShape[broadcastDim.index()] == 1 && + resultType.getShape()[size] != 1; + dimExprs.push_back(expansion_needed ? b->getAffineConstantExpr(0) + : b->getAffineDimExpr(size)); + } + } + return { + AffineMap::get(nloops, /*symbolCount=*/0, dimExprs, b->getContext()), + b->getMultiDimIdentityMap(nloops)}; + } +}; + +class LhloBroadcastInDimConverter + : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + xla_lhlo::BroadcastInDimOp op, ArrayRef args, + ConversionPatternRewriter& rewriter) const final { + xla_lhlo::BroadcastInDimOp::Adaptor operand_adaptor(args); + auto result_type = operand_adaptor.output().getType().cast(); + auto result_shape = result_type.getShape(); + + auto operand_and_dims = InsertReshapeIfNecessary(op, args, rewriter); + + Value operand = std::get<0>(operand_and_dims); + auto broadcast_dims = std::get<1>(operand_and_dims); + + auto loc = op.getLoc(); + auto nloops = result_type.getRank(); + auto operand_type = operand.getType().cast(); + + // For a degenerate case, i.e. broadcasting with expansion of + // memref<1xELEMENT_TYPE>, the operand is not passed to `linalg.generic`. + // Instead the value is loaded and used directly in `linalg.yield`. + if (operand_type.getRank() == 1 && + operand_type.getDimSize(0) < + result_type.getDimSize(broadcast_dims.front())) { + Value zero = rewriter.create(loc, 0); + Value val = + rewriter.create(loc, operand, llvm::makeArrayRef({zero})); + rewriter.create( + loc, llvm::None, llvm::makeArrayRef(operand_adaptor.output()), + /*inputCount=*/0, /*outputCount=*/1, + llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)), + GetNParallelLoopsAttrs(nloops), + [&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) { + nestedBuilder.create(loc, val); + }); + + } else { + auto indexing_maps = getIndexingMaps(op, broadcast_dims, result_shape, + operand_type, &rewriter); + rewriter.create( + loc, llvm::None, + llvm::makeArrayRef({operand, operand_adaptor.output()}), + /*inputCount=*/1, /*outputCount=*/1, indexing_maps, + GetNParallelLoopsAttrs(nloops), + [&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) { + nestedBuilder.create(loc, *args.begin()); + }); + } + rewriter.replaceOp(op, llvm::None); + return success(); + } + + // Inserts 'linalg.reshape' if there is a size-1 dim expansion. + std::pair> InsertReshapeIfNecessary( + xla_lhlo::BroadcastInDimOp op, ArrayRef args, + ConversionPatternRewriter& rewriter) const { + xla_lhlo::BroadcastInDimOp::Adaptor operand_adaptor(args); + Value operand = operand_adaptor.operand(); + auto operand_type = operand_adaptor.operand().getType().cast(); + auto operand_shape = operand_type.getShape(); + + Value result = operand_adaptor.output(); + auto result_type = result.getType().cast(); + auto result_shape = result_type.getShape(); + + SmallVector operand_strides; + int64_t operand_offset; + if (failed(getStridesAndOffset(operand_type, operand_strides, + operand_offset))) { + op.emitOpError() << "Failed to get offset and strides."; + } + + SmallVector new_shape, new_strides, broadcast_dims; + SmallVector collapsed_dims_list; + linalg::ReassociationIndices collapsed_dims; + for (const auto& item : + enumerate(op.broadcast_dimensions().getIntValues())) { + size_t index = item.index(); + int dim = item.value().getSExtValue(); + + collapsed_dims.push_back(index); + + bool expansion_needed = + operand_shape[index] == 1 && result_shape[dim] != 1; + if (expansion_needed) { + continue; + } + new_shape.push_back(operand_shape[index]); + new_strides.push_back(operand_strides[index]); + broadcast_dims.push_back(dim); + + collapsed_dims_list.push_back(collapsed_dims); + collapsed_dims.clear(); + } + // If `collapsed_dims_list` is empty, then the memref has shape [1, ..., 1] + // and all dimensions need expansion. Such memref will be reshaped to a 1D + // memref with a single element. New shape and strides needs to be updated + // accordingly. + if (collapsed_dims_list.empty()) { + collapsed_dims_list.push_back({}); + new_shape.push_back(1); + new_strides.push_back(1); + broadcast_dims.push_back(0); + } + for (const auto& dims : collapsed_dims) { + collapsed_dims_list.back().push_back(dims); + } + + // `linalg.reshape` is inserted only if necessary, i.e. when the rank can be + // reduced. + if (new_shape.size() < operand_shape.size()) { + auto new_memref_type = MemRefType::get( + new_shape, operand_type.getElementType(), + makeStridedLinearLayoutMap(new_strides, operand_offset, + rewriter.getContext())); + operand = rewriter.create(op.getLoc(), new_memref_type, + operand_adaptor.operand(), + collapsed_dims_list); + } + return std::make_pair(operand, broadcast_dims); + } + + SmallVector getIndexingMaps(xla_lhlo::BroadcastInDimOp op, + ArrayRef broadcastDims, + ArrayRef resultShape, + MemRefType operandType, + Builder* b) const { + unsigned nloops = resultShape.size(); + + // The input is a scalar, i.e. this is a scalar broadcast op. + if (operandType.getRank() == 0) { + return {AffineMap::get(nloops, /*symbolCount=*/0, b->getContext()), + b->getMultiDimIdentityMap(nloops)}; + } + + auto operandShape = operandType.getShape(); + SmallVector dimExprs; + dimExprs.reserve(nloops); + + for (const auto& broadcastDim : llvm::enumerate(broadcastDims)) { + int size = broadcastDim.value(); + bool expansion_needed = + operandShape[broadcastDim.index()] == 1 && resultShape[size] != 1; + if (expansion_needed) { + op.emitOpError( + "BroadcastInDimOp lowering to Linalg does not support size-1 " + "dimensions expansion."); + } + dimExprs.push_back(b->getAffineDimExpr(size)); + } + return { + AffineMap::get(nloops, /*symbolCount=*/0, dimExprs, b->getContext()), + b->getMultiDimIdentityMap(nloops)}; + } +}; + +template +class TransposeConverter + : public DataMovementOpConverter, OpTy, + isLHLO> { + public: + using DataMovementOpConverter, OpTy, + isLHLO>::DataMovementOpConverter; + static SmallVector getIndexingMaps(OpTy op, Builder* b) { + auto resultType = + getXLAOpResultType(op).template cast(); + auto nloops = resultType.getRank(); + SmallVector inputExprs; + inputExprs.resize(resultType.getRank()); + for (auto permutation : llvm::enumerate(op.permutation())) { + inputExprs[permutation.value().getZExtValue()] = + b->getAffineDimExpr(permutation.index()); + } + return { + AffineMap::get(nloops, /*symbolCount=*/0, inputExprs, b->getContext()), + b->getMultiDimIdentityMap(nloops)}; + } +}; + +// Converts reshape ops that can be proven to be either a collapse of dimensions +// or expansion of dimensions of the operand. +template +class ReshapeOpConverter : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + OpTy reshapeOp, ArrayRef args, + ConversionPatternRewriter& rewriter) const final { + if (!verifyXLAOpBufferOrTensorSemantics(reshapeOp)) + return failure(); + ShapedType operandType = + reshapeOp.operand().getType().template cast(); + ShapedType resultType = getXLAOpResultType(reshapeOp); + + if (!operandType.hasStaticShape() || !resultType.hasStaticShape()) + return failure(); + + // Compute the reassociation maps for the linalg operation. + ArrayRef srcShape = + (operandType.getRank() > resultType.getRank() ? operandType.getShape() + : resultType.getShape()); + ArrayRef dstShape = + (operandType.getRank() > resultType.getRank() ? resultType.getShape() + : operandType.getShape()); + unsigned currSrcDim = 0, currDstDim = 0; + SmallVector reassociationMap( + dstShape.size()); + while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) { + int64_t dstSize = dstShape[currDstDim]; + int64_t srcSize = srcShape[currSrcDim]; + while (srcSize < dstSize && currSrcDim < srcShape.size()) { + reassociationMap[currDstDim].push_back( + rewriter.getAffineDimExpr(currSrcDim++)); + srcSize *= srcShape[currSrcDim]; + } + if (srcSize == dstSize) { + reassociationMap[currDstDim].push_back( + rewriter.getAffineDimExpr(currSrcDim++)); + // If the next dim in dstShape is not 1, treat subsequent dims in + // srcShape which are 1 to be collapsed. + if (currDstDim == dstShape.size() - 1 || + dstShape[currDstDim + 1] != 1) { + while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) { + reassociationMap[currDstDim].push_back( + rewriter.getAffineDimExpr(currSrcDim++)); + } + } + } else { + return failure(); + } + currDstDim++; + } + if (currSrcDim != srcShape.size()) return failure(); + + if (isLHLO) { + Value reshapeBuffer = rewriter.create( + reshapeOp.getLoc(), resultType, args[0], reassociationMap); + rewriter.replaceOpWithNewOp( + reshapeOp, reshapeBuffer, args[1], /*inputPermutation =*/nullptr, + /*outputPermutation =*/nullptr); + } else { + rewriter.replaceOpWithNewOp( + reshapeOp, resultType, args[0], reassociationMap); + } + return success(); + } +}; + +class IotaConverter : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + xla_lhlo::IotaOp iotaOp, ArrayRef args, + ConversionPatternRewriter& rewriter) const final { + auto resultMemrefType = + iotaOp.getOperand().getType().dyn_cast(); + if (!resultMemrefType) return failure(); + + auto resultElementType = resultMemrefType.getElementType(); + if (!resultElementType.isSignlessIntOrFloat()) return failure(); + + // Construct the indexing maps needed for linalg.generic ops. + unsigned nloops = resultMemrefType.getRank(); + + rewriter.create( + iotaOp.getLoc(), ArrayRef{}, args, + 0, // args_in + 1, // args_out + llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)), + GetNParallelLoopsAttrs(nloops), + [&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange ivs, + ValueRange args) { + Value castOp = nestedBuilder.create( + nestedLoc, ivs[iotaOp.iota_dimension().getZExtValue()], + nestedBuilder.getIntegerType( + resultElementType.getIntOrFloatBitWidth())); + if (resultElementType.isa()) { + castOp = nestedBuilder.create(nestedLoc, castOp, + resultElementType); + } + nestedBuilder.create(nestedLoc, castOp); + }); + + rewriter.replaceOp(iotaOp, llvm::None); + return success(); + } +}; + +class ConstConverter : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + xla_lhlo::ConstOp constOp, ArrayRef args, + ConversionPatternRewriter& rewriter) const final { + auto loc = constOp.getLoc(); + auto valueAttr = constOp.value().cast(); + if (valueAttr.getType().getRank() != 0) return failure(); + auto stdConstOp = + rewriter.create(loc, valueAttr.getValue({})); + rewriter.create(loc, stdConstOp, constOp.getOperand()); + rewriter.eraseOp(constOp); + return success(); + } +}; + +// TODO(b/156787842): Support the lowering for dynamic shapes. +template +class ReverseConverter + : public DataMovementOpConverter, OpTy, + isLHLO> { + public: + using DataMovementOpConverter, OpTy, + isLHLO>::DataMovementOpConverter; + static SmallVector getIndexingMaps(OpTy op, Builder* b) { + auto resultType = + getXLAOpResultType(op).template cast(); + auto nloops = resultType.getRank(); + SmallVector inputExprs; + inputExprs.reserve(nloops); + for (int i = 0; i < nloops; ++i) + inputExprs.push_back(b->getAffineDimExpr(i)); + for (auto dim : op.dimensions()) { + int i = dim.getZExtValue(); + if (resultType.isDynamicDim(i)) return {}; + int n = resultType.getShape()[i]; + inputExprs[i] = b->getAffineConstantExpr(n - 1) - inputExprs[i]; + } + return { + AffineMap::get(nloops, /*symbolCount=*/0, inputExprs, b->getContext()), + b->getMultiDimIdentityMap(nloops)}; + } +}; + +class SliceConverter : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + xla_lhlo::SliceOp sliceOp, ArrayRef args, + ConversionPatternRewriter& rewriter) const final { + auto loc = sliceOp.getLoc(); + auto argType = + sliceOp.getOperand(0).getType().template dyn_cast(); + if (!argType || !argType.hasRank()) { + emitError(loc, "lhlo to linalg conversion expects known-rank args"); + return failure(); + } + + SmallVector ranges; + for (int i = 0, e = argType.getRank(); i < e; ++i) { + Value start_index = rewriter.create( + loc, sliceOp.start_indices().getValue(i)); + Value limit_index = rewriter.create( + loc, sliceOp.limit_indices().getValue(i)); + Value stride = rewriter.create( + loc, sliceOp.strides().getValue(i)); + ranges.push_back(rewriter.create(loc, start_index, + limit_index, stride)); + } + auto linalg_slice = + rewriter.create(loc, sliceOp.getOperand(0), ranges); + rewriter.create(loc, linalg_slice, sliceOp.getOperand(1)); + rewriter.eraseOp(sliceOp); + return success(); + } +}; + +void populateLHLOToLinalgConversionPattern(MLIRContext* context, + OwningRewritePatternList* patterns) { + // clang-format off + patterns->insert, + ConstConverter, + ConvToLinalgConverter, + IotaConverter, + LhloBroadcastInDimConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + // TODO(ataei): Remove this pattern, CopyOp is folded away. + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + ReshapeOpConverter, + ReverseConverter, + ScalarPointwiseToStandardConverter, + SliceConverter + >(context); + // clang-format on +} + +// Converts LHLO ops to Linalg generic. +// Sample result for xla_lhlo::AddOp. +// +// "xla_lhlo.add"(%arg1, %arg2, %out) : +// (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () +// +// will be converted to +// +// #map0 = (d0, d1) -> (d0, d1) +// "linalg.generic"(%arg1, %arg2, %out) ( { +// ^bb0(%arg4: f32, %arg5: f32): +// %0 = addf %arg4, %arg5 : f32 +// "linalg.yield"(%0) : (f32) -> () +// }) { +// args_in = 2, +// args_out = 1, +// indexing_maps = [#map0, #map0, #map0], +// iterator_types = ["parallel", "parallel"], +// } : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () +struct LhloLegalizeToLinalg + : public PassWrapper { + void runOnFunction() override { + OwningRewritePatternList patterns; + ConversionTarget target(getContext()); + target.addLegalDialect(); + + auto func = getFunction(); + populateLHLOToLinalgConversionPattern(func.getContext(), &patterns); + if (failed(applyPartialConversion(func, target, patterns, nullptr))) { + signalPassFailure(); + } + } +}; + +struct HloLegalizeToLinalg + : public PassWrapper { + void runOnFunction() override { + OwningRewritePatternList patterns; + ConversionTarget target(getContext()); + target.addLegalDialect(); + + auto func = getFunction(); + xla_hlo::populateHLOToLinalgConversionPattern(func.getContext(), &patterns); + if (failed(applyPartialConversion(func, target, patterns, nullptr))) { + signalPassFailure(); + } + } +}; + +} // namespace + +namespace xla_lhlo { +std::unique_ptr> createLegalizeLhloToLinalgPass() { + return absl::make_unique(); +} + +static PassRegistration legalize_lhlo_pass( + "lhlo-legalize-to-linalg", "Legalize from LHLO dialect to Linalg dialect"); +} // namespace xla_lhlo + +namespace xla_hlo { + +void populateHLOToLinalgConversionPattern(MLIRContext* context, + OwningRewritePatternList* patterns) { + patterns->insert, + HloBroadcastInDimConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + ReshapeOpConverter, + ReverseConverter, + TransposeConverter>(context); +} + +std::unique_ptr> createLegalizeHloToLinalgPass() { + return absl::make_unique(); +} + +static PassRegistration legalize_hlo_pass( + "hlo-legalize-to-linalg", "Legalize from HLO dialect to Linalg dialect"); +} // namespace xla_hlo +} // namespace mlir diff --git a/lib/Dialect/mhlo/transforms/xla_transform_unranked_hlo.cc b/lib/Dialect/mhlo/transforms/xla_transform_unranked_hlo.cc new file mode 100644 index 0000000..fde9cef --- /dev/null +++ b/lib/Dialect/mhlo/transforms/xla_transform_unranked_hlo.cc @@ -0,0 +1,188 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +==============================================================================*/ + +#include "third_party/absl/memory/memory.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Shape/IR/Shape.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Function.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h" +#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" + +namespace mlir { +namespace xla_hlo { +namespace { + +// TODO(frgossen): Make it variadic. +template +inline void AddLegalOpOnRankedTensor(ConversionTarget *target) { + target->addDynamicallyLegalOp([](OpTy op) { + return llvm::all_of((op.getOperation())->getOperandTypes(), + [&](Type t) { return t.isa(); }); + }); +} + +/// Unary element-wise operations on unranked tensors can be applied to the +/// flattened tensor with the same effect. +/// This pattern rewrites every such operation to +/// (i) flatten the input tensor, +/// (ii) apply the unary operation, and +/// (iii) restore the original shape. +template +struct UnaryElementwiseOpConversion : public OpRewritePattern { + explicit UnaryElementwiseOpConversion(MLIRContext *context) + : OpRewritePattern(context) {} + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + // Don't apply conversion to ops with statically shaped operands. + Value operand = op.getOperand(); + auto operandTy = operand.getType().dyn_cast(); + if (operandTy.hasRank()) return failure(); + + // Generate IR to flatten the operand. + auto loc = op.getLoc(); + Value shape = rewriter.create(loc, operand); + Value numElements = rewriter.create( + loc, rewriter.getType(), shape); + Value numElementsAsIndex = rewriter.create( + loc, rewriter.getIndexType(), numElements); + Value flatShapeAsDimTensor = + rewriter.create(loc, numElementsAsIndex); + auto flatTensorTy = RankedTensorType::get({ShapedType::kDynamicSize}, + operandTy.getElementType()); + Value flatOperand = rewriter.create( + loc, flatTensorTy, operand, flatShapeAsDimTensor); + + // Generate IR for the actual operation. + Value flatResult = rewriter.create(loc, flatTensorTy, flatOperand); + + // Generate IR to restore the original shape. + auto extentTensorTy = RankedTensorType::get({ShapedType::kDynamicSize}, + rewriter.getIndexType()); + Value shapeAsExtentTensor = + rewriter.create(loc, extentTensorTy, shape); + Value result = rewriter.create( + loc, operandTy, flatResult, shapeAsExtentTensor); + rewriter.replaceOp(op, result); + + return success(); + } +}; + +/// Binary element-wise operation on unranked tensors can be applied to the +/// flattened operand tensors with the same effect. +/// This pattern rewrites every such operation to +/// (i) flatten the operand tensors, +/// (ii) apply the binary operation, and +// (iii) restore the original shape. +template +struct BinaryElementwiseOpConversion : public OpRewritePattern { + explicit BinaryElementwiseOpConversion(MLIRContext *context) + : OpRewritePattern(context) {} + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + // Don't apply conversion unless both operands are unranked. + if (op.lhs().getType().template isa() || + op.rhs().getType().template isa()) { + return failure(); + } + + // Flatten operands. + Type shapeTy = shape::ShapeType::get(rewriter.getContext()); + auto loc = op.getLoc(); + Value shapeLhs = rewriter.create(loc, op.lhs()); + Value shapeRhs = rewriter.create(loc, op.rhs()); + Value shape = rewriter.create(loc, shapeTy, + ValueRange{shapeLhs, shapeRhs}); + Value numElements = rewriter.create(loc, shape); + Value numElementsAsIndex = + rewriter.create(loc, numElements); + Value flatShape = + rewriter.create(loc, numElementsAsIndex); + TensorType lhsTy = op.lhs().getType().template cast(); + Type flatLhsTy = RankedTensorType::get({ShapedType::kDynamicSize}, + lhsTy.getElementType()); + Value flatLhs = + rewriter.create(loc, flatLhsTy, op.lhs(), flatShape); + TensorType rhsTy = op.rhs().getType().template cast(); + Type flatRhsTy = RankedTensorType::get({ShapedType::kDynamicSize}, + rhsTy.getElementType()); + Value flatRhs = + rewriter.create(loc, flatRhsTy, op.rhs(), flatShape); + + // Apply actual operation to flattened operands. + Value flatResult = rewriter.create(loc, flatLhs, flatRhs); + + // Restore original shape. + auto extentTensorTy = RankedTensorType::get({ShapedType::kDynamicSize}, + rewriter.getIndexType()); + Value shapeAsExtentTensor = + rewriter.create(loc, extentTensorTy, shape); + Value result = rewriter.create( + loc, op.getType(), flatResult, shapeAsExtentTensor); + rewriter.replaceOp(op, result); + + return success(); + } +}; + +struct TransformUnrankedHloPass + : public PassWrapper { + void runOnFunction() override { + // Setup conversion target. + MLIRContext &ctx = getContext(); + ConversionTarget target(ctx); + target.addLegalDialect(); + target.addLegalOp(); + AddLegalOpOnRankedTensor(&target); + AddLegalOpOnRankedTensor(&target); + + // Populate rewrite patterns. + OwningRewritePatternList patterns; + PopulateTransformUnrankedHloPatterns(&ctx, &patterns); + + // Apply transformation. + if (failed(applyFullConversion(getFunction(), target, patterns))) + return signalPassFailure(); + } +}; + +} // namespace + +void PopulateTransformUnrankedHloPatterns(MLIRContext *context, + OwningRewritePatternList *patterns) { + // TODO(frgossen): Populate all unary and binary operations. + // clang-format off + patterns->insert< + BinaryElementwiseOpConversion, + UnaryElementwiseOpConversion>(context); + // clang-format on +} + +static PassRegistration transform_unranked_hlo_pass( + "transform-unranked-hlo", + "Realize element-wise operations on ranked tensors where possible"); + +} // namespace xla_hlo +} // namespace mlir diff --git a/lib/utils/cycle_detector.cc b/lib/utils/cycle_detector.cc new file mode 100644 index 0000000..b3b51dd --- /dev/null +++ b/lib/utils/cycle_detector.cc @@ -0,0 +1,340 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/cycle_detector.h" + +#include + +#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/DenseSet.h" + +namespace mlir { + +namespace { + +using NodeSet = llvm::DenseSet; +using OrderedNodeSet = OrderedSet; + +template +struct VecStruct { + typedef llvm::SmallVector type; +}; +template +using Vec = typename VecStruct::type; + +struct Node { + // rank number assigned by Pearce-Kelly algorithm + int32_t rank; + // Temporary marker used by depth-first-search + bool visited; + // User-supplied data + void* data; + // List of immediate predecessor nodes in graph + OrderedNodeSet in; + // List of immediate successor nodes in graph + OrderedNodeSet out; +}; + +} // namespace + +struct GraphCycles::Rep { + Vec nodes; + // Indices for unused entries in nodes + Vec free_nodes; + + // Temporary state. + // Results of forward DFS + Vec deltaf; + // Results of backward DFS + Vec deltab; + // All nodes to reprocess + Vec list; + // Rank values to assign to list entries + Vec merged; + // Emulates recursion stack when doing depth first search + Vec stack; +}; + +GraphCycles::GraphCycles(int32_t num_nodes) : rep_(new Rep) { + rep_->nodes.reserve(num_nodes); + for (int32_t i = 0; i < num_nodes; ++i) { + Node* n = new Node; + n->visited = false; + n->data = nullptr; + n->rank = rep_->nodes.size(); + rep_->nodes.push_back(n); + } +} + +GraphCycles::~GraphCycles() { + for (Vec::size_type i = 0, e = rep_->nodes.size(); i < e; ++i) { + delete rep_->nodes[i]; + } + delete rep_; +} + +bool GraphCycles::HasEdge(int32_t x, int32_t y) const { + return rep_->nodes[x]->out.Contains(y); +} + +void GraphCycles::RemoveEdge(int32_t x, int32_t y) { + rep_->nodes[x]->out.Erase(y); + rep_->nodes[y]->in.Erase(x); + // No need to update the rank assignment since a previous valid + // rank assignment remains valid after an edge deletion. +} + +static bool ForwardDFS(GraphCycles::Rep* r, int32_t n, int32_t upper_bound); +static void BackwardDFS(GraphCycles::Rep* r, int32_t n, int32_t lower_bound); +static void Reorder(GraphCycles::Rep* r); +static void Sort(const Vec&, Vec* delta); +static void MoveToList(GraphCycles::Rep* r, Vec* src, + Vec* dst); +static void ClearVisitedBits(GraphCycles::Rep* r, const Vec& nodes); + +bool GraphCycles::InsertEdge(int32_t x, int32_t y) { + if (x == y) return false; + Rep* r = rep_; + Node* nx = r->nodes[x]; + if (!nx->out.Insert(y)) { + // Edge already exists. + return true; + } + + Node* ny = r->nodes[y]; + ny->in.Insert(x); + + if (nx->rank <= ny->rank) { + // New edge is consistent with existing rank assignment. + return true; + } + + // Current rank assignments are incompatible with the new edge. Recompute. + // We only need to consider nodes that fall in the range [ny->rank,nx->rank]. + if (ForwardDFS(r, y, nx->rank)) { + // Found a cycle. Undo the insertion and tell caller. + nx->out.Erase(y); + ny->in.Erase(x); + // Since we do not call Reorder() on this path, clear any visited + // markers left by ForwardDFS. + ClearVisitedBits(r, r->deltaf); + return false; + } + BackwardDFS(r, x, ny->rank); + Reorder(r); + return true; +} + +// Follows the edges from producer to consumer and searchs if the node having +// rank `n` can reach the node having rank `upper_bound` using a DFS search. +// When doing DFS search, We only consider the pathes that satisfy the ranks +// of the nodes of the path are all smaller than `upper_bound`. +// +// Returns true if such path exists. +static bool ForwardDFS(GraphCycles::Rep* r, int32_t n, int32_t upper_bound) { + // Avoid recursion since stack space might be limited. + // We instead keep a stack of nodes to visit. + r->deltaf.clear(); + r->stack.clear(); + r->stack.push_back(n); + while (!r->stack.empty()) { + n = r->stack.back(); + r->stack.pop_back(); + Node* nn = r->nodes[n]; + if (nn->visited) continue; + + nn->visited = true; + r->deltaf.push_back(n); + + for (auto w : nn->out.GetSequence()) { + Node* nw = r->nodes[w]; + if (nw->rank == upper_bound) { + return true; + } + if (!nw->visited && nw->rank < upper_bound) { + r->stack.push_back(w); + } + } + } + return false; +} + +// Follows the edges from consumer to producer and visit all the nodes that +// is reachable from node `n` and have rank larger than `lower_bound`. +static void BackwardDFS(GraphCycles::Rep* r, int32_t n, int32_t lower_bound) { + r->deltab.clear(); + r->stack.clear(); + r->stack.push_back(n); + while (!r->stack.empty()) { + n = r->stack.back(); + r->stack.pop_back(); + Node* nn = r->nodes[n]; + if (nn->visited) continue; + + nn->visited = true; + r->deltab.push_back(n); + + for (auto w : nn->in.GetSequence()) { + Node* nw = r->nodes[w]; + if (!nw->visited && lower_bound < nw->rank) { + r->stack.push_back(w); + } + } + } +} + +// Recomputes rank assignments to make them compatible with the edges (producer +// has smaller rank than its consumer) +static void Reorder(GraphCycles::Rep* r) { + Sort(r->nodes, &r->deltab); + Sort(r->nodes, &r->deltaf); + + // Adds contents of delta lists to list (backwards deltas first). + r->list.clear(); + MoveToList(r, &r->deltab, &r->list); + MoveToList(r, &r->deltaf, &r->list); + + // Produce sorted list of all ranks that will be reassigned. + r->merged.resize(r->deltab.size() + r->deltaf.size()); + std::merge(r->deltab.begin(), r->deltab.end(), r->deltaf.begin(), + r->deltaf.end(), r->merged.begin()); + + // Assign the ranks in order to the collected list. + for (Vec::size_type i = 0, e = r->list.size(); i < e; ++i) { + r->nodes[r->list[i]]->rank = r->merged[i]; + } +} + +// Sorts nodes in the vector according to their ranks. Small rank first. +static void Sort(const Vec& nodes, Vec* delta) { + struct ByRank { + const Vec* nodes; + bool operator()(int32_t a, int32_t b) const { + return (*nodes)[a]->rank < (*nodes)[b]->rank; + } + }; + ByRank cmp; + cmp.nodes = &nodes; + std::sort(delta->begin(), delta->end(), cmp); +} + +// Collects ranks of nodes in vector `src` to vector `dst` +static void MoveToList(GraphCycles::Rep* r, Vec* src, + Vec* dst) { + for (Vec::size_type i = 0, e = src->size(); i < e; i++) { + int32_t w = (*src)[i]; + // Replace src entry with its rank + (*src)[i] = r->nodes[w]->rank; + // Prepare for future DFS calls + r->nodes[w]->visited = false; + dst->push_back(w); + } +} + +// Clears bookkeeping fileds used during the last DFS process. +static void ClearVisitedBits(GraphCycles::Rep* r, const Vec& nodes) { + for (Vec::size_type i = 0, e = nodes.size(); i < e; i++) { + r->nodes[nodes[i]]->visited = false; + } +} + +bool GraphCycles::IsReachable(int32_t x, int32_t y) { + if (x == y) return true; + Rep* r = rep_; + Node* nx = r->nodes[x]; + Node* ny = r->nodes[y]; + + if (nx->rank >= ny->rank) { + // x cannot reach y since it is after it in the topological ordering + return false; + } + + // See if x can reach y using a DFS search that is limited to y's rank + bool reachable = ForwardDFS(r, x, ny->rank); + + // Clear any visited markers left by ForwardDFS. + ClearVisitedBits(r, r->deltaf); + return reachable; +} + +llvm::Optional GraphCycles::ContractEdge(int32_t a, int32_t b) { + assert(HasEdge(a, b)); + RemoveEdge(a, b); + + if (IsReachable(a, b)) { + // Restore the graph to its original state. + InsertEdge(a, b); + return {}; + } + + if (rep_->nodes[b]->in.Size() + rep_->nodes[b]->out.Size() > + rep_->nodes[a]->in.Size() + rep_->nodes[a]->out.Size()) { + // Swap "a" and "b" to minimize copying. + std::swap(a, b); + } + + Node* nb = rep_->nodes[b]; + OrderedNodeSet out = std::move(nb->out); + OrderedNodeSet in = std::move(nb->in); + for (int32_t y : out.GetSequence()) { + rep_->nodes[y]->in.Erase(b); + } + for (int32_t y : in.GetSequence()) { + rep_->nodes[y]->out.Erase(b); + } + rep_->free_nodes.push_back(b); + + rep_->nodes[a]->out.Reserve(rep_->nodes[a]->out.Size() + out.Size()); + for (int32_t y : out.GetSequence()) { + InsertEdge(a, y); + } + + rep_->nodes[a]->in.Reserve(rep_->nodes[a]->in.Size() + in.Size()); + for (int32_t y : in.GetSequence()) { + InsertEdge(y, a); + } + + // Note, if the swap happened it might be what originally was called "b". + return a; +} + +std::vector GraphCycles::SuccessorsCopy(int32_t node) const { + return rep_->nodes[node]->out.GetSequence(); +} + +namespace { +void SortInPostOrder(const Vec& nodes, std::vector* to_sort) { + std::sort(to_sort->begin(), to_sort->end(), [&](int32_t a, int32_t b) { + return nodes[a]->rank > nodes[b]->rank; + }); +} +} // namespace + +std::vector GraphCycles::AllNodesInPostOrder() const { + llvm::DenseSet free_nodes_set; + for (int32_t n : rep_->free_nodes) free_nodes_set.insert(n); + + std::vector all_nodes; + all_nodes.reserve(rep_->nodes.size() - free_nodes_set.size()); + for (size_t i = 0, e = rep_->nodes.size(); i < e; i++) { + if (!free_nodes_set.count(i)) { + all_nodes.push_back(i); + } + } + + SortInPostOrder(rep_->nodes, &all_nodes); + return all_nodes; +} + +} // namespace mlir diff --git a/lib/utils/cycle_detector_test.cc b/lib/utils/cycle_detector_test.cc new file mode 100644 index 0000000..bee96d2 --- /dev/null +++ b/lib/utils/cycle_detector_test.cc @@ -0,0 +1,89 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/cycle_detector.h" + +#include "third_party/tensorflow/compiler/xla/test.h" + +class GraphCyclesTest : public ::testing::Test { + public: + GraphCyclesTest() : g_(100) {} + + bool AddEdge(int x, int y) { return g_.InsertEdge(x, y); } + + void AddMultiples() { + // For every node x > 0: add edge to 2*x, 3*x + for (int x = 1; x < 25; x++) { + EXPECT_TRUE(AddEdge(x, 2 * x)) << x; + EXPECT_TRUE(AddEdge(x, 3 * x)) << x; + } + } + + mlir::GraphCycles g_; +}; + +TEST_F(GraphCyclesTest, NoCycle) { AddMultiples(); } + +TEST_F(GraphCyclesTest, SimpleCycle) { + AddMultiples(); + EXPECT_FALSE(AddEdge(8, 4)); +} + +TEST_F(GraphCyclesTest, IndirectCycle) { + AddMultiples(); + EXPECT_TRUE(AddEdge(16, 9)); + EXPECT_FALSE(AddEdge(9, 2)); +} + +TEST_F(GraphCyclesTest, RemoveEdge) { + EXPECT_TRUE(AddEdge(1, 2)); + EXPECT_TRUE(AddEdge(2, 3)); + EXPECT_TRUE(AddEdge(3, 4)); + EXPECT_TRUE(AddEdge(4, 5)); + g_.RemoveEdge(2, 3); + EXPECT_FALSE(g_.HasEdge(2, 3)); +} + +TEST_F(GraphCyclesTest, IsReachable) { + EXPECT_TRUE(AddEdge(1, 2)); + EXPECT_TRUE(AddEdge(2, 3)); + EXPECT_TRUE(AddEdge(3, 4)); + EXPECT_TRUE(AddEdge(4, 5)); + + EXPECT_TRUE(g_.IsReachable(1, 5)); + EXPECT_FALSE(g_.IsReachable(5, 1)); +} + +TEST_F(GraphCyclesTest, ContractEdge) { + ASSERT_TRUE(AddEdge(1, 2)); + ASSERT_TRUE(AddEdge(1, 3)); + ASSERT_TRUE(AddEdge(2, 3)); + ASSERT_TRUE(AddEdge(2, 4)); + ASSERT_TRUE(AddEdge(3, 4)); + + // It will introduce a cycle if the edge is contracted + EXPECT_FALSE(g_.ContractEdge(1, 3).hasValue()); + EXPECT_TRUE(g_.HasEdge(1, 3)); + + // Node (2) has more edges. + EXPECT_EQ(*g_.ContractEdge(1, 2), 2); + EXPECT_TRUE(g_.HasEdge(2, 3)); + EXPECT_TRUE(g_.HasEdge(2, 4)); + EXPECT_TRUE(g_.HasEdge(3, 4)); + + // Node (2) has more edges. + EXPECT_EQ(*g_.ContractEdge(2, 3), 2); + EXPECT_TRUE(g_.HasEdge(2, 4)); +}