Define lowering of [l]mhlo.pow.
For floating point operations, this uses std.pow. For integer operations, this lowers to a loop. This adds a dependency on scf. PiperOrigin-RevId: 348537232
This commit is contained in:
parent
99a0ee378c
commit
a42213b870
2
BUILD
2
BUILD
|
@ -593,6 +593,7 @@ cc_library(
|
||||||
":map_hlo_to_lhlo_op",
|
":map_hlo_to_lhlo_op",
|
||||||
"@llvm-project//llvm:Support",
|
"@llvm-project//llvm:Support",
|
||||||
"@llvm-project//mlir:IR",
|
"@llvm-project//mlir:IR",
|
||||||
|
"@llvm-project//mlir:SCFDialect",
|
||||||
"@llvm-project//mlir:StandardOps",
|
"@llvm-project//mlir:StandardOps",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -664,6 +665,7 @@ cc_library(
|
||||||
"@llvm-project//mlir:IR",
|
"@llvm-project//mlir:IR",
|
||||||
"@llvm-project//mlir:LinalgOps",
|
"@llvm-project//mlir:LinalgOps",
|
||||||
"@llvm-project//mlir:Pass",
|
"@llvm-project//mlir:Pass",
|
||||||
|
"@llvm-project//mlir:SCFDialect",
|
||||||
"@llvm-project//mlir:StandardOps",
|
"@llvm-project//mlir:StandardOps",
|
||||||
"@llvm-project//mlir:Transforms",
|
"@llvm-project//mlir:Transforms",
|
||||||
],
|
],
|
||||||
|
|
|
@ -66,6 +66,7 @@ MAP_HLO_TO_LHLO(MulOp);
|
||||||
MAP_HLO_TO_LHLO(NegOp);
|
MAP_HLO_TO_LHLO(NegOp);
|
||||||
MAP_HLO_TO_LHLO(NotOp);
|
MAP_HLO_TO_LHLO(NotOp);
|
||||||
MAP_HLO_TO_LHLO(OrOp);
|
MAP_HLO_TO_LHLO(OrOp);
|
||||||
|
MAP_HLO_TO_LHLO(PowOp);
|
||||||
MAP_HLO_TO_LHLO(RealOp);
|
MAP_HLO_TO_LHLO(RealOp);
|
||||||
MAP_HLO_TO_LHLO(ReduceOp);
|
MAP_HLO_TO_LHLO(ReduceOp);
|
||||||
MAP_HLO_TO_LHLO(ReshapeOp);
|
MAP_HLO_TO_LHLO(ReshapeOp);
|
||||||
|
|
|
@ -16,12 +16,16 @@ limitations under the License.
|
||||||
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_LMHLO_TO_SCALAR_OP_H_
|
#ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_LMHLO_TO_SCALAR_OP_H_
|
||||||
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_LMHLO_TO_SCALAR_OP_H_
|
#define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_LMHLO_TO_SCALAR_OP_H_
|
||||||
|
|
||||||
|
#include "llvm/ADT/ArrayRef.h"
|
||||||
#include "llvm/ADT/StringRef.h"
|
#include "llvm/ADT/StringRef.h"
|
||||||
#include "llvm/ADT/StringSwitch.h"
|
#include "llvm/ADT/StringSwitch.h"
|
||||||
|
#include "llvm/ADT/iterator_range.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h"
|
#include "mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h"
|
||||||
|
#include "mlir/Dialect/SCF/SCF.h"
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
#include "mlir/IR/TypeUtilities.h"
|
#include "mlir/IR/TypeUtilities.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
|
@ -508,6 +512,40 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::RsqrtOp>(Location loc,
|
||||||
loc, result_types, args, b);
|
loc, result_types, args, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline Value MapLhloOpToStdScalarOp<lmhlo::PowOp>(Location loc,
|
||||||
|
ArrayRef<Type> result_types,
|
||||||
|
ArrayRef<Value> args,
|
||||||
|
OpBuilder* b) {
|
||||||
|
lmhlo::PowOp::Adaptor adaptor(args);
|
||||||
|
// Floating point can use std::powf
|
||||||
|
auto result_type = result_types.front();
|
||||||
|
if (result_type.isa<::mlir::FloatType>())
|
||||||
|
return MapLhloOpToStdScalarOpImpl<::mlir::PowFOp>{}(loc, result_types, args,
|
||||||
|
b);
|
||||||
|
|
||||||
|
assert(result_type.isa<::mlir::IntegerType>() &&
|
||||||
|
"only float and integer `pow` is supported right now");
|
||||||
|
|
||||||
|
// There is no powi, so lower to a simple product. Note that HLO does not
|
||||||
|
// define semantics of negative exponents.
|
||||||
|
Value init = b->create<ConstantOp>(loc, b->getIntegerAttr(result_type, 1));
|
||||||
|
|
||||||
|
Value lowerBound = b->create<ConstantIndexOp>(loc, 0);
|
||||||
|
Value upperBound =
|
||||||
|
b->create<IndexCastOp>(loc, adaptor.rhs(), b->getIndexType());
|
||||||
|
Value step = b->create<ConstantIndexOp>(loc, 1);
|
||||||
|
return b
|
||||||
|
->create<scf::ForOp>(
|
||||||
|
loc, lowerBound, upperBound, step, llvm::makeArrayRef(init),
|
||||||
|
[&](OpBuilder& b, Location l, Value v, ValueRange iters) {
|
||||||
|
Value prod =
|
||||||
|
b.create<::mlir::MulIOp>(l, adaptor.lhs(), iters.front());
|
||||||
|
b.create<scf::YieldOp>(l, prod);
|
||||||
|
})
|
||||||
|
.getResult(0);
|
||||||
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline Value MapLhloOpToStdScalarOp<lmhlo::SelectOp>(
|
inline Value MapLhloOpToStdScalarOp<lmhlo::SelectOp>(
|
||||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||||
|
|
|
@ -650,6 +650,7 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context,
|
||||||
HloToLhloOpConverter<mhlo::NegOp>,
|
HloToLhloOpConverter<mhlo::NegOp>,
|
||||||
HloToLhloOpConverter<mhlo::NotOp>,
|
HloToLhloOpConverter<mhlo::NotOp>,
|
||||||
HloToLhloOpConverter<mhlo::OrOp>,
|
HloToLhloOpConverter<mhlo::OrOp>,
|
||||||
|
HloToLhloOpConverter<mhlo::PowOp>,
|
||||||
HloToLhloOpConverter<mhlo::RealOp>,
|
HloToLhloOpConverter<mhlo::RealOp>,
|
||||||
HloToLhloOpConverter<mhlo::RemOp>,
|
HloToLhloOpConverter<mhlo::RemOp>,
|
||||||
HloToLhloOpConverter<mhlo::RsqrtOp>,
|
HloToLhloOpConverter<mhlo::RsqrtOp>,
|
||||||
|
|
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||||
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
|
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
|
||||||
|
#include "mlir/Dialect/SCF/SCF.h"
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
#include "mlir/IR/AffineExpr.h"
|
#include "mlir/IR/AffineExpr.h"
|
||||||
#include "mlir/IR/Attributes.h"
|
#include "mlir/IR/Attributes.h"
|
||||||
|
@ -957,6 +958,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
|
||||||
PointwiseToLinalgConverter<lmhlo::NegOp>,
|
PointwiseToLinalgConverter<lmhlo::NegOp>,
|
||||||
PointwiseToLinalgConverter<lmhlo::NotOp>,
|
PointwiseToLinalgConverter<lmhlo::NotOp>,
|
||||||
PointwiseToLinalgConverter<lmhlo::OrOp>,
|
PointwiseToLinalgConverter<lmhlo::OrOp>,
|
||||||
|
PointwiseToLinalgConverter<lmhlo::PowOp>,
|
||||||
PointwiseToLinalgConverter<lmhlo::RealOp>,
|
PointwiseToLinalgConverter<lmhlo::RealOp>,
|
||||||
PointwiseToLinalgConverter<lmhlo::RemOp>,
|
PointwiseToLinalgConverter<lmhlo::RemOp>,
|
||||||
PointwiseToLinalgConverter<lmhlo::RsqrtOp>,
|
PointwiseToLinalgConverter<lmhlo::RsqrtOp>,
|
||||||
|
@ -1021,13 +1023,14 @@ struct LhloLegalizeToLinalgPass
|
||||||
struct HloLegalizeToLinalgPass
|
struct HloLegalizeToLinalgPass
|
||||||
: public PassWrapper<HloLegalizeToLinalgPass, FunctionPass> {
|
: public PassWrapper<HloLegalizeToLinalgPass, FunctionPass> {
|
||||||
void getDependentDialects(DialectRegistry& registry) const override {
|
void getDependentDialects(DialectRegistry& registry) const override {
|
||||||
registry.insert<linalg::LinalgDialect>();
|
registry.insert<linalg::LinalgDialect, scf::SCFDialect>();
|
||||||
}
|
}
|
||||||
|
|
||||||
void runOnFunction() override {
|
void runOnFunction() override {
|
||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
ConversionTarget target(getContext());
|
ConversionTarget target(getContext());
|
||||||
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect>();
|
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,
|
||||||
|
scf::SCFDialect>();
|
||||||
|
|
||||||
auto func = getFunction();
|
auto func = getFunction();
|
||||||
mhlo::populateHLOToLinalgConversionPattern(func.getContext(), &patterns);
|
mhlo::populateHLOToLinalgConversionPattern(func.getContext(), &patterns);
|
||||||
|
@ -1075,6 +1078,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
||||||
PointwiseToLinalgConverter<mhlo::NegOp, false>,
|
PointwiseToLinalgConverter<mhlo::NegOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::NotOp, false>,
|
PointwiseToLinalgConverter<mhlo::NotOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::OrOp, false>,
|
PointwiseToLinalgConverter<mhlo::OrOp, false>,
|
||||||
|
PointwiseToLinalgConverter<mhlo::PowOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::RealOp, false>,
|
PointwiseToLinalgConverter<mhlo::RealOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::RemOp, false>,
|
PointwiseToLinalgConverter<mhlo::RemOp, false>,
|
||||||
PointwiseToLinalgConverter<mhlo::RsqrtOp, false>,
|
PointwiseToLinalgConverter<mhlo::RsqrtOp, false>,
|
||||||
|
|
|
@ -745,3 +745,42 @@ func @constant() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// CHECK: %[[CONSTANT:.*]] = constant dense<10> : tensor<i32>
|
// CHECK: %[[CONSTANT:.*]] = constant dense<10> : tensor<i32>
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK: #map = affine_map<(d0, d1) -> (d0, d1)>
|
||||||
|
// CHECK-LABEL: func @float_pow
|
||||||
|
func @float_pow(%lhs: tensor<2x2xf32>,
|
||||||
|
%rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||||
|
// CHECK: linalg.generic
|
||||||
|
// CHECK: ^{{[a-z0-9_]*}}
|
||||||
|
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: f32
|
||||||
|
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: f32
|
||||||
|
// CHECK: %[[RESULT:[a-zA-Z0-9_]*]] = powf %[[ARG0]], %[[ARG1]]
|
||||||
|
// CHECK: linalg.yield %[[RESULT]]
|
||||||
|
%0 = "mhlo.power"(%lhs, %rhs) : (tensor<2x2xf32>,
|
||||||
|
tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||||
|
return %0 : tensor<2x2xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK: #map = affine_map<(d0, d1) -> (d0, d1)>
|
||||||
|
// CHECK-LABEL: func @integer_pow
|
||||||
|
func @integer_pow(%lhs: tensor<2x2xi32>,
|
||||||
|
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
|
||||||
|
// CHECK: linalg.generic
|
||||||
|
// CHECK: ^{{[a-z0-9_]*}}
|
||||||
|
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: i32
|
||||||
|
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: i32
|
||||||
|
// CHECK: %[[UPPER:.*]] = index_cast %[[ARG1]]
|
||||||
|
// CHECK: %[[RESULT:.*]] = scf.for {{.*}} to %[[UPPER]]
|
||||||
|
// CHECK-SAME: step %c1{{[a-zA-Z0-9_]*}}
|
||||||
|
// CHECK-SAME: iter_args(%[[ITER:.*]] = %c1{{.*}}) -> (i32) {
|
||||||
|
// CHECK: %[[ACCUM:[a-zA-Z0-9_]*]] = muli %[[ARG0]], %[[ITER]]
|
||||||
|
// CHECK: scf.yield %[[ACCUM]]
|
||||||
|
// CHECK: linalg.yield %[[RESULT]]
|
||||||
|
%0 = "mhlo.power"(%lhs, %rhs) : (tensor<2x2xi32>,
|
||||||
|
tensor<2x2xi32>) -> tensor<2x2xi32>
|
||||||
|
return %0 : tensor<2x2xi32>
|
||||||
|
}
|
||||||
|
|
|
@ -1,5 +1,20 @@
|
||||||
// RUN: mlir-hlo-opt %s -lhlo-legalize-to-linalg -split-input-file | FILECHECK_OPTS="" FileCheck %s
|
// RUN: mlir-hlo-opt %s -lhlo-legalize-to-linalg -split-input-file | FILECHECK_OPTS="" FileCheck %s
|
||||||
|
|
||||||
|
// CHECK: #map = affine_map<(d0, d1) -> (d0, d1)>
|
||||||
|
// CHECK-LABEL: func @element_wise
|
||||||
|
func @element_wise(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
|
||||||
|
%result: memref<2x2xf32>) {
|
||||||
|
"lmhlo.power"(%lhs, %rhs, %result)
|
||||||
|
: (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// CHECK: linalg.generic
|
||||||
|
// CHECK-NEXT: ^bb0(%[[LHS_IN:.*]]: f32, %[[RHS_IN:.*]]: f32, %[[RESULT_OUT:.*]]: f32):
|
||||||
|
// CHECK-NEXT: %[[RESULT:.*]] = powf %[[LHS_IN]], %[[RHS_IN]] : f32
|
||||||
|
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK: #map = affine_map<(d0, d1) -> (d0, d1)>
|
// CHECK: #map = affine_map<(d0, d1) -> (d0, d1)>
|
||||||
// CHECK-LABEL: func @element_wise
|
// CHECK-LABEL: func @element_wise
|
||||||
func @element_wise(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
|
func @element_wise(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
|
||||||
|
|
Loading…
Reference in New Issue