From a42213b87076ab350dde957e022344a53ea31905 Mon Sep 17 00:00:00 2001 From: Tres Popp Date: Mon, 21 Dec 2020 15:26:38 -0800 Subject: [PATCH] Define lowering of [l]mhlo.pow. For floating point operations, this uses std.pow. For integer operations, this lowers to a loop. This adds a dependency on scf. PiperOrigin-RevId: 348537232 --- BUILD | 2 + .../mhlo/transforms/map_hlo_to_lhlo_op.h | 1 + .../mhlo/transforms/map_lmhlo_to_scalar_op.h | 38 ++++++++++++++++++ .../mhlo/transforms/hlo_legalize_to_lhlo.cc | 1 + .../mhlo/transforms/legalize_to_linalg.cc | 8 +++- tests/hlo-legalize-to-linalg.mlir | 39 +++++++++++++++++++ tests/lhlo-legalize-to-linalg.mlir | 15 +++++++ 7 files changed, 102 insertions(+), 2 deletions(-) diff --git a/BUILD b/BUILD index a6317ea..44e369e 100644 --- a/BUILD +++ b/BUILD @@ -593,6 +593,7 @@ cc_library( ":map_hlo_to_lhlo_op", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", + "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:StandardOps", ], ) @@ -664,6 +665,7 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:LinalgOps", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Transforms", ], diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h b/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h index 60fff05..ef36f41 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h @@ -66,6 +66,7 @@ MAP_HLO_TO_LHLO(MulOp); MAP_HLO_TO_LHLO(NegOp); MAP_HLO_TO_LHLO(NotOp); MAP_HLO_TO_LHLO(OrOp); +MAP_HLO_TO_LHLO(PowOp); MAP_HLO_TO_LHLO(RealOp); MAP_HLO_TO_LHLO(ReduceOp); MAP_HLO_TO_LHLO(ReshapeOp); diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h b/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h index a179eb6..eadc32c 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h @@ -16,12 +16,16 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_LMHLO_TO_SCALAR_OP_H_ #define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_LMHLO_TO_SCALAR_OP_H_ +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSwitch.h" +#include "llvm/ADT/iterator_range.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h" +#include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/TypeUtilities.h" namespace mlir { @@ -508,6 +512,40 @@ inline Value MapLhloOpToStdScalarOp(Location loc, loc, result_types, args, b); } +template <> +inline Value MapLhloOpToStdScalarOp(Location loc, + ArrayRef result_types, + ArrayRef args, + OpBuilder* b) { + lmhlo::PowOp::Adaptor adaptor(args); + // Floating point can use std::powf + auto result_type = result_types.front(); + if (result_type.isa<::mlir::FloatType>()) + return MapLhloOpToStdScalarOpImpl<::mlir::PowFOp>{}(loc, result_types, args, + b); + + assert(result_type.isa<::mlir::IntegerType>() && + "only float and integer `pow` is supported right now"); + + // There is no powi, so lower to a simple product. Note that HLO does not + // define semantics of negative exponents. + Value init = b->create(loc, b->getIntegerAttr(result_type, 1)); + + Value lowerBound = b->create(loc, 0); + Value upperBound = + b->create(loc, adaptor.rhs(), b->getIndexType()); + Value step = b->create(loc, 1); + return b + ->create( + loc, lowerBound, upperBound, step, llvm::makeArrayRef(init), + [&](OpBuilder& b, Location l, Value v, ValueRange iters) { + Value prod = + b.create<::mlir::MulIOp>(l, adaptor.lhs(), iters.front()); + b.create(l, prod); + }) + .getResult(0); +} + template <> inline Value MapLhloOpToStdScalarOp( Location loc, ArrayRef result_types, ArrayRef args, diff --git a/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc b/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc index f61c7eb..59c2a22 100644 --- a/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc +++ b/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc @@ -650,6 +650,7 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context, HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, + HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, diff --git a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index 4d44b9c..1a153dd 100644 --- a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -25,6 +25,7 @@ limitations under the License. #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" +#include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/Attributes.h" @@ -957,6 +958,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, + PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, @@ -1021,13 +1023,14 @@ struct LhloLegalizeToLinalgPass struct HloLegalizeToLinalgPass : public PassWrapper { void getDependentDialects(DialectRegistry& registry) const override { - registry.insert(); + registry.insert(); } void runOnFunction() override { OwningRewritePatternList patterns; ConversionTarget target(getContext()); - target.addLegalDialect(); + target.addLegalDialect(); auto func = getFunction(); mhlo::populateHLOToLinalgConversionPattern(func.getContext(), &patterns); @@ -1075,6 +1078,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, + PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, diff --git a/tests/hlo-legalize-to-linalg.mlir b/tests/hlo-legalize-to-linalg.mlir index 1932de0..71a8b79 100644 --- a/tests/hlo-legalize-to-linalg.mlir +++ b/tests/hlo-legalize-to-linalg.mlir @@ -745,3 +745,42 @@ func @constant() { return } // CHECK: %[[CONSTANT:.*]] = constant dense<10> : tensor + +// ----- + +// CHECK: #map = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func @float_pow +func @float_pow(%lhs: tensor<2x2xf32>, + %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> { + // CHECK: linalg.generic + // CHECK: ^{{[a-z0-9_]*}} + // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: f32 + // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: f32 + // CHECK: %[[RESULT:[a-zA-Z0-9_]*]] = powf %[[ARG0]], %[[ARG1]] + // CHECK: linalg.yield %[[RESULT]] + %0 = "mhlo.power"(%lhs, %rhs) : (tensor<2x2xf32>, + tensor<2x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> +} + +// ----- + +// CHECK: #map = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func @integer_pow +func @integer_pow(%lhs: tensor<2x2xi32>, + %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { + // CHECK: linalg.generic + // CHECK: ^{{[a-z0-9_]*}} + // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: i32 + // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: i32 + // CHECK: %[[UPPER:.*]] = index_cast %[[ARG1]] + // CHECK: %[[RESULT:.*]] = scf.for {{.*}} to %[[UPPER]] + // CHECK-SAME: step %c1{{[a-zA-Z0-9_]*}} + // CHECK-SAME: iter_args(%[[ITER:.*]] = %c1{{.*}}) -> (i32) { + // CHECK: %[[ACCUM:[a-zA-Z0-9_]*]] = muli %[[ARG0]], %[[ITER]] + // CHECK: scf.yield %[[ACCUM]] + // CHECK: linalg.yield %[[RESULT]] + %0 = "mhlo.power"(%lhs, %rhs) : (tensor<2x2xi32>, + tensor<2x2xi32>) -> tensor<2x2xi32> + return %0 : tensor<2x2xi32> +} diff --git a/tests/lhlo-legalize-to-linalg.mlir b/tests/lhlo-legalize-to-linalg.mlir index 3dba087..716e00a 100644 --- a/tests/lhlo-legalize-to-linalg.mlir +++ b/tests/lhlo-legalize-to-linalg.mlir @@ -1,5 +1,20 @@ // RUN: mlir-hlo-opt %s -lhlo-legalize-to-linalg -split-input-file | FILECHECK_OPTS="" FileCheck %s +// CHECK: #map = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func @element_wise +func @element_wise(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, + %result: memref<2x2xf32>) { + "lmhlo.power"(%lhs, %rhs, %result) + : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32, %[[RESULT_OUT:.*]]: f32): +// CHECK-NEXT: %[[RESULT:.*]] = powf %[[LHS_IN]], %[[RHS_IN]] : f32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : f32 + +// ----- + // CHECK: #map = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: func @element_wise func @element_wise(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,