Make lmhlo.sign to linalg lowering work for more floating point and integer types

PiperOrigin-RevId: 334400341
This commit is contained in:
Benjamin Kramer 2020-09-29 09:51:37 -07:00 committed by TensorFlow MLIR Team
parent 04bf09382e
commit 336ee14538
2 changed files with 51 additions and 4 deletions

View File

@ -469,11 +469,27 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::SignOp>(Location loc,
ArrayRef<Value> args, ArrayRef<Value> args,
OpBuilder* b) { OpBuilder* b) {
Type element_type = args.front().getType(); Type element_type = args.front().getType();
if (element_type.isa<FloatType>()) { if (auto float_type = element_type.dyn_cast<FloatType>()) {
FloatType float_type = element_type.cast<FloatType>(); bool ignored;
APFloat const_value = float_type.isF32() ? APFloat(1.0f) : APFloat(1.0); APFloat one_apfloat(1.0f);
Value one = b->create<mlir::ConstantFloatOp>(loc, const_value, float_type); one_apfloat.convert(float_type.getFloatSemantics(),
APFloat::rmNearestTiesToEven, &ignored);
Value one = b->create<mlir::ConstantFloatOp>(loc, one_apfloat, float_type);
return b->create<::mlir::CopySignOp>(loc, result_types, one, args[0]); return b->create<::mlir::CopySignOp>(loc, result_types, one, args[0]);
} else if (auto integer_type = element_type.dyn_cast<IntegerType>()) {
// sign(x) = x == 0 ? 0 : ((x s>> 31) | 1)
Value zero =
b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth());
Value cmp =
b->create<::mlir::CmpIOp>(loc, CmpIPredicate::eq, args[0], zero);
Value bitwidth_minus_one = b->create<::mlir::ConstantIntOp>(
loc, integer_type.getWidth() - 1, integer_type.getWidth());
Value ashr =
b->create<::mlir::SignedShiftRightOp>(loc, args[0], bitwidth_minus_one);
Value one =
b->create<::mlir::ConstantIntOp>(loc, 1, integer_type.getWidth());
Value or_op = b->create<::mlir::OrOp>(loc, ashr, one);
return b->create<::mlir::SelectOp>(loc, cmp, zero, or_op);
} }
return nullptr; return nullptr;
} }

View File

@ -586,6 +586,37 @@ func @sign(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
// ----- // -----
// CHECK-LABEL: func @sign_bf16
func @sign_bf16(%input: memref<2x2xbf16>, %result: memref<2x2xbf16>) {
"lmhlo.sign"(%input, %result) : (memref<2x2xbf16>, memref<2x2xbf16>) -> ()
return
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: bf16, %[[RESULT_OUT:.*]]):
// CHECK-NEXT: %[[CST:.*]] = constant 1.000000e+00 : bf16
// CHECK-NEXT: %[[RESULT:.*]] = copysign %[[CST]], %[[OPERAND_IN]] : bf16
// CHECK-NEXT: linalg.yield %[[RESULT]] : bf16
// -----
// CHECK-LABEL: func @sign_i16
func @sign_i16(%input: memref<2x2xi16>, %result: memref<2x2xi16>) {
"lmhlo.sign"(%input, %result) : (memref<2x2xi16>, memref<2x2xi16>) -> ()
return
}
// CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i16, %[[RESULT_OUT:.*]]):
// CHECK-NEXT: %[[C0:.*]] = constant 0 : i16
// CHECK-NEXT: %[[CMP:.*]] = cmpi "eq", %[[OPERAND_IN]], %[[C0]] : i16
// CHECK-NEXT: %[[C15:.*]] = constant 15 : i16
// CHECK-NEXT: %[[ASHR:.*]] = shift_right_signed %[[OPERAND_IN]], %[[C15]] : i16
// CHECK-NEXT: %[[C1:.*]] = constant 1 : i16
// CHECK-NEXT: %[[OR:.*]] = or %[[ASHR]], %[[C1]] : i16
// CHECK-NEXT: %[[RESULT:.*]] = select %[[CMP]], %[[C0]], %[[OR]] : i16
// CHECK-NEXT: linalg.yield %[[RESULT]] : i16
// -----
// CHECK-LABEL: func @sqrt // CHECK-LABEL: func @sqrt
func @sqrt(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { func @sqrt(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
"lmhlo.sqrt"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () "lmhlo.sqrt"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()