Improve performance of lowered chlo.pow with integers

The new lowering takes 6 iterations of a loop always rather than iterating the exponent's number of times.

PiperOrigin-RevId: 355131133
This commit is contained in:
Tres Popp 2021-02-02 03:27:38 -08:00 committed by TensorFlow MLIR Team
parent 0458ae9a22
commit ae722a883f
2 changed files with 62 additions and 40 deletions

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/TypeUtilities.h" #include "mlir/IR/TypeUtilities.h"
namespace mlir { namespace mlir {
@ -588,6 +589,7 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::PowOp>(Location loc,
ArrayRef<Value> args, ArrayRef<Value> args,
OpBuilder* b) { OpBuilder* b) {
lmhlo::PowOp::Adaptor adaptor(args); lmhlo::PowOp::Adaptor adaptor(args);
auto lb = ImplicitLocOpBuilder(loc, *b);
// Floating point can use std::powf // Floating point can use std::powf
auto result_type = result_types.front(); auto result_type = result_types.front();
if (result_type.isa<::mlir::FloatType>()) if (result_type.isa<::mlir::FloatType>())
@ -597,53 +599,66 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::PowOp>(Location loc,
assert(result_type.isa<::mlir::IntegerType>() && assert(result_type.isa<::mlir::IntegerType>() &&
"only float and integer `pow` is supported right now"); "only float and integer `pow` is supported right now");
// There is no powi, so lower to a simple product. // Exponentiation by squaring:
Value neg_one = // https://en.wikipedia.org/wiki/Exponentiation_by_squaring;
b->create<ConstantOp>(loc, b->getIntegerAttr(result_type, -1)); Value neg_one = lb.create<ConstantOp>(lb.getIntegerAttr(result_type, -1));
Value zero = b->create<ConstantOp>(loc, b->getIntegerAttr(result_type, 0)); Value zero = lb.create<ConstantOp>(lb.getIntegerAttr(result_type, 0));
Value one = b->create<ConstantOp>(loc, b->getIntegerAttr(result_type, 1)); Value one = lb.create<ConstantOp>(lb.getIntegerAttr(result_type, 1));
Value two = b->create<ConstantOp>(loc, b->getIntegerAttr(result_type, 2)); Value two = lb.create<ConstantOp>(lb.getIntegerAttr(result_type, 2));
Value step = lb.create<ConstantIndexOp>(1);
Value lowerBound = lb.create<ConstantIndexOp>(0);
// Everything else would overflow for any exponent > 1, as 2^64
// is the larget possible exponent for a 64-bit integer, and
// that's 1 << 6.
Value upperBound = lb.create<ConstantIndexOp>(6);
auto original_base = adaptor.lhs();
auto original_exponent = adaptor.rhs();
Value lowerBound = b->create<ConstantIndexOp>(loc, 0); Value accum =
Value upperBound = lb.create<scf::ForOp>(
b->create<IndexCastOp>(loc, adaptor.rhs(), b->getIndexType()); lowerBound, upperBound, step,
Value step = b->create<ConstantIndexOp>(loc, 1); SmallVector<Value>({one, original_base, original_exponent}),
Value for_result = [&](OpBuilder& b, Location, Value v, ValueRange iters) {
b->create<scf::ForOp>( Value accum = iters[0];
loc, lowerBound, upperBound, step, llvm::makeArrayRef(one), Value base = iters[1];
[&](OpBuilder& b, Location l, Value v, ValueRange iters) { Value exponent = iters[2];
Value prod =
b.create<::mlir::MulIOp>(l, adaptor.lhs(), iters.front()); Value condition = b.create<CmpIOp>(
b.create<scf::YieldOp>(l, prod); loc, CmpIPredicate::eq,
b.create<::mlir::AndOp>(loc, exponent, one), one);
Value multiplied = b.create<::mlir::MulIOp>(loc, accum, base);
accum =
b.create<::mlir::SelectOp>(loc, condition, multiplied, accum);
base = b.create<::mlir::MulIOp>(loc, base, base);
exponent =
b.create<::mlir::UnsignedShiftRightOp>(loc, exponent, one);
b.create<scf::YieldOp>(
loc, SmallVector<Value>({accum, base, exponent}));
}) })
.getResult(0); .getResult(0);
Value rhs_is_even = Value rhs_is_even = lb.create<CmpIOp>(
b->create<CmpIOp>(loc, CmpIPredicate::eq, CmpIPredicate::eq, lb.create<SignedRemIOp>(adaptor.rhs(), two), zero);
b->create<SignedRemIOp>(loc, adaptor.rhs(), two), zero);
Value rhs_is_negative = Value rhs_is_negative =
b->create<CmpIOp>(loc, CmpIPredicate::slt, adaptor.rhs(), zero); lb.create<CmpIOp>(CmpIPredicate::slt, adaptor.rhs(), zero);
Value lhs_is_one = Value lhs_is_one = lb.create<CmpIOp>(CmpIPredicate::eq, adaptor.lhs(), one);
b->create<CmpIOp>(loc, CmpIPredicate::eq, adaptor.lhs(), one);
Value lhs_is_neg_one = Value lhs_is_neg_one =
b->create<CmpIOp>(loc, CmpIPredicate::eq, adaptor.lhs(), neg_one); lb.create<CmpIOp>(CmpIPredicate::eq, adaptor.lhs(), neg_one);
// The for_result is correct when the rhs is non-negative. When rhs is // The accum is correct when the rhs is non-negative. When rhs is
// negative, we return 0 for integer, with the exception of lhs values of 1 // negative, we return 0 for integer, with the exception of lhs values of 1
// and -1 which have integer results for negative exponents. Specifically, the // and -1 which have integer results for negative exponents. Specifically, the
// calulation is the following: // calulation is the following:
// //
// - Return for_result if the rhs is not negative. // - Return accum if the rhs is not negative.
// - Return 1 or -1 depending on the parity of rhs when the lhs is -1. // - Return 1 or -1 depending on the parity of rhs when the lhs is -1.
// - Return 1 if lhs is 1. // - Return 1 if lhs is 1.
// - Else return 0. // - Else return 0.
Value if_lhs_is_one = b->create<::mlir::SelectOp>(loc, lhs_is_one, one, zero); Value if_lhs_is_one = lb.create<::mlir::SelectOp>(lhs_is_one, one, zero);
Value if_lhs_is_neg_one = b->create<::mlir::SelectOp>( Value if_lhs_is_neg_one = lb.create<::mlir::SelectOp>(
loc, lhs_is_neg_one, lhs_is_neg_one, lb.create<::mlir::SelectOp>(rhs_is_even, one, neg_one),
b->create<::mlir::SelectOp>(loc, rhs_is_even, one, neg_one),
if_lhs_is_one); if_lhs_is_one);
return b->create<::mlir::SelectOp>(loc, rhs_is_negative, if_lhs_is_neg_one, return lb.create<::mlir::SelectOp>(rhs_is_negative, if_lhs_is_neg_one, accum);
for_result);
} }
template <> template <>

View File

@ -851,12 +851,19 @@ func @integer_pow(%lhs: tensor<2x2xi32>,
// CHECK: ^{{[a-z0-9_]*}} // CHECK: ^{{[a-z0-9_]*}}
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: i32 // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: i32
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: i32 // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: i32
// CHECK: %[[UPPER:.*]] = index_cast %[[ARG1]] // CHECK: %[[FOR_RESULT:[a-zA-Z0-9_]*]]:3 = scf.for {{.*}} to %c6 step %c1
// CHECK: %[[FOR_RESULT:.*]] = scf.for {{.*}} to %[[UPPER]] // CHECK-SAME: iter_args(
// CHECK-SAME: step %c1{{[a-zA-Z0-9_]*}} // CHECK-SAME: %[[ITER0:.*]] = %c1
// CHECK-SAME: iter_args(%[[ITER:.*]] = %c1{{.*}}) -> (i32) { // CHECK-SAME: %[[ITER1:.*]] = %[[ARG0]]
// CHECK: %[[ACCUM:[a-zA-Z0-9_]*]] = muli %[[ARG0]], %[[ITER]] // CHECK-SAME: %[[ITER2:.*]] = %[[ARG1]]
// CHECK: scf.yield %[[ACCUM]] // CHECK-SAME: ) -> (i32, i32, i32) {
// CHECK: %[[AND:[a-zA-Z0-9_]*]] = and %[[ITER2]], %c1
// CHECK: %[[COND:[a-zA-Z0-9_]*]] = cmpi eq, %[[AND]], %c1
// CHECK: %[[MUL:[a-zA-Z0-9_]*]] = muli %[[ITER0]], %[[ITER1]]
// CHECK: %[[ACCUM:[a-zA-Z0-9_]*]] = select %[[COND]], %[[MUL]], %[[ITER0]]
// CHECK: %[[BASE:[a-zA-Z0-9_]*]] = muli %[[ITER1]], %[[ITER1]]
// CHECK: %[[EXP:[a-zA-Z0-9_]*]] = shift_right_unsigned %[[ITER2]], %c1
// CHECK: scf.yield %[[ACCUM]], %[[BASE]], %[[EXP]]
// CHECK: %[[RHS_PARITY:.*]] = remi_signed %[[ARG1]], %c2 // CHECK: %[[RHS_PARITY:.*]] = remi_signed %[[ARG1]], %c2
// CHECK: %[[RHS_EVEN:.*]] = cmpi eq, %[[RHS_PARITY]], %c0 // CHECK: %[[RHS_EVEN:.*]] = cmpi eq, %[[RHS_PARITY]], %c0
// CHECK: %[[RHS_NEG:.*]] = cmpi slt, %[[ARG1]], %c0 // CHECK: %[[RHS_NEG:.*]] = cmpi slt, %[[ARG1]], %c0
@ -865,7 +872,7 @@ func @integer_pow(%lhs: tensor<2x2xi32>,
// CHECK: %[[VAL5:.*]] = select %[[LHS_ONE]], %c1_i32, %c0 // CHECK: %[[VAL5:.*]] = select %[[LHS_ONE]], %c1_i32, %c0
// CHECK: %[[VAL6:.*]] = select %[[RHS_EVEN]], %c1{{.*}}, %c-1 // CHECK: %[[VAL6:.*]] = select %[[RHS_EVEN]], %c1{{.*}}, %c-1
// CHECK: %[[VAL7:.*]] = select %[[LHS_NEG_ONE]], %[[VAL6]], %[[VAL5]] // CHECK: %[[VAL7:.*]] = select %[[LHS_NEG_ONE]], %[[VAL6]], %[[VAL5]]
// CHECK: %[[RESULT:.*]] = select %[[RHS_NEG]], %[[VAL7]], %[[FOR_RESULT]] // CHECK: %[[RESULT:.*]] = select %[[RHS_NEG]], %[[VAL7]], %[[FOR_RESULT]]#0
// CHECK: linalg.yield %[[RESULT]] // CHECK: linalg.yield %[[RESULT]]
%0 = "mhlo.power"(%lhs, %rhs) : (tensor<2x2xi32>, %0 = "mhlo.power"(%lhs, %rhs) : (tensor<2x2xi32>,
tensor<2x2xi32>) -> tensor<2x2xi32> tensor<2x2xi32>) -> tensor<2x2xi32>