Handle negative exponents for lowering of hlo.pow
PiperOrigin-RevId: 352382812
This commit is contained in:
parent
3763740910
commit
ba2ee556f1
|
@ -556,23 +556,53 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::PowOp>(Location loc,
|
||||||
assert(result_type.isa<::mlir::IntegerType>() &&
|
assert(result_type.isa<::mlir::IntegerType>() &&
|
||||||
"only float and integer `pow` is supported right now");
|
"only float and integer `pow` is supported right now");
|
||||||
|
|
||||||
// There is no powi, so lower to a simple product. Note that HLO does not
|
// There is no powi, so lower to a simple product.
|
||||||
// define semantics of negative exponents.
|
Value neg_one =
|
||||||
Value init = b->create<ConstantOp>(loc, b->getIntegerAttr(result_type, 1));
|
b->create<ConstantOp>(loc, b->getIntegerAttr(result_type, -1));
|
||||||
|
Value zero = b->create<ConstantOp>(loc, b->getIntegerAttr(result_type, 0));
|
||||||
|
Value one = b->create<ConstantOp>(loc, b->getIntegerAttr(result_type, 1));
|
||||||
|
Value two = b->create<ConstantOp>(loc, b->getIntegerAttr(result_type, 2));
|
||||||
|
|
||||||
Value lowerBound = b->create<ConstantIndexOp>(loc, 0);
|
Value lowerBound = b->create<ConstantIndexOp>(loc, 0);
|
||||||
Value upperBound =
|
Value upperBound =
|
||||||
b->create<IndexCastOp>(loc, adaptor.rhs(), b->getIndexType());
|
b->create<IndexCastOp>(loc, adaptor.rhs(), b->getIndexType());
|
||||||
Value step = b->create<ConstantIndexOp>(loc, 1);
|
Value step = b->create<ConstantIndexOp>(loc, 1);
|
||||||
return b
|
Value for_result =
|
||||||
->create<scf::ForOp>(
|
b->create<scf::ForOp>(
|
||||||
loc, lowerBound, upperBound, step, llvm::makeArrayRef(init),
|
loc, lowerBound, upperBound, step, llvm::makeArrayRef(one),
|
||||||
[&](OpBuilder& b, Location l, Value v, ValueRange iters) {
|
[&](OpBuilder& b, Location l, Value v, ValueRange iters) {
|
||||||
Value prod =
|
Value prod =
|
||||||
b.create<::mlir::MulIOp>(l, adaptor.lhs(), iters.front());
|
b.create<::mlir::MulIOp>(l, adaptor.lhs(), iters.front());
|
||||||
b.create<scf::YieldOp>(l, prod);
|
b.create<scf::YieldOp>(l, prod);
|
||||||
})
|
})
|
||||||
.getResult(0);
|
.getResult(0);
|
||||||
|
|
||||||
|
Value rhs_is_even =
|
||||||
|
b->create<CmpIOp>(loc, CmpIPredicate::eq,
|
||||||
|
b->create<SignedRemIOp>(loc, adaptor.rhs(), two), zero);
|
||||||
|
Value rhs_is_negative =
|
||||||
|
b->create<CmpIOp>(loc, CmpIPredicate::slt, adaptor.rhs(), zero);
|
||||||
|
Value lhs_is_one =
|
||||||
|
b->create<CmpIOp>(loc, CmpIPredicate::eq, adaptor.lhs(), one);
|
||||||
|
Value lhs_is_neg_one =
|
||||||
|
b->create<CmpIOp>(loc, CmpIPredicate::eq, adaptor.lhs(), neg_one);
|
||||||
|
|
||||||
|
// The for_result is correct when the rhs is non-negative. When rhs is
|
||||||
|
// negative, we return 0 for integer, with the exception of lhs values of 1
|
||||||
|
// and -1 which have integer results for negative exponents. Specifically, the
|
||||||
|
// calulation is the following:
|
||||||
|
//
|
||||||
|
// - Return for_result if the rhs is not negative.
|
||||||
|
// - Return 1 or -1 depending on the parity of rhs when the lhs is -1.
|
||||||
|
// - Return 1 if lhs is 1.
|
||||||
|
// - Else return 0.
|
||||||
|
Value if_lhs_is_one = b->create<::mlir::SelectOp>(loc, lhs_is_one, one, zero);
|
||||||
|
Value if_lhs_is_neg_one = b->create<::mlir::SelectOp>(
|
||||||
|
loc, lhs_is_neg_one,
|
||||||
|
b->create<::mlir::SelectOp>(loc, rhs_is_even, one, neg_one),
|
||||||
|
if_lhs_is_one);
|
||||||
|
return b->create<::mlir::SelectOp>(loc, rhs_is_negative, if_lhs_is_neg_one,
|
||||||
|
for_result);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
|
|
|
@ -792,17 +792,26 @@ func @float_pow(%lhs: tensor<2x2xf32>,
|
||||||
// CHECK: #map = affine_map<(d0, d1) -> (d0, d1)>
|
// CHECK: #map = affine_map<(d0, d1) -> (d0, d1)>
|
||||||
// CHECK-LABEL: func @integer_pow
|
// CHECK-LABEL: func @integer_pow
|
||||||
func @integer_pow(%lhs: tensor<2x2xi32>,
|
func @integer_pow(%lhs: tensor<2x2xi32>,
|
||||||
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
|
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
|
||||||
// CHECK: linalg.generic
|
// CHECK: linalg.generic
|
||||||
// CHECK: ^{{[a-z0-9_]*}}
|
// CHECK: ^{{[a-z0-9_]*}}
|
||||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: i32
|
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: i32
|
||||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: i32
|
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: i32
|
||||||
// CHECK: %[[UPPER:.*]] = index_cast %[[ARG1]]
|
// CHECK: %[[UPPER:.*]] = index_cast %[[ARG1]]
|
||||||
// CHECK: %[[RESULT:.*]] = scf.for {{.*}} to %[[UPPER]]
|
// CHECK: %[[FOR_RESULT:.*]] = scf.for {{.*}} to %[[UPPER]]
|
||||||
// CHECK-SAME: step %c1{{[a-zA-Z0-9_]*}}
|
// CHECK-SAME: step %c1{{[a-zA-Z0-9_]*}}
|
||||||
// CHECK-SAME: iter_args(%[[ITER:.*]] = %c1{{.*}}) -> (i32) {
|
// CHECK-SAME: iter_args(%[[ITER:.*]] = %c1{{.*}}) -> (i32) {
|
||||||
// CHECK: %[[ACCUM:[a-zA-Z0-9_]*]] = muli %[[ARG0]], %[[ITER]]
|
// CHECK: %[[ACCUM:[a-zA-Z0-9_]*]] = muli %[[ARG0]], %[[ITER]]
|
||||||
// CHECK: scf.yield %[[ACCUM]]
|
// CHECK: scf.yield %[[ACCUM]]
|
||||||
|
// CHECK: %[[RHS_PARITY:.*]] = remi_signed %[[ARG1]], %c2
|
||||||
|
// CHECK: %[[RHS_EVEN:.*]] = cmpi eq, %[[RHS_PARITY]], %c0
|
||||||
|
// CHECK: %[[RHS_NEG:.*]] = cmpi slt, %[[ARG1]], %c0
|
||||||
|
// CHECK: %[[LHS_ONE:.*]] = cmpi eq, %[[ARG0]], %c1
|
||||||
|
// CHECK: %[[LHS_NEG_ONE:.*]] = cmpi eq, %[[ARG0]], %c-1
|
||||||
|
// CHECK: %[[VAL5:.*]] = select %[[LHS_ONE]], %c1_i32, %c0
|
||||||
|
// CHECK: %[[VAL6:.*]] = select %[[RHS_EVEN]], %c1{{.*}}, %c-1
|
||||||
|
// CHECK: %[[VAL7:.*]] = select %[[LHS_NEG_ONE]], %[[VAL6]], %[[VAL5]]
|
||||||
|
// CHECK: %[[RESULT:.*]] = select %[[RHS_NEG]], %[[VAL7]], %[[FOR_RESULT]]
|
||||||
// CHECK: linalg.yield %[[RESULT]]
|
// CHECK: linalg.yield %[[RESULT]]
|
||||||
%0 = "mhlo.power"(%lhs, %rhs) : (tensor<2x2xi32>,
|
%0 = "mhlo.power"(%lhs, %rhs) : (tensor<2x2xi32>,
|
||||||
tensor<2x2xi32>) -> tensor<2x2xi32>
|
tensor<2x2xi32>) -> tensor<2x2xi32>
|
||||||
|
|
Loading…
Reference in New Issue