Define lowering of [l]mhlo.pow.

For floating point operations, this uses std.pow.
For integer operations, this lowers to a loop.
This adds a dependency on scf.

PiperOrigin-RevId: 348537232
This commit is contained in:
Tres Popp 2020-12-21 15:26:38 -08:00 committed by TensorFlow MLIR Team
parent 99a0ee378c
commit a42213b870
7 changed files with 102 additions and 2 deletions

2
BUILD
View File

@ -593,6 +593,7 @@ cc_library(
":map_hlo_to_lhlo_op",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:StandardOps",
],
)
@ -664,6 +665,7 @@ cc_library(
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LinalgOps",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Transforms",
],

View File

@ -66,6 +66,7 @@ MAP_HLO_TO_LHLO(MulOp);
MAP_HLO_TO_LHLO(NegOp);
MAP_HLO_TO_LHLO(NotOp);
MAP_HLO_TO_LHLO(OrOp);
MAP_HLO_TO_LHLO(PowOp);
MAP_HLO_TO_LHLO(RealOp);
MAP_HLO_TO_LHLO(ReduceOp);
MAP_HLO_TO_LHLO(ReshapeOp);

View File

@ -16,12 +16,16 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_LMHLO_TO_SCALAR_OP_H_
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_LMHLO_TO_SCALAR_OP_H_
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/ADT/iterator_range.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/TypeUtilities.h"
namespace mlir {
@ -508,6 +512,40 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::RsqrtOp>(Location loc,
loc, result_types, args, b);
}
template <>
inline Value MapLhloOpToStdScalarOp<lmhlo::PowOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
lmhlo::PowOp::Adaptor adaptor(args);
// Floating point can use std::powf
auto result_type = result_types.front();
if (result_type.isa<::mlir::FloatType>())
return MapLhloOpToStdScalarOpImpl<::mlir::PowFOp>{}(loc, result_types, args,
b);
assert(result_type.isa<::mlir::IntegerType>() &&
"only float and integer `pow` is supported right now");
// There is no powi, so lower to a simple product. Note that HLO does not
// define semantics of negative exponents.
Value init = b->create<ConstantOp>(loc, b->getIntegerAttr(result_type, 1));
Value lowerBound = b->create<ConstantIndexOp>(loc, 0);
Value upperBound =
b->create<IndexCastOp>(loc, adaptor.rhs(), b->getIndexType());
Value step = b->create<ConstantIndexOp>(loc, 1);
return b
->create<scf::ForOp>(
loc, lowerBound, upperBound, step, llvm::makeArrayRef(init),
[&](OpBuilder& b, Location l, Value v, ValueRange iters) {
Value prod =
b.create<::mlir::MulIOp>(l, adaptor.lhs(), iters.front());
b.create<scf::YieldOp>(l, prod);
})
.getResult(0);
}
template <>
inline Value MapLhloOpToStdScalarOp<lmhlo::SelectOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,

View File

@ -650,6 +650,7 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context,
HloToLhloOpConverter<mhlo::NegOp>,
HloToLhloOpConverter<mhlo::NotOp>,
HloToLhloOpConverter<mhlo::OrOp>,
HloToLhloOpConverter<mhlo::PowOp>,
HloToLhloOpConverter<mhlo::RealOp>,
HloToLhloOpConverter<mhlo::RemOp>,
HloToLhloOpConverter<mhlo::RsqrtOp>,

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Attributes.h"
@ -957,6 +958,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
PointwiseToLinalgConverter<lmhlo::NegOp>,
PointwiseToLinalgConverter<lmhlo::NotOp>,
PointwiseToLinalgConverter<lmhlo::OrOp>,
PointwiseToLinalgConverter<lmhlo::PowOp>,
PointwiseToLinalgConverter<lmhlo::RealOp>,
PointwiseToLinalgConverter<lmhlo::RemOp>,
PointwiseToLinalgConverter<lmhlo::RsqrtOp>,
@ -1021,13 +1023,14 @@ struct LhloLegalizeToLinalgPass
struct HloLegalizeToLinalgPass
: public PassWrapper<HloLegalizeToLinalgPass, FunctionPass> {
void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<linalg::LinalgDialect>();
registry.insert<linalg::LinalgDialect, scf::SCFDialect>();
}
void runOnFunction() override {
OwningRewritePatternList patterns;
ConversionTarget target(getContext());
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect>();
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,
scf::SCFDialect>();
auto func = getFunction();
mhlo::populateHLOToLinalgConversionPattern(func.getContext(), &patterns);
@ -1075,6 +1078,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
PointwiseToLinalgConverter<mhlo::NegOp, false>,
PointwiseToLinalgConverter<mhlo::NotOp, false>,
PointwiseToLinalgConverter<mhlo::OrOp, false>,
PointwiseToLinalgConverter<mhlo::PowOp, false>,
PointwiseToLinalgConverter<mhlo::RealOp, false>,
PointwiseToLinalgConverter<mhlo::RemOp, false>,
PointwiseToLinalgConverter<mhlo::RsqrtOp, false>,

View File

@ -745,3 +745,42 @@ func @constant() {
return
}
// CHECK: %[[CONSTANT:.*]] = constant dense<10> : tensor<i32>
// -----
// CHECK: #map = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func @float_pow
func @float_pow(%lhs: tensor<2x2xf32>,
%rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: linalg.generic
// CHECK: ^{{[a-z0-9_]*}}
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: f32
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: f32
// CHECK: %[[RESULT:[a-zA-Z0-9_]*]] = powf %[[ARG0]], %[[ARG1]]
// CHECK: linalg.yield %[[RESULT]]
%0 = "mhlo.power"(%lhs, %rhs) : (tensor<2x2xf32>,
tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
}
// -----
// CHECK: #map = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func @integer_pow
func @integer_pow(%lhs: tensor<2x2xi32>,
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
// CHECK: linalg.generic
// CHECK: ^{{[a-z0-9_]*}}
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: i32
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: i32
// CHECK: %[[UPPER:.*]] = index_cast %[[ARG1]]
// CHECK: %[[RESULT:.*]] = scf.for {{.*}} to %[[UPPER]]
// CHECK-SAME: step %c1{{[a-zA-Z0-9_]*}}
// CHECK-SAME: iter_args(%[[ITER:.*]] = %c1{{.*}}) -> (i32) {
// CHECK: %[[ACCUM:[a-zA-Z0-9_]*]] = muli %[[ARG0]], %[[ITER]]
// CHECK: scf.yield %[[ACCUM]]
// CHECK: linalg.yield %[[RESULT]]
%0 = "mhlo.power"(%lhs, %rhs) : (tensor<2x2xi32>,
tensor<2x2xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32>
}

View File

@ -1,5 +1,20 @@
// RUN: mlir-hlo-opt %s -lhlo-legalize-to-linalg -split-input-file | FILECHECK_OPTS="" FileCheck %s
// CHECK: #map = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func @element_wise
func @element_wise(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
%result: memref<2x2xf32>) {
"lmhlo.power"(%lhs, %rhs, %result)
: (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
return
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32, %[[RESULT_OUT:.*]]: f32):
// CHECK-NEXT: %[[RESULT:.*]] = powf %[[LHS_IN]], %[[RHS_IN]] : f32
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
// -----
// CHECK: #map = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func @element_wise
func @element_wise(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,