Define lowering of [l]mhlo.pow.
For floating point operations, this uses std.pow. For integer operations, this lowers to a loop. This adds a dependency on scf. PiperOrigin-RevId: 348537232
This commit is contained in:
parent
99a0ee378c
commit
a42213b870
2
BUILD
2
BUILD
|
@ -593,6 +593,7 @@ cc_library(
|
|||
":map_hlo_to_lhlo_op",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:SCFDialect",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
],
|
||||
)
|
||||
|
@ -664,6 +665,7 @@ cc_library(
|
|||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:LinalgOps",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:SCFDialect",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
],
|
||||
|
|
|
@ -66,6 +66,7 @@ MAP_HLO_TO_LHLO(MulOp);
|
|||
MAP_HLO_TO_LHLO(NegOp);
|
||||
MAP_HLO_TO_LHLO(NotOp);
|
||||
MAP_HLO_TO_LHLO(OrOp);
|
||||
MAP_HLO_TO_LHLO(PowOp);
|
||||
MAP_HLO_TO_LHLO(RealOp);
|
||||
MAP_HLO_TO_LHLO(ReduceOp);
|
||||
MAP_HLO_TO_LHLO(ReshapeOp);
|
||||
|
|
|
@ -16,12 +16,16 @@ limitations under the License.
|
|||
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_LMHLO_TO_SCALAR_OP_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_LMHLO_TO_SCALAR_OP_H_
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/ADT/StringSwitch.h"
|
||||
#include "llvm/ADT/iterator_range.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
|
||||
namespace mlir {
|
||||
|
@ -508,6 +512,40 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::RsqrtOp>(Location loc,
|
|||
loc, result_types, args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::PowOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
lmhlo::PowOp::Adaptor adaptor(args);
|
||||
// Floating point can use std::powf
|
||||
auto result_type = result_types.front();
|
||||
if (result_type.isa<::mlir::FloatType>())
|
||||
return MapLhloOpToStdScalarOpImpl<::mlir::PowFOp>{}(loc, result_types, args,
|
||||
b);
|
||||
|
||||
assert(result_type.isa<::mlir::IntegerType>() &&
|
||||
"only float and integer `pow` is supported right now");
|
||||
|
||||
// There is no powi, so lower to a simple product. Note that HLO does not
|
||||
// define semantics of negative exponents.
|
||||
Value init = b->create<ConstantOp>(loc, b->getIntegerAttr(result_type, 1));
|
||||
|
||||
Value lowerBound = b->create<ConstantIndexOp>(loc, 0);
|
||||
Value upperBound =
|
||||
b->create<IndexCastOp>(loc, adaptor.rhs(), b->getIndexType());
|
||||
Value step = b->create<ConstantIndexOp>(loc, 1);
|
||||
return b
|
||||
->create<scf::ForOp>(
|
||||
loc, lowerBound, upperBound, step, llvm::makeArrayRef(init),
|
||||
[&](OpBuilder& b, Location l, Value v, ValueRange iters) {
|
||||
Value prod =
|
||||
b.create<::mlir::MulIOp>(l, adaptor.lhs(), iters.front());
|
||||
b.create<scf::YieldOp>(l, prod);
|
||||
})
|
||||
.getResult(0);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::SelectOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
|
|
|
@ -650,6 +650,7 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context,
|
|||
HloToLhloOpConverter<mhlo::NegOp>,
|
||||
HloToLhloOpConverter<mhlo::NotOp>,
|
||||
HloToLhloOpConverter<mhlo::OrOp>,
|
||||
HloToLhloOpConverter<mhlo::PowOp>,
|
||||
HloToLhloOpConverter<mhlo::RealOp>,
|
||||
HloToLhloOpConverter<mhlo::RemOp>,
|
||||
HloToLhloOpConverter<mhlo::RsqrtOp>,
|
||||
|
|
|
@ -25,6 +25,7 @@ limitations under the License.
|
|||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
|
@ -957,6 +958,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
|
|||
PointwiseToLinalgConverter<lmhlo::NegOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::NotOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::OrOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::PowOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::RealOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::RemOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::RsqrtOp>,
|
||||
|
@ -1021,13 +1023,14 @@ struct LhloLegalizeToLinalgPass
|
|||
struct HloLegalizeToLinalgPass
|
||||
: public PassWrapper<HloLegalizeToLinalgPass, FunctionPass> {
|
||||
void getDependentDialects(DialectRegistry& registry) const override {
|
||||
registry.insert<linalg::LinalgDialect>();
|
||||
registry.insert<linalg::LinalgDialect, scf::SCFDialect>();
|
||||
}
|
||||
|
||||
void runOnFunction() override {
|
||||
OwningRewritePatternList patterns;
|
||||
ConversionTarget target(getContext());
|
||||
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect>();
|
||||
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,
|
||||
scf::SCFDialect>();
|
||||
|
||||
auto func = getFunction();
|
||||
mhlo::populateHLOToLinalgConversionPattern(func.getContext(), &patterns);
|
||||
|
@ -1075,6 +1078,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
|||
PointwiseToLinalgConverter<mhlo::NegOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::NotOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::OrOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::PowOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::RealOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::RemOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::RsqrtOp, false>,
|
||||
|
|
|
@ -745,3 +745,42 @@ func @constant() {
|
|||
return
|
||||
}
|
||||
// CHECK: %[[CONSTANT:.*]] = constant dense<10> : tensor<i32>
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: #map = affine_map<(d0, d1) -> (d0, d1)>
|
||||
// CHECK-LABEL: func @float_pow
|
||||
func @float_pow(%lhs: tensor<2x2xf32>,
|
||||
%rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: ^{{[a-z0-9_]*}}
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: f32
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: f32
|
||||
// CHECK: %[[RESULT:[a-zA-Z0-9_]*]] = powf %[[ARG0]], %[[ARG1]]
|
||||
// CHECK: linalg.yield %[[RESULT]]
|
||||
%0 = "mhlo.power"(%lhs, %rhs) : (tensor<2x2xf32>,
|
||||
tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: #map = affine_map<(d0, d1) -> (d0, d1)>
|
||||
// CHECK-LABEL: func @integer_pow
|
||||
func @integer_pow(%lhs: tensor<2x2xi32>,
|
||||
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: ^{{[a-z0-9_]*}}
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: i32
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: i32
|
||||
// CHECK: %[[UPPER:.*]] = index_cast %[[ARG1]]
|
||||
// CHECK: %[[RESULT:.*]] = scf.for {{.*}} to %[[UPPER]]
|
||||
// CHECK-SAME: step %c1{{[a-zA-Z0-9_]*}}
|
||||
// CHECK-SAME: iter_args(%[[ITER:.*]] = %c1{{.*}}) -> (i32) {
|
||||
// CHECK: %[[ACCUM:[a-zA-Z0-9_]*]] = muli %[[ARG0]], %[[ITER]]
|
||||
// CHECK: scf.yield %[[ACCUM]]
|
||||
// CHECK: linalg.yield %[[RESULT]]
|
||||
%0 = "mhlo.power"(%lhs, %rhs) : (tensor<2x2xi32>,
|
||||
tensor<2x2xi32>) -> tensor<2x2xi32>
|
||||
return %0 : tensor<2x2xi32>
|
||||
}
|
||||
|
|
|
@ -1,5 +1,20 @@
|
|||
// RUN: mlir-hlo-opt %s -lhlo-legalize-to-linalg -split-input-file | FILECHECK_OPTS="" FileCheck %s
|
||||
|
||||
// CHECK: #map = affine_map<(d0, d1) -> (d0, d1)>
|
||||
// CHECK-LABEL: func @element_wise
|
||||
func @element_wise(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
|
||||
%result: memref<2x2xf32>) {
|
||||
"lmhlo.power"(%lhs, %rhs, %result)
|
||||
: (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32, %[[RESULT_OUT:.*]]: f32):
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = powf %[[LHS_IN]], %[[RHS_IN]] : f32
|
||||
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: #map = affine_map<(d0, d1) -> (d0, d1)>
|
||||
// CHECK-LABEL: func @element_wise
|
||||
func @element_wise(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
|
||||
|
|
Loading…
Reference in New Issue