Move XLA-independent transforms to the new MLIR-HLO directory
This is as straighforward as possible, more cleanup/rewrite to come. PiperOrigin-RevId: 319849713
This commit is contained in:
		
							parent
							
								
									72010faaa7
								
							
						
					
					
						commit
						31dc1b21eb
					
				| 
						 | 
				
			
			@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
 | 
			
		|||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_CHLO_OPS_H_
 | 
			
		||||
#define TENSORFLOW_COMPILER_MLIR_XLA_IR_CHLO_OPS_H_
 | 
			
		||||
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_CHLO_OPS_H_
 | 
			
		||||
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_CHLO_OPS_H_
 | 
			
		||||
 | 
			
		||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringRef.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Dialect.h"
 | 
			
		||||
| 
						 | 
				
			
			@ -42,4 +42,4 @@ class XlaHloClientDialect : public Dialect {
 | 
			
		|||
}  // namespace xla_chlo
 | 
			
		||||
}  // namespace mlir
 | 
			
		||||
 | 
			
		||||
#endif  // TENSORFLOW_COMPILER_MLIR_XLA_IR_CHLO_OPS_H_
 | 
			
		||||
#endif  // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_CHLO_OPS_H_
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -15,8 +15,8 @@ limitations under the License.
 | 
			
		|||
 | 
			
		||||
// This file defines the operations used in the XLA dialect.
 | 
			
		||||
 | 
			
		||||
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_OPS_H_
 | 
			
		||||
#define TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_OPS_H_
 | 
			
		||||
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_H_
 | 
			
		||||
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_H_
 | 
			
		||||
 | 
			
		||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringRef.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
 | 
			
		||||
| 
						 | 
				
			
			@ -96,4 +96,4 @@ LogicalResult deriveShapeFromFirstOperand(
 | 
			
		|||
}  // end namespace xla_hlo
 | 
			
		||||
}  // end namespace mlir
 | 
			
		||||
 | 
			
		||||
#endif  //  TENSORFLOW_COMPILER_MLIR_XLA_IR_HLO_OPS_H_
 | 
			
		||||
#endif  //  TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_H_
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
 | 
			
		|||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_INFER_FUSIBILITY_OP_INTERFACE_H_
 | 
			
		||||
#define TENSORFLOW_COMPILER_MLIR_XLA_IR_INFER_FUSIBILITY_OP_INTERFACE_H_
 | 
			
		||||
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_INFER_FUSIBILITY_OP_INTERFACE_H_
 | 
			
		||||
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_INFER_FUSIBILITY_OP_INTERFACE_H_
 | 
			
		||||
 | 
			
		||||
#include "mlir/IR/OpDefinition.h"
 | 
			
		||||
#include "mlir/IR/StandardTypes.h"
 | 
			
		||||
| 
						 | 
				
			
			@ -25,4 +25,4 @@ namespace mlir {
 | 
			
		|||
 | 
			
		||||
}  // namespace mlir
 | 
			
		||||
 | 
			
		||||
#endif  // TENSORFLOW_COMPILER_MLIR_XLA_IR_INFER_FUSIBILITY_OP_INTERFACE_H_
 | 
			
		||||
#endif  // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_INFER_FUSIBILITY_OP_INTERFACE_H_
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -15,8 +15,8 @@ limitations under the License.
 | 
			
		|||
 | 
			
		||||
// This file defines the operations used in the LXLA dialect.
 | 
			
		||||
 | 
			
		||||
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_IR_LHLO_OPS_H_
 | 
			
		||||
#define TENSORFLOW_COMPILER_MLIR_XLA_IR_LHLO_OPS_H_
 | 
			
		||||
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_OPS_H_
 | 
			
		||||
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_OPS_H_
 | 
			
		||||
 | 
			
		||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringRef.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
 | 
			
		||||
| 
						 | 
				
			
			@ -49,4 +49,4 @@ class XlaLhloDialect : public Dialect {
 | 
			
		|||
}  // namespace xla_lhlo
 | 
			
		||||
}  // end namespace mlir
 | 
			
		||||
 | 
			
		||||
#endif  // TENSORFLOW_COMPILER_MLIR_XLA_IR_LHLO_OPS_H_
 | 
			
		||||
#endif  // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_OPS_H_
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,80 @@
 | 
			
		|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_HLO_TO_LHLO_OP_H_
 | 
			
		||||
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_HLO_TO_LHLO_OP_H_
 | 
			
		||||
 | 
			
		||||
#include <type_traits>
 | 
			
		||||
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
 | 
			
		||||
 | 
			
		||||
namespace mlir {
 | 
			
		||||
namespace xla_hlo {
 | 
			
		||||
 | 
			
		||||
template <typename HloOpTy>
 | 
			
		||||
struct HloToLhloOpImpl {
 | 
			
		||||
  using Type = std::false_type;
 | 
			
		||||
};
 | 
			
		||||
template <typename HloOpTy>
 | 
			
		||||
using HloToLhloOp = typename HloToLhloOpImpl<HloOpTy>::Type;
 | 
			
		||||
 | 
			
		||||
#define MAP_HLO_TO_LHLO(OpName)             \
 | 
			
		||||
  template <>                               \
 | 
			
		||||
  struct HloToLhloOpImpl<xla_hlo::OpName> { \
 | 
			
		||||
    using Type = xla_lhlo::OpName;          \
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
MAP_HLO_TO_LHLO(AbsOp);
 | 
			
		||||
MAP_HLO_TO_LHLO(AddOp);
 | 
			
		||||
MAP_HLO_TO_LHLO(AndOp);
 | 
			
		||||
MAP_HLO_TO_LHLO(BroadcastInDimOp);
 | 
			
		||||
MAP_HLO_TO_LHLO(CeilOp);
 | 
			
		||||
MAP_HLO_TO_LHLO(ConstOp);
 | 
			
		||||
MAP_HLO_TO_LHLO(CompareOp);
 | 
			
		||||
MAP_HLO_TO_LHLO(ComplexOp);
 | 
			
		||||
MAP_HLO_TO_LHLO(ConvOp);
 | 
			
		||||
MAP_HLO_TO_LHLO(ConvertOp);
 | 
			
		||||
MAP_HLO_TO_LHLO(CopyOp);
 | 
			
		||||
MAP_HLO_TO_LHLO(CosOp);
 | 
			
		||||
MAP_HLO_TO_LHLO(DivOp);
 | 
			
		||||
MAP_HLO_TO_LHLO(DotOp);
 | 
			
		||||
MAP_HLO_TO_LHLO(ExpOp);
 | 
			
		||||
MAP_HLO_TO_LHLO(GatherOp);
 | 
			
		||||
MAP_HLO_TO_LHLO(ImagOp);
 | 
			
		||||
MAP_HLO_TO_LHLO(IotaOp);
 | 
			
		||||
MAP_HLO_TO_LHLO(LogOp);
 | 
			
		||||
MAP_HLO_TO_LHLO(MaxOp);
 | 
			
		||||
MAP_HLO_TO_LHLO(MinOp);
 | 
			
		||||
MAP_HLO_TO_LHLO(MulOp);
 | 
			
		||||
MAP_HLO_TO_LHLO(NegOp);
 | 
			
		||||
MAP_HLO_TO_LHLO(RealOp);
 | 
			
		||||
MAP_HLO_TO_LHLO(ReduceOp);
 | 
			
		||||
MAP_HLO_TO_LHLO(ReshapeOp);
 | 
			
		||||
MAP_HLO_TO_LHLO(RemOp);
 | 
			
		||||
MAP_HLO_TO_LHLO(RsqrtOp);
 | 
			
		||||
MAP_HLO_TO_LHLO(SelectOp);
 | 
			
		||||
MAP_HLO_TO_LHLO(SignOp);
 | 
			
		||||
MAP_HLO_TO_LHLO(SinOp);
 | 
			
		||||
MAP_HLO_TO_LHLO(SqrtOp);
 | 
			
		||||
MAP_HLO_TO_LHLO(SubOp);
 | 
			
		||||
MAP_HLO_TO_LHLO(TanhOp);
 | 
			
		||||
 | 
			
		||||
#undef MAP_HLO_TO_LHLO
 | 
			
		||||
 | 
			
		||||
}  // namespace xla_hlo
 | 
			
		||||
}  // namespace mlir
 | 
			
		||||
 | 
			
		||||
#endif  // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_HLO_TO_LHLO_OP_H_
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,510 @@
 | 
			
		|||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_XLA_TO_SCALAR_OP_H_
 | 
			
		||||
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_XLA_TO_SCALAR_OP_H_
 | 
			
		||||
 | 
			
		||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringRef.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringSwitch.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h"
 | 
			
		||||
 | 
			
		||||
namespace mlir {
 | 
			
		||||
namespace xla_lhlo {
 | 
			
		||||
namespace impl {
 | 
			
		||||
 | 
			
		||||
// A struct to map LhloBinaryOpTy type to the corresponding floating-point and
 | 
			
		||||
// integer scalar operation types.
 | 
			
		||||
template <typename LhloBinaryOpTy>
 | 
			
		||||
struct LhloToScalarOp;
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
struct LhloToScalarOp<xla_lhlo::AddOp> {
 | 
			
		||||
  using FOp = ::mlir::AddFOp;
 | 
			
		||||
  using IOp = ::mlir::AddIOp;
 | 
			
		||||
};
 | 
			
		||||
template <>
 | 
			
		||||
struct LhloToScalarOp<xla_lhlo::CompareOp> {
 | 
			
		||||
  using FOp = ::mlir::CmpFOp;
 | 
			
		||||
  using IOp = ::mlir::CmpIOp;
 | 
			
		||||
};
 | 
			
		||||
template <>
 | 
			
		||||
struct LhloToScalarOp<xla_lhlo::DivOp> {
 | 
			
		||||
  using FOp = ::mlir::DivFOp;
 | 
			
		||||
  using IOp = ::mlir::SignedDivIOp;
 | 
			
		||||
};
 | 
			
		||||
template <>
 | 
			
		||||
struct LhloToScalarOp<xla_lhlo::MulOp> {
 | 
			
		||||
  using FOp = ::mlir::MulFOp;
 | 
			
		||||
  using IOp = ::mlir::MulIOp;
 | 
			
		||||
};
 | 
			
		||||
template <>
 | 
			
		||||
struct LhloToScalarOp<xla_lhlo::RemOp> {
 | 
			
		||||
  using FOp = ::mlir::RemFOp;
 | 
			
		||||
  using IOp = ::mlir::SignedRemIOp;
 | 
			
		||||
};
 | 
			
		||||
template <>
 | 
			
		||||
struct LhloToScalarOp<xla_lhlo::SubOp> {
 | 
			
		||||
  using FOp = ::mlir::SubFOp;
 | 
			
		||||
  using IOp = ::mlir::SubIOp;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
template <typename LhloBinaryOpTy>
 | 
			
		||||
struct ScalarOp {
 | 
			
		||||
  using FOp = typename LhloToScalarOp<LhloBinaryOpTy>::FOp;
 | 
			
		||||
  using IOp = typename LhloToScalarOp<LhloBinaryOpTy>::IOp;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// Alias for the map from LHLO binary op type to STD floating-point op type.
 | 
			
		||||
template <typename LhloOp>
 | 
			
		||||
using ScalarFOp = typename ScalarOp<LhloOp>::FOp;
 | 
			
		||||
// Alias for the map from LHLO binary op type to STD integer op type.
 | 
			
		||||
template <typename LhloOp>
 | 
			
		||||
using ScalarIOp = typename ScalarOp<LhloOp>::IOp;
 | 
			
		||||
 | 
			
		||||
template <typename... Args>
 | 
			
		||||
struct MapLhloOpToStdScalarOpImpl {
 | 
			
		||||
  Value operator()(Location loc, ArrayRef<Type> result_types,
 | 
			
		||||
                   ArrayRef<Value> args, OpBuilder* b) {
 | 
			
		||||
    return nullptr;
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
template <typename StdScalarOp>
 | 
			
		||||
struct MapLhloOpToStdScalarOpImpl<StdScalarOp> {
 | 
			
		||||
  Value operator()(Location loc, ArrayRef<Type> result_types,
 | 
			
		||||
                   ArrayRef<Value> args, OpBuilder* b) {
 | 
			
		||||
    return b->template create<StdScalarOp>(loc, result_types, args, mlir::None);
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
template <typename SupportedType, typename StdScalarOp, typename... Args>
 | 
			
		||||
struct MapLhloOpToStdScalarOpImpl<SupportedType, StdScalarOp, Args...> {
 | 
			
		||||
  Value operator()(Location loc, ArrayRef<Type> result_types,
 | 
			
		||||
                   ArrayRef<Value> args, OpBuilder* b) {
 | 
			
		||||
    Type element_type = args.front().getType();
 | 
			
		||||
    if (element_type.isa<SupportedType>()) {
 | 
			
		||||
      return b->template create<StdScalarOp>(loc, result_types, args,
 | 
			
		||||
                                             mlir::None);
 | 
			
		||||
    }
 | 
			
		||||
    return MapLhloOpToStdScalarOpImpl<Args...>{}(loc, result_types, args, b);
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// Inserts the computation that corresponds to the body of the loop for lowered
 | 
			
		||||
// LHLO unary/binary op. Returns the value for the result.
 | 
			
		||||
template <typename LhloOpTy>
 | 
			
		||||
inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef<Type> result_types,
 | 
			
		||||
                                    ArrayRef<Value> args, OpBuilder* b) {
 | 
			
		||||
  return MapLhloOpToStdScalarOpImpl<IntegerType, ScalarIOp<LhloOpTy>, FloatType,
 | 
			
		||||
                                    ScalarFOp<LhloOpTy>>{}(loc, result_types,
 | 
			
		||||
                                                           args, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::AbsOp>(
 | 
			
		||||
    Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
 | 
			
		||||
    OpBuilder* b) {
 | 
			
		||||
  Type element_type = args.front().getType();
 | 
			
		||||
  if (element_type.isa<FloatType>()) {
 | 
			
		||||
    return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::AbsFOp>{}(
 | 
			
		||||
        loc, result_types, args, b);
 | 
			
		||||
  }
 | 
			
		||||
  if (element_type.isa<IntegerType>()) {
 | 
			
		||||
    // xla_lhlo.abs(x, result) ->  result = select((x > 0), x, sub(0, x))
 | 
			
		||||
    Value lhs = args[0];
 | 
			
		||||
    auto integer_type = element_type.dyn_cast<IntegerType>();
 | 
			
		||||
 | 
			
		||||
    auto zero_intval =
 | 
			
		||||
        b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth());
 | 
			
		||||
    auto lhs_gt_zero = b->create<ScalarIOp<CompareOp>>(loc, CmpIPredicate::sge,
 | 
			
		||||
                                                       lhs, zero_intval);
 | 
			
		||||
    auto neg_val = b->create<ScalarIOp<xla_lhlo::SubOp>>(loc, zero_intval, lhs);
 | 
			
		||||
    return b->create<::mlir::SelectOp>(loc, lhs_gt_zero, lhs, neg_val);
 | 
			
		||||
  }
 | 
			
		||||
  return nullptr;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::AndOp>(
 | 
			
		||||
    Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
 | 
			
		||||
    OpBuilder* b) {
 | 
			
		||||
  return MapLhloOpToStdScalarOpImpl<IntegerType, ::mlir::AndOp>{}(
 | 
			
		||||
      loc, result_types, args, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename PredicateType>
 | 
			
		||||
inline Optional<PredicateType> getCmpPredicate(
 | 
			
		||||
    StringRef xla_comparison_direction) {
 | 
			
		||||
  return llvm::None;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
inline Optional<CmpFPredicate> getCmpPredicate<CmpFPredicate>(
 | 
			
		||||
    StringRef xla_comparison_direction) {
 | 
			
		||||
  return llvm::StringSwitch<Optional<CmpFPredicate>>(xla_comparison_direction)
 | 
			
		||||
      .Case("EQ", CmpFPredicate::OEQ)
 | 
			
		||||
      .Case("NE", CmpFPredicate::ONE)
 | 
			
		||||
      .Case("GE", CmpFPredicate::OGE)
 | 
			
		||||
      .Case("GT", CmpFPredicate::OGT)
 | 
			
		||||
      .Case("LE", CmpFPredicate::OLE)
 | 
			
		||||
      .Case("LT", CmpFPredicate::OLT)
 | 
			
		||||
      .Default(llvm::None);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
inline Optional<CmpIPredicate> getCmpPredicate<CmpIPredicate>(
 | 
			
		||||
    StringRef xla_comparison_direction) {
 | 
			
		||||
  return llvm::StringSwitch<Optional<CmpIPredicate>>(xla_comparison_direction)
 | 
			
		||||
      .Case("EQ", CmpIPredicate::eq)
 | 
			
		||||
      .Case("NE", CmpIPredicate::ne)
 | 
			
		||||
      .Case("GE", CmpIPredicate::sge)
 | 
			
		||||
      .Case("GT", CmpIPredicate::sgt)
 | 
			
		||||
      .Case("LE", CmpIPredicate::sle)
 | 
			
		||||
      .Case("LT", CmpIPredicate::slt)
 | 
			
		||||
      .Default(llvm::None);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename XLACompareOpTy>
 | 
			
		||||
inline Value MapXlaCompareOpToStdScalarOp(Location loc,
 | 
			
		||||
                                          StringRef comparison_direction,
 | 
			
		||||
                                          ArrayRef<Type> result_types,
 | 
			
		||||
                                          ArrayRef<Value> args, OpBuilder* b) {
 | 
			
		||||
  const auto& lhs = args[0];
 | 
			
		||||
  const auto& rhs = args[1];
 | 
			
		||||
  Type element_type = lhs.getType();
 | 
			
		||||
  if (element_type.isSignlessInteger()) {
 | 
			
		||||
    Optional<CmpIPredicate> predicate =
 | 
			
		||||
        getCmpPredicate<CmpIPredicate>(comparison_direction);
 | 
			
		||||
    assert(predicate.hasValue() && "expected valid comparison direction");
 | 
			
		||||
    return b->create<ScalarIOp<XLACompareOpTy>>(loc, predicate.getValue(), lhs,
 | 
			
		||||
                                                rhs);
 | 
			
		||||
  }
 | 
			
		||||
  if (element_type.isa<FloatType>()) {
 | 
			
		||||
    Optional<CmpFPredicate> predicate =
 | 
			
		||||
        getCmpPredicate<CmpFPredicate>(comparison_direction);
 | 
			
		||||
    assert(predicate.hasValue() && "expected valid comparison direction");
 | 
			
		||||
    return b->create<ScalarFOp<XLACompareOpTy>>(loc, predicate.getValue(), lhs,
 | 
			
		||||
                                                rhs);
 | 
			
		||||
  }
 | 
			
		||||
  return nullptr;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::CopyOp>(
 | 
			
		||||
    Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
 | 
			
		||||
    OpBuilder* b) {
 | 
			
		||||
  return args.front();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::ExpOp>(
 | 
			
		||||
    Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
 | 
			
		||||
    OpBuilder* b) {
 | 
			
		||||
  return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::ExpOp>{}(
 | 
			
		||||
      loc, result_types, args, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::CeilOp>(
 | 
			
		||||
    Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
 | 
			
		||||
    OpBuilder* b) {
 | 
			
		||||
  return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::CeilFOp>{}(
 | 
			
		||||
      loc, result_types, args, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::ComplexOp>(
 | 
			
		||||
    Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
 | 
			
		||||
    OpBuilder* b) {
 | 
			
		||||
  return MapLhloOpToStdScalarOpImpl<CreateComplexOp>{}(loc, result_types, args,
 | 
			
		||||
                                                       b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::RealOp>(
 | 
			
		||||
    Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
 | 
			
		||||
    OpBuilder* b) {
 | 
			
		||||
  return MapLhloOpToStdScalarOpImpl<ReOp>{}(loc, result_types, args, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::ImagOp>(
 | 
			
		||||
    Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
 | 
			
		||||
    OpBuilder* b) {
 | 
			
		||||
  return MapLhloOpToStdScalarOpImpl<ImOp>{}(loc, result_types, args, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::ConvertOp>(
 | 
			
		||||
    Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
 | 
			
		||||
    OpBuilder* b) {
 | 
			
		||||
  Type sourceType = args.front().getType();
 | 
			
		||||
  Type targetType = result_types.front();
 | 
			
		||||
 | 
			
		||||
  if (mlir::SIToFPOp::areCastCompatible(sourceType, targetType)) {
 | 
			
		||||
    return b->create<mlir::SIToFPOp>(loc, result_types, args, mlir::None);
 | 
			
		||||
  } else if (sourceType.isa<FloatType>() && targetType.isa<FloatType>()) {
 | 
			
		||||
    FloatType src = sourceType.cast<FloatType>();
 | 
			
		||||
    FloatType res = targetType.cast<FloatType>();
 | 
			
		||||
    if (src.getWidth() > res.getWidth()) {
 | 
			
		||||
      return b->create<mlir::FPTruncOp>(loc, result_types, args, mlir::None);
 | 
			
		||||
    } else if (src.getWidth() < res.getWidth()) {
 | 
			
		||||
      return b->create<mlir::FPExtOp>(loc, result_types, args, mlir::None);
 | 
			
		||||
    }
 | 
			
		||||
    // No conversion is needed for the same width floats
 | 
			
		||||
    return args.front();
 | 
			
		||||
  }
 | 
			
		||||
  if (sourceType.isSignlessInteger() && targetType.isSignlessInteger()) {
 | 
			
		||||
    IntegerType src = sourceType.cast<IntegerType>();
 | 
			
		||||
    IntegerType res = targetType.cast<IntegerType>();
 | 
			
		||||
    if (src.getWidth() > res.getWidth()) {
 | 
			
		||||
      return b->create<mlir::TruncateIOp>(loc, result_types, args, mlir::None);
 | 
			
		||||
    } else if (src.getWidth() < res.getWidth()) {
 | 
			
		||||
      return b->create<mlir::ZeroExtendIOp>(loc, result_types, args,
 | 
			
		||||
                                            mlir::None);
 | 
			
		||||
    }
 | 
			
		||||
    // No conversion is needed for the same width integers
 | 
			
		||||
    return args.front();
 | 
			
		||||
  }
 | 
			
		||||
  if (mlir::FPToSIOp::areCastCompatible(sourceType, targetType)) {
 | 
			
		||||
    return b->create<mlir::FPToSIOp>(loc, result_types, args, mlir::None);
 | 
			
		||||
  }
 | 
			
		||||
  return nullptr;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::DotOp>(
 | 
			
		||||
    Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
 | 
			
		||||
    OpBuilder* b) {
 | 
			
		||||
  // Dot Op converter from lhlo to affine only accepts float and integer types.
 | 
			
		||||
  const auto& lhs = args[0];
 | 
			
		||||
  const auto& rhs = args[1];
 | 
			
		||||
  const auto& result = args[2];
 | 
			
		||||
  Type element_type = lhs.getType();
 | 
			
		||||
  if (element_type.isa<FloatType>()) {
 | 
			
		||||
    Value float_mul = MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::MulFOp>{}(
 | 
			
		||||
        loc, result_types, {lhs, rhs}, b);
 | 
			
		||||
    return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::AddFOp>{}(
 | 
			
		||||
        loc, result_types, {float_mul, result}, b);
 | 
			
		||||
  }
 | 
			
		||||
  if (element_type.isa<IntegerType>()) {
 | 
			
		||||
    Value int_mul = MapLhloOpToStdScalarOpImpl<IntegerType, ::mlir::MulIOp>{}(
 | 
			
		||||
        loc, result_types, {lhs, rhs}, b);
 | 
			
		||||
    return MapLhloOpToStdScalarOpImpl<IntegerType, ::mlir::AddIOp>{}(
 | 
			
		||||
        loc, result_types, {int_mul, result}, b);
 | 
			
		||||
  }
 | 
			
		||||
  return nullptr;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::CosOp>(
 | 
			
		||||
    Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
 | 
			
		||||
    OpBuilder* b) {
 | 
			
		||||
  return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::CosOp>{}(
 | 
			
		||||
      loc, result_types, args, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::SinOp>(
 | 
			
		||||
    Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
 | 
			
		||||
    OpBuilder* b) {
 | 
			
		||||
  return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::SinOp>{}(
 | 
			
		||||
      loc, result_types, args, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/// Implements the conversion of XLA op to scalar op (to use within region of a
 | 
			
		||||
/// linalg.generic op) for compare-select style operations like min/max.
 | 
			
		||||
template <typename... Args>
 | 
			
		||||
struct XlaCompareSelectOpToStdScalarOp {
 | 
			
		||||
  static Value map(Location loc, StringRef comparison_direction,
 | 
			
		||||
                   ArrayRef<Type> result_types, ArrayRef<Value> args,
 | 
			
		||||
                   OpBuilder* b) {
 | 
			
		||||
    return nullptr;
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
/// Specialization which allows converting to a comparison operation in standard
 | 
			
		||||
/// dialect with a given predicate based on the element type of the operand.
 | 
			
		||||
template <typename SupportedType, typename StdCompareOp, typename Predicate,
 | 
			
		||||
          typename... Args>
 | 
			
		||||
struct XlaCompareSelectOpToStdScalarOp<SupportedType, StdCompareOp, Predicate,
 | 
			
		||||
                                       Args...> {
 | 
			
		||||
  static Value map(Location loc, StringRef comparison_direction,
 | 
			
		||||
                   ArrayRef<Type> result_types, ArrayRef<Value> args,
 | 
			
		||||
                   OpBuilder* b) {
 | 
			
		||||
    Type element_type = args.front().getType();
 | 
			
		||||
    if (element_type.isa<SupportedType>()) {
 | 
			
		||||
      auto predicate = getCmpPredicate<Predicate>(comparison_direction);
 | 
			
		||||
      assert(predicate.hasValue() && "expected valid comparison direction");
 | 
			
		||||
      auto cmp = b->template create<StdCompareOp>(loc, predicate.getValue(),
 | 
			
		||||
                                                  args[0], args[1]);
 | 
			
		||||
      return b->create<::mlir::SelectOp>(loc, cmp, args[0], args[1]);
 | 
			
		||||
    }
 | 
			
		||||
    return XlaCompareSelectOpToStdScalarOp<Args...>::map(
 | 
			
		||||
        loc, comparison_direction, result_types, args, b);
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::LogOp>(
 | 
			
		||||
    Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
 | 
			
		||||
    OpBuilder* b) {
 | 
			
		||||
  return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::LogOp>{}(
 | 
			
		||||
      loc, result_types, args, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::MaxOp>(
 | 
			
		||||
    Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
 | 
			
		||||
    OpBuilder* b) {
 | 
			
		||||
  return XlaCompareSelectOpToStdScalarOp<
 | 
			
		||||
      IntegerType, ScalarIOp<xla_lhlo::CompareOp>, CmpIPredicate, FloatType,
 | 
			
		||||
      ScalarFOp<xla_lhlo::CompareOp>, CmpFPredicate>::map(loc, "GT",
 | 
			
		||||
                                                          result_types, args,
 | 
			
		||||
                                                          b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::MinOp>(
 | 
			
		||||
    Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
 | 
			
		||||
    OpBuilder* b) {
 | 
			
		||||
  return XlaCompareSelectOpToStdScalarOp<
 | 
			
		||||
      IntegerType, ScalarIOp<xla_lhlo::CompareOp>, CmpIPredicate, FloatType,
 | 
			
		||||
      ScalarFOp<xla_lhlo::CompareOp>, CmpFPredicate>::map(loc, "LT",
 | 
			
		||||
                                                          result_types, args,
 | 
			
		||||
                                                          b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::NegOp>(
 | 
			
		||||
    Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
 | 
			
		||||
    OpBuilder* b) {
 | 
			
		||||
  Type element_type = args.front().getType();
 | 
			
		||||
  if (element_type.isa<FloatType>()) {
 | 
			
		||||
    return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::NegFOp>{}(
 | 
			
		||||
        loc, result_types, args, b);
 | 
			
		||||
  }
 | 
			
		||||
  if (element_type.isa<IntegerType>()) {
 | 
			
		||||
    // xla_lhlo.neg(x, result) -> result = sub(0, x)
 | 
			
		||||
    Value lhs = args[0];
 | 
			
		||||
    auto integer_type = element_type.dyn_cast<IntegerType>();
 | 
			
		||||
 | 
			
		||||
    auto zero_intval =
 | 
			
		||||
        b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth());
 | 
			
		||||
    return b->create<ScalarIOp<xla_lhlo::SubOp>>(loc, zero_intval, lhs);
 | 
			
		||||
  }
 | 
			
		||||
  return nullptr;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::RsqrtOp>(
 | 
			
		||||
    Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
 | 
			
		||||
    OpBuilder* b) {
 | 
			
		||||
  return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::RsqrtOp>{}(
 | 
			
		||||
      loc, result_types, args, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::SelectOp>(
 | 
			
		||||
    Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
 | 
			
		||||
    OpBuilder* b) {
 | 
			
		||||
  return MapLhloOpToStdScalarOpImpl<::mlir::SelectOp>{}(loc, result_types, args,
 | 
			
		||||
                                                        b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::SignOp>(
 | 
			
		||||
    Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
 | 
			
		||||
    OpBuilder* b) {
 | 
			
		||||
  Type element_type = args.front().getType();
 | 
			
		||||
  if (element_type.isa<FloatType>()) {
 | 
			
		||||
    FloatType float_type = element_type.cast<FloatType>();
 | 
			
		||||
    APFloat const_value = float_type.isF32() ? APFloat(1.0f) : APFloat(1.0);
 | 
			
		||||
    Value one = b->create<mlir::ConstantFloatOp>(loc, const_value, float_type);
 | 
			
		||||
    return b->create<::mlir::CopySignOp>(loc, result_types, one, args[0]);
 | 
			
		||||
  }
 | 
			
		||||
  return nullptr;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::SqrtOp>(
 | 
			
		||||
    Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
 | 
			
		||||
    OpBuilder* b) {
 | 
			
		||||
  return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::SqrtOp>{}(
 | 
			
		||||
      loc, result_types, args, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <>
 | 
			
		||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::TanhOp>(
 | 
			
		||||
    Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
 | 
			
		||||
    OpBuilder* b) {
 | 
			
		||||
  return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::TanhOp>{}(
 | 
			
		||||
      loc, result_types, args, b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace impl
 | 
			
		||||
 | 
			
		||||
struct XlaOpToStdScalarOp {
 | 
			
		||||
  // Implementation for LHLO ops except xla_lhlo::CompareOp.
 | 
			
		||||
  template <typename XlaOpTy, typename LhloOpTy = XlaOpTy,
 | 
			
		||||
            typename = std::enable_if_t<
 | 
			
		||||
                !std::is_same<LhloOpTy, xla_lhlo::CompareOp>::value &&
 | 
			
		||||
                std::is_same<typename xla_hlo::HloToLhloOp<LhloOpTy>,
 | 
			
		||||
                             std::false_type>::value>>
 | 
			
		||||
  static Value map(XlaOpTy op, ArrayRef<Type> result_types,
 | 
			
		||||
                   ArrayRef<Value> args, OpBuilder* b, unsigned i = 0) {
 | 
			
		||||
    return impl::MapLhloOpToStdScalarOp<LhloOpTy>(op.getLoc(), result_types,
 | 
			
		||||
                                                  args, b);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Implementation for HLO ops except xla_hlo::CompareOp.
 | 
			
		||||
  template <typename XlaOpTy, typename LhloOpTy = xla_hlo::HloToLhloOp<XlaOpTy>,
 | 
			
		||||
            typename = std::enable_if_t<
 | 
			
		||||
                !std::is_same<LhloOpTy, xla_lhlo::CompareOp>::value &&
 | 
			
		||||
                !std::is_same<LhloOpTy, std::false_type>::value>>
 | 
			
		||||
  static Value map(XlaOpTy op, ArrayRef<Type> result_types,
 | 
			
		||||
                   ArrayRef<Value> args, OpBuilder* b, int i = 0) {
 | 
			
		||||
    return impl::MapLhloOpToStdScalarOp<LhloOpTy>(op.getLoc(), result_types,
 | 
			
		||||
                                                  args, b);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Implementation for xla_lhlo::CompareOp.
 | 
			
		||||
  template <typename LhloOpTy, typename = std::enable_if_t<std::is_same<
 | 
			
		||||
                                   LhloOpTy, xla_lhlo::CompareOp>::value>>
 | 
			
		||||
  static Value map(xla_lhlo::CompareOp op, ArrayRef<Type> result_types,
 | 
			
		||||
                   ArrayRef<Value> args, OpBuilder* b) {
 | 
			
		||||
    auto comparison_direction = op.comparison_direction();
 | 
			
		||||
    return impl::MapXlaCompareOpToStdScalarOp<xla_lhlo::CompareOp>(
 | 
			
		||||
        op.getLoc(), comparison_direction, result_types, args, b);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Implementation for xla_hlo::CompareOp.
 | 
			
		||||
  template <typename HloOpTy, typename = std::enable_if_t<std::is_same<
 | 
			
		||||
                                  HloOpTy, xla_hlo::CompareOp>::value>>
 | 
			
		||||
  static Value map(xla_hlo::CompareOp op, ArrayRef<Type> result_types,
 | 
			
		||||
                   ArrayRef<Value> args, OpBuilder* b) {
 | 
			
		||||
    auto comparison_direction = op.comparison_direction();
 | 
			
		||||
    return impl::MapXlaCompareOpToStdScalarOp<xla_lhlo::CompareOp>(
 | 
			
		||||
        op.getLoc(), comparison_direction, result_types, args, b);
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
}  // namespace xla_lhlo
 | 
			
		||||
}  // namespace mlir
 | 
			
		||||
 | 
			
		||||
#endif  // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_XLA_TO_SCALAR_OP_H_
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,105 @@
 | 
			
		|||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_PASSES_H_
 | 
			
		||||
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_PASSES_H_
 | 
			
		||||
 | 
			
		||||
#include <memory>
 | 
			
		||||
 | 
			
		||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/ArrayRef.h"
 | 
			
		||||
 | 
			
		||||
namespace mlir {
 | 
			
		||||
 | 
			
		||||
class FuncOp;
 | 
			
		||||
class ModuleOp;
 | 
			
		||||
class Operation;
 | 
			
		||||
template <typename T>
 | 
			
		||||
class OperationPass;
 | 
			
		||||
class Pass;
 | 
			
		||||
 | 
			
		||||
namespace xla_hlo {
 | 
			
		||||
 | 
			
		||||
/// Lowers HLO control flow ops to the Standard dialect.
 | 
			
		||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeControlFlowPass();
 | 
			
		||||
 | 
			
		||||
/// Lowers from HLO dialect to Standard dialect.
 | 
			
		||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeToStdPass();
 | 
			
		||||
 | 
			
		||||
/// Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary
 | 
			
		||||
/// buffers if necessary. If `results_escape_functions` is set to true,
 | 
			
		||||
/// allocated buffers for function results will be returned and escape the
 | 
			
		||||
/// function. Otherwise, the signature is rewritten with extra arguments for the
 | 
			
		||||
/// buffers that are to be used for results.
 | 
			
		||||
std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass(
 | 
			
		||||
    bool results_escape_functions = false);
 | 
			
		||||
 | 
			
		||||
// Lowers from HLO dialect to Linalg dialect.
 | 
			
		||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass();
 | 
			
		||||
 | 
			
		||||
// Transforms unranked HLO operations to ranked ones where possible.
 | 
			
		||||
std::unique_ptr<OperationPass<FuncOp>> createTransformUnrankedHloPass();
 | 
			
		||||
 | 
			
		||||
// Sinks constants implicitly captured in control flow regions. This is
 | 
			
		||||
// necessary to export to XLA.
 | 
			
		||||
std::unique_ptr<OperationPass<FuncOp>> createSinkConstantsToControlFlowPass();
 | 
			
		||||
 | 
			
		||||
// fuse xla_hlo ops to kLoop/kInput fusion patterns
 | 
			
		||||
std::unique_ptr<OperationPass<FuncOp>> createXlaHloFusionPass();
 | 
			
		||||
 | 
			
		||||
}  // namespace xla_hlo
 | 
			
		||||
 | 
			
		||||
namespace xla_lhlo {
 | 
			
		||||
 | 
			
		||||
// Lowers from LHLO dialect to Affine dialect.
 | 
			
		||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeToAffinePass();
 | 
			
		||||
 | 
			
		||||
// Lowers from LHLO dialect to Linalg dialect.
 | 
			
		||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeLhloToLinalgPass();
 | 
			
		||||
 | 
			
		||||
// Lowers from LHLO dialect to GPU dialect.
 | 
			
		||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeToGpuPass();
 | 
			
		||||
 | 
			
		||||
// Fuses linalg ops obtained after LHLO lowering. To enable fusion,
 | 
			
		||||
// operations are first tiled.
 | 
			
		||||
//
 | 
			
		||||
// When 'use_parallel_loops' is set, the tiling will use scf.parallel
 | 
			
		||||
// operations. Otherwise, scf.for operations are used.
 | 
			
		||||
//
 | 
			
		||||
// 'tile_sizes' provides the tile sizes to use for tiling. If the linalg
 | 
			
		||||
// operation has more dimensions than tile sizes provided, 1 is used as
 | 
			
		||||
// default.
 | 
			
		||||
std::unique_ptr<OperationPass<FuncOp>> createLhloFuseLinalg(
 | 
			
		||||
    bool use_parallel_loops = false, llvm::ArrayRef<unsigned> tile_sizes = {});
 | 
			
		||||
 | 
			
		||||
// Removes unnecessary LHLO copies which copy from the allocated buffers to the
 | 
			
		||||
// block arguments. The block arguments are used instead of all uses of these
 | 
			
		||||
// buffers. The buffers are freed. This pass only works in regions that contain
 | 
			
		||||
// a single block.
 | 
			
		||||
std::unique_ptr<Pass> createLhloCopyRemovalPass();
 | 
			
		||||
 | 
			
		||||
// Lowers from LHLO dialect to parallel loops.
 | 
			
		||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeLhloToParallelLoopsPass();
 | 
			
		||||
 | 
			
		||||
}  // namespace xla_lhlo
 | 
			
		||||
 | 
			
		||||
namespace xla {
 | 
			
		||||
 | 
			
		||||
/// Lowers the standard TanhOp to an approximation that does not use intrinsics.
 | 
			
		||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeTanhToApproximationPass();
 | 
			
		||||
 | 
			
		||||
}  // namespace xla
 | 
			
		||||
}  // namespace mlir
 | 
			
		||||
 | 
			
		||||
#endif  // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_PASSES_H_
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,106 @@
 | 
			
		|||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_REWRITERS_H_
 | 
			
		||||
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_REWRITERS_H_
 | 
			
		||||
 | 
			
		||||
#include <memory>
 | 
			
		||||
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h"
 | 
			
		||||
 | 
			
		||||
namespace mlir {
 | 
			
		||||
class LLVMTypeConverter;
 | 
			
		||||
class LowerToLLVMOptions;
 | 
			
		||||
class OwningRewritePatternList;
 | 
			
		||||
class BufferAssignmentPlacer;
 | 
			
		||||
namespace xla_hlo {
 | 
			
		||||
 | 
			
		||||
// Collection of rewrite patterns for lowering a general dot product.
 | 
			
		||||
void PopulateGeneralDotOpLoweringPatterns(OwningRewritePatternList *patterns,
 | 
			
		||||
                                          MLIRContext *ctx);
 | 
			
		||||
 | 
			
		||||
// Collection of rewrite patterns for lowering complex operations to equivalent
 | 
			
		||||
// float operations.
 | 
			
		||||
void PopulateComplexLoweringPatterns(MLIRContext *context,
 | 
			
		||||
                                     OwningRewritePatternList *patterns);
 | 
			
		||||
 | 
			
		||||
void PopulateXlaToStdPatterns(OwningRewritePatternList *patterns,
 | 
			
		||||
                              MLIRContext *ctx);
 | 
			
		||||
 | 
			
		||||
// Collection of rewrite patterns for lowering of HLO to LHLO dialect.
 | 
			
		||||
void populateHLOToLHLOConversionPattern(
 | 
			
		||||
    MLIRContext *context, BufferAssignmentPlacer *bufferAssignment,
 | 
			
		||||
    TypeConverter *converter, OwningRewritePatternList *patterns);
 | 
			
		||||
// Collection of rewrite patterns for lowering of HLO to Linalg dialect.
 | 
			
		||||
void populateHLOToLinalgConversionPattern(MLIRContext *context,
 | 
			
		||||
                                          OwningRewritePatternList *patterns);
 | 
			
		||||
 | 
			
		||||
// Sets up legality definitions for materializing broadcasts.
 | 
			
		||||
void SetupMaterializeBroadcastsLegality(MLIRContext *context,
 | 
			
		||||
                                        ConversionTarget *conversionTarget);
 | 
			
		||||
 | 
			
		||||
// Populates a collection of rewrite patterns for materializing broadcast
 | 
			
		||||
// attributes to equivalent sequences of ops.
 | 
			
		||||
void PopulateMaterializeBroadcastsPatterns(MLIRContext *context,
 | 
			
		||||
                                           OwningRewritePatternList *patterns);
 | 
			
		||||
 | 
			
		||||
// Sets up legality definitions for element-wise operations on ranked tensors.
 | 
			
		||||
void SetupTransformUnrankedHloLegality(MLIRContext *context,
 | 
			
		||||
                                       ConversionTarget *conversionTarget);
 | 
			
		||||
 | 
			
		||||
// Populates a collection of rewrite patterns to realize element-wise operations
 | 
			
		||||
// on ranked tensors where possible.
 | 
			
		||||
void PopulateTransformUnrankedHloPatterns(MLIRContext *context,
 | 
			
		||||
                                          OwningRewritePatternList *patterns);
 | 
			
		||||
 | 
			
		||||
// Populate a collection of conversion patterns for un-fusing
 | 
			
		||||
// batch_norm_inference and batch_norm_training into constituent HLO ops.
 | 
			
		||||
// TODO(laurenzo): Implement un-fusing of batch_norm_training.
 | 
			
		||||
void PopulateUnfuseBatchNormPatterns(MLIRContext *context,
 | 
			
		||||
                                     OwningRewritePatternList *patterns);
 | 
			
		||||
 | 
			
		||||
}  // namespace xla_hlo
 | 
			
		||||
 | 
			
		||||
namespace xla_lhlo {
 | 
			
		||||
 | 
			
		||||
/// Collect a set of patterns to convert from the LHLO dialect to LLVM.
 | 
			
		||||
void PopulateLhloToLLVMConversionPatterns(const LowerToLLVMOptions &options,
 | 
			
		||||
                                          LLVMTypeConverter *converter,
 | 
			
		||||
                                          OwningRewritePatternList *patterns);
 | 
			
		||||
 | 
			
		||||
}  // namespace xla_lhlo
 | 
			
		||||
 | 
			
		||||
namespace xla_chlo {
 | 
			
		||||
 | 
			
		||||
// Populates a collection of conversion patterns for legalizing client-HLO to
 | 
			
		||||
// HLO.
 | 
			
		||||
void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
 | 
			
		||||
                                       OwningRewritePatternList *patterns);
 | 
			
		||||
 | 
			
		||||
}  // namespace xla_chlo
 | 
			
		||||
 | 
			
		||||
namespace xla {
 | 
			
		||||
 | 
			
		||||
// Populates a pattern that translates the standard TanhOp to an approximation
 | 
			
		||||
// that does not use intrinsics.
 | 
			
		||||
void PopulateTanhToApproximationPatterns(MLIRContext *context,
 | 
			
		||||
                                         OwningRewritePatternList *patterns);
 | 
			
		||||
 | 
			
		||||
}  // namespace xla
 | 
			
		||||
}  // namespace mlir
 | 
			
		||||
 | 
			
		||||
#endif  // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_REWRITERS_H_
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,165 @@
 | 
			
		|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_CYCLE_DETECTOR_H_
 | 
			
		||||
#define TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_CYCLE_DETECTOR_H_
 | 
			
		||||
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/DenseMap.h"
 | 
			
		||||
 | 
			
		||||
namespace mlir {
 | 
			
		||||
 | 
			
		||||
// -------------------------------------------------------------------
 | 
			
		||||
 | 
			
		||||
// This file contains a light version of GraphCycles implemented in
 | 
			
		||||
// tensorflow/compiler/jit/graphcycles/graphcycles.h
 | 
			
		||||
//
 | 
			
		||||
// We re-implement it here because we do not want to rely
 | 
			
		||||
// on TensorFlow data structures, and hence we can move
 | 
			
		||||
// corresponding passes to llvm repo. easily in case necessnary.
 | 
			
		||||
 | 
			
		||||
// --------------------------------------------------------------------
 | 
			
		||||
 | 
			
		||||
// This is a set data structure that provides a deterministic iteration order.
 | 
			
		||||
// The iteration order of elements only depends on the sequence of
 | 
			
		||||
// inserts/deletes, so as long as the inserts/deletes happen in the same
 | 
			
		||||
// sequence, the set will have the same iteration order.
 | 
			
		||||
//
 | 
			
		||||
// Assumes that T can be cheaply copied for simplicity.
 | 
			
		||||
template <typename T>
 | 
			
		||||
class OrderedSet {
 | 
			
		||||
 public:
 | 
			
		||||
  // Inserts `value` into the ordered set.  Returns true if the value was not
 | 
			
		||||
  // present in the set before the insertion.
 | 
			
		||||
  bool Insert(T value) {
 | 
			
		||||
    bool new_insertion =
 | 
			
		||||
        value_to_index_.insert({value, value_sequence_.size()}).second;
 | 
			
		||||
    if (new_insertion) {
 | 
			
		||||
      value_sequence_.push_back(value);
 | 
			
		||||
    }
 | 
			
		||||
    return new_insertion;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Removes `value` from the set.  Assumes `value` is already present in the
 | 
			
		||||
  // set.
 | 
			
		||||
  void Erase(T value) {
 | 
			
		||||
    auto it = value_to_index_.find(value);
 | 
			
		||||
 | 
			
		||||
    // Since we don't want to move values around in `value_sequence_` we swap
 | 
			
		||||
    // the value in the last position and with value to be deleted and then
 | 
			
		||||
    // pop_back.
 | 
			
		||||
    value_to_index_[value_sequence_.back()] = it->second;
 | 
			
		||||
    std::swap(value_sequence_[it->second], value_sequence_.back());
 | 
			
		||||
    value_sequence_.pop_back();
 | 
			
		||||
    value_to_index_.erase(it);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void Reserve(size_t new_size) {
 | 
			
		||||
    value_to_index_.reserve(new_size);
 | 
			
		||||
    value_sequence_.reserve(new_size);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void Clear() {
 | 
			
		||||
    value_to_index_.clear();
 | 
			
		||||
    value_sequence_.clear();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  bool Contains(T value) const { return value_to_index_.count(value); }
 | 
			
		||||
  size_t Size() const { return value_sequence_.size(); }
 | 
			
		||||
 | 
			
		||||
  const std::vector<T>& GetSequence() const { return value_sequence_; }
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  // The stable order that we maintain through insertions and deletions.
 | 
			
		||||
  std::vector<T> value_sequence_;
 | 
			
		||||
 | 
			
		||||
  // Maps values to their indices in `value_sequence_`.
 | 
			
		||||
  llvm::DenseMap<T, int> value_to_index_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// ---------------------------------------------------------------------
 | 
			
		||||
 | 
			
		||||
// GraphCycles detects the introduction of a cycle into a directed
 | 
			
		||||
// graph that is being built up incrementally.
 | 
			
		||||
//
 | 
			
		||||
// Nodes are identified by small integers.  It is not possible to
 | 
			
		||||
// record multiple edges with the same (source, destination) pair;
 | 
			
		||||
// requests to add an edge where one already exists are silently
 | 
			
		||||
// ignored.
 | 
			
		||||
//
 | 
			
		||||
// It is also not possible to introduce a cycle; an attempt to insert
 | 
			
		||||
// an edge that would introduce a cycle fails and returns false.
 | 
			
		||||
//
 | 
			
		||||
// GraphCycles uses no internal locking; calls into it should be
 | 
			
		||||
// serialized externally.
 | 
			
		||||
 | 
			
		||||
// Performance considerations:
 | 
			
		||||
//   Works well on sparse graphs, poorly on dense graphs.
 | 
			
		||||
//   Extra information is maintained incrementally to detect cycles quickly.
 | 
			
		||||
//   InsertEdge() is very fast when the edge already exists, and reasonably fast
 | 
			
		||||
//   otherwise.
 | 
			
		||||
//   FindPath() is linear in the size of the graph.
 | 
			
		||||
// The current implementation uses O(|V|+|E|) space.
 | 
			
		||||
 | 
			
		||||
class GraphCycles {
 | 
			
		||||
 public:
 | 
			
		||||
  explicit GraphCycles(int32_t num_nodes);
 | 
			
		||||
  ~GraphCycles();
 | 
			
		||||
 | 
			
		||||
  // Attempt to insert an edge from x to y.  If the
 | 
			
		||||
  // edge would introduce a cycle, return false without making any
 | 
			
		||||
  // changes. Otherwise add the edge and return true.
 | 
			
		||||
  bool InsertEdge(int32_t x, int32_t y);
 | 
			
		||||
 | 
			
		||||
  // Remove any edge that exists from x to y.
 | 
			
		||||
  void RemoveEdge(int32_t x, int32_t y);
 | 
			
		||||
 | 
			
		||||
  // Return whether there is an edge directly from x to y.
 | 
			
		||||
  bool HasEdge(int32_t x, int32_t y) const;
 | 
			
		||||
 | 
			
		||||
  // Contracts the edge from 'a' to node 'b', merging nodes 'a' and 'b'. One of
 | 
			
		||||
  // the nodes is removed from the graph, and edges to/from it are added to
 | 
			
		||||
  // the remaining one, which is returned. If contracting the edge would create
 | 
			
		||||
  // a cycle, does nothing and return no value.
 | 
			
		||||
  llvm::Optional<int32_t> ContractEdge(int32_t a, int32_t b);
 | 
			
		||||
 | 
			
		||||
  // Return whether dest_node `y` is reachable from source_node `x`
 | 
			
		||||
  // by following edges. This is non-thread-safe version.
 | 
			
		||||
  bool IsReachable(int32_t x, int32_t y);
 | 
			
		||||
 | 
			
		||||
  // Return a copy of the successors set. This is needed for code using the
 | 
			
		||||
  // collection while modifying the GraphCycles.
 | 
			
		||||
  std::vector<int32_t> SuccessorsCopy(int32_t node) const;
 | 
			
		||||
 | 
			
		||||
  // Returns all nodes in post order.
 | 
			
		||||
  //
 | 
			
		||||
  // If there is a path from X to Y then X appears after Y in the
 | 
			
		||||
  // returned vector.
 | 
			
		||||
  std::vector<int32_t> AllNodesInPostOrder() const;
 | 
			
		||||
 | 
			
		||||
  // ----------------------------------------------------
 | 
			
		||||
  struct Rep;
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  GraphCycles(const GraphCycles&) = delete;
 | 
			
		||||
  GraphCycles& operator=(const GraphCycles&) = delete;
 | 
			
		||||
 | 
			
		||||
  Rep* rep_;  // opaque representation
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
}  // namespace mlir
 | 
			
		||||
 | 
			
		||||
#endif  // TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_CYCLE_DETECTOR_H_
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,242 @@
 | 
			
		|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Shape/IR/Shape.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OperationSupport.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h"
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/broadcast_utils.h"
 | 
			
		||||
 | 
			
		||||
namespace mlir {
 | 
			
		||||
namespace xla_chlo {
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
// Converts binary ops that statically are determined to not broadcast directly
 | 
			
		||||
// to the corresponding xla_hlo non-broadcasting op.
 | 
			
		||||
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
 | 
			
		||||
struct ConvertTrivialNonBroadcastBinaryOp : public OpRewritePattern<ChloOpTy> {
 | 
			
		||||
  using OpRewritePattern<ChloOpTy>::OpRewritePattern;
 | 
			
		||||
  LogicalResult matchAndRewrite(ChloOpTy op,
 | 
			
		||||
                                PatternRewriter &rewriter) const override {
 | 
			
		||||
    // Only rewrite for statically determinable non-broadcasting cases.
 | 
			
		||||
    auto lhs_type = op.lhs().getType().template dyn_cast<RankedTensorType>();
 | 
			
		||||
    auto rhs_type = op.rhs().getType().template dyn_cast<RankedTensorType>();
 | 
			
		||||
    if (!lhs_type || !rhs_type) return failure();
 | 
			
		||||
 | 
			
		||||
    // Requires rank broadcast.
 | 
			
		||||
    if (lhs_type.getRank() != rhs_type.getRank()) return failure();
 | 
			
		||||
    // Any dynamic dimension may require broadcasting and requires more
 | 
			
		||||
    // analysis.
 | 
			
		||||
    if (!lhs_type.hasStaticShape() || !rhs_type.hasStaticShape())
 | 
			
		||||
      return failure();
 | 
			
		||||
 | 
			
		||||
    for (auto extents : llvm::zip(lhs_type.getShape(), rhs_type.getShape())) {
 | 
			
		||||
      auto lhs_extent = std::get<0>(extents);
 | 
			
		||||
      auto rhs_extent = std::get<1>(extents);
 | 
			
		||||
      if (lhs_extent != rhs_extent) {
 | 
			
		||||
        return failure();
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    rewriter.replaceOp(op, {Adaptor::CreateOp(op, op.getResult().getType(),
 | 
			
		||||
                                              op.lhs(), op.rhs(), rewriter)});
 | 
			
		||||
    return success();
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// Converts a binary op with ranked broadcasting operands to explicitly
 | 
			
		||||
// broadcast and invoke the corresponding xla_hlo non-broadcasting op.
 | 
			
		||||
// Note that dynamic broadcasting supported by this pattern is only valid for
 | 
			
		||||
// "numpy" broadcasting semantics as defined here:
 | 
			
		||||
//   https://docs.scipy.org/doc/numpy/reference/ufuncs.html
 | 
			
		||||
// Specifically, this includes the following cases:
 | 
			
		||||
//   - Same rank broadcast (operands have the same static rank).
 | 
			
		||||
//   - Different-rank broadcast, either without a broadcast_dims attribte or
 | 
			
		||||
//     with the broadcast_dims attribute set to map to a prefix padding.
 | 
			
		||||
//   - Legal combinations of degenerate (1-dim) implicit broadcasting.
 | 
			
		||||
// The restriction on broadcast_dims derives from the definition of the
 | 
			
		||||
// `shape.broadcast` op, which only supports prefix-padding.
 | 
			
		||||
//
 | 
			
		||||
// It may be possible to expand this pattern to operate on unranked tensors in
 | 
			
		||||
// the future by emitting more code to dynamically differentiate based on rank.
 | 
			
		||||
// Whether that is of any practical benefit remains to be seen.
 | 
			
		||||
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
 | 
			
		||||
struct ConvertRankedDynamicBroadcastBinaryOp
 | 
			
		||||
    : public OpRewritePattern<ChloOpTy> {
 | 
			
		||||
  using OpRewritePattern<ChloOpTy>::OpRewritePattern;
 | 
			
		||||
  LogicalResult matchAndRewrite(ChloOpTy op,
 | 
			
		||||
                                PatternRewriter &rewriter) const override {
 | 
			
		||||
    // Only support ranked operands.
 | 
			
		||||
    Value lhs = op.lhs();
 | 
			
		||||
    Value rhs = op.rhs();
 | 
			
		||||
    auto lhs_type = lhs.getType().dyn_cast<RankedTensorType>();
 | 
			
		||||
    auto rhs_type = rhs.getType().dyn_cast<RankedTensorType>();
 | 
			
		||||
    auto result_type =
 | 
			
		||||
        op.getResult().getType().template dyn_cast<RankedTensorType>();
 | 
			
		||||
    if (!lhs_type || !rhs_type || !result_type) return failure();
 | 
			
		||||
 | 
			
		||||
    // Check for "numpy"-style rank broadcast.
 | 
			
		||||
    auto broadcast_dimensions = op.broadcast_dimensions();
 | 
			
		||||
    if (broadcast_dimensions &&
 | 
			
		||||
        !xla::IsLegalNumpyRankedBroadcast(lhs, rhs, *broadcast_dimensions)) {
 | 
			
		||||
      // Note: It is unclear whether the general specification of explicit
 | 
			
		||||
      // broadcast_dimensions on binary ops is a feature we want to carry
 | 
			
		||||
      // forward. While it can technically be implemented for ranked-dynamic,
 | 
			
		||||
      // it is incompatible with unranked inputs. If this warning is emitted
 | 
			
		||||
      // in real programs, it is an indication that the feature should be
 | 
			
		||||
      // implemented versus just falling back on the more standard definition
 | 
			
		||||
      // of numpy-like prefix-padding.
 | 
			
		||||
      op.emitWarning() << "unsupported non prefix-padded dynamic rank "
 | 
			
		||||
                       << "broadcast_dimensions = " << *broadcast_dimensions;
 | 
			
		||||
      return failure();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Compute result shape.
 | 
			
		||||
    auto loc = op.getLoc();
 | 
			
		||||
 | 
			
		||||
    // Insert a constraint on the shapes being broadcastable and insert all
 | 
			
		||||
    // future code into an assuming block reliant on the constraint.
 | 
			
		||||
    Value lhs_shape = rewriter.create<shape::ShapeOfOp>(loc, lhs);
 | 
			
		||||
    Value rhs_shape = rewriter.create<shape::ShapeOfOp>(loc, rhs);
 | 
			
		||||
    auto broadcastable_cstr =
 | 
			
		||||
        rewriter.create<shape::CstrBroadcastableOp>(loc, lhs_shape, rhs_shape);
 | 
			
		||||
    auto assuming_op = rewriter.create<shape::AssumingOp>(
 | 
			
		||||
        loc, ArrayRef<Type>{result_type}, broadcastable_cstr.result());
 | 
			
		||||
 | 
			
		||||
    OpBuilder::InsertionGuard guard(rewriter);
 | 
			
		||||
    rewriter.createBlock(&assuming_op.doRegion());
 | 
			
		||||
 | 
			
		||||
    int64_t result_rank = std::max(lhs_type.getRank(), rhs_type.getRank());
 | 
			
		||||
    Value result_extents =
 | 
			
		||||
        xla::ComputeBinaryElementwiseBroadcastingResultExtents(loc, lhs, rhs,
 | 
			
		||||
                                                               rewriter);
 | 
			
		||||
 | 
			
		||||
    // Note that we unconditionally emit DynamicBroadcastInDim ops and let
 | 
			
		||||
    // downstream canonicalizations fold them away if possible. This is
 | 
			
		||||
    // because, in the dynamic case, there are many corner cases regarding
 | 
			
		||||
    // when it is safe to omit, and some of them require analysis to prove
 | 
			
		||||
    // properly.
 | 
			
		||||
    auto lhs_broadcast_dimensions = llvm::to_vector<4>(
 | 
			
		||||
        llvm::seq<int64_t>(result_rank - lhs_type.getRank(), result_rank));
 | 
			
		||||
    Value broadcasted_lhs = rewriter.create<xla_hlo::DynamicBroadcastInDimOp>(
 | 
			
		||||
        loc,
 | 
			
		||||
        RankedTensorType::get(result_type.getShape(),
 | 
			
		||||
                              lhs_type.getElementType()),
 | 
			
		||||
        lhs, result_extents,
 | 
			
		||||
        rewriter.getI64TensorAttr(lhs_broadcast_dimensions));
 | 
			
		||||
    auto rhs_broadcast_dimensions = llvm::to_vector<4>(
 | 
			
		||||
        llvm::seq<int64_t>(result_rank - rhs_type.getRank(), result_rank));
 | 
			
		||||
    Value broadcasted_rhs = rewriter.create<xla_hlo::DynamicBroadcastInDimOp>(
 | 
			
		||||
        loc,
 | 
			
		||||
        RankedTensorType::get(result_type.getShape(),
 | 
			
		||||
                              rhs_type.getElementType()),
 | 
			
		||||
        rhs, result_extents,
 | 
			
		||||
        rewriter.getI64TensorAttr(rhs_broadcast_dimensions));
 | 
			
		||||
 | 
			
		||||
    // And generate the final non-broadcasted binary op.
 | 
			
		||||
    Value final_result = Adaptor::CreateOp(op, result_type, broadcasted_lhs,
 | 
			
		||||
                                           broadcasted_rhs, rewriter);
 | 
			
		||||
    rewriter.create<shape::AssumingYieldOp>(loc, final_result);
 | 
			
		||||
    rewriter.replaceOp(op, {assuming_op.getResult(0)});
 | 
			
		||||
    return success();
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
 | 
			
		||||
void PopulateForBinaryOp(MLIRContext *context,
 | 
			
		||||
                         OwningRewritePatternList *patterns) {
 | 
			
		||||
  patterns
 | 
			
		||||
      ->insert<ConvertTrivialNonBroadcastBinaryOp<ChloOpTy, HloOpTy, Adaptor>>(
 | 
			
		||||
          context, 10);
 | 
			
		||||
  patterns->insert<
 | 
			
		||||
      ConvertRankedDynamicBroadcastBinaryOp<ChloOpTy, HloOpTy, Adaptor>>(
 | 
			
		||||
      context, 5);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename FromOpTy, typename ToOpTy>
 | 
			
		||||
struct HloBinaryElementwiseAdaptor {
 | 
			
		||||
  static ToOpTy CreateOp(FromOpTy from_op, Type result_type,
 | 
			
		||||
                         Value broadcasted_lhs, Value broadcasted_rhs,
 | 
			
		||||
                         OpBuilder &builder) {
 | 
			
		||||
    return builder.create<ToOpTy>(from_op.getLoc(), result_type,
 | 
			
		||||
                                  broadcasted_lhs, broadcasted_rhs);
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct HloComplexAdaptor {
 | 
			
		||||
  static xla_hlo::ComplexOp CreateOp(BroadcastComplexOp from_op,
 | 
			
		||||
                                     Type result_type, Value broadcasted_lhs,
 | 
			
		||||
                                     Value broadcasted_rhs,
 | 
			
		||||
                                     OpBuilder &builder) {
 | 
			
		||||
    return builder.create<xla_hlo::ComplexOp>(from_op.getLoc(), result_type,
 | 
			
		||||
                                              broadcasted_lhs, broadcasted_rhs);
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct HloCompareAdaptor {
 | 
			
		||||
  static xla_hlo::CompareOp CreateOp(BroadcastCompareOp from_op,
 | 
			
		||||
                                     Type result_type, Value broadcasted_lhs,
 | 
			
		||||
                                     Value broadcasted_rhs,
 | 
			
		||||
                                     OpBuilder &builder) {
 | 
			
		||||
    return builder.create<xla_hlo::CompareOp>(from_op.getLoc(), result_type,
 | 
			
		||||
                                              broadcasted_lhs, broadcasted_rhs,
 | 
			
		||||
                                              from_op.comparison_direction());
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
 | 
			
		||||
void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
 | 
			
		||||
                                       OwningRewritePatternList *patterns) {
 | 
			
		||||
  // Instantiate conversion templates for conforming binary elementwise ops
 | 
			
		||||
  // that do not have different dtypes between operands and results and do
 | 
			
		||||
  // not have special attributes that need to be preserved.
 | 
			
		||||
#define POPULATE_BCAST(ChloOp, HloOp)                                      \
 | 
			
		||||
  PopulateForBinaryOp<ChloOp, HloOp,                                       \
 | 
			
		||||
                      HloBinaryElementwiseAdaptor<ChloOp, HloOp>>(context, \
 | 
			
		||||
                                                                  patterns);
 | 
			
		||||
 | 
			
		||||
  POPULATE_BCAST(BroadcastAddOp, xla_hlo::AddOp);
 | 
			
		||||
  POPULATE_BCAST(BroadcastAndOp, xla_hlo::AndOp);
 | 
			
		||||
  POPULATE_BCAST(BroadcastAtan2Op, xla_hlo::Atan2Op);
 | 
			
		||||
  POPULATE_BCAST(BroadcastDivOp, xla_hlo::DivOp);
 | 
			
		||||
  POPULATE_BCAST(BroadcastMaxOp, xla_hlo::MaxOp);
 | 
			
		||||
  POPULATE_BCAST(BroadcastMinOp, xla_hlo::MinOp);
 | 
			
		||||
  POPULATE_BCAST(BroadcastMulOp, xla_hlo::MulOp);
 | 
			
		||||
  POPULATE_BCAST(BroadcastOrOp, xla_hlo::OrOp);
 | 
			
		||||
  POPULATE_BCAST(BroadcastPowOp, xla_hlo::PowOp);
 | 
			
		||||
  POPULATE_BCAST(BroadcastRemOp, xla_hlo::RemOp);
 | 
			
		||||
  POPULATE_BCAST(BroadcastShiftLeftOp, xla_hlo::ShiftLeftOp);
 | 
			
		||||
  POPULATE_BCAST(BroadcastShiftRightArithmeticOp,
 | 
			
		||||
                 xla_hlo::ShiftRightArithmeticOp);
 | 
			
		||||
  POPULATE_BCAST(BroadcastShiftRightLogicalOp, xla_hlo::ShiftRightLogicalOp);
 | 
			
		||||
  POPULATE_BCAST(BroadcastSubOp, xla_hlo::SubOp);
 | 
			
		||||
  POPULATE_BCAST(BroadcastXorOp, xla_hlo::XorOp);
 | 
			
		||||
 | 
			
		||||
  // Broadcasting ops requiring special construction.
 | 
			
		||||
  PopulateForBinaryOp<BroadcastComplexOp, xla_hlo::ComplexOp,
 | 
			
		||||
                      HloComplexAdaptor>(context, patterns);
 | 
			
		||||
  PopulateForBinaryOp<BroadcastCompareOp, xla_hlo::CompareOp,
 | 
			
		||||
                      HloCompareAdaptor>(context, patterns);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace xla_chlo
 | 
			
		||||
}  // namespace mlir
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,57 @@
 | 
			
		|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Shape/IR/Shape.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
 | 
			
		||||
 | 
			
		||||
namespace mlir {
 | 
			
		||||
namespace xla_chlo {
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
struct TestChloLegalizeToHloPass
 | 
			
		||||
    : public PassWrapper<TestChloLegalizeToHloPass, FunctionPass> {
 | 
			
		||||
  void runOnFunction() override {
 | 
			
		||||
    ConversionTarget conversionTarget(getContext());
 | 
			
		||||
    OwningRewritePatternList conversionPatterns;
 | 
			
		||||
 | 
			
		||||
    conversionTarget.addIllegalDialect<XlaHloClientDialect>();
 | 
			
		||||
    // Consider the xla_hlo dialect legal for tests.
 | 
			
		||||
    conversionTarget.addLegalDialect<xla_hlo::XlaHloDialect>();
 | 
			
		||||
    // The conversion uses helpers from the Standard dialect.
 | 
			
		||||
    conversionTarget.addLegalDialect<mlir::StandardOpsDialect>();
 | 
			
		||||
    conversionTarget.addLegalDialect<mlir::shape::ShapeDialect>();
 | 
			
		||||
 | 
			
		||||
    PopulateLegalizeChloToHloPatterns(&getContext(), &conversionPatterns);
 | 
			
		||||
 | 
			
		||||
    if (failed(applyPartialConversion(getFunction(), conversionTarget,
 | 
			
		||||
                                      conversionPatterns))) {
 | 
			
		||||
      return signalPassFailure();
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
 | 
			
		||||
}  // namespace xla_chlo
 | 
			
		||||
}  // namespace mlir
 | 
			
		||||
 | 
			
		||||
static mlir::PassRegistration<mlir::xla_chlo::TestChloLegalizeToHloPass> pass(
 | 
			
		||||
    "test-xla-chlo-legalize-to-hlo",
 | 
			
		||||
    "Test pass for applying chlo -> hlo legalization patterns");
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,493 @@
 | 
			
		|||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
// This file implements logic for lowering HLO dialect to LHLO dialect.
 | 
			
		||||
 | 
			
		||||
#include "third_party/absl/memory/memory.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/AffineMap.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/BlockAndValueMapping.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Function.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Location.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/BufferPlacement.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h"
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h"
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
 | 
			
		||||
 | 
			
		||||
namespace mlir {
 | 
			
		||||
namespace xla_hlo {
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
template <typename T>
 | 
			
		||||
using BaseOpConversion = BufferAssignmentOpConversionPattern<T>;
 | 
			
		||||
using StdReturnOpConverter =
 | 
			
		||||
    detail::BufferAssignmentReturnOpConverter<mlir::ReturnOp, mlir::ReturnOp,
 | 
			
		||||
                                              xla_lhlo::CopyOp, true>;
 | 
			
		||||
 | 
			
		||||
Value InsertDynamicAllocAndDealloc(Location loc, Value result,
 | 
			
		||||
                                   Value shape_operand,
 | 
			
		||||
                                   ConversionPatternRewriter* rewriter) {
 | 
			
		||||
  auto result_type = result.getType().dyn_cast<ShapedType>();
 | 
			
		||||
  if (!result_type) {
 | 
			
		||||
    result.getDefiningOp()->emitOpError()
 | 
			
		||||
        << "tensor to buffer conversion expects ranked results";
 | 
			
		||||
  }
 | 
			
		||||
  auto memref_type =
 | 
			
		||||
      MemRefType::get(result_type.getShape(), result_type.getElementType());
 | 
			
		||||
 | 
			
		||||
  Operation* op = result.getDefiningOp();
 | 
			
		||||
 | 
			
		||||
  // Extract the required element out of the vector.
 | 
			
		||||
  SmallVector<Value, 4> dynamic_operands;
 | 
			
		||||
  for (auto shape_element : llvm::enumerate(result_type.getShape())) {
 | 
			
		||||
    if (shape_element.value() != ShapedType::kDynamicSize) continue;
 | 
			
		||||
    Value index = rewriter->create<ConstantOp>(
 | 
			
		||||
        loc, rewriter->getIntegerAttr(rewriter->getIndexType(),
 | 
			
		||||
                                      shape_element.index()));
 | 
			
		||||
    Value alloc_operand = rewriter->create<ExtractElementOp>(loc, shape_operand,
 | 
			
		||||
                                                             ValueRange{index});
 | 
			
		||||
    if (!alloc_operand.getType().isIndex()) {
 | 
			
		||||
      alloc_operand = rewriter->create<IndexCastOp>(loc, alloc_operand,
 | 
			
		||||
                                                    rewriter->getIndexType());
 | 
			
		||||
    }
 | 
			
		||||
    dynamic_operands.push_back(alloc_operand);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Insert in front of op to ensure sizes are available.
 | 
			
		||||
  OpBuilder allocBuilder(op);
 | 
			
		||||
  auto alloc = allocBuilder.create<AllocOp>(loc, memref_type, dynamic_operands);
 | 
			
		||||
  return alloc;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Value InsertAlloc(Location loc, OpResult result,
 | 
			
		||||
                  BufferAssignmentPlacer* bufferAssignment,
 | 
			
		||||
                  ConversionPatternRewriter* rewriter) {
 | 
			
		||||
  auto result_type = result.getType().dyn_cast<ShapedType>();
 | 
			
		||||
  if (!result_type || !result_type.hasStaticShape()) {
 | 
			
		||||
    result.getDefiningOp()->emitOpError()
 | 
			
		||||
        << "tensor to buffer conversion expects statically shaped results";
 | 
			
		||||
  }
 | 
			
		||||
  auto memref_type =
 | 
			
		||||
      MemRefType::get(result_type.getShape(), result_type.getElementType());
 | 
			
		||||
  OpBuilder::InsertionGuard guard(*rewriter);
 | 
			
		||||
  rewriter->restoreInsertionPoint(
 | 
			
		||||
      bufferAssignment->computeAllocPosition(result));
 | 
			
		||||
  auto alloc = rewriter->create<AllocOp>(loc, memref_type);
 | 
			
		||||
  return alloc;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename HloOpTy>
 | 
			
		||||
class HloToLhloOpConverter : public BaseOpConversion<HloOpTy> {
 | 
			
		||||
 public:
 | 
			
		||||
  using BaseOpConversion<HloOpTy>::BaseOpConversion;
 | 
			
		||||
  LogicalResult matchAndRewrite(
 | 
			
		||||
      HloOpTy hloOp, ArrayRef<Value> operands,
 | 
			
		||||
      ConversionPatternRewriter& rewriter) const final {
 | 
			
		||||
    Operation* op = hloOp.getOperation();
 | 
			
		||||
    const auto& original_results = op->getResults();
 | 
			
		||||
    SmallVector<Value, 4> buffer_args(operands.begin(), operands.end());
 | 
			
		||||
    for (auto result : llvm::enumerate(original_results)) {
 | 
			
		||||
      RankedTensorType resultType =
 | 
			
		||||
          result.value().getType().dyn_cast<RankedTensorType>();
 | 
			
		||||
      if (!resultType) {
 | 
			
		||||
        return failure();
 | 
			
		||||
      }
 | 
			
		||||
      if (resultType.hasStaticShape()) {
 | 
			
		||||
        buffer_args.push_back(InsertAlloc(op->getLoc(), result.value(),
 | 
			
		||||
                                          this->bufferAssignment, &rewriter));
 | 
			
		||||
      } else {
 | 
			
		||||
        SmallVector<Value, 1> results_shape;
 | 
			
		||||
        auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op);
 | 
			
		||||
        if (!shape_type_op) return failure();
 | 
			
		||||
        if (failed(
 | 
			
		||||
                shape_type_op.reifyReturnTypeShapes(rewriter, results_shape)))
 | 
			
		||||
          return failure();
 | 
			
		||||
        buffer_args.push_back(InsertDynamicAllocAndDealloc(
 | 
			
		||||
            op->getLoc(), result.value(), results_shape.front(), &rewriter));
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
    rewriter.create<xla_hlo::HloToLhloOp<HloOpTy>>(op->getLoc(), llvm::None,
 | 
			
		||||
                                                   buffer_args, op->getAttrs());
 | 
			
		||||
    rewriter.replaceOp(op, ArrayRef<Value>(buffer_args).slice(operands.size()));
 | 
			
		||||
    return success();
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct HloToLhloDynamicBroadcastInDimOpConverter
 | 
			
		||||
    : public BaseOpConversion<xla_hlo::DynamicBroadcastInDimOp> {
 | 
			
		||||
 public:
 | 
			
		||||
  using BaseOpConversion<xla_hlo::DynamicBroadcastInDimOp>::BaseOpConversion;
 | 
			
		||||
 | 
			
		||||
  LogicalResult matchAndRewrite(
 | 
			
		||||
      xla_hlo::DynamicBroadcastInDimOp op, ArrayRef<Value> operands,
 | 
			
		||||
      ConversionPatternRewriter& rewriter) const final {
 | 
			
		||||
    auto loc = op.getLoc();
 | 
			
		||||
    Value resultBuffer = InsertDynamicAllocAndDealloc(
 | 
			
		||||
        loc, op.getResult(), op.output_dimensions(), &rewriter);
 | 
			
		||||
 | 
			
		||||
    Value transformed_operand =
 | 
			
		||||
        InsertDynamicMemrefCastOp(op, operands.front(), &rewriter);
 | 
			
		||||
    rewriter.create<xla_lhlo::BroadcastInDimOp>(
 | 
			
		||||
        loc, transformed_operand, resultBuffer, op.broadcast_dimensions());
 | 
			
		||||
 | 
			
		||||
    rewriter.replaceOp(op, {resultBuffer});
 | 
			
		||||
 | 
			
		||||
    return success();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  // Inserts dynamic memref to change the layout of the memref to put 0-stride
 | 
			
		||||
  // and size of the target dimension if size-1 dimension expansion is
 | 
			
		||||
  // necessary.
 | 
			
		||||
  xla_lhlo::DynamicMemRefCastOp InsertDynamicMemrefCastOp(
 | 
			
		||||
      xla_hlo::DynamicBroadcastInDimOp op, Value operand, OpBuilder* b) const {
 | 
			
		||||
    auto loc = op.getLoc();
 | 
			
		||||
    auto operand_type = operand.getType().cast<MemRefType>();
 | 
			
		||||
    auto operand_shape = operand_type.getShape();
 | 
			
		||||
 | 
			
		||||
    SmallVector<Value, 2> sizes, strides;
 | 
			
		||||
    sizes.reserve(operand_shape.size());
 | 
			
		||||
    strides.reserve(operand_shape.size());
 | 
			
		||||
 | 
			
		||||
    Value zero = b->create<ConstantIndexOp>(loc, 0);
 | 
			
		||||
    Value one = b->create<ConstantIndexOp>(loc, 1);
 | 
			
		||||
    for (auto dim : llvm::enumerate(op.broadcast_dimensions())) {
 | 
			
		||||
      Value broadcast_dim_value =
 | 
			
		||||
          b->create<ConstantIndexOp>(loc, dim.value().getSExtValue());
 | 
			
		||||
      Value result_dim_size = b->create<ExtractElementOp>(
 | 
			
		||||
          loc, op.output_dimensions(), broadcast_dim_value);
 | 
			
		||||
      Value operand_dim_size =
 | 
			
		||||
          ShapedType::isDynamic(operand_shape[dim.index()])
 | 
			
		||||
              ? b->create<DimOp>(loc, operand, dim.index()).getResult()
 | 
			
		||||
              : b->create<ConstantIndexOp>(loc, operand_shape[dim.index()])
 | 
			
		||||
                    .getResult();
 | 
			
		||||
 | 
			
		||||
      // TODO(pifon): Revisit if this cast is needed. Maybe we can use
 | 
			
		||||
      // tensor<index> for `output_dimensions` as well.
 | 
			
		||||
      if (!result_dim_size.getType().isIndex()) {
 | 
			
		||||
        result_dim_size =
 | 
			
		||||
            b->create<IndexCastOp>(loc, result_dim_size, b->getIndexType());
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      // There can be two cases:
 | 
			
		||||
      // 1) Operand dim == result dim => expansion is not needed => stride := 1.
 | 
			
		||||
      // 2) Operand dim < result dim => expansion is needed => stride := 0.
 | 
			
		||||
      Value is_expansion = b->create<CmpIOp>(loc, CmpIPredicate::slt,
 | 
			
		||||
                                             operand_dim_size, result_dim_size);
 | 
			
		||||
      strides.push_back(
 | 
			
		||||
          b->create<mlir::SelectOp>(loc, is_expansion, zero, one));
 | 
			
		||||
 | 
			
		||||
      // Size of input dim can be set to the size of the corresponding output
 | 
			
		||||
      // dimension for both cases.
 | 
			
		||||
      sizes.push_back(result_dim_size);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Type-erased memref type with static rank, dynamic sizes and strides.
 | 
			
		||||
    SmallVector<int64_t, 2> dynamic_layout(operand_shape.size(),
 | 
			
		||||
                                           MemRefType::kDynamicStrideOrOffset);
 | 
			
		||||
    SmallVector<int64_t, 2> dynamic_shape(operand_shape.size(),
 | 
			
		||||
                                          MemRefType::kDynamicSize);
 | 
			
		||||
    auto type_erased_memref_type = MemRefType::get(
 | 
			
		||||
        dynamic_shape, operand_type.getElementType(),
 | 
			
		||||
        makeStridedLinearLayoutMap(dynamic_layout,
 | 
			
		||||
                                   /*offset=*/0, b->getContext()));
 | 
			
		||||
 | 
			
		||||
    auto transformed_operand = b->create<xla_lhlo::DynamicMemRefCastOp>(
 | 
			
		||||
        loc, type_erased_memref_type, operand, sizes, strides);
 | 
			
		||||
    return transformed_operand;
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct HloToLhloReduceOpConverter : public BaseOpConversion<xla_hlo::ReduceOp> {
 | 
			
		||||
 public:
 | 
			
		||||
  using BaseOpConversion<xla_hlo::ReduceOp>::BaseOpConversion;
 | 
			
		||||
 | 
			
		||||
  LogicalResult matchAndRewrite(
 | 
			
		||||
      xla_hlo::ReduceOp op, ArrayRef<Value> operands,
 | 
			
		||||
      ConversionPatternRewriter& rewriter) const final {
 | 
			
		||||
    auto loc = op.getLoc();
 | 
			
		||||
    // TODO(b/137624192) Implement variadic reduce.
 | 
			
		||||
    if (op.getNumResults() != 1) return failure();
 | 
			
		||||
    if (!llvm::hasSingleElement(op.body())) {
 | 
			
		||||
      return op.emitOpError()
 | 
			
		||||
             << "tensor to buffer conversion expects a single block "
 | 
			
		||||
                "in the region containing the operation";
 | 
			
		||||
    }
 | 
			
		||||
    const auto& original_results = op.getResults();
 | 
			
		||||
    SmallVector<Value, 4> buffer_args(operands.begin(), operands.end());
 | 
			
		||||
    for (auto result : original_results) {
 | 
			
		||||
      buffer_args.push_back(
 | 
			
		||||
          InsertAlloc(loc, result, this->bufferAssignment, &rewriter));
 | 
			
		||||
    }
 | 
			
		||||
    auto new_op = rewriter.create<xla_lhlo::ReduceOp>(
 | 
			
		||||
        loc, llvm::None, buffer_args, op.getAttrs());
 | 
			
		||||
 | 
			
		||||
    // Copy over the operations inside the region.
 | 
			
		||||
    rewriter.inlineRegionBefore(op.body(), new_op.body(), new_op.body().end());
 | 
			
		||||
 | 
			
		||||
    // Create new block arguments with correct type.
 | 
			
		||||
    auto& entry_block = new_op.body().front();
 | 
			
		||||
    int original_arg_count = entry_block.getNumArguments();
 | 
			
		||||
    for (int i = 0; i < original_arg_count; ++i) {
 | 
			
		||||
      auto old_arg = entry_block.getArgument(i);
 | 
			
		||||
      auto old_type = old_arg.getType().cast<TensorType>();
 | 
			
		||||
      auto new_type =
 | 
			
		||||
          MemRefType::get(old_type.getShape(), old_type.getElementType());
 | 
			
		||||
      auto new_arg = entry_block.addArgument(new_type);
 | 
			
		||||
      rewriter.replaceUsesOfBlockArgument(old_arg, new_arg);
 | 
			
		||||
    }
 | 
			
		||||
    // Add an argument for the result.
 | 
			
		||||
    entry_block.addArgument(
 | 
			
		||||
        entry_block.getArgument(original_arg_count).getType());
 | 
			
		||||
    // Remove the old arguments.
 | 
			
		||||
    for (int i = original_arg_count - 1; i >= 0; --i) {
 | 
			
		||||
      entry_block.eraseArgument(i);
 | 
			
		||||
    }
 | 
			
		||||
    // Insert terminator at the end.
 | 
			
		||||
    rewriter.setInsertionPointToEnd(&entry_block);
 | 
			
		||||
    rewriter.create<xla_lhlo::TerminatorOp>(loc);
 | 
			
		||||
 | 
			
		||||
    rewriter.replaceOp(op, ArrayRef<Value>(buffer_args).slice(operands.size()));
 | 
			
		||||
 | 
			
		||||
    return success();
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
class HloToLhloTensorLoadOpConverter
 | 
			
		||||
    : public BaseOpConversion<mlir::TensorLoadOp> {
 | 
			
		||||
 public:
 | 
			
		||||
  using BaseOpConversion<mlir::TensorLoadOp>::BaseOpConversion;
 | 
			
		||||
  LogicalResult matchAndRewrite(
 | 
			
		||||
      mlir::TensorLoadOp op, ArrayRef<Value> operands,
 | 
			
		||||
      ConversionPatternRewriter& rewriter) const final {
 | 
			
		||||
    rewriter.replaceOp(op, operands);
 | 
			
		||||
    return success();
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// TODO(b/137624192): Rewrite into a copy and elide copy if possible.
 | 
			
		||||
class HloToLhloTensorStoreOpConverter
 | 
			
		||||
    : public BaseOpConversion<mlir::TensorStoreOp> {
 | 
			
		||||
 public:
 | 
			
		||||
  using BaseOpConversion<mlir::TensorStoreOp>::BaseOpConversion;
 | 
			
		||||
 | 
			
		||||
  LogicalResult matchAndRewrite(
 | 
			
		||||
      mlir::TensorStoreOp op, ArrayRef<Value> operands,
 | 
			
		||||
      ConversionPatternRewriter& rewriter) const final {
 | 
			
		||||
    rewriter.replaceOpWithNewOp<xla_lhlo::CopyOp>(
 | 
			
		||||
        op, llvm::None, operands.front(), operands.back());
 | 
			
		||||
    return success();
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary
 | 
			
		||||
// buffers if necessary.
 | 
			
		||||
//
 | 
			
		||||
// Example fusion with HLO ops.
 | 
			
		||||
//
 | 
			
		||||
// func @fusion(%arg0: memref<2x2xf32>,
 | 
			
		||||
//              %arg1: memref<2x2xf32>,
 | 
			
		||||
//              %arg2: memref<2x2xf32>,
 | 
			
		||||
//              %arg3: memref<2x2xf32>) {
 | 
			
		||||
//   "xla_lhlo.fusion"() ({
 | 
			
		||||
//     %0 = tensor_load %arg1 : memref<2x2xf32>
 | 
			
		||||
//     %1 = tensor_load %arg2 : memref<2x2xf32>
 | 
			
		||||
//     %2 = "xla_hlo.add"(%0, %1) :
 | 
			
		||||
//         (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
 | 
			
		||||
//     %3 = tensor_load %arg0 : memref<2x2xf32>
 | 
			
		||||
//     %4 = "xla_hlo.multiply"(%2, %3) :
 | 
			
		||||
//         (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
 | 
			
		||||
//     tensor_store %4, %arg3 : memref<2x2xf32>
 | 
			
		||||
//     "xla_lhlo.terminator"() : () -> ()
 | 
			
		||||
//   }) : () -> ()
 | 
			
		||||
//   return
 | 
			
		||||
// }
 | 
			
		||||
//
 | 
			
		||||
// Transformed fusion with LHLO ops.
 | 
			
		||||
// func @fusion(%arg0: memref<2x2xf32>,
 | 
			
		||||
//              %arg1: memref<2x2xf32>,
 | 
			
		||||
//              %arg2: memref<2x2xf32>,
 | 
			
		||||
//              %arg3: memref<2x2xf32>) {
 | 
			
		||||
//   "xla_lhlo.fusion"() ( {
 | 
			
		||||
//     %0 = alloc() : memref<2x2xf32>
 | 
			
		||||
//     "xla_lhlo.add"(%arg1, %arg2, %0) :
 | 
			
		||||
//         (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
 | 
			
		||||
//     "xla_lhlo.multiply"(%0, %arg0, %arg3) :
 | 
			
		||||
//         (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
 | 
			
		||||
//     "xla_lhlo.terminator"() : () -> ()
 | 
			
		||||
//   }) : () -> ()
 | 
			
		||||
//   return
 | 
			
		||||
// }
 | 
			
		||||
//
 | 
			
		||||
// FuncOp signature conversion example:
 | 
			
		||||
//
 | 
			
		||||
// func @func_op(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
 | 
			
		||||
//   %0 = "xla_hlo.maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) ->
 | 
			
		||||
//   tensor<4xf32> %1 = "xla_hlo.add"(%arg0, %0)  : (tensor<4xf32>,
 | 
			
		||||
//   tensor<4xf32>) -> tensor<4xf32> return %1 : tensor<4xf32>
 | 
			
		||||
// }
 | 
			
		||||
//
 | 
			
		||||
// Transformed function with an extra argument for the result. The types have
 | 
			
		||||
// been converted from tensor to memref.
 | 
			
		||||
//
 | 
			
		||||
// func @func_op(%arg0: memref<4xf32>,
 | 
			
		||||
//               %arg1: memref<4xf32>,
 | 
			
		||||
//               %arg2: memref<4xf32>) {
 | 
			
		||||
//   %0 = alloc() : memref<4xf32>
 | 
			
		||||
 | 
			
		||||
//   "xla_lhlo.maximum"(%arg0, %arg1, %0) :
 | 
			
		||||
//         (memref<4xf32>, memref<4xf32>, memref<4xf32>) -> ()
 | 
			
		||||
//   %1 = alloc() : memref<4xf32>
 | 
			
		||||
//   "xla_lhlo.add"(%arg0, %0, %1) :
 | 
			
		||||
//         (memref<4xf32>, memref<4xf32>, memref<4xf32>) -> ()
 | 
			
		||||
//   "xla_lhlo.copy"(%1, %arg2) : (memref<4xf32>, memref<4xf32>) -> ()
 | 
			
		||||
//   "xla_lhlo.terminator"() : () -> ()
 | 
			
		||||
// }
 | 
			
		||||
 | 
			
		||||
struct HloLegalizeToLhlo
 | 
			
		||||
    : public PassWrapper<HloLegalizeToLhlo, OperationPass<ModuleOp>> {
 | 
			
		||||
 public:
 | 
			
		||||
  HloLegalizeToLhlo() = default;
 | 
			
		||||
  HloLegalizeToLhlo(const HloLegalizeToLhlo& o) {
 | 
			
		||||
    this->results_escape_function = o.results_escape_function.getValue();
 | 
			
		||||
  }
 | 
			
		||||
  explicit HloLegalizeToLhlo(bool results_escape_function) {
 | 
			
		||||
    this->results_escape_function.setValue(results_escape_function);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void runOnOperation() override {
 | 
			
		||||
    OwningRewritePatternList patterns;
 | 
			
		||||
    auto& context = getContext();
 | 
			
		||||
    ConversionTarget target(context);
 | 
			
		||||
    target.addLegalDialect<xla_lhlo::XlaLhloDialect>();
 | 
			
		||||
    target.addLegalDialect<StandardOpsDialect>();
 | 
			
		||||
    target.addLegalOp<ModuleOp>();
 | 
			
		||||
    target.addIllegalOp<mlir::TensorLoadOp>();
 | 
			
		||||
    target.addIllegalOp<mlir::TensorStoreOp>();
 | 
			
		||||
    target.addLegalOp<ModuleTerminatorOp>();
 | 
			
		||||
    target.addLegalOp<TensorFromElementsOp>();
 | 
			
		||||
    target.addIllegalDialect<xla_hlo::XlaHloDialect>();
 | 
			
		||||
 | 
			
		||||
    BufferAssignmentTypeConverter converter;
 | 
			
		||||
    target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
 | 
			
		||||
      auto inputs = op.getType().getInputs();
 | 
			
		||||
      return llvm::all_of(inputs,
 | 
			
		||||
                          [](Type input) { return input.isa<MemRefType>(); }) &&
 | 
			
		||||
             converter.isLegal(&op.getBody());
 | 
			
		||||
    });
 | 
			
		||||
    target.addDynamicallyLegalOp<mlir::ReturnOp>([&](mlir::ReturnOp returnOp) {
 | 
			
		||||
      return std::all_of(returnOp.operand_type_begin(),
 | 
			
		||||
                         returnOp.operand_type_end(),
 | 
			
		||||
                         [](Type type) { return type.isa<MemRefType>(); });
 | 
			
		||||
    });
 | 
			
		||||
 | 
			
		||||
    auto module = getOperation();
 | 
			
		||||
    WalkResult result = module.walk([&](FuncOp func) -> WalkResult {
 | 
			
		||||
      BufferAssignmentPlacer bufferAssignment(func);
 | 
			
		||||
      OwningRewritePatternList patterns;
 | 
			
		||||
      populateHLOToLHLOConversionPattern(func.getContext(), &bufferAssignment,
 | 
			
		||||
                                         &converter, &patterns);
 | 
			
		||||
      if (results_escape_function) {
 | 
			
		||||
        populateWithBufferAssignmentOpConversionPatterns<
 | 
			
		||||
            mlir::ReturnOp, mlir::ReturnOp, xla_lhlo::CopyOp,
 | 
			
		||||
            /*allowMemrefFunctionResults=*/true>(&context, &bufferAssignment,
 | 
			
		||||
                                                 &converter, &patterns);
 | 
			
		||||
      } else {
 | 
			
		||||
        populateWithBufferAssignmentOpConversionPatterns<
 | 
			
		||||
            mlir::ReturnOp, mlir::ReturnOp, xla_lhlo::CopyOp,
 | 
			
		||||
            /*allowMemrefFunctionResults=*/false>(&context, &bufferAssignment,
 | 
			
		||||
                                                  &converter, &patterns);
 | 
			
		||||
      }
 | 
			
		||||
      return applyPartialConversion(func, target, patterns);
 | 
			
		||||
    });
 | 
			
		||||
    if (result.wasInterrupted()) {
 | 
			
		||||
      signalPassFailure();
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  Option<bool> results_escape_function{
 | 
			
		||||
      *this, "results-escape-function",
 | 
			
		||||
      llvm::cl::desc(
 | 
			
		||||
          "Allocate the results of functions within the functions body"),
 | 
			
		||||
      llvm::cl::init(false)};
 | 
			
		||||
};
 | 
			
		||||
}  // namespace
 | 
			
		||||
 | 
			
		||||
void populateHLOToLHLOConversionPattern(
 | 
			
		||||
    MLIRContext* context, BufferAssignmentPlacer* bufferAssignment,
 | 
			
		||||
    TypeConverter* converter, OwningRewritePatternList* patterns) {
 | 
			
		||||
  // clang-format off
 | 
			
		||||
  patterns->insert<
 | 
			
		||||
      HloToLhloDynamicBroadcastInDimOpConverter,
 | 
			
		||||
      HloToLhloOpConverter<xla_hlo::AbsOp>,
 | 
			
		||||
      HloToLhloOpConverter<xla_hlo::AddOp>,
 | 
			
		||||
      HloToLhloOpConverter<xla_hlo::AndOp>,
 | 
			
		||||
      HloToLhloOpConverter<xla_hlo::BroadcastInDimOp>,
 | 
			
		||||
      HloToLhloOpConverter<xla_hlo::CeilOp>,
 | 
			
		||||
      HloToLhloOpConverter<xla_hlo::CompareOp>,
 | 
			
		||||
      HloToLhloOpConverter<xla_hlo::ComplexOp>,
 | 
			
		||||
      HloToLhloOpConverter<xla_hlo::ConstOp>,
 | 
			
		||||
      HloToLhloOpConverter<xla_hlo::ConvOp>,
 | 
			
		||||
      HloToLhloOpConverter<xla_hlo::ConvertOp>,
 | 
			
		||||
      HloToLhloOpConverter<xla_hlo::CopyOp>,
 | 
			
		||||
      HloToLhloOpConverter<xla_hlo::CosOp>,
 | 
			
		||||
      HloToLhloOpConverter<xla_hlo::DivOp>,
 | 
			
		||||
      HloToLhloOpConverter<xla_hlo::DotOp>,
 | 
			
		||||
      HloToLhloOpConverter<xla_hlo::ExpOp>,
 | 
			
		||||
      HloToLhloOpConverter<xla_hlo::GatherOp>,
 | 
			
		||||
      HloToLhloOpConverter<xla_hlo::ImagOp>,
 | 
			
		||||
      HloToLhloOpConverter<xla_hlo::IotaOp>,
 | 
			
		||||
      HloToLhloOpConverter<xla_hlo::LogOp>,
 | 
			
		||||
      HloToLhloOpConverter<xla_hlo::MaxOp>,
 | 
			
		||||
      HloToLhloOpConverter<xla_hlo::MinOp>,
 | 
			
		||||
      HloToLhloOpConverter<xla_hlo::MulOp>,
 | 
			
		||||
      HloToLhloOpConverter<xla_hlo::NegOp>,
 | 
			
		||||
      HloToLhloOpConverter<xla_hlo::RealOp>,
 | 
			
		||||
      HloToLhloOpConverter<xla_hlo::RemOp>,
 | 
			
		||||
      HloToLhloOpConverter<xla_hlo::RsqrtOp>,
 | 
			
		||||
      HloToLhloOpConverter<xla_hlo::ReshapeOp>,
 | 
			
		||||
      HloToLhloOpConverter<xla_hlo::SelectOp>,
 | 
			
		||||
      HloToLhloOpConverter<xla_hlo::SignOp>,
 | 
			
		||||
      HloToLhloOpConverter<xla_hlo::SqrtOp>,
 | 
			
		||||
      HloToLhloOpConverter<xla_hlo::SubOp>,
 | 
			
		||||
      HloToLhloOpConverter<xla_hlo::TanhOp>,
 | 
			
		||||
      HloToLhloReduceOpConverter,
 | 
			
		||||
      HloToLhloTensorLoadOpConverter,
 | 
			
		||||
      HloToLhloTensorStoreOpConverter
 | 
			
		||||
  >(context, bufferAssignment, converter);
 | 
			
		||||
  // clang-format on
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass(
 | 
			
		||||
    bool results_escape_function) {
 | 
			
		||||
  return absl::make_unique<HloLegalizeToLhlo>(results_escape_function);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static PassRegistration<HloLegalizeToLhlo> legalize_pass(
 | 
			
		||||
    "hlo-legalize-to-lhlo", "Legalize from HLO dialect to LHLO dialect");
 | 
			
		||||
 | 
			
		||||
}  // namespace xla_hlo
 | 
			
		||||
}  // namespace mlir
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,237 @@
 | 
			
		|||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
// This file implements logic for lowering XLA dialect to Standard dialect.
 | 
			
		||||
 | 
			
		||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLExtras.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringSwitch.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/Support/Casting.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Block.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/BlockAndValueMapping.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Function.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/TypeUtilities.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/PassRegistry.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Support/LogicalResult.h"
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
 | 
			
		||||
 | 
			
		||||
using mlir::PassRegistration;
 | 
			
		||||
 | 
			
		||||
namespace mlir {
 | 
			
		||||
namespace xla_hlo {
 | 
			
		||||
namespace {
 | 
			
		||||
struct LegalizeControlFlow
 | 
			
		||||
    : public mlir::PassWrapper<LegalizeControlFlow, FunctionPass> {
 | 
			
		||||
  // Perform the lowering to MLIR control flow.
 | 
			
		||||
  void runOnFunction() override;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// Replaces terminators for the newly created blocks from a targe region.
 | 
			
		||||
// These terminators are replaced with branch operations to a target block.
 | 
			
		||||
LogicalResult ReplaceTerminators(Region* region, Block* target_block,
 | 
			
		||||
                                 Location loc,
 | 
			
		||||
                                 const BlockAndValueMapping& mapper,
 | 
			
		||||
                                 OpBuilder* builder) {
 | 
			
		||||
  for (auto& old_block : region->getBlocks()) {
 | 
			
		||||
    Block* block = mapper.lookup(&old_block);
 | 
			
		||||
    auto return_op = dyn_cast<xla_hlo::ReturnOp>(block->getTerminator());
 | 
			
		||||
    if (!return_op) continue;
 | 
			
		||||
    builder->setInsertionPointToEnd(block);
 | 
			
		||||
    builder->create<mlir::BranchOp>(loc, target_block, return_op.getOperands());
 | 
			
		||||
    return_op.erase();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return success();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
LogicalResult LowerIfOp(mlir::xla_hlo::IfOp if_op) {
 | 
			
		||||
  Operation* op_inst = if_op.getOperation();
 | 
			
		||||
  mlir::OpBuilder builder(if_op);
 | 
			
		||||
  auto orig_block = op_inst->getBlock();
 | 
			
		||||
  auto* tail_block = orig_block->splitBlock(op_inst);
 | 
			
		||||
  auto loc = if_op.getLoc();
 | 
			
		||||
 | 
			
		||||
  // Duplicate the true and false regions in the block between the sections
 | 
			
		||||
  // before and after the conditional.
 | 
			
		||||
  BlockAndValueMapping mapper;
 | 
			
		||||
  if_op.true_branch().cloneInto(orig_block->getParent(),
 | 
			
		||||
                                Region::iterator(tail_block), mapper);
 | 
			
		||||
  if_op.false_branch().cloneInto(orig_block->getParent(),
 | 
			
		||||
                                 Region::iterator(tail_block), mapper);
 | 
			
		||||
 | 
			
		||||
  // Determine the blocks for the start of the true and false regions.
 | 
			
		||||
  Block* true_block = mapper.lookup(&if_op.true_branch().front());
 | 
			
		||||
  Block* false_block = mapper.lookup(&if_op.false_branch().front());
 | 
			
		||||
 | 
			
		||||
  // Perform the conditional branch into the true/false cases.
 | 
			
		||||
  builder.setInsertionPointToEnd(orig_block);
 | 
			
		||||
 | 
			
		||||
  // Extract the predicate for checking branching, then branch to the true and
 | 
			
		||||
  // false regions appropriately.
 | 
			
		||||
  auto cond_value = builder.create<mlir::ExtractElementOp>(loc, if_op.pred());
 | 
			
		||||
  builder.create<mlir::CondBranchOp>(loc, cond_value, true_block,
 | 
			
		||||
                                     if_op.true_arg(), false_block,
 | 
			
		||||
                                     if_op.false_arg());
 | 
			
		||||
 | 
			
		||||
  // Replace the true case's return operations with a branch to the tail of
 | 
			
		||||
  // the condition.
 | 
			
		||||
  if (failed(ReplaceTerminators(&if_op.true_branch(), tail_block, loc, mapper,
 | 
			
		||||
                                &builder)))
 | 
			
		||||
    return failure();
 | 
			
		||||
  if (failed(ReplaceTerminators(&if_op.false_branch(), tail_block, loc, mapper,
 | 
			
		||||
                                &builder)))
 | 
			
		||||
    return failure();
 | 
			
		||||
 | 
			
		||||
  tail_block->addArguments(if_op.getResult().getType());
 | 
			
		||||
  if_op.getResult().replaceAllUsesWith(tail_block->getArgument(0));
 | 
			
		||||
 | 
			
		||||
  op_inst->erase();
 | 
			
		||||
  return success();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
LogicalResult LowerWhileOp(mlir::xla_hlo::WhileOp while_op) {
 | 
			
		||||
  // Converts an XLA while loop into control flow. This generates a set of MLIR
 | 
			
		||||
  // blocks and branches, along with inlining the regions provided by the XLA
 | 
			
		||||
  // while loop. The structure should be similar to below:
 | 
			
		||||
  //
 | 
			
		||||
  //   <prior operations>
 | 
			
		||||
  //   %0 = "xla_hlo.while"(%arg0) {^cond(...){...}, ^body(...){...}}
 | 
			
		||||
  //   <post operations>
 | 
			
		||||
  auto* op_inst = while_op.getOperation();
 | 
			
		||||
  mlir::OpBuilder builder(while_op);
 | 
			
		||||
  auto loc = while_op.getLoc();
 | 
			
		||||
 | 
			
		||||
  // Break the block into four sections:
 | 
			
		||||
  // orig_block - operations before the while and the branch into looping check.
 | 
			
		||||
  // tail_block - operations after the while loop completes.
 | 
			
		||||
  // cond_block - check the looping condition, then conditionally branch into
 | 
			
		||||
  //              the loop or, if condition is false, jump to the tail branch.
 | 
			
		||||
  // body_block - inlined loop body, then jump back to the condition block.
 | 
			
		||||
  auto* orig_block = op_inst->getBlock();
 | 
			
		||||
  auto* tail_block = orig_block->splitBlock(op_inst);
 | 
			
		||||
 | 
			
		||||
  BlockAndValueMapping mapper;
 | 
			
		||||
  while_op.cond().cloneInto(orig_block->getParent(),
 | 
			
		||||
                            Region::iterator(tail_block), mapper);
 | 
			
		||||
  while_op.body().cloneInto(orig_block->getParent(),
 | 
			
		||||
                            Region::iterator(tail_block), mapper);
 | 
			
		||||
 | 
			
		||||
  // Lookup the entry blocks for both condition and body.
 | 
			
		||||
  auto* cond_block = mapper.lookup(&while_op.cond().front());
 | 
			
		||||
  auto* body_block = mapper.lookup(&while_op.body().front());
 | 
			
		||||
 | 
			
		||||
  // Setup the end of the original block:
 | 
			
		||||
  //     <prior operations>
 | 
			
		||||
  //     br ^cond(%arg0) // Jumps to the condition statement.
 | 
			
		||||
  builder.setInsertionPointToEnd(orig_block);
 | 
			
		||||
  builder.create<mlir::BranchOp>(loc, cond_block, while_op.getOperand());
 | 
			
		||||
 | 
			
		||||
  // Updates the inlined condition blocks by replacing the return op with an
 | 
			
		||||
  // extract_element and conditional branch. This changes the block below:
 | 
			
		||||
  //   ^cond(%0):
 | 
			
		||||
  //     <inlined conditional region>
 | 
			
		||||
  //    "xla_hlo".return(%1)
 | 
			
		||||
  //
 | 
			
		||||
  //  Into:
 | 
			
		||||
  //   ^cond(%0):
 | 
			
		||||
  //     <inlined conditional region>
 | 
			
		||||
  //     %2 = extract_element %1[] : tensor<i1> // Extract the condition value.
 | 
			
		||||
  //     cond_br %2, ^body(%0), ^tail(%0) // Branch.
 | 
			
		||||
  builder.setInsertionPointToStart(cond_block);
 | 
			
		||||
 | 
			
		||||
  // Replace the xla_hlo::ReturnOp with a branch back to the condition block.
 | 
			
		||||
  // This is required as the xla_hlo::ReturnOp is used to mark the end of a
 | 
			
		||||
  // block for regions nested inside of a operations (MLIR ReturnOp cannot be
 | 
			
		||||
  // nested within an non-function region).
 | 
			
		||||
  for (auto& block : while_op.cond()) {
 | 
			
		||||
    auto new_block = mapper.lookup(&block);
 | 
			
		||||
 | 
			
		||||
    auto return_op = dyn_cast<xla_hlo::ReturnOp>(new_block->getTerminator());
 | 
			
		||||
    if (!return_op) continue;
 | 
			
		||||
    builder.setInsertionPointToEnd(new_block);
 | 
			
		||||
 | 
			
		||||
    auto return_value = return_op.getOperand(0);
 | 
			
		||||
    auto cond_value = builder.create<mlir::ExtractElementOp>(loc, return_value);
 | 
			
		||||
 | 
			
		||||
    // Get the body block arguments.
 | 
			
		||||
    llvm::SmallVector<Value, 4> successor_args(cond_block->args_begin(),
 | 
			
		||||
                                               cond_block->args_end());
 | 
			
		||||
    builder.create<mlir::CondBranchOp>(loc, cond_value, body_block,
 | 
			
		||||
                                       successor_args, tail_block,
 | 
			
		||||
                                       successor_args);
 | 
			
		||||
    return_op.erase();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Updates the body blocks by replace the return op with an branch to the
 | 
			
		||||
  // conditional block. This changes the block below:
 | 
			
		||||
  //   ^body(%0):
 | 
			
		||||
  //     <inlined body block>
 | 
			
		||||
  //    "xla_hlo".return(%1)
 | 
			
		||||
  //
 | 
			
		||||
  //  Into:
 | 
			
		||||
  //   ^body(%0):
 | 
			
		||||
  //     <inlined body block>
 | 
			
		||||
  //     br ^cond(%0) // Branch.
 | 
			
		||||
  for (auto& block : while_op.body()) {
 | 
			
		||||
    auto new_block = mapper.lookup(&block);
 | 
			
		||||
    auto return_op =
 | 
			
		||||
        dyn_cast<mlir::xla_hlo::ReturnOp>(new_block->getTerminator());
 | 
			
		||||
    if (!return_op) continue;
 | 
			
		||||
    builder.setInsertionPointToEnd(new_block);
 | 
			
		||||
    builder.create<mlir::BranchOp>(loc, cond_block, return_op.getOperands());
 | 
			
		||||
    return_op.erase();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Erase the original while loop.
 | 
			
		||||
  tail_block->addArgument(while_op.getType());
 | 
			
		||||
  while_op.getResult().replaceAllUsesWith(tail_block->getArgument(0));
 | 
			
		||||
  op_inst->erase();
 | 
			
		||||
 | 
			
		||||
  return success();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void LegalizeControlFlow::runOnFunction() {
 | 
			
		||||
  auto func = getFunction();
 | 
			
		||||
  llvm::SmallVector<IfOp, 4> if_ops;
 | 
			
		||||
  func.walk([&](IfOp op) { if_ops.push_back(op); });
 | 
			
		||||
 | 
			
		||||
  for (auto& op : if_ops) {
 | 
			
		||||
    if (failed(LowerIfOp(op))) return signalPassFailure();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  llvm::SmallVector<WhileOp, 4> while_ops;
 | 
			
		||||
  func.walk([&](WhileOp op) { while_ops.push_back(op); });
 | 
			
		||||
 | 
			
		||||
  for (auto& op : while_ops) {
 | 
			
		||||
    if (failed(LowerWhileOp(op))) return signalPassFailure();
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
}  // namespace
 | 
			
		||||
}  // namespace xla_hlo
 | 
			
		||||
}  // namespace mlir
 | 
			
		||||
 | 
			
		||||
std::unique_ptr<mlir::OperationPass<mlir::FuncOp>>
 | 
			
		||||
mlir::xla_hlo::createLegalizeControlFlowPass() {
 | 
			
		||||
  return std::make_unique<LegalizeControlFlow>();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static PassRegistration<mlir::xla_hlo::LegalizeControlFlow> legalize_cf_pass(
 | 
			
		||||
    "xla-legalize-control-flow",
 | 
			
		||||
    "Legalize from XLA control flow to MLIR control flow");
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,156 @@
 | 
			
		|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
// This file implements logic for lowering the tanh standard ops to an
 | 
			
		||||
// approximation.
 | 
			
		||||
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Function.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
 | 
			
		||||
 | 
			
		||||
namespace mlir {
 | 
			
		||||
namespace xla {
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
/// Emits the fast tanh approximation that is also used by XLA.
 | 
			
		||||
Value EmitTanhApproximation(Value input, Location loc,
 | 
			
		||||
                            PatternRewriter &rewriter) {
 | 
			
		||||
  // For small values of x, we can approximate tanh(x)=x. For extremely small
 | 
			
		||||
  // values of x (|x| < 1e-37), the other approximation would evaluate
 | 
			
		||||
  // tanh(x) = 0.
 | 
			
		||||
  constexpr float kCanUseApprox = 0.0004;
 | 
			
		||||
  Value abs_value = rewriter.create<AbsFOp>(loc, input);
 | 
			
		||||
  Value can_use_approx =
 | 
			
		||||
      rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(kCanUseApprox));
 | 
			
		||||
  Value return_input = rewriter.create<CmpFOp>(loc, CmpFPredicate::OLT,
 | 
			
		||||
                                               abs_value, can_use_approx);
 | 
			
		||||
  // Clamp the input to [-c, c].
 | 
			
		||||
  Value max_clamp = rewriter.create<ConstantOp>(
 | 
			
		||||
      loc, rewriter.getF32FloatAttr(7.90531110763549805f));
 | 
			
		||||
  Value smaller_than_max =
 | 
			
		||||
      rewriter.create<CmpFOp>(loc, CmpFPredicate::ULE, input, max_clamp);
 | 
			
		||||
  Value clamped_half =
 | 
			
		||||
      rewriter.create<SelectOp>(loc, smaller_than_max, input, max_clamp);
 | 
			
		||||
  Value min_clamp = rewriter.create<ConstantOp>(
 | 
			
		||||
      loc, rewriter.getF32FloatAttr(-7.90531110763549805f));
 | 
			
		||||
  Value larger_than_min =
 | 
			
		||||
      rewriter.create<CmpFOp>(loc, CmpFPredicate::UGE, clamped_half, min_clamp);
 | 
			
		||||
  Value input_clamped =
 | 
			
		||||
      rewriter.create<SelectOp>(loc, larger_than_min, clamped_half, min_clamp);
 | 
			
		||||
 | 
			
		||||
  static constexpr std::array<float, 7> numerator_coeffs{
 | 
			
		||||
      -2.76076847742355e-16f, 2.00018790482477e-13f, -8.60467152213735e-11f,
 | 
			
		||||
      5.12229709037114e-08f,  1.48572235717979e-05f, 6.37261928875436e-04f,
 | 
			
		||||
      4.89352455891786e-03f};
 | 
			
		||||
 | 
			
		||||
  static constexpr std::array<float, 4> denominator_coeffs{
 | 
			
		||||
      1.19825839466702e-06f, 1.18534705686654e-04f, 2.26843463243900e-03f,
 | 
			
		||||
      4.89352518554385e-03f};
 | 
			
		||||
 | 
			
		||||
  Value input_squared =
 | 
			
		||||
      rewriter.create<MulFOp>(loc, input_clamped, input_clamped);
 | 
			
		||||
  Value numerator = rewriter.create<ConstantOp>(
 | 
			
		||||
      loc, rewriter.getF32FloatAttr(numerator_coeffs[0]));
 | 
			
		||||
  for (int i = 1; i < numerator_coeffs.size(); i++) {
 | 
			
		||||
    numerator = rewriter.create<AddFOp>(
 | 
			
		||||
        loc, rewriter.create<MulFOp>(loc, input_squared, numerator),
 | 
			
		||||
        rewriter.create<ConstantOp>(
 | 
			
		||||
            loc, rewriter.getF32FloatAttr(numerator_coeffs[i])));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  numerator = rewriter.create<MulFOp>(loc, input_clamped, numerator);
 | 
			
		||||
 | 
			
		||||
  Value denominator = rewriter.create<ConstantOp>(
 | 
			
		||||
      loc, rewriter.getF32FloatAttr(denominator_coeffs[0]));
 | 
			
		||||
  for (int i = 1; i < denominator_coeffs.size(); i++) {
 | 
			
		||||
    denominator = rewriter.create<AddFOp>(
 | 
			
		||||
        loc, rewriter.create<MulFOp>(loc, input_squared, denominator),
 | 
			
		||||
        rewriter.create<ConstantOp>(
 | 
			
		||||
            loc, rewriter.getF32FloatAttr(denominator_coeffs[i])));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Value approx = rewriter.create<DivFOp>(loc, numerator, denominator);
 | 
			
		||||
 | 
			
		||||
  return rewriter.create<SelectOp>(loc, return_input, input, approx);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
class ApproximateTanhLowering : public OpRewritePattern<TanhOp> {
 | 
			
		||||
 public:
 | 
			
		||||
  explicit ApproximateTanhLowering(MLIRContext *ctx)
 | 
			
		||||
      : OpRewritePattern<TanhOp>(ctx, 100) {}
 | 
			
		||||
 | 
			
		||||
  LogicalResult matchAndRewrite(TanhOp tanhOp,
 | 
			
		||||
                                PatternRewriter &rewriter) const override {
 | 
			
		||||
    Type operand_type = tanhOp.getType();
 | 
			
		||||
 | 
			
		||||
    if (operand_type.isF64()) {
 | 
			
		||||
      // Similar to XLA, do not rewrite f64 as precision might matter.
 | 
			
		||||
      return failure();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    Location loc = tanhOp.getLoc();
 | 
			
		||||
    Value input = tanhOp.operand();
 | 
			
		||||
    if (operand_type.isF16()) {
 | 
			
		||||
      input = rewriter.create<FPExtOp>(loc, input, rewriter.getF32Type());
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // If we still do not have f32, fail.
 | 
			
		||||
    if (!input.getType().isF32()) {
 | 
			
		||||
      return failure();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    Value result = EmitTanhApproximation(input, loc, rewriter);
 | 
			
		||||
 | 
			
		||||
    // Truncate back if needed.
 | 
			
		||||
    if (operand_type.isF16()) {
 | 
			
		||||
      result = rewriter.create<FPTruncOp>(loc, result, rewriter.getF16Type());
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    rewriter.replaceOp(tanhOp, {result});
 | 
			
		||||
    return success();
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct LegalizeTanhToApproximation
 | 
			
		||||
    : public PassWrapper<LegalizeTanhToApproximation, FunctionPass> {
 | 
			
		||||
  /// Perform the lowering of standard dialect operations to approximations.
 | 
			
		||||
  void runOnFunction() override {
 | 
			
		||||
    OwningRewritePatternList patterns;
 | 
			
		||||
    PopulateTanhToApproximationPatterns(&getContext(), &patterns);
 | 
			
		||||
    applyPatternsAndFoldGreedily(getFunction(), patterns);
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
}  // anonymous namespace
 | 
			
		||||
 | 
			
		||||
std::unique_ptr<mlir::OperationPass<mlir::FuncOp>>
 | 
			
		||||
createLegalizeTanhToApproximationPass() {
 | 
			
		||||
  return std::make_unique<LegalizeTanhToApproximation>();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void PopulateTanhToApproximationPatterns(mlir::MLIRContext *context,
 | 
			
		||||
                                         OwningRewritePatternList *patterns) {
 | 
			
		||||
  patterns->insert<ApproximateTanhLowering>(context);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static PassRegistration<LegalizeTanhToApproximation> legalize_pass(
 | 
			
		||||
    "xla-legalize-tanh-to-approximation",
 | 
			
		||||
    "Legalize tanh from standard dialect to an approximation");
 | 
			
		||||
 | 
			
		||||
}  // namespace xla
 | 
			
		||||
}  // namespace mlir
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,208 @@
 | 
			
		|||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
// This file implements logic for lowering XLA dialect to Standard dialect.
 | 
			
		||||
 | 
			
		||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringSwitch.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Function.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
 | 
			
		||||
 | 
			
		||||
namespace mlir {
 | 
			
		||||
namespace {
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/generated_legalize_to_standard.inc"
 | 
			
		||||
}  // end anonymous namespace
 | 
			
		||||
namespace xla_hlo {
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
class CompareIConvert : public OpRewritePattern<xla_hlo::CompareOp> {
 | 
			
		||||
 public:
 | 
			
		||||
  using OpRewritePattern::OpRewritePattern;
 | 
			
		||||
 | 
			
		||||
  LogicalResult matchAndRewrite(xla_hlo::CompareOp op,
 | 
			
		||||
                                PatternRewriter &rewriter) const override {
 | 
			
		||||
    auto lhs = op.lhs();
 | 
			
		||||
    auto rhs = op.rhs();
 | 
			
		||||
    auto lhs_type = lhs.getType().cast<TensorType>();
 | 
			
		||||
    auto rhs_type = rhs.getType().cast<TensorType>();
 | 
			
		||||
 | 
			
		||||
    // Broadcasting not supported by this rewrite.
 | 
			
		||||
    if (lhs_type.getShape() != rhs_type.getShape()) return failure();
 | 
			
		||||
 | 
			
		||||
    if (!lhs_type.getElementType().isSignlessInteger() ||
 | 
			
		||||
        !rhs_type.getElementType().isSignlessInteger())
 | 
			
		||||
      return failure();
 | 
			
		||||
 | 
			
		||||
    auto comparison_direction = op.comparison_direction();
 | 
			
		||||
    auto compare_predicate =
 | 
			
		||||
        llvm::StringSwitch<Optional<CmpIPredicate>>(comparison_direction)
 | 
			
		||||
            .Case("EQ", CmpIPredicate::eq)
 | 
			
		||||
            .Case("NE", CmpIPredicate::ne)
 | 
			
		||||
            .Case("LT", CmpIPredicate::slt)
 | 
			
		||||
            .Case("LE", CmpIPredicate::sle)
 | 
			
		||||
            .Case("GT", CmpIPredicate::sgt)
 | 
			
		||||
            .Case("GE", CmpIPredicate::sge)
 | 
			
		||||
            .Default(llvm::None);
 | 
			
		||||
 | 
			
		||||
    if (!compare_predicate.hasValue()) return failure();
 | 
			
		||||
 | 
			
		||||
    rewriter.replaceOpWithNewOp<CmpIOp>(op, compare_predicate.getValue(), lhs,
 | 
			
		||||
                                        rhs);
 | 
			
		||||
    return success();
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
class CompareFConvert : public OpRewritePattern<xla_hlo::CompareOp> {
 | 
			
		||||
 public:
 | 
			
		||||
  using OpRewritePattern::OpRewritePattern;
 | 
			
		||||
 | 
			
		||||
  LogicalResult matchAndRewrite(xla_hlo::CompareOp op,
 | 
			
		||||
                                PatternRewriter &rewriter) const override {
 | 
			
		||||
    auto lhs = op.lhs();
 | 
			
		||||
    auto rhs = op.rhs();
 | 
			
		||||
    auto lhs_type = lhs.getType().cast<TensorType>();
 | 
			
		||||
    auto rhs_type = rhs.getType().cast<TensorType>();
 | 
			
		||||
 | 
			
		||||
    // Broadcasting not supported by this rewrite.
 | 
			
		||||
    if (lhs_type.getShape() != rhs_type.getShape()) return failure();
 | 
			
		||||
 | 
			
		||||
    if (!lhs_type.getElementType().isa<FloatType>() ||
 | 
			
		||||
        !rhs_type.getElementType().isa<FloatType>())
 | 
			
		||||
      return failure();
 | 
			
		||||
 | 
			
		||||
    auto comparison_direction = op.comparison_direction();
 | 
			
		||||
    auto compare_predicate =
 | 
			
		||||
        llvm::StringSwitch<Optional<CmpFPredicate>>(comparison_direction)
 | 
			
		||||
            .Case("EQ", CmpFPredicate::OEQ)
 | 
			
		||||
            .Case("NE", CmpFPredicate::UNE)
 | 
			
		||||
            .Case("LT", CmpFPredicate::OLT)
 | 
			
		||||
            .Case("LE", CmpFPredicate::OLE)
 | 
			
		||||
            .Case("GT", CmpFPredicate::OGT)
 | 
			
		||||
            .Case("GE", CmpFPredicate::OGE)
 | 
			
		||||
            .Default(llvm::None);
 | 
			
		||||
 | 
			
		||||
    if (!compare_predicate.hasValue()) return failure();
 | 
			
		||||
 | 
			
		||||
    rewriter.replaceOpWithNewOp<CmpFOp>(op, compare_predicate.getValue(), lhs,
 | 
			
		||||
                                        rhs);
 | 
			
		||||
    return success();
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// Replace IotaOp with an integer constant. A ConvertOp is added to
 | 
			
		||||
// convert the integer constant to iota result type. For complex types, the real
 | 
			
		||||
// part is replaced with the generated constant and the imaginary part is
 | 
			
		||||
// replaced with zero tensor.
 | 
			
		||||
class ConvertIotaOp : public OpRewritePattern<xla_hlo::IotaOp> {
 | 
			
		||||
 public:
 | 
			
		||||
  using OpRewritePattern::OpRewritePattern;
 | 
			
		||||
 | 
			
		||||
  LogicalResult matchAndRewrite(xla_hlo::IotaOp op,
 | 
			
		||||
                                PatternRewriter &rewriter) const override {
 | 
			
		||||
    auto output_type = op.getType().cast<ShapedType>();
 | 
			
		||||
    auto output_size = output_type.getNumElements();
 | 
			
		||||
    auto dimension = op.iota_dimension().getSExtValue();
 | 
			
		||||
    auto max_dim_size = output_type.getDimSize(dimension);
 | 
			
		||||
 | 
			
		||||
    auto element_type = output_type.getElementType();
 | 
			
		||||
    int bitwidth;
 | 
			
		||||
 | 
			
		||||
    auto complex_ty = element_type.dyn_cast<ComplexType>();
 | 
			
		||||
    Type int_or_float_ty = element_type;
 | 
			
		||||
    if (complex_ty) int_or_float_ty = complex_ty.getElementType();
 | 
			
		||||
 | 
			
		||||
    bitwidth = int_or_float_ty.getIntOrFloatBitWidth();
 | 
			
		||||
    llvm::SmallVector<APInt, 10> values;
 | 
			
		||||
    values.reserve(output_size);
 | 
			
		||||
 | 
			
		||||
    int64_t increase_stride = output_size;
 | 
			
		||||
    for (int i = 0; i <= dimension; i++) {
 | 
			
		||||
      increase_stride /= output_type.getDimSize(i);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    int64_t current_value = 0;
 | 
			
		||||
    for (int i = 0; i < output_size; i++) {
 | 
			
		||||
      int64_t value = (current_value / increase_stride) % max_dim_size;
 | 
			
		||||
      values.push_back(APInt(bitwidth, value));
 | 
			
		||||
      ++current_value;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    auto int_shape_type = RankedTensorType::get(
 | 
			
		||||
        output_type.getShape(),
 | 
			
		||||
        IntegerType::get(bitwidth, rewriter.getContext()));
 | 
			
		||||
    auto loc = op.getLoc();
 | 
			
		||||
    auto integer_const = rewriter.create<mlir::ConstantOp>(
 | 
			
		||||
        loc, DenseIntElementsAttr::get(int_shape_type, values));
 | 
			
		||||
 | 
			
		||||
    auto int_or_float_shape_ty =
 | 
			
		||||
        RankedTensorType::get(output_type.getShape(), int_or_float_ty);
 | 
			
		||||
 | 
			
		||||
    auto iota_const =
 | 
			
		||||
        rewriter.create<ConvertOp>(loc, int_or_float_shape_ty, integer_const);
 | 
			
		||||
 | 
			
		||||
    // For int/float types we are done, replace op and return.
 | 
			
		||||
    if (!complex_ty) {
 | 
			
		||||
      rewriter.replaceOp(op, iota_const.getResult());
 | 
			
		||||
      return success();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // For complex types, generate a constant tensor of zeroes for the imaginary
 | 
			
		||||
    // part and use iota_const for real part.
 | 
			
		||||
    auto zeroes = rewriter.create<mlir::ConstantOp>(
 | 
			
		||||
        loc, DenseIntElementsAttr::get(int_shape_type, APInt(bitwidth, 0)));
 | 
			
		||||
    auto imag_zeroes =
 | 
			
		||||
        rewriter.create<ConvertOp>(loc, int_or_float_shape_ty, zeroes);
 | 
			
		||||
    rewriter.replaceOpWithNewOp<xla_hlo::ComplexOp>(op, iota_const,
 | 
			
		||||
                                                    imag_zeroes);
 | 
			
		||||
    return success();
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
}  // end anonymous namespace
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
struct LegalizeToStandard
 | 
			
		||||
    : public PassWrapper<LegalizeToStandard, FunctionPass> {
 | 
			
		||||
  /// Perform the lowering to Standard dialect.
 | 
			
		||||
  void runOnFunction() override;
 | 
			
		||||
};
 | 
			
		||||
}  // end anonymous namespace
 | 
			
		||||
 | 
			
		||||
std::unique_ptr<mlir::OperationPass<mlir::FuncOp>> createLegalizeToStdPass() {
 | 
			
		||||
  return std::make_unique<LegalizeToStandard>();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void PopulateXlaToStdPatterns(OwningRewritePatternList *patterns,
 | 
			
		||||
                              mlir::MLIRContext *ctx) {
 | 
			
		||||
  mlir::populateWithGenerated(ctx, patterns);
 | 
			
		||||
  patterns->insert<CompareFConvert, CompareIConvert, ConvertIotaOp>(ctx);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/// Perform the lowering to standard dialect.
 | 
			
		||||
void LegalizeToStandard::runOnFunction() {
 | 
			
		||||
  OwningRewritePatternList patterns;
 | 
			
		||||
  mlir::xla_hlo::PopulateXlaToStdPatterns(&patterns, &getContext());
 | 
			
		||||
  applyPatternsAndFoldGreedily(getFunction(), patterns);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static PassRegistration<LegalizeToStandard> legalize_pass(
 | 
			
		||||
    "xla-legalize-to-std", "Legalize from XLA dialect to standard dialect");
 | 
			
		||||
 | 
			
		||||
}  // end namespace xla_hlo
 | 
			
		||||
}  // end namespace mlir
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,71 @@
 | 
			
		|||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
// This is the legalization pattern definition file for XLA to StandardOps.
 | 
			
		||||
 | 
			
		||||
include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpBase.td"
 | 
			
		||||
include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td"
 | 
			
		||||
include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td"
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// Nullary op patterns.
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
def : Pat<(HLO_ConstOp ElementsAttr:$value),
 | 
			
		||||
          (ConstantOp $value)>;
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// Binary op patterns.
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
def IsSameSizePred : CPred<
 | 
			
		||||
    "$0.getType().cast<ShapedType>().getShape() "
 | 
			
		||||
    "== $1.getType().cast<ShapedType>().getShape()">;
 | 
			
		||||
def IsSameSizeConstraint : Constraint<IsSameSizePred, "inputs are same size">;
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def : Pat<(HLO_AndOp HLO_PredTensor:$l, HLO_PredTensor:$r),
 | 
			
		||||
          (AndOp $l, $r),
 | 
			
		||||
          [(IsSameSizeConstraint $l, $r)]>;
 | 
			
		||||
def : Pat<(HLO_AddOp HLO_FpTensor:$l, HLO_FpTensor:$r),
 | 
			
		||||
          (AddFOp $l, $r),
 | 
			
		||||
          [(IsSameSizeConstraint $l, $r)]>;
 | 
			
		||||
def : Pat<(HLO_SubOp HLO_FpTensor:$l, HLO_FpTensor:$r),
 | 
			
		||||
          (SubFOp $l, $r),
 | 
			
		||||
          [(IsSameSizeConstraint $l, $r)]>;
 | 
			
		||||
def : Pat<(HLO_MulOp HLO_FpTensor:$l, HLO_FpTensor:$r),
 | 
			
		||||
          (MulFOp $l, $r),
 | 
			
		||||
          [(IsSameSizeConstraint $l, $r)]>;
 | 
			
		||||
def : Pat<(HLO_DivOp HLO_FpTensor:$l, HLO_FpTensor:$r),
 | 
			
		||||
          (DivFOp $l, $r),
 | 
			
		||||
          [(IsSameSizeConstraint $l, $r)]>;
 | 
			
		||||
def : Pat<(HLO_RemOp HLO_FpTensor:$l, HLO_FpTensor:$r),
 | 
			
		||||
          (RemFOp $l, $r),
 | 
			
		||||
          [(IsSameSizeConstraint $l, $r)]>;
 | 
			
		||||
def : Pat<(HLO_AddOp HLO_IntTensor:$l, HLO_IntTensor:$r),
 | 
			
		||||
          (AddIOp $l, $r),
 | 
			
		||||
          [(IsSameSizeConstraint $l, $r)]>;
 | 
			
		||||
def : Pat<(HLO_SubOp HLO_IntTensor:$l, HLO_IntTensor:$r),
 | 
			
		||||
          (SubIOp $l, $r),
 | 
			
		||||
          [(IsSameSizeConstraint $l, $r)]>;
 | 
			
		||||
def : Pat<(HLO_MulOp HLO_IntTensor:$l, HLO_IntTensor:$r),
 | 
			
		||||
          (MulIOp $l, $r),
 | 
			
		||||
          [(IsSameSizeConstraint $l, $r)]>;
 | 
			
		||||
def : Pat<(HLO_DivOp HLO_IntTensor:$l, HLO_IntTensor:$r),
 | 
			
		||||
          (SignedDivIOp $l, $r),
 | 
			
		||||
          [(IsSameSizeConstraint $l, $r)]>;
 | 
			
		||||
def : Pat<(HLO_RemOp HLO_IntTensor:$l, HLO_IntTensor:$r),
 | 
			
		||||
          (SignedRemIOp $l, $r),
 | 
			
		||||
          [(IsSameSizeConstraint $l, $r)]>;
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,105 @@
 | 
			
		|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
// This file implements a pass to remove redundant LHLO copy operations.
 | 
			
		||||
 | 
			
		||||
#include "third_party/absl/memory/memory.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
 | 
			
		||||
 | 
			
		||||
namespace mlir {
 | 
			
		||||
namespace xla_lhlo {
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
// Removes LHLO copy operations that copy from allocated buffers to block
 | 
			
		||||
// arguments. All uses of each buffer are replaced with the corresponding block
 | 
			
		||||
// argument and the buffer is freed. Note that this pass only works in regions
 | 
			
		||||
// with a single block.
 | 
			
		||||
struct LhloCopyRemoval : mlir::PassWrapper<LhloCopyRemoval, OperationPass<>> {
 | 
			
		||||
  void runOnOperation() override {
 | 
			
		||||
    llvm::SmallVector<mlir::Operation*, 2> eraseList;
 | 
			
		||||
    auto operation = getOperation();
 | 
			
		||||
    operation->walk([&](mlir::xla_lhlo::CopyOp copyOp) {
 | 
			
		||||
      // If this region contains more than one block, then ignore this copy
 | 
			
		||||
      // operation.
 | 
			
		||||
      if (copyOp.getParentRegion()->getBlocks().size() > 1) {
 | 
			
		||||
        return;
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      mlir::Value fromOperand = copyOp.operand();
 | 
			
		||||
      mlir::Value toOperand = copyOp.output();
 | 
			
		||||
 | 
			
		||||
      // If the fromOperand value is a block argument or the toOperand
 | 
			
		||||
      // value is not a block argument, then ignore this copy operation.
 | 
			
		||||
      if (!fromOperand.getDefiningOp() || toOperand.getDefiningOp()) {
 | 
			
		||||
        return;
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      // The copy operation removal is illegal if there is at least a single use
 | 
			
		||||
      // of toOperand value that lies between the first use of fromOperand value
 | 
			
		||||
      // and the copy operation.
 | 
			
		||||
      auto fromOperandUsers = fromOperand.getUsers();
 | 
			
		||||
      auto firstUser = *fromOperandUsers.begin();
 | 
			
		||||
      for (auto op : fromOperandUsers) {
 | 
			
		||||
        if (op->isBeforeInBlock(firstUser)) firstUser = op;
 | 
			
		||||
      }
 | 
			
		||||
      for (auto op : toOperand.getUsers()) {
 | 
			
		||||
        if (op->isBeforeInBlock(copyOp) && firstUser->isBeforeInBlock(op)) {
 | 
			
		||||
          return;
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      // TODO(DFKI): Use live variable analysis to solve aliasing issues among
 | 
			
		||||
      // block arguments.
 | 
			
		||||
 | 
			
		||||
      // Remove the associated alloc operation.
 | 
			
		||||
      auto allocOp = fromOperand.getDefiningOp();
 | 
			
		||||
      eraseList.push_back(allocOp);
 | 
			
		||||
 | 
			
		||||
      // Iterate over all uses of the fromOperand to find the associated
 | 
			
		||||
      // deallocOp (if any).
 | 
			
		||||
      for (auto op : fromOperandUsers) {
 | 
			
		||||
        if (isa<mlir::DeallocOp>(op)) {
 | 
			
		||||
          eraseList.push_back(op);
 | 
			
		||||
          break;
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      // Replace all uses of the fromOperand with the toOperand. This rewires
 | 
			
		||||
      // all references pointing to the original alloc operation to the new
 | 
			
		||||
      // target operation in order to safely remove the copy op.
 | 
			
		||||
      fromOperand.replaceAllUsesWith(toOperand);
 | 
			
		||||
      copyOp.erase();
 | 
			
		||||
    });
 | 
			
		||||
    for (auto op : eraseList) {
 | 
			
		||||
      op->erase();
 | 
			
		||||
    }
 | 
			
		||||
  };
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
 | 
			
		||||
std::unique_ptr<Pass> createLhloCopyRemovalPass() {
 | 
			
		||||
  return absl::make_unique<LhloCopyRemoval>();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static PassRegistration<LhloCopyRemoval> copy_removal_pass(
 | 
			
		||||
    "lhlo-copy-removal", "Removes redundant LHLO copy operations");
 | 
			
		||||
 | 
			
		||||
}  // namespace xla_lhlo
 | 
			
		||||
}  // namespace mlir
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,151 @@
 | 
			
		|||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
// This file implements logic for fusing linalg ops obtained after LHLO
 | 
			
		||||
// lowering.
 | 
			
		||||
 | 
			
		||||
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
 | 
			
		||||
#include "third_party/absl/memory/memory.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/ArrayRef.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLExtras.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/FoldUtils.h"
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
 | 
			
		||||
 | 
			
		||||
namespace mlir {
 | 
			
		||||
namespace xla_lhlo {
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
using linalg::LinalgOp;
 | 
			
		||||
 | 
			
		||||
class LhloFuseLinalg : public PassWrapper<LhloFuseLinalg, FunctionPass> {
 | 
			
		||||
 public:
 | 
			
		||||
  LhloFuseLinalg() = default;
 | 
			
		||||
  LhloFuseLinalg(const LhloFuseLinalg&) {}
 | 
			
		||||
  LhloFuseLinalg(bool use_parallel_loops, llvm::ArrayRef<unsigned> tile_sizes) {
 | 
			
		||||
    tile_sizes_ = tile_sizes;
 | 
			
		||||
    use_parallel_loops_.setValue(use_parallel_loops);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void runOnFunction() override {
 | 
			
		||||
    auto func = getFunction();
 | 
			
		||||
 | 
			
		||||
    // TODO(pifon): Remove assumption that the function has a single block.
 | 
			
		||||
    if (!llvm::hasSingleElement(func)) {
 | 
			
		||||
      emitError(func.getLoc(), "The function needs to have a single block.");
 | 
			
		||||
      signalPassFailure();
 | 
			
		||||
      return;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // The fusion in Linalg is currently possible only when the consumer op is
 | 
			
		||||
    // tiled. In order to greedily fuse the ops, we have to start from the tiled
 | 
			
		||||
    // root linalg ops, i.e. linalg ops that write to output buffers of the
 | 
			
		||||
    // function or are returned in case of escaping allocations.
 | 
			
		||||
    llvm::SmallDenseSet<Value> result_buffers;
 | 
			
		||||
    for (auto func_arg : func.getArguments()) {
 | 
			
		||||
      result_buffers.insert(func_arg);
 | 
			
		||||
    }
 | 
			
		||||
    for (auto& block : func) {
 | 
			
		||||
      auto returnOp = mlir::dyn_cast<mlir::ReturnOp>(block.getTerminator());
 | 
			
		||||
      if (!returnOp) continue;
 | 
			
		||||
      for (auto operand : returnOp.getOperands()) {
 | 
			
		||||
        result_buffers.insert(operand);
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
    MLIRContext* ctx = func.getContext();
 | 
			
		||||
    OpBuilder b(func);
 | 
			
		||||
    OperationFolder folder(ctx);
 | 
			
		||||
    func.walk([&](linalg::GenericOp generic_op) {
 | 
			
		||||
      SmallVector<int64_t, 2> tile_sizes(tile_sizes_.begin(),
 | 
			
		||||
                                         tile_sizes_.end());
 | 
			
		||||
      if (tile_sizes.empty()) {
 | 
			
		||||
        tile_sizes = SmallVector<int64_t, 2>(generic_op.getNumLoops(), 1);
 | 
			
		||||
      }
 | 
			
		||||
      auto op = cast<LinalgOp>(generic_op.getOperation());
 | 
			
		||||
      for (const Value result : op.getOutputBuffers()) {
 | 
			
		||||
        if (!result_buffers.count(result)) continue;
 | 
			
		||||
        if (tileGenericOp(op, tile_sizes, &b)) {
 | 
			
		||||
          generic_op.erase();
 | 
			
		||||
          return;
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
    });
 | 
			
		||||
    auto patterns = linalg::getLinalgTilingCanonicalizationPatterns(ctx);
 | 
			
		||||
    applyPatternsAndFoldGreedily(func, patterns);
 | 
			
		||||
 | 
			
		||||
    // Fuse producers of tiled linalg ops.
 | 
			
		||||
    llvm::SmallDenseSet<Operation*> erase_set;
 | 
			
		||||
    SmallVector<Operation*, 8> linalg_ops;
 | 
			
		||||
    func.walk([&](LinalgOp op) { linalg_ops.push_back(op); });
 | 
			
		||||
    for (auto* op : llvm::reverse(linalg_ops)) {
 | 
			
		||||
      for (unsigned id = 0, e = LinalgOp(op).getNumInputs(); id < e; ++id) {
 | 
			
		||||
        linalg::Aliases aliases;
 | 
			
		||||
        linalg::LinalgDependenceGraph graph(aliases, linalg_ops);
 | 
			
		||||
        if (auto info = fuseProducerOf(b, op, id, graph, &folder)) {
 | 
			
		||||
          auto originalOp = info->originalProducer.getOperation();
 | 
			
		||||
          erase_set.insert(originalOp);
 | 
			
		||||
          auto originalOpInLinalgOpsVector = std::find_if(
 | 
			
		||||
              linalg_ops.begin(), linalg_ops.end(),
 | 
			
		||||
              [&](const Operation* op) { return op == originalOp; });
 | 
			
		||||
          *originalOpInLinalgOpsVector = info->fusedProducer.getOperation();
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      auto patterns = linalg::getLinalgTilingCanonicalizationPatterns(ctx);
 | 
			
		||||
      applyPatternsAndFoldGreedily(func, patterns);
 | 
			
		||||
    }
 | 
			
		||||
    for (auto* e : erase_set) e->erase();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  bool tileGenericOp(LinalgOp op, ArrayRef<int64_t> tile_sizes, OpBuilder* b) {
 | 
			
		||||
    auto loopType = use_parallel_loops_
 | 
			
		||||
                        ? linalg::LinalgTilingLoopType::ParallelLoops
 | 
			
		||||
                        : linalg::LinalgTilingLoopType::Loops;
 | 
			
		||||
    auto tiled_generic_op = linalg::tileLinalgOp(*b, op,
 | 
			
		||||
                                                 linalg::LinalgTilingOptions()
 | 
			
		||||
                                                     .setTileSizes(tile_sizes)
 | 
			
		||||
                                                     .setLoopType(loopType));
 | 
			
		||||
    return tiled_generic_op.hasValue();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Option<bool> use_parallel_loops_{
 | 
			
		||||
      *this, "use-parallel-loops",
 | 
			
		||||
      llvm::cl::desc(
 | 
			
		||||
          "Tiles GenericOp consumer to parallel loops before linalg fusion"),
 | 
			
		||||
      llvm::cl::init(false)};
 | 
			
		||||
 | 
			
		||||
  ListOption<unsigned> tile_sizes_{
 | 
			
		||||
      *this, "tile-sizes",
 | 
			
		||||
      llvm::cl::desc(
 | 
			
		||||
          "Tile sizes by which to tile linalg generic before linalg fusion"),
 | 
			
		||||
      llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
 | 
			
		||||
std::unique_ptr<OperationPass<FuncOp>> createLhloFuseLinalg(
 | 
			
		||||
    bool use_parallel_loops, ArrayRef<unsigned> tile_sizes) {
 | 
			
		||||
  return absl::make_unique<LhloFuseLinalg>(use_parallel_loops, tile_sizes);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static PassRegistration<LhloFuseLinalg> legalize_pass(
 | 
			
		||||
    "lhlo-fuse-linalg",
 | 
			
		||||
    "Greedily fuse linalg ops obtained after LHLO lowering.");
 | 
			
		||||
 | 
			
		||||
}  // namespace xla_lhlo
 | 
			
		||||
}  // namespace mlir
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,161 @@
 | 
			
		|||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
// This file implements logic for lowering LHLO dialect to Affine dialect.
 | 
			
		||||
 | 
			
		||||
#include "third_party/absl/memory/memory.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Location.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_xla_to_scalar_op.h"
 | 
			
		||||
 | 
			
		||||
namespace mlir {
 | 
			
		||||
namespace xla_lhlo {
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
// Builds an affine loop nest iterating from zeros to "upper_bounds" with unit
 | 
			
		||||
// steps, and populates the body of the innermost loop using "body_builder".
 | 
			
		||||
static void BuildBoundedAffineLoopNest(
 | 
			
		||||
    OpBuilder& builder, Location location, ArrayRef<int64_t> upper_bounds,
 | 
			
		||||
    function_ref<void(OpBuilder&, Location, ValueRange)> body_builder) {
 | 
			
		||||
  SmallVector<int64_t, 3> lower_bounds(upper_bounds.size(), /*Value=*/0);
 | 
			
		||||
  SmallVector<int64_t, 3> steps(upper_bounds.size(), /*Value=*/1);
 | 
			
		||||
  buildAffineLoopNest(builder, location, lower_bounds, upper_bounds, steps,
 | 
			
		||||
                      body_builder);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
struct DotOpConverter : public OpRewritePattern<DotOp> {
 | 
			
		||||
  using OpRewritePattern<DotOp>::OpRewritePattern;
 | 
			
		||||
 | 
			
		||||
  // Supports only rank-2 tensors for LHS and RHS.
 | 
			
		||||
  LogicalResult matchAndRewrite(DotOp op,
 | 
			
		||||
                                PatternRewriter& rewriter) const override {
 | 
			
		||||
    Value lhs = op.lhs();
 | 
			
		||||
    Value rhs = op.rhs();
 | 
			
		||||
    MemRefType lhs_type = lhs.getType().cast<MemRefType>();
 | 
			
		||||
    MemRefType rhs_type = rhs.getType().cast<MemRefType>();
 | 
			
		||||
    Type element_type = lhs_type.getElementType();
 | 
			
		||||
    ArrayRef<int64_t> shape_lhs = lhs_type.getShape();
 | 
			
		||||
    ArrayRef<int64_t> shape_rhs = rhs_type.getShape();
 | 
			
		||||
 | 
			
		||||
    if ((lhs_type.getRank() != 2) || (rhs_type.getRank() != 2)) {
 | 
			
		||||
      return failure();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    LogicalResult map_status = success();
 | 
			
		||||
    auto body_builder = [&](OpBuilder& builder, Location loc, ValueRange ivs) {
 | 
			
		||||
      SmallVector<Value, 2> lhs_indices{ivs[0], ivs[2]},
 | 
			
		||||
          rhs_indices{ivs[2], ivs[1]}, result_indices{ivs[0], ivs[1]};
 | 
			
		||||
 | 
			
		||||
      auto l = builder.create<AffineLoadOp>(loc, lhs, lhs_indices);
 | 
			
		||||
      auto r = builder.create<AffineLoadOp>(loc, rhs, rhs_indices);
 | 
			
		||||
      auto result =
 | 
			
		||||
          rewriter.create<AffineLoadOp>(loc, op.output(), result_indices);
 | 
			
		||||
      Value op_result = xla_lhlo::XlaOpToStdScalarOp::map<DotOp>(
 | 
			
		||||
          op, element_type, {l, r, result}, &builder);
 | 
			
		||||
      map_status = success(op_result != nullptr);
 | 
			
		||||
      if (failed(map_status)) return;
 | 
			
		||||
      builder.create<AffineStoreOp>(loc, op_result, op.output(),
 | 
			
		||||
                                    result_indices);
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    BuildBoundedAffineLoopNest(rewriter, op.getLoc(),
 | 
			
		||||
                               {shape_lhs[0], shape_rhs[1], shape_rhs[0]},
 | 
			
		||||
                               body_builder);
 | 
			
		||||
    if (failed(map_status)) return failure();
 | 
			
		||||
 | 
			
		||||
    rewriter.eraseOp(op);
 | 
			
		||||
    return success();
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
template <typename LhloOpTy>
 | 
			
		||||
struct BinaryOpConverter : public OpRewritePattern<LhloOpTy> {
 | 
			
		||||
  using OpRewritePattern<LhloOpTy>::OpRewritePattern;
 | 
			
		||||
 | 
			
		||||
  LogicalResult matchAndRewrite(LhloOpTy op,
 | 
			
		||||
                                PatternRewriter& rewriter) const override {
 | 
			
		||||
    const auto& lhs = op.lhs();
 | 
			
		||||
    const auto& rhs = op.rhs();
 | 
			
		||||
    const auto& lhs_type = lhs.getType().template cast<MemRefType>();
 | 
			
		||||
    const auto& rhs_type = rhs.getType().template cast<MemRefType>();
 | 
			
		||||
    const auto& element_type = lhs_type.getElementType();
 | 
			
		||||
 | 
			
		||||
    if (lhs_type.getShape() != rhs_type.getShape()) {
 | 
			
		||||
      return failure();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    LogicalResult map_status = success();
 | 
			
		||||
    auto body_builder = [&](OpBuilder& builder, Location loc,
 | 
			
		||||
                            ValueRange induction_vars) {
 | 
			
		||||
      auto l = builder.create<AffineLoadOp>(loc, lhs, induction_vars);
 | 
			
		||||
      auto r = builder.create<AffineLoadOp>(loc, rhs, induction_vars);
 | 
			
		||||
      Value op_result = xla_lhlo::XlaOpToStdScalarOp::map<LhloOpTy>(
 | 
			
		||||
          op, element_type, {l, r}, &builder);
 | 
			
		||||
      map_status = success(op_result != nullptr);
 | 
			
		||||
      if (failed(map_status)) return;
 | 
			
		||||
      rewriter.create<AffineStoreOp>(loc, op_result, op.out(), induction_vars);
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    BuildBoundedAffineLoopNest(rewriter, op.getLoc(), lhs_type.getShape(),
 | 
			
		||||
                               body_builder);
 | 
			
		||||
    if (failed(map_status)) return failure();
 | 
			
		||||
    rewriter.eraseOp(op);
 | 
			
		||||
    return success();
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
void populateLHLOToAffineConversionPattern(MLIRContext* context,
 | 
			
		||||
                                           OwningRewritePatternList* patterns) {
 | 
			
		||||
  // clang-format off
 | 
			
		||||
  patterns->insert<
 | 
			
		||||
      BinaryOpConverter<xla_lhlo::AddOp>,
 | 
			
		||||
      BinaryOpConverter<xla_lhlo::AndOp>,
 | 
			
		||||
      BinaryOpConverter<xla_lhlo::DivOp>,
 | 
			
		||||
      BinaryOpConverter<xla_lhlo::MaxOp>,
 | 
			
		||||
      BinaryOpConverter<xla_lhlo::MinOp>,
 | 
			
		||||
      BinaryOpConverter<xla_lhlo::MulOp>,
 | 
			
		||||
      BinaryOpConverter<xla_lhlo::SubOp>,
 | 
			
		||||
      DotOpConverter>(context);
 | 
			
		||||
  // clang-format on
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
struct LhloLegalizeToAffine
 | 
			
		||||
    : public PassWrapper<LhloLegalizeToAffine, FunctionPass> {
 | 
			
		||||
  void runOnFunction() override {
 | 
			
		||||
    OwningRewritePatternList patterns;
 | 
			
		||||
    auto func = getFunction();
 | 
			
		||||
    populateLHLOToAffineConversionPattern(func.getContext(), &patterns);
 | 
			
		||||
    applyPatternsAndFoldGreedily(func, patterns);
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
 | 
			
		||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeToAffinePass() {
 | 
			
		||||
  return absl::make_unique<LhloLegalizeToAffine>();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static PassRegistration<LhloLegalizeToAffine> legalize_pass(
 | 
			
		||||
    "lhlo-legalize-to-affine", "Legalize from LHLO dialect to affine dialect");
 | 
			
		||||
 | 
			
		||||
}  // namespace xla_lhlo
 | 
			
		||||
}  // namespace mlir
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,196 @@
 | 
			
		|||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
// This file implements logic for lowering LHLO dialect to GPU dialect.
 | 
			
		||||
 | 
			
		||||
#include <cstdint>
 | 
			
		||||
 | 
			
		||||
#include "third_party/absl/memory/memory.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/ArrayRef.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/GPU/GPUDialect.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/SCF/SCF.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/BlockAndValueMapping.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Function.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Location.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h"
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_xla_to_scalar_op.h"
 | 
			
		||||
 | 
			
		||||
namespace mlir {
 | 
			
		||||
namespace xla_lhlo {
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
// A simple translation of LHLO reduce operations to a corresponding gpu
 | 
			
		||||
// launch operation. The transformation does no tiling and also only supports
 | 
			
		||||
// 1d results.
 | 
			
		||||
class LhloReduceToGPULaunchConverter : public OpConversionPattern<ReduceOp> {
 | 
			
		||||
 public:
 | 
			
		||||
  using OpConversionPattern::OpConversionPattern;
 | 
			
		||||
 | 
			
		||||
  LogicalResult matchAndRewrite(
 | 
			
		||||
      ReduceOp reduce_op, ArrayRef<Value> args,
 | 
			
		||||
      ConversionPatternRewriter& rewriter) const final {
 | 
			
		||||
    auto loc = reduce_op.getLoc();
 | 
			
		||||
    // Only support 1d reductions for now.
 | 
			
		||||
    int64_t size = 0;
 | 
			
		||||
    for (auto result : reduce_op.out()) {
 | 
			
		||||
      auto shaped_type = result.getType().dyn_cast<ShapedType>();
 | 
			
		||||
      if (!shaped_type || shaped_type.getRank() != 1) {
 | 
			
		||||
        return failure();
 | 
			
		||||
      }
 | 
			
		||||
      auto dim_size = shaped_type.getDimSize(0);
 | 
			
		||||
      if (size && size != dim_size) {
 | 
			
		||||
        return failure();
 | 
			
		||||
      }
 | 
			
		||||
      size = dim_size;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    auto reducing_dimension = *reduce_op.dimensions().int_value_begin();
 | 
			
		||||
 | 
			
		||||
    // Require all inputs to have the same shape.
 | 
			
		||||
    int64_t reduce_dim_size = 0;
 | 
			
		||||
    for (auto input : reduce_op.operands()) {
 | 
			
		||||
      auto shaped_type = input.getType().dyn_cast<ShapedType>();
 | 
			
		||||
      if (!shaped_type || !shaped_type.hasStaticShape()) {
 | 
			
		||||
        return failure();
 | 
			
		||||
      }
 | 
			
		||||
      reduce_dim_size =
 | 
			
		||||
          shaped_type.getDimSize(reducing_dimension.getSExtValue());
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Create a launch that is parallel in the result dimension.
 | 
			
		||||
    auto block_size_x = rewriter.create<mlir::ConstantOp>(
 | 
			
		||||
        loc, rewriter.getIndexType(),
 | 
			
		||||
        rewriter.getIntegerAttr(rewriter.getIndexType(), size));
 | 
			
		||||
    auto one = rewriter.create<mlir::ConstantOp>(
 | 
			
		||||
        loc, rewriter.getIndexType(),
 | 
			
		||||
        rewriter.getIntegerAttr(rewriter.getIndexType(), 1));
 | 
			
		||||
    auto launch_op = rewriter.create<mlir::gpu::LaunchOp>(
 | 
			
		||||
        loc, one, one, one, block_size_x, one, one);
 | 
			
		||||
    {
 | 
			
		||||
      OpBuilder::InsertionGuard guard(rewriter);
 | 
			
		||||
      rewriter.setInsertionPointToEnd(&launch_op.body().front());
 | 
			
		||||
      auto index = launch_op.getThreadIds().x;
 | 
			
		||||
 | 
			
		||||
      // Load the initial value and store it to the output.
 | 
			
		||||
      for (auto pair : llvm::zip(reduce_op.init_values(), reduce_op.out())) {
 | 
			
		||||
        auto init_value = rewriter.create<mlir::LoadOp>(loc, std::get<0>(pair));
 | 
			
		||||
        rewriter.create<mlir::StoreOp>(loc, init_value, std::get<1>(pair),
 | 
			
		||||
                                       ArrayRef<Value>{index});
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      // Insert a loop into the body to compute the reduction. The loop ranges
 | 
			
		||||
      // from [0.dim).
 | 
			
		||||
      auto zero = rewriter.create<mlir::ConstantOp>(
 | 
			
		||||
          loc, rewriter.getIndexType(),
 | 
			
		||||
          rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
 | 
			
		||||
      // TODO(b/137624192) Use dimOp to make it shape independent.
 | 
			
		||||
      auto upper = rewriter.create<mlir::ConstantOp>(
 | 
			
		||||
          loc, rewriter.getIndexType(),
 | 
			
		||||
          rewriter.getIntegerAttr(rewriter.getIndexType(), reduce_dim_size));
 | 
			
		||||
      auto step = rewriter.create<mlir::ConstantOp>(
 | 
			
		||||
          loc, rewriter.getIndexType(),
 | 
			
		||||
          rewriter.getIntegerAttr(rewriter.getIndexType(), 1));
 | 
			
		||||
      auto loop = rewriter.create<mlir::scf::ForOp>(loc, zero, upper, step);
 | 
			
		||||
 | 
			
		||||
      rewriter.setInsertionPointToStart(loop.getBody());
 | 
			
		||||
      // Compute memrefs for the value to reduce. This makes it easier to just
 | 
			
		||||
      // inline the body.
 | 
			
		||||
      auto output = *reduce_op.out().begin();
 | 
			
		||||
      // TODO(herhut) Move this to the SliceOp builder.
 | 
			
		||||
      auto resType = MemRefType::get(
 | 
			
		||||
          llvm::None, output.getType().cast<MemRefType>().getElementType(),
 | 
			
		||||
          makeStridedLinearLayoutMap(llvm::None,
 | 
			
		||||
                                     MemRefType::getDynamicStrideOrOffset(),
 | 
			
		||||
                                     rewriter.getContext()));
 | 
			
		||||
      auto accumulator = rewriter.create<mlir::linalg::SliceOp>(
 | 
			
		||||
          loc, resType, output, ArrayRef<Value>{launch_op.getThreadIds().x});
 | 
			
		||||
      llvm::SmallVector<Value, 4> indexings;
 | 
			
		||||
      auto input_buffer = *reduce_op.operands().begin();
 | 
			
		||||
      auto input_type = input_buffer.getType().cast<MemRefType>();
 | 
			
		||||
      for (int64_t dim = 0; dim < input_type.getRank(); ++dim) {
 | 
			
		||||
        indexings.push_back(dim == reducing_dimension
 | 
			
		||||
                                ? loop.getInductionVar()
 | 
			
		||||
                                : launch_op.getThreadIds().x);
 | 
			
		||||
      }
 | 
			
		||||
      // TODO(herhut) Move this to the SliceOp builder.
 | 
			
		||||
      auto input = *reduce_op.operand_begin();
 | 
			
		||||
      auto rhs = rewriter.create<mlir::linalg::SliceOp>(
 | 
			
		||||
          loc,
 | 
			
		||||
          MemRefType::get(
 | 
			
		||||
              llvm::None, input_type.getElementType(),
 | 
			
		||||
              makeStridedLinearLayoutMap(llvm::None,
 | 
			
		||||
                                         MemRefType::getDynamicStrideOrOffset(),
 | 
			
		||||
                                         rewriter.getContext())),
 | 
			
		||||
          input, indexings);
 | 
			
		||||
 | 
			
		||||
      // Now copy over the actual body of the reduction, leaving out the
 | 
			
		||||
      // terminator.
 | 
			
		||||
      BlockAndValueMapping mapping;
 | 
			
		||||
      mapping.map(reduce_op.body().front().getArgument(0), accumulator);
 | 
			
		||||
      mapping.map(reduce_op.body().front().getArgument(1), rhs);
 | 
			
		||||
      mapping.map(reduce_op.body().front().getArgument(2), accumulator);
 | 
			
		||||
      for (auto& nested : reduce_op.body().front().without_terminator()) {
 | 
			
		||||
        auto clone = rewriter.clone(nested, mapping);
 | 
			
		||||
        for (auto pair : llvm::zip(nested.getResults(), clone->getResults())) {
 | 
			
		||||
          mapping.map(std::get<0>(pair), std::get<1>(pair));
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      // Finally, insert the terminator for the launchOp.
 | 
			
		||||
      rewriter.setInsertionPointToEnd(&launch_op.body().front());
 | 
			
		||||
      rewriter.create<mlir::gpu::TerminatorOp>(loc);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    rewriter.eraseOp(reduce_op);
 | 
			
		||||
    return success();
 | 
			
		||||
  };
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct LhloLegalizeToGpu : public PassWrapper<LhloLegalizeToGpu, FunctionPass> {
 | 
			
		||||
  void runOnFunction() override {
 | 
			
		||||
    OwningRewritePatternList patterns;
 | 
			
		||||
    ConversionTarget target(getContext());
 | 
			
		||||
    target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,
 | 
			
		||||
                           gpu::GPUDialect, scf::SCFDialect, XlaLhloDialect>();
 | 
			
		||||
    target.addIllegalOp<ReduceOp>();
 | 
			
		||||
    auto func = getFunction();
 | 
			
		||||
    patterns.insert<LhloReduceToGPULaunchConverter>(func.getContext());
 | 
			
		||||
    if (failed(applyPartialConversion(func, target, patterns))) {
 | 
			
		||||
      signalPassFailure();
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
 | 
			
		||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeToGpuPass() {
 | 
			
		||||
  return absl::make_unique<LhloLegalizeToGpu>();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static PassRegistration<LhloLegalizeToGpu> legalize_pass(
 | 
			
		||||
    "lhlo-legalize-to-gpu", "Legalize from LHLO dialect to GPU dialect");
 | 
			
		||||
 | 
			
		||||
}  // namespace xla_lhlo
 | 
			
		||||
}  // namespace mlir
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,136 @@
 | 
			
		|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h"
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
 | 
			
		||||
 | 
			
		||||
namespace mlir {
 | 
			
		||||
namespace xla_lhlo {
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
struct StaticMemRefCastOpConverter
 | 
			
		||||
    : public ConvertOpToLLVMPattern<StaticMemRefCastOp> {
 | 
			
		||||
  using ConvertOpToLLVMPattern<StaticMemRefCastOp>::ConvertOpToLLVMPattern;
 | 
			
		||||
 | 
			
		||||
  LogicalResult matchAndRewrite(
 | 
			
		||||
      Operation *op, ArrayRef<Value> operands,
 | 
			
		||||
      ConversionPatternRewriter &rewriter) const override {
 | 
			
		||||
    auto loc = op->getLoc();
 | 
			
		||||
    auto cast_op = cast<StaticMemRefCastOp>(op);
 | 
			
		||||
 | 
			
		||||
    StaticMemRefCastOp::Adaptor operands_adaptor(operands);
 | 
			
		||||
    MemRefDescriptor sourceMemRef(operands_adaptor.operand());
 | 
			
		||||
 | 
			
		||||
    MemRefType targetMemRefType =
 | 
			
		||||
        cast_op.getResult().getType().cast<MemRefType>();
 | 
			
		||||
    auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType)
 | 
			
		||||
                                      .dyn_cast_or_null<LLVM::LLVMType>();
 | 
			
		||||
    if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
 | 
			
		||||
      return failure();
 | 
			
		||||
    // Create descriptor.
 | 
			
		||||
    auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
 | 
			
		||||
    Type llvmTargetElementTy = desc.getElementType();
 | 
			
		||||
    // Set allocated ptr.
 | 
			
		||||
    Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
 | 
			
		||||
    allocated =
 | 
			
		||||
        rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
 | 
			
		||||
    desc.setAllocatedPtr(rewriter, loc, allocated);
 | 
			
		||||
    // Set aligned ptr.
 | 
			
		||||
    Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
 | 
			
		||||
    ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
 | 
			
		||||
    desc.setAlignedPtr(rewriter, loc, ptr);
 | 
			
		||||
 | 
			
		||||
    // Fill size and stride descriptors in memref.
 | 
			
		||||
    auto target_sizes = targetMemRefType.getShape();
 | 
			
		||||
    int64_t target_offset;
 | 
			
		||||
    llvm::SmallVector<int64_t, 4> target_strides;
 | 
			
		||||
    if (failed((getStridesAndOffset(targetMemRefType, target_strides,
 | 
			
		||||
                                    target_offset))))
 | 
			
		||||
      return failure();
 | 
			
		||||
 | 
			
		||||
    // Copy offset of `targetMemRef`.
 | 
			
		||||
    desc.setConstantOffset(rewriter, loc, target_offset);
 | 
			
		||||
    for (int i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
 | 
			
		||||
      desc.setConstantSize(rewriter, loc, i, target_sizes[i]);
 | 
			
		||||
      desc.setConstantStride(rewriter, loc, i, target_strides[i]);
 | 
			
		||||
    }
 | 
			
		||||
    rewriter.replaceOp(op, {desc});
 | 
			
		||||
    return success();
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct DynamicMemRefCastOpConverter
 | 
			
		||||
    : public ConvertOpToLLVMPattern<DynamicMemRefCastOp> {
 | 
			
		||||
  using ConvertOpToLLVMPattern<DynamicMemRefCastOp>::ConvertOpToLLVMPattern;
 | 
			
		||||
 | 
			
		||||
  LogicalResult matchAndRewrite(
 | 
			
		||||
      Operation *op, ArrayRef<Value> operands,
 | 
			
		||||
      ConversionPatternRewriter &rewriter) const override {
 | 
			
		||||
    auto loc = op->getLoc();
 | 
			
		||||
    auto cast_op = cast<DynamicMemRefCastOp>(op);
 | 
			
		||||
 | 
			
		||||
    DynamicMemRefCastOp::Adaptor operands_adaptor(operands);
 | 
			
		||||
    MemRefDescriptor sourceMemRef(operands_adaptor.operand());
 | 
			
		||||
 | 
			
		||||
    MemRefType targetMemRefType =
 | 
			
		||||
        cast_op.getResult().getType().cast<MemRefType>();
 | 
			
		||||
    auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType)
 | 
			
		||||
                                      .dyn_cast_or_null<LLVM::LLVMType>();
 | 
			
		||||
    if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy())
 | 
			
		||||
      return failure();
 | 
			
		||||
    // Create descriptor.
 | 
			
		||||
    auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy);
 | 
			
		||||
    Type llvmTargetElementTy = desc.getElementType();
 | 
			
		||||
    // Set allocated ptr.
 | 
			
		||||
    Value allocated = sourceMemRef.allocatedPtr(rewriter, loc);
 | 
			
		||||
    allocated =
 | 
			
		||||
        rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated);
 | 
			
		||||
    desc.setAllocatedPtr(rewriter, loc, allocated);
 | 
			
		||||
    // Set aligned ptr.
 | 
			
		||||
    Value ptr = sourceMemRef.alignedPtr(rewriter, loc);
 | 
			
		||||
    ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr);
 | 
			
		||||
    desc.setAlignedPtr(rewriter, loc, ptr);
 | 
			
		||||
    // Copy offset of `sourceMemRef`.
 | 
			
		||||
    desc.setOffset(rewriter, loc, sourceMemRef.offset(rewriter, loc));
 | 
			
		||||
 | 
			
		||||
    // Fill size and stride descriptors in memref.
 | 
			
		||||
    if (!cast_op.sizes().empty()) {
 | 
			
		||||
      auto sizes = operands_adaptor.sizes();
 | 
			
		||||
      auto strides = operands_adaptor.strides();
 | 
			
		||||
      for (int i = 0, e = targetMemRefType.getRank(); i < e; ++i) {
 | 
			
		||||
        desc.setSize(rewriter, loc, i, sizes[i]);
 | 
			
		||||
        desc.setStride(rewriter, loc, i, strides[i]);
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
    rewriter.replaceOp(op, {desc});
 | 
			
		||||
    return success();
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
 | 
			
		||||
void PopulateLhloToLLVMConversionPatterns(const LowerToLLVMOptions &options,
 | 
			
		||||
                                          LLVMTypeConverter *converter,
 | 
			
		||||
                                          OwningRewritePatternList *patterns) {
 | 
			
		||||
  patterns->insert<DynamicMemRefCastOpConverter, StaticMemRefCastOpConverter>(
 | 
			
		||||
      *converter, options);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace xla_lhlo
 | 
			
		||||
}  // namespace mlir
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,59 @@
 | 
			
		|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
 | 
			
		||||
 | 
			
		||||
namespace mlir {
 | 
			
		||||
namespace xla_lhlo {
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
class TestLhloToLLVMPass
 | 
			
		||||
    : public ::mlir::PassWrapper<TestLhloToLLVMPass,
 | 
			
		||||
                                 ::mlir::OperationPass<::mlir::ModuleOp>> {
 | 
			
		||||
 public:
 | 
			
		||||
  void runOnOperation() override {
 | 
			
		||||
    ModuleOp m = getOperation();
 | 
			
		||||
 | 
			
		||||
    OwningRewritePatternList patterns;
 | 
			
		||||
    LLVMTypeConverter converter(m.getContext());
 | 
			
		||||
    populateStdToLLVMConversionPatterns(converter, patterns);
 | 
			
		||||
    PopulateLhloToLLVMConversionPatterns(
 | 
			
		||||
        LowerToLLVMOptions::getDefaultOptions(), &converter, &patterns);
 | 
			
		||||
 | 
			
		||||
    ConversionTarget target(getContext());
 | 
			
		||||
    target.addLegalDialect<LLVM::LLVMDialect>();
 | 
			
		||||
    target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
 | 
			
		||||
    target.addIllegalDialect<XlaLhloDialect>();
 | 
			
		||||
 | 
			
		||||
    if (failed(applyFullConversion(m, target, patterns))) {
 | 
			
		||||
      signalPassFailure();
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
 | 
			
		||||
static PassRegistration<TestLhloToLLVMPass> legalize_lhlo_pass(
 | 
			
		||||
    "test-lhlo-legalize-to-llvm", "Legalize from LHLO dialect to LLVM.");
 | 
			
		||||
 | 
			
		||||
}  // namespace xla_lhlo
 | 
			
		||||
}  // namespace mlir
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,731 @@
 | 
			
		|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#include "third_party/absl/memory/memory.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/ArrayRef.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLExtras.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/SmallVector.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/SCF/SCF.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h"
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
 | 
			
		||||
 | 
			
		||||
namespace mlir {
 | 
			
		||||
namespace xla_lhlo {
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
// Clones and adapts the code in `lhlo_block` that works on buffers and has a
 | 
			
		||||
// single output buffer to make it compatible with `operands` that have element
 | 
			
		||||
// types of the respective buffers. Returns the computed value.
 | 
			
		||||
//
 | 
			
		||||
// Example. For `operands` with (f32, i32) types and a block with LHLO ops and
 | 
			
		||||
// with signature:
 | 
			
		||||
//   ^bb(%lhs: memref<f32>, %rhs: memref<i32>, %res: memref<i1>):
 | 
			
		||||
//     <LHLO_ops>
 | 
			
		||||
//
 | 
			
		||||
// inserts necessary alloc and store ops to compute and return result that has
 | 
			
		||||
// `i1` type.
 | 
			
		||||
Value ApplySingleResultLhloCode(Location loc, ValueRange operands,
 | 
			
		||||
                                Block* lhlo_block, OpBuilder* b) {
 | 
			
		||||
  SmallVector<Value, 2> arg_bufs;
 | 
			
		||||
  for (auto arg_type : lhlo_block->getArgumentTypes()) {
 | 
			
		||||
    arg_bufs.push_back(b->create<AllocOp>(loc, arg_type.cast<MemRefType>()));
 | 
			
		||||
  }
 | 
			
		||||
  for (auto operand : llvm::enumerate(operands)) {
 | 
			
		||||
    b->create<StoreOp>(loc, operand.value(), arg_bufs[operand.index()]);
 | 
			
		||||
  }
 | 
			
		||||
  // Clone the ops from `lhlo_block`.
 | 
			
		||||
  BlockAndValueMapping mapping;
 | 
			
		||||
  mapping.map(lhlo_block->getArguments(), arg_bufs);
 | 
			
		||||
  for (auto& nested : lhlo_block->without_terminator()) {
 | 
			
		||||
    auto clone = b->clone(nested, mapping);
 | 
			
		||||
    mapping.map(nested.getResults(), clone->getResults());
 | 
			
		||||
  }
 | 
			
		||||
  return b->create<LoadOp>(loc, arg_bufs.back());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Converts a block with LHLO ops and with signature:
 | 
			
		||||
//   ^bb(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
 | 
			
		||||
// into a reduction operator of scf.reduce by doing buffer allocation for
 | 
			
		||||
// scalar arguments and the result of `scf.reduce` to make it compatible with
 | 
			
		||||
// LHLO ops.
 | 
			
		||||
void ConvertToReductionOperator(Location loc, scf::ReduceOp reduce_op,
 | 
			
		||||
                                Block* lhlo_block, OpBuilder* b) {
 | 
			
		||||
  Block& loop_reduce_op_body = reduce_op.reductionOperator().front();
 | 
			
		||||
  OpBuilder::InsertionGuard guard(*b);
 | 
			
		||||
  b->setInsertionPointToStart(&loop_reduce_op_body);
 | 
			
		||||
  b->create<scf::ReduceReturnOp>(
 | 
			
		||||
      loc, ApplySingleResultLhloCode(loc, loop_reduce_op_body.getArguments(),
 | 
			
		||||
                                     lhlo_block, b));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Returns result of ConstantOp if `dim` is static, otherwise uses DimOp to
 | 
			
		||||
// extract dimension at runtime.
 | 
			
		||||
Value GetStaticOrDynamicDim(mlir::Location loc, Value shaped_value,
 | 
			
		||||
                            size_t dim_index, int64_t dim, OpBuilder* b) {
 | 
			
		||||
  return dim == ShapedType::kDynamicSize
 | 
			
		||||
             ? b->create<DimOp>(loc, shaped_value, dim_index).getResult()
 | 
			
		||||
             : b->create<ConstantIndexOp>(loc, dim);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
struct MappedIvs {
 | 
			
		||||
  // False if the mapped indices are in the padding area, true otherwise.
 | 
			
		||||
  Value in_bounds;
 | 
			
		||||
  // Mapped indices.
 | 
			
		||||
  SmallVector<Value, 2> ivs;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
template <typename OpTy>
 | 
			
		||||
MappedIvs MapWindowIvsToInput(OpTy op, ValueRange ivs, ValueRange window_ivs,
 | 
			
		||||
                              OpBuilder* b) {
 | 
			
		||||
  MappedIvs mapped_ivs;
 | 
			
		||||
 | 
			
		||||
  if (!op.window_strides().hasValue()) {
 | 
			
		||||
    op.emitOpError("No window strides specified.");
 | 
			
		||||
  }
 | 
			
		||||
  auto window_strides = op.window_strides().getValue();
 | 
			
		||||
 | 
			
		||||
  if (!op.padding().hasValue()) {
 | 
			
		||||
    op.emitOpError("No padding specified.");
 | 
			
		||||
  }
 | 
			
		||||
  auto padding = op.padding().getValue();
 | 
			
		||||
 | 
			
		||||
  auto loc = op.getLoc();
 | 
			
		||||
  auto operand = op.operand();
 | 
			
		||||
  auto operand_shape = operand.getType().template cast<MemRefType>().getShape();
 | 
			
		||||
 | 
			
		||||
  // `in_bounds` is false when the mapped indices are in the padding area.
 | 
			
		||||
  mapped_ivs.in_bounds = b->create<mlir::ConstantOp>(
 | 
			
		||||
      loc, b->getI1Type(), b->getIntegerAttr(b->getI1Type(), 1));
 | 
			
		||||
  for (unsigned i = 0, e = ivs.size(); i < e; ++i) {
 | 
			
		||||
    auto stride = window_strides.template getValue<llvm::APInt>(i);
 | 
			
		||||
    auto pad_low = padding.template getValue<llvm::APInt>({i, 0});
 | 
			
		||||
 | 
			
		||||
    Value stride_val = b->create<ConstantIndexOp>(loc, stride.getSExtValue());
 | 
			
		||||
    Value pad_low_val = b->create<ConstantIndexOp>(loc, pad_low.getSExtValue());
 | 
			
		||||
 | 
			
		||||
    Value center = b->create<MulIOp>(loc, ivs[i], stride_val);
 | 
			
		||||
    Value offset = b->create<SubIOp>(loc, window_ivs[i], pad_low_val);
 | 
			
		||||
    Value index = b->create<AddIOp>(loc, center, offset);
 | 
			
		||||
    Value upper_bound =
 | 
			
		||||
        GetStaticOrDynamicDim(loc, operand, i, operand_shape[i], b);
 | 
			
		||||
    // We must check whether 0 <= index_i < shape_i, as otherwise we are in
 | 
			
		||||
    // the pad and then we have to use the neutral element for reduction.
 | 
			
		||||
    // Equivalently, it can be computed as the unsigned comparison index_i <
 | 
			
		||||
    // shape_i, since a negative value wraps to a large positive value.
 | 
			
		||||
    mapped_ivs.in_bounds = b->create<mlir::AndOp>(
 | 
			
		||||
        loc, mapped_ivs.in_bounds,
 | 
			
		||||
        b->create<CmpIOp>(loc, CmpIPredicate::ult, index, upper_bound));
 | 
			
		||||
    mapped_ivs.ivs.push_back(index);
 | 
			
		||||
  }
 | 
			
		||||
  return mapped_ivs;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Returns scf::Parallel over a shaped value with static or dynamic shape.
 | 
			
		||||
scf::ParallelOp MakeLoopOverShape(Location loc, Value shaped_value,
 | 
			
		||||
                                  OpBuilder* b) {
 | 
			
		||||
  Value zero = b->create<ConstantIndexOp>(loc, 0);
 | 
			
		||||
  Value one = b->create<ConstantIndexOp>(loc, 1);
 | 
			
		||||
 | 
			
		||||
  ArrayRef<int64_t> shape =
 | 
			
		||||
      shaped_value.getType().cast<ShapedType>().getShape();
 | 
			
		||||
  SmallVector<Value, 2> lower, upper, step;
 | 
			
		||||
  for (auto dim : llvm::enumerate(shape)) {
 | 
			
		||||
    upper.push_back(
 | 
			
		||||
        GetStaticOrDynamicDim(loc, shaped_value, dim.index(), dim.value(), b));
 | 
			
		||||
    lower.push_back(zero);
 | 
			
		||||
    step.push_back(one);
 | 
			
		||||
  }
 | 
			
		||||
  return b->create<scf::ParallelOp>(loc, lower, upper, step);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Converts `xla_lhlo.ReduceOp` into two scf::ParallelOp and a scf::ReduceOp.
 | 
			
		||||
// The outper `ParallelOp` refers to the parallel loops if there are
 | 
			
		||||
// any. The inner `ParalleOp` refers to the reduction loops and `ReduceOp`
 | 
			
		||||
// contains the reduction operator.
 | 
			
		||||
//
 | 
			
		||||
// Example:
 | 
			
		||||
//
 | 
			
		||||
//  "xla_lhlo.reduce"(%buffer, %init_buf, %result) ( {
 | 
			
		||||
//    ^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
 | 
			
		||||
//      <LHLO ops>
 | 
			
		||||
//    } ) {dimensions = dense<[1]> : tensor<1xi64>}
 | 
			
		||||
//      : (memref<100x10x5xf32>, memref<f32>, memref<100x5xf32>) -> ()
 | 
			
		||||
//
 | 
			
		||||
//  is roughly converted into:
 | 
			
		||||
//
 | 
			
		||||
//  %init = load %init_buf[] : memref<f32>
 | 
			
		||||
//  scf.parallel (%i, %k) = (%c0, %c0) to (%c100, %c5) step (%c1, %c1) {
 | 
			
		||||
//    %result = scf.parallel (%j) = (%c0) to (%c10) step (%c1) init (%init) {
 | 
			
		||||
//      %elem_to_reduce = load %buffer[%i, %j, %k] : memref<100x10x5xf32>
 | 
			
		||||
//      scf.reduce(%elem_to_reduce)  {
 | 
			
		||||
//        ^bb0(%elem: f32, %acc: f32):   // no predecessors
 | 
			
		||||
//          elem_buf = alloc() : memref<f32>
 | 
			
		||||
//          store %elem, elem_buf[] : memref<f32>
 | 
			
		||||
//          acc_buf = alloc() : memref<f32>
 | 
			
		||||
//          store %acc, acc_buf[] : memref<f32>
 | 
			
		||||
//          <LHLO_ops>
 | 
			
		||||
//          %acc_result = load acc_buf[] : memref<f32>
 | 
			
		||||
//          scf.reduce.return %acc_result : f32
 | 
			
		||||
//      } : f32
 | 
			
		||||
//      scf.yield
 | 
			
		||||
//    } : f32
 | 
			
		||||
//    scf.yield
 | 
			
		||||
//  }
 | 
			
		||||
class ReduceOpConverter : public OpConversionPattern<xla_lhlo::ReduceOp> {
 | 
			
		||||
 public:
 | 
			
		||||
  using OpConversionPattern<xla_lhlo::ReduceOp>::OpConversionPattern;
 | 
			
		||||
 | 
			
		||||
  LogicalResult matchAndRewrite(
 | 
			
		||||
      xla_lhlo::ReduceOp xla_reduce_op, ArrayRef<Value> /*args*/,
 | 
			
		||||
      ConversionPatternRewriter& rewriter) const final {
 | 
			
		||||
    // TODO(b/137624192) Implement variadic reduce.
 | 
			
		||||
    if (xla_reduce_op.out().size() != 1) return failure();
 | 
			
		||||
 | 
			
		||||
    scf::ReduceOp reduce_op =
 | 
			
		||||
        CreateReduceOpInNestedParallelLoops(xla_reduce_op, &rewriter);
 | 
			
		||||
    ConvertToReductionOperator(xla_reduce_op.getLoc(), reduce_op,
 | 
			
		||||
                               &xla_reduce_op.body().front(), &rewriter);
 | 
			
		||||
    rewriter.replaceOp(xla_reduce_op, llvm::None);
 | 
			
		||||
    return success();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  // Creates nested `scf.parallel` ops with `scf.reduce`. The outer ParallelOp
 | 
			
		||||
  // refers to the parallel dimensions of `xla_reduce_op` if any and the inner
 | 
			
		||||
  // ParallelOp refers to the reduction dimensions. The scf.reduce op is
 | 
			
		||||
  // returned.
 | 
			
		||||
  //
 | 
			
		||||
  // If the reduction argument is a memref<100x10x5xf32> and the
 | 
			
		||||
  // reduction is performed along dimension 1 then this method will generate
 | 
			
		||||
  //
 | 
			
		||||
  //  %init = load %init_buf[] : memref<f32>
 | 
			
		||||
  //  scf.parallel (%i, %k) = (%c0, %c0) to (%c100, %c5) step (%c1, %c1) {
 | 
			
		||||
  //    %result = scf.parallel (%j) = (%c0) to (%c10) step (%c1) init (%init) {
 | 
			
		||||
  //      %elem_to_reduce = load %buffer[%i, %j, %k] : memref<100x10x5xf32>
 | 
			
		||||
  //      scf.reduce(%elem_to_reduce)  {
 | 
			
		||||
  //        <THE BLOCK PTR TO BE RETURNED>
 | 
			
		||||
  //      } : f32
 | 
			
		||||
  //      scf.yield
 | 
			
		||||
  //    } : f32
 | 
			
		||||
  //    scf.yield
 | 
			
		||||
  //  }
 | 
			
		||||
  scf::ReduceOp CreateReduceOpInNestedParallelLoops(
 | 
			
		||||
      xla_lhlo::ReduceOp xla_reduce_op,
 | 
			
		||||
      ConversionPatternRewriter* rewriter) const {
 | 
			
		||||
    auto loc = xla_reduce_op.getLoc();
 | 
			
		||||
    DenseSet<int> reducing_dims;
 | 
			
		||||
    for (const auto& rdim : xla_reduce_op.dimensions().getIntValues()) {
 | 
			
		||||
      reducing_dims.insert(rdim.getSExtValue());
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    Value operand = *xla_reduce_op.operands().begin();
 | 
			
		||||
    Value out = *xla_reduce_op.out().begin();
 | 
			
		||||
    SmallVector<Value, 2> parallel_lower, parallel_upper, parallel_step;
 | 
			
		||||
    SmallVector<Value, 2> reduce_lower, reduce_upper, reduce_step;
 | 
			
		||||
    auto operand_shape = operand.getType().cast<MemRefType>().getShape();
 | 
			
		||||
    for (auto dim : llvm::enumerate(operand_shape)) {
 | 
			
		||||
      const bool is_reducing_dim = reducing_dims.count(dim.index());
 | 
			
		||||
 | 
			
		||||
      Value ub = GetStaticOrDynamicDim(loc, operand, dim.index(), dim.value(),
 | 
			
		||||
                                       rewriter);
 | 
			
		||||
      Value lb = rewriter->create<ConstantIndexOp>(loc, 0);
 | 
			
		||||
      Value step = rewriter->create<ConstantIndexOp>(loc, 1);
 | 
			
		||||
      (is_reducing_dim ? reduce_lower : parallel_lower).push_back(lb);
 | 
			
		||||
      (is_reducing_dim ? reduce_upper : parallel_upper).push_back(ub);
 | 
			
		||||
      (is_reducing_dim ? reduce_step : parallel_step).push_back(step);
 | 
			
		||||
    }
 | 
			
		||||
    // Load initial value from memref<element_type>.
 | 
			
		||||
    SmallVector<Value, 1> init_value = {
 | 
			
		||||
        rewriter->create<LoadOp>(loc, *xla_reduce_op.init_values().begin())};
 | 
			
		||||
    // Outer ParallelOp is not needed if it is a reduction across all dims.
 | 
			
		||||
    scf::ParallelOp outer;
 | 
			
		||||
    if (!parallel_lower.empty()) {
 | 
			
		||||
      outer = rewriter->create<scf::ParallelOp>(loc, parallel_lower,
 | 
			
		||||
                                                parallel_upper, parallel_step);
 | 
			
		||||
      rewriter->setInsertionPointToStart(outer.getBody());
 | 
			
		||||
    }
 | 
			
		||||
    scf::ParallelOp inner = rewriter->create<scf::ParallelOp>(
 | 
			
		||||
        loc, reduce_lower, reduce_upper, reduce_step, ValueRange(init_value));
 | 
			
		||||
    Value reduction_result = *inner.getResults().begin();
 | 
			
		||||
 | 
			
		||||
    SmallVector<Value, 1> out_indices;
 | 
			
		||||
    if (outer != nullptr) {
 | 
			
		||||
      out_indices.reserve(outer.getNumLoops());
 | 
			
		||||
      for (Value iv : outer.getInductionVars()) {
 | 
			
		||||
        out_indices.push_back(iv);
 | 
			
		||||
      }
 | 
			
		||||
    } else {
 | 
			
		||||
      out_indices.push_back(rewriter->create<ConstantIndexOp>(loc, 0));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    rewriter->create<StoreOp>(loc, reduction_result, out, out_indices);
 | 
			
		||||
 | 
			
		||||
    // Load the element to reduce.
 | 
			
		||||
    SmallVector<Value, 2> indices;
 | 
			
		||||
    indices.reserve(operand_shape.size());
 | 
			
		||||
 | 
			
		||||
    if (outer) {
 | 
			
		||||
      auto inner_ivs_it = inner.getInductionVars().begin();
 | 
			
		||||
      auto outer_ivs_it = outer.getInductionVars().begin();
 | 
			
		||||
      for (unsigned i = 0, e = operand_shape.size(); i < e; ++i) {
 | 
			
		||||
        indices.push_back(reducing_dims.count(i) ? *inner_ivs_it++
 | 
			
		||||
                                                 : *outer_ivs_it++);
 | 
			
		||||
      }
 | 
			
		||||
    } else {
 | 
			
		||||
      indices = inner.getInductionVars();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    rewriter->setInsertionPointToStart(inner.getBody());
 | 
			
		||||
    Value elem = rewriter->create<mlir::LoadOp>(
 | 
			
		||||
        loc, *xla_reduce_op.operands().begin(), indices);
 | 
			
		||||
    return rewriter->create<scf::ReduceOp>(loc, elem);
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// Pseudocode:
 | 
			
		||||
// for each index O in output
 | 
			
		||||
//   accumulator = neutral_value
 | 
			
		||||
//   in_bounds = true
 | 
			
		||||
//   for each index W in window
 | 
			
		||||
//     for each dimension i from 0 to rank - 1
 | 
			
		||||
//       index = O[i] * stride[i] + W[i] - pad_low[i]
 | 
			
		||||
//       in_bounds = inbounds && (index `ult` shape[i])
 | 
			
		||||
//       I[i] = index
 | 
			
		||||
//     if (in_bounds)
 | 
			
		||||
//       value = input[I]
 | 
			
		||||
//     else
 | 
			
		||||
//       value = neutral_value
 | 
			
		||||
//     accumulator = reduction_operator(output[O], value)
 | 
			
		||||
//   output[O] = accumulator
 | 
			
		||||
//
 | 
			
		||||
// Converts `xla_lhlo.ReduceWindowOp` into two scf::ParallelOp and a
 | 
			
		||||
// scf::ReduceOp.
 | 
			
		||||
// The outper `ParallelOp` refers to the parallel loops that traverese output
 | 
			
		||||
// buffer. The inner `ParalleOp` refers to the reduction loops that traverse
 | 
			
		||||
// reduction windows and `ReduceOp` contains the reduction operator.
 | 
			
		||||
//
 | 
			
		||||
// Example:
 | 
			
		||||
//
 | 
			
		||||
// func @reduce_window(%arg: memref<112x112xf32>,
 | 
			
		||||
//              %init: memref<f32>,
 | 
			
		||||
//              %result: memref<56x56xf32>) {
 | 
			
		||||
//   "xla_lhlo.reduce_window"(%arg, %init, %result) ( {
 | 
			
		||||
//     ^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
 | 
			
		||||
//       "xla_lhlo.maximum"(%lhs, %rhs, %res)
 | 
			
		||||
//         : (memref<f32>, memref<f32>, memref<f32>) -> ()
 | 
			
		||||
//       "xla_lhlo.terminator"() : () -> ()
 | 
			
		||||
//     }) {
 | 
			
		||||
//       padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>,
 | 
			
		||||
//       window_dimensions = dense<[3, 3]> : tensor<2xi64>,
 | 
			
		||||
//       window_strides = dense<[2, 2]> : tensor<2xi64>
 | 
			
		||||
//     } : (memref<112x112xf32>, memref<f32>, memref<56x56xf32>) -> ()
 | 
			
		||||
//   return
 | 
			
		||||
// }
 | 
			
		||||
//
 | 
			
		||||
// is roughly converted into:
 | 
			
		||||
//
 | 
			
		||||
//    %neutral_elem = load %init_buf[] : memref<f32>
 | 
			
		||||
//    scf.parallel (%i, %j) = (%c0, %c0) to (%c56, %c56) step (%c1, %c1) {
 | 
			
		||||
//      %result = scf.parallel (%iw, %jw) = (%c0, %c0)
 | 
			
		||||
//                  to (%c3, %c3) step (%c1, %c1) neutral_elem (%0) -> f32 {
 | 
			
		||||
//        %in_bounds = <COMPUTE IF INDEX IS IN OPERAND'S pad>
 | 
			
		||||
//        %elem = load %operand[%computed_i, %computed_j]
 | 
			
		||||
//        %elem_or_neutral = select %in_bounds, %elem, %neutral_elem : f32
 | 
			
		||||
//        scf.reduce(%elem_to_reduce)  : f32 {
 | 
			
		||||
//          ^bb0(%arg7: f32, %arg8: f32):
 | 
			
		||||
//            <LHLO ops>
 | 
			
		||||
//        }
 | 
			
		||||
//        scf.yield
 | 
			
		||||
//      }
 | 
			
		||||
//      store %result, %output_buffer[%i, %j] : memref<56x56xf32>
 | 
			
		||||
//      scf.yield
 | 
			
		||||
//    }
 | 
			
		||||
//    return
 | 
			
		||||
//  }
 | 
			
		||||
class ReduceWindowOpConverter
 | 
			
		||||
    : public OpConversionPattern<xla_lhlo::ReduceWindowOp> {
 | 
			
		||||
 public:
 | 
			
		||||
  using OpConversionPattern<xla_lhlo::ReduceWindowOp>::OpConversionPattern;
 | 
			
		||||
 | 
			
		||||
  LogicalResult matchAndRewrite(
 | 
			
		||||
      xla_lhlo::ReduceWindowOp xla_reduce_window_op, ArrayRef<Value> /*args*/,
 | 
			
		||||
      ConversionPatternRewriter& rewriter) const final {
 | 
			
		||||
    scf::ParallelOp output_loop, window_loop;
 | 
			
		||||
    std::tie(output_loop, window_loop) =
 | 
			
		||||
        CreateParallelLoopsToTraverseOutputAndWindow(xla_reduce_window_op,
 | 
			
		||||
                                                     &rewriter);
 | 
			
		||||
 | 
			
		||||
    scf::ReduceOp reduce_op = CreateReduceOpInNestedParallelLoops(
 | 
			
		||||
        xla_reduce_window_op, output_loop, window_loop, &rewriter);
 | 
			
		||||
 | 
			
		||||
    ConvertToReductionOperator(xla_reduce_window_op.getLoc(), reduce_op,
 | 
			
		||||
                               &xla_reduce_window_op.body().front(), &rewriter);
 | 
			
		||||
    rewriter.replaceOp(xla_reduce_window_op, llvm::None);
 | 
			
		||||
    return success();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  std::pair<scf::ParallelOp, scf::ParallelOp>
 | 
			
		||||
  CreateParallelLoopsToTraverseOutputAndWindow(
 | 
			
		||||
      xla_lhlo::ReduceWindowOp xla_reduce_window_op,
 | 
			
		||||
      ConversionPatternRewriter* rewriter) const {
 | 
			
		||||
    auto loc = xla_reduce_window_op.getLoc();
 | 
			
		||||
    Value init_value =
 | 
			
		||||
        rewriter->create<LoadOp>(loc, xla_reduce_window_op.init_value());
 | 
			
		||||
 | 
			
		||||
    Value zero = rewriter->create<ConstantIndexOp>(loc, 0);
 | 
			
		||||
    Value one = rewriter->create<ConstantIndexOp>(loc, 1);
 | 
			
		||||
 | 
			
		||||
    // Create an outer parallel loop that spans the output of ReduceWindowOp.
 | 
			
		||||
    Value xla_output = xla_reduce_window_op.out();
 | 
			
		||||
    auto output_loop = MakeLoopOverShape(loc, xla_output, rewriter);
 | 
			
		||||
 | 
			
		||||
    // Create a nested loop that traverses the window.
 | 
			
		||||
    SmallVector<Value, 2> window_lower, window_upper, window_step;
 | 
			
		||||
    rewriter->setInsertionPointToStart(output_loop.getBody());
 | 
			
		||||
    for (const auto& window_dim : xla_reduce_window_op.window_dimensions()) {
 | 
			
		||||
      window_step.push_back(one);
 | 
			
		||||
      window_lower.push_back(zero);
 | 
			
		||||
      window_upper.push_back(
 | 
			
		||||
          rewriter->create<ConstantIndexOp>(loc, window_dim.getSExtValue()));
 | 
			
		||||
    }
 | 
			
		||||
    auto window_loop = rewriter->create<scf::ParallelOp>(
 | 
			
		||||
        loc, window_lower, window_upper, window_step, ValueRange(init_value));
 | 
			
		||||
 | 
			
		||||
    Value reduction_result = *window_loop.getResults().begin();
 | 
			
		||||
    auto output_ivs = output_loop.getInductionVars();
 | 
			
		||||
    rewriter->create<StoreOp>(loc, reduction_result, xla_output, output_ivs);
 | 
			
		||||
    return std::make_pair(output_loop, window_loop);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  scf::ReduceOp CreateReduceOpInNestedParallelLoops(
 | 
			
		||||
      xla_lhlo::ReduceWindowOp xla_reduce_window_op,
 | 
			
		||||
      scf::ParallelOp output_loop, scf::ParallelOp window_loop,
 | 
			
		||||
      ConversionPatternRewriter* rewriter) const {
 | 
			
		||||
    rewriter->setInsertionPointToStart(window_loop.getBody());
 | 
			
		||||
    auto loc = xla_reduce_window_op.getLoc();
 | 
			
		||||
 | 
			
		||||
    if (xla_reduce_window_op.base_dilations().hasValue() ||
 | 
			
		||||
        xla_reduce_window_op.window_dilations().hasValue()) {
 | 
			
		||||
      xla_reduce_window_op.emitRemark(
 | 
			
		||||
          "Lowering to parallel loops does not support `base_dilations` or "
 | 
			
		||||
          "`window_dilations` attributes yet. The attributes will be ignored.");
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    Value xla_operand = xla_reduce_window_op.operand();
 | 
			
		||||
    auto xla_operand_type = xla_operand.getType().cast<MemRefType>();
 | 
			
		||||
 | 
			
		||||
    // Compute ivs in 'arg' buffer and whether these ivs are in pad area or not.
 | 
			
		||||
    MappedIvs mapped_ivs = MapWindowIvsToInput(
 | 
			
		||||
        xla_reduce_window_op, output_loop.getInductionVars(),
 | 
			
		||||
        window_loop.getInductionVars(), rewriter);
 | 
			
		||||
 | 
			
		||||
    auto elem_or_init = rewriter->create<scf::IfOp>(
 | 
			
		||||
        loc, xla_operand_type.getElementType(), mapped_ivs.in_bounds,
 | 
			
		||||
        /*withElseRegion=*/true);
 | 
			
		||||
 | 
			
		||||
    OpBuilder then_builder = elem_or_init.getThenBodyBuilder();
 | 
			
		||||
    Value elem = then_builder.create<mlir::LoadOp>(
 | 
			
		||||
        loc, xla_reduce_window_op.operand(), mapped_ivs.ivs);
 | 
			
		||||
    then_builder.create<scf::YieldOp>(loc, elem);
 | 
			
		||||
 | 
			
		||||
    OpBuilder else_builder = elem_or_init.getElseBodyBuilder();
 | 
			
		||||
    else_builder.create<scf::YieldOp>(loc, *window_loop.initVals().begin());
 | 
			
		||||
 | 
			
		||||
    return rewriter->create<scf::ReduceOp>(loc,
 | 
			
		||||
                                           *elem_or_init.results().begin());
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// See the operation semantics in
 | 
			
		||||
// https://www.tensorflow.org/xla/operation_semantics#selectandscatter
 | 
			
		||||
//
 | 
			
		||||
// Pseudocode:
 | 
			
		||||
//  scf.parallel(coordinates O in the output):
 | 
			
		||||
//    output[O] = init
 | 
			
		||||
//  scf.parallel(coordinates S in the source):
 | 
			
		||||
//    selected_ivs = 0
 | 
			
		||||
//    selected_val = 0
 | 
			
		||||
//    initialized_flag = false
 | 
			
		||||
//    scf.for (first dim W_1 in the window)
 | 
			
		||||
//         iter_args (selected_ivs, selected_val, initialized_flag):
 | 
			
		||||
//    ...
 | 
			
		||||
//      scf.for (last dim W_N in the window):
 | 
			
		||||
//           iter_args (selected_ivs, selected_val, initialized_flag):
 | 
			
		||||
//        I = S * stride + W - pad_low
 | 
			
		||||
//        if I within bounds of operand:
 | 
			
		||||
//          if (initialized_flag):
 | 
			
		||||
//            pred = select(selected_value, operand(I))):
 | 
			
		||||
//            if (pred)
 | 
			
		||||
//              selected_value = operand(I)
 | 
			
		||||
//              selected_index = I
 | 
			
		||||
//          else
 | 
			
		||||
//              selected_value = operand(I)
 | 
			
		||||
//              selected_index = I
 | 
			
		||||
//              initialized_flag = true
 | 
			
		||||
//    output(selected_index) = scatter(output(selected_index), source(S))
 | 
			
		||||
class SelectAndScatterOpConverter
 | 
			
		||||
    : public OpConversionPattern<xla_lhlo::SelectAndScatterOp> {
 | 
			
		||||
 public:
 | 
			
		||||
  using OpConversionPattern<xla_lhlo::SelectAndScatterOp>::OpConversionPattern;
 | 
			
		||||
 | 
			
		||||
  LogicalResult matchAndRewrite(
 | 
			
		||||
      xla_lhlo::SelectAndScatterOp s_and_s_op, ArrayRef<Value> /*args*/,
 | 
			
		||||
      ConversionPatternRewriter& rewriter) const final {
 | 
			
		||||
    auto loc = s_and_s_op.getLoc();
 | 
			
		||||
    InitializeOutput(s_and_s_op, &rewriter);
 | 
			
		||||
    scf::ParallelOp loop_over_src =
 | 
			
		||||
        MakeLoopOverShape(loc, s_and_s_op.source(), &rewriter);
 | 
			
		||||
    rewriter.setInsertionPointToStart(loop_over_src.getBody());
 | 
			
		||||
 | 
			
		||||
    // Compute indices of the selected element in the window.
 | 
			
		||||
    auto selected_ivs = SelectIvs(s_and_s_op, loop_over_src, &rewriter);
 | 
			
		||||
 | 
			
		||||
    // Load `source[selected_ivs]`.
 | 
			
		||||
    auto src_elem = rewriter.create<LoadOp>(loc, s_and_s_op.source(),
 | 
			
		||||
                                            loop_over_src.getInductionVars());
 | 
			
		||||
 | 
			
		||||
    // Compute `out[selected_ivs]` = scatter(out[selected_ivs], src_element)`.
 | 
			
		||||
    auto rmw = rewriter.create<GenericAtomicRMWOp>(loc, s_and_s_op.out(),
 | 
			
		||||
                                                   selected_ivs);
 | 
			
		||||
    OpBuilder rmw_builder = OpBuilder::atBlockEnd(rmw.getBody());
 | 
			
		||||
    auto acc_result =
 | 
			
		||||
        ApplySingleResultLhloCode(loc, {src_elem, rmw.getCurrentValue()},
 | 
			
		||||
                                  &s_and_s_op.scatter().front(), &rmw_builder);
 | 
			
		||||
    rmw_builder.create<AtomicYieldOp>(loc, acc_result);
 | 
			
		||||
 | 
			
		||||
    rewriter.replaceOp(s_and_s_op, llvm::None);
 | 
			
		||||
    return success();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  void InitializeOutput(xla_lhlo::SelectAndScatterOp s_and_s_op,
 | 
			
		||||
                        OpBuilder* b) const {
 | 
			
		||||
    auto loc = s_and_s_op.getLoc();
 | 
			
		||||
    Value init_value = b->create<LoadOp>(loc, s_and_s_op.init_value());
 | 
			
		||||
 | 
			
		||||
    scf::ParallelOp loop_over_output =
 | 
			
		||||
        MakeLoopOverShape(loc, s_and_s_op.out(), b);
 | 
			
		||||
    OpBuilder::InsertionGuard guard(*b);
 | 
			
		||||
    b->setInsertionPointToStart(loop_over_output.getBody());
 | 
			
		||||
    b->create<StoreOp>(loc, init_value, s_and_s_op.out(),
 | 
			
		||||
                       loop_over_output.getInductionVars());
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  struct WindowLoops {
 | 
			
		||||
    SmallVector<Value, 2> selected_ivs;
 | 
			
		||||
    SmallVector<Value, 2> window_ivs;
 | 
			
		||||
    scf::ForOp inner_loop;
 | 
			
		||||
  };
 | 
			
		||||
  WindowLoops InsertWindowLoops(xla_lhlo::SelectAndScatterOp s_and_s_op,
 | 
			
		||||
                                scf::ParallelOp loop_over_src,
 | 
			
		||||
                                OpBuilder* b) const {
 | 
			
		||||
    auto loc = s_and_s_op.getLoc();
 | 
			
		||||
    Value zero = b->create<ConstantIndexOp>(loc, 0);
 | 
			
		||||
    Value one = b->create<ConstantIndexOp>(loc, 1);
 | 
			
		||||
 | 
			
		||||
    auto element_type =
 | 
			
		||||
        s_and_s_op.out().getType().cast<MemRefType>().getElementType();
 | 
			
		||||
    auto rank = loop_over_src.getNumLoops();
 | 
			
		||||
 | 
			
		||||
    // `iter_args` = [iv_1, ..., iv_N, selected_value, is_initialized]
 | 
			
		||||
    SmallVector<Value, 4> iter_args(rank, zero);
 | 
			
		||||
    iter_args.push_back(b->create<mlir::ConstantOp>(
 | 
			
		||||
        loc, element_type, b->getFloatAttr(element_type, 0)));
 | 
			
		||||
    iter_args.push_back(b->create<mlir::ConstantOp>(
 | 
			
		||||
        loc, b->getI1Type(), b->getIntegerAttr(b->getI1Type(), 0)));
 | 
			
		||||
 | 
			
		||||
    // Create a nested loop that traverses the window.
 | 
			
		||||
    OpBuilder::InsertPoint ip;
 | 
			
		||||
    WindowLoops result;
 | 
			
		||||
    for (const auto& window_dim :
 | 
			
		||||
         s_and_s_op.window_dimensions()->getIntValues()) {
 | 
			
		||||
      Value upper = b->create<ConstantIndexOp>(loc, window_dim.getSExtValue());
 | 
			
		||||
      result.inner_loop =
 | 
			
		||||
          b->create<scf::ForOp>(loc, zero, upper, one, iter_args);
 | 
			
		||||
      if (b->getInsertionBlock() == loop_over_src.getBody()) {
 | 
			
		||||
        ip = b->saveInsertionPoint();
 | 
			
		||||
        result.selected_ivs = result.inner_loop.getResults().take_front(rank);
 | 
			
		||||
      } else {
 | 
			
		||||
        b->create<scf::YieldOp>(loc, result.inner_loop.getResults());
 | 
			
		||||
      }
 | 
			
		||||
      b->setInsertionPointToStart(result.inner_loop.getBody());
 | 
			
		||||
      iter_args = ValueRange{result.inner_loop.getRegionIterArgs()};
 | 
			
		||||
      result.window_ivs.push_back(result.inner_loop.getInductionVar());
 | 
			
		||||
    }
 | 
			
		||||
    b->restoreInsertionPoint(ip);
 | 
			
		||||
    return result;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Adapter to store iteration arguments of sequential loops that perform
 | 
			
		||||
  // select in a window.
 | 
			
		||||
  class IterArgs {
 | 
			
		||||
   public:
 | 
			
		||||
    explicit IterArgs(ValueRange ivs_val_flag) : ivs_val_flag_(ivs_val_flag) {}
 | 
			
		||||
    IterArgs(ValueRange ivs, Value value, Value flag) {
 | 
			
		||||
      ivs_val_flag_ = ivs;
 | 
			
		||||
      ivs_val_flag_.push_back(value);
 | 
			
		||||
      ivs_val_flag_.push_back(flag);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    ArrayRef<Value> to_vector() const { return ivs_val_flag_; }
 | 
			
		||||
 | 
			
		||||
    // Indices of the currently selected value.
 | 
			
		||||
    ArrayRef<Value> ivs() const { return to_vector().drop_back(2); }
 | 
			
		||||
    // Currently selected value w.r.t. select() function.
 | 
			
		||||
    Value value() const { return ivs_val_flag_.end()[-2]; }
 | 
			
		||||
    // i1 flag if value() and ivs() were initialized.
 | 
			
		||||
    Value is_init() const { return ivs_val_flag_.back(); }
 | 
			
		||||
 | 
			
		||||
   private:
 | 
			
		||||
    // Vector that stores iv_1, ..., iv_N, value, init.
 | 
			
		||||
    SmallVector<Value, 4> ivs_val_flag_;
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  SmallVector<Value, 2> SelectIvs(xla_lhlo::SelectAndScatterOp s_and_s_op,
 | 
			
		||||
                                  scf::ParallelOp loop_over_src,
 | 
			
		||||
                                  OpBuilder* b) const {
 | 
			
		||||
    auto loc = s_and_s_op.getLoc();
 | 
			
		||||
 | 
			
		||||
    WindowLoops window_loops = InsertWindowLoops(s_and_s_op, loop_over_src, b);
 | 
			
		||||
    auto inner_loop_b =
 | 
			
		||||
        OpBuilder::atBlockEnd(window_loops.inner_loop.getBody());
 | 
			
		||||
 | 
			
		||||
    // Compute ivs in 'arg' buffer and whether these ivs are in the pad area.
 | 
			
		||||
    MappedIvs mapped_ivs =
 | 
			
		||||
        MapWindowIvsToInput(s_and_s_op, loop_over_src.getInductionVars(),
 | 
			
		||||
                            window_loops.window_ivs, &inner_loop_b);
 | 
			
		||||
 | 
			
		||||
    IterArgs ivs_val_flag(window_loops.inner_loop.getRegionIterArgs());
 | 
			
		||||
 | 
			
		||||
    auto if_in_bounds = inner_loop_b.create<scf::IfOp>(
 | 
			
		||||
        loc, window_loops.inner_loop.getResultTypes(), mapped_ivs.in_bounds,
 | 
			
		||||
        /*withElseRegion=*/true);
 | 
			
		||||
 | 
			
		||||
    // Case when we are inside boundaries of 'arg' and not in the pad area.
 | 
			
		||||
    {
 | 
			
		||||
      OpBuilder in_bounds_then_b = if_in_bounds.getThenBodyBuilder();
 | 
			
		||||
      auto select_or_init_results = SelectOrInitialize(
 | 
			
		||||
          s_and_s_op, mapped_ivs.ivs, &ivs_val_flag, &in_bounds_then_b);
 | 
			
		||||
      in_bounds_then_b.create<scf::YieldOp>(loc, select_or_init_results);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Case when we are in the pad.
 | 
			
		||||
    {
 | 
			
		||||
      OpBuilder in_bounds_else_b = if_in_bounds.getElseBodyBuilder();
 | 
			
		||||
      in_bounds_else_b.create<scf::YieldOp>(loc, ivs_val_flag.to_vector());
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    inner_loop_b.create<scf::YieldOp>(loc, if_in_bounds.getResults());
 | 
			
		||||
    return window_loops.selected_ivs;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  SmallVector<Value, 4> SelectOrInitialize(
 | 
			
		||||
      xla_lhlo::SelectAndScatterOp s_and_s_op, ArrayRef<Value> operand_ivs,
 | 
			
		||||
      IterArgs* ivs_val_flag, OpBuilder* b) const {
 | 
			
		||||
    auto loc = s_and_s_op.getLoc();
 | 
			
		||||
    Value true_i1 = b->create<mlir::ConstantOp>(
 | 
			
		||||
        loc, b->getI1Type(), b->getIntegerAttr(b->getI1Type(), 1));
 | 
			
		||||
 | 
			
		||||
    TypeRange iter_arg_types{ivs_val_flag->to_vector()};
 | 
			
		||||
    Value operand_elem =
 | 
			
		||||
        b->create<LoadOp>(loc, s_and_s_op.operand(), operand_ivs);
 | 
			
		||||
    auto if_init =
 | 
			
		||||
        b->create<scf::IfOp>(loc, iter_arg_types, ivs_val_flag->is_init(),
 | 
			
		||||
                             /*withElseRegion=*/true);
 | 
			
		||||
    // Init == true, i.e. iter args are already initialized with a selected
 | 
			
		||||
    // element in boundaries of the operand. Select function has to be computed
 | 
			
		||||
    // here.
 | 
			
		||||
    {
 | 
			
		||||
      OpBuilder if_init_then_b = if_init.getThenBodyBuilder();
 | 
			
		||||
 | 
			
		||||
      auto& lhlo_select = s_and_s_op.select().front();
 | 
			
		||||
      Value pred =
 | 
			
		||||
          ApplySingleResultLhloCode(loc, {operand_elem, ivs_val_flag->value()},
 | 
			
		||||
                                    &lhlo_select, &if_init_then_b);
 | 
			
		||||
 | 
			
		||||
      auto if_pred = if_init_then_b.create<scf::IfOp>(loc, iter_arg_types, pred,
 | 
			
		||||
                                                      /*withElseRegion=*/true);
 | 
			
		||||
 | 
			
		||||
      // Pred == true, therefore pack newly selected ivs, val and init flag back
 | 
			
		||||
      // to iter_args and return.
 | 
			
		||||
      {
 | 
			
		||||
        OpBuilder if_pred_then_b = if_pred.getThenBodyBuilder();
 | 
			
		||||
        if_pred_then_b.create<scf::YieldOp>(
 | 
			
		||||
            loc, IterArgs{operand_ivs, operand_elem, true_i1}.to_vector());
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      // Pred == false, therefore return old iter_args.
 | 
			
		||||
      {
 | 
			
		||||
        OpBuilder if_pred_else_b = if_pred.getElseBodyBuilder();
 | 
			
		||||
        if_pred_else_b.create<scf::YieldOp>(loc, ivs_val_flag->to_vector());
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      if_init_then_b.create<scf::YieldOp>(loc, if_pred.getResults());
 | 
			
		||||
    }
 | 
			
		||||
    // Init == false, i.e. only pad was visited before and this is the first
 | 
			
		||||
    // element in the boundaries of the operand.
 | 
			
		||||
    {
 | 
			
		||||
      OpBuilder if_init_else_b = if_init.getElseBodyBuilder();
 | 
			
		||||
 | 
			
		||||
      if_init_else_b.create<scf::YieldOp>(
 | 
			
		||||
          loc, IterArgs{operand_ivs, operand_elem, true_i1}.to_vector());
 | 
			
		||||
    }
 | 
			
		||||
    return if_init.getResults();
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct LhloLegalizeToParallelLoops
 | 
			
		||||
    : public PassWrapper<LhloLegalizeToParallelLoops, FunctionPass> {
 | 
			
		||||
  void runOnFunction() override {
 | 
			
		||||
    auto func = getFunction();
 | 
			
		||||
 | 
			
		||||
    OwningRewritePatternList patterns;
 | 
			
		||||
    // clang-format off
 | 
			
		||||
    patterns.insert<
 | 
			
		||||
        ReduceOpConverter,
 | 
			
		||||
        ReduceWindowOpConverter,
 | 
			
		||||
        SelectAndScatterOpConverter
 | 
			
		||||
      >(func.getContext());
 | 
			
		||||
    // clang-format on
 | 
			
		||||
 | 
			
		||||
    ConversionTarget target(getContext());
 | 
			
		||||
    target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,
 | 
			
		||||
                           scf::SCFDialect, XlaLhloDialect>();
 | 
			
		||||
    target.addIllegalOp<xla_lhlo::ReduceOp, xla_lhlo::ReduceWindowOp,
 | 
			
		||||
                        xla_lhlo::SelectAndScatterOp>();
 | 
			
		||||
 | 
			
		||||
    if (failed(applyPartialConversion(func, target, patterns))) {
 | 
			
		||||
      signalPassFailure();
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
 | 
			
		||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeLhloToParallelLoopsPass() {
 | 
			
		||||
  return absl::make_unique<LhloLegalizeToParallelLoops>();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static PassRegistration<LhloLegalizeToParallelLoops> legalize_lhlo_pass(
 | 
			
		||||
    "lhlo-legalize-to-parallel-loops",
 | 
			
		||||
    "Legalize from LHLO dialect to parallel loops.");
 | 
			
		||||
 | 
			
		||||
}  // namespace xla_lhlo
 | 
			
		||||
}  // namespace mlir
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,79 @@
 | 
			
		|||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
// Thsi file implements passes to convert complex operations to equivalent real
 | 
			
		||||
// value operations. This does not include removing complex values from function
 | 
			
		||||
// argument or return types.
 | 
			
		||||
 | 
			
		||||
#include <cstddef>
 | 
			
		||||
#include <cstdint>
 | 
			
		||||
#include <iterator>
 | 
			
		||||
#include <numeric>
 | 
			
		||||
 | 
			
		||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLExtras.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/TypeUtilities.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Types.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/PassRegistry.h"
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h"
 | 
			
		||||
 | 
			
		||||
using mlir::FunctionPass;
 | 
			
		||||
using mlir::OwningRewritePatternList;
 | 
			
		||||
using mlir::PassRegistration;
 | 
			
		||||
using mlir::PassWrapper;
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
class LowerComplex : public PassWrapper<LowerComplex, FunctionPass> {
 | 
			
		||||
 public:
 | 
			
		||||
  explicit LowerComplex() : PassWrapper<LowerComplex, FunctionPass>() {}
 | 
			
		||||
 | 
			
		||||
  /// Performs the lowering to XLA dialect.
 | 
			
		||||
  void runOnFunction() override;
 | 
			
		||||
};
 | 
			
		||||
}  // end anonymous namespace
 | 
			
		||||
 | 
			
		||||
namespace mlir {
 | 
			
		||||
namespace xla {
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/generated_lower_complex.inc"
 | 
			
		||||
 | 
			
		||||
}  // end anonymous namespace
 | 
			
		||||
 | 
			
		||||
void PopulateComplexLoweringPatterns(MLIRContext* context,
 | 
			
		||||
                                     OwningRewritePatternList* patterns) {
 | 
			
		||||
  populateWithGenerated(context, patterns);
 | 
			
		||||
}
 | 
			
		||||
}  // end namespace xla
 | 
			
		||||
}  // end namespace mlir
 | 
			
		||||
 | 
			
		||||
// Lowers the complex operations that can be represented using other operations.
 | 
			
		||||
void LowerComplex::runOnFunction() {
 | 
			
		||||
  // Add lowering patterns to the list.
 | 
			
		||||
  OwningRewritePatternList patterns;
 | 
			
		||||
  mlir::xla::PopulateComplexLoweringPatterns(&getContext(), &patterns);
 | 
			
		||||
 | 
			
		||||
  applyPatternsAndFoldGreedily(getFunction(), patterns);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static PassRegistration<LowerComplex> pass(
 | 
			
		||||
    "test-xla-lower-complex",
 | 
			
		||||
    "Lower complex operations into non-complex operations");
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,109 @@
 | 
			
		|||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
// This is the legalization pattern that converts complex operations into
 | 
			
		||||
// equivalent real value operations.
 | 
			
		||||
 | 
			
		||||
include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OpBase.td"
 | 
			
		||||
include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td"
 | 
			
		||||
include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td"
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// Binary op patterns.
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
// Add and subtraction are elementwise and can be distributed across the real
 | 
			
		||||
// and imaginary components.
 | 
			
		||||
foreach elementwiseOp = [HLO_AddOp, HLO_SubOp] in
 | 
			
		||||
  def : Pat<(elementwiseOp HLO_ComplexTensor:$lhs,
 | 
			
		||||
             HLO_ComplexTensor:$rhs),
 | 
			
		||||
            (HLO_ComplexOp
 | 
			
		||||
              (elementwiseOp (HLO_RealOp $lhs), (HLO_RealOp $rhs)),
 | 
			
		||||
              (elementwiseOp (HLO_ImagOp $lhs), (HLO_ImagOp $rhs)))>;
 | 
			
		||||
 | 
			
		||||
// Complex multiplication results in a cross product multiplication between the
 | 
			
		||||
// real and imaginary components such that:
 | 
			
		||||
//   result.real = lhs.real * rhs.real - lhs.imag * rhs.imag
 | 
			
		||||
//   result.imag = lhs.imag * rhs.real + lhs.real * rhs.imag
 | 
			
		||||
def : Pat<(HLO_MulOp HLO_ComplexTensor:$lhs,
 | 
			
		||||
           HLO_ComplexTensor:$rhs),
 | 
			
		||||
          (HLO_ComplexOp
 | 
			
		||||
           (HLO_SubOp
 | 
			
		||||
            (HLO_MulOp
 | 
			
		||||
             (HLO_RealOp:$lhs_real $lhs),
 | 
			
		||||
             (HLO_RealOp:$rhs_real $rhs)),
 | 
			
		||||
            (HLO_MulOp
 | 
			
		||||
             (HLO_ImagOp:$lhs_imag $lhs),
 | 
			
		||||
             (HLO_ImagOp:$rhs_imag $rhs))),
 | 
			
		||||
           (HLO_AddOp
 | 
			
		||||
            (HLO_MulOp $lhs_real, $rhs_imag),
 | 
			
		||||
            (HLO_MulOp $lhs_imag, $rhs_real)))>;
 | 
			
		||||
 | 
			
		||||
// Multiplication between a complex and real tensor can be distributed by
 | 
			
		||||
// applying the real multiplicant to both the real and complex component.
 | 
			
		||||
//
 | 
			
		||||
// Note that the sourcep pattern is not legal according to the HLO dialect but
 | 
			
		||||
// instead handle intermediates generated by other patterns.
 | 
			
		||||
def : Pat<(HLO_MulOp HLO_ComplexTensor:$lhs, HLO_IntOrFpTensor:$rhs),
 | 
			
		||||
          (HLO_ComplexOp
 | 
			
		||||
           (HLO_MulOp (HLO_RealOp $lhs), $rhs),
 | 
			
		||||
           (HLO_MulOp (HLO_ImagOp $lhs), $rhs))>;
 | 
			
		||||
 | 
			
		||||
def : Pat<(HLO_MulOp HLO_IntOrFpTensor:$lhs, HLO_ComplexTensor:$rhs),
 | 
			
		||||
          (HLO_ComplexOp
 | 
			
		||||
           (HLO_MulOp $lhs, (HLO_RealOp $rhs)),
 | 
			
		||||
           (HLO_MulOp $lhs, (HLO_ImagOp $rhs)))>;
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
// Division is performed by normalizing the denominator by multiplying by the
 | 
			
		||||
// conjugate of the rhs.
 | 
			
		||||
//   numerator = lhs * conj(rhs)
 | 
			
		||||
//   denominator = rhs * conj(rhs)
 | 
			
		||||
def : Pat<(HLO_DivOp HLO_ComplexTensor:$lhs, HLO_ComplexTensor:$rhs),
 | 
			
		||||
            (HLO_DivOp
 | 
			
		||||
             (HLO_MulOp:$num $lhs,
 | 
			
		||||
              (HLO_ComplexOp:$conj
 | 
			
		||||
               (HLO_RealOp $rhs),
 | 
			
		||||
               (HLO_NegOp (HLO_ImagOp $rhs)))),
 | 
			
		||||
             (HLO_RealOp:$den (HLO_MulOp $rhs, $conj)))>;
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def : Pat<(HLO_DivOp HLO_ComplexTensor:$lhs, HLO_IntOrFpTensor:$rhs),
 | 
			
		||||
          (HLO_ComplexOp
 | 
			
		||||
           (HLO_DivOp (HLO_RealOp $lhs), $rhs),
 | 
			
		||||
           (HLO_DivOp (HLO_ImagOp $lhs), $rhs))>;
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
// Absolute value is evaluated as:
 | 
			
		||||
//   result = sqrt(val.real * val.real + val.imag * val.imag)
 | 
			
		||||
def : Pat<(HLO_AbsOp HLO_ComplexTensor:$val),
 | 
			
		||||
          (HLO_ComplexOp
 | 
			
		||||
           (HLO_SqrtOp
 | 
			
		||||
             (HLO_AddOp
 | 
			
		||||
              (HLO_MulOp (HLO_RealOp:$real $val), $real),
 | 
			
		||||
              (HLO_MulOp (HLO_ImagOp:$imag $val), $imag))),
 | 
			
		||||
           (HLO_ConstOp (ConstantSplat<"0"> $real)))>;
 | 
			
		||||
 | 
			
		||||
// Exponential can be lowered to an exponential on the real component and a
 | 
			
		||||
// sum of sinusoids of the imaginary component, which equates to a normal
 | 
			
		||||
// exponential operator multiplied by Euler's formula.
 | 
			
		||||
//
 | 
			
		||||
// Exp(a + ib) = Exp(a) * Exp(ib) = Exp(a) * (Cos(b) + iSin(b))
 | 
			
		||||
def : Pat<(HLO_ExpOp HLO_ComplexTensor:$val),
 | 
			
		||||
          (HLO_MulOp
 | 
			
		||||
           (HLO_ExpOp (HLO_RealOp $val)),
 | 
			
		||||
           (HLO_ComplexOp
 | 
			
		||||
            (HLO_CosOp (HLO_ImagOp:$imag $val)),
 | 
			
		||||
            (HLO_SinOp $imag)))>;
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,194 @@
 | 
			
		|||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
// This file implements logic for lowering XLA general dot to a regular dot.
 | 
			
		||||
 | 
			
		||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLExtras.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringSwitch.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Function.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Location.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/TypeUtilities.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
 | 
			
		||||
 | 
			
		||||
using mlir::DenseIntElementsAttr;
 | 
			
		||||
using mlir::ElementsAttr;
 | 
			
		||||
using mlir::failure;
 | 
			
		||||
using mlir::FunctionPass;
 | 
			
		||||
using mlir::LogicalResult;
 | 
			
		||||
using mlir::MLIRContext;
 | 
			
		||||
using mlir::OpRewritePattern;
 | 
			
		||||
using mlir::OwningRewritePatternList;
 | 
			
		||||
using mlir::PassRegistration;
 | 
			
		||||
using mlir::PassWrapper;
 | 
			
		||||
using mlir::PatternRewriter;
 | 
			
		||||
using mlir::RankedTensorType;
 | 
			
		||||
using mlir::success;
 | 
			
		||||
using mlir::Value;
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
Value TransposeReshape(Value arg, mlir::Location loc,
 | 
			
		||||
                       llvm::ArrayRef<int64_t> left_dims,
 | 
			
		||||
                       llvm::ArrayRef<int64_t> right_dims,
 | 
			
		||||
                       llvm::ArrayRef<int64_t> arg_shape,
 | 
			
		||||
                       PatternRewriter *rewriter) {
 | 
			
		||||
  auto element_type = mlir::getElementTypeOrSelf(arg.getType());
 | 
			
		||||
 | 
			
		||||
  int64_t left_size = 1;
 | 
			
		||||
  for (auto dim : left_dims) {
 | 
			
		||||
    left_size *= arg_shape[dim];
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  int64_t right_size = 1;
 | 
			
		||||
  for (auto dim : right_dims) {
 | 
			
		||||
    right_size *= arg_shape[dim];
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Generate the transpose permutation attribute.
 | 
			
		||||
  llvm::SmallVector<int64_t, 5> transpose_permutation(left_dims.begin(),
 | 
			
		||||
                                                      left_dims.end());
 | 
			
		||||
  transpose_permutation.append(right_dims.begin(), right_dims.end());
 | 
			
		||||
 | 
			
		||||
  mlir::TensorType transpose_permutation_type = RankedTensorType::get(
 | 
			
		||||
      {static_cast<int64_t>(transpose_permutation.size())},
 | 
			
		||||
      rewriter->getIntegerType(64));
 | 
			
		||||
 | 
			
		||||
  auto transpose_permutation_attr =
 | 
			
		||||
      DenseIntElementsAttr::get(transpose_permutation_type,
 | 
			
		||||
                                llvm::makeArrayRef(transpose_permutation))
 | 
			
		||||
          .cast<DenseIntElementsAttr>();
 | 
			
		||||
 | 
			
		||||
  // Compute the resulting shape.
 | 
			
		||||
  llvm::SmallVector<int64_t, 5> transposed_shape;
 | 
			
		||||
  for (auto val : transpose_permutation) {
 | 
			
		||||
    transposed_shape.push_back(arg_shape[val]);
 | 
			
		||||
  }
 | 
			
		||||
  auto transpose_type = RankedTensorType::get(transposed_shape, element_type);
 | 
			
		||||
  auto transpose_result = rewriter->create<mlir::xla_hlo::TransposeOp>(
 | 
			
		||||
      loc, transpose_type, arg, transpose_permutation_attr);
 | 
			
		||||
 | 
			
		||||
  // Return the final result.
 | 
			
		||||
  auto reshaped_type =
 | 
			
		||||
      RankedTensorType::get({left_size, right_size}, element_type);
 | 
			
		||||
  return rewriter->create<mlir::xla_hlo::ReshapeOp>(loc, reshaped_type,
 | 
			
		||||
                                                    transpose_result);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Value ProcessDotArg(Value arg, mlir::Location loc,
 | 
			
		||||
                    ElementsAttr contract_dims_attr, bool outer_dims_first,
 | 
			
		||||
                    PatternRewriter *rewriter) {
 | 
			
		||||
  auto shape = arg.getType().cast<mlir::ShapedType>().getShape();
 | 
			
		||||
 | 
			
		||||
  llvm::SmallVector<bool, 5> is_outer_dim;
 | 
			
		||||
  is_outer_dim.resize(shape.size(), true);
 | 
			
		||||
 | 
			
		||||
  // Compute the contract dimension ordering.
 | 
			
		||||
  llvm::SmallVector<int64_t, 5> contract_dims;
 | 
			
		||||
  for (auto dim : contract_dims_attr.getValues<int64_t>()) {
 | 
			
		||||
    contract_dims.push_back(dim);
 | 
			
		||||
    is_outer_dim[dim] = false;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Compute the outer dimension orderings.
 | 
			
		||||
  llvm::SmallVector<int64_t, 5> outer_dims;
 | 
			
		||||
  for (auto it : llvm::enumerate(is_outer_dim)) {
 | 
			
		||||
    if (it.value()) {
 | 
			
		||||
      outer_dims.push_back(it.index());
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  if (outer_dims_first) {
 | 
			
		||||
    return TransposeReshape(arg, loc, outer_dims, contract_dims, shape,
 | 
			
		||||
                            rewriter);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return TransposeReshape(arg, loc, contract_dims, outer_dims, shape, rewriter);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
struct GeneralDotConvert
 | 
			
		||||
    : public OpRewritePattern<mlir::xla_hlo::DotGeneralOp> {
 | 
			
		||||
  // Attempts to lower a General Dot operator to a standard Dot operator.
 | 
			
		||||
  // General dots include batching dimensions and can have collapsing
 | 
			
		||||
  // dimensions along any axis. Inserting correctly arrange transpose and
 | 
			
		||||
  // reshape operators organizes the tensors and allows the General Dot to be
 | 
			
		||||
  // replaced with the standard Dot operator.
 | 
			
		||||
  //
 | 
			
		||||
  // Note: This requires an empty list of batch dimensions.
 | 
			
		||||
 | 
			
		||||
  explicit GeneralDotConvert(MLIRContext *context)
 | 
			
		||||
      : OpRewritePattern(context) {}
 | 
			
		||||
 | 
			
		||||
  LogicalResult matchAndRewrite(mlir::xla_hlo::DotGeneralOp op,
 | 
			
		||||
                                PatternRewriter &rewriter) const override {
 | 
			
		||||
    auto dot_element_type = mlir::getElementTypeOrSelf(op);
 | 
			
		||||
 | 
			
		||||
    auto dot_numbers = op.dot_dimension_numbers();
 | 
			
		||||
    if (dot_numbers.lhs_batching_dimensions().getNumElements() != 0 ||
 | 
			
		||||
        dot_numbers.rhs_batching_dimensions().getNumElements() != 0) {
 | 
			
		||||
      return failure();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    auto lhs = ProcessDotArg(op.lhs(), op.getLoc(),
 | 
			
		||||
                             dot_numbers.lhs_contracting_dimensions(),
 | 
			
		||||
                             /*outer_dims_first=*/true, &rewriter);
 | 
			
		||||
 | 
			
		||||
    auto rhs = ProcessDotArg(op.rhs(), op.getLoc(),
 | 
			
		||||
                             dot_numbers.rhs_contracting_dimensions(),
 | 
			
		||||
                             /*outer_dims_first=*/false, &rewriter);
 | 
			
		||||
 | 
			
		||||
    // Dot resulting shape.
 | 
			
		||||
    auto lhs_shape = lhs.getType().cast<mlir::ShapedType>().getShape();
 | 
			
		||||
    auto rhs_shape = rhs.getType().cast<mlir::ShapedType>().getShape();
 | 
			
		||||
    auto new_dot_type =
 | 
			
		||||
        RankedTensorType::get({lhs_shape[0], rhs_shape[1]}, dot_element_type);
 | 
			
		||||
 | 
			
		||||
    auto new_dot_op = rewriter.create<mlir::xla_hlo::DotOp>(
 | 
			
		||||
        op.getLoc(), new_dot_type, lhs, rhs, *(op.precision_config()));
 | 
			
		||||
 | 
			
		||||
    rewriter.replaceOpWithNewOp<mlir::xla_hlo::ReshapeOp>(op, op.getType(),
 | 
			
		||||
                                                          new_dot_op);
 | 
			
		||||
    return success();
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct LegalizeGeneralDot
 | 
			
		||||
    : public PassWrapper<LegalizeGeneralDot, FunctionPass> {
 | 
			
		||||
  /// Lower all general dots that can be represented as a non-batched matmul.
 | 
			
		||||
  void runOnFunction() override {
 | 
			
		||||
    OwningRewritePatternList patterns;
 | 
			
		||||
    mlir::xla_hlo::PopulateGeneralDotOpLoweringPatterns(&patterns,
 | 
			
		||||
                                                        &getContext());
 | 
			
		||||
    applyPatternsAndFoldGreedily(getFunction(), patterns);
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
 | 
			
		||||
void mlir::xla_hlo::PopulateGeneralDotOpLoweringPatterns(
 | 
			
		||||
    OwningRewritePatternList *patterns, MLIRContext *ctx) {
 | 
			
		||||
  patterns->insert<GeneralDotConvert>(ctx);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static PassRegistration<LegalizeGeneralDot> legalize_pass(
 | 
			
		||||
    "test-xla-lower-general-dot",
 | 
			
		||||
    "Tests lowering general dot to a non-batched dot when possible");
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,90 @@
 | 
			
		|||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#include <numeric>
 | 
			
		||||
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h"
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
 | 
			
		||||
 | 
			
		||||
namespace mlir {
 | 
			
		||||
namespace xla_hlo {
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
// Converts ClampOp with broadcast semantics. ClampOp requires "all three arrays
 | 
			
		||||
// must be the same shape. Alternatively, as a restricted form of broadcasting,
 | 
			
		||||
// min and/or max can be a scalar of type T."
 | 
			
		||||
struct ClampWithBroadcastConvert : public OpRewritePattern<ClampOp> {
 | 
			
		||||
  explicit ClampWithBroadcastConvert(MLIRContext *context)
 | 
			
		||||
      : OpRewritePattern<ClampOp>(context) {}
 | 
			
		||||
 | 
			
		||||
  LogicalResult matchAndRewrite(ClampOp op,
 | 
			
		||||
                                PatternRewriter &rewriter) const override {
 | 
			
		||||
    auto operand_type = op.operand().getType().dyn_cast<RankedTensorType>();
 | 
			
		||||
    auto max_type = op.max().getType().dyn_cast<RankedTensorType>();
 | 
			
		||||
    auto min_type = op.min().getType().dyn_cast<RankedTensorType>();
 | 
			
		||||
    // Unrancked types are not supported.
 | 
			
		||||
    if (!operand_type || !max_type || !min_type) return failure();
 | 
			
		||||
    // Does not support operand with dynamic dimensions for now.
 | 
			
		||||
    if (!operand_type.hasStaticShape()) return failure();
 | 
			
		||||
 | 
			
		||||
    ArrayRef<int64_t> operand_shape = operand_type.getShape();
 | 
			
		||||
 | 
			
		||||
    Value max_value = op.max();
 | 
			
		||||
    if (max_type != operand_type) {
 | 
			
		||||
      assert(max_type.getRank() == 0);
 | 
			
		||||
      max_value = rewriter.createOrFold<BroadcastOp>(
 | 
			
		||||
          op.getLoc(), operand_type, max_value,
 | 
			
		||||
          rewriter.getI64TensorAttr(operand_shape));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    Value min_value = op.min();
 | 
			
		||||
    if (min_type != operand_type) {
 | 
			
		||||
      assert(min_type.getRank() == 0);
 | 
			
		||||
      min_value = rewriter.createOrFold<BroadcastOp>(
 | 
			
		||||
          op.getLoc(), operand_type, min_value,
 | 
			
		||||
          rewriter.getI64TensorAttr(operand_shape));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    rewriter.replaceOpWithNewOp<ClampOp>(op, op.getType(), min_value,
 | 
			
		||||
                                         op.operand(), max_value);
 | 
			
		||||
    return success();
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
 | 
			
		||||
void SetupMaterializeBroadcastsLegality(MLIRContext *context,
 | 
			
		||||
                                        ConversionTarget *conversionTarget) {
 | 
			
		||||
  conversionTarget->addDynamicallyLegalOp<ClampOp>([](ClampOp op) {
 | 
			
		||||
    return op.max().getType() == op.operand().getType() &&
 | 
			
		||||
           op.min().getType() == op.operand().getType();
 | 
			
		||||
  });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void PopulateMaterializeBroadcastsPatterns(MLIRContext *context,
 | 
			
		||||
                                           OwningRewritePatternList *patterns) {
 | 
			
		||||
  // ClampOp. This op has a special case where it accepts either same-shaped
 | 
			
		||||
  // inputs or scalars (a restricted form of broadcasting). This makes the
 | 
			
		||||
  // broadcast explicit.
 | 
			
		||||
  patterns->insert<ClampWithBroadcastConvert>(context);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace xla_hlo
 | 
			
		||||
}  // namespace mlir
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,58 @@
 | 
			
		|||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h"
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
 | 
			
		||||
 | 
			
		||||
namespace mlir {
 | 
			
		||||
namespace xla_hlo {
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
struct TestMaterializeBroadcastsPass
 | 
			
		||||
    : public PassWrapper<TestMaterializeBroadcastsPass, FunctionPass> {
 | 
			
		||||
  void runOnFunction() override {
 | 
			
		||||
    ConversionTarget conversionTarget(getContext());
 | 
			
		||||
    OwningRewritePatternList conversionPatterns;
 | 
			
		||||
 | 
			
		||||
    // Consider the xla_hlo dialect legal for tests.
 | 
			
		||||
    conversionTarget.addLegalDialect<XlaHloDialect>();
 | 
			
		||||
    // The conversion uses helpers from the Standard dialect.
 | 
			
		||||
    conversionTarget.addLegalDialect<mlir::StandardOpsDialect>();
 | 
			
		||||
 | 
			
		||||
    SetupMaterializeBroadcastsLegality(&getContext(), &conversionTarget);
 | 
			
		||||
    PopulateMaterializeBroadcastsPatterns(&getContext(), &conversionPatterns);
 | 
			
		||||
 | 
			
		||||
    if (failed(applyPartialConversion(getFunction(), conversionTarget,
 | 
			
		||||
                                      conversionPatterns))) {
 | 
			
		||||
      return signalPassFailure();
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
 | 
			
		||||
}  // namespace xla_hlo
 | 
			
		||||
}  // namespace mlir
 | 
			
		||||
 | 
			
		||||
static mlir::PassRegistration<mlir::xla_hlo::TestMaterializeBroadcastsPass>
 | 
			
		||||
    pass("test-xla-materialize-broadcasts",
 | 
			
		||||
         "Test pass for materializing 'broadcast_dimensions' attributes");
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,85 @@
 | 
			
		|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/DenseMap.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/Support/Casting.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/PassManager.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Support/LLVM.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/RegionUtils.h"
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
 | 
			
		||||
 | 
			
		||||
namespace mlir {
 | 
			
		||||
namespace xla_hlo {
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
// A pass that sinks constants implicitly captured in control flow regions. This
 | 
			
		||||
// is necessary to export to XLA.
 | 
			
		||||
class SinkConstantsToControlFlow
 | 
			
		||||
    : public mlir::PassWrapper<SinkConstantsToControlFlow, FunctionPass> {
 | 
			
		||||
  void runOnFunction() override {
 | 
			
		||||
    getFunction().walk([](Operation* op) {
 | 
			
		||||
      if (auto while_op = llvm::dyn_cast<WhileOp>(op)) {
 | 
			
		||||
        SinkToRegion(&while_op.body());
 | 
			
		||||
        SinkToRegion(&while_op.cond());
 | 
			
		||||
      } else if (auto if_op = llvm::dyn_cast<IfOp>(op)) {
 | 
			
		||||
        SinkToRegion(&if_op.true_branch());
 | 
			
		||||
        SinkToRegion(&if_op.false_branch());
 | 
			
		||||
      }
 | 
			
		||||
    });
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  // Performs constant sinking into a region.
 | 
			
		||||
  static void SinkToRegion(Region* region) {
 | 
			
		||||
    llvm::DenseMap<Value, ConstOp> sunk_constant;
 | 
			
		||||
    visitUsedValuesDefinedAbove({*region}, [&](OpOperand* use) {
 | 
			
		||||
      Value constant = use->get();
 | 
			
		||||
      auto const_op = dyn_cast_or_null<ConstOp>(constant.getDefiningOp());
 | 
			
		||||
      if (!const_op) return;
 | 
			
		||||
      auto map_entry = sunk_constant.try_emplace(constant, nullptr);
 | 
			
		||||
      if (!map_entry.second) {
 | 
			
		||||
        // This constant has already been cloned into the region, reuse it.
 | 
			
		||||
        use->set(map_entry.first->getSecond().getResult());
 | 
			
		||||
        if (constant.use_empty()) const_op.erase();
 | 
			
		||||
        return;
 | 
			
		||||
      }
 | 
			
		||||
      if (constant.hasOneUse()) {
 | 
			
		||||
        const_op.getOperation()->moveBefore(®ion->front().front());
 | 
			
		||||
        return;
 | 
			
		||||
      }
 | 
			
		||||
      map_entry.first->getSecond() = const_op.clone();
 | 
			
		||||
      region->front().getOperations().insert(region->front().begin(),
 | 
			
		||||
                                             map_entry.first->getSecond());
 | 
			
		||||
      use->set(map_entry.first->getSecond().getResult());
 | 
			
		||||
    });
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
static mlir::PassRegistration<SinkConstantsToControlFlow> pass(
 | 
			
		||||
    "xla-hlo-sink-constants-to-control-flow",
 | 
			
		||||
    "Sink constants implicitly captured in control flow regions. This is "
 | 
			
		||||
    "necessary to export to XLA.");
 | 
			
		||||
 | 
			
		||||
}  // anonymous namespace
 | 
			
		||||
 | 
			
		||||
std::unique_ptr<OperationPass<FuncOp>> createSinkConstantsToControlFlowPass() {
 | 
			
		||||
  return std::make_unique<SinkConstantsToControlFlow>();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace xla_hlo
 | 
			
		||||
}  // namespace mlir
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,100 @@
 | 
			
		|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Identifier.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/OperationSupport.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Interfaces/InferTypeOpInterface.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
 | 
			
		||||
 | 
			
		||||
namespace mlir {
 | 
			
		||||
namespace xla {
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
struct InferReturnTypeComponentsPattern : public RewritePattern {
 | 
			
		||||
  InferReturnTypeComponentsPattern(MLIRContext *context)
 | 
			
		||||
      : RewritePattern("xla_test.get_return_type_components", 1, context) {}
 | 
			
		||||
  LogicalResult matchAndRewrite(Operation *op,
 | 
			
		||||
                                PatternRewriter &rewriter) const override {
 | 
			
		||||
    if (op->getNumOperands() != 1) return failure();
 | 
			
		||||
    auto defining_op = op->getOperand(0).getDefiningOp();
 | 
			
		||||
    auto defining_op_int =
 | 
			
		||||
        llvm::dyn_cast_or_null<InferShapedTypeOpInterface>(defining_op);
 | 
			
		||||
    if (!defining_op_int) return failure();
 | 
			
		||||
    SmallVector<ShapedTypeComponents, 4> components;
 | 
			
		||||
    if (failed(defining_op_int.inferReturnTypeComponents(
 | 
			
		||||
            op->getContext(), op->getLoc(), defining_op->getOperands(),
 | 
			
		||||
            defining_op->getAttrDictionary(), defining_op->getRegions(),
 | 
			
		||||
            components))) {
 | 
			
		||||
      return failure();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Replace the op with another pass-through op with attributes added.
 | 
			
		||||
    OperationState state(op->getLoc(), "xla_test.return_type_components",
 | 
			
		||||
                         op->getOperands(), op->getResultTypes(),
 | 
			
		||||
                         op->getAttrs());
 | 
			
		||||
    auto new_op = rewriter.createOperation(state);
 | 
			
		||||
    for (auto it : llvm::enumerate(components)) {
 | 
			
		||||
      if (it.value().hasRank()) {
 | 
			
		||||
        new_op->setAttr((StringRef("dims") + Twine(it.index())).str(),
 | 
			
		||||
                        rewriter.getI64ArrayAttr(it.value().getDims()));
 | 
			
		||||
      }
 | 
			
		||||
      if (it.value().getElementType()) {
 | 
			
		||||
        new_op->setAttr((Twine("element_type") + Twine(it.index())).str(),
 | 
			
		||||
                        TypeAttr::get(it.value().getElementType()));
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
    rewriter.replaceOp(op, {new_op->getResults()});
 | 
			
		||||
    return success();
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct ReifyReturnTypeShapesPattern : public RewritePattern {
 | 
			
		||||
  ReifyReturnTypeShapesPattern(MLIRContext *context)
 | 
			
		||||
      : RewritePattern("xla_test.reify_return_type_shapes", 1, context) {}
 | 
			
		||||
  LogicalResult matchAndRewrite(Operation *op,
 | 
			
		||||
                                PatternRewriter &rewriter) const override {
 | 
			
		||||
    if (op->getNumOperands() != 1) return failure();
 | 
			
		||||
    auto defining_op = llvm::dyn_cast_or_null<InferShapedTypeOpInterface>(
 | 
			
		||||
        op->getOperand(0).getDefiningOp());
 | 
			
		||||
    if (!defining_op) return failure();
 | 
			
		||||
    SmallVector<Value, 4> return_shapes;
 | 
			
		||||
    if (failed(defining_op.reifyReturnTypeShapes(rewriter, return_shapes))) {
 | 
			
		||||
      return failure();
 | 
			
		||||
    }
 | 
			
		||||
    rewriter.replaceOp(op, return_shapes);
 | 
			
		||||
    return success();
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct TestInferShapedTypeMethodsPass
 | 
			
		||||
    : public PassWrapper<TestInferShapedTypeMethodsPass, FunctionPass> {
 | 
			
		||||
  void runOnFunction() override {
 | 
			
		||||
    OwningRewritePatternList patterns;
 | 
			
		||||
    patterns.insert<ReifyReturnTypeShapesPattern>(&getContext());
 | 
			
		||||
    patterns.insert<InferReturnTypeComponentsPattern>(&getContext());
 | 
			
		||||
    applyPatternsAndFoldGreedily(getFunction(), patterns);
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
}  // namespace xla
 | 
			
		||||
}  // namespace mlir
 | 
			
		||||
 | 
			
		||||
static mlir::PassRegistration<mlir::xla::TestInferShapedTypeMethodsPass> pass(
 | 
			
		||||
    "test-xla-infer-shaped-type-methods",
 | 
			
		||||
    "Uses test ops to invoke InferShapedTypeOpInterface methods");
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,184 @@
 | 
			
		|||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/SmallVector.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Types.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h"
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
 | 
			
		||||
 | 
			
		||||
namespace mlir {
 | 
			
		||||
namespace xla_hlo {
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
// Broadcasts the 1D value tensor 'value_1d' to the shape of 'result_type'. If
 | 
			
		||||
// 'shape_value' is initialized, creates a dynamic broadcast, otherwise creates
 | 
			
		||||
// a static broadcast.
 | 
			
		||||
Value BroadcastToFeatureDim(Location loc, RankedTensorType result_type,
 | 
			
		||||
                            Value value_1d, Value shape_value,
 | 
			
		||||
                            int64_t feature_dim,
 | 
			
		||||
                            PatternRewriter& rewriter) {  // NOLINT
 | 
			
		||||
  Builder b(rewriter.getContext());
 | 
			
		||||
  auto dims_type = RankedTensorType::get({1}, b.getIntegerType(64));
 | 
			
		||||
  auto dims = DenseIntElementsAttr::get(dims_type, {feature_dim});
 | 
			
		||||
  if (shape_value) {
 | 
			
		||||
    return rewriter.createOrFold<xla_hlo::DynamicBroadcastInDimOp>(
 | 
			
		||||
        loc, result_type, value_1d, shape_value, dims);
 | 
			
		||||
  }
 | 
			
		||||
  assert(result_type.hasStaticShape());
 | 
			
		||||
  return rewriter.create<xla_hlo::BroadcastInDimOp>(loc, result_type, value_1d,
 | 
			
		||||
                                                    dims);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Calculate the shape value of operand, assuming it is a dynamic shape with
 | 
			
		||||
// static rank.
 | 
			
		||||
Value CalculateShapeValue(Location loc, Value operand,
 | 
			
		||||
                          PatternRewriter& rewriter) {  // NOLINT
 | 
			
		||||
  RankedTensorType result_type = operand.getType().dyn_cast<RankedTensorType>();
 | 
			
		||||
  llvm::SmallVector<Value, 4> shape_values;
 | 
			
		||||
  int64_t rank = result_type.getRank();
 | 
			
		||||
  shape_values.reserve(rank);
 | 
			
		||||
  for (int64_t i = 0; i < rank; ++i) {
 | 
			
		||||
    shape_values.push_back(rewriter.create<mlir::DimOp>(loc, operand, i));
 | 
			
		||||
  }
 | 
			
		||||
  return rewriter.create<TensorFromElementsOp>(loc, shape_values);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Value MaterializeEpsilon(Operation* op, FloatAttr epsilon_attr,
 | 
			
		||||
                         FloatType fp_type, Value variance,
 | 
			
		||||
                         RankedTensorType broadcast_to_type,
 | 
			
		||||
                         PatternRewriter& rewriter) {  // NOLINT
 | 
			
		||||
  Builder b(rewriter.getContext());
 | 
			
		||||
  if (epsilon_attr.getType() != fp_type) {
 | 
			
		||||
    // Need to convert.
 | 
			
		||||
    bool loses_info;
 | 
			
		||||
    APFloat epsilon_float = epsilon_attr.getValue();
 | 
			
		||||
    auto status = epsilon_float.convert(
 | 
			
		||||
        fp_type.getFloatSemantics(), APFloat::rmNearestTiesToEven, &loses_info);
 | 
			
		||||
    if ((status & (~APFloat::opInexact)) != APFloat::opOK) {
 | 
			
		||||
      op->emitWarning() << "Could not convert batch_norm epsilon to target fp "
 | 
			
		||||
                           "type: opStatus = "
 | 
			
		||||
                        << static_cast<int>(status);
 | 
			
		||||
      return nullptr;
 | 
			
		||||
    }
 | 
			
		||||
    if (loses_info) {
 | 
			
		||||
      op->emitWarning("Conversion of epsilon loses precision");
 | 
			
		||||
    }
 | 
			
		||||
    epsilon_attr = b.getFloatAttr(fp_type, epsilon_float);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  auto scalar_type = RankedTensorType::get({}, fp_type);
 | 
			
		||||
  auto epsilon_tensor_attr =
 | 
			
		||||
      DenseElementsAttr::get(scalar_type, {epsilon_attr.cast<Attribute>()});
 | 
			
		||||
  Value epsilon =
 | 
			
		||||
      rewriter.create<xla_hlo::ConstOp>(op->getLoc(), epsilon_tensor_attr);
 | 
			
		||||
  auto dims_type = RankedTensorType::get({0}, b.getIntegerType(64));
 | 
			
		||||
  auto dims = DenseIntElementsAttr::get(dims_type, SmallVector<int64_t, 1>{});
 | 
			
		||||
  if (broadcast_to_type.hasStaticShape()) {
 | 
			
		||||
    return rewriter.create<xla_hlo::BroadcastInDimOp>(
 | 
			
		||||
        op->getLoc(), broadcast_to_type, epsilon, /*broadcast_dims=*/dims);
 | 
			
		||||
  }
 | 
			
		||||
  Value shape_value = CalculateShapeValue(op->getLoc(), variance, rewriter);
 | 
			
		||||
  return rewriter.createOrFold<xla_hlo::DynamicBroadcastInDimOp>(
 | 
			
		||||
      op->getLoc(), broadcast_to_type, epsilon, shape_value,
 | 
			
		||||
      /*broadcast_dims=*/dims);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
class UnfuseBatchNormInferencePattern
 | 
			
		||||
    : public OpRewritePattern<xla_hlo::BatchNormInferenceOp> {
 | 
			
		||||
 public:
 | 
			
		||||
  using OpRewritePattern<xla_hlo::BatchNormInferenceOp>::OpRewritePattern;
 | 
			
		||||
 | 
			
		||||
  LogicalResult matchAndRewrite(xla_hlo::BatchNormInferenceOp bn_op,
 | 
			
		||||
                                PatternRewriter& rewriter) const override {
 | 
			
		||||
    // Enforce type invariants.
 | 
			
		||||
    // Note that we deduce the actual element type from the variance,
 | 
			
		||||
    // which should not be subject to quantization at a higher level.
 | 
			
		||||
    auto input_type = bn_op.operand().getType().dyn_cast<RankedTensorType>();
 | 
			
		||||
    auto variance_type =
 | 
			
		||||
        bn_op.variance().getType().dyn_cast<RankedTensorType>();
 | 
			
		||||
    if (!input_type || !variance_type) {
 | 
			
		||||
      return failure();
 | 
			
		||||
    }
 | 
			
		||||
    auto fp_type = variance_type.getElementType().dyn_cast<FloatType>();
 | 
			
		||||
    if (!fp_type) {
 | 
			
		||||
      return failure();
 | 
			
		||||
    }
 | 
			
		||||
    int64_t feature_dim = bn_op.feature_index().getSExtValue();
 | 
			
		||||
 | 
			
		||||
    // Add epsilon to the variance and sqrt to get stddev:
 | 
			
		||||
    // stddev = sqrt(variance + epsilon)
 | 
			
		||||
    auto epsilon =
 | 
			
		||||
        MaterializeEpsilon(bn_op.getOperation(), bn_op.epsilonAttr(), fp_type,
 | 
			
		||||
                           bn_op.variance(), variance_type, rewriter);
 | 
			
		||||
    if (!epsilon) {
 | 
			
		||||
      return failure();
 | 
			
		||||
    }
 | 
			
		||||
    Value stddev = rewriter.create<xla_hlo::AddOp>(bn_op.getLoc(),
 | 
			
		||||
                                                   bn_op.variance(), epsilon);
 | 
			
		||||
    stddev = rewriter.create<xla_hlo::SqrtOp>(bn_op.getLoc(), stddev);
 | 
			
		||||
 | 
			
		||||
    // Broadcast all terms.
 | 
			
		||||
    Value shape_value;
 | 
			
		||||
    if (!input_type.hasStaticShape()) {
 | 
			
		||||
      shape_value =
 | 
			
		||||
          CalculateShapeValue(bn_op.getLoc(), bn_op.operand(), rewriter);
 | 
			
		||||
    }
 | 
			
		||||
    auto broadcast_scale =
 | 
			
		||||
        BroadcastToFeatureDim(bn_op.getLoc(), input_type, bn_op.scale(),
 | 
			
		||||
                              shape_value, feature_dim, rewriter);
 | 
			
		||||
    auto broadcast_offset =
 | 
			
		||||
        BroadcastToFeatureDim(bn_op.getLoc(), input_type, bn_op.offset(),
 | 
			
		||||
                              shape_value, feature_dim, rewriter);
 | 
			
		||||
    auto broadcast_mean =
 | 
			
		||||
        BroadcastToFeatureDim(bn_op.getLoc(), input_type, bn_op.mean(),
 | 
			
		||||
                              shape_value, feature_dim, rewriter);
 | 
			
		||||
    auto broadcast_stddev = BroadcastToFeatureDim(
 | 
			
		||||
        bn_op.getLoc(), input_type, stddev, shape_value, feature_dim, rewriter);
 | 
			
		||||
 | 
			
		||||
    // Compute:
 | 
			
		||||
    // scale * (input - mean) / stddev + offset
 | 
			
		||||
    Value result = rewriter.create<xla_hlo::SubOp>(
 | 
			
		||||
        bn_op.getLoc(), bn_op.operand(), broadcast_mean);
 | 
			
		||||
    result = rewriter.create<xla_hlo::MulOp>(bn_op.getLoc(), result,
 | 
			
		||||
                                             broadcast_scale);
 | 
			
		||||
    result = rewriter.create<xla_hlo::DivOp>(bn_op.getLoc(), result,
 | 
			
		||||
                                             broadcast_stddev);
 | 
			
		||||
    rewriter.replaceOpWithNewOp<xla_hlo::AddOp>(bn_op, result,
 | 
			
		||||
                                                broadcast_offset);
 | 
			
		||||
 | 
			
		||||
    return success();
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
 | 
			
		||||
// Populates conversion patterns to unfuse batch normalization operations.
 | 
			
		||||
// In combination with marking such ops as illegal, this allows backends that
 | 
			
		||||
// do not have special support for fused batchnorm to use simpler arithmetic
 | 
			
		||||
// primitives.
 | 
			
		||||
void PopulateUnfuseBatchNormPatterns(MLIRContext* context,
 | 
			
		||||
                                     OwningRewritePatternList* patterns) {
 | 
			
		||||
  patterns->insert<UnfuseBatchNormInferencePattern>(context);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace xla_hlo
 | 
			
		||||
}  // namespace mlir
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,46 @@
 | 
			
		|||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h"
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
 | 
			
		||||
 | 
			
		||||
namespace mlir {
 | 
			
		||||
namespace xla_hlo {
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
struct TestUnfuseBatchNormPass
 | 
			
		||||
    : public PassWrapper<TestUnfuseBatchNormPass, OperationPass<>> {
 | 
			
		||||
  void runOnOperation() override {
 | 
			
		||||
    OwningRewritePatternList patterns;
 | 
			
		||||
    PopulateUnfuseBatchNormPatterns(&getContext(), &patterns);
 | 
			
		||||
    applyPatternsAndFoldGreedily(getOperation(), patterns);
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
 | 
			
		||||
}  // namespace xla_hlo
 | 
			
		||||
}  // namespace mlir
 | 
			
		||||
 | 
			
		||||
static mlir::PassRegistration<mlir::xla_hlo::TestUnfuseBatchNormPass> pass(
 | 
			
		||||
    "test-xla-unfuse-batch-norm",
 | 
			
		||||
    "Test pass for materializing 'broadcast_dimensions' attributes");
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,579 @@
 | 
			
		|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#include <memory>
 | 
			
		||||
#include <unordered_map>
 | 
			
		||||
#include <unordered_set>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
 | 
			
		||||
#include "mlir/IR/MLIRContext.h"              // TF:llvm-project
 | 
			
		||||
#include "mlir/IR/Matchers.h"
 | 
			
		||||
#include "mlir/Pass/Pass.h"               // TF:local_config_mlir
 | 
			
		||||
#include "mlir/Transforms/RegionUtils.h"  // TF:llvm-project
 | 
			
		||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/EquivalenceClasses.h"
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/cycle_detector.h"
 | 
			
		||||
 | 
			
		||||
// This pass has similar functionality of the fusion pass in XLA stack.
 | 
			
		||||
// However, unlike XLA, it targets the fully dynamic shape scenario.
 | 
			
		||||
// Currently, it implements the kLoop and kInput fusion templates.
 | 
			
		||||
// During conversion, it tries to greedily find kLoop/kInput fusion
 | 
			
		||||
// patterns.
 | 
			
		||||
//
 | 
			
		||||
// Similar to XLA, this pass supports fusion pattern having multiple outputs
 | 
			
		||||
// if all the shape of outputs are consistent. Following are some examples.
 | 
			
		||||
//
 | 
			
		||||
//        kLoop                          kInput
 | 
			
		||||
// +----+  +----+  +----+    +----+    +----+    +----+
 | 
			
		||||
// |elem|  |elem|  |elem|    |elem<----+elem+---->elem+----+
 | 
			
		||||
// +-+--+  +-+--+  +-+--+    +-+--+    +----+    +-+--+    |
 | 
			
		||||
//   |       |       |         |                   |       |
 | 
			
		||||
//   |               |         |                   |       |
 | 
			
		||||
// +-v--+    |     +-v--+   +--v---+            +--v---+   |
 | 
			
		||||
// |elem+<---+----<+elem|   |reduce|            |reduce|   |
 | 
			
		||||
// +-+--+          +-+--+   +--+---+            +--+---+   |
 | 
			
		||||
//   |               |         |                   |       |
 | 
			
		||||
//   |               |         |                   |       |
 | 
			
		||||
//   v               v         v                   v       v
 | 
			
		||||
//
 | 
			
		||||
// To this end, we also add an simple shape constraint analysis phase.
 | 
			
		||||
// For kLoop fusion template, it requires all the outputs of the fused
 | 
			
		||||
// pattern have the same shape. However, we don't know the actual value
 | 
			
		||||
// of the shape at the compile time in the dynamic shape world.
 | 
			
		||||
// Fortunately, we could still infer the relationship among different ops
 | 
			
		||||
// according to their shape constrain traits. Currently, We only consider
 | 
			
		||||
// shape equality propagation for elementwise ops (assuming that implicit
 | 
			
		||||
// shape broadcast is forbidden). The above process could be built on the
 | 
			
		||||
// shape dialect once it is ready.
 | 
			
		||||
 | 
			
		||||
namespace mlir {
 | 
			
		||||
namespace xla_hlo {
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
using llvm::EquivalenceClasses;
 | 
			
		||||
using FusionPattern = std::vector<Operation*>;
 | 
			
		||||
using FusionPlan = std::vector<FusionPattern>;
 | 
			
		||||
 | 
			
		||||
// To support using EquivalenceClasses for Value
 | 
			
		||||
class ValueWrapper {
 | 
			
		||||
 public:
 | 
			
		||||
  explicit ValueWrapper(Value value) : value_(std::move(value)) {}
 | 
			
		||||
 | 
			
		||||
  Value getValue() const { return value_; }
 | 
			
		||||
 | 
			
		||||
  bool operator==(const ValueWrapper& rhs) const {
 | 
			
		||||
    return getValue() == rhs.getValue();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  Value value_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
bool operator<(const ValueWrapper& lhs, const ValueWrapper& rhs) {
 | 
			
		||||
  auto lhs_value = lhs.getValue().getAsOpaquePointer();
 | 
			
		||||
  auto rhs_value = rhs.getValue().getAsOpaquePointer();
 | 
			
		||||
  return lhs_value < rhs_value;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool IsFusible(Operation* op) {
 | 
			
		||||
  if (matchPattern(op, m_Constant())) {
 | 
			
		||||
    return true;
 | 
			
		||||
  }
 | 
			
		||||
  auto op_fusibility = dyn_cast<InferFusibilityOpInterface>(op);
 | 
			
		||||
  return op_fusibility && (op_fusibility.isFusibleWithOperand() ||
 | 
			
		||||
                           op_fusibility.isFusibleWithConsumer());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
SmallVector<Value, 4> GetInputsOfFusionPattern(const FusionPattern& pattern) {
 | 
			
		||||
  SmallVector<Value, 4> inputs;
 | 
			
		||||
  DenseSet<Value> input_set;
 | 
			
		||||
  DenseSet<Operation*> op_set;
 | 
			
		||||
  for (Operation* op : pattern) {
 | 
			
		||||
    bool inserted = op_set.insert(op).second;
 | 
			
		||||
    (void)inserted;
 | 
			
		||||
    assert(inserted && "FusionPattern contains duplicate operations");
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  for (Operation* op : pattern) {
 | 
			
		||||
    for (Value operand : op->getOperands()) {
 | 
			
		||||
      Operation* operand_op = operand.getDefiningOp();
 | 
			
		||||
      if (op_set.find(operand_op) != op_set.end()) {
 | 
			
		||||
        // skip if defining op is in the pattern
 | 
			
		||||
        continue;
 | 
			
		||||
      }
 | 
			
		||||
      if (input_set.insert(operand).second) {
 | 
			
		||||
        inputs.push_back(operand);
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  return inputs;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
SmallVector<Value, 4> GetOutputsOfFusionPattern(const FusionPattern& pattern) {
 | 
			
		||||
  SmallVector<Value, 4> outputs;
 | 
			
		||||
  DenseSet<Operation*> op_set;
 | 
			
		||||
  for (Operation* op : pattern) {
 | 
			
		||||
    bool inserted = op_set.insert(op).second;
 | 
			
		||||
    (void)inserted;
 | 
			
		||||
    assert(inserted && "FusionPattern contains duplicate operations");
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  for (Operation* op : pattern) {
 | 
			
		||||
    for (Value result : op->getResults()) {
 | 
			
		||||
      bool has_external_user = llvm::any_of(
 | 
			
		||||
          result.getUses(),
 | 
			
		||||
          [&](OpOperand& use) { return !op_set.count(use.getOwner()); });
 | 
			
		||||
      if (has_external_user) {
 | 
			
		||||
        outputs.push_back(result);
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  return outputs;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
FusionPattern MergeFusionPattern(const FusionPattern& lhs,
 | 
			
		||||
                                 const FusionPattern& rhs) {
 | 
			
		||||
  FusionPattern pattern(lhs);
 | 
			
		||||
  pattern.insert(pattern.end(), rhs.begin(), rhs.end());
 | 
			
		||||
  return pattern;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline int EffectiveSize(const FusionPattern& pattern) {
 | 
			
		||||
  return llvm::count_if(
 | 
			
		||||
      pattern, [](Operation* op) { return !matchPattern(op, m_Constant()); });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// This is an simple shape constraint analysis, which is used to
 | 
			
		||||
// guide fusion decision (e.g. we only fuse shape-compatible ops).
 | 
			
		||||
//
 | 
			
		||||
// Currently, We only consider shape equality propagation based
 | 
			
		||||
// on the shape constrain traits of elementwise ops (assuming that
 | 
			
		||||
// implicit shape broadcast is forbidden).
 | 
			
		||||
class ShapeConstraintAnalysis {
 | 
			
		||||
 public:
 | 
			
		||||
  explicit ShapeConstraintAnalysis(const SmallVectorImpl<Operation*>& op_list) {
 | 
			
		||||
    PropagateEquality(op_list);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Returns true is `lhs` and `rhs` are supposed to have same shape.
 | 
			
		||||
  bool HasSameShape(Value lhs, Value rhs) {
 | 
			
		||||
    return impl_.isEquivalent(ValueWrapper(lhs), ValueWrapper(rhs));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  // shape equality propagation based on the shape constrains of
 | 
			
		||||
  // elementwise ops.
 | 
			
		||||
  void PropagateEquality(const SmallVectorImpl<Operation*>& op_list) {
 | 
			
		||||
    bool converged = true;
 | 
			
		||||
    do {
 | 
			
		||||
      converged = true;
 | 
			
		||||
      auto update = [&](Value lhs, Value rhs) {
 | 
			
		||||
        if (!impl_.isEquivalent(ValueWrapper(lhs), ValueWrapper(rhs))) {
 | 
			
		||||
          converged = false;
 | 
			
		||||
          impl_.unionSets(ValueWrapper(lhs), ValueWrapper(rhs));
 | 
			
		||||
        }
 | 
			
		||||
      };
 | 
			
		||||
      for (Operation* op : op_list) {
 | 
			
		||||
        auto op_fusibility = dyn_cast<InferFusibilityOpInterface>(op);
 | 
			
		||||
        if (!op_fusibility) continue;
 | 
			
		||||
        int numInput = op->getNumOperands();
 | 
			
		||||
        int numOutput = op->getNumResults();
 | 
			
		||||
        // shape equality propagation between inputs.
 | 
			
		||||
        for (int input1 = 0; input1 < numInput; ++input1)
 | 
			
		||||
          for (int input2 = input1 + 1; input2 < numInput; ++input2)
 | 
			
		||||
            if (op_fusibility.inferInputsShapeEquality(input1, input2))
 | 
			
		||||
              update(op->getOperand(input1), op->getOperand(input2));
 | 
			
		||||
 | 
			
		||||
        // shape equality propagation between outputs.
 | 
			
		||||
        for (int output1 = 0; output1 < numOutput; ++output1)
 | 
			
		||||
          for (int output2 = output1 + 1; output2 < numOutput; ++output2)
 | 
			
		||||
            if (op_fusibility.inferOutputsShapeEquality(output1, output2))
 | 
			
		||||
              update(op->getResult(output1), op->getResult(output2));
 | 
			
		||||
 | 
			
		||||
        // shape equality propagation between input and output.
 | 
			
		||||
        for (int input = 0; input < numInput; ++input)
 | 
			
		||||
          for (int output = 0; output < numOutput; ++output)
 | 
			
		||||
            if (op_fusibility.inferInputOutputShapeEquality(input, output))
 | 
			
		||||
              update(op->getOperand(input), op->getResult(output));
 | 
			
		||||
      }
 | 
			
		||||
    } while (!converged);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // a UnionFind set
 | 
			
		||||
  EquivalenceClasses<ValueWrapper> impl_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// A fusion planner that can propose a fusion plan for a block of ops.
 | 
			
		||||
// The fusion plan is consisted of a group of fusion patterns.
 | 
			
		||||
//
 | 
			
		||||
// Currently all proposed patterns followed xla kLoop/kInput like fusion
 | 
			
		||||
// templates while are adapted to the fully dynamic shape world.
 | 
			
		||||
//
 | 
			
		||||
// kLoop fusion template satifies:
 | 
			
		||||
//   - all ops in the fusion pattern are element-wise.
 | 
			
		||||
//   - all the shapes of outputs of fusion pattern are same, and thus can
 | 
			
		||||
//     fit into a same parallel loop.
 | 
			
		||||
//
 | 
			
		||||
// kInput fusion template satifies:
 | 
			
		||||
//   - any op in the fusion pattern is either element-wise or a reduction.
 | 
			
		||||
//   - if a op is a reduction, its output cannot be consumered by other
 | 
			
		||||
//     ops in the same fusion pattern.
 | 
			
		||||
//   - all the effective shapes of outputs of fusion pattern are same.
 | 
			
		||||
//     - For element-wise op, its effective shape is its output shape.
 | 
			
		||||
//     - For reduction op, its effective shape is its operand shape.
 | 
			
		||||
class FusionPlanner {
 | 
			
		||||
 public:
 | 
			
		||||
  explicit FusionPlanner(const SmallVectorImpl<Operation*>& op_list)
 | 
			
		||||
      : op_list_(op_list),
 | 
			
		||||
        shape_analysis_(op_list),
 | 
			
		||||
        cycle_detector_(op_list.size()) {
 | 
			
		||||
    BuildNodeMap();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Returns a fusion plan if success, otherwise none.
 | 
			
		||||
  llvm::Optional<FusionPlan> Run() {
 | 
			
		||||
    // Greedily search connected fusible pattern, and ops belonging to
 | 
			
		||||
    // a same fusion pattern are grouped into a cluster.
 | 
			
		||||
    RunEdgeContractionLoop();
 | 
			
		||||
 | 
			
		||||
    // After doing edge contraction, each unique cluster having size
 | 
			
		||||
    // more than one represents a potential fusion pattern.
 | 
			
		||||
    // We collect all these clusters and construct a fusion plan.
 | 
			
		||||
    //
 | 
			
		||||
    // Note that the ops in a fusion pattern are in topological ordering.
 | 
			
		||||
    FusionPlan plan;
 | 
			
		||||
    DenseMap<int, int> pattern_ids;
 | 
			
		||||
    for (Operation* op : op_list_) {
 | 
			
		||||
      Cluster* cluster = GetClusterForNode(op);
 | 
			
		||||
      int node_id = cluster->cycles_graph_node_id();
 | 
			
		||||
      if (!IsFusible(op_list_[node_id]) ||
 | 
			
		||||
          EffectiveSize(GetClusterForNode(op)->fused_pattern()) <= 1) {
 | 
			
		||||
        continue;
 | 
			
		||||
      }
 | 
			
		||||
      if (!pattern_ids.count(node_id)) {
 | 
			
		||||
        int pattern_id = pattern_ids.size();
 | 
			
		||||
        pattern_ids[node_id] = pattern_id;
 | 
			
		||||
        plan.emplace_back();
 | 
			
		||||
      }
 | 
			
		||||
      plan[pattern_ids[node_id]].push_back(op);
 | 
			
		||||
    }
 | 
			
		||||
    return plan;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Returns the op_list this planner operates on.
 | 
			
		||||
  const SmallVectorImpl<Operation*>& op_list() const { return op_list_; }
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  // Represent a (partial) fused pattern
 | 
			
		||||
  class Cluster {
 | 
			
		||||
   public:
 | 
			
		||||
    Cluster(int node_id, FusionPlanner* planner) : node_id_(node_id) {
 | 
			
		||||
      const SmallVectorImpl<Operation*>& op_list = planner->op_list();
 | 
			
		||||
      pattern_.push_back(op_list[node_id]);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Merges `other` into this cluster, and clears `other`.
 | 
			
		||||
    void Merge(Cluster* other) {
 | 
			
		||||
      pattern_.insert(pattern_.end(), other->pattern_.begin(),
 | 
			
		||||
                      other->pattern_.end());
 | 
			
		||||
      other->pattern_.clear();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // The number of nodes in this cluster.
 | 
			
		||||
    int cluster_size() const { return pattern_.size(); }
 | 
			
		||||
 | 
			
		||||
    // The ID of the cluster as represented in `cycle_detector_`.
 | 
			
		||||
    int cycles_graph_node_id() const { return node_id_; }
 | 
			
		||||
 | 
			
		||||
    // Sets the ID of the cluster as represented in `cycle_detector_`.
 | 
			
		||||
    void set_cycles_graph_node_id(int cycles_graph_node_id) {
 | 
			
		||||
      node_id_ = cycles_graph_node_id;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Currently the fused pattern this cluster holds.
 | 
			
		||||
    const FusionPattern& fused_pattern() { return pattern_; }
 | 
			
		||||
 | 
			
		||||
   private:
 | 
			
		||||
    // ID of the representative node of this cluster.
 | 
			
		||||
    int node_id_;
 | 
			
		||||
 | 
			
		||||
    // the fused pattern this cluster holds.
 | 
			
		||||
    FusionPattern pattern_;
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  Cluster* MakeCluster(int cycles_graph_node_id) {
 | 
			
		||||
    cluster_storage_.emplace_back(new Cluster(cycles_graph_node_id, this));
 | 
			
		||||
    return cluster_storage_.back().get();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void BuildNodeMap() {
 | 
			
		||||
    int num_nodes = op_list_.size();
 | 
			
		||||
    for (int node_id = 0; node_id < num_nodes; ++node_id) {
 | 
			
		||||
      Operation* op = op_list_[node_id];
 | 
			
		||||
      MakeCluster(node_id);
 | 
			
		||||
      op_to_node_id_[op] = node_id;
 | 
			
		||||
      leader_for_node_.insert(node_id);
 | 
			
		||||
      for (Value operand : op->getOperands()) {
 | 
			
		||||
        Operation* operand_op = operand.getDefiningOp();
 | 
			
		||||
        if (operand_op == nullptr) {
 | 
			
		||||
          // skip block argument
 | 
			
		||||
          continue;
 | 
			
		||||
        }
 | 
			
		||||
        auto iter = op_to_node_id_.find(operand_op);
 | 
			
		||||
        assert(iter != op_to_node_id_.end());
 | 
			
		||||
        cycle_detector_.InsertEdge(iter->second, node_id);
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Returns the cluster contains this op.
 | 
			
		||||
  Cluster* GetClusterForNode(Operation* n) {
 | 
			
		||||
    int id = op_to_node_id_.at(n);
 | 
			
		||||
    id = leader_for_node_.getLeaderValue(id);
 | 
			
		||||
    return cluster_storage_[id].get();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Returns the cluster contains the op having `node_id`.
 | 
			
		||||
  Cluster* GetClusterForCyclesGraphNode(int node_id) {
 | 
			
		||||
    return cluster_storage_[leader_for_node_.getLeaderValue(node_id)].get();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Merges the clusters `cluster_from` and `cluster_to`.
 | 
			
		||||
  bool MergeClusters(Cluster* cluster_from, Cluster* cluster_to) {
 | 
			
		||||
    int from = cluster_from->cycles_graph_node_id();
 | 
			
		||||
    int to = cluster_to->cycles_graph_node_id();
 | 
			
		||||
 | 
			
		||||
    auto optional_merged_node = cycle_detector_.ContractEdge(from, to);
 | 
			
		||||
    if (!optional_merged_node.hasValue()) {
 | 
			
		||||
      llvm::dbgs() << "Could not contract " << from << " -> " << to
 | 
			
		||||
                   << " because contracting the edge would create a cycle.";
 | 
			
		||||
      return false;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Merge the clusters.
 | 
			
		||||
    cluster_from->Merge(cluster_to);
 | 
			
		||||
    cluster_from->set_cycles_graph_node_id(*optional_merged_node);
 | 
			
		||||
 | 
			
		||||
    // Merge the UnionFind Set.
 | 
			
		||||
    leader_for_node_.unionSets(from, to);
 | 
			
		||||
    return true;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  template <typename FnTy>
 | 
			
		||||
  bool ForEachEdgeInPostOrder(FnTy fn) {
 | 
			
		||||
    bool changed = false;
 | 
			
		||||
    for (int32_t node : cycle_detector_.AllNodesInPostOrder()) {
 | 
			
		||||
      Cluster* cluster_from = GetClusterForCyclesGraphNode(node);
 | 
			
		||||
      // Make a copy of the set of successors because we may modify the graph in
 | 
			
		||||
      // TryToContractEdge.
 | 
			
		||||
      std::vector<int32_t> successors_copy =
 | 
			
		||||
          cycle_detector_.SuccessorsCopy(cluster_from->cycles_graph_node_id());
 | 
			
		||||
 | 
			
		||||
      for (int to : successors_copy) {
 | 
			
		||||
        Cluster* cluster_to = GetClusterForCyclesGraphNode(to);
 | 
			
		||||
        bool contracted_edge = fn(cluster_from, cluster_to);
 | 
			
		||||
        changed |= contracted_edge;
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    return changed;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // returns the outputs if two cluster were merged
 | 
			
		||||
  SmallVector<Value, 4> GetResultsOfFusedPattern(Cluster* from, Cluster* to) {
 | 
			
		||||
    FusionPattern fused_pattern =
 | 
			
		||||
        MergeFusionPattern(from->fused_pattern(), to->fused_pattern());
 | 
			
		||||
    return GetOutputsOfFusionPattern(fused_pattern);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // This function check if fusing `from` with `to` is valid and if so perform
 | 
			
		||||
  // the merge. The validity is based on the operations in the clusters and
 | 
			
		||||
  // the compatibility of the shapes of the outputs of the would-be fused
 | 
			
		||||
  // clusters.
 | 
			
		||||
  // Returns true is the merge was performed.
 | 
			
		||||
  bool TryToContractEdge(Cluster* from, Cluster* to) {
 | 
			
		||||
    int node_to = to->cycles_graph_node_id();
 | 
			
		||||
    int node_from = from->cycles_graph_node_id();
 | 
			
		||||
 | 
			
		||||
    // Both node_to and node_from should be fusible
 | 
			
		||||
    if (!IsFusible(op_list_[node_to]) || !IsFusible(op_list_[node_from])) {
 | 
			
		||||
      return false;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    auto op_from_fusibility =
 | 
			
		||||
        dyn_cast<InferFusibilityOpInterface>(op_list_[node_from]);
 | 
			
		||||
    if (op_from_fusibility && !op_from_fusibility.isFusibleWithConsumer()) {
 | 
			
		||||
      // This op cannot be fused with its consumers.
 | 
			
		||||
      return false;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    auto op_to_fusibility =
 | 
			
		||||
        dyn_cast<InferFusibilityOpInterface>(op_list_[node_to]);
 | 
			
		||||
    if (op_to_fusibility && !op_to_fusibility.isFusibleWithOperand()) {
 | 
			
		||||
      // This op cannot be fused with its operands.
 | 
			
		||||
      return false;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Output shapes of a fusion pattern should be compatible as described in
 | 
			
		||||
    // the document of this class.
 | 
			
		||||
    SmallVector<Value, 4> results = GetResultsOfFusedPattern(from, to);
 | 
			
		||||
    auto get_workload_shape = [](Value v) {
 | 
			
		||||
      Operation* op = v.getDefiningOp();
 | 
			
		||||
      // Block argument
 | 
			
		||||
      if (!op) return v;
 | 
			
		||||
      auto op_fusibility = dyn_cast<InferFusibilityOpInterface>(op);
 | 
			
		||||
      // Const value
 | 
			
		||||
      if (!op_fusibility) return v;
 | 
			
		||||
      llvm::Optional<Value> workload =
 | 
			
		||||
          op_fusibility.inferEffectiveWorkloadShape();
 | 
			
		||||
      return workload.hasValue() ? *workload : v;
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    Value ref = get_workload_shape(results[0]);
 | 
			
		||||
    if (!llvm::all_of(results, [&](Value result) {
 | 
			
		||||
          Value val = get_workload_shape(result);
 | 
			
		||||
          return shape_analysis_.HasSameShape(ref, val);
 | 
			
		||||
        })) {
 | 
			
		||||
      return false;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    return MergeClusters(from, to);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Greedily fuse connected node.
 | 
			
		||||
  bool RunEdgeContractionLoop() {
 | 
			
		||||
    using std::placeholders::_1;
 | 
			
		||||
    using std::placeholders::_2;
 | 
			
		||||
    return ForEachEdgeInPostOrder(
 | 
			
		||||
        std::bind(&FusionPlanner::TryToContractEdge, this, _1, _2));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  const SmallVectorImpl<Operation*>& op_list_;
 | 
			
		||||
 | 
			
		||||
  // Shape equality checker
 | 
			
		||||
  ShapeConstraintAnalysis shape_analysis_;
 | 
			
		||||
 | 
			
		||||
  // op -> node_id
 | 
			
		||||
  std::unordered_map<Operation*, int> op_to_node_id_;
 | 
			
		||||
 | 
			
		||||
  // make sure not introduce cycle after fusion
 | 
			
		||||
  GraphCycles cycle_detector_;
 | 
			
		||||
  std::vector<std::unique_ptr<Cluster>> cluster_storage_;
 | 
			
		||||
 | 
			
		||||
  // a UnionFind set. Each set represents a (partial) fused pattern
 | 
			
		||||
  // and has a leader as representation.
 | 
			
		||||
  EquivalenceClasses<int32_t> leader_for_node_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct XlaHloFusion : public mlir::PassWrapper<XlaHloFusion, FunctionPass> {
 | 
			
		||||
  void runOnFunction() override {
 | 
			
		||||
    FuncOp func = getFunction();
 | 
			
		||||
    if (!IsTargetFunc(func)) {
 | 
			
		||||
      return;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // process each block and do fusion within a block.
 | 
			
		||||
    for (Block& block : func) {
 | 
			
		||||
      SmallVector<Operation*, 4> op_list;
 | 
			
		||||
      for (Operation& op : block) {
 | 
			
		||||
        op_list.push_back(&op);
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      FusionPlanner planner(op_list);
 | 
			
		||||
      llvm::Optional<FusionPlan> plan = planner.Run();
 | 
			
		||||
      if (!plan) {
 | 
			
		||||
        emitError(func.getLoc(), "can't find a fusion plan");
 | 
			
		||||
        signalPassFailure();
 | 
			
		||||
        return;
 | 
			
		||||
      }
 | 
			
		||||
      if (!ApplyFusionPlan(*plan)) {
 | 
			
		||||
        emitError(func.getLoc(), "apply fusion plan failed");
 | 
			
		||||
        signalPassFailure();
 | 
			
		||||
        return;
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  bool IsTargetFunc(FuncOp func) {
 | 
			
		||||
    int num_fusible_ops = 0;
 | 
			
		||||
    bool is_target_func = false;
 | 
			
		||||
    // We only process the function having enough candidates
 | 
			
		||||
    func.walk([&](Operation* op) {
 | 
			
		||||
      num_fusible_ops +=
 | 
			
		||||
          static_cast<int>(dyn_cast<InferFusibilityOpInterface>(op) != nullptr);
 | 
			
		||||
      is_target_func = (num_fusible_ops > 1);
 | 
			
		||||
      // early stop
 | 
			
		||||
      if (is_target_func) return WalkResult::interrupt();
 | 
			
		||||
      return WalkResult::advance();
 | 
			
		||||
    });
 | 
			
		||||
    return is_target_func;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  bool ApplyFusionPlan(const FusionPlan& plan) {
 | 
			
		||||
    for (const FusionPattern& pattern : plan) {
 | 
			
		||||
      OpBuilder b(pattern.back());
 | 
			
		||||
 | 
			
		||||
      SmallVector<Location, 4> locations;
 | 
			
		||||
      locations.reserve(pattern.size());
 | 
			
		||||
      for (Operation* op : pattern) {
 | 
			
		||||
        locations.push_back(op->getLoc());
 | 
			
		||||
      }
 | 
			
		||||
      Location fused_loc =
 | 
			
		||||
          FusedLoc::get(locations, pattern.back()->getContext());
 | 
			
		||||
 | 
			
		||||
      SmallVector<Value, 4> inputs = GetInputsOfFusionPattern(pattern);
 | 
			
		||||
      SmallVector<Value, 4> outputs = GetOutputsOfFusionPattern(pattern);
 | 
			
		||||
      SmallVector<Type, 4> output_types;
 | 
			
		||||
      output_types.reserve(outputs.size());
 | 
			
		||||
      for (Value v : outputs) {
 | 
			
		||||
        output_types.push_back(v.getType());
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      FusionOp fusion =
 | 
			
		||||
          b.create<xla_hlo::FusionOp>(fused_loc, output_types, inputs);
 | 
			
		||||
      Region& region = fusion.fused_computation();
 | 
			
		||||
      region.push_back(new Block);
 | 
			
		||||
      Block& block = region.front();
 | 
			
		||||
      for (Operation* op : pattern) {
 | 
			
		||||
        op->moveBefore(&block, block.end());
 | 
			
		||||
      }
 | 
			
		||||
      b.setInsertionPoint(&block, block.end());
 | 
			
		||||
      b.create<xla_hlo::ReturnOp>(fused_loc, outputs);
 | 
			
		||||
 | 
			
		||||
      for (auto output_and_result : llvm::zip(outputs, fusion.getResults())) {
 | 
			
		||||
        Value output = std::get<0>(output_and_result);
 | 
			
		||||
        Value fusion_result = std::get<1>(output_and_result);
 | 
			
		||||
        for (OpOperand& use : llvm::make_early_inc_range(output.getUses())) {
 | 
			
		||||
          if (use.getOwner()->getBlock() != &block) use.set(fusion_result);
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
    return true;
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
 | 
			
		||||
std::unique_ptr<OperationPass<FuncOp>> createXlaHloFusion() {
 | 
			
		||||
  return std::make_unique<XlaHloFusion>();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static PassRegistration<XlaHloFusion> xla_hlo_fusion_pass(
 | 
			
		||||
    "xla-hlo-fusion", "fuse xla_hlo ops to kLoop/kInput fusion patterns.");
 | 
			
		||||
 | 
			
		||||
}  // namespace xla_hlo
 | 
			
		||||
}  // namespace mlir
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,909 @@
 | 
			
		|||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
// This file implements logic for lowering HLO/LHLO dialect to Linalg dialect.
 | 
			
		||||
 | 
			
		||||
#include "third_party/absl/memory/memory.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/AffineExpr.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Attributes.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Function.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Location.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h"
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_xla_to_scalar_op.h"
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
 | 
			
		||||
 | 
			
		||||
namespace mlir {
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
SmallVector<StringRef, 3> GetNParallelLoopsAttrs(unsigned nParallelLoops) {
 | 
			
		||||
  static constexpr StringRef kParallelIterType = "parallel";
 | 
			
		||||
  return SmallVector<StringRef, 3>(nParallelLoops, kParallelIterType);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <bool isLHLO = true>
 | 
			
		||||
Value getResultValue(Operation* op) {
 | 
			
		||||
  return isLHLO ? op->getOperand(op->getNumOperands() - 1) : op->getResult(0);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <bool isLHLO = true>
 | 
			
		||||
ShapedType getXLAOpResultType(Operation* op) {
 | 
			
		||||
  return getResultValue<isLHLO>(op).getType().template cast<ShapedType>();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <bool isLHLO = true>
 | 
			
		||||
bool verifyXLAOpBufferOrTensorSemantics(Operation* op) {
 | 
			
		||||
  auto verifyType = [&](Value val) -> bool {
 | 
			
		||||
    return (isLHLO && val.getType().isa<MemRefType>()) ||
 | 
			
		||||
           (!isLHLO && val.getType().isa<RankedTensorType>());
 | 
			
		||||
  };
 | 
			
		||||
  if (!llvm::all_of(op->getOperands(), verifyType)) return false;
 | 
			
		||||
  return isLHLO ? op->getResults().empty()
 | 
			
		||||
                : llvm::all_of(op->getResults(), verifyType);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename OpTy, bool isLHLO = true>
 | 
			
		||||
class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
 | 
			
		||||
 public:
 | 
			
		||||
  using OpConversionPattern<OpTy>::OpConversionPattern;
 | 
			
		||||
 | 
			
		||||
  LogicalResult matchAndRewrite(
 | 
			
		||||
      OpTy op, ArrayRef<Value> args,
 | 
			
		||||
      ConversionPatternRewriter& rewriter) const final {
 | 
			
		||||
    auto loc = op.getLoc();
 | 
			
		||||
    auto argType =
 | 
			
		||||
        op.getOperation()->getOperand(0).getType().template cast<ShapedType>();
 | 
			
		||||
    if (!argType.hasRank()) {
 | 
			
		||||
      emitError(loc, "lhlo to linalg conversion expects ranked args");
 | 
			
		||||
      return failure();
 | 
			
		||||
    }
 | 
			
		||||
    auto elemTy = argType.getElementType();
 | 
			
		||||
    if (!elemTy.isSignlessIntOrFloat() && !elemTy.template isa<ComplexType>()) {
 | 
			
		||||
      return failure();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Construct the indexing maps needed for linalg.generic ops.
 | 
			
		||||
    SmallVector<AffineMap, 2> indexing_maps;
 | 
			
		||||
    SmallVector<Type, 4> bodyArgTypes, bodyResultTypes, opResultTypes;
 | 
			
		||||
 | 
			
		||||
    // This doesnt account for implicit broadcast, but the working assumption
 | 
			
		||||
    // here is that are broadcasts have been made explicit.
 | 
			
		||||
    unsigned nloops = argType.getRank();
 | 
			
		||||
 | 
			
		||||
    if (isLHLO && !nloops) return failure();
 | 
			
		||||
 | 
			
		||||
    int operandCount = (isLHLO ? args.size() - 1 : args.size());
 | 
			
		||||
    auto verifyArgOrResultType = [&](Value val) -> ShapedType {
 | 
			
		||||
      auto shapedType = val.getType().dyn_cast<ShapedType>();
 | 
			
		||||
      if (!shapedType ||
 | 
			
		||||
          (!shapedType.isa<MemRefType>() &&
 | 
			
		||||
           !shapedType.isa<RankedTensorType>()) ||
 | 
			
		||||
          shapedType.getRank() != nloops)
 | 
			
		||||
        return nullptr;
 | 
			
		||||
      indexing_maps.emplace_back(
 | 
			
		||||
          nloops ? rewriter.getMultiDimIdentityMap(nloops)
 | 
			
		||||
                 : AffineMap::get(nloops, 0, rewriter.getContext()));
 | 
			
		||||
      return shapedType;
 | 
			
		||||
    };
 | 
			
		||||
    for (const auto& arg : llvm::enumerate(args)) {
 | 
			
		||||
      auto shapedType = verifyArgOrResultType(arg.value());
 | 
			
		||||
      if (!shapedType) return failure();
 | 
			
		||||
      auto& result_or_body_arg =
 | 
			
		||||
          arg.index() < operandCount ? bodyArgTypes : bodyResultTypes;
 | 
			
		||||
      result_or_body_arg.emplace_back(shapedType.getElementType());
 | 
			
		||||
    }
 | 
			
		||||
    if (!isLHLO) {
 | 
			
		||||
      // HLO operations have return as tensor types.
 | 
			
		||||
      assert(bodyResultTypes.empty() &&
 | 
			
		||||
             "When lowering HLO ops result can't be part of arguments");
 | 
			
		||||
      Value result = op.getOperation()->getResult(0);
 | 
			
		||||
      auto shapedType = verifyArgOrResultType(result);
 | 
			
		||||
      if (!shapedType) return failure();
 | 
			
		||||
      bodyResultTypes.push_back(shapedType.getElementType());
 | 
			
		||||
      opResultTypes.push_back(shapedType);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    int64_t args_count = bodyArgTypes.size();
 | 
			
		||||
    int64_t results_count = bodyResultTypes.size();
 | 
			
		||||
    auto linalgOp = rewriter.create<linalg::GenericOp>(
 | 
			
		||||
        loc, opResultTypes, args, args_count, results_count, indexing_maps,
 | 
			
		||||
        GetNParallelLoopsAttrs(nloops),
 | 
			
		||||
        [&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) {
 | 
			
		||||
          // TODO(ravishankarm) : For now use the method in xla_lhlo namespace.
 | 
			
		||||
          // That method needs to be moved out of there.
 | 
			
		||||
          Value opResult = xla_lhlo::XlaOpToStdScalarOp::map<OpTy>(
 | 
			
		||||
              op, bodyResultTypes,
 | 
			
		||||
              llvm::to_vector<2>(args.take_front(args_count)), &rewriter);
 | 
			
		||||
          nestedBuilder.create<linalg::YieldOp>(loc, opResult);
 | 
			
		||||
        });
 | 
			
		||||
    rewriter.replaceOp(op, linalgOp.getOperation()->getResults());
 | 
			
		||||
    return success();
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
template <typename LhloOp>
 | 
			
		||||
class ScalarPointwiseToStandardConverter : public OpConversionPattern<LhloOp> {
 | 
			
		||||
 public:
 | 
			
		||||
  using OpConversionPattern<LhloOp>::OpConversionPattern;
 | 
			
		||||
 | 
			
		||||
  LogicalResult matchAndRewrite(
 | 
			
		||||
      LhloOp lhlo_op, ArrayRef<Value> args,
 | 
			
		||||
      ConversionPatternRewriter& rewriter) const final {
 | 
			
		||||
    auto loc = lhlo_op.getLoc();
 | 
			
		||||
    auto argType =
 | 
			
		||||
        lhlo_op.getOperand(0).getType().template dyn_cast<ShapedType>();
 | 
			
		||||
    if (!argType || !argType.getElementType().isSignlessIntOrFloat() ||
 | 
			
		||||
        (argType.getRank() != 0)) {
 | 
			
		||||
      return failure();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Create two loads from the input.
 | 
			
		||||
    auto lhs = rewriter.create<LoadOp>(loc, lhlo_op.lhs());
 | 
			
		||||
    auto rhs = rewriter.create<LoadOp>(loc, lhlo_op.rhs());
 | 
			
		||||
    // TODO(ravishankarm) : Move this method out of xla_lhlo namespace.
 | 
			
		||||
    Value opResult = xla_lhlo::XlaOpToStdScalarOp::map<LhloOp>(
 | 
			
		||||
        lhlo_op, argType.getElementType(), llvm::ArrayRef<Value>{lhs, rhs},
 | 
			
		||||
        &rewriter);
 | 
			
		||||
    rewriter.create<StoreOp>(loc, opResult, lhlo_op.out());
 | 
			
		||||
    rewriter.eraseOp(lhlo_op);
 | 
			
		||||
    return success();
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
// xla_lhlo.convolution conversion pattern.
 | 
			
		||||
//===----------------------------------------------------------------------===//
 | 
			
		||||
 | 
			
		||||
/// Converts xla_lhlo.convolution operation to a linalg.conv op.
 | 
			
		||||
struct ConvToLinalgConverter : public OpConversionPattern<xla_lhlo::ConvOp> {
 | 
			
		||||
 public:
 | 
			
		||||
  using OpConversionPattern<xla_lhlo::ConvOp>::OpConversionPattern;
 | 
			
		||||
 | 
			
		||||
  //  This code has been adapted from IREE's
 | 
			
		||||
  //  (https://github.com/google/iree/) xla_hlo -> linalg conversion.
 | 
			
		||||
  LogicalResult matchAndRewrite(
 | 
			
		||||
      xla_lhlo::ConvOp op, ArrayRef<Value> args,
 | 
			
		||||
      ConversionPatternRewriter& rewriter) const final {
 | 
			
		||||
    // Check validity of dimension information.
 | 
			
		||||
    if (const xla_lhlo::ConvDimensionNumbers& dimensionNumbers =
 | 
			
		||||
            op.dimension_numbers()) {
 | 
			
		||||
      const int inputSpatialRank =
 | 
			
		||||
          llvm::size(dimensionNumbers.input_spatial_dimensions());
 | 
			
		||||
      // The dimensions for input should follow the order of
 | 
			
		||||
      // batch_count, spatial_dims..., input_feature_count.
 | 
			
		||||
      if (dimensionNumbers.input_batch_dimension().getInt() != 0 ||
 | 
			
		||||
          dimensionNumbers.input_feature_dimension().getInt() !=
 | 
			
		||||
              (inputSpatialRank + 1))
 | 
			
		||||
        return failure();
 | 
			
		||||
 | 
			
		||||
      const int kernelSpatialRank =
 | 
			
		||||
          llvm::size(dimensionNumbers.kernel_spatial_dimensions());
 | 
			
		||||
      // The dimensions for filter should follow the order of
 | 
			
		||||
      // spatial_dims..., input_feature_count, num_output_feature_count.
 | 
			
		||||
      if (dimensionNumbers.kernel_input_feature_dimension().getInt() !=
 | 
			
		||||
              kernelSpatialRank ||
 | 
			
		||||
          dimensionNumbers.kernel_output_feature_dimension().getInt() !=
 | 
			
		||||
              (kernelSpatialRank + 1))
 | 
			
		||||
        return failure();
 | 
			
		||||
 | 
			
		||||
      const int outputSpatialRank =
 | 
			
		||||
          llvm::size(dimensionNumbers.output_spatial_dimensions());
 | 
			
		||||
      // The dimensions for output should follow the order of
 | 
			
		||||
      // batch_count, spatial_dims.., output_feature_count.
 | 
			
		||||
      if (dimensionNumbers.output_batch_dimension().getInt() != 0 ||
 | 
			
		||||
          dimensionNumbers.output_feature_dimension().getInt() !=
 | 
			
		||||
              (outputSpatialRank + 1))
 | 
			
		||||
        return failure();
 | 
			
		||||
 | 
			
		||||
      if (inputSpatialRank != outputSpatialRank ||
 | 
			
		||||
          inputSpatialRank != kernelSpatialRank)
 | 
			
		||||
        return failure();
 | 
			
		||||
 | 
			
		||||
      auto inputSpatialDim =
 | 
			
		||||
          dimensionNumbers.input_spatial_dimensions().begin();
 | 
			
		||||
      auto kernelSpatialDim =
 | 
			
		||||
          dimensionNumbers.kernel_spatial_dimensions().begin();
 | 
			
		||||
      auto outputSpatialDim =
 | 
			
		||||
          dimensionNumbers.output_spatial_dimensions().begin();
 | 
			
		||||
      // Check if spatial dims are ordered correctly.
 | 
			
		||||
      for (int i = 0; i < inputSpatialRank; ++i) {
 | 
			
		||||
        const int dim = i + 1;
 | 
			
		||||
        if ((*inputSpatialDim++).getZExtValue() != dim ||
 | 
			
		||||
            (*outputSpatialDim++).getZExtValue() != dim ||
 | 
			
		||||
            (*kernelSpatialDim++).getZExtValue() != i)
 | 
			
		||||
          return failure();
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // TODO: LHS dilation for deconvolution not supported yet.
 | 
			
		||||
    if (op.lhs_dilation()) {
 | 
			
		||||
      return failure();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    llvm::SmallVector<Attribute, 4> strides;
 | 
			
		||||
    if (auto windowStrides = op.window_strides()) {
 | 
			
		||||
      auto range = windowStrides->getAttributeValues();
 | 
			
		||||
      strides.assign(range.begin(), range.end());
 | 
			
		||||
    }
 | 
			
		||||
    auto stridesArg = ArrayAttr::get(strides, op.getContext());
 | 
			
		||||
 | 
			
		||||
    llvm::SmallVector<Attribute, 2> dilation;
 | 
			
		||||
    if (auto rhsDilation = op.rhs_dilation()) {
 | 
			
		||||
      auto range = rhsDilation->getAttributeValues();
 | 
			
		||||
      dilation.assign(range.begin(), range.end());
 | 
			
		||||
    } else {
 | 
			
		||||
      // Default dilation of 1.
 | 
			
		||||
      dilation.resize(2, IntegerAttr::get(rewriter.getIntegerType(64), 1));
 | 
			
		||||
    }
 | 
			
		||||
    auto dilationArg = ArrayAttr::get(dilation, op.getContext());
 | 
			
		||||
 | 
			
		||||
    // Set padding only if it is non-zero.
 | 
			
		||||
    DenseIntElementsAttr padding = op.paddingAttr();
 | 
			
		||||
    if (!padding || !llvm::any_of(padding.getValues<APInt>(), [](APInt intVal) {
 | 
			
		||||
          return !intVal.isNullValue();
 | 
			
		||||
        })) {
 | 
			
		||||
      padding = nullptr;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // The order of input and filter are switched with linalg.conv.
 | 
			
		||||
    rewriter.replaceOpWithNewOp<linalg::ConvOp>(
 | 
			
		||||
        op, args[1], args[0], args[2], stridesArg, dilationArg, padding);
 | 
			
		||||
    return success();
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
/// Base class for lowering xla operations that have one operand and one result,
 | 
			
		||||
/// and are semantically equivalent to a copy of the input to the output (like
 | 
			
		||||
/// transpose, some reshape, etc.). The derived classes need to provide a method
 | 
			
		||||
/// `getIndexingMaps` that returns AffineMaps for the index maps of the input
 | 
			
		||||
/// and the output.
 | 
			
		||||
template <typename Derived, typename OpTy, bool isLHLO = true>
 | 
			
		||||
class DataMovementOpConverter : public OpConversionPattern<OpTy> {
 | 
			
		||||
 public:
 | 
			
		||||
  using OpConversionPattern<OpTy>::OpConversionPattern;
 | 
			
		||||
 | 
			
		||||
  LogicalResult matchAndRewrite(
 | 
			
		||||
      OpTy op, ArrayRef<Value> args,
 | 
			
		||||
      ConversionPatternRewriter& rewriter) const final {
 | 
			
		||||
    if (!verifyXLAOpBufferOrTensorSemantics<isLHLO>(op)) return failure();
 | 
			
		||||
    auto resultType = getXLAOpResultType<isLHLO>(op);
 | 
			
		||||
 | 
			
		||||
    SmallVector<AffineMap, 2> indexing_maps =
 | 
			
		||||
        Derived::getIndexingMaps(op, &rewriter);
 | 
			
		||||
    if (indexing_maps.empty()) return failure();
 | 
			
		||||
 | 
			
		||||
    auto nloops = resultType.getRank();
 | 
			
		||||
    auto loc = op.getLoc();
 | 
			
		||||
    auto linalgOp = rewriter.create<linalg::GenericOp>(
 | 
			
		||||
        loc, isLHLO ? ArrayRef<Type>{} : resultType, args, /*inputCount=*/1,
 | 
			
		||||
        /*outputCount=*/1, indexing_maps, GetNParallelLoopsAttrs(nloops),
 | 
			
		||||
        [&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) {
 | 
			
		||||
          nestedBuilder.create<linalg::YieldOp>(loc, *args.begin());
 | 
			
		||||
        });
 | 
			
		||||
 | 
			
		||||
    rewriter.replaceOp(op, linalgOp.getOperation()->getResults());
 | 
			
		||||
    return success();
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
/// Pattern to convert BroadcastOp to Linalg ops.
 | 
			
		||||
template <typename OpTy, bool isLHLO = true>
 | 
			
		||||
class BroadcastConverter
 | 
			
		||||
    : public DataMovementOpConverter<BroadcastConverter<OpTy, isLHLO>, OpTy,
 | 
			
		||||
                                     isLHLO> {
 | 
			
		||||
 public:
 | 
			
		||||
  using DataMovementOpConverter<BroadcastConverter, OpTy,
 | 
			
		||||
                                isLHLO>::DataMovementOpConverter;
 | 
			
		||||
 | 
			
		||||
  static SmallVector<AffineMap, 2> getIndexingMaps(OpTy broadcastOp,
 | 
			
		||||
                                                   Builder* b) {
 | 
			
		||||
    ShapedType inputType =
 | 
			
		||||
        broadcastOp.operand().getType().template cast<ShapedType>();
 | 
			
		||||
    unsigned inputRank = inputType.getRank();
 | 
			
		||||
    unsigned nloops = getXLAOpResultType<isLHLO>(broadcastOp).getRank();
 | 
			
		||||
 | 
			
		||||
    // BroadcastOp prepends the dimensions in the `broadcast_sizes` attribute to
 | 
			
		||||
    // the input's dimensions.
 | 
			
		||||
    unsigned numPrependedDims = llvm::size(broadcastOp.broadcast_sizes());
 | 
			
		||||
    SmallVector<AffineExpr, 4> inputDimExprs;
 | 
			
		||||
    inputDimExprs.reserve(inputRank);
 | 
			
		||||
    for (int i = 0; i < inputRank; ++i) {
 | 
			
		||||
      inputDimExprs.push_back(b->getAffineDimExpr(numPrependedDims + i));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    AffineMap inputMap;
 | 
			
		||||
    MLIRContext* context = b->getContext();
 | 
			
		||||
    if (inputDimExprs.empty()) {
 | 
			
		||||
      // The input is a scalar, i.e. this is a scalar broadcast op.
 | 
			
		||||
      inputMap = AffineMap::get(nloops, /*symbolCount=*/0, context);
 | 
			
		||||
    } else {
 | 
			
		||||
      inputMap =
 | 
			
		||||
          AffineMap::get(nloops, /*symbolCount=*/0, inputDimExprs, context);
 | 
			
		||||
    }
 | 
			
		||||
    return {inputMap, b->getMultiDimIdentityMap(nloops)};
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
class HloBroadcastInDimConverter
 | 
			
		||||
    : public DataMovementOpConverter<HloBroadcastInDimConverter,
 | 
			
		||||
                                     xla_hlo::BroadcastInDimOp, false> {
 | 
			
		||||
 public:
 | 
			
		||||
  using DataMovementOpConverter<HloBroadcastInDimConverter,
 | 
			
		||||
                                xla_hlo::BroadcastInDimOp,
 | 
			
		||||
                                false>::DataMovementOpConverter;
 | 
			
		||||
 | 
			
		||||
  static SmallVector<AffineMap, 2> getIndexingMaps(
 | 
			
		||||
      xla_hlo::BroadcastInDimOp broadcastOp, Builder* b) {
 | 
			
		||||
    auto resultType = getXLAOpResultType<false>(broadcastOp);
 | 
			
		||||
    auto operandType =
 | 
			
		||||
        broadcastOp.operand().getType().template cast<ShapedType>();
 | 
			
		||||
    unsigned nloops = resultType.getRank();
 | 
			
		||||
 | 
			
		||||
    // The input is a scalar, i.e. this is a scalar broadcast op.
 | 
			
		||||
    if (operandType.getRank() == 0) {
 | 
			
		||||
      return {AffineMap::get(nloops, /*symbolCount=*/0, b->getContext()),
 | 
			
		||||
              b->getMultiDimIdentityMap(nloops)};
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    auto operandShape = operandType.getShape();
 | 
			
		||||
    SmallVector<AffineExpr, 4> dimExprs;
 | 
			
		||||
    dimExprs.reserve(nloops);
 | 
			
		||||
 | 
			
		||||
    if (broadcastOp.broadcast_dimensions()) {
 | 
			
		||||
      for (const auto& broadcastDim :
 | 
			
		||||
           enumerate(broadcastOp.broadcast_dimensions().getIntValues())) {
 | 
			
		||||
        int size = broadcastDim.value().getSExtValue();
 | 
			
		||||
        bool expansion_needed = operandShape[broadcastDim.index()] == 1 &&
 | 
			
		||||
                                resultType.getShape()[size] != 1;
 | 
			
		||||
        dimExprs.push_back(expansion_needed ? b->getAffineConstantExpr(0)
 | 
			
		||||
                                            : b->getAffineDimExpr(size));
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
    return {
 | 
			
		||||
        AffineMap::get(nloops, /*symbolCount=*/0, dimExprs, b->getContext()),
 | 
			
		||||
        b->getMultiDimIdentityMap(nloops)};
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
class LhloBroadcastInDimConverter
 | 
			
		||||
    : public OpConversionPattern<xla_lhlo::BroadcastInDimOp> {
 | 
			
		||||
 public:
 | 
			
		||||
  using OpConversionPattern<xla_lhlo::BroadcastInDimOp>::OpConversionPattern;
 | 
			
		||||
 | 
			
		||||
  LogicalResult matchAndRewrite(
 | 
			
		||||
      xla_lhlo::BroadcastInDimOp op, ArrayRef<Value> args,
 | 
			
		||||
      ConversionPatternRewriter& rewriter) const final {
 | 
			
		||||
    xla_lhlo::BroadcastInDimOp::Adaptor operand_adaptor(args);
 | 
			
		||||
    auto result_type = operand_adaptor.output().getType().cast<MemRefType>();
 | 
			
		||||
    auto result_shape = result_type.getShape();
 | 
			
		||||
 | 
			
		||||
    auto operand_and_dims = InsertReshapeIfNecessary(op, args, rewriter);
 | 
			
		||||
 | 
			
		||||
    Value operand = std::get<0>(operand_and_dims);
 | 
			
		||||
    auto broadcast_dims = std::get<1>(operand_and_dims);
 | 
			
		||||
 | 
			
		||||
    auto loc = op.getLoc();
 | 
			
		||||
    auto nloops = result_type.getRank();
 | 
			
		||||
    auto operand_type = operand.getType().cast<MemRefType>();
 | 
			
		||||
 | 
			
		||||
    // For a degenerate case, i.e. broadcasting with expansion of
 | 
			
		||||
    // memref<1xELEMENT_TYPE>, the operand is not passed to `linalg.generic`.
 | 
			
		||||
    // Instead the value is loaded and used directly in `linalg.yield`.
 | 
			
		||||
    if (operand_type.getRank() == 1 &&
 | 
			
		||||
        operand_type.getDimSize(0) <
 | 
			
		||||
            result_type.getDimSize(broadcast_dims.front())) {
 | 
			
		||||
      Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
 | 
			
		||||
      Value val =
 | 
			
		||||
          rewriter.create<LoadOp>(loc, operand, llvm::makeArrayRef({zero}));
 | 
			
		||||
      rewriter.create<linalg::GenericOp>(
 | 
			
		||||
          loc, llvm::None, llvm::makeArrayRef(operand_adaptor.output()),
 | 
			
		||||
          /*inputCount=*/0, /*outputCount=*/1,
 | 
			
		||||
          llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
 | 
			
		||||
          GetNParallelLoopsAttrs(nloops),
 | 
			
		||||
          [&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) {
 | 
			
		||||
            nestedBuilder.create<linalg::YieldOp>(loc, val);
 | 
			
		||||
          });
 | 
			
		||||
 | 
			
		||||
    } else {
 | 
			
		||||
      auto indexing_maps = getIndexingMaps(op, broadcast_dims, result_shape,
 | 
			
		||||
                                           operand_type, &rewriter);
 | 
			
		||||
      rewriter.create<linalg::GenericOp>(
 | 
			
		||||
          loc, llvm::None,
 | 
			
		||||
          llvm::makeArrayRef({operand, operand_adaptor.output()}),
 | 
			
		||||
          /*inputCount=*/1, /*outputCount=*/1, indexing_maps,
 | 
			
		||||
          GetNParallelLoopsAttrs(nloops),
 | 
			
		||||
          [&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) {
 | 
			
		||||
            nestedBuilder.create<linalg::YieldOp>(loc, *args.begin());
 | 
			
		||||
          });
 | 
			
		||||
    }
 | 
			
		||||
    rewriter.replaceOp(op, llvm::None);
 | 
			
		||||
    return success();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Inserts 'linalg.reshape' if there is a size-1 dim expansion.
 | 
			
		||||
  std::pair<Value, SmallVector<int64_t, 2>> InsertReshapeIfNecessary(
 | 
			
		||||
      xla_lhlo::BroadcastInDimOp op, ArrayRef<Value> args,
 | 
			
		||||
      ConversionPatternRewriter& rewriter) const {
 | 
			
		||||
    xla_lhlo::BroadcastInDimOp::Adaptor operand_adaptor(args);
 | 
			
		||||
    Value operand = operand_adaptor.operand();
 | 
			
		||||
    auto operand_type = operand_adaptor.operand().getType().cast<MemRefType>();
 | 
			
		||||
    auto operand_shape = operand_type.getShape();
 | 
			
		||||
 | 
			
		||||
    Value result = operand_adaptor.output();
 | 
			
		||||
    auto result_type = result.getType().cast<MemRefType>();
 | 
			
		||||
    auto result_shape = result_type.getShape();
 | 
			
		||||
 | 
			
		||||
    SmallVector<int64_t, 2> operand_strides;
 | 
			
		||||
    int64_t operand_offset;
 | 
			
		||||
    if (failed(getStridesAndOffset(operand_type, operand_strides,
 | 
			
		||||
                                   operand_offset))) {
 | 
			
		||||
      op.emitOpError() << "Failed to get offset and strides.";
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    SmallVector<int64_t, 2> new_shape, new_strides, broadcast_dims;
 | 
			
		||||
    SmallVector<linalg::ReassociationIndices, 4> collapsed_dims_list;
 | 
			
		||||
    linalg::ReassociationIndices collapsed_dims;
 | 
			
		||||
    for (const auto& item :
 | 
			
		||||
         enumerate(op.broadcast_dimensions().getIntValues())) {
 | 
			
		||||
      size_t index = item.index();
 | 
			
		||||
      int dim = item.value().getSExtValue();
 | 
			
		||||
 | 
			
		||||
      collapsed_dims.push_back(index);
 | 
			
		||||
 | 
			
		||||
      bool expansion_needed =
 | 
			
		||||
          operand_shape[index] == 1 && result_shape[dim] != 1;
 | 
			
		||||
      if (expansion_needed) {
 | 
			
		||||
        continue;
 | 
			
		||||
      }
 | 
			
		||||
      new_shape.push_back(operand_shape[index]);
 | 
			
		||||
      new_strides.push_back(operand_strides[index]);
 | 
			
		||||
      broadcast_dims.push_back(dim);
 | 
			
		||||
 | 
			
		||||
      collapsed_dims_list.push_back(collapsed_dims);
 | 
			
		||||
      collapsed_dims.clear();
 | 
			
		||||
    }
 | 
			
		||||
    // If `collapsed_dims_list` is empty, then the memref has shape [1, ..., 1]
 | 
			
		||||
    // and all dimensions need expansion. Such memref will be reshaped to a 1D
 | 
			
		||||
    // memref with a single element. New shape and strides needs to be updated
 | 
			
		||||
    // accordingly.
 | 
			
		||||
    if (collapsed_dims_list.empty()) {
 | 
			
		||||
      collapsed_dims_list.push_back({});
 | 
			
		||||
      new_shape.push_back(1);
 | 
			
		||||
      new_strides.push_back(1);
 | 
			
		||||
      broadcast_dims.push_back(0);
 | 
			
		||||
    }
 | 
			
		||||
    for (const auto& dims : collapsed_dims) {
 | 
			
		||||
      collapsed_dims_list.back().push_back(dims);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // `linalg.reshape` is inserted only if necessary, i.e. when the rank can be
 | 
			
		||||
    // reduced.
 | 
			
		||||
    if (new_shape.size() < operand_shape.size()) {
 | 
			
		||||
      auto new_memref_type = MemRefType::get(
 | 
			
		||||
          new_shape, operand_type.getElementType(),
 | 
			
		||||
          makeStridedLinearLayoutMap(new_strides, operand_offset,
 | 
			
		||||
                                     rewriter.getContext()));
 | 
			
		||||
      operand = rewriter.create<linalg::ReshapeOp>(op.getLoc(), new_memref_type,
 | 
			
		||||
                                                   operand_adaptor.operand(),
 | 
			
		||||
                                                   collapsed_dims_list);
 | 
			
		||||
    }
 | 
			
		||||
    return std::make_pair(operand, broadcast_dims);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  SmallVector<AffineMap, 2> getIndexingMaps(xla_lhlo::BroadcastInDimOp op,
 | 
			
		||||
                                            ArrayRef<int64_t> broadcastDims,
 | 
			
		||||
                                            ArrayRef<int64_t> resultShape,
 | 
			
		||||
                                            MemRefType operandType,
 | 
			
		||||
                                            Builder* b) const {
 | 
			
		||||
    unsigned nloops = resultShape.size();
 | 
			
		||||
 | 
			
		||||
    // The input is a scalar, i.e. this is a scalar broadcast op.
 | 
			
		||||
    if (operandType.getRank() == 0) {
 | 
			
		||||
      return {AffineMap::get(nloops, /*symbolCount=*/0, b->getContext()),
 | 
			
		||||
              b->getMultiDimIdentityMap(nloops)};
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    auto operandShape = operandType.getShape();
 | 
			
		||||
    SmallVector<AffineExpr, 4> dimExprs;
 | 
			
		||||
    dimExprs.reserve(nloops);
 | 
			
		||||
 | 
			
		||||
    for (const auto& broadcastDim : llvm::enumerate(broadcastDims)) {
 | 
			
		||||
      int size = broadcastDim.value();
 | 
			
		||||
      bool expansion_needed =
 | 
			
		||||
          operandShape[broadcastDim.index()] == 1 && resultShape[size] != 1;
 | 
			
		||||
      if (expansion_needed) {
 | 
			
		||||
        op.emitOpError(
 | 
			
		||||
            "BroadcastInDimOp lowering to Linalg does not support size-1 "
 | 
			
		||||
            "dimensions expansion.");
 | 
			
		||||
      }
 | 
			
		||||
      dimExprs.push_back(b->getAffineDimExpr(size));
 | 
			
		||||
    }
 | 
			
		||||
    return {
 | 
			
		||||
        AffineMap::get(nloops, /*symbolCount=*/0, dimExprs, b->getContext()),
 | 
			
		||||
        b->getMultiDimIdentityMap(nloops)};
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
template <typename OpTy, bool isLHLO = true>
 | 
			
		||||
class TransposeConverter
 | 
			
		||||
    : public DataMovementOpConverter<TransposeConverter<OpTy, isLHLO>, OpTy,
 | 
			
		||||
                                     isLHLO> {
 | 
			
		||||
 public:
 | 
			
		||||
  using DataMovementOpConverter<TransposeConverter<OpTy, isLHLO>, OpTy,
 | 
			
		||||
                                isLHLO>::DataMovementOpConverter;
 | 
			
		||||
  static SmallVector<AffineMap, 2> getIndexingMaps(OpTy op, Builder* b) {
 | 
			
		||||
    auto resultType =
 | 
			
		||||
        getXLAOpResultType<isLHLO>(op).template cast<ShapedType>();
 | 
			
		||||
    auto nloops = resultType.getRank();
 | 
			
		||||
    SmallVector<AffineExpr, 2> inputExprs;
 | 
			
		||||
    inputExprs.resize(resultType.getRank());
 | 
			
		||||
    for (auto permutation : llvm::enumerate(op.permutation())) {
 | 
			
		||||
      inputExprs[permutation.value().getZExtValue()] =
 | 
			
		||||
          b->getAffineDimExpr(permutation.index());
 | 
			
		||||
    }
 | 
			
		||||
    return {
 | 
			
		||||
        AffineMap::get(nloops, /*symbolCount=*/0, inputExprs, b->getContext()),
 | 
			
		||||
        b->getMultiDimIdentityMap(nloops)};
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// Converts reshape ops that can be proven to be either a collapse of dimensions
 | 
			
		||||
// or expansion of dimensions of the operand.
 | 
			
		||||
template <typename OpTy, bool isLHLO = true>
 | 
			
		||||
class ReshapeOpConverter : public OpConversionPattern<OpTy> {
 | 
			
		||||
 public:
 | 
			
		||||
  using OpConversionPattern<OpTy>::OpConversionPattern;
 | 
			
		||||
 | 
			
		||||
  LogicalResult matchAndRewrite(
 | 
			
		||||
      OpTy reshapeOp, ArrayRef<Value> args,
 | 
			
		||||
      ConversionPatternRewriter& rewriter) const final {
 | 
			
		||||
    if (!verifyXLAOpBufferOrTensorSemantics<isLHLO>(reshapeOp))
 | 
			
		||||
      return failure();
 | 
			
		||||
    ShapedType operandType =
 | 
			
		||||
        reshapeOp.operand().getType().template cast<ShapedType>();
 | 
			
		||||
    ShapedType resultType = getXLAOpResultType<isLHLO>(reshapeOp);
 | 
			
		||||
 | 
			
		||||
    if (!operandType.hasStaticShape() || !resultType.hasStaticShape())
 | 
			
		||||
      return failure();
 | 
			
		||||
 | 
			
		||||
    // Compute the reassociation maps for the linalg operation.
 | 
			
		||||
    ArrayRef<int64_t> srcShape =
 | 
			
		||||
        (operandType.getRank() > resultType.getRank() ? operandType.getShape()
 | 
			
		||||
                                                      : resultType.getShape());
 | 
			
		||||
    ArrayRef<int64_t> dstShape =
 | 
			
		||||
        (operandType.getRank() > resultType.getRank() ? resultType.getShape()
 | 
			
		||||
                                                      : operandType.getShape());
 | 
			
		||||
    unsigned currSrcDim = 0, currDstDim = 0;
 | 
			
		||||
    SmallVector<linalg::ReassociationExprs, 4> reassociationMap(
 | 
			
		||||
        dstShape.size());
 | 
			
		||||
    while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) {
 | 
			
		||||
      int64_t dstSize = dstShape[currDstDim];
 | 
			
		||||
      int64_t srcSize = srcShape[currSrcDim];
 | 
			
		||||
      while (srcSize < dstSize && currSrcDim < srcShape.size()) {
 | 
			
		||||
        reassociationMap[currDstDim].push_back(
 | 
			
		||||
            rewriter.getAffineDimExpr(currSrcDim++));
 | 
			
		||||
        srcSize *= srcShape[currSrcDim];
 | 
			
		||||
      }
 | 
			
		||||
      if (srcSize == dstSize) {
 | 
			
		||||
        reassociationMap[currDstDim].push_back(
 | 
			
		||||
            rewriter.getAffineDimExpr(currSrcDim++));
 | 
			
		||||
        // If the next dim in dstShape is not 1, treat subsequent dims in
 | 
			
		||||
        // srcShape which are 1 to be collapsed.
 | 
			
		||||
        if (currDstDim == dstShape.size() - 1 ||
 | 
			
		||||
            dstShape[currDstDim + 1] != 1) {
 | 
			
		||||
          while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) {
 | 
			
		||||
            reassociationMap[currDstDim].push_back(
 | 
			
		||||
                rewriter.getAffineDimExpr(currSrcDim++));
 | 
			
		||||
          }
 | 
			
		||||
        }
 | 
			
		||||
      } else {
 | 
			
		||||
        return failure();
 | 
			
		||||
      }
 | 
			
		||||
      currDstDim++;
 | 
			
		||||
    }
 | 
			
		||||
    if (currSrcDim != srcShape.size()) return failure();
 | 
			
		||||
 | 
			
		||||
    if (isLHLO) {
 | 
			
		||||
      Value reshapeBuffer = rewriter.create<linalg::ReshapeOp>(
 | 
			
		||||
          reshapeOp.getLoc(), resultType, args[0], reassociationMap);
 | 
			
		||||
      rewriter.replaceOpWithNewOp<linalg::CopyOp>(
 | 
			
		||||
          reshapeOp, reshapeBuffer, args[1], /*inputPermutation =*/nullptr,
 | 
			
		||||
          /*outputPermutation =*/nullptr);
 | 
			
		||||
    } else {
 | 
			
		||||
      rewriter.replaceOpWithNewOp<linalg::TensorReshapeOp>(
 | 
			
		||||
          reshapeOp, resultType, args[0], reassociationMap);
 | 
			
		||||
    }
 | 
			
		||||
    return success();
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
class IotaConverter : public OpConversionPattern<xla_lhlo::IotaOp> {
 | 
			
		||||
 public:
 | 
			
		||||
  using OpConversionPattern<xla_lhlo::IotaOp>::OpConversionPattern;
 | 
			
		||||
 | 
			
		||||
  LogicalResult matchAndRewrite(
 | 
			
		||||
      xla_lhlo::IotaOp iotaOp, ArrayRef<Value> args,
 | 
			
		||||
      ConversionPatternRewriter& rewriter) const final {
 | 
			
		||||
    auto resultMemrefType =
 | 
			
		||||
        iotaOp.getOperand().getType().dyn_cast<MemRefType>();
 | 
			
		||||
    if (!resultMemrefType) return failure();
 | 
			
		||||
 | 
			
		||||
    auto resultElementType = resultMemrefType.getElementType();
 | 
			
		||||
    if (!resultElementType.isSignlessIntOrFloat()) return failure();
 | 
			
		||||
 | 
			
		||||
    // Construct the indexing maps needed for linalg.generic ops.
 | 
			
		||||
    unsigned nloops = resultMemrefType.getRank();
 | 
			
		||||
 | 
			
		||||
    rewriter.create<linalg::IndexedGenericOp>(
 | 
			
		||||
        iotaOp.getLoc(), ArrayRef<Type>{}, args,
 | 
			
		||||
        0,  // args_in
 | 
			
		||||
        1,  // args_out
 | 
			
		||||
        llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
 | 
			
		||||
        GetNParallelLoopsAttrs(nloops),
 | 
			
		||||
        [&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange ivs,
 | 
			
		||||
            ValueRange args) {
 | 
			
		||||
          Value castOp = nestedBuilder.create<IndexCastOp>(
 | 
			
		||||
              nestedLoc, ivs[iotaOp.iota_dimension().getZExtValue()],
 | 
			
		||||
              nestedBuilder.getIntegerType(
 | 
			
		||||
                  resultElementType.getIntOrFloatBitWidth()));
 | 
			
		||||
          if (resultElementType.isa<FloatType>()) {
 | 
			
		||||
            castOp = nestedBuilder.create<SIToFPOp>(nestedLoc, castOp,
 | 
			
		||||
                                                    resultElementType);
 | 
			
		||||
          }
 | 
			
		||||
          nestedBuilder.create<linalg::YieldOp>(nestedLoc, castOp);
 | 
			
		||||
        });
 | 
			
		||||
 | 
			
		||||
    rewriter.replaceOp(iotaOp, llvm::None);
 | 
			
		||||
    return success();
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
class ConstConverter : public OpConversionPattern<xla_lhlo::ConstOp> {
 | 
			
		||||
 public:
 | 
			
		||||
  using OpConversionPattern<xla_lhlo::ConstOp>::OpConversionPattern;
 | 
			
		||||
 | 
			
		||||
  LogicalResult matchAndRewrite(
 | 
			
		||||
      xla_lhlo::ConstOp constOp, ArrayRef<Value> args,
 | 
			
		||||
      ConversionPatternRewriter& rewriter) const final {
 | 
			
		||||
    auto loc = constOp.getLoc();
 | 
			
		||||
    auto valueAttr = constOp.value().cast<DenseElementsAttr>();
 | 
			
		||||
    if (valueAttr.getType().getRank() != 0) return failure();
 | 
			
		||||
    auto stdConstOp =
 | 
			
		||||
        rewriter.create<mlir::ConstantOp>(loc, valueAttr.getValue({}));
 | 
			
		||||
    rewriter.create<mlir::StoreOp>(loc, stdConstOp, constOp.getOperand());
 | 
			
		||||
    rewriter.eraseOp(constOp);
 | 
			
		||||
    return success();
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// TODO(b/156787842): Support the lowering for dynamic shapes.
 | 
			
		||||
template <typename OpTy, bool isLHLO = true>
 | 
			
		||||
class ReverseConverter
 | 
			
		||||
    : public DataMovementOpConverter<ReverseConverter<OpTy, isLHLO>, OpTy,
 | 
			
		||||
                                     isLHLO> {
 | 
			
		||||
 public:
 | 
			
		||||
  using DataMovementOpConverter<ReverseConverter<OpTy, isLHLO>, OpTy,
 | 
			
		||||
                                isLHLO>::DataMovementOpConverter;
 | 
			
		||||
  static SmallVector<AffineMap, 2> getIndexingMaps(OpTy op, Builder* b) {
 | 
			
		||||
    auto resultType =
 | 
			
		||||
        getXLAOpResultType<isLHLO>(op).template cast<ShapedType>();
 | 
			
		||||
    auto nloops = resultType.getRank();
 | 
			
		||||
    SmallVector<AffineExpr, 2> inputExprs;
 | 
			
		||||
    inputExprs.reserve(nloops);
 | 
			
		||||
    for (int i = 0; i < nloops; ++i)
 | 
			
		||||
      inputExprs.push_back(b->getAffineDimExpr(i));
 | 
			
		||||
    for (auto dim : op.dimensions()) {
 | 
			
		||||
      int i = dim.getZExtValue();
 | 
			
		||||
      if (resultType.isDynamicDim(i)) return {};
 | 
			
		||||
      int n = resultType.getShape()[i];
 | 
			
		||||
      inputExprs[i] = b->getAffineConstantExpr(n - 1) - inputExprs[i];
 | 
			
		||||
    }
 | 
			
		||||
    return {
 | 
			
		||||
        AffineMap::get(nloops, /*symbolCount=*/0, inputExprs, b->getContext()),
 | 
			
		||||
        b->getMultiDimIdentityMap(nloops)};
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
class SliceConverter : public OpConversionPattern<xla_lhlo::SliceOp> {
 | 
			
		||||
 public:
 | 
			
		||||
  using OpConversionPattern<xla_lhlo::SliceOp>::OpConversionPattern;
 | 
			
		||||
 | 
			
		||||
  LogicalResult matchAndRewrite(
 | 
			
		||||
      xla_lhlo::SliceOp sliceOp, ArrayRef<Value> args,
 | 
			
		||||
      ConversionPatternRewriter& rewriter) const final {
 | 
			
		||||
    auto loc = sliceOp.getLoc();
 | 
			
		||||
    auto argType =
 | 
			
		||||
        sliceOp.getOperand(0).getType().template dyn_cast<ShapedType>();
 | 
			
		||||
    if (!argType || !argType.hasRank()) {
 | 
			
		||||
      emitError(loc, "lhlo to linalg conversion expects known-rank args");
 | 
			
		||||
      return failure();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    SmallVector<Value, 3> ranges;
 | 
			
		||||
    for (int i = 0, e = argType.getRank(); i < e; ++i) {
 | 
			
		||||
      Value start_index = rewriter.create<ConstantIndexOp>(
 | 
			
		||||
          loc, sliceOp.start_indices().getValue<int64_t>(i));
 | 
			
		||||
      Value limit_index = rewriter.create<ConstantIndexOp>(
 | 
			
		||||
          loc, sliceOp.limit_indices().getValue<int64_t>(i));
 | 
			
		||||
      Value stride = rewriter.create<ConstantIndexOp>(
 | 
			
		||||
          loc, sliceOp.strides().getValue<int64_t>(i));
 | 
			
		||||
      ranges.push_back(rewriter.create<linalg::RangeOp>(loc, start_index,
 | 
			
		||||
                                                        limit_index, stride));
 | 
			
		||||
    }
 | 
			
		||||
    auto linalg_slice =
 | 
			
		||||
        rewriter.create<linalg::SliceOp>(loc, sliceOp.getOperand(0), ranges);
 | 
			
		||||
    rewriter.create<linalg::CopyOp>(loc, linalg_slice, sliceOp.getOperand(1));
 | 
			
		||||
    rewriter.eraseOp(sliceOp);
 | 
			
		||||
    return success();
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
void populateLHLOToLinalgConversionPattern(MLIRContext* context,
 | 
			
		||||
                                           OwningRewritePatternList* patterns) {
 | 
			
		||||
  // clang-format off
 | 
			
		||||
  patterns->insert<BroadcastConverter<xla_lhlo::BroadcastOp>,
 | 
			
		||||
                   ConstConverter,
 | 
			
		||||
                   ConvToLinalgConverter,
 | 
			
		||||
                   IotaConverter,
 | 
			
		||||
                   LhloBroadcastInDimConverter,
 | 
			
		||||
                   PointwiseToLinalgConverter<xla_lhlo::AbsOp>,
 | 
			
		||||
                   PointwiseToLinalgConverter<xla_lhlo::AddOp>,
 | 
			
		||||
                   PointwiseToLinalgConverter<xla_lhlo::AndOp>,
 | 
			
		||||
                   PointwiseToLinalgConverter<xla_lhlo::CeilOp>,
 | 
			
		||||
                   PointwiseToLinalgConverter<xla_lhlo::CompareOp>,
 | 
			
		||||
                   PointwiseToLinalgConverter<xla_lhlo::ComplexOp>,
 | 
			
		||||
                   PointwiseToLinalgConverter<xla_lhlo::ConvertOp>,
 | 
			
		||||
                   // TODO(ataei): Remove this pattern, CopyOp is folded away.
 | 
			
		||||
                   PointwiseToLinalgConverter<xla_lhlo::CopyOp>,
 | 
			
		||||
                   PointwiseToLinalgConverter<xla_lhlo::CosOp>,
 | 
			
		||||
                   PointwiseToLinalgConverter<xla_lhlo::DivOp>,
 | 
			
		||||
                   PointwiseToLinalgConverter<xla_lhlo::ExpOp>,
 | 
			
		||||
                   PointwiseToLinalgConverter<xla_lhlo::ImagOp>,
 | 
			
		||||
                   PointwiseToLinalgConverter<xla_lhlo::LogOp>,
 | 
			
		||||
                   PointwiseToLinalgConverter<xla_lhlo::MaxOp>,
 | 
			
		||||
                   PointwiseToLinalgConverter<xla_lhlo::MinOp>,
 | 
			
		||||
                   PointwiseToLinalgConverter<xla_lhlo::MulOp>,
 | 
			
		||||
                   PointwiseToLinalgConverter<xla_lhlo::NegOp>,
 | 
			
		||||
                   PointwiseToLinalgConverter<xla_lhlo::RealOp>,
 | 
			
		||||
                   PointwiseToLinalgConverter<xla_lhlo::RemOp>,
 | 
			
		||||
                   PointwiseToLinalgConverter<xla_lhlo::RsqrtOp>,
 | 
			
		||||
                   PointwiseToLinalgConverter<xla_lhlo::SelectOp>,
 | 
			
		||||
                   PointwiseToLinalgConverter<xla_lhlo::SignOp>,
 | 
			
		||||
                   PointwiseToLinalgConverter<xla_lhlo::SinOp>,
 | 
			
		||||
                   PointwiseToLinalgConverter<xla_lhlo::SqrtOp>,
 | 
			
		||||
                   PointwiseToLinalgConverter<xla_lhlo::SubOp>,
 | 
			
		||||
                   PointwiseToLinalgConverter<xla_lhlo::TanhOp>,
 | 
			
		||||
                   ReshapeOpConverter<xla_lhlo::ReshapeOp>,
 | 
			
		||||
                   ReverseConverter<xla_lhlo::ReverseOp>,
 | 
			
		||||
                   ScalarPointwiseToStandardConverter<xla_lhlo::AddOp>,
 | 
			
		||||
                   SliceConverter
 | 
			
		||||
                  >(context);
 | 
			
		||||
  // clang-format on
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Converts LHLO ops to Linalg generic.
 | 
			
		||||
// Sample result for xla_lhlo::AddOp.
 | 
			
		||||
//
 | 
			
		||||
// "xla_lhlo.add"(%arg1, %arg2, %out) :
 | 
			
		||||
//      (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
 | 
			
		||||
//
 | 
			
		||||
// will be converted to
 | 
			
		||||
//
 | 
			
		||||
// #map0 = (d0, d1) -> (d0, d1)
 | 
			
		||||
// "linalg.generic"(%arg1, %arg2, %out) ( {
 | 
			
		||||
//   ^bb0(%arg4: f32, %arg5: f32):
 | 
			
		||||
//     %0 = addf %arg4, %arg5 : f32
 | 
			
		||||
//     "linalg.yield"(%0) : (f32) -> ()
 | 
			
		||||
// }) {
 | 
			
		||||
//     args_in = 2,
 | 
			
		||||
//     args_out = 1,
 | 
			
		||||
//     indexing_maps = [#map0, #map0, #map0],
 | 
			
		||||
//     iterator_types = ["parallel", "parallel"],
 | 
			
		||||
// } : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
 | 
			
		||||
struct LhloLegalizeToLinalg
 | 
			
		||||
    : public PassWrapper<LhloLegalizeToLinalg, FunctionPass> {
 | 
			
		||||
  void runOnFunction() override {
 | 
			
		||||
    OwningRewritePatternList patterns;
 | 
			
		||||
    ConversionTarget target(getContext());
 | 
			
		||||
    target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect>();
 | 
			
		||||
 | 
			
		||||
    auto func = getFunction();
 | 
			
		||||
    populateLHLOToLinalgConversionPattern(func.getContext(), &patterns);
 | 
			
		||||
    if (failed(applyPartialConversion(func, target, patterns, nullptr))) {
 | 
			
		||||
      signalPassFailure();
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct HloLegalizeToLinalg
 | 
			
		||||
    : public PassWrapper<HloLegalizeToLinalg, FunctionPass> {
 | 
			
		||||
  void runOnFunction() override {
 | 
			
		||||
    OwningRewritePatternList patterns;
 | 
			
		||||
    ConversionTarget target(getContext());
 | 
			
		||||
    target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect>();
 | 
			
		||||
 | 
			
		||||
    auto func = getFunction();
 | 
			
		||||
    xla_hlo::populateHLOToLinalgConversionPattern(func.getContext(), &patterns);
 | 
			
		||||
    if (failed(applyPartialConversion(func, target, patterns, nullptr))) {
 | 
			
		||||
      signalPassFailure();
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
 | 
			
		||||
namespace xla_lhlo {
 | 
			
		||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeLhloToLinalgPass() {
 | 
			
		||||
  return absl::make_unique<LhloLegalizeToLinalg>();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static PassRegistration<LhloLegalizeToLinalg> legalize_lhlo_pass(
 | 
			
		||||
    "lhlo-legalize-to-linalg", "Legalize from LHLO dialect to Linalg dialect");
 | 
			
		||||
}  // namespace xla_lhlo
 | 
			
		||||
 | 
			
		||||
namespace xla_hlo {
 | 
			
		||||
 | 
			
		||||
void populateHLOToLinalgConversionPattern(MLIRContext* context,
 | 
			
		||||
                                          OwningRewritePatternList* patterns) {
 | 
			
		||||
  patterns->insert<BroadcastConverter<xla_hlo::BroadcastOp, false>,
 | 
			
		||||
                   HloBroadcastInDimConverter,
 | 
			
		||||
                   PointwiseToLinalgConverter<xla_hlo::AbsOp, false>,
 | 
			
		||||
                   PointwiseToLinalgConverter<xla_hlo::AddOp, false>,
 | 
			
		||||
                   PointwiseToLinalgConverter<xla_hlo::AndOp, false>,
 | 
			
		||||
                   PointwiseToLinalgConverter<xla_hlo::CeilOp, false>,
 | 
			
		||||
                   PointwiseToLinalgConverter<xla_hlo::CompareOp, false>,
 | 
			
		||||
                   PointwiseToLinalgConverter<xla_hlo::ComplexOp, false>,
 | 
			
		||||
                   PointwiseToLinalgConverter<xla_hlo::ConvertOp, false>,
 | 
			
		||||
                   PointwiseToLinalgConverter<xla_hlo::CopyOp, false>,
 | 
			
		||||
                   PointwiseToLinalgConverter<xla_hlo::CosOp, false>,
 | 
			
		||||
                   PointwiseToLinalgConverter<xla_hlo::DivOp, false>,
 | 
			
		||||
                   PointwiseToLinalgConverter<xla_hlo::ExpOp, false>,
 | 
			
		||||
                   PointwiseToLinalgConverter<xla_hlo::ImagOp, false>,
 | 
			
		||||
                   PointwiseToLinalgConverter<xla_hlo::LogOp, false>,
 | 
			
		||||
                   PointwiseToLinalgConverter<xla_hlo::MaxOp, false>,
 | 
			
		||||
                   PointwiseToLinalgConverter<xla_hlo::MinOp, false>,
 | 
			
		||||
                   PointwiseToLinalgConverter<xla_hlo::MulOp, false>,
 | 
			
		||||
                   PointwiseToLinalgConverter<xla_hlo::NegOp, false>,
 | 
			
		||||
                   PointwiseToLinalgConverter<xla_hlo::RealOp, false>,
 | 
			
		||||
                   PointwiseToLinalgConverter<xla_hlo::RemOp, false>,
 | 
			
		||||
                   PointwiseToLinalgConverter<xla_hlo::RsqrtOp, false>,
 | 
			
		||||
                   PointwiseToLinalgConverter<xla_hlo::SelectOp, false>,
 | 
			
		||||
                   PointwiseToLinalgConverter<xla_hlo::SinOp, false>,
 | 
			
		||||
                   PointwiseToLinalgConverter<xla_hlo::SqrtOp, false>,
 | 
			
		||||
                   PointwiseToLinalgConverter<xla_hlo::SubOp, false>,
 | 
			
		||||
                   PointwiseToLinalgConverter<xla_hlo::TanhOp, false>,
 | 
			
		||||
                   ReshapeOpConverter<xla_hlo::ReshapeOp, false>,
 | 
			
		||||
                   ReverseConverter<xla_hlo::ReverseOp, false>,
 | 
			
		||||
                   TransposeConverter<xla_hlo::TransposeOp, false>>(context);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass() {
 | 
			
		||||
  return absl::make_unique<HloLegalizeToLinalg>();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static PassRegistration<HloLegalizeToLinalg> legalize_hlo_pass(
 | 
			
		||||
    "hlo-legalize-to-linalg", "Legalize from HLO dialect to Linalg dialect");
 | 
			
		||||
}  // namespace xla_hlo
 | 
			
		||||
}  // namespace mlir
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,188 @@
 | 
			
		|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#include "third_party/absl/memory/memory.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Shape/IR/Shape.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Function.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/StandardTypes.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Pass/Pass.h"
 | 
			
		||||
#include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h"
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
 | 
			
		||||
 | 
			
		||||
namespace mlir {
 | 
			
		||||
namespace xla_hlo {
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
// TODO(frgossen): Make it variadic.
 | 
			
		||||
template <typename OpTy>
 | 
			
		||||
inline void AddLegalOpOnRankedTensor(ConversionTarget *target) {
 | 
			
		||||
  target->addDynamicallyLegalOp<OpTy>([](OpTy op) {
 | 
			
		||||
    return llvm::all_of((op.getOperation())->getOperandTypes(),
 | 
			
		||||
                        [&](Type t) { return t.isa<RankedTensorType>(); });
 | 
			
		||||
  });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/// Unary element-wise operations on unranked tensors can be applied to the
 | 
			
		||||
/// flattened tensor with the same effect.
 | 
			
		||||
/// This pattern rewrites every such operation to
 | 
			
		||||
///   (i)   flatten the input tensor,
 | 
			
		||||
///   (ii)  apply the unary operation, and
 | 
			
		||||
///   (iii) restore the original shape.
 | 
			
		||||
template <typename OpTy>
 | 
			
		||||
struct UnaryElementwiseOpConversion : public OpRewritePattern<OpTy> {
 | 
			
		||||
  explicit UnaryElementwiseOpConversion(MLIRContext *context)
 | 
			
		||||
      : OpRewritePattern<OpTy>(context) {}
 | 
			
		||||
 | 
			
		||||
  LogicalResult matchAndRewrite(OpTy op,
 | 
			
		||||
                                PatternRewriter &rewriter) const override {
 | 
			
		||||
    // Don't apply conversion to ops with statically shaped operands.
 | 
			
		||||
    Value operand = op.getOperand();
 | 
			
		||||
    auto operandTy = operand.getType().dyn_cast<TensorType>();
 | 
			
		||||
    if (operandTy.hasRank()) return failure();
 | 
			
		||||
 | 
			
		||||
    // Generate IR to flatten the operand.
 | 
			
		||||
    auto loc = op.getLoc();
 | 
			
		||||
    Value shape = rewriter.create<shape::ShapeOfOp>(loc, operand);
 | 
			
		||||
    Value numElements = rewriter.create<shape::NumElementsOp>(
 | 
			
		||||
        loc, rewriter.getType<shape::SizeType>(), shape);
 | 
			
		||||
    Value numElementsAsIndex = rewriter.create<shape::SizeToIndexOp>(
 | 
			
		||||
        loc, rewriter.getIndexType(), numElements);
 | 
			
		||||
    Value flatShapeAsDimTensor =
 | 
			
		||||
        rewriter.create<TensorFromElementsOp>(loc, numElementsAsIndex);
 | 
			
		||||
    auto flatTensorTy = RankedTensorType::get({ShapedType::kDynamicSize},
 | 
			
		||||
                                              operandTy.getElementType());
 | 
			
		||||
    Value flatOperand = rewriter.create<xla_hlo::DynamicReshapeOp>(
 | 
			
		||||
        loc, flatTensorTy, operand, flatShapeAsDimTensor);
 | 
			
		||||
 | 
			
		||||
    // Generate IR for the actual operation.
 | 
			
		||||
    Value flatResult = rewriter.create<OpTy>(loc, flatTensorTy, flatOperand);
 | 
			
		||||
 | 
			
		||||
    // Generate IR to restore the original shape.
 | 
			
		||||
    auto extentTensorTy = RankedTensorType::get({ShapedType::kDynamicSize},
 | 
			
		||||
                                                rewriter.getIndexType());
 | 
			
		||||
    Value shapeAsExtentTensor =
 | 
			
		||||
        rewriter.create<shape::ToExtentTensorOp>(loc, extentTensorTy, shape);
 | 
			
		||||
    Value result = rewriter.create<xla_hlo::DynamicReshapeOp>(
 | 
			
		||||
        loc, operandTy, flatResult, shapeAsExtentTensor);
 | 
			
		||||
    rewriter.replaceOp(op, result);
 | 
			
		||||
 | 
			
		||||
    return success();
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
/// Binary element-wise operation on unranked tensors can be applied to the
 | 
			
		||||
/// flattened operand tensors with the same effect.
 | 
			
		||||
/// This pattern rewrites every such operation to
 | 
			
		||||
///   (i)   flatten the operand tensors,
 | 
			
		||||
///   (ii)  apply the binary operation, and
 | 
			
		||||
//    (iii) restore the original shape.
 | 
			
		||||
template <typename OpTy>
 | 
			
		||||
struct BinaryElementwiseOpConversion : public OpRewritePattern<OpTy> {
 | 
			
		||||
  explicit BinaryElementwiseOpConversion(MLIRContext *context)
 | 
			
		||||
      : OpRewritePattern<OpTy>(context) {}
 | 
			
		||||
 | 
			
		||||
  LogicalResult matchAndRewrite(OpTy op,
 | 
			
		||||
                                PatternRewriter &rewriter) const override {
 | 
			
		||||
    // Don't apply conversion unless both operands are unranked.
 | 
			
		||||
    if (op.lhs().getType().template isa<RankedTensorType>() ||
 | 
			
		||||
        op.rhs().getType().template isa<RankedTensorType>()) {
 | 
			
		||||
      return failure();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Flatten operands.
 | 
			
		||||
    Type shapeTy = shape::ShapeType::get(rewriter.getContext());
 | 
			
		||||
    auto loc = op.getLoc();
 | 
			
		||||
    Value shapeLhs = rewriter.create<shape::ShapeOfOp>(loc, op.lhs());
 | 
			
		||||
    Value shapeRhs = rewriter.create<shape::ShapeOfOp>(loc, op.rhs());
 | 
			
		||||
    Value shape = rewriter.create<shape::AnyOp>(loc, shapeTy,
 | 
			
		||||
                                                ValueRange{shapeLhs, shapeRhs});
 | 
			
		||||
    Value numElements = rewriter.create<shape::NumElementsOp>(loc, shape);
 | 
			
		||||
    Value numElementsAsIndex =
 | 
			
		||||
        rewriter.create<shape::SizeToIndexOp>(loc, numElements);
 | 
			
		||||
    Value flatShape =
 | 
			
		||||
        rewriter.create<TensorFromElementsOp>(loc, numElementsAsIndex);
 | 
			
		||||
    TensorType lhsTy = op.lhs().getType().template cast<TensorType>();
 | 
			
		||||
    Type flatLhsTy = RankedTensorType::get({ShapedType::kDynamicSize},
 | 
			
		||||
                                           lhsTy.getElementType());
 | 
			
		||||
    Value flatLhs =
 | 
			
		||||
        rewriter.create<DynamicReshapeOp>(loc, flatLhsTy, op.lhs(), flatShape);
 | 
			
		||||
    TensorType rhsTy = op.rhs().getType().template cast<TensorType>();
 | 
			
		||||
    Type flatRhsTy = RankedTensorType::get({ShapedType::kDynamicSize},
 | 
			
		||||
                                           rhsTy.getElementType());
 | 
			
		||||
    Value flatRhs =
 | 
			
		||||
        rewriter.create<DynamicReshapeOp>(loc, flatRhsTy, op.rhs(), flatShape);
 | 
			
		||||
 | 
			
		||||
    // Apply actual operation to flattened operands.
 | 
			
		||||
    Value flatResult = rewriter.create<OpTy>(loc, flatLhs, flatRhs);
 | 
			
		||||
 | 
			
		||||
    // Restore original shape.
 | 
			
		||||
    auto extentTensorTy = RankedTensorType::get({ShapedType::kDynamicSize},
 | 
			
		||||
                                                rewriter.getIndexType());
 | 
			
		||||
    Value shapeAsExtentTensor =
 | 
			
		||||
        rewriter.create<shape::ToExtentTensorOp>(loc, extentTensorTy, shape);
 | 
			
		||||
    Value result = rewriter.create<DynamicReshapeOp>(
 | 
			
		||||
        loc, op.getType(), flatResult, shapeAsExtentTensor);
 | 
			
		||||
    rewriter.replaceOp(op, result);
 | 
			
		||||
 | 
			
		||||
    return success();
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct TransformUnrankedHloPass
 | 
			
		||||
    : public PassWrapper<TransformUnrankedHloPass, FunctionPass> {
 | 
			
		||||
  void runOnFunction() override {
 | 
			
		||||
    // Setup conversion target.
 | 
			
		||||
    MLIRContext &ctx = getContext();
 | 
			
		||||
    ConversionTarget target(ctx);
 | 
			
		||||
    target.addLegalDialect<XlaHloDialect, StandardOpsDialect,
 | 
			
		||||
                           shape::ShapeDialect>();
 | 
			
		||||
    target.addLegalOp<FuncOp>();
 | 
			
		||||
    AddLegalOpOnRankedTensor<SqrtOp>(&target);
 | 
			
		||||
    AddLegalOpOnRankedTensor<AddOp>(&target);
 | 
			
		||||
 | 
			
		||||
    // Populate rewrite patterns.
 | 
			
		||||
    OwningRewritePatternList patterns;
 | 
			
		||||
    PopulateTransformUnrankedHloPatterns(&ctx, &patterns);
 | 
			
		||||
 | 
			
		||||
    // Apply transformation.
 | 
			
		||||
    if (failed(applyFullConversion(getFunction(), target, patterns)))
 | 
			
		||||
      return signalPassFailure();
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
 | 
			
		||||
void PopulateTransformUnrankedHloPatterns(MLIRContext *context,
 | 
			
		||||
                                          OwningRewritePatternList *patterns) {
 | 
			
		||||
  // TODO(frgossen): Populate all unary and binary operations.
 | 
			
		||||
  // clang-format off
 | 
			
		||||
  patterns->insert<
 | 
			
		||||
      BinaryElementwiseOpConversion<AddOp>,
 | 
			
		||||
      UnaryElementwiseOpConversion<SqrtOp>>(context);
 | 
			
		||||
  // clang-format on
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static PassRegistration<TransformUnrankedHloPass> transform_unranked_hlo_pass(
 | 
			
		||||
    "transform-unranked-hlo",
 | 
			
		||||
    "Realize element-wise operations on ranked tensors where possible");
 | 
			
		||||
 | 
			
		||||
}  // namespace xla_hlo
 | 
			
		||||
}  // namespace mlir
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,340 @@
 | 
			
		|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/cycle_detector.h"
 | 
			
		||||
 | 
			
		||||
#include <algorithm>
 | 
			
		||||
 | 
			
		||||
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/DenseSet.h"
 | 
			
		||||
 | 
			
		||||
namespace mlir {
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
using NodeSet = llvm::DenseSet<int32_t>;
 | 
			
		||||
using OrderedNodeSet = OrderedSet<int32_t>;
 | 
			
		||||
 | 
			
		||||
template <typename T>
 | 
			
		||||
struct VecStruct {
 | 
			
		||||
  typedef llvm::SmallVector<T, 4> type;
 | 
			
		||||
};
 | 
			
		||||
template <typename T>
 | 
			
		||||
using Vec = typename VecStruct<T>::type;
 | 
			
		||||
 | 
			
		||||
struct Node {
 | 
			
		||||
  // rank number assigned by Pearce-Kelly algorithm
 | 
			
		||||
  int32_t rank;
 | 
			
		||||
  // Temporary marker used by depth-first-search
 | 
			
		||||
  bool visited;
 | 
			
		||||
  // User-supplied data
 | 
			
		||||
  void* data;
 | 
			
		||||
  // List of immediate predecessor nodes in graph
 | 
			
		||||
  OrderedNodeSet in;
 | 
			
		||||
  // List of immediate successor nodes in graph
 | 
			
		||||
  OrderedNodeSet out;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
 | 
			
		||||
struct GraphCycles::Rep {
 | 
			
		||||
  Vec<Node*> nodes;
 | 
			
		||||
  // Indices for unused entries in nodes
 | 
			
		||||
  Vec<int32_t> free_nodes;
 | 
			
		||||
 | 
			
		||||
  // Temporary state.
 | 
			
		||||
  // Results of forward DFS
 | 
			
		||||
  Vec<int32_t> deltaf;
 | 
			
		||||
  // Results of backward DFS
 | 
			
		||||
  Vec<int32_t> deltab;
 | 
			
		||||
  // All nodes to reprocess
 | 
			
		||||
  Vec<int32_t> list;
 | 
			
		||||
  // Rank values to assign to list entries
 | 
			
		||||
  Vec<int32_t> merged;
 | 
			
		||||
  // Emulates recursion stack when doing depth first search
 | 
			
		||||
  Vec<int32_t> stack;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
GraphCycles::GraphCycles(int32_t num_nodes) : rep_(new Rep) {
 | 
			
		||||
  rep_->nodes.reserve(num_nodes);
 | 
			
		||||
  for (int32_t i = 0; i < num_nodes; ++i) {
 | 
			
		||||
    Node* n = new Node;
 | 
			
		||||
    n->visited = false;
 | 
			
		||||
    n->data = nullptr;
 | 
			
		||||
    n->rank = rep_->nodes.size();
 | 
			
		||||
    rep_->nodes.push_back(n);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
GraphCycles::~GraphCycles() {
 | 
			
		||||
  for (Vec<Node*>::size_type i = 0, e = rep_->nodes.size(); i < e; ++i) {
 | 
			
		||||
    delete rep_->nodes[i];
 | 
			
		||||
  }
 | 
			
		||||
  delete rep_;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool GraphCycles::HasEdge(int32_t x, int32_t y) const {
 | 
			
		||||
  return rep_->nodes[x]->out.Contains(y);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void GraphCycles::RemoveEdge(int32_t x, int32_t y) {
 | 
			
		||||
  rep_->nodes[x]->out.Erase(y);
 | 
			
		||||
  rep_->nodes[y]->in.Erase(x);
 | 
			
		||||
  // No need to update the rank assignment since a previous valid
 | 
			
		||||
  // rank assignment remains valid after an edge deletion.
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static bool ForwardDFS(GraphCycles::Rep* r, int32_t n, int32_t upper_bound);
 | 
			
		||||
static void BackwardDFS(GraphCycles::Rep* r, int32_t n, int32_t lower_bound);
 | 
			
		||||
static void Reorder(GraphCycles::Rep* r);
 | 
			
		||||
static void Sort(const Vec<Node*>&, Vec<int32_t>* delta);
 | 
			
		||||
static void MoveToList(GraphCycles::Rep* r, Vec<int32_t>* src,
 | 
			
		||||
                       Vec<int32_t>* dst);
 | 
			
		||||
static void ClearVisitedBits(GraphCycles::Rep* r, const Vec<int32_t>& nodes);
 | 
			
		||||
 | 
			
		||||
bool GraphCycles::InsertEdge(int32_t x, int32_t y) {
 | 
			
		||||
  if (x == y) return false;
 | 
			
		||||
  Rep* r = rep_;
 | 
			
		||||
  Node* nx = r->nodes[x];
 | 
			
		||||
  if (!nx->out.Insert(y)) {
 | 
			
		||||
    // Edge already exists.
 | 
			
		||||
    return true;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Node* ny = r->nodes[y];
 | 
			
		||||
  ny->in.Insert(x);
 | 
			
		||||
 | 
			
		||||
  if (nx->rank <= ny->rank) {
 | 
			
		||||
    // New edge is consistent with existing rank assignment.
 | 
			
		||||
    return true;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Current rank assignments are incompatible with the new edge.  Recompute.
 | 
			
		||||
  // We only need to consider nodes that fall in the range [ny->rank,nx->rank].
 | 
			
		||||
  if (ForwardDFS(r, y, nx->rank)) {
 | 
			
		||||
    // Found a cycle.  Undo the insertion and tell caller.
 | 
			
		||||
    nx->out.Erase(y);
 | 
			
		||||
    ny->in.Erase(x);
 | 
			
		||||
    // Since we do not call Reorder() on this path, clear any visited
 | 
			
		||||
    // markers left by ForwardDFS.
 | 
			
		||||
    ClearVisitedBits(r, r->deltaf);
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
  BackwardDFS(r, x, ny->rank);
 | 
			
		||||
  Reorder(r);
 | 
			
		||||
  return true;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Follows the edges from producer to consumer and searchs if the node having
 | 
			
		||||
// rank `n` can reach the node having rank `upper_bound` using a DFS search.
 | 
			
		||||
// When doing DFS search, We only consider the pathes that satisfy the ranks
 | 
			
		||||
// of the nodes of the path are all smaller than `upper_bound`.
 | 
			
		||||
//
 | 
			
		||||
// Returns true if such path exists.
 | 
			
		||||
static bool ForwardDFS(GraphCycles::Rep* r, int32_t n, int32_t upper_bound) {
 | 
			
		||||
  // Avoid recursion since stack space might be limited.
 | 
			
		||||
  // We instead keep a stack of nodes to visit.
 | 
			
		||||
  r->deltaf.clear();
 | 
			
		||||
  r->stack.clear();
 | 
			
		||||
  r->stack.push_back(n);
 | 
			
		||||
  while (!r->stack.empty()) {
 | 
			
		||||
    n = r->stack.back();
 | 
			
		||||
    r->stack.pop_back();
 | 
			
		||||
    Node* nn = r->nodes[n];
 | 
			
		||||
    if (nn->visited) continue;
 | 
			
		||||
 | 
			
		||||
    nn->visited = true;
 | 
			
		||||
    r->deltaf.push_back(n);
 | 
			
		||||
 | 
			
		||||
    for (auto w : nn->out.GetSequence()) {
 | 
			
		||||
      Node* nw = r->nodes[w];
 | 
			
		||||
      if (nw->rank == upper_bound) {
 | 
			
		||||
        return true;
 | 
			
		||||
      }
 | 
			
		||||
      if (!nw->visited && nw->rank < upper_bound) {
 | 
			
		||||
        r->stack.push_back(w);
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  return false;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Follows the edges from consumer to producer and visit all the nodes that
 | 
			
		||||
// is reachable from node `n` and have rank larger than `lower_bound`.
 | 
			
		||||
static void BackwardDFS(GraphCycles::Rep* r, int32_t n, int32_t lower_bound) {
 | 
			
		||||
  r->deltab.clear();
 | 
			
		||||
  r->stack.clear();
 | 
			
		||||
  r->stack.push_back(n);
 | 
			
		||||
  while (!r->stack.empty()) {
 | 
			
		||||
    n = r->stack.back();
 | 
			
		||||
    r->stack.pop_back();
 | 
			
		||||
    Node* nn = r->nodes[n];
 | 
			
		||||
    if (nn->visited) continue;
 | 
			
		||||
 | 
			
		||||
    nn->visited = true;
 | 
			
		||||
    r->deltab.push_back(n);
 | 
			
		||||
 | 
			
		||||
    for (auto w : nn->in.GetSequence()) {
 | 
			
		||||
      Node* nw = r->nodes[w];
 | 
			
		||||
      if (!nw->visited && lower_bound < nw->rank) {
 | 
			
		||||
        r->stack.push_back(w);
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Recomputes rank assignments to make them compatible with the edges (producer
 | 
			
		||||
// has smaller rank than its consumer)
 | 
			
		||||
static void Reorder(GraphCycles::Rep* r) {
 | 
			
		||||
  Sort(r->nodes, &r->deltab);
 | 
			
		||||
  Sort(r->nodes, &r->deltaf);
 | 
			
		||||
 | 
			
		||||
  // Adds contents of delta lists to list (backwards deltas first).
 | 
			
		||||
  r->list.clear();
 | 
			
		||||
  MoveToList(r, &r->deltab, &r->list);
 | 
			
		||||
  MoveToList(r, &r->deltaf, &r->list);
 | 
			
		||||
 | 
			
		||||
  // Produce sorted list of all ranks that will be reassigned.
 | 
			
		||||
  r->merged.resize(r->deltab.size() + r->deltaf.size());
 | 
			
		||||
  std::merge(r->deltab.begin(), r->deltab.end(), r->deltaf.begin(),
 | 
			
		||||
             r->deltaf.end(), r->merged.begin());
 | 
			
		||||
 | 
			
		||||
  // Assign the ranks in order to the collected list.
 | 
			
		||||
  for (Vec<int32_t>::size_type i = 0, e = r->list.size(); i < e; ++i) {
 | 
			
		||||
    r->nodes[r->list[i]]->rank = r->merged[i];
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Sorts nodes in the vector according to their ranks. Small rank first.
 | 
			
		||||
static void Sort(const Vec<Node*>& nodes, Vec<int32_t>* delta) {
 | 
			
		||||
  struct ByRank {
 | 
			
		||||
    const Vec<Node*>* nodes;
 | 
			
		||||
    bool operator()(int32_t a, int32_t b) const {
 | 
			
		||||
      return (*nodes)[a]->rank < (*nodes)[b]->rank;
 | 
			
		||||
    }
 | 
			
		||||
  };
 | 
			
		||||
  ByRank cmp;
 | 
			
		||||
  cmp.nodes = &nodes;
 | 
			
		||||
  std::sort(delta->begin(), delta->end(), cmp);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Collects ranks of nodes in vector `src` to vector `dst`
 | 
			
		||||
static void MoveToList(GraphCycles::Rep* r, Vec<int32_t>* src,
 | 
			
		||||
                       Vec<int32_t>* dst) {
 | 
			
		||||
  for (Vec<int32_t>::size_type i = 0, e = src->size(); i < e; i++) {
 | 
			
		||||
    int32_t w = (*src)[i];
 | 
			
		||||
    // Replace src entry with its rank
 | 
			
		||||
    (*src)[i] = r->nodes[w]->rank;
 | 
			
		||||
    // Prepare for future DFS calls
 | 
			
		||||
    r->nodes[w]->visited = false;
 | 
			
		||||
    dst->push_back(w);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Clears bookkeeping fileds used during the last DFS process.
 | 
			
		||||
static void ClearVisitedBits(GraphCycles::Rep* r, const Vec<int32_t>& nodes) {
 | 
			
		||||
  for (Vec<int32_t>::size_type i = 0, e = nodes.size(); i < e; i++) {
 | 
			
		||||
    r->nodes[nodes[i]]->visited = false;
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool GraphCycles::IsReachable(int32_t x, int32_t y) {
 | 
			
		||||
  if (x == y) return true;
 | 
			
		||||
  Rep* r = rep_;
 | 
			
		||||
  Node* nx = r->nodes[x];
 | 
			
		||||
  Node* ny = r->nodes[y];
 | 
			
		||||
 | 
			
		||||
  if (nx->rank >= ny->rank) {
 | 
			
		||||
    // x cannot reach y since it is after it in the topological ordering
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // See if x can reach y using a DFS search that is limited to y's rank
 | 
			
		||||
  bool reachable = ForwardDFS(r, x, ny->rank);
 | 
			
		||||
 | 
			
		||||
  // Clear any visited markers left by ForwardDFS.
 | 
			
		||||
  ClearVisitedBits(r, r->deltaf);
 | 
			
		||||
  return reachable;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
llvm::Optional<int32_t> GraphCycles::ContractEdge(int32_t a, int32_t b) {
 | 
			
		||||
  assert(HasEdge(a, b));
 | 
			
		||||
  RemoveEdge(a, b);
 | 
			
		||||
 | 
			
		||||
  if (IsReachable(a, b)) {
 | 
			
		||||
    // Restore the graph to its original state.
 | 
			
		||||
    InsertEdge(a, b);
 | 
			
		||||
    return {};
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  if (rep_->nodes[b]->in.Size() + rep_->nodes[b]->out.Size() >
 | 
			
		||||
      rep_->nodes[a]->in.Size() + rep_->nodes[a]->out.Size()) {
 | 
			
		||||
    // Swap "a" and "b" to minimize copying.
 | 
			
		||||
    std::swap(a, b);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Node* nb = rep_->nodes[b];
 | 
			
		||||
  OrderedNodeSet out = std::move(nb->out);
 | 
			
		||||
  OrderedNodeSet in = std::move(nb->in);
 | 
			
		||||
  for (int32_t y : out.GetSequence()) {
 | 
			
		||||
    rep_->nodes[y]->in.Erase(b);
 | 
			
		||||
  }
 | 
			
		||||
  for (int32_t y : in.GetSequence()) {
 | 
			
		||||
    rep_->nodes[y]->out.Erase(b);
 | 
			
		||||
  }
 | 
			
		||||
  rep_->free_nodes.push_back(b);
 | 
			
		||||
 | 
			
		||||
  rep_->nodes[a]->out.Reserve(rep_->nodes[a]->out.Size() + out.Size());
 | 
			
		||||
  for (int32_t y : out.GetSequence()) {
 | 
			
		||||
    InsertEdge(a, y);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  rep_->nodes[a]->in.Reserve(rep_->nodes[a]->in.Size() + in.Size());
 | 
			
		||||
  for (int32_t y : in.GetSequence()) {
 | 
			
		||||
    InsertEdge(y, a);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Note, if the swap happened it might be what originally was called "b".
 | 
			
		||||
  return a;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::vector<int32_t> GraphCycles::SuccessorsCopy(int32_t node) const {
 | 
			
		||||
  return rep_->nodes[node]->out.GetSequence();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
void SortInPostOrder(const Vec<Node*>& nodes, std::vector<int32_t>* to_sort) {
 | 
			
		||||
  std::sort(to_sort->begin(), to_sort->end(), [&](int32_t a, int32_t b) {
 | 
			
		||||
    return nodes[a]->rank > nodes[b]->rank;
 | 
			
		||||
  });
 | 
			
		||||
}
 | 
			
		||||
}  // namespace
 | 
			
		||||
 | 
			
		||||
std::vector<int32_t> GraphCycles::AllNodesInPostOrder() const {
 | 
			
		||||
  llvm::DenseSet<int32_t> free_nodes_set;
 | 
			
		||||
  for (int32_t n : rep_->free_nodes) free_nodes_set.insert(n);
 | 
			
		||||
 | 
			
		||||
  std::vector<int32_t> all_nodes;
 | 
			
		||||
  all_nodes.reserve(rep_->nodes.size() - free_nodes_set.size());
 | 
			
		||||
  for (size_t i = 0, e = rep_->nodes.size(); i < e; i++) {
 | 
			
		||||
    if (!free_nodes_set.count(i)) {
 | 
			
		||||
      all_nodes.push_back(i);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  SortInPostOrder(rep_->nodes, &all_nodes);
 | 
			
		||||
  return all_nodes;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace mlir
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,89 @@
 | 
			
		|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#include "third_party/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/cycle_detector.h"
 | 
			
		||||
 | 
			
		||||
#include "third_party/tensorflow/compiler/xla/test.h"
 | 
			
		||||
 | 
			
		||||
class GraphCyclesTest : public ::testing::Test {
 | 
			
		||||
 public:
 | 
			
		||||
  GraphCyclesTest() : g_(100) {}
 | 
			
		||||
 | 
			
		||||
  bool AddEdge(int x, int y) { return g_.InsertEdge(x, y); }
 | 
			
		||||
 | 
			
		||||
  void AddMultiples() {
 | 
			
		||||
    // For every node x > 0: add edge to 2*x, 3*x
 | 
			
		||||
    for (int x = 1; x < 25; x++) {
 | 
			
		||||
      EXPECT_TRUE(AddEdge(x, 2 * x)) << x;
 | 
			
		||||
      EXPECT_TRUE(AddEdge(x, 3 * x)) << x;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  mlir::GraphCycles g_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
TEST_F(GraphCyclesTest, NoCycle) { AddMultiples(); }
 | 
			
		||||
 | 
			
		||||
TEST_F(GraphCyclesTest, SimpleCycle) {
 | 
			
		||||
  AddMultiples();
 | 
			
		||||
  EXPECT_FALSE(AddEdge(8, 4));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(GraphCyclesTest, IndirectCycle) {
 | 
			
		||||
  AddMultiples();
 | 
			
		||||
  EXPECT_TRUE(AddEdge(16, 9));
 | 
			
		||||
  EXPECT_FALSE(AddEdge(9, 2));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(GraphCyclesTest, RemoveEdge) {
 | 
			
		||||
  EXPECT_TRUE(AddEdge(1, 2));
 | 
			
		||||
  EXPECT_TRUE(AddEdge(2, 3));
 | 
			
		||||
  EXPECT_TRUE(AddEdge(3, 4));
 | 
			
		||||
  EXPECT_TRUE(AddEdge(4, 5));
 | 
			
		||||
  g_.RemoveEdge(2, 3);
 | 
			
		||||
  EXPECT_FALSE(g_.HasEdge(2, 3));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(GraphCyclesTest, IsReachable) {
 | 
			
		||||
  EXPECT_TRUE(AddEdge(1, 2));
 | 
			
		||||
  EXPECT_TRUE(AddEdge(2, 3));
 | 
			
		||||
  EXPECT_TRUE(AddEdge(3, 4));
 | 
			
		||||
  EXPECT_TRUE(AddEdge(4, 5));
 | 
			
		||||
 | 
			
		||||
  EXPECT_TRUE(g_.IsReachable(1, 5));
 | 
			
		||||
  EXPECT_FALSE(g_.IsReachable(5, 1));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(GraphCyclesTest, ContractEdge) {
 | 
			
		||||
  ASSERT_TRUE(AddEdge(1, 2));
 | 
			
		||||
  ASSERT_TRUE(AddEdge(1, 3));
 | 
			
		||||
  ASSERT_TRUE(AddEdge(2, 3));
 | 
			
		||||
  ASSERT_TRUE(AddEdge(2, 4));
 | 
			
		||||
  ASSERT_TRUE(AddEdge(3, 4));
 | 
			
		||||
 | 
			
		||||
  // It will introduce a cycle if the edge is contracted
 | 
			
		||||
  EXPECT_FALSE(g_.ContractEdge(1, 3).hasValue());
 | 
			
		||||
  EXPECT_TRUE(g_.HasEdge(1, 3));
 | 
			
		||||
 | 
			
		||||
  // Node (2) has more edges.
 | 
			
		||||
  EXPECT_EQ(*g_.ContractEdge(1, 2), 2);
 | 
			
		||||
  EXPECT_TRUE(g_.HasEdge(2, 3));
 | 
			
		||||
  EXPECT_TRUE(g_.HasEdge(2, 4));
 | 
			
		||||
  EXPECT_TRUE(g_.HasEdge(3, 4));
 | 
			
		||||
 | 
			
		||||
  // Node (2) has more edges.
 | 
			
		||||
  EXPECT_EQ(*g_.ContractEdge(2, 3), 2);
 | 
			
		||||
  EXPECT_TRUE(g_.HasEdge(2, 4));
 | 
			
		||||
}
 | 
			
		||||
		Loading…
	
		Reference in New Issue