From ba2ee556f1d25600ae67b0ec51059fdca27d8ff3 Mon Sep 17 00:00:00 2001 From: Tres Popp Date: Mon, 18 Jan 2021 03:46:41 -0800 Subject: [PATCH] Handle negative exponents for lowering of hlo.pow PiperOrigin-RevId: 352382812 --- .../mhlo/transforms/map_lmhlo_to_scalar_op.h | 54 ++++++++++++++----- tests/hlo-legalize-to-linalg.mlir | 13 ++++- 2 files changed, 53 insertions(+), 14 deletions(-) diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h b/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h index bc163a6..636cd8c 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h @@ -556,23 +556,53 @@ inline Value MapLhloOpToStdScalarOp(Location loc, assert(result_type.isa<::mlir::IntegerType>() && "only float and integer `pow` is supported right now"); - // There is no powi, so lower to a simple product. Note that HLO does not - // define semantics of negative exponents. - Value init = b->create(loc, b->getIntegerAttr(result_type, 1)); + // There is no powi, so lower to a simple product. + Value neg_one = + b->create(loc, b->getIntegerAttr(result_type, -1)); + Value zero = b->create(loc, b->getIntegerAttr(result_type, 0)); + Value one = b->create(loc, b->getIntegerAttr(result_type, 1)); + Value two = b->create(loc, b->getIntegerAttr(result_type, 2)); Value lowerBound = b->create(loc, 0); Value upperBound = b->create(loc, adaptor.rhs(), b->getIndexType()); Value step = b->create(loc, 1); - return b - ->create( - loc, lowerBound, upperBound, step, llvm::makeArrayRef(init), - [&](OpBuilder& b, Location l, Value v, ValueRange iters) { - Value prod = - b.create<::mlir::MulIOp>(l, adaptor.lhs(), iters.front()); - b.create(l, prod); - }) - .getResult(0); + Value for_result = + b->create( + loc, lowerBound, upperBound, step, llvm::makeArrayRef(one), + [&](OpBuilder& b, Location l, Value v, ValueRange iters) { + Value prod = + b.create<::mlir::MulIOp>(l, adaptor.lhs(), iters.front()); + b.create(l, prod); + }) + .getResult(0); + + Value rhs_is_even = + b->create(loc, CmpIPredicate::eq, + b->create(loc, adaptor.rhs(), two), zero); + Value rhs_is_negative = + b->create(loc, CmpIPredicate::slt, adaptor.rhs(), zero); + Value lhs_is_one = + b->create(loc, CmpIPredicate::eq, adaptor.lhs(), one); + Value lhs_is_neg_one = + b->create(loc, CmpIPredicate::eq, adaptor.lhs(), neg_one); + + // The for_result is correct when the rhs is non-negative. When rhs is + // negative, we return 0 for integer, with the exception of lhs values of 1 + // and -1 which have integer results for negative exponents. Specifically, the + // calulation is the following: + // + // - Return for_result if the rhs is not negative. + // - Return 1 or -1 depending on the parity of rhs when the lhs is -1. + // - Return 1 if lhs is 1. + // - Else return 0. + Value if_lhs_is_one = b->create<::mlir::SelectOp>(loc, lhs_is_one, one, zero); + Value if_lhs_is_neg_one = b->create<::mlir::SelectOp>( + loc, lhs_is_neg_one, + b->create<::mlir::SelectOp>(loc, rhs_is_even, one, neg_one), + if_lhs_is_one); + return b->create<::mlir::SelectOp>(loc, rhs_is_negative, if_lhs_is_neg_one, + for_result); } template <> diff --git a/tests/hlo-legalize-to-linalg.mlir b/tests/hlo-legalize-to-linalg.mlir index 6f59bf2..298ce7b 100644 --- a/tests/hlo-legalize-to-linalg.mlir +++ b/tests/hlo-legalize-to-linalg.mlir @@ -792,17 +792,26 @@ func @float_pow(%lhs: tensor<2x2xf32>, // CHECK: #map = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: func @integer_pow func @integer_pow(%lhs: tensor<2x2xi32>, - %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { + %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { // CHECK: linalg.generic // CHECK: ^{{[a-z0-9_]*}} // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: i32 // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: i32 // CHECK: %[[UPPER:.*]] = index_cast %[[ARG1]] - // CHECK: %[[RESULT:.*]] = scf.for {{.*}} to %[[UPPER]] + // CHECK: %[[FOR_RESULT:.*]] = scf.for {{.*}} to %[[UPPER]] // CHECK-SAME: step %c1{{[a-zA-Z0-9_]*}} // CHECK-SAME: iter_args(%[[ITER:.*]] = %c1{{.*}}) -> (i32) { // CHECK: %[[ACCUM:[a-zA-Z0-9_]*]] = muli %[[ARG0]], %[[ITER]] // CHECK: scf.yield %[[ACCUM]] + // CHECK: %[[RHS_PARITY:.*]] = remi_signed %[[ARG1]], %c2 + // CHECK: %[[RHS_EVEN:.*]] = cmpi eq, %[[RHS_PARITY]], %c0 + // CHECK: %[[RHS_NEG:.*]] = cmpi slt, %[[ARG1]], %c0 + // CHECK: %[[LHS_ONE:.*]] = cmpi eq, %[[ARG0]], %c1 + // CHECK: %[[LHS_NEG_ONE:.*]] = cmpi eq, %[[ARG0]], %c-1 + // CHECK: %[[VAL5:.*]] = select %[[LHS_ONE]], %c1_i32, %c0 + // CHECK: %[[VAL6:.*]] = select %[[RHS_EVEN]], %c1{{.*}}, %c-1 + // CHECK: %[[VAL7:.*]] = select %[[LHS_NEG_ONE]], %[[VAL6]], %[[VAL5]] + // CHECK: %[[RESULT:.*]] = select %[[RHS_NEG]], %[[VAL7]], %[[FOR_RESULT]] // CHECK: linalg.yield %[[RESULT]] %0 = "mhlo.power"(%lhs, %rhs) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>